Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
4af3f8bc
Unverified
Commit
4af3f8bc
authored
Oct 18, 2018
by
Da Zheng
Committed by
GitHub
Oct 18, 2018
Browse files
update the MXNet backend. (#89)
* update mxnet. * add get_tvmtype. * remove undefined test.
parent
6cbdf37c
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
37 additions
and
11 deletions
+37
-11
python/dgl/backend/mxnet.py
python/dgl/backend/mxnet.py
+33
-3
tests/mxnet/test_basics.py
tests/mxnet/test_basics.py
+4
-7
tests/pytorch/test_basics.py
tests/pytorch/test_basics.py
+0
-1
No files found.
python/dgl/backend/mxnet.py
View file @
4af3f8bc
...
...
@@ -46,8 +46,14 @@ def from_numpy(np_data):
def
pack
(
tensors
):
return
F
.
concat
(
*
tensors
,
dim
=
0
)
def
unpack
(
x
,
indices_or_sections
=
1
):
return
th
.
split
(
x
,
indices_or_sections
)
def
unpack
(
x
,
split_sizes_or_sections
=
1
):
if
isinstance
(
split_sizes_or_sections
,
list
):
np_arr
=
x
.
asnumpy
()
indices
=
np
.
cumsum
(
split_sizes_or_sections
)
res
=
np
.
split
(
np_arr
,
indices
[:
-
1
])
return
[
tensor
(
arr
,
dtype
=
x
.
dtype
)
for
arr
in
res
]
else
:
return
F
.
split
(
x
,
split_sizes_or_sections
)
# TODO this doesn't exist for symbol.
def
shape
(
x
):
...
...
@@ -66,6 +72,9 @@ def unique(x):
return
mx
.
nd
.
array
(
tmp
,
ctx
=
x
.
context
,
dtype
=
x
.
dtype
)
def
gather_row
(
data
,
row_index
):
if
isinstance
(
row_index
,
F
.
NDArray
):
return
F
.
take
(
data
,
row_index
)
else
:
return
data
[
row_index
,]
scatter_row
=
mx
.
nd
.
contrib
.
index_copy
...
...
@@ -114,6 +123,27 @@ def get_context(x):
def
_typestr
(
arr_dtype
):
return
arr_dtype
def
get_tvmtype
(
arr
):
arr_dtype
=
arr
.
dtype
if
arr_dtype
==
np
.
float16
:
return
TVMType
(
'float16'
)
elif
arr_dtype
==
np
.
float32
:
return
TVMType
(
'float32'
)
elif
arr_dtype
==
np
.
float64
:
return
TVMType
(
'float64'
)
elif
arr_dtype
==
np
.
int16
:
return
TVMType
(
'int16'
)
elif
arr_dtype
==
np
.
int32
:
return
TVMType
(
'int32'
)
elif
arr_dtype
==
np
.
int64
:
return
TVMType
(
'int64'
)
elif
arr_dtype
==
np
.
int8
:
return
TVMType
(
'int8'
)
elif
arr_dtype
==
np
.
uint8
:
return
TVMType
(
'uint8'
)
else
:
raise
RuntimeError
(
'Unsupported data type:'
,
arr_dtype
)
def
zerocopy_to_dlpack
(
arr
):
"""Return a dlpack compatible array using zero copy."""
return
arr
.
to_dlpack_for_read
()
...
...
tests/mxnet/test_basics.py
View file @
4af3f8bc
...
...
@@ -9,7 +9,7 @@ reduce_msg_shapes = set()
def
check_eq
(
a
,
b
):
assert
a
.
shape
==
b
.
shape
assert
mx
.
sum
(
a
==
b
)
==
int
(
np
.
prod
(
list
(
a
.
shape
)))
assert
mx
.
nd
.
sum
(
a
==
b
)
.
asnumpy
()
==
int
(
np
.
prod
(
list
(
a
.
shape
)))
def
message_func
(
src
,
edge
):
assert
len
(
src
[
'h'
].
shape
)
==
2
...
...
@@ -53,16 +53,12 @@ def test_batch_setter_getter():
assert
len
(
g
.
get_n_repr
())
==
0
g
.
set_n_repr
({
'h'
:
mx
.
nd
.
zeros
((
10
,
D
))})
# set partial nodes
# TODO we need to enable the test later.
'''
u
=
mx
.
nd
.
array
([
1
,
3
,
5
],
dtype
=
'int64'
)
g
.
set_n_repr
({
'h'
:
mx
.
nd
.
ones
((
3
,
D
))},
u
)
assert
_pfc
(
g
.
get_n_repr
()[
'h'
])
==
[
0.
,
1.
,
0.
,
1.
,
0.
,
1.
,
0.
,
0.
,
0.
,
0.
]
# get partial nodes
u
=
mx
.
nd
.
array
([
1
,
2
,
3
],
dtype
=
'int64'
)
print(g.get_n_repr(u)['h'])
assert
_pfc
(
g
.
get_n_repr
(
u
)[
'h'
])
==
[
1.
,
0.
,
1.
]
'''
'''
s, d, eid
...
...
@@ -127,9 +123,11 @@ def test_batch_setter_autograd():
with
mx
.
autograd
.
record
():
g
=
generate_graph
(
grad
=
True
)
h1
=
g
.
get_n_repr
()[
'h'
]
h1
.
attach_grad
()
# partial set
v
=
mx
.
nd
.
array
([
1
,
2
,
8
],
dtype
=
'int64'
)
hh
=
mx
.
nd
.
zeros
((
len
(
v
),
D
))
hh
.
attach_grad
()
g
.
set_n_repr
({
'h'
:
hh
},
v
)
h2
=
g
.
get_n_repr
()[
'h'
]
h2
.
backward
(
mx
.
nd
.
ones
((
10
,
D
))
*
2
)
...
...
@@ -252,8 +250,7 @@ def test_pull_0deg():
if
__name__
==
'__main__'
:
test_batch_setter_getter
()
# TODO we need to enable it after index_copy is implemented.
#test_batch_setter_autograd()
test_batch_setter_autograd
()
test_batch_send
()
test_batch_recv
()
test_update_routines
()
...
...
tests/pytorch/test_basics.py
View file @
4af3f8bc
...
...
@@ -355,5 +355,4 @@ if __name__ == '__main__':
test_update_routines
()
test_reduce_0deg
()
test_pull_0deg
()
test_send_twice
()
test_send_multigraph
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment