test_index.py 1.37 KB
Newer Older
1
2
3
4
5
import unittest

import backend as F
import numpy as np

6
7
8
9
import dgl
import dgl.ndarray as nd
from dgl.utils import toindex

10
11
12
13
14

@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support inplace update",
)
15
16
17
def test_dlpack():
    # test dlpack conversion.
    def nd2th():
18
19
20
        ans = np.array(
            [[1.0, 1.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]
        )
21
22
        x = nd.array(np.zeros((3, 4), dtype=np.float32))
        dl = x.to_dlpack()
23
        y = F.zerocopy_from_dlpack(dl)
24
        y[0] = 1
25
26
        print(x)
        print(y)
27
28
29
        assert np.allclose(x.asnumpy(), ans)

    def th2nd():
30
31
32
        ans = np.array(
            [[1.0, 1.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]
        )
33
34
        x = F.zeros((3, 4))
        dl = F.zerocopy_to_dlpack(x)
35
36
        y = nd.from_dlpack(dl)
        x[0] = 1
37
38
        print(x)
        print(y)
39
40
        assert np.allclose(y.asnumpy(), ans)

41
    def th2nd_incontiguous():
42
        x = F.astype(F.tensor([[0, 1], [2, 3]]), F.int64)
43
44
45
        ans = np.array([0, 2])
        y = x[:2, 0]
        # Uncomment this line and comment the one below to observe error
46
        # dl = dlpack.to_dlpack(y)
47
48
        dl = F.zerocopy_to_dlpack(y)
        z = nd.from_dlpack(dl)
49
50
        print(x)
        print(z)
51
52
        assert np.allclose(z.asnumpy(), ans)

53
54
    nd2th()
    th2nd()
55
    th2nd_incontiguous()