Unverified Commit 8eab08d0 authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[KVStore] Remove Freeze flag (#1605)

* remove freeze

* update

* update

* fix lint
parent cbe4c28f
"""Define distributed kvstore""" """Define distributed kvstore"""
import os import os
import time
import random import random
import numpy as np import numpy as np
...@@ -356,8 +355,6 @@ class GetSharedDataRequest(rpc.Request): ...@@ -356,8 +355,6 @@ class GetSharedDataRequest(rpc.Request):
kv_store.part_policy[name].policy_str) kv_store.part_policy[name].policy_str)
if len(meta) == 0: if len(meta) == 0:
raise RuntimeError('There is no data on kvserver.') raise RuntimeError('There is no data on kvserver.')
# Freeze data init
kv_store.freeze = True
res = GetSharedDataResponse(meta) res = GetSharedDataResponse(meta)
return res return res
...@@ -451,6 +448,7 @@ class SendMetaToBackupRequest(rpc.Request): ...@@ -451,6 +448,7 @@ class SendMetaToBackupRequest(rpc.Request):
def process_request(self, server_state): def process_request(self, server_state):
kv_store = server_state.kv_store kv_store = server_state.kv_store
assert kv_store.is_backup_server() assert kv_store.is_backup_server()
if self.name not in kv_store.data_store:
shared_data = empty_shared_mem(self.name+'-kvdata-', False, self.shape, self.dtype) shared_data = empty_shared_mem(self.name+'-kvdata-', False, self.shape, self.dtype)
dlpack = shared_data.to_dlpack() dlpack = shared_data.to_dlpack()
kv_store.data_store[self.name] = F.zerocopy_from_dlpack(dlpack) kv_store.data_store[self.name] = F.zerocopy_from_dlpack(dlpack)
...@@ -570,8 +568,6 @@ class KVServer(object): ...@@ -570,8 +568,6 @@ class KVServer(object):
# push and pull handler # push and pull handler
self._push_handler = default_push_handler self._push_handler = default_push_handler
self._pull_handler = default_pull_handler self._pull_handler = default_pull_handler
# We cannot create new data on kvstore when freeze == True
self._freeze = False
@property @property
def server_id(self): def server_id(self):
...@@ -588,16 +584,6 @@ class KVServer(object): ...@@ -588,16 +584,6 @@ class KVServer(object):
"""Set barrier count""" """Set barrier count"""
self._barrier_count = count self._barrier_count = count
@property
def freeze(self):
"""Get freeze"""
return self._freeze
@freeze.setter
def freeze(self, freeze):
"""Set freeze"""
self._freeze = freeze
@property @property
def num_clients(self): def num_clients(self):
"""Get number of clients""" """Get number of clients"""
...@@ -669,9 +655,6 @@ class KVServer(object): ...@@ -669,9 +655,6 @@ class KVServer(object):
read shared-memory when client invoking get_shared_data(). read shared-memory when client invoking get_shared_data().
""" """
assert len(name) > 0, 'name cannot be empty.' assert len(name) > 0, 'name cannot be empty.'
if self._freeze:
raise RuntimeError("KVServer cannot create new data \
after client invoking get_shared_data() API.")
if self._data_store.__contains__(name): if self._data_store.__contains__(name):
raise RuntimeError("Data %s has already exists!" % name) raise RuntimeError("Data %s has already exists!" % name)
if data_tensor is not None: # Create shared-tensor if data_tensor is not None: # Create shared-tensor
...@@ -764,9 +747,6 @@ class KVClient(object): ...@@ -764,9 +747,6 @@ class KVClient(object):
# push and pull handler # push and pull handler
self._pull_handler = default_pull_handler self._pull_handler = default_pull_handler
self._push_handler = default_push_handler self._push_handler = default_push_handler
# We cannot create new data on kvstore when freeze == True
self._freeze = False
random.seed(time.time())
@property @property
def client_id(self): def client_id(self):
...@@ -858,9 +838,7 @@ class KVClient(object): ...@@ -858,9 +838,7 @@ class KVClient(object):
assert len(name) > 0, 'name cannot be empty.' assert len(name) > 0, 'name cannot be empty.'
assert len(shape) > 0, 'shape cannot be empty' assert len(shape) > 0, 'shape cannot be empty'
assert policy_str in ('edge', 'node'), 'policy_str must be \'edge\' or \'node\'.' assert policy_str in ('edge', 'node'), 'policy_str must be \'edge\' or \'node\'.'
if self._freeze: assert name not in self._data_name_list, 'data name: %s already exists.' % name
raise RuntimeError("KVClient cannot create new \
data after invoking get_shared_data() API.")
shape = list(shape) shape = list(shape)
if self._client_id == 0: if self._client_id == 0:
for machine_id in range(self._machine_count): for machine_id in range(self._machine_count):
...@@ -920,14 +898,15 @@ class KVClient(object): ...@@ -920,14 +898,15 @@ class KVClient(object):
rpc.send_request(self._main_server_id, request) rpc.send_request(self._main_server_id, request)
response = rpc.recv_response() response = rpc.recv_response()
for name, meta in response.meta.items(): for name, meta in response.meta.items():
if name not in self._data_name_list:
shape, dtype, policy_str = meta shape, dtype, policy_str = meta
shared_data = empty_shared_mem(name+'-kvdata-', False, shape, dtype) shared_data = empty_shared_mem(name+'-kvdata-', False, shape, dtype)
dlpack = shared_data.to_dlpack() dlpack = shared_data.to_dlpack()
self._data_store[name] = F.zerocopy_from_dlpack(dlpack) self._data_store[name] = F.zerocopy_from_dlpack(dlpack)
self._part_policy[name] = PartitionPolicy(policy_str, self._part_id, partition_book) self._part_policy[name] = PartitionPolicy(policy_str, self._part_id, partition_book)
self._data_name_list.add(name)
# Get full data shape across servers # Get full data shape across servers
for name, meta in response.meta.items(): for name, meta in response.meta.items():
if name not in self._data_name_list:
shape, _, _ = meta shape, _, _ = meta
data_shape = list(shape) data_shape = list(shape)
data_shape[0] = 0 data_shape[0] = 0
...@@ -953,7 +932,7 @@ class KVClient(object): ...@@ -953,7 +932,7 @@ class KVClient(object):
for _ in range(self._group_count-1): for _ in range(self._group_count-1):
response = rpc.recv_response() response = rpc.recv_response()
assert response.msg == SEND_META_TO_BACKUP_MSG assert response.msg == SEND_META_TO_BACKUP_MSG
self._freeze = True self._data_name_list.add(name)
def data_name_list(self): def data_name_list(self):
"""Get all the data name""" """Get all the data name"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment