Unverified Commit 4af3f8bc authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

update the MXNet backend. (#89)

* update mxnet.

* add get_tvmtype.

* remove undefined test.
parent 6cbdf37c
......@@ -46,8 +46,14 @@ def from_numpy(np_data):
def pack(tensors):
return F.concat(*tensors, dim=0)
def unpack(x, indices_or_sections=1):
return th.split(x, indices_or_sections)
def unpack(x, split_sizes_or_sections=1):
if isinstance(split_sizes_or_sections, list):
np_arr = x.asnumpy()
indices = np.cumsum(split_sizes_or_sections)
res = np.split(np_arr, indices[:-1])
return [tensor(arr, dtype=x.dtype) for arr in res]
else:
return F.split(x, split_sizes_or_sections)
# TODO this doesn't exist for symbol.
def shape(x):
......@@ -66,6 +72,9 @@ def unique(x):
return mx.nd.array(tmp, ctx=x.context, dtype=x.dtype)
def gather_row(data, row_index):
if isinstance(row_index, F.NDArray):
return F.take(data, row_index)
else:
return data[row_index,]
scatter_row = mx.nd.contrib.index_copy
......@@ -114,6 +123,27 @@ def get_context(x):
def _typestr(arr_dtype):
return arr_dtype
def get_tvmtype(arr):
arr_dtype = arr.dtype
if arr_dtype == np.float16:
return TVMType('float16')
elif arr_dtype == np.float32:
return TVMType('float32')
elif arr_dtype == np.float64:
return TVMType('float64')
elif arr_dtype == np.int16:
return TVMType('int16')
elif arr_dtype == np.int32:
return TVMType('int32')
elif arr_dtype == np.int64:
return TVMType('int64')
elif arr_dtype == np.int8:
return TVMType('int8')
elif arr_dtype == np.uint8:
return TVMType('uint8')
else:
raise RuntimeError('Unsupported data type:', arr_dtype)
def zerocopy_to_dlpack(arr):
"""Return a dlpack compatible array using zero copy."""
return arr.to_dlpack_for_read()
......
......@@ -9,7 +9,7 @@ reduce_msg_shapes = set()
def check_eq(a, b):
assert a.shape == b.shape
assert mx.sum(a == b) == int(np.prod(list(a.shape)))
assert mx.nd.sum(a == b).asnumpy() == int(np.prod(list(a.shape)))
def message_func(src, edge):
assert len(src['h'].shape) == 2
......@@ -53,16 +53,12 @@ def test_batch_setter_getter():
assert len(g.get_n_repr()) == 0
g.set_n_repr({'h' : mx.nd.zeros((10, D))})
# set partial nodes
# TODO we need to enable the test later.
'''
u = mx.nd.array([1, 3, 5], dtype='int64')
g.set_n_repr({'h' : mx.nd.ones((3, D))}, u)
assert _pfc(g.get_n_repr()['h']) == [0., 1., 0., 1., 0., 1., 0., 0., 0., 0.]
# get partial nodes
u = mx.nd.array([1, 2, 3], dtype='int64')
print(g.get_n_repr(u)['h'])
assert _pfc(g.get_n_repr(u)['h']) == [1., 0., 1.]
'''
'''
s, d, eid
......@@ -127,9 +123,11 @@ def test_batch_setter_autograd():
with mx.autograd.record():
g = generate_graph(grad=True)
h1 = g.get_n_repr()['h']
h1.attach_grad()
# partial set
v = mx.nd.array([1, 2, 8], dtype='int64')
hh = mx.nd.zeros((len(v), D))
hh.attach_grad()
g.set_n_repr({'h' : hh}, v)
h2 = g.get_n_repr()['h']
h2.backward(mx.nd.ones((10, D)) * 2)
......@@ -252,8 +250,7 @@ def test_pull_0deg():
if __name__ == '__main__':
test_batch_setter_getter()
# TODO we need to enable it after index_copy is implemented.
#test_batch_setter_autograd()
test_batch_setter_autograd()
test_batch_send()
test_batch_recv()
test_update_routines()
......
......@@ -355,5 +355,4 @@ if __name__ == '__main__':
test_update_routines()
test_reduce_0deg()
test_pull_0deg()
test_send_twice()
test_send_multigraph()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment