Unverified Commit 485c04cf authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Distributed] Fix a few bugs in distributed API (#3094)



* fix.

* fix.

* fix.

* fix.

* Fix test
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-71-112.ec2.internal>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-2-66.ec2.internal>
parent 595d4e34
......@@ -995,7 +995,7 @@ class DistGraph:
def _get_ndata_names(self, ntype=None):
''' Get the names of all node data.
'''
names = self._client.data_name_list()
names = self._client.gdata_name_list()
ndata_names = []
for name in names:
name = parse_hetero_data_name(name)
......@@ -1007,7 +1007,7 @@ class DistGraph:
def _get_edata_names(self, etype=None):
''' Get the names of all edge data.
'''
names = self._client.data_name_list()
names = self._client.gdata_name_list()
edata_names = []
for name in names:
name = parse_hetero_data_name(name)
......
......@@ -150,8 +150,8 @@ class DistTensor:
self._name = str(data_name)
self._persistent = persistent
if self._name not in exist_names:
self.kvstore.init_data(self._name, shape, dtype, part_policy, init_func, is_gdata)
self._owner = True
self.kvstore.init_data(self._name, shape, dtype, part_policy, init_func, is_gdata)
else:
self._owner = False
dtype1, shape1, _ = self.kvstore.get_data_meta(self._name)
......
......@@ -1123,9 +1123,13 @@ class KVClient(object):
self._gdata_name_list.add(name)
self.barrier()
def gdata_name_list(self):
"""Get all the graph data name"""
return list(self._gdata_name_list)
def data_name_list(self):
"""Get all the data name"""
return list(self._gdata_name_list)
return list(self._data_name_list)
def get_data_meta(self, name):
"""Get meta data (data_type, data_shape, partition_policy)
......
......@@ -32,7 +32,8 @@ class DistSparseGradOptimizer(abc.ABC):
self._rank = th.distributed.get_rank()
self._world_size = th.distributed.get_world_size()
else:
assert 'th.distributed shoud be initialized'
self._rank = 0
self._world_size = 1
def step(self):
''' The step function.
......
......@@ -16,12 +16,19 @@ class KVClient(object):
self._all_possible_part_policy = {}
self._push_handlers = {}
self._pull_handlers = {}
# Store all graph data name
self._gdata_name_list = set()
@property
def all_possible_part_policy(self):
"""Get all possible partition policies"""
return self._all_possible_part_policy
@property
def num_servers(self):
"""Get the number of servers"""
return 1
def barrier(self):
'''barrier'''
......@@ -39,11 +46,13 @@ class KVClient(object):
if part_policy.policy_str not in self._all_possible_part_policy:
self._all_possible_part_policy[part_policy.policy_str] = part_policy
def init_data(self, name, shape, dtype, part_policy, init_func):
def init_data(self, name, shape, dtype, part_policy, init_func, is_gdata=True):
'''add new data to the client'''
self._data[name] = init_func(shape, dtype)
if part_policy.policy_str not in self._all_possible_part_policy:
self._all_possible_part_policy[part_policy.policy_str] = part_policy
if is_gdata:
self._gdata_name_list.add(name)
def delete_data(self, name):
'''delete the data'''
......@@ -53,6 +62,10 @@ class KVClient(object):
'''get the names of all data'''
return list(self._data.keys())
def gdata_name_list(self):
'''get the names of graph data'''
return list(self._gdata_name_list)
def get_data_meta(self, name):
'''get the metadata of data'''
return F.dtype(self._data[name]), F.shape(self._data[name]), None
......
......@@ -193,7 +193,7 @@ def check_dist_emb(g, num_clients, num_nodes, num_edges):
assert np.all(F.asnumpy(feats1) == np.zeros((len(rest), 1)))
policy = dgl.distributed.PartitionPolicy('node', g.get_partition_book())
grad_sum = dgl.distributed.DistTensor((g.number_of_nodes(),), F.float32,
grad_sum = dgl.distributed.DistTensor((g.number_of_nodes(), 1), F.float32,
'emb1_sum', policy)
if num_clients == 1:
assert np.all(F.asnumpy(grad_sum[nids]) == np.ones((len(nids), 1)) * num_clients)
......@@ -216,12 +216,15 @@ def check_dist_emb(g, num_clients, num_nodes, num_edges):
with F.no_grad():
feats = emb(nids)
if num_clients == 1:
assert_almost_equal(F.asnumpy(feats), np.ones((len(nids), 1)) * math.sqrt(2) * -lr)
assert_almost_equal(F.asnumpy(feats), np.ones((len(nids), 1)) * 1 * -lr)
rest = np.setdiff1d(np.arange(g.number_of_nodes()), F.asnumpy(nids))
feats1 = emb(rest)
assert np.all(F.asnumpy(feats1) == np.zeros((len(rest), 1)))
except NotImplementedError as e:
pass
except Exception as e:
print(e)
sys.exit(-1)
def check_dist_graph(g, num_clients, num_nodes, num_edges):
# Test API
......@@ -332,6 +335,7 @@ def check_dist_emb_server_client(shared_mem, num_servers, num_clients):
for p in cli_ps:
p.join()
assert p.exitcode == 0
for p in serv_ps:
p.join()
......@@ -590,7 +594,6 @@ def test_dist_emb_server_client():
check_dist_emb_server_client(True, 1, 1)
check_dist_emb_server_client(False, 1, 1)
check_dist_emb_server_client(True, 2, 2)
check_dist_emb_server_client(False, 2, 2)
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support some of operations in DistGraph")
def test_standalone():
......@@ -765,9 +768,9 @@ def prepare_dist():
if __name__ == '__main__':
os.makedirs('/tmp/dist_graph', exist_ok=True)
test_dist_emb_server_client()
test_server_client()
test_split()
test_split_even()
test_standalone()
test_standalone_node_emb()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment