test_dist_tensor.py 2.58 KB
Newer Older
1
2
3
4
5
6
7
import operator
import os
import unittest

import backend as F

import dgl
8
9
import pytest
from utils import create_random_graph, generate_ip_config, reset_envs
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34

dist_g = None


def rand_mask(shape, dtype):
    return F.randn(shape) > 0


@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support some of operations in DistGraph",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
def setup_module():
    global dist_g

    reset_envs()
    os.environ["DGL_DIST_MODE"] = "standalone"

    dist_g = create_random_graph(10000)
    # Partition the graph.
    num_parts = 1
    graph_name = "dist_graph_test_3"
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
35
36
    dist_g.ndata["features"] = F.unsqueeze(F.arange(0, dist_g.num_nodes()), 1)
    dist_g.edata["features"] = F.unsqueeze(F.arange(0, dist_g.num_edges()), 1)
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
    dgl.distributed.partition_graph(
        dist_g, graph_name, num_parts, "/tmp/dist_graph"
    )

    dgl.distributed.initialize("kv_ip_config.txt")
    dist_g = dgl.distributed.DistGraph(
        graph_name, part_config="/tmp/dist_graph/{}.json".format(graph_name)
    )
    dist_g.edata["mask1"] = dgl.distributed.DistTensor(
        (dist_g.num_edges(),), F.bool, init_func=rand_mask
    )
    dist_g.edata["mask2"] = dgl.distributed.DistTensor(
        (dist_g.num_edges(),), F.bool, init_func=rand_mask
    )


def check_binary_op(key1, key2, key3, op):
    for i in range(0, dist_g.num_edges(), 1000):
        i_end = min(i + 1000, dist_g.num_edges())
        assert F.array_equal(
            dist_g.edata[key3][i:i_end],
            op(dist_g.edata[key1][i:i_end], dist_g.edata[key2][i:i_end]),
        )
60
61
        _ = dist_g.edata[key3][F.tensor([100, 20, 10], F.int32)]
        _ = dist_g.edata[key3][F.tensor([100, 20, 10], F.int64)]
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92


@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support some of operations in DistGraph",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
def test_op():
    dist_g.edata["mask3"] = dist_g.edata["mask1"] | dist_g.edata["mask2"]
    check_binary_op("mask1", "mask2", "mask3", operator.or_)


@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow",
    reason="TF doesn't support some of operations in DistGraph",
)
@unittest.skipIf(
    dgl.backend.backend_name == "mxnet", reason="Turn off Mxnet support"
)
def teardown_module():
    # Since there are two tests in one process, this is needed to make sure
    # the client exits properly.
    dgl.distributed.exit_client()


if __name__ == "__main__":
    setup_module()
    test_op()
    teardown_module()