run_dist_objects.py 6.26 KB
Newer Older
1
import os
2

3
import numpy as np
4
5

import dgl
6
7
8
import dgl.backend as F
from dgl.distributed import load_partition_book

9
10
11
12
13
14
15
16
17
18
19
20
mode = os.environ.get("DIST_DGL_TEST_MODE", "")
graph_name = os.environ.get("DIST_DGL_TEST_GRAPH_NAME", "random_test_graph")
num_part = int(os.environ.get("DIST_DGL_TEST_NUM_PART"))
num_servers_per_machine = int(os.environ.get("DIST_DGL_TEST_NUM_SERVER"))
num_client_per_machine = int(os.environ.get("DIST_DGL_TEST_NUM_CLIENT"))
shared_workspace = os.environ.get("DIST_DGL_TEST_WORKSPACE")
graph_path = os.environ.get("DIST_DGL_TEST_GRAPH_PATH")
part_id = int(os.environ.get("DIST_DGL_TEST_PART_ID"))
net_type = os.environ.get("DIST_DGL_TEST_NET_TYPE")
ip_config = os.environ.get("DIST_DGL_TEST_IP_CONFIG", "ip_config.txt")

os.environ["DGL_DIST_MODE"] = "distributed"
21
22
23
24
25


def zeros_init(shape, dtype):
    return F.zeros(shape, dtype=dtype, ctx=F.cpu())

26
27
28
29
30
31
32
33
34

def run_server(
    graph_name,
    server_id,
    server_count,
    num_clients,
    shared_mem,
    keep_alive=False,
):
35
    # server_count = num_servers_per_machine
36
37
38
39
40
41
42
43
44
45
46
47
    g = dgl.distributed.DistGraphServer(
        server_id,
        ip_config,
        server_count,
        num_clients,
        graph_path + "/{}.json".format(graph_name),
        disable_shared_mem=not shared_mem,
        graph_format=["csc", "coo"],
        keep_alive=keep_alive,
        net_type=net_type,
    )
    print("start server", server_id)
48
49
    g.start()

50

51
52
53
54
##########################################
############### DistTensor ###############
##########################################

55

56
57
def dist_tensor_test_sanity(data_shape, name=None):
    local_rank = dgl.distributed.get_rank() % num_client_per_machine
58
59
60
    dist_ten = dgl.distributed.DistTensor(
        data_shape, F.int32, init_func=zeros_init, name=name
    )
61
62
    # arbitrary value
    stride = 3
63
64
    pos = (part_id // 2) * num_client_per_machine + local_rank
    if part_id % 2 == 0:
65
66
67
        dist_ten[pos * stride : (pos + 1) * stride] = F.ones(
            (stride, 2), dtype=F.int32, ctx=F.cpu()
        ) * (pos + 1)
68
69

    dgl.distributed.client_barrier()
70
71
72
73
    assert F.allclose(
        dist_ten[pos * stride : (pos + 1) * stride],
        F.ones((stride, 2), dtype=F.int32, ctx=F.cpu()) * (pos + 1),
    )
74

75
76

def dist_tensor_test_destroy_recreate(data_shape, name):
77
78
79
    dist_ten = dgl.distributed.DistTensor(
        data_shape, F.float32, name, init_func=zeros_init
    )
80
81
82
83
84
    del dist_ten

    dgl.distributed.client_barrier()

    new_shape = (data_shape[0], 4)
85
86
87
88
    dist_ten = dgl.distributed.DistTensor(
        new_shape, F.float32, name, init_func=zeros_init
    )

89
90

def dist_tensor_test_persistent(data_shape):
91
92
93
94
95
96
97
98
    dist_ten_name = "persistent_dist_tensor"
    dist_ten = dgl.distributed.DistTensor(
        data_shape,
        F.float32,
        dist_ten_name,
        init_func=zeros_init,
        persistent=True,
    )
99
100
    del dist_ten
    try:
101
102
103
104
        dist_ten = dgl.distributed.DistTensor(
            data_shape, F.float32, dist_ten_name
        )
        raise Exception("")
105
106
107
108
    except:
        pass


109
def test_dist_tensor(g):
110
111
    first_type = g.ntypes[0]
    data_shape = (g.number_of_nodes(first_type), 2)
112
113
    dist_tensor_test_sanity(data_shape)
    dist_tensor_test_sanity(data_shape, name="DistTensorSanity")
114
115
116
117
    dist_tensor_test_destroy_recreate(data_shape, name="DistTensorRecreate")
    dist_tensor_test_persistent(data_shape)


118
119
120
121
##########################################
############# DistEmbedding ##############
##########################################

122

123
def dist_embedding_check_sanity(num_nodes, optimizer, name=None):
124
    local_rank = dgl.distributed.get_rank() % num_client_per_machine
125

126
127
128
    emb = dgl.distributed.DistEmbedding(
        num_nodes, 1, name=name, init_func=zeros_init
    )
129
130
131
132
133
134
    lr = 0.001
    optim = optimizer(params=[emb], lr=lr)

    stride = 3

    pos = (part_id // 2) * num_client_per_machine + local_rank
135
    idx = F.arange(pos * stride, (pos + 1) * stride)
136
137
138
139
140
141
142
143
144
145
146
147
148

    if part_id % 2 == 0:
        with F.record_grad():
            value = emb(idx)
            optim.zero_grad()
            loss = F.sum(value + 1, 0)
        loss.backward()
        optim.step()

    dgl.distributed.client_barrier()
    value = emb(idx)
    F.allclose(value, F.ones((len(idx), 1), dtype=F.int32, ctx=F.cpu()) * -lr)

149
150
151
    not_update_idx = F.arange(
        ((num_part + 1) / 2) * num_client_per_machine * stride, num_nodes
    )
152
153
154
155
156
157
    value = emb(not_update_idx)
    assert np.all(F.asnumpy(value) == np.zeros((len(not_update_idx), 1)))


def dist_embedding_check_existing(num_nodes):
    dist_emb_name = "UniqueEmb"
158
159
160
    emb = dgl.distributed.DistEmbedding(
        num_nodes, 1, name=dist_emb_name, init_func=zeros_init
    )
161
    try:
162
163
164
165
        emb1 = dgl.distributed.DistEmbedding(
            num_nodes, 2, name=dist_emb_name, init_func=zeros_init
        )
        raise Exception("")
166
167
168
    except:
        pass

169

170
171
172
def test_dist_embedding(g):
    num_nodes = g.number_of_nodes(g.ntypes[0])
    dist_embedding_check_sanity(num_nodes, dgl.distributed.optim.SparseAdagrad)
173
174
175
176
177
178
    dist_embedding_check_sanity(
        num_nodes, dgl.distributed.optim.SparseAdagrad, name="SomeEmbedding"
    )
    dist_embedding_check_sanity(
        num_nodes, dgl.distributed.optim.SparseAdam, name="SomeEmbedding"
    )
179
180
181

    dist_embedding_check_existing(num_nodes)

182

183
if mode == "server":
184
185
186
187
188
189
190
191
192
193
    shared_mem = bool(int(os.environ.get("DIST_DGL_TEST_SHARED_MEM")))
    server_id = int(os.environ.get("DIST_DGL_TEST_SERVER_ID"))
    run_server(
        graph_name,
        server_id,
        server_count=num_servers_per_machine,
        num_clients=num_part * num_client_per_machine,
        shared_mem=shared_mem,
        keep_alive=False,
    )
194
elif mode == "client":
195
    os.environ["DGL_NUM_SERVER"] = str(num_servers_per_machine)
196
197
    dgl.distributed.initialize(ip_config, net_type=net_type)

198
199
200
    gpb, graph_name, _, _ = load_partition_book(
        graph_path + "/{}.json".format(graph_name), part_id, None
    )
201
    g = dgl.distributed.DistGraph(graph_name, gpb=gpb)
202

203
204
205
206
    target_func_map = {
        "DistTensor": test_dist_tensor,
        "DistEmbedding": test_dist_embedding,
    }
207
208
209
210
211

    target = os.environ.get("DIST_DGL_TEST_OBJECT_TYPE", "")
    if target not in target_func_map:
        for test_func in target_func_map.values():
            test_func(g)
212
    else:
213
214
        target_func_map[target](g)

215
216
217
else:
    print("DIST_DGL_TEST_MODE has to be either server or client")
    exit(1)