Unverified Commit 04522a76 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[bugfix] Quick fix of #1547 (#1600)

* upd

* upd
parent 57d111f9
......@@ -269,7 +269,7 @@ def asnumpy(input):
"""
pass
def copy_to(input, ctx):
def copy_to(input, ctx, **kwargs):
"""Copy the given tensor to the context.
Parameters
......
......@@ -115,7 +115,7 @@ def astype(input, ty):
def asnumpy(input):
return input.asnumpy()
def copy_to(input, ctx):
def copy_to(input, ctx, **kwargs):
return input.as_in_context(ctx)
def sum(input, dim, keepdims=False):
......
......@@ -87,13 +87,13 @@ def asnumpy(input):
else:
return input.cpu().detach().numpy()
def copy_to(input, ctx):
def copy_to(input, ctx, **kwargs):
if ctx.type == 'cpu':
return input.cpu()
elif ctx.type == 'cuda':
if ctx.index is not None:
th.cuda.set_device(ctx.index)
return input.cuda()
return input.cuda(**kwargs)
else:
raise RuntimeError('Invalid context', ctx)
......
......@@ -129,7 +129,7 @@ def asnumpy(input):
return input.numpy()
def copy_to(input, ctx):
def copy_to(input, ctx, **kwargs):
with tf.device(ctx):
new_tensor = tf.identity(input)
return new_tensor
......
......@@ -3872,7 +3872,7 @@ class DGLGraph(DGLBaseGraph):
edata=str(self.edge_attr_schemes()))
# pylint: disable=invalid-name
def to(self, ctx):
def to(self, ctx, **kwargs):
"""Move both ndata and edata to the targeted mode (cpu/gpu)
Framework agnostic
......@@ -3898,9 +3898,9 @@ class DGLGraph(DGLBaseGraph):
>>> G = G.to(torch.device('cuda:0'))
"""
for k in self.ndata.keys():
self.ndata[k] = F.copy_to(self.ndata[k], ctx)
self.ndata[k] = F.copy_to(self.ndata[k], ctx, **kwargs)
for k in self.edata.keys():
self.edata[k] = F.copy_to(self.edata[k], ctx)
self.edata[k] = F.copy_to(self.edata[k], ctx, **kwargs)
return self
# pylint: enable=invalid-name
......
......@@ -4025,7 +4025,7 @@ class DGLHeteroGraph(object):
edges = F.tensor(edges)
return F.boolean_mask(edges, e_mask)
def to(self, ctx): # pylint: disable=invalid-name
def to(self, ctx, **kwargs): # pylint: disable=invalid-name
"""Move both ndata and edata to the targeted mode (cpu/gpu)
Framework agnostic
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment