"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "ba0d665bbbbd8587777a33d22b059ed40c0d9866"
Unverified Commit 4af3f8bc authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

update the MXNet backend. (#89)

* update mxnet.

* add get_tvmtype.

* remove undefined test.
parent 6cbdf37c
...@@ -46,8 +46,14 @@ def from_numpy(np_data): ...@@ -46,8 +46,14 @@ def from_numpy(np_data):
def pack(tensors): def pack(tensors):
return F.concat(*tensors, dim=0) return F.concat(*tensors, dim=0)
def unpack(x, indices_or_sections=1): def unpack(x, split_sizes_or_sections=1):
return th.split(x, indices_or_sections) if isinstance(split_sizes_or_sections, list):
np_arr = x.asnumpy()
indices = np.cumsum(split_sizes_or_sections)
res = np.split(np_arr, indices[:-1])
return [tensor(arr, dtype=x.dtype) for arr in res]
else:
return F.split(x, split_sizes_or_sections)
# TODO this doesn't exist for symbol. # TODO this doesn't exist for symbol.
def shape(x): def shape(x):
...@@ -66,7 +72,10 @@ def unique(x): ...@@ -66,7 +72,10 @@ def unique(x):
return mx.nd.array(tmp, ctx=x.context, dtype=x.dtype) return mx.nd.array(tmp, ctx=x.context, dtype=x.dtype)
def gather_row(data, row_index): def gather_row(data, row_index):
return data[row_index,] if isinstance(row_index, F.NDArray):
return F.take(data, row_index)
else:
return data[row_index,]
scatter_row = mx.nd.contrib.index_copy scatter_row = mx.nd.contrib.index_copy
...@@ -114,6 +123,27 @@ def get_context(x): ...@@ -114,6 +123,27 @@ def get_context(x):
def _typestr(arr_dtype): def _typestr(arr_dtype):
return arr_dtype return arr_dtype
def get_tvmtype(arr):
arr_dtype = arr.dtype
if arr_dtype == np.float16:
return TVMType('float16')
elif arr_dtype == np.float32:
return TVMType('float32')
elif arr_dtype == np.float64:
return TVMType('float64')
elif arr_dtype == np.int16:
return TVMType('int16')
elif arr_dtype == np.int32:
return TVMType('int32')
elif arr_dtype == np.int64:
return TVMType('int64')
elif arr_dtype == np.int8:
return TVMType('int8')
elif arr_dtype == np.uint8:
return TVMType('uint8')
else:
raise RuntimeError('Unsupported data type:', arr_dtype)
def zerocopy_to_dlpack(arr): def zerocopy_to_dlpack(arr):
"""Return a dlpack compatible array using zero copy.""" """Return a dlpack compatible array using zero copy."""
return arr.to_dlpack_for_read() return arr.to_dlpack_for_read()
......
...@@ -9,7 +9,7 @@ reduce_msg_shapes = set() ...@@ -9,7 +9,7 @@ reduce_msg_shapes = set()
def check_eq(a, b): def check_eq(a, b):
assert a.shape == b.shape assert a.shape == b.shape
assert mx.sum(a == b) == int(np.prod(list(a.shape))) assert mx.nd.sum(a == b).asnumpy() == int(np.prod(list(a.shape)))
def message_func(src, edge): def message_func(src, edge):
assert len(src['h'].shape) == 2 assert len(src['h'].shape) == 2
...@@ -53,16 +53,12 @@ def test_batch_setter_getter(): ...@@ -53,16 +53,12 @@ def test_batch_setter_getter():
assert len(g.get_n_repr()) == 0 assert len(g.get_n_repr()) == 0
g.set_n_repr({'h' : mx.nd.zeros((10, D))}) g.set_n_repr({'h' : mx.nd.zeros((10, D))})
# set partial nodes # set partial nodes
# TODO we need to enable the test later.
'''
u = mx.nd.array([1, 3, 5], dtype='int64') u = mx.nd.array([1, 3, 5], dtype='int64')
g.set_n_repr({'h' : mx.nd.ones((3, D))}, u) g.set_n_repr({'h' : mx.nd.ones((3, D))}, u)
assert _pfc(g.get_n_repr()['h']) == [0., 1., 0., 1., 0., 1., 0., 0., 0., 0.] assert _pfc(g.get_n_repr()['h']) == [0., 1., 0., 1., 0., 1., 0., 0., 0., 0.]
# get partial nodes # get partial nodes
u = mx.nd.array([1, 2, 3], dtype='int64') u = mx.nd.array([1, 2, 3], dtype='int64')
print(g.get_n_repr(u)['h'])
assert _pfc(g.get_n_repr(u)['h']) == [1., 0., 1.] assert _pfc(g.get_n_repr(u)['h']) == [1., 0., 1.]
'''
''' '''
s, d, eid s, d, eid
...@@ -127,9 +123,11 @@ def test_batch_setter_autograd(): ...@@ -127,9 +123,11 @@ def test_batch_setter_autograd():
with mx.autograd.record(): with mx.autograd.record():
g = generate_graph(grad=True) g = generate_graph(grad=True)
h1 = g.get_n_repr()['h'] h1 = g.get_n_repr()['h']
h1.attach_grad()
# partial set # partial set
v = mx.nd.array([1, 2, 8], dtype='int64') v = mx.nd.array([1, 2, 8], dtype='int64')
hh = mx.nd.zeros((len(v), D)) hh = mx.nd.zeros((len(v), D))
hh.attach_grad()
g.set_n_repr({'h' : hh}, v) g.set_n_repr({'h' : hh}, v)
h2 = g.get_n_repr()['h'] h2 = g.get_n_repr()['h']
h2.backward(mx.nd.ones((10, D)) * 2) h2.backward(mx.nd.ones((10, D)) * 2)
...@@ -252,8 +250,7 @@ def test_pull_0deg(): ...@@ -252,8 +250,7 @@ def test_pull_0deg():
if __name__ == '__main__': if __name__ == '__main__':
test_batch_setter_getter() test_batch_setter_getter()
# TODO we need to enable it after index_copy is implemented. test_batch_setter_autograd()
#test_batch_setter_autograd()
test_batch_send() test_batch_send()
test_batch_recv() test_batch_recv()
test_update_routines() test_update_routines()
......
...@@ -355,5 +355,4 @@ if __name__ == '__main__': ...@@ -355,5 +355,4 @@ if __name__ == '__main__':
test_update_routines() test_update_routines()
test_reduce_0deg() test_reduce_0deg()
test_pull_0deg() test_pull_0deg()
test_send_twice()
test_send_multigraph() test_send_multigraph()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment