Unverified Commit e5e1a99d authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[KVStore] Fix bug in barrier. (#1904)

* update

* update

* fix lint

* update

* update
parent 03f251ce
......@@ -549,8 +549,15 @@ class RegisterRoleRequest(rpc.Request):
role[self.role] = set()
kv_store.barrier_count[self.role] = 0
role[self.role].add(self.client_id)
res = RegisterRoleResponse(ROLE_MSG)
return res
total_count = 0
for key in role:
total_count += len(role[key])
if total_count == kv_store.num_clients:
res_list = []
for target_id in range(kv_store.num_clients):
res_list.append((target_id, RegisterRoleResponse(ROLE_MSG)))
return res_list
return None
############################ KVServer ###############################
......
......@@ -104,7 +104,7 @@ def test_partition_policy():
assert edge_policy.get_data_size() == len(edge_map)
def start_server(server_id, num_clients):
# Init kvserver
# Init kvserver
print("Sleep 5 seconds to test client re-connect.")
time.sleep(5)
kvserver = dgl.distributed.KVServer(server_id=server_id,
......@@ -151,7 +151,6 @@ def start_client(num_clients):
dgl.distributed.connect_to_server(ip_config='kv_ip_config.txt')
# Init kvclient
kvclient = dgl.distributed.KVClient(ip_config='kv_ip_config.txt')
time.sleep(2)
assert dgl.distributed.get_num_client() == num_clients
kvclient.init_data(name='data_1',
shape=F.shape(data_1),
......@@ -286,7 +285,6 @@ def start_client_mul_role(i, num_clients):
kvclient = dgl.distributed.KVClient(ip_config='kv_ip_mul_config.txt', role='trainer')
else:
kvclient = dgl.distributed.KVClient(ip_config='kv_ip_mul_config.txt', role='sampler')
time.sleep(2)
if i == 2: # block one trainer
time.sleep(5)
kvclient.barrier()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment