Unverified Commit 4dc5728a authored by Serge Panev's avatar Serge Panev Committed by GitHub
Browse files

[Dist][Test] Improves DistTensor test for num_part id > 2 (#4265)


Signed-off-by: default avatarSerge Panev <spanev@nvidia.com>
Co-authored-by: default avatarRhett Ying <85214957+Rhett-Ying@users.noreply.github.com>
parent 740cd706
...@@ -33,7 +33,6 @@ def run_server(graph_name, server_id, server_count, num_clients, shared_mem, kee ...@@ -33,7 +33,6 @@ def run_server(graph_name, server_id, server_count, num_clients, shared_mem, kee
print('start server', server_id) print('start server', server_id)
g.start() g.start()
def dist_tensor_test_sanity(data_shape, rank, name=None): def dist_tensor_test_sanity(data_shape, rank, name=None):
dist_ten = dgl.distributed.DistTensor(data_shape, dist_ten = dgl.distributed.DistTensor(data_shape,
F.int32, F.int32,
...@@ -41,14 +40,15 @@ def dist_tensor_test_sanity(data_shape, rank, name=None): ...@@ -41,14 +40,15 @@ def dist_tensor_test_sanity(data_shape, rank, name=None):
name=name) name=name)
# arbitrary value # arbitrary value
stride = 3 stride = 3
if part_id == 0: local_rank = rank % num_client_per_machine
dist_ten[rank*stride:(rank+1)*stride] = F.ones((stride, 2), dtype=F.int32, ctx=F.cpu()) * (rank+1) pos = (part_id // 2) * num_client_per_machine + local_rank
dgl.distributed.client_barrier() if part_id % 2 == 0:
else: dist_ten[pos*stride:(pos+1)*stride] = F.ones((stride, 2), dtype=F.int32, ctx=F.cpu()) * (pos+1)
dgl.distributed.client_barrier() dgl.distributed.client_barrier()
original_rank = rank % num_client_per_machine assert F.allclose(dist_ten[pos*stride:(pos+1)*stride],
assert F.allclose(dist_ten[original_rank*stride:(original_rank+1)*stride], F.ones((stride, 2), dtype=F.int32, ctx=F.cpu()) * (pos+1))
F.ones((stride, 2), dtype=F.int32, ctx=F.cpu()) * (original_rank+1))
def dist_tensor_test_destroy_recreate(data_shape, name): def dist_tensor_test_destroy_recreate(data_shape, name):
dist_ten = dgl.distributed.DistTensor(data_shape, F.float32, name, init_func=zeros_init) dist_ten = dgl.distributed.DistTensor(data_shape, F.float32, name, init_func=zeros_init)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment