"tests/python/common/transforms/test_transform.py" did not exist on "03024f9587d1bf9b577b56c51e745cb3af502f0a"
Unverified Commit 45858d68 authored by freeliuzc's avatar freeliuzc Committed by GitHub
Browse files

[feature] add count_nonzero function for DistTensor (#3203)



* add count_nonzero function for DistTensor

* change the load method of local data
Co-authored-by: default avatarliuzichang04 <liuzichang04@meituan.com>
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
Co-authored-by: default avatarDa Zheng <zhengda1936@gmail.com>
parent 8b5f4f5b
...@@ -1229,17 +1229,7 @@ def _split_even_to_part(partition_book, elements): ...@@ -1229,17 +1229,7 @@ def _split_even_to_part(partition_book, elements):
# strategy. # strategy.
# TODO(zhengda) we need another way to divide the list for other partitioning strategy. # TODO(zhengda) we need another way to divide the list for other partitioning strategy.
if isinstance(elements, DistTensor): if isinstance(elements, DistTensor):
# Here we need to fetch all elements from the kvstore server. nonzero_count = elements.count_nonzero()
# I hope it's OK.
eles = F.nonzero_1d(elements[0:len(elements)])
# compute the offset of each split and ensure that the difference of each partition size
# is 1.
offsets = _even_offset(len(eles), partition_book.num_partitions())
assert offsets[-1] == len(eles)
# Get the elements that belong to the partition.
partid = partition_book.partid
part_eles = eles[offsets[partid] : offsets[partid + 1]]
else: else:
elements = F.tensor(elements) elements = F.tensor(elements)
nonzero_count = F.count_nonzero(elements) nonzero_count = F.count_nonzero(elements)
......
...@@ -232,3 +232,13 @@ class DistTensor: ...@@ -232,3 +232,13 @@ class DistTensor:
The name of the tensor. The name of the tensor.
''' '''
return self._tensor_name return self._tensor_name
def count_nonzero(self):
'''Count and return the number of nonzero value
Returns
-------
int
the number of nonzero value
'''
return self.kvstore.count_nonzero(name=self.name)
...@@ -529,6 +529,42 @@ class DeleteDataRequest(rpc.Request): ...@@ -529,6 +529,42 @@ class DeleteDataRequest(rpc.Request):
res = DeleteDataResponse(DELETE_MSG) res = DeleteDataResponse(DELETE_MSG)
return res return res
COUNT_LOCAL_NONZERO = 901241
class CountLocalNonzeroResponse(rpc.Response):
"""Send the number of nonzero value in local data
"""
def __init__(self, num_local_nonzero):
self.num_local_nonzero = num_local_nonzero
def __getstate__(self):
return self.num_local_nonzero
def __setstate__(self, state):
self.num_local_nonzero = state
class CountLocalNonzeroRequest(rpc.Request):
"""Send data name to server to count local nonzero value
Parameters
----------
name : str
data name
"""
def __init__(self, name):
self.name = name
def __getstate__(self):
return self.name
def __setstate__(self, state):
self.name = state
def process_request(self, server_state):
kv_store = server_state.kv_store
num_local_nonzero = kv_store.count_local_nonzero(self.name)
res = CountLocalNonzeroResponse(num_local_nonzero)
return res
############################ KVServer ############################### ############################ KVServer ###############################
def default_push_handler(target, name, id_tensor, data_tensor): def default_push_handler(target, name, id_tensor, data_tensor):
...@@ -630,6 +666,9 @@ class KVServer(object): ...@@ -630,6 +666,9 @@ class KVServer(object):
rpc.register_service(DELETE_DATA, rpc.register_service(DELETE_DATA,
DeleteDataRequest, DeleteDataRequest,
DeleteDataResponse) DeleteDataResponse)
rpc.register_service(COUNT_LOCAL_NONZERO,
CountLocalNonzeroRequest,
CountLocalNonzeroResponse)
# Store the tensor data with specified data name # Store the tensor data with specified data name
self._data_store = {} self._data_store = {}
# Store the partition information with specified data name # Store the partition information with specified data name
...@@ -758,6 +797,24 @@ class KVServer(object): ...@@ -758,6 +797,24 @@ class KVServer(object):
return policy return policy
raise RuntimeError("Cannot find policy_str: %s from kvserver." % policy_str) raise RuntimeError("Cannot find policy_str: %s from kvserver." % policy_str)
def count_local_nonzero(self, name):
"""Count nonzero in local data
Parameters
----------
name : str
data name.
Returns
-------
int
the number of nonzero in local data.
"""
assert len(name) > 0, 'name cannot be empty.'
if name not in self._data_store:
raise RuntimeError("Data %s has not be created!" % name)
return F.count_nonzero(self._data_store[name])
############################ KVClient ############################### ############################ KVClient ###############################
class KVClient(object): class KVClient(object):
...@@ -815,6 +872,9 @@ class KVClient(object): ...@@ -815,6 +872,9 @@ class KVClient(object):
rpc.register_service(DELETE_DATA, rpc.register_service(DELETE_DATA,
DeleteDataRequest, DeleteDataRequest,
DeleteDataResponse) DeleteDataResponse)
rpc.register_service(COUNT_LOCAL_NONZERO,
CountLocalNonzeroRequest,
CountLocalNonzeroResponse)
# Store the tensor data with specified data name # Store the tensor data with specified data name
self._data_store = {} self._data_store = {}
# Store the partition information with specified data name # Store the partition information with specified data name
...@@ -1283,6 +1343,35 @@ class KVClient(object): ...@@ -1283,6 +1343,35 @@ class KVClient(object):
""" """
return elem.server_id return elem.server_id
def count_nonzero(self, name):
"""Count nonzero value by pull request from KVServers.
Parameters
----------
name : str
data name
Returns
-------
int
the number of nonzero in this data.
"""
total = 0
pull_count = 0
for machine_id in range(self._machine_count):
if machine_id == self._machine_id:
local_id = F.tensor(np.arange(self._part_policy[name].get_part_size(),
dtype=np.int64))
total += F.count_nonzero(self._data_store[name][local_id])
else:
request = CountLocalNonzeroRequest(name)
rpc.send_request_to_machine(machine_id, request)
pull_count += 1
for _ in range(pull_count):
res = rpc.recv_response()
total += res.num_local_nonzero
return total
KVCLIENT = None KVCLIENT = None
def init_kvstore(ip_config, num_servers, role): def init_kvstore(ip_config, num_servers, role):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment