"git@developer.sourcefind.cn:OpenDAS/fairscale.git" did not exist on "72c6bab24f398dbc583a26508dd9ee1f3dbc4fc2"
Commit 993fd3f9 authored by HQ's avatar HQ Committed by VoVAllen
Browse files

[Enhancement] Add DGLGraph.to for PyTorch and MXNet backend (#600)

* add graph_to

* use backend copy_to

* add test

* fix test

* framework agnostic to() test

* disable pylint complaint

* add examples

* fix docstring

* formatting

* Format

* Update test_to_device.py
parent baa16231
......@@ -79,6 +79,7 @@ Converting from/to other format
DGLGraph.adjacency_matrix
DGLGraph.adjacency_matrix_scipy
DGLGraph.incidence_matrix
DGLGraph.to
Using Node/edge features
------------------------
......
......@@ -81,7 +81,8 @@ def copy_to(input, ctx):
if ctx.type == 'cpu':
return input.cpu()
elif ctx.type == 'cuda':
th.cuda.set_device(ctx.index)
if ctx.index is not None:
th.cuda.set_device(ctx.index)
return input.cuda()
else:
raise RuntimeError('Invalid context', ctx)
......
......@@ -3229,3 +3229,28 @@ class DGLGraph(DGLBaseGraph):
return ret.format(node=self.number_of_nodes(), edge=self.number_of_edges(),
ndata=str(self.node_attr_schemes()),
edata=str(self.edge_attr_schemes()))
# pylint: disable=invalid-name
def to(self, ctx):
"""
Move both ndata and edata to the targeted mode (cpu/gpu)
Framework agnostic
Parameters
----------
ctx : framework specific context object
Examples (Pytorch & MXNet)
--------
>>> import backend as F
>>> G = dgl.DGLGraph()
>>> G.add_nodes(5, {'h': torch.ones((5, 2))})
>>> G.add_edges([0, 1], [1, 2], {'m' : torch.ones((2, 2))})
>>> G.add_edges([0, 1], [1, 2], {'m' : torch.ones((2, 2))})
>>> G.to(F.cuda())
"""
for k in self.ndata.keys():
self.ndata[k] = F.copy_to(self.ndata[k], ctx)
for k in self.edata.keys():
self.edata[k] = F.copy_to(self.edata[k], ctx)
......@@ -9,6 +9,10 @@ def cuda():
"""Context object for CUDA."""
pass
def is_cuda_available():
"""Check whether CUDA is available."""
pass
###############################################################################
# Tensor functions on feature data
# --------------------------------
......
......@@ -8,6 +8,14 @@ import mxnet.autograd as autograd
def cuda():
return mx.gpu()
def is_cuda_available():
# TODO: Does MXNet have a convenient function to test GPU availability/compilation?
try:
a = nd.array([1, 2, 3], ctx=mx.gpu())
return True
except mx.MXNetError:
return False
def array_equal(a, b):
return nd.equal(a, b).asnumpy().all()
......
......@@ -5,6 +5,9 @@ import torch as th
def cuda():
return th.device('cuda:0')
def is_cuda_available():
return th.cuda.is_available()
def array_equal(a, b):
return th.equal(a.cpu(), b.cpu())
......
import dgl
import backend as F
def test_to_device():
g = dgl.DGLGraph()
g.add_nodes(5, {'h' : F.ones((5, 2))})
g.add_edges([0, 1], [1, 2], {'m' : F.ones((2, 2))})
if F.is_cuda_available():
g.to(F.cuda())
if __name__ == '__main__':
test_to_device()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment