Unverified Commit 5e34ca8b authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

[Feature] Lazy copy ndata, edata to device (#1986)



* Lazy to device

* remove print

* lint

* Fix

* Fix

* Fix

* Fix

* Fix

* Fix

* Fix

* lint

* Fix

* Revert "Fix"

This reverts commit 615c9b8f80f5f6ee2ab43c849a22f0083deedf3b.

* Add test for frame lazy update

* disable tensorflow

* upd
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-51-214.ec2.internal>
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent 912da18c
...@@ -84,10 +84,11 @@ class Column(object): ...@@ -84,10 +84,11 @@ class Column(object):
index : Tensor index : Tensor
Index tensor Index tensor
""" """
def __init__(self, storage, scheme=None, index=None): def __init__(self, storage, scheme=None, index=None, device=None):
self.storage = storage self.storage = storage
self.scheme = scheme if scheme else infer_scheme(storage) self.scheme = scheme if scheme else infer_scheme(storage)
self.index = index self.index = index
self.device = device
def __len__(self): def __len__(self):
"""The number of features (number of rows) in this column.""" """The number of features (number of rows) in this column."""
...@@ -105,8 +106,18 @@ class Column(object): ...@@ -105,8 +106,18 @@ class Column(object):
def data(self): def data(self):
"""Return the feature data. Perform index selecting if needed.""" """Return the feature data. Perform index selecting if needed."""
if self.index is not None: if self.index is not None:
# If index and storage is not in the same context,
# copy index to the same context of storage.
# Copy index is usually cheaper than copy data
if F.context(self.storage) != F.context(self.index):
self.index = F.copy_to(self.index, F.context(self.storage))
self.storage = F.gather_row(self.storage, self.index) self.storage = F.gather_row(self.storage, self.index)
self.index = None self.index = None
# move data to the right device
if self.device is not None:
self.storage = F.copy_to(self.storage, self.device[0], **self.device[1])
self.device = None
return self.storage return self.storage
@data.setter @data.setter
...@@ -115,6 +126,25 @@ class Column(object): ...@@ -115,6 +126,25 @@ class Column(object):
self.index = None self.index = None
self.storage = val self.storage = val
def to(self, device, **kwargs): # pylint: disable=invalid-name
""" Return a new column with columns copy to the targeted device (cpu/gpu).
Parameters
----------
device : Framework-specific device context object
The context to move data to.
kwargs : Key-word arguments.
Key-word arguments fed to the framework copy function.
Returns
-------
Column
A new column
"""
col = self.clone()
col.device = (device, kwargs)
return col
def __getitem__(self, rowids): def __getitem__(self, rowids):
"""Return the feature data given the rowids. """Return the feature data given the rowids.
...@@ -186,7 +216,7 @@ class Column(object): ...@@ -186,7 +216,7 @@ class Column(object):
def clone(self): def clone(self):
"""Return a shallow copy of this column.""" """Return a shallow copy of this column."""
return Column(self.storage, self.scheme, self.index) return Column(self.storage, self.scheme, self.index, self.device)
def deepclone(self): def deepclone(self):
"""Return a deepcopy of this column. """Return a deepcopy of this column.
...@@ -214,9 +244,9 @@ class Column(object): ...@@ -214,9 +244,9 @@ class Column(object):
Sub-column Sub-column
""" """
if self.index is None: if self.index is None:
return Column(self.storage, self.scheme, rowids) return Column(self.storage, self.scheme, rowids, self.device)
else: else:
return Column(self.storage, self.scheme, F.gather_row(self.index, rowids)) return Column(self.storage, self.scheme, F.gather_row(self.index, rowids), self.device)
@staticmethod @staticmethod
def create(data): def create(data):
...@@ -578,5 +608,25 @@ class Frame(MutableMapping): ...@@ -578,5 +608,25 @@ class Frame(MutableMapping):
subf._default_initializer = self._default_initializer subf._default_initializer = self._default_initializer
return subf return subf
def to(self, device, **kwargs): # pylint: disable=invalid-name
""" Return a new frame with columns copy to the targeted device (cpu/gpu).
Parameters
----------
device : Framework-specific device context object
The context to move data to.
kwargs : Key-word arguments.
Key-word arguments fed to the framework copy function.
Returns
-------
Frame
A new frame
"""
newframe = self.clone()
new_columns = {key : col.to(device, **kwargs) for key, col in newframe._columns.items()}
newframe._columns = new_columns
return newframe
def __repr__(self): def __repr__(self):
return repr(dict(self)) return repr(dict(self))
...@@ -2756,7 +2756,7 @@ class DGLHeteroGraph(object): ...@@ -2756,7 +2756,7 @@ class DGLHeteroGraph(object):
Representation dict from feature name to feature tensor. Representation dict from feature name to feature tensor.
""" """
if is_all(u): if is_all(u):
return dict(self._node_frames[ntid]) return self._node_frames[ntid]
else: else:
u = utils.prepare_tensor(self, u, 'u') u = utils.prepare_tensor(self, u, 'u')
return self._node_frames[ntid].subframe(u) return self._node_frames[ntid].subframe(u)
...@@ -3614,14 +3614,12 @@ class DGLHeteroGraph(object): ...@@ -3614,14 +3614,12 @@ class DGLHeteroGraph(object):
# TODO(minjie): handle initializer # TODO(minjie): handle initializer
new_nframes = [] new_nframes = []
for nframe in self._node_frames: for nframe in self._node_frames:
new_feats = {k : F.copy_to(feat, device, **kwargs) for k, feat in nframe.items()} new_nframes.append(nframe.to(device, **kwargs))
new_nframes.append(Frame(new_feats, num_rows=nframe.num_rows))
ret._node_frames = new_nframes ret._node_frames = new_nframes
new_eframes = [] new_eframes = []
for eframe in self._edge_frames: for eframe in self._edge_frames:
new_feats = {k : F.copy_to(feat, device, **kwargs) for k, feat in eframe.items()} new_eframes.append(eframe.to(device, **kwargs))
new_eframes.append(Frame(new_feats, num_rows=eframe.num_rows))
ret._edge_frames = new_eframes ret._edge_frames = new_eframes
# 2. Copy misc info # 2. Copy misc info
......
...@@ -2406,6 +2406,100 @@ def test_remove_nodes(idtype): ...@@ -2406,6 +2406,100 @@ def test_remove_nodes(idtype):
assert F.array_equal(u, F.tensor([1], dtype=idtype)) assert F.array_equal(u, F.tensor([1], dtype=idtype))
assert F.array_equal(v, F.tensor([0], dtype=idtype)) assert F.array_equal(v, F.tensor([0], dtype=idtype))
@parametrize_dtype
def test_frame(idtype):
g = dgl.graph(([0, 1, 2], [1, 2, 3]), idtype=idtype, device=F.ctx())
g.ndata['h'] = F.copy_to(F.tensor([0, 1, 2, 3], dtype=idtype), ctx=F.ctx())
g.edata['h'] = F.copy_to(F.tensor([0, 1, 2], dtype=idtype), ctx=F.ctx())
# remove nodes
sg = dgl.remove_nodes(g, [3])
# check for lazy update
assert F.array_equal(sg._node_frames[0]._columns['h'].storage, g.ndata['h'])
assert F.array_equal(sg._edge_frames[0]._columns['h'].storage, g.edata['h'])
assert sg.ndata['h'].shape[0] == 3
assert sg.edata['h'].shape[0] == 2
# update after read
assert F.array_equal(sg._node_frames[0]._columns['h'].storage, F.tensor([0, 1, 2], dtype=idtype))
assert F.array_equal(sg._edge_frames[0]._columns['h'].storage, F.tensor([0, 1], dtype=idtype))
ng = dgl.add_nodes(sg, 1)
assert ng.ndata['h'].shape[0] == 4
assert F.array_equal(ng._node_frames[0]._columns['h'].storage, F.tensor([0, 1, 2, 0], dtype=idtype))
ng = dgl.add_edges(ng, [3], [1])
assert ng.edata['h'].shape[0] == 3
assert F.array_equal(ng._edge_frames[0]._columns['h'].storage, F.tensor([0, 1, 0], dtype=idtype))
# multi level lazy update
sg = dgl.remove_nodes(g, [3])
assert F.array_equal(sg._node_frames[0]._columns['h'].storage, g.ndata['h'])
assert F.array_equal(sg._edge_frames[0]._columns['h'].storage, g.edata['h'])
ssg = dgl.remove_nodes(sg, [1])
assert F.array_equal(ssg._node_frames[0]._columns['h'].storage, g.ndata['h'])
assert F.array_equal(ssg._edge_frames[0]._columns['h'].storage, g.edata['h'])
# ssg is changed
assert ssg.ndata['h'].shape[0] == 2
assert ssg.edata['h'].shape[0] == 0
assert F.array_equal(ssg._node_frames[0]._columns['h'].storage, F.tensor([0, 2], dtype=idtype))
# sg still in lazy model
assert F.array_equal(sg._node_frames[0]._columns['h'].storage, g.ndata['h'])
assert F.array_equal(sg._edge_frames[0]._columns['h'].storage, g.edata['h'])
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TensorFlow always create a new tensor")
@unittest.skipIf(F._default_context_str == 'cpu', reason="cpu do not have context change problem")
@parametrize_dtype
def test_frame_device(idtype):
g = dgl.graph(([0,1,2], [2,3,1]))
g.ndata['h'] = F.copy_to(F.tensor([1,1,1,2], dtype=idtype), ctx=F.cpu())
g.ndata['hh'] = F.copy_to(F.ones((4,3), dtype=idtype), ctx=F.cpu())
g.edata['h'] = F.copy_to(F.tensor([1,2,3], dtype=idtype), ctx=F.cpu())
g = g.to(F.ctx())
# lazy device copy
assert F.context(g._node_frames[0]._columns['h'].storage) == F.cpu()
assert F.context(g._node_frames[0]._columns['hh'].storage) == F.cpu()
print(g.ndata['h'])
assert F.context(g._node_frames[0]._columns['h'].storage) == F.ctx()
assert F.context(g._node_frames[0]._columns['hh'].storage) == F.cpu()
assert F.context(g._edge_frames[0]._columns['h'].storage) == F.cpu()
# lazy device copy in subgraph
sg = dgl.node_subgraph(g, [0,1,2])
assert F.context(sg._node_frames[0]._columns['h'].storage) == F.ctx()
assert F.context(sg._node_frames[0]._columns['hh'].storage) == F.cpu()
assert F.context(sg._edge_frames[0]._columns['h'].storage) == F.cpu()
print(sg.ndata['hh'])
assert F.context(sg._node_frames[0]._columns['hh'].storage) == F.ctx()
assert F.context(sg._edge_frames[0]._columns['h'].storage) == F.cpu()
# back to cpu
sg = sg.to(F.cpu())
assert F.context(sg._node_frames[0]._columns['h'].storage) == F.ctx()
assert F.context(sg._node_frames[0]._columns['hh'].storage) == F.ctx()
assert F.context(sg._edge_frames[0]._columns['h'].storage) == F.cpu()
print(sg.ndata['h'])
print(sg.ndata['hh'])
print(sg.edata['h'])
assert F.context(sg._node_frames[0]._columns['h'].storage) == F.cpu()
assert F.context(sg._node_frames[0]._columns['hh'].storage) == F.cpu()
assert F.context(sg._edge_frames[0]._columns['h'].storage) == F.cpu()
# set some field
sg = sg.to(F.ctx())
assert F.context(sg._node_frames[0]._columns['h'].storage) == F.cpu()
sg.ndata['h'][0] = 5
assert F.context(sg._node_frames[0]._columns['h'].storage) == F.ctx()
assert F.context(sg._node_frames[0]._columns['hh'].storage) == F.cpu()
assert F.context(sg._edge_frames[0]._columns['h'].storage) == F.cpu()
# add nodes
ng = dgl.add_nodes(sg, 3)
assert F.context(ng._node_frames[0]._columns['h'].storage) == F.ctx()
assert F.context(ng._node_frames[0]._columns['hh'].storage) == F.ctx()
assert F.context(ng._edge_frames[0]._columns['h'].storage) == F.cpu()
if __name__ == '__main__': if __name__ == '__main__':
# test_create() # test_create()
# test_query() # test_query()
...@@ -2434,9 +2528,11 @@ if __name__ == '__main__': ...@@ -2434,9 +2528,11 @@ if __name__ == '__main__':
# test_dtype_cast() # test_dtype_cast()
# test_reverse("int32") # test_reverse("int32")
# test_format() # test_format()
test_add_edges(F.int32) #test_add_edges(F.int32)
test_add_nodes(F.int32) #test_add_nodes(F.int32)
test_remove_edges(F.int32) #test_remove_edges(F.int32)
test_remove_nodes(F.int32) #test_remove_nodes(F.int32)
test_clone(F.int32) #test_clone(F.int32)
test_frame(F.int32)
test_frame_device(F.int32)
pass pass
...@@ -420,7 +420,7 @@ def test_sage_conv2(idtype): ...@@ -420,7 +420,7 @@ def test_sage_conv2(idtype):
sage = nn.SAGEConv((3, 3), 2, 'gcn') sage = nn.SAGEConv((3, 3), 2, 'gcn')
feat = (F.randn((5, 3)), F.randn((3, 3))) feat = (F.randn((5, 3)), F.randn((3, 3)))
sage = sage.to(ctx) sage = sage.to(ctx)
h = sage(g, feat) h = sage(g, (F.copy_to(feat[0], F.ctx()), F.copy_to(feat[1], F.ctx())))
assert h.shape[-1] == 2 assert h.shape[-1] == 2
assert h.shape[0] == 3 assert h.shape[0] == 3
for aggre_type in ['mean', 'pool', 'lstm']: for aggre_type in ['mean', 'pool', 'lstm']:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment