test_generators.py 602 Bytes
Newer Older
1
2
import unittest

3
4
import backend as F

5
import dgl
6
import numpy as np
7
8
9
10
11


@unittest.skipIf(
    F._default_context_str == "gpu", reason="GPU random choice not implemented"
)
12
13
def test_rand_graph():
    g = dgl.rand_graph(10000, 100000)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
14
15
    assert g.num_nodes() == 10000
    assert g.num_edges() == 100000
16
17
18
19
20
21
22
23
24
25
    # test random seed
    dgl.random.seed(42)
    g1 = dgl.rand_graph(100, 30)
    dgl.random.seed(42)
    g2 = dgl.rand_graph(100, 30)
    u1, v1 = g1.edges()
    u2, v2 = g2.edges()
    assert F.array_equal(u1, u2)
    assert F.array_equal(v1, v2)

26
27

if __name__ == "__main__":
28
    test_rand_graph()