Commit 4f606312 authored by Zhang Zhi's avatar Zhang Zhi Committed by Mufei Li
Browse files

[Graph]To(device) should return updated object (#1057)

* To(device) should return updated object

* Add test case and doc

* Add another doc

* Add descriptions  and a test

* Pass test
parent 8ae9770f
......@@ -3339,6 +3339,11 @@ class DGLGraph(DGLBaseGraph):
ctx : framework-specific context object
The context to move data to.
Returns
-------
g : DGLGraph
Moved DGLGraph of the targeted mode.
Examples
--------
The following example uses PyTorch backend.
......@@ -3348,12 +3353,13 @@ class DGLGraph(DGLBaseGraph):
>>> G.add_nodes(5, {'h': torch.ones((5, 2))})
>>> G.add_edges([0, 1], [1, 2], {'m' : torch.ones((2, 2))})
>>> G.add_edges([0, 1], [1, 2], {'m' : torch.ones((2, 2))})
>>> G.to(torch.device('cuda:0'))
>>> G = G.to(torch.device('cuda:0'))
"""
for k in self.ndata.keys():
self.ndata[k] = F.copy_to(self.ndata[k], ctx)
for k in self.edata.keys():
self.edata[k] = F.copy_to(self.edata[k], ctx)
return self
# pylint: enable=invalid-name
def local_var(self):
......
......@@ -3575,6 +3575,11 @@ class DGLHeteroGraph(object):
ctx : framework-specific context object
The context to move data to.
Returns
-------
g : DGLHeteroGraph
Moved DGLHeteroGraph of the targeted mode.
Examples
--------
The following example uses PyTorch backend.
......@@ -3583,7 +3588,7 @@ class DGLHeteroGraph(object):
>>> g = dgl.bipartite([(0, 0), (1, 0), (1, 2), (2, 1)], 'user', 'plays', 'game')
>>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])
>>> g.edges['plays'].data['h'] = torch.tensor([[0.], [1.], [2.], [3.]])
>>> g.to(torch.device('cuda:0'))
>>> g = g.to(torch.device('cuda:0'))
"""
for i in range(len(self._node_frames)):
for k in self._node_frames[i].keys():
......@@ -3591,6 +3596,7 @@ class DGLHeteroGraph(object):
for i in range(len(self._edge_frames)):
for k in self._edge_frames[i].keys():
self._edge_frames[i][k] = F.copy_to(self._edge_frames[i][k], ctx)
return self
def local_var(self):
"""Return a heterograph object that can be used in a local function scope.
......
......@@ -592,6 +592,12 @@ def test_flatten():
assert fg.etypes == ['follows+knows']
check_mapping(g, fg)
def test_to_device():
hg = create_test_heterograph()
if F.is_cuda_available():
hg = hg.to(F.cuda())
assert hg is not None
def test_convert():
hg = create_test_heterograph()
hs = []
......@@ -1200,6 +1206,7 @@ if __name__ == '__main__':
test_view1()
test_flatten()
test_convert()
test_to_device()
test_transform()
test_subgraph()
test_apply()
......
......@@ -6,7 +6,8 @@ def test_to_device():
g.add_nodes(5, {'h' : F.ones((5, 2))})
g.add_edges([0, 1], [1, 2], {'m' : F.ones((2, 2))})
if F.is_cuda_available():
g.to(F.cuda())
g = g.to(F.cuda())
assert g is not None
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment