test_dist_objects.py 5.76 KB
Newer Older
1
import multiprocessing as mp
2
import os
3
import shutil
4
import subprocess
5
import unittest
6

7
8
9
import dgl
import dgl.backend as F

10
import numpy as np
11
12
13
14
import pytest
import utils
from dgl.distributed import partition_graph

15
16
graph_name = os.environ.get("DIST_DGL_TEST_GRAPH_NAME", "random_test_graph")
target = os.environ.get("DIST_DGL_TEST_OBJECT_TYPE", "")
17
blacklist = os.environ.get("DIST_DGL_TEST_OBJECT_TYPE_BLACKLIST", "")
18
19
20
shared_workspace = os.environ.get(
    "DIST_DGL_TEST_WORKSPACE", "/shared_workspace/dgl_dist_tensor_test/"
)
21

22
23
24
25

def create_graph(num_part, dist_graph_path, hetero):
    if not hetero:
        g = dgl.rand_graph(10000, 42000)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
26
27
        g.ndata["feat"] = F.unsqueeze(F.arange(0, g.num_nodes()), 1)
        g.edata["feat"] = F.unsqueeze(F.arange(0, g.num_edges()), 1)
28
29
30
31
32
        g.ndata["in_degrees"] = g.in_degrees()
        g.ndata["out_degrees"] = g.out_degrees()

        etype = g.etypes[0]
        ntype = g.ntypes[0]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
33
        edge_u, edge_v = g.find_edges(F.arange(0, g.num_edges(etype)))
34
35
36
        g.edges[etype].data["edge_u"] = edge_u
        g.edges[etype].data["edge_v"] = edge_v

37
38
39
        orig_nid, orig_eid = partition_graph(
            g, graph_name, num_part, dist_graph_path, return_mapping=True
        )
40

41
42
43
        orig_nid_f = os.path.join(
            dist_graph_path, f"orig_nid_array_{ntype}.npy"
        )
44
        np.save(orig_nid_f, orig_nid.numpy())
45
46
47
        orig_eid_f = os.path.join(
            dist_graph_path, f"orig_eid_array_{etype}.npy"
        )
48
49
        np.save(orig_eid_f, orig_eid.numpy())

50
51
    else:
        from scipy import sparse as spsp
52
53
54

        num_nodes = {"n1": 10000, "n2": 10010, "n3": 10020}
        etypes = [("n1", "r1", "n2"), ("n1", "r2", "n3"), ("n2", "r3", "n3")]
55
56
57
        edges = {}
        for etype in etypes:
            src_ntype, _, dst_ntype = etype
58
59
60
61
62
63
64
            arr = spsp.random(
                num_nodes[src_ntype],
                num_nodes[dst_ntype],
                density=0.001,
                format="coo",
                random_state=100,
            )
65
66
            edges[etype] = (arr.row, arr.col)
        g = dgl.heterograph(edges, num_nodes)
67

68
        g.nodes["n1"].data["feat"] = F.unsqueeze(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
69
            F.arange(0, g.num_nodes("n1")), 1
70
71
        )
        g.edges["r1"].data["feat"] = F.unsqueeze(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
72
            F.arange(0, g.num_edges("r1")), 1
73
        )
74

75
        for _, etype, _ in etypes:
76
            edge_u, edge_v = g.find_edges(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
77
                F.arange(0, g.num_edges(etype)), etype=etype
78
            )
79
80
81
            g.edges[etype].data["edge_u"] = edge_u
            g.edges[etype].data["edge_v"] = edge_v

82
83
84
        orig_nid, orig_eid = partition_graph(
            g, graph_name, num_part, dist_graph_path, return_mapping=True
        )
85
86

        for n_type, tensor in orig_nid.items():
87
88
89
            orig_nid_f = os.path.join(
                dist_graph_path, f"orig_nid_array_{n_type}.npy"
            )
90
91
            np.save(orig_nid_f, tensor.numpy())
        for e_type, tensor in orig_eid.items():
92
93
94
            orig_eid_f = os.path.join(
                dist_graph_path, f"orig_eid_array_{e_type}.npy"
            )
95
96
            np.save(orig_eid_f, tensor.numpy())

97

98
99
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@pytest.mark.parametrize("net_type", ["tensorpipe", "socket"])
100
101
102
103
@pytest.mark.parametrize("num_servers", [1, 4])
@pytest.mark.parametrize("num_clients", [1, 4])
@pytest.mark.parametrize("hetero", [False, True])
@pytest.mark.parametrize("shared_mem", [False, True])
104
def test_dist_objects(net_type, num_servers, num_clients, hetero, shared_mem):
105
    if not shared_mem and num_servers > 1:
106
107
108
109
        pytest.skip(
            f"Backup servers are not supported when shared memory is disabled"
        )
    ip_config = os.environ.get("DIST_DGL_TEST_IP_CONFIG", "ip_config.txt")
110
111
112
113

    ips = utils.get_ips(ip_config)
    num_part = len(ips)

114
115
116
    test_bin = os.path.join(
        os.environ.get("DIST_DGL_TEST_PY_BIN_DIR", "."), "run_dist_objects.py"
    )
117

118
    dist_graph_path = os.path.join(
119
        shared_workspace, "hetero_dist_graph" if hetero else "dist_graph"
120
    )
121
122
123
    if not os.path.isdir(dist_graph_path):
        create_graph(num_part, dist_graph_path, hetero)

124
    base_envs = (
125
        f"DIST_DGL_TEST_WORKSPACE={shared_workspace} "
126
127
128
129
130
131
132
        f"DIST_DGL_TEST_NUM_PART={num_part} "
        f"DIST_DGL_TEST_NUM_SERVER={num_servers} "
        f"DIST_DGL_TEST_NUM_CLIENT={num_clients} "
        f"DIST_DGL_TEST_NET_TYPE={net_type} "
        f"DIST_DGL_TEST_GRAPH_PATH={dist_graph_path} "
        f"DIST_DGL_TEST_IP_CONFIG={ip_config} "
    )
133
134
135
136
137
138

    procs = []
    # Start server
    server_id = 0
    for part_id, ip in enumerate(ips):
        for _ in range(num_servers):
139
140
141
142
143
144
145
146
147
            cmd_envs = (
                base_envs + f"DIST_DGL_TEST_SERVER_ID={server_id} "
                f"DIST_DGL_TEST_PART_ID={part_id} "
                f"DIST_DGL_TEST_SHARED_MEM={str(int(shared_mem))} "
                f"DIST_DGL_TEST_MODE=server "
            )
            procs.append(
                utils.execute_remote(f"{cmd_envs} python3 {test_bin}", ip)
            )
148
149
150
151
            server_id += 1
    # Start client processes
    for part_id, ip in enumerate(ips):
        for _ in range(num_clients):
152
153
154
            cmd_envs = (
                base_envs + f"DIST_DGL_TEST_PART_ID={part_id} "
                f"DIST_DGL_TEST_OBJECT_TYPE={target} "
155
                f"DIST_DGL_TEST_OBJECT_TYPE_BLACKLIST={blacklist} "
156
157
158
159
160
                f"DIST_DGL_TEST_MODE=client "
            )
            procs.append(
                utils.execute_remote(f"{cmd_envs} python3 {test_bin}", ip)
            )
161
162
163
164

    for p in procs:
        p.join()
        assert p.exitcode == 0
165

166

167
def teardown():
168
    for name in ["dist_graph", "hetero_dist_graph"]:
169
170
171
172
        path = os.path.join(shared_workspace, name)
        if os.path.exists(path):
            print(f"Removing {path}...")
            shutil.rmtree(path)