Unverified Commit 5fc1d0c8 authored by Serge Panev's avatar Serge Panev Committed by GitHub
Browse files

[Dist][Test] Add tests for multi-node DistEmbedding (#4256)


Signed-off-by: default avatarSerge Panev <spanev@nvidia.com>
Co-authored-by: default avatarRhett Ying <85214957+Rhett-Ying@users.noreply.github.com>
parent 701b746b
import dgl
import torch
import os
import numpy as np
import dgl.backend as F
from dgl.distributed import load_partition_book
import time
mode = os.environ.get('DIST_DGL_TEST_MODE', "")
graph_name = os.environ.get('DIST_DGL_TEST_GRAPH_NAME', 'random_test_graph')
......@@ -33,14 +31,18 @@ def run_server(graph_name, server_id, server_count, num_clients, shared_mem, kee
print('start server', server_id)
g.start()
def dist_tensor_test_sanity(data_shape, rank, name=None):
##########################################
############### DistTensor ###############
##########################################
def dist_tensor_test_sanity(data_shape, name=None):
local_rank = dgl.distributed.get_rank() % num_client_per_machine
dist_ten = dgl.distributed.DistTensor(data_shape,
F.int32,
init_func=zeros_init,
name=name)
# arbitrary value
stride = 3
local_rank = rank % num_client_per_machine
pos = (part_id // 2) * num_client_per_machine + local_rank
if part_id % 2 == 0:
dist_ten[pos*stride:(pos+1)*stride] = F.ones((stride, 2), dtype=F.int32, ctx=F.cpu()) * (pos+1)
......@@ -71,15 +73,65 @@ def dist_tensor_test_persistent(data_shape):
pass
def test_dist_tensor(g, rank):
def test_dist_tensor(g):
first_type = g.ntypes[0]
data_shape = (g.number_of_nodes(first_type), 2)
dist_tensor_test_sanity(data_shape, rank)
dist_tensor_test_sanity(data_shape, rank, name="DistTensorSanity")
dist_tensor_test_sanity(data_shape)
dist_tensor_test_sanity(data_shape, name="DistTensorSanity")
dist_tensor_test_destroy_recreate(data_shape, name="DistTensorRecreate")
dist_tensor_test_persistent(data_shape)
##########################################
############# DistEmbedding ##############
##########################################
def dist_embedding_check_sanity(num_nodes, optimizer, name=None):
local_rank = dgl.distributed.get_rank() % num_client_per_machine
emb = dgl.distributed.DistEmbedding(num_nodes, 1, name=name, init_func=zeros_init)
lr = 0.001
optim = optimizer(params=[emb], lr=lr)
stride = 3
pos = (part_id // 2) * num_client_per_machine + local_rank
idx = F.arange(pos*stride, (pos+1)*stride)
if part_id % 2 == 0:
with F.record_grad():
value = emb(idx)
optim.zero_grad()
loss = F.sum(value + 1, 0)
loss.backward()
optim.step()
dgl.distributed.client_barrier()
value = emb(idx)
F.allclose(value, F.ones((len(idx), 1), dtype=F.int32, ctx=F.cpu()) * -lr)
not_update_idx = F.arange(((num_part + 1) / 2) * num_client_per_machine * stride, num_nodes)
value = emb(not_update_idx)
assert np.all(F.asnumpy(value) == np.zeros((len(not_update_idx), 1)))
def dist_embedding_check_existing(num_nodes):
dist_emb_name = "UniqueEmb"
emb = dgl.distributed.DistEmbedding(num_nodes, 1, name=dist_emb_name, init_func=zeros_init)
try:
emb1 = dgl.distributed.DistEmbedding(num_nodes, 2, name=dist_emb_name, init_func=zeros_init)
raise Exception('')
except:
pass
def test_dist_embedding(g):
num_nodes = g.number_of_nodes(g.ntypes[0])
dist_embedding_check_sanity(num_nodes, dgl.distributed.optim.SparseAdagrad)
dist_embedding_check_sanity(num_nodes, dgl.distributed.optim.SparseAdagrad, name='SomeEmbedding')
dist_embedding_check_sanity(num_nodes, dgl.distributed.optim.SparseAdam, name='SomeEmbedding')
dist_embedding_check_existing(num_nodes)
if mode == "server":
shared_mem = bool(int(os.environ.get('DIST_DGL_TEST_SHARED_MEM')))
server_id = int(os.environ.get('DIST_DGL_TEST_SERVER_ID'))
......@@ -88,18 +140,21 @@ if mode == "server":
elif mode == "client":
os.environ['DGL_NUM_SERVER'] = str(num_servers_per_machine)
dgl.distributed.initialize(ip_config, net_type=net_type)
global_rank = dgl.distributed.get_rank()
gpb, graph_name, _, _ = load_partition_book(graph_path + '/{}.json'.format(graph_name), part_id, None)
g = dgl.distributed.DistGraph(graph_name, gpb=gpb)
target = os.environ.get('DIST_DGL_TEST_OBJECT_TYPE', 'DistTensor')
if target == "DistTensor":
test_dist_tensor(g, global_rank)
elif target == "DistEmbedding":
# TODO: implement DistEmbedding
pass
target_func_map = {"DistTensor": test_dist_tensor,
"DistEmbedding": test_dist_embedding,
}
target = os.environ.get("DIST_DGL_TEST_OBJECT_TYPE", "")
if target not in target_func_map:
for test_func in target_func_map.values():
test_func(g)
else:
print(target + " is not a valid DIST_DGL_TEST_OBJECT_TYPE")
target_func_map[target](g)
else:
print("DIST_DGL_TEST_MODE has to be either server or client")
exit(1)
......
......@@ -10,6 +10,7 @@ import dgl.backend as F
from dgl.distributed import partition_graph
graph_name = os.environ.get('DIST_DGL_TEST_GRAPH_NAME', 'random_test_graph')
target = os.environ.get('DIST_DGL_TEST_OBJECT_TYPE', '')
shared_workspace = os.environ.get('DIST_DGL_TEST_WORKSPACE')
def create_graph(num_part, dist_graph_path, hetero):
......@@ -37,13 +38,12 @@ def create_graph(num_part, dist_graph_path, hetero):
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@pytest.mark.parametrize("target", ['DistTensor'])
@pytest.mark.parametrize("net_type", ['tensorpipe', 'socket'])
@pytest.mark.parametrize("num_servers", [1, 4])
@pytest.mark.parametrize("num_clients", [1, 4])
@pytest.mark.parametrize("hetero", [False, True])
@pytest.mark.parametrize("shared_mem", [False, True])
def test_dist_objects(target, net_type, num_servers, num_clients, hetero, shared_mem):
def test_dist_objects(net_type, num_servers, num_clients, hetero, shared_mem):
if not shared_mem and num_servers > 1:
pytest.skip(f"Backup servers are not supported when shared memory is disabled")
ip_config = os.environ.get('DIST_DGL_TEST_IP_CONFIG', 'ip_config.txt')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment