Unverified Commit 8eab08d0 authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[KVStore] Remove Freeze flag (#1605)

* remove freeze

* update

* update

* fix lint
parent cbe4c28f
"""Define distributed kvstore""" """Define distributed kvstore"""
import os import os
import time
import random import random
import numpy as np import numpy as np
...@@ -356,8 +355,6 @@ class GetSharedDataRequest(rpc.Request): ...@@ -356,8 +355,6 @@ class GetSharedDataRequest(rpc.Request):
kv_store.part_policy[name].policy_str) kv_store.part_policy[name].policy_str)
if len(meta) == 0: if len(meta) == 0:
raise RuntimeError('There is no data on kvserver.') raise RuntimeError('There is no data on kvserver.')
# Freeze data init
kv_store.freeze = True
res = GetSharedDataResponse(meta) res = GetSharedDataResponse(meta)
return res return res
...@@ -451,10 +448,11 @@ class SendMetaToBackupRequest(rpc.Request): ...@@ -451,10 +448,11 @@ class SendMetaToBackupRequest(rpc.Request):
def process_request(self, server_state): def process_request(self, server_state):
kv_store = server_state.kv_store kv_store = server_state.kv_store
assert kv_store.is_backup_server() assert kv_store.is_backup_server()
shared_data = empty_shared_mem(self.name+'-kvdata-', False, self.shape, self.dtype) if self.name not in kv_store.data_store:
dlpack = shared_data.to_dlpack() shared_data = empty_shared_mem(self.name+'-kvdata-', False, self.shape, self.dtype)
kv_store.data_store[self.name] = F.zerocopy_from_dlpack(dlpack) dlpack = shared_data.to_dlpack()
kv_store.part_policy[self.name] = kv_store.find_policy(self.policy_str) kv_store.data_store[self.name] = F.zerocopy_from_dlpack(dlpack)
kv_store.part_policy[self.name] = kv_store.find_policy(self.policy_str)
res = SendMetaToBackupResponse(SEND_META_TO_BACKUP_MSG) res = SendMetaToBackupResponse(SEND_META_TO_BACKUP_MSG)
return res return res
...@@ -570,8 +568,6 @@ class KVServer(object): ...@@ -570,8 +568,6 @@ class KVServer(object):
# push and pull handler # push and pull handler
self._push_handler = default_push_handler self._push_handler = default_push_handler
self._pull_handler = default_pull_handler self._pull_handler = default_pull_handler
# We cannot create new data on kvstore when freeze == True
self._freeze = False
@property @property
def server_id(self): def server_id(self):
...@@ -588,16 +584,6 @@ class KVServer(object): ...@@ -588,16 +584,6 @@ class KVServer(object):
"""Set barrier count""" """Set barrier count"""
self._barrier_count = count self._barrier_count = count
@property
def freeze(self):
"""Get freeze"""
return self._freeze
@freeze.setter
def freeze(self, freeze):
"""Set freeze"""
self._freeze = freeze
@property @property
def num_clients(self): def num_clients(self):
"""Get number of clients""" """Get number of clients"""
...@@ -669,9 +655,6 @@ class KVServer(object): ...@@ -669,9 +655,6 @@ class KVServer(object):
read shared-memory when client invoking get_shared_data(). read shared-memory when client invoking get_shared_data().
""" """
assert len(name) > 0, 'name cannot be empty.' assert len(name) > 0, 'name cannot be empty.'
if self._freeze:
raise RuntimeError("KVServer cannot create new data \
after client invoking get_shared_data() API.")
if self._data_store.__contains__(name): if self._data_store.__contains__(name):
raise RuntimeError("Data %s has already exists!" % name) raise RuntimeError("Data %s has already exists!" % name)
if data_tensor is not None: # Create shared-tensor if data_tensor is not None: # Create shared-tensor
...@@ -764,9 +747,6 @@ class KVClient(object): ...@@ -764,9 +747,6 @@ class KVClient(object):
# push and pull handler # push and pull handler
self._pull_handler = default_pull_handler self._pull_handler = default_pull_handler
self._push_handler = default_push_handler self._push_handler = default_push_handler
# We cannot create new data on kvstore when freeze == True
self._freeze = False
random.seed(time.time())
@property @property
def client_id(self): def client_id(self):
...@@ -858,9 +838,7 @@ class KVClient(object): ...@@ -858,9 +838,7 @@ class KVClient(object):
assert len(name) > 0, 'name cannot be empty.' assert len(name) > 0, 'name cannot be empty.'
assert len(shape) > 0, 'shape cannot be empty' assert len(shape) > 0, 'shape cannot be empty'
assert policy_str in ('edge', 'node'), 'policy_str must be \'edge\' or \'node\'.' assert policy_str in ('edge', 'node'), 'policy_str must be \'edge\' or \'node\'.'
if self._freeze: assert name not in self._data_name_list, 'data name: %s already exists.' % name
raise RuntimeError("KVClient cannot create new \
data after invoking get_shared_data() API.")
shape = list(shape) shape = list(shape)
if self._client_id == 0: if self._client_id == 0:
for machine_id in range(self._machine_count): for machine_id in range(self._machine_count):
...@@ -920,27 +898,28 @@ class KVClient(object): ...@@ -920,27 +898,28 @@ class KVClient(object):
rpc.send_request(self._main_server_id, request) rpc.send_request(self._main_server_id, request)
response = rpc.recv_response() response = rpc.recv_response()
for name, meta in response.meta.items(): for name, meta in response.meta.items():
shape, dtype, policy_str = meta if name not in self._data_name_list:
shared_data = empty_shared_mem(name+'-kvdata-', False, shape, dtype) shape, dtype, policy_str = meta
dlpack = shared_data.to_dlpack() shared_data = empty_shared_mem(name+'-kvdata-', False, shape, dtype)
self._data_store[name] = F.zerocopy_from_dlpack(dlpack) dlpack = shared_data.to_dlpack()
self._part_policy[name] = PartitionPolicy(policy_str, self._part_id, partition_book) self._data_store[name] = F.zerocopy_from_dlpack(dlpack)
self._data_name_list.add(name) self._part_policy[name] = PartitionPolicy(policy_str, self._part_id, partition_book)
# Get full data shape across servers # Get full data shape across servers
for name, meta in response.meta.items(): for name, meta in response.meta.items():
shape, _, _ = meta if name not in self._data_name_list:
data_shape = list(shape) shape, _, _ = meta
data_shape[0] = 0 data_shape = list(shape)
request = GetPartShapeRequest(name) data_shape[0] = 0
# send request to all main server nodes request = GetPartShapeRequest(name)
for machine_id in range(self._machine_count): # send request to all main server nodes
server_id = machine_id * self._group_count for machine_id in range(self._machine_count):
rpc.send_request(server_id, request) server_id = machine_id * self._group_count
# recv response from all the main server nodes rpc.send_request(server_id, request)
for _ in range(self._machine_count): # recv response from all the main server nodes
res = rpc.recv_response() for _ in range(self._machine_count):
data_shape[0] += res.shape[0] res = rpc.recv_response()
self._full_data_shape[name] = tuple(data_shape) data_shape[0] += res.shape[0]
self._full_data_shape[name] = tuple(data_shape)
# Send meta data to backup servers # Send meta data to backup servers
for name, meta in response.meta.items(): for name, meta in response.meta.items():
shape, dtype, policy_str = meta shape, dtype, policy_str = meta
...@@ -953,7 +932,7 @@ class KVClient(object): ...@@ -953,7 +932,7 @@ class KVClient(object):
for _ in range(self._group_count-1): for _ in range(self._group_count-1):
response = rpc.recv_response() response = rpc.recv_response()
assert response.msg == SEND_META_TO_BACKUP_MSG assert response.msg == SEND_META_TO_BACKUP_MSG
self._freeze = True self._data_name_list.add(name)
def data_name_list(self): def data_name_list(self):
"""Get all the data name""" """Get all the data name"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment