Unverified Commit 11974812 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Backend] use new split in MXNet (#660)

* use new split.

* fix.

* fix.

* add zero copy from numpy.
parent f3d3fdf8
...@@ -125,8 +125,17 @@ def stack(seq, dim): ...@@ -125,8 +125,17 @@ def stack(seq, dim):
return nd.stack(*seq, axis=dim) return nd.stack(*seq, axis=dim)
def split(x, sizes_or_sections, dim): def split(x, sizes_or_sections, dim):
if isinstance(sizes_or_sections, list) and len(sizes_or_sections) == 1:
assert len(x) == sizes_or_sections[0]
return [x]
if MX_VERSION.version[0] == 1 and MX_VERSION.version[1] >= 5:
if isinstance(sizes_or_sections, (np.ndarray, list)):
sizes_or_sections1 = tuple(np.cumsum(sizes_or_sections)[:-1])
return nd.split_v2(x, sizes_or_sections1, axis=dim)
if isinstance(sizes_or_sections, list) or isinstance(sizes_or_sections, np.ndarray): if isinstance(sizes_or_sections, list) or isinstance(sizes_or_sections, np.ndarray):
# TODO: fallback to numpy is unfortunate # Old MXNet doesn't support split with different section sizes.
np_arr = x.asnumpy() np_arr = x.asnumpy()
indices = np.cumsum(sizes_or_sections)[:-1] indices = np.cumsum(sizes_or_sections)[:-1]
res = np.split(np_arr, indices, axis=dim) res = np.split(np_arr, indices, axis=dim)
...@@ -249,8 +258,7 @@ def zerocopy_to_numpy(arr): ...@@ -249,8 +258,7 @@ def zerocopy_to_numpy(arr):
return arr.asnumpy() return arr.asnumpy()
def zerocopy_from_numpy(np_data): def zerocopy_from_numpy(np_data):
# NOTE: not zerocopy return mx.nd.from_numpy(np_data, zero_copy=True)
return nd.array(np_data, dtype=np_data.dtype)
def zerocopy_to_dgl_ndarray(arr): def zerocopy_to_dgl_ndarray(arr):
return dglnd.from_dlpack(arr.to_dlpack_for_read()) return dglnd.from_dlpack(arr.to_dlpack_for_read())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment