test_to_device.py 312 Bytes
Newer Older
1
2
3
4
5
6
7
8
import dgl
import backend as F

def test_to_device():
    g = dgl.DGLGraph()
    g.add_nodes(5, {'h' : F.ones((5, 2))})
    g.add_edges([0, 1], [1, 2], {'m' : F.ones((2, 2))})
    if F.is_cuda_available():
9
10
        g = g.to(F.cuda())
        assert g is not None
11
12
13
14


if __name__ == '__main__':
    test_to_device()