Commit 993fd3f9 authored by HQ's avatar HQ Committed by VoVAllen
Browse files

[Enhancement] Add DGLGraph.to for PyTorch and MXNet backend (#600)

* add graph_to

* use backend copy_to

* add test

* fix test

* framework agnostic to() test

* disable pylint complaint

* add examples

* fix docstring

* formatting

* Format

* Update test_to_device.py
parent baa16231
...@@ -79,6 +79,7 @@ Converting from/to other format ...@@ -79,6 +79,7 @@ Converting from/to other format
DGLGraph.adjacency_matrix DGLGraph.adjacency_matrix
DGLGraph.adjacency_matrix_scipy DGLGraph.adjacency_matrix_scipy
DGLGraph.incidence_matrix DGLGraph.incidence_matrix
DGLGraph.to
Using Node/edge features Using Node/edge features
------------------------ ------------------------
......
...@@ -81,6 +81,7 @@ def copy_to(input, ctx): ...@@ -81,6 +81,7 @@ def copy_to(input, ctx):
if ctx.type == 'cpu': if ctx.type == 'cpu':
return input.cpu() return input.cpu()
elif ctx.type == 'cuda': elif ctx.type == 'cuda':
if ctx.index is not None:
th.cuda.set_device(ctx.index) th.cuda.set_device(ctx.index)
return input.cuda() return input.cuda()
else: else:
......
...@@ -3229,3 +3229,28 @@ class DGLGraph(DGLBaseGraph): ...@@ -3229,3 +3229,28 @@ class DGLGraph(DGLBaseGraph):
return ret.format(node=self.number_of_nodes(), edge=self.number_of_edges(), return ret.format(node=self.number_of_nodes(), edge=self.number_of_edges(),
ndata=str(self.node_attr_schemes()), ndata=str(self.node_attr_schemes()),
edata=str(self.edge_attr_schemes())) edata=str(self.edge_attr_schemes()))
# pylint: disable=invalid-name
def to(self, ctx):
"""
Move both ndata and edata to the targeted mode (cpu/gpu)
Framework agnostic
Parameters
----------
ctx : framework specific context object
Examples (Pytorch & MXNet)
--------
>>> import backend as F
>>> G = dgl.DGLGraph()
>>> G.add_nodes(5, {'h': torch.ones((5, 2))})
>>> G.add_edges([0, 1], [1, 2], {'m' : torch.ones((2, 2))})
>>> G.add_edges([0, 1], [1, 2], {'m' : torch.ones((2, 2))})
>>> G.to(F.cuda())
"""
for k in self.ndata.keys():
self.ndata[k] = F.copy_to(self.ndata[k], ctx)
for k in self.edata.keys():
self.edata[k] = F.copy_to(self.edata[k], ctx)
...@@ -9,6 +9,10 @@ def cuda(): ...@@ -9,6 +9,10 @@ def cuda():
"""Context object for CUDA.""" """Context object for CUDA."""
pass pass
def is_cuda_available():
"""Check whether CUDA is available."""
pass
############################################################################### ###############################################################################
# Tensor functions on feature data # Tensor functions on feature data
# -------------------------------- # --------------------------------
......
...@@ -8,6 +8,14 @@ import mxnet.autograd as autograd ...@@ -8,6 +8,14 @@ import mxnet.autograd as autograd
def cuda(): def cuda():
return mx.gpu() return mx.gpu()
def is_cuda_available():
# TODO: Does MXNet have a convenient function to test GPU availability/compilation?
try:
a = nd.array([1, 2, 3], ctx=mx.gpu())
return True
except mx.MXNetError:
return False
def array_equal(a, b): def array_equal(a, b):
return nd.equal(a, b).asnumpy().all() return nd.equal(a, b).asnumpy().all()
......
...@@ -5,6 +5,9 @@ import torch as th ...@@ -5,6 +5,9 @@ import torch as th
def cuda(): def cuda():
return th.device('cuda:0') return th.device('cuda:0')
def is_cuda_available():
return th.cuda.is_available()
def array_equal(a, b): def array_equal(a, b):
return th.equal(a.cpu(), b.cpu()) return th.equal(a.cpu(), b.cpu())
......
import dgl
import backend as F
def test_to_device():
g = dgl.DGLGraph()
g.add_nodes(5, {'h' : F.ones((5, 2))})
g.add_edges([0, 1], [1, 2], {'m' : F.ones((2, 2))})
if F.is_cuda_available():
g.to(F.cuda())
if __name__ == '__main__':
test_to_device()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment