Unverified Commit 485c04cf authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Distributed] Fix a few bugs in distributed API (#3094)



* fix.

* fix.

* fix.

* fix.

* Fix test
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-71-112.ec2.internal>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-2-66.ec2.internal>
parent 595d4e34
...@@ -995,7 +995,7 @@ class DistGraph: ...@@ -995,7 +995,7 @@ class DistGraph:
def _get_ndata_names(self, ntype=None): def _get_ndata_names(self, ntype=None):
''' Get the names of all node data. ''' Get the names of all node data.
''' '''
names = self._client.data_name_list() names = self._client.gdata_name_list()
ndata_names = [] ndata_names = []
for name in names: for name in names:
name = parse_hetero_data_name(name) name = parse_hetero_data_name(name)
...@@ -1007,7 +1007,7 @@ class DistGraph: ...@@ -1007,7 +1007,7 @@ class DistGraph:
def _get_edata_names(self, etype=None): def _get_edata_names(self, etype=None):
''' Get the names of all edge data. ''' Get the names of all edge data.
''' '''
names = self._client.data_name_list() names = self._client.gdata_name_list()
edata_names = [] edata_names = []
for name in names: for name in names:
name = parse_hetero_data_name(name) name = parse_hetero_data_name(name)
......
...@@ -150,8 +150,8 @@ class DistTensor: ...@@ -150,8 +150,8 @@ class DistTensor:
self._name = str(data_name) self._name = str(data_name)
self._persistent = persistent self._persistent = persistent
if self._name not in exist_names: if self._name not in exist_names:
self.kvstore.init_data(self._name, shape, dtype, part_policy, init_func, is_gdata)
self._owner = True self._owner = True
self.kvstore.init_data(self._name, shape, dtype, part_policy, init_func, is_gdata)
else: else:
self._owner = False self._owner = False
dtype1, shape1, _ = self.kvstore.get_data_meta(self._name) dtype1, shape1, _ = self.kvstore.get_data_meta(self._name)
......
...@@ -1123,9 +1123,13 @@ class KVClient(object): ...@@ -1123,9 +1123,13 @@ class KVClient(object):
self._gdata_name_list.add(name) self._gdata_name_list.add(name)
self.barrier() self.barrier()
def gdata_name_list(self):
"""Get all the graph data name"""
return list(self._gdata_name_list)
def data_name_list(self): def data_name_list(self):
"""Get all the data name""" """Get all the data name"""
return list(self._gdata_name_list) return list(self._data_name_list)
def get_data_meta(self, name): def get_data_meta(self, name):
"""Get meta data (data_type, data_shape, partition_policy) """Get meta data (data_type, data_shape, partition_policy)
......
...@@ -32,7 +32,8 @@ class DistSparseGradOptimizer(abc.ABC): ...@@ -32,7 +32,8 @@ class DistSparseGradOptimizer(abc.ABC):
self._rank = th.distributed.get_rank() self._rank = th.distributed.get_rank()
self._world_size = th.distributed.get_world_size() self._world_size = th.distributed.get_world_size()
else: else:
assert 'th.distributed shoud be initialized' self._rank = 0
self._world_size = 1
def step(self): def step(self):
''' The step function. ''' The step function.
......
...@@ -16,12 +16,19 @@ class KVClient(object): ...@@ -16,12 +16,19 @@ class KVClient(object):
self._all_possible_part_policy = {} self._all_possible_part_policy = {}
self._push_handlers = {} self._push_handlers = {}
self._pull_handlers = {} self._pull_handlers = {}
# Store all graph data name
self._gdata_name_list = set()
@property @property
def all_possible_part_policy(self): def all_possible_part_policy(self):
"""Get all possible partition policies""" """Get all possible partition policies"""
return self._all_possible_part_policy return self._all_possible_part_policy
@property
def num_servers(self):
"""Get the number of servers"""
return 1
def barrier(self): def barrier(self):
'''barrier''' '''barrier'''
...@@ -39,11 +46,13 @@ class KVClient(object): ...@@ -39,11 +46,13 @@ class KVClient(object):
if part_policy.policy_str not in self._all_possible_part_policy: if part_policy.policy_str not in self._all_possible_part_policy:
self._all_possible_part_policy[part_policy.policy_str] = part_policy self._all_possible_part_policy[part_policy.policy_str] = part_policy
def init_data(self, name, shape, dtype, part_policy, init_func): def init_data(self, name, shape, dtype, part_policy, init_func, is_gdata=True):
'''add new data to the client''' '''add new data to the client'''
self._data[name] = init_func(shape, dtype) self._data[name] = init_func(shape, dtype)
if part_policy.policy_str not in self._all_possible_part_policy: if part_policy.policy_str not in self._all_possible_part_policy:
self._all_possible_part_policy[part_policy.policy_str] = part_policy self._all_possible_part_policy[part_policy.policy_str] = part_policy
if is_gdata:
self._gdata_name_list.add(name)
def delete_data(self, name): def delete_data(self, name):
'''delete the data''' '''delete the data'''
...@@ -53,6 +62,10 @@ class KVClient(object): ...@@ -53,6 +62,10 @@ class KVClient(object):
'''get the names of all data''' '''get the names of all data'''
return list(self._data.keys()) return list(self._data.keys())
def gdata_name_list(self):
'''get the names of graph data'''
return list(self._gdata_name_list)
def get_data_meta(self, name): def get_data_meta(self, name):
'''get the metadata of data''' '''get the metadata of data'''
return F.dtype(self._data[name]), F.shape(self._data[name]), None return F.dtype(self._data[name]), F.shape(self._data[name]), None
......
...@@ -193,7 +193,7 @@ def check_dist_emb(g, num_clients, num_nodes, num_edges): ...@@ -193,7 +193,7 @@ def check_dist_emb(g, num_clients, num_nodes, num_edges):
assert np.all(F.asnumpy(feats1) == np.zeros((len(rest), 1))) assert np.all(F.asnumpy(feats1) == np.zeros((len(rest), 1)))
policy = dgl.distributed.PartitionPolicy('node', g.get_partition_book()) policy = dgl.distributed.PartitionPolicy('node', g.get_partition_book())
grad_sum = dgl.distributed.DistTensor((g.number_of_nodes(),), F.float32, grad_sum = dgl.distributed.DistTensor((g.number_of_nodes(), 1), F.float32,
'emb1_sum', policy) 'emb1_sum', policy)
if num_clients == 1: if num_clients == 1:
assert np.all(F.asnumpy(grad_sum[nids]) == np.ones((len(nids), 1)) * num_clients) assert np.all(F.asnumpy(grad_sum[nids]) == np.ones((len(nids), 1)) * num_clients)
...@@ -216,12 +216,15 @@ def check_dist_emb(g, num_clients, num_nodes, num_edges): ...@@ -216,12 +216,15 @@ def check_dist_emb(g, num_clients, num_nodes, num_edges):
with F.no_grad(): with F.no_grad():
feats = emb(nids) feats = emb(nids)
if num_clients == 1: if num_clients == 1:
assert_almost_equal(F.asnumpy(feats), np.ones((len(nids), 1)) * math.sqrt(2) * -lr) assert_almost_equal(F.asnumpy(feats), np.ones((len(nids), 1)) * 1 * -lr)
rest = np.setdiff1d(np.arange(g.number_of_nodes()), F.asnumpy(nids)) rest = np.setdiff1d(np.arange(g.number_of_nodes()), F.asnumpy(nids))
feats1 = emb(rest) feats1 = emb(rest)
assert np.all(F.asnumpy(feats1) == np.zeros((len(rest), 1))) assert np.all(F.asnumpy(feats1) == np.zeros((len(rest), 1)))
except NotImplementedError as e: except NotImplementedError as e:
pass pass
except Exception as e:
print(e)
sys.exit(-1)
def check_dist_graph(g, num_clients, num_nodes, num_edges): def check_dist_graph(g, num_clients, num_nodes, num_edges):
# Test API # Test API
...@@ -332,6 +335,7 @@ def check_dist_emb_server_client(shared_mem, num_servers, num_clients): ...@@ -332,6 +335,7 @@ def check_dist_emb_server_client(shared_mem, num_servers, num_clients):
for p in cli_ps: for p in cli_ps:
p.join() p.join()
assert p.exitcode == 0
for p in serv_ps: for p in serv_ps:
p.join() p.join()
...@@ -590,7 +594,6 @@ def test_dist_emb_server_client(): ...@@ -590,7 +594,6 @@ def test_dist_emb_server_client():
check_dist_emb_server_client(True, 1, 1) check_dist_emb_server_client(True, 1, 1)
check_dist_emb_server_client(False, 1, 1) check_dist_emb_server_client(False, 1, 1)
check_dist_emb_server_client(True, 2, 2) check_dist_emb_server_client(True, 2, 2)
check_dist_emb_server_client(False, 2, 2)
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support some of operations in DistGraph") @unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support some of operations in DistGraph")
def test_standalone(): def test_standalone():
...@@ -765,9 +768,9 @@ def prepare_dist(): ...@@ -765,9 +768,9 @@ def prepare_dist():
if __name__ == '__main__': if __name__ == '__main__':
os.makedirs('/tmp/dist_graph', exist_ok=True) os.makedirs('/tmp/dist_graph', exist_ok=True)
test_dist_emb_server_client()
test_server_client() test_server_client()
test_split() test_split()
test_split_even() test_split_even()
test_standalone() test_standalone()
test_standalone_node_emb() test_standalone_node_emb()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment