"tests/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "e602ab1b56889c8f999f07aeddb55d641fba1014"
Commit 05f464f8 authored by Gan Quan's avatar Gan Quan Committed by Minjie Wang
Browse files

[Bugfix] fix args with incontiguous node and edge tensor (#153)

parent ed1854e8
...@@ -137,7 +137,7 @@ def arange(start, stop): ...@@ -137,7 +137,7 @@ def arange(start, stop):
return th.arange(start, stop, dtype=th.int64) return th.arange(start, stop, dtype=th.int64)
def zerocopy_to_dlpack(input): def zerocopy_to_dlpack(input):
return dlpack.to_dlpack(input) return dlpack.to_dlpack(input.contiguous())
def zerocopy_from_dlpack(dlpack_tensor): def zerocopy_from_dlpack(dlpack_tensor):
return dlpack.from_dlpack(dlpack_tensor) return dlpack.from_dlpack(dlpack_tensor)
......
...@@ -28,8 +28,21 @@ def test_dlpack(): ...@@ -28,8 +28,21 @@ def test_dlpack():
x[0] = 1 x[0] = 1
assert np.allclose(y.asnumpy(), ans) assert np.allclose(y.asnumpy(), ans)
def th2nd_incontiguous():
import dgl.backend as F
x = th.LongTensor([[0, 1], [2, 3]])
ans = np.array([0, 2])
y = x[:2, 0]
# Uncomment this line and comment the one below to observe error
#dl = dlpack.to_dlpack(y)
dl = F.zerocopy_to_dlpack(y)
z = nd.from_dlpack(dl)
assert np.allclose(z.asnumpy(), ans)
nd2th() nd2th()
th2nd() th2nd()
th2nd_incontiguous()
def test_index(): def test_index():
ans = np.ones((10,), dtype=np.int64) * 10 ans = np.ones((10,), dtype=np.int64) * 10
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment