run_dist_objects.py 8.28 KB
Newer Older
1
import os
2

3
import numpy as np
4
5

import dgl
6
7
8
import dgl.backend as F
from dgl.distributed import load_partition_book

9
10
11
12
13
14
15
16
17
18
19
20
mode = os.environ.get("DIST_DGL_TEST_MODE", "")
graph_name = os.environ.get("DIST_DGL_TEST_GRAPH_NAME", "random_test_graph")
num_part = int(os.environ.get("DIST_DGL_TEST_NUM_PART"))
num_servers_per_machine = int(os.environ.get("DIST_DGL_TEST_NUM_SERVER"))
num_client_per_machine = int(os.environ.get("DIST_DGL_TEST_NUM_CLIENT"))
shared_workspace = os.environ.get("DIST_DGL_TEST_WORKSPACE")
graph_path = os.environ.get("DIST_DGL_TEST_GRAPH_PATH")
part_id = int(os.environ.get("DIST_DGL_TEST_PART_ID"))
net_type = os.environ.get("DIST_DGL_TEST_NET_TYPE")
ip_config = os.environ.get("DIST_DGL_TEST_IP_CONFIG", "ip_config.txt")

os.environ["DGL_DIST_MODE"] = "distributed"
21
22
23
24
25


def zeros_init(shape, dtype):
    return F.zeros(shape, dtype=dtype, ctx=F.cpu())

26
27
28
29
30
31
32
33
34

def run_server(
    graph_name,
    server_id,
    server_count,
    num_clients,
    shared_mem,
    keep_alive=False,
):
35
    # server_count = num_servers_per_machine
36
37
38
39
40
41
42
43
44
45
46
47
    g = dgl.distributed.DistGraphServer(
        server_id,
        ip_config,
        server_count,
        num_clients,
        graph_path + "/{}.json".format(graph_name),
        disable_shared_mem=not shared_mem,
        graph_format=["csc", "coo"],
        keep_alive=keep_alive,
        net_type=net_type,
    )
    print("start server", server_id)
48
49
    g.start()

50

51
52
53
54
##########################################
############### DistTensor ###############
##########################################

55

56
57
def dist_tensor_test_sanity(data_shape, name=None):
    local_rank = dgl.distributed.get_rank() % num_client_per_machine
58
59
60
    dist_ten = dgl.distributed.DistTensor(
        data_shape, F.int32, init_func=zeros_init, name=name
    )
61
62
    # arbitrary value
    stride = 3
63
64
    pos = (part_id // 2) * num_client_per_machine + local_rank
    if part_id % 2 == 0:
65
        dist_ten[pos * stride: (pos + 1) * stride] = F.ones(
66
67
            (stride, 2), dtype=F.int32, ctx=F.cpu()
        ) * (pos + 1)
68
69

    dgl.distributed.client_barrier()
70
    assert F.allclose(
71
        dist_ten[pos * stride: (pos + 1) * stride],
72
73
        F.ones((stride, 2), dtype=F.int32, ctx=F.cpu()) * (pos + 1),
    )
74

75
76

def dist_tensor_test_destroy_recreate(data_shape, name):
77
78
79
    dist_ten = dgl.distributed.DistTensor(
        data_shape, F.float32, name, init_func=zeros_init
    )
80
81
82
83
84
    del dist_ten

    dgl.distributed.client_barrier()

    new_shape = (data_shape[0], 4)
85
86
87
88
    dist_ten = dgl.distributed.DistTensor(
        new_shape, F.float32, name, init_func=zeros_init
    )

89
90

def dist_tensor_test_persistent(data_shape):
91
92
93
94
95
96
97
98
    dist_ten_name = "persistent_dist_tensor"
    dist_ten = dgl.distributed.DistTensor(
        data_shape,
        F.float32,
        dist_ten_name,
        init_func=zeros_init,
        persistent=True,
    )
99
100
    del dist_ten
    try:
101
102
103
104
        dist_ten = dgl.distributed.DistTensor(
            data_shape, F.float32, dist_ten_name
        )
        raise Exception("")
105
    except BaseException:
106
107
108
        pass


109
def test_dist_tensor(g):
110
111
    first_type = g.ntypes[0]
    data_shape = (g.number_of_nodes(first_type), 2)
112
113
    dist_tensor_test_sanity(data_shape)
    dist_tensor_test_sanity(data_shape, name="DistTensorSanity")
114
115
116
117
    dist_tensor_test_destroy_recreate(data_shape, name="DistTensorRecreate")
    dist_tensor_test_persistent(data_shape)


118
119
120
121
##########################################
############# DistEmbedding ##############
##########################################

122

123
def dist_embedding_check_sanity(num_nodes, optimizer, name=None):
124
    local_rank = dgl.distributed.get_rank() % num_client_per_machine
125

126
127
128
    emb = dgl.distributed.DistEmbedding(
        num_nodes, 1, name=name, init_func=zeros_init
    )
129
130
131
132
133
134
    lr = 0.001
    optim = optimizer(params=[emb], lr=lr)

    stride = 3

    pos = (part_id // 2) * num_client_per_machine + local_rank
135
    idx = F.arange(pos * stride, (pos + 1) * stride)
136
137
138
139
140
141
142
143
144
145
146
147
148

    if part_id % 2 == 0:
        with F.record_grad():
            value = emb(idx)
            optim.zero_grad()
            loss = F.sum(value + 1, 0)
        loss.backward()
        optim.step()

    dgl.distributed.client_barrier()
    value = emb(idx)
    F.allclose(value, F.ones((len(idx), 1), dtype=F.int32, ctx=F.cpu()) * -lr)

149
150
151
    not_update_idx = F.arange(
        ((num_part + 1) / 2) * num_client_per_machine * stride, num_nodes
    )
152
153
154
155
156
157
    value = emb(not_update_idx)
    assert np.all(F.asnumpy(value) == np.zeros((len(not_update_idx), 1)))


def dist_embedding_check_existing(num_nodes):
    dist_emb_name = "UniqueEmb"
158
159
160
    emb = dgl.distributed.DistEmbedding(
        num_nodes, 1, name=dist_emb_name, init_func=zeros_init
    )
161
    try:
162
163
164
165
        emb1 = dgl.distributed.DistEmbedding(
            num_nodes, 2, name=dist_emb_name, init_func=zeros_init
        )
        raise Exception("")
166
    except BaseException:
167
168
        pass

169

170
171
172
def test_dist_embedding(g):
    num_nodes = g.number_of_nodes(g.ntypes[0])
    dist_embedding_check_sanity(num_nodes, dgl.distributed.optim.SparseAdagrad)
173
174
175
176
177
178
    dist_embedding_check_sanity(
        num_nodes, dgl.distributed.optim.SparseAdagrad, name="SomeEmbedding"
    )
    dist_embedding_check_sanity(
        num_nodes, dgl.distributed.optim.SparseAdam, name="SomeEmbedding"
    )
179
180
181

    dist_embedding_check_existing(num_nodes)

182

183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
##########################################
############# DistOptimizer ##############
##########################################


def dist_optimizer_check_store(g):
    num_nodes = g.number_of_nodes(g.ntypes[0])
    rank = g.rank()
    try:
        emb = dgl.distributed.DistEmbedding(
            num_nodes, 1, name="optimizer_test", init_func=zeros_init
        )
        emb2 = dgl.distributed.DistEmbedding(
            num_nodes, 5, name="optimizer_test2", init_func=zeros_init
        )
        emb_optimizer = dgl.distributed.optim.SparseAdam([emb, emb2], lr=0.1)
        if rank == 0:
            name_to_state = {}
            for _, emb_states in emb_optimizer._state.items():
                for state in emb_states:
                    name_to_state[state.name] = F.uniform(
                        state.shape, F.float32, F.cpu(), 0, 1
                    )
                    state[
                        F.arange(0, num_nodes, F.int64, F.cpu())
                    ] = name_to_state[state.name]
        emb_optimizer.save("emb.pt")
        new_emb_optimizer = dgl.distributed.optim.SparseAdam(
            [emb, emb2], lr=000.1, eps=2e-08, betas=(0.1, 0.222)
        )
        new_emb_optimizer.load("emb.pt")
        if rank == 0:
            for _, emb_states in new_emb_optimizer._state.items():
                for new_state in emb_states:
                    state = name_to_state[new_state.name]
                    new_state = new_state[
                        F.arange(0, num_nodes, F.int64, F.cpu())
                    ]
                    assert F.allclose (state, new_state, 0., 0.)
            assert new_emb_optimizer._lr == emb_optimizer._lr
            assert new_emb_optimizer._eps == emb_optimizer._eps
            assert new_emb_optimizer._beta1 == emb_optimizer._beta1
            assert new_emb_optimizer._beta2 == emb_optimizer._beta2
        g.barrier()
    finally:
        file = f'emb.pt_{rank}'
        if os.path.exists(file):
            os.remove(file)

def test_dist_optimizer(g):
    dist_optimizer_check_store(g)


236
if mode == "server":
237
238
239
240
241
242
243
244
245
246
    shared_mem = bool(int(os.environ.get("DIST_DGL_TEST_SHARED_MEM")))
    server_id = int(os.environ.get("DIST_DGL_TEST_SERVER_ID"))
    run_server(
        graph_name,
        server_id,
        server_count=num_servers_per_machine,
        num_clients=num_part * num_client_per_machine,
        shared_mem=shared_mem,
        keep_alive=False,
    )
247
elif mode == "client":
248
    os.environ["DGL_NUM_SERVER"] = str(num_servers_per_machine)
249
250
    dgl.distributed.initialize(ip_config, net_type=net_type)

251
252
253
    gpb, graph_name, _, _ = load_partition_book(
        graph_path + "/{}.json".format(graph_name), part_id, None
    )
254
    g = dgl.distributed.DistGraph(graph_name, gpb=gpb)
255

256
257
258
    target_func_map = {
        "DistTensor": test_dist_tensor,
        "DistEmbedding": test_dist_embedding,
259
        "DistOptimizer": test_dist_optimizer,
260
    }
261
262
263
264
265

    target = os.environ.get("DIST_DGL_TEST_OBJECT_TYPE", "")
    if target not in target_func_map:
        for test_func in target_func_map.values():
            test_func(g)
266
    else:
267
268
        target_func_map[target](g)

269
270
else:
    exit(1)