test_tensor.py 1.34 KB
Newer Older
1
2
3
4
import unittest

import backend as F

5
6
import dgl
import dgl.ndarray as nd
7
import numpy as np
8

9
10
11
12
13

@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support inplace update",
)
14
15
16
def test_dlpack():
    # test dlpack conversion.
    def nd2th():
17
18
19
        ans = np.array(
            [[1.0, 1.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]
        )
20
21
        x = nd.array(np.zeros((3, 4), dtype=np.float32))
        dl = x.to_dlpack()
22
        y = F.zerocopy_from_dlpack(dl)
23
        y[0] = 1
24
25
        print(x)
        print(y)
26
27
28
        assert np.allclose(x.asnumpy(), ans)

    def th2nd():
29
30
31
        ans = np.array(
            [[1.0, 1.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]
        )
32
33
        x = F.zeros((3, 4))
        dl = F.zerocopy_to_dlpack(x)
34
35
        y = nd.from_dlpack(dl)
        x[0] = 1
36
37
        print(x)
        print(y)
38
39
        assert np.allclose(y.asnumpy(), ans)

40
    def th2nd_incontiguous():
41
        x = F.astype(F.tensor([[0, 1], [2, 3]]), F.int64)
42
43
44
        ans = np.array([0, 2])
        y = x[:2, 0]
        # Uncomment this line and comment the one below to observe error
45
        # dl = dlpack.to_dlpack(y)
46
47
        dl = F.zerocopy_to_dlpack(y)
        z = nd.from_dlpack(dl)
48
49
        print(x)
        print(z)
50
51
        assert np.allclose(z.asnumpy(), ans)

52
53
    nd2th()
    th2nd()
54
    th2nd_incontiguous()