Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
04522a76
Unverified
Commit
04522a76
authored
Jun 15, 2020
by
Zihao Ye
Committed by
GitHub
Jun 15, 2020
Browse files
[bugfix] Quick fix of #1547 (#1600)
* upd * upd
parent
57d111f9
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
9 additions
and
9 deletions
+9
-9
python/dgl/backend/backend.py
python/dgl/backend/backend.py
+1
-1
python/dgl/backend/mxnet/tensor.py
python/dgl/backend/mxnet/tensor.py
+1
-1
python/dgl/backend/pytorch/tensor.py
python/dgl/backend/pytorch/tensor.py
+2
-2
python/dgl/backend/tensorflow/tensor.py
python/dgl/backend/tensorflow/tensor.py
+1
-1
python/dgl/graph.py
python/dgl/graph.py
+3
-3
python/dgl/heterograph.py
python/dgl/heterograph.py
+1
-1
No files found.
python/dgl/backend/backend.py
View file @
04522a76
...
@@ -269,7 +269,7 @@ def asnumpy(input):
...
@@ -269,7 +269,7 @@ def asnumpy(input):
"""
"""
pass
pass
def
copy_to
(
input
,
ctx
):
def
copy_to
(
input
,
ctx
,
**
kwargs
):
"""Copy the given tensor to the context.
"""Copy the given tensor to the context.
Parameters
Parameters
...
...
python/dgl/backend/mxnet/tensor.py
View file @
04522a76
...
@@ -115,7 +115,7 @@ def astype(input, ty):
...
@@ -115,7 +115,7 @@ def astype(input, ty):
def
asnumpy
(
input
):
def
asnumpy
(
input
):
return
input
.
asnumpy
()
return
input
.
asnumpy
()
def
copy_to
(
input
,
ctx
):
def
copy_to
(
input
,
ctx
,
**
kwargs
):
return
input
.
as_in_context
(
ctx
)
return
input
.
as_in_context
(
ctx
)
def
sum
(
input
,
dim
,
keepdims
=
False
):
def
sum
(
input
,
dim
,
keepdims
=
False
):
...
...
python/dgl/backend/pytorch/tensor.py
View file @
04522a76
...
@@ -87,13 +87,13 @@ def asnumpy(input):
...
@@ -87,13 +87,13 @@ def asnumpy(input):
else
:
else
:
return
input
.
cpu
().
detach
().
numpy
()
return
input
.
cpu
().
detach
().
numpy
()
def
copy_to
(
input
,
ctx
):
def
copy_to
(
input
,
ctx
,
**
kwargs
):
if
ctx
.
type
==
'cpu'
:
if
ctx
.
type
==
'cpu'
:
return
input
.
cpu
()
return
input
.
cpu
()
elif
ctx
.
type
==
'cuda'
:
elif
ctx
.
type
==
'cuda'
:
if
ctx
.
index
is
not
None
:
if
ctx
.
index
is
not
None
:
th
.
cuda
.
set_device
(
ctx
.
index
)
th
.
cuda
.
set_device
(
ctx
.
index
)
return
input
.
cuda
()
return
input
.
cuda
(
**
kwargs
)
else
:
else
:
raise
RuntimeError
(
'Invalid context'
,
ctx
)
raise
RuntimeError
(
'Invalid context'
,
ctx
)
...
...
python/dgl/backend/tensorflow/tensor.py
View file @
04522a76
...
@@ -129,7 +129,7 @@ def asnumpy(input):
...
@@ -129,7 +129,7 @@ def asnumpy(input):
return
input
.
numpy
()
return
input
.
numpy
()
def
copy_to
(
input
,
ctx
):
def
copy_to
(
input
,
ctx
,
**
kwargs
):
with
tf
.
device
(
ctx
):
with
tf
.
device
(
ctx
):
new_tensor
=
tf
.
identity
(
input
)
new_tensor
=
tf
.
identity
(
input
)
return
new_tensor
return
new_tensor
...
...
python/dgl/graph.py
View file @
04522a76
...
@@ -3872,7 +3872,7 @@ class DGLGraph(DGLBaseGraph):
...
@@ -3872,7 +3872,7 @@ class DGLGraph(DGLBaseGraph):
edata
=
str
(
self
.
edge_attr_schemes
()))
edata
=
str
(
self
.
edge_attr_schemes
()))
# pylint: disable=invalid-name
# pylint: disable=invalid-name
def
to
(
self
,
ctx
):
def
to
(
self
,
ctx
,
**
kwargs
):
"""Move both ndata and edata to the targeted mode (cpu/gpu)
"""Move both ndata and edata to the targeted mode (cpu/gpu)
Framework agnostic
Framework agnostic
...
@@ -3898,9 +3898,9 @@ class DGLGraph(DGLBaseGraph):
...
@@ -3898,9 +3898,9 @@ class DGLGraph(DGLBaseGraph):
>>> G = G.to(torch.device('cuda:0'))
>>> G = G.to(torch.device('cuda:0'))
"""
"""
for
k
in
self
.
ndata
.
keys
():
for
k
in
self
.
ndata
.
keys
():
self
.
ndata
[
k
]
=
F
.
copy_to
(
self
.
ndata
[
k
],
ctx
)
self
.
ndata
[
k
]
=
F
.
copy_to
(
self
.
ndata
[
k
],
ctx
,
**
kwargs
)
for
k
in
self
.
edata
.
keys
():
for
k
in
self
.
edata
.
keys
():
self
.
edata
[
k
]
=
F
.
copy_to
(
self
.
edata
[
k
],
ctx
)
self
.
edata
[
k
]
=
F
.
copy_to
(
self
.
edata
[
k
],
ctx
,
**
kwargs
)
return
self
return
self
# pylint: enable=invalid-name
# pylint: enable=invalid-name
...
...
python/dgl/heterograph.py
View file @
04522a76
...
@@ -4025,7 +4025,7 @@ class DGLHeteroGraph(object):
...
@@ -4025,7 +4025,7 @@ class DGLHeteroGraph(object):
edges
=
F
.
tensor
(
edges
)
edges
=
F
.
tensor
(
edges
)
return
F
.
boolean_mask
(
edges
,
e_mask
)
return
F
.
boolean_mask
(
edges
,
e_mask
)
def
to
(
self
,
ctx
):
# pylint: disable=invalid-name
def
to
(
self
,
ctx
,
**
kwargs
):
# pylint: disable=invalid-name
"""Move both ndata and edata to the targeted mode (cpu/gpu)
"""Move both ndata and edata to the targeted mode (cpu/gpu)
Framework agnostic
Framework agnostic
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment