Unverified Commit 04522a76 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[bugfix] Quick fix of #1547 (#1600)

* upd

* upd
parent 57d111f9
...@@ -269,7 +269,7 @@ def asnumpy(input): ...@@ -269,7 +269,7 @@ def asnumpy(input):
""" """
pass pass
def copy_to(input, ctx): def copy_to(input, ctx, **kwargs):
"""Copy the given tensor to the context. """Copy the given tensor to the context.
Parameters Parameters
......
...@@ -115,7 +115,7 @@ def astype(input, ty): ...@@ -115,7 +115,7 @@ def astype(input, ty):
def asnumpy(input): def asnumpy(input):
return input.asnumpy() return input.asnumpy()
def copy_to(input, ctx): def copy_to(input, ctx, **kwargs):
return input.as_in_context(ctx) return input.as_in_context(ctx)
def sum(input, dim, keepdims=False): def sum(input, dim, keepdims=False):
......
...@@ -87,13 +87,13 @@ def asnumpy(input): ...@@ -87,13 +87,13 @@ def asnumpy(input):
else: else:
return input.cpu().detach().numpy() return input.cpu().detach().numpy()
def copy_to(input, ctx): def copy_to(input, ctx, **kwargs):
if ctx.type == 'cpu': if ctx.type == 'cpu':
return input.cpu() return input.cpu()
elif ctx.type == 'cuda': elif ctx.type == 'cuda':
if ctx.index is not None: if ctx.index is not None:
th.cuda.set_device(ctx.index) th.cuda.set_device(ctx.index)
return input.cuda() return input.cuda(**kwargs)
else: else:
raise RuntimeError('Invalid context', ctx) raise RuntimeError('Invalid context', ctx)
......
...@@ -129,7 +129,7 @@ def asnumpy(input): ...@@ -129,7 +129,7 @@ def asnumpy(input):
return input.numpy() return input.numpy()
def copy_to(input, ctx): def copy_to(input, ctx, **kwargs):
with tf.device(ctx): with tf.device(ctx):
new_tensor = tf.identity(input) new_tensor = tf.identity(input)
return new_tensor return new_tensor
......
...@@ -3872,7 +3872,7 @@ class DGLGraph(DGLBaseGraph): ...@@ -3872,7 +3872,7 @@ class DGLGraph(DGLBaseGraph):
edata=str(self.edge_attr_schemes())) edata=str(self.edge_attr_schemes()))
# pylint: disable=invalid-name # pylint: disable=invalid-name
def to(self, ctx): def to(self, ctx, **kwargs):
"""Move both ndata and edata to the targeted mode (cpu/gpu) """Move both ndata and edata to the targeted mode (cpu/gpu)
Framework agnostic Framework agnostic
...@@ -3898,9 +3898,9 @@ class DGLGraph(DGLBaseGraph): ...@@ -3898,9 +3898,9 @@ class DGLGraph(DGLBaseGraph):
>>> G = G.to(torch.device('cuda:0')) >>> G = G.to(torch.device('cuda:0'))
""" """
for k in self.ndata.keys(): for k in self.ndata.keys():
self.ndata[k] = F.copy_to(self.ndata[k], ctx) self.ndata[k] = F.copy_to(self.ndata[k], ctx, **kwargs)
for k in self.edata.keys(): for k in self.edata.keys():
self.edata[k] = F.copy_to(self.edata[k], ctx) self.edata[k] = F.copy_to(self.edata[k], ctx, **kwargs)
return self return self
# pylint: enable=invalid-name # pylint: enable=invalid-name
......
...@@ -4025,7 +4025,7 @@ class DGLHeteroGraph(object): ...@@ -4025,7 +4025,7 @@ class DGLHeteroGraph(object):
edges = F.tensor(edges) edges = F.tensor(edges)
return F.boolean_mask(edges, e_mask) return F.boolean_mask(edges, e_mask)
def to(self, ctx): # pylint: disable=invalid-name def to(self, ctx, **kwargs): # pylint: disable=invalid-name
"""Move both ndata and edata to the targeted mode (cpu/gpu) """Move both ndata and edata to the targeted mode (cpu/gpu)
Framework agnostic Framework agnostic
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment