"src/array/vscode:/vscode.git/clone" did not exist on "9a00cf194fcf994b2527cd927d691144f5e9c47b"
run_dist_objects.py 6.15 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import dgl
import os
import numpy as np
import dgl.backend as F
from dgl.distributed import load_partition_book

mode = os.environ.get('DIST_DGL_TEST_MODE', "")
graph_name = os.environ.get('DIST_DGL_TEST_GRAPH_NAME', 'random_test_graph')
num_part = int(os.environ.get('DIST_DGL_TEST_NUM_PART'))
num_servers_per_machine = int(os.environ.get('DIST_DGL_TEST_NUM_SERVER'))
num_client_per_machine = int(os.environ.get('DIST_DGL_TEST_NUM_CLIENT'))
shared_workspace = os.environ.get('DIST_DGL_TEST_WORKSPACE')
graph_path = os.environ.get('DIST_DGL_TEST_GRAPH_PATH')
part_id = int(os.environ.get('DIST_DGL_TEST_PART_ID'))
net_type = os.environ.get('DIST_DGL_TEST_NET_TYPE')
ip_config = os.environ.get('DIST_DGL_TEST_IP_CONFIG', 'ip_config.txt')

os.environ['DGL_DIST_MODE'] = 'distributed'

def zeros_init(shape, dtype):
    return F.zeros(shape, dtype=dtype, ctx=F.cpu())

def run_server(graph_name, server_id, server_count, num_clients, shared_mem, keep_alive=False):
    # server_count = num_servers_per_machine
    g = dgl.distributed.DistGraphServer(server_id, ip_config,
                        server_count, num_clients,
                        graph_path + '/{}.json'.format(graph_name),
                        disable_shared_mem=not shared_mem,
                        graph_format=['csc', 'coo'], keep_alive=keep_alive,
                        net_type=net_type)
    print('start server', server_id)
    g.start()

34
35
36
37
38
39
##########################################
############### DistTensor ###############
##########################################

def dist_tensor_test_sanity(data_shape, name=None):
    local_rank = dgl.distributed.get_rank() % num_client_per_machine
40
41
42
43
44
45
    dist_ten = dgl.distributed.DistTensor(data_shape,
                                          F.int32,
                                          init_func=zeros_init,
                                          name=name)
    # arbitrary value
    stride = 3
46
47
48
49
50
51
52
53
    pos = (part_id // 2) * num_client_per_machine + local_rank
    if part_id % 2 == 0:
        dist_ten[pos*stride:(pos+1)*stride] = F.ones((stride, 2), dtype=F.int32, ctx=F.cpu()) * (pos+1)

    dgl.distributed.client_barrier()
    assert F.allclose(dist_ten[pos*stride:(pos+1)*stride],
                    F.ones((stride, 2), dtype=F.int32, ctx=F.cpu()) * (pos+1))

54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75

def dist_tensor_test_destroy_recreate(data_shape, name):
    dist_ten = dgl.distributed.DistTensor(data_shape, F.float32, name, init_func=zeros_init)
    del dist_ten

    dgl.distributed.client_barrier()

    new_shape = (data_shape[0], 4)
    dist_ten = dgl.distributed.DistTensor(new_shape, F.float32, name, init_func=zeros_init)

def dist_tensor_test_persistent(data_shape):
    dist_ten_name = 'persistent_dist_tensor'
    dist_ten = dgl.distributed.DistTensor(data_shape, F.float32, dist_ten_name, init_func=zeros_init,
                                          persistent=True)
    del dist_ten
    try:
        dist_ten = dgl.distributed.DistTensor(data_shape, F.float32, dist_ten_name)
        raise Exception('')
    except:
        pass


76
def test_dist_tensor(g):
77
78
    first_type = g.ntypes[0]
    data_shape = (g.number_of_nodes(first_type), 2)
79
80
    dist_tensor_test_sanity(data_shape)
    dist_tensor_test_sanity(data_shape, name="DistTensorSanity")
81
82
83
84
    dist_tensor_test_destroy_recreate(data_shape, name="DistTensorRecreate")
    dist_tensor_test_persistent(data_shape)


85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
##########################################
############# DistEmbedding ##############
##########################################

def dist_embedding_check_sanity(num_nodes, optimizer, name=None):
    local_rank =  dgl.distributed.get_rank() % num_client_per_machine

    emb = dgl.distributed.DistEmbedding(num_nodes, 1, name=name, init_func=zeros_init)
    lr = 0.001
    optim = optimizer(params=[emb], lr=lr)

    stride = 3

    pos = (part_id // 2) * num_client_per_machine + local_rank
    idx = F.arange(pos*stride, (pos+1)*stride)

    if part_id % 2 == 0:
        with F.record_grad():
            value = emb(idx)
            optim.zero_grad()
            loss = F.sum(value + 1, 0)
        loss.backward()
        optim.step()

    dgl.distributed.client_barrier()
    value = emb(idx)
    F.allclose(value, F.ones((len(idx), 1), dtype=F.int32, ctx=F.cpu()) * -lr)

    not_update_idx = F.arange(((num_part + 1) / 2) * num_client_per_machine * stride, num_nodes)
    value = emb(not_update_idx)
    assert np.all(F.asnumpy(value) == np.zeros((len(not_update_idx), 1)))


def dist_embedding_check_existing(num_nodes):
    dist_emb_name = "UniqueEmb"
    emb = dgl.distributed.DistEmbedding(num_nodes, 1, name=dist_emb_name, init_func=zeros_init)
    try:
        emb1 = dgl.distributed.DistEmbedding(num_nodes, 2, name=dist_emb_name, init_func=zeros_init)
        raise Exception('')
    except:
        pass

def test_dist_embedding(g):
    num_nodes = g.number_of_nodes(g.ntypes[0])
    dist_embedding_check_sanity(num_nodes, dgl.distributed.optim.SparseAdagrad)
    dist_embedding_check_sanity(num_nodes, dgl.distributed.optim.SparseAdagrad, name='SomeEmbedding')
    dist_embedding_check_sanity(num_nodes, dgl.distributed.optim.SparseAdam, name='SomeEmbedding')

    dist_embedding_check_existing(num_nodes)

135
136
137
138
139
140
141
142
143
144
145
if mode == "server":
    shared_mem = bool(int(os.environ.get('DIST_DGL_TEST_SHARED_MEM')))
    server_id = int(os.environ.get('DIST_DGL_TEST_SERVER_ID'))
    run_server(graph_name, server_id, server_count=num_servers_per_machine,
               num_clients=num_part*num_client_per_machine, shared_mem=shared_mem, keep_alive=False)
elif mode == "client":
    os.environ['DGL_NUM_SERVER'] = str(num_servers_per_machine)
    dgl.distributed.initialize(ip_config, net_type=net_type)

    gpb, graph_name, _, _ = load_partition_book(graph_path + '/{}.json'.format(graph_name), part_id, None)
    g = dgl.distributed.DistGraph(graph_name, gpb=gpb)
146
147
148
149
150
151
152
153
154

    target_func_map = {"DistTensor": test_dist_tensor,
                       "DistEmbedding": test_dist_embedding,
                       }

    target = os.environ.get("DIST_DGL_TEST_OBJECT_TYPE", "")
    if target not in target_func_map:
        for test_func in target_func_map.values():
            test_func(g)
155
    else:
156
157
        target_func_map[target](g)

158
159
160
161
else:
    print("DIST_DGL_TEST_MODE has to be either server or client")
    exit(1)