test_to_device.py 412 Bytes
Newer Older
1
2
import dgl
import backend as F
3
import unittest
4

5
6

@unittest.skipIf(F._default_context_str == 'cpu', reason="Need gpu for this test")
7
8
9
10
11
def test_to_device():
    g = dgl.DGLGraph()
    g.add_nodes(5, {'h' : F.ones((5, 2))})
    g.add_edges([0, 1], [1, 2], {'m' : F.ones((2, 2))})
    if F.is_cuda_available():
12
13
        g = g.to(F.cuda())
        assert g is not None
14
15
16
17


if __name__ == '__main__':
    test_to_device()