"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "9d2fc6b53528bf0aa19a49f5b98848abac98275c"
Commit 4f606312 authored by Zhang Zhi's avatar Zhang Zhi Committed by Mufei Li
Browse files

[Graph]To(device) should return updated object (#1057)

* To(device) should return updated object

* Add test case and doc

* Add another doc

* Add descriptions  and a test

* Pass test
parent 8ae9770f
...@@ -3339,6 +3339,11 @@ class DGLGraph(DGLBaseGraph): ...@@ -3339,6 +3339,11 @@ class DGLGraph(DGLBaseGraph):
ctx : framework-specific context object ctx : framework-specific context object
The context to move data to. The context to move data to.
Returns
-------
g : DGLGraph
Moved DGLGraph of the targeted mode.
Examples Examples
-------- --------
The following example uses PyTorch backend. The following example uses PyTorch backend.
...@@ -3348,12 +3353,13 @@ class DGLGraph(DGLBaseGraph): ...@@ -3348,12 +3353,13 @@ class DGLGraph(DGLBaseGraph):
>>> G.add_nodes(5, {'h': torch.ones((5, 2))}) >>> G.add_nodes(5, {'h': torch.ones((5, 2))})
>>> G.add_edges([0, 1], [1, 2], {'m' : torch.ones((2, 2))}) >>> G.add_edges([0, 1], [1, 2], {'m' : torch.ones((2, 2))})
>>> G.add_edges([0, 1], [1, 2], {'m' : torch.ones((2, 2))}) >>> G.add_edges([0, 1], [1, 2], {'m' : torch.ones((2, 2))})
>>> G.to(torch.device('cuda:0')) >>> G = G.to(torch.device('cuda:0'))
""" """
for k in self.ndata.keys(): for k in self.ndata.keys():
self.ndata[k] = F.copy_to(self.ndata[k], ctx) self.ndata[k] = F.copy_to(self.ndata[k], ctx)
for k in self.edata.keys(): for k in self.edata.keys():
self.edata[k] = F.copy_to(self.edata[k], ctx) self.edata[k] = F.copy_to(self.edata[k], ctx)
return self
# pylint: enable=invalid-name # pylint: enable=invalid-name
def local_var(self): def local_var(self):
......
...@@ -3575,6 +3575,11 @@ class DGLHeteroGraph(object): ...@@ -3575,6 +3575,11 @@ class DGLHeteroGraph(object):
ctx : framework-specific context object ctx : framework-specific context object
The context to move data to. The context to move data to.
Returns
-------
g : DGLHeteroGraph
Moved DGLHeteroGraph of the targeted mode.
Examples Examples
-------- --------
The following example uses PyTorch backend. The following example uses PyTorch backend.
...@@ -3583,7 +3588,7 @@ class DGLHeteroGraph(object): ...@@ -3583,7 +3588,7 @@ class DGLHeteroGraph(object):
>>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game') >>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
>>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]]) >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])
>>> g.edges['plays'].data['h'] = torch.tensor([[0.], [1.], [2.], [3.]]) >>> g.edges['plays'].data['h'] = torch.tensor([[0.], [1.], [2.], [3.]])
>>> g.to(torch.device('cuda:0')) >>> g = g.to(torch.device('cuda:0'))
""" """
for i in range(len(self._node_frames)): for i in range(len(self._node_frames)):
for k in self._node_frames[i].keys(): for k in self._node_frames[i].keys():
...@@ -3591,6 +3596,7 @@ class DGLHeteroGraph(object): ...@@ -3591,6 +3596,7 @@ class DGLHeteroGraph(object):
for i in range(len(self._edge_frames)): for i in range(len(self._edge_frames)):
for k in self._edge_frames[i].keys(): for k in self._edge_frames[i].keys():
self._edge_frames[i][k] = F.copy_to(self._edge_frames[i][k], ctx) self._edge_frames[i][k] = F.copy_to(self._edge_frames[i][k], ctx)
return self
def local_var(self): def local_var(self):
"""Return a heterograph object that can be used in a local function scope. """Return a heterograph object that can be used in a local function scope.
......
...@@ -592,6 +592,12 @@ def test_flatten(): ...@@ -592,6 +592,12 @@ def test_flatten():
assert fg.etypes == ['follows+knows'] assert fg.etypes == ['follows+knows']
check_mapping(g, fg) check_mapping(g, fg)
def test_to_device():
hg = create_test_heterograph()
if F.is_cuda_available():
hg = hg.to(F.cuda())
assert hg is not None
def test_convert(): def test_convert():
hg = create_test_heterograph() hg = create_test_heterograph()
hs = [] hs = []
...@@ -1200,6 +1206,7 @@ if __name__ == '__main__': ...@@ -1200,6 +1206,7 @@ if __name__ == '__main__':
test_view1() test_view1()
test_flatten() test_flatten()
test_convert() test_convert()
test_to_device()
test_transform() test_transform()
test_subgraph() test_subgraph()
test_apply() test_apply()
......
...@@ -6,7 +6,8 @@ def test_to_device(): ...@@ -6,7 +6,8 @@ def test_to_device():
g.add_nodes(5, {'h' : F.ones((5, 2))}) g.add_nodes(5, {'h' : F.ones((5, 2))})
g.add_edges([0, 1], [1, 2], {'m' : F.ones((2, 2))}) g.add_edges([0, 1], [1, 2], {'m' : F.ones((2, 2))})
if F.is_cuda_available(): if F.is_cuda_available():
g.to(F.cuda()) g = g.to(F.cuda())
assert g is not None
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment