Unverified Commit e5e1a99d authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[KVStore] Fix bug in barrier. (#1904)

* update

* update

* fix lint

* update

* update
parent 03f251ce
...@@ -549,8 +549,15 @@ class RegisterRoleRequest(rpc.Request): ...@@ -549,8 +549,15 @@ class RegisterRoleRequest(rpc.Request):
role[self.role] = set() role[self.role] = set()
kv_store.barrier_count[self.role] = 0 kv_store.barrier_count[self.role] = 0
role[self.role].add(self.client_id) role[self.role].add(self.client_id)
res = RegisterRoleResponse(ROLE_MSG) total_count = 0
return res for key in role:
total_count += len(role[key])
if total_count == kv_store.num_clients:
res_list = []
for target_id in range(kv_store.num_clients):
res_list.append((target_id, RegisterRoleResponse(ROLE_MSG)))
return res_list
return None
############################ KVServer ############################### ############################ KVServer ###############################
......
...@@ -104,7 +104,7 @@ def test_partition_policy(): ...@@ -104,7 +104,7 @@ def test_partition_policy():
assert edge_policy.get_data_size() == len(edge_map) assert edge_policy.get_data_size() == len(edge_map)
def start_server(server_id, num_clients): def start_server(server_id, num_clients):
# Init kvserver # Init kvserver
print("Sleep 5 seconds to test client re-connect.") print("Sleep 5 seconds to test client re-connect.")
time.sleep(5) time.sleep(5)
kvserver = dgl.distributed.KVServer(server_id=server_id, kvserver = dgl.distributed.KVServer(server_id=server_id,
...@@ -151,7 +151,6 @@ def start_client(num_clients): ...@@ -151,7 +151,6 @@ def start_client(num_clients):
dgl.distributed.connect_to_server(ip_config='kv_ip_config.txt') dgl.distributed.connect_to_server(ip_config='kv_ip_config.txt')
# Init kvclient # Init kvclient
kvclient = dgl.distributed.KVClient(ip_config='kv_ip_config.txt') kvclient = dgl.distributed.KVClient(ip_config='kv_ip_config.txt')
time.sleep(2)
assert dgl.distributed.get_num_client() == num_clients assert dgl.distributed.get_num_client() == num_clients
kvclient.init_data(name='data_1', kvclient.init_data(name='data_1',
shape=F.shape(data_1), shape=F.shape(data_1),
...@@ -286,7 +285,6 @@ def start_client_mul_role(i, num_clients): ...@@ -286,7 +285,6 @@ def start_client_mul_role(i, num_clients):
kvclient = dgl.distributed.KVClient(ip_config='kv_ip_mul_config.txt', role='trainer') kvclient = dgl.distributed.KVClient(ip_config='kv_ip_mul_config.txt', role='trainer')
else: else:
kvclient = dgl.distributed.KVClient(ip_config='kv_ip_mul_config.txt', role='sampler') kvclient = dgl.distributed.KVClient(ip_config='kv_ip_mul_config.txt', role='sampler')
time.sleep(2)
if i == 2: # block one trainer if i == 2: # block one trainer
time.sleep(5) time.sleep(5)
kvclient.barrier() kvclient.barrier()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment