"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "99540747b5d1ac977be3e55eea03ae63b2d63c80"
Unverified Commit 4dc5728a authored by Serge Panev's avatar Serge Panev Committed by GitHub
Browse files

[Dist][Test] Improves DistTensor test for num_part id > 2 (#4265)


Signed-off-by: default avatarSerge Panev <spanev@nvidia.com>
Co-authored-by: default avatarRhett Ying <85214957+Rhett-Ying@users.noreply.github.com>
parent 740cd706
......@@ -33,7 +33,6 @@ def run_server(graph_name, server_id, server_count, num_clients, shared_mem, kee
print('start server', server_id)
g.start()
def dist_tensor_test_sanity(data_shape, rank, name=None):
dist_ten = dgl.distributed.DistTensor(data_shape,
F.int32,
......@@ -41,14 +40,15 @@ def dist_tensor_test_sanity(data_shape, rank, name=None):
name=name)
# arbitrary value
stride = 3
if part_id == 0:
dist_ten[rank*stride:(rank+1)*stride] = F.ones((stride, 2), dtype=F.int32, ctx=F.cpu()) * (rank+1)
dgl.distributed.client_barrier()
else:
dgl.distributed.client_barrier()
original_rank = rank % num_client_per_machine
assert F.allclose(dist_ten[original_rank*stride:(original_rank+1)*stride],
F.ones((stride, 2), dtype=F.int32, ctx=F.cpu()) * (original_rank+1))
local_rank = rank % num_client_per_machine
pos = (part_id // 2) * num_client_per_machine + local_rank
if part_id % 2 == 0:
dist_ten[pos*stride:(pos+1)*stride] = F.ones((stride, 2), dtype=F.int32, ctx=F.cpu()) * (pos+1)
dgl.distributed.client_barrier()
assert F.allclose(dist_ten[pos*stride:(pos+1)*stride],
F.ones((stride, 2), dtype=F.int32, ctx=F.cpu()) * (pos+1))
def dist_tensor_test_destroy_recreate(data_shape, name):
dist_ten = dgl.distributed.DistTensor(data_shape, F.float32, name, init_func=zeros_init)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment