"tests/git@developer.sourcefind.cn:OpenDAS/pytorch3d.git" did not exist on "0f12c51646a5754122c6375103f06ac8ee8fca7d"
Unverified Commit 87b6997b authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[kvstore] support any data type for init_data() (#1465)

* support any data type for init_data

* update

* update
parent 6cd7c313
......@@ -21,7 +21,6 @@ if os.name != 'nt':
import fcntl
import struct
def read_ip_config(filename):
"""Read network configuration information of kvstore from file.
......@@ -78,6 +77,29 @@ def read_ip_config(filename):
return server_namebook
def get_type_str(dtype):
"""Get data type string
"""
if 'float16' in str(dtype):
return 'float16'
elif 'float32' in str(dtype):
return 'float32'
elif 'float64' in str(dtype):
return 'float64'
elif 'uint8' in str(dtype):
return 'uint8'
elif 'int8' in str(dtype):
return 'int8'
elif 'int16' in str(dtype):
return 'int16'
elif 'int32' in str(dtype):
return 'int32'
elif 'int64' in str(dtype):
return 'int64'
else:
raise RuntimeError('Unknown data type: %s' % str(dtype))
class KVServer(object):
"""KVServer is a lightweight key-value store service for DGL distributed training.
......@@ -184,12 +206,13 @@ class KVServer(object):
if global2local is not None: # Create shared-tensor
if isinstance(global2local, list):
global2local = F.tensor(global2local)
assert 'int64' == get_type_str(F.dtype(global2local)), 'global2local must be int64 type.'
shared_data = empty_shared_mem(name+'-g2l-', True, global2local.shape, 'int64')
dlpack = shared_data.to_dlpack()
self._data_store[name+'-g2l-'] = F.zerocopy_from_dlpack(dlpack)
self._data_store[name+'-g2l-'][:] = global2local[:]
# write data information to temp file that can be read by other processes
self._write_data_shape(name+'-g2l-shape-'+str(self._machine_id), global2local)
self._write_data_shape_type(name+'-g2l-shape-'+str(self._machine_id), global2local)
self._open_file_list.append(name+'-g2l-shape-'+str(self._machine_id))
else: # Read shared-tensor
while True:
......@@ -198,7 +221,8 @@ class KVServer(object):
break
else:
time.sleep(2) # wait until the file been created
data_shape = self._read_data_shape(name+'-g2l-shape-'+str(self._machine_id))
data_shape, data_type = self._read_data_shape_type(name+'-g2l-shape-'+str(self._machine_id))
assert data_type == 'int64'
shared_data = empty_shared_mem(name+'-g2l-', False, data_shape, 'int64')
dlpack = shared_data.to_dlpack()
self._data_store[name+'-g2l-'] = F.zerocopy_from_dlpack(dlpack)
......@@ -223,11 +247,12 @@ class KVServer(object):
if partition_book is not None: # Create shared-tensor
if isinstance(partition_book, list):
partition_book = F.tensor(partition_book)
assert 'int64' == get_type_str(F.dtype(partition_book)), 'partition_book must be int64 type.'
shared_data = empty_shared_mem(name+'-part-', True, partition_book.shape, 'int64')
dlpack = shared_data.to_dlpack()
self._data_store[name+'-part-'] = F.zerocopy_from_dlpack(dlpack)
self._data_store[name+'-part-'][:] = partition_book[:]
self._write_data_shape(name+'-part-shape-'+str(self._machine_id), partition_book)
self._write_data_shape_type(name+'-part-shape-'+str(self._machine_id), partition_book)
self._open_file_list.append(name+'-part-shape-'+str(self._machine_id))
else: # Read shared-tensor
while True:
......@@ -236,7 +261,8 @@ class KVServer(object):
break
else:
time.sleep(2) # wait until the file been created
data_shape = self._read_data_shape(name+'-part-shape-'+str(self._machine_id))
data_shape, data_type = self._read_data_shape_type(name+'-part-shape-'+str(self._machine_id))
assert data_type == 'int64'
shared_data = empty_shared_mem(name+'-part-', False, data_shape, 'int64')
dlpack = shared_data.to_dlpack()
self._data_store[name+'-part-'] = F.zerocopy_from_dlpack(dlpack)
......@@ -259,11 +285,12 @@ class KVServer(object):
assert len(name) > 0, 'name cannot be empty.'
if data_tensor is not None: # Create shared-tensor
shared_data = empty_shared_mem(name+'-data-', True, data_tensor.shape, 'float32')
data_type = get_type_str(F.dtype(data_tensor))
shared_data = empty_shared_mem(name+'-data-', True, data_tensor.shape, data_type)
dlpack = shared_data.to_dlpack()
self._data_store[name+'-data-'] = F.zerocopy_from_dlpack(dlpack)
self._data_store[name+'-data-'][:] = data_tensor[:]
self._write_data_shape(name+'-data-shape-'+str(self._machine_id), data_tensor)
self._write_data_shape_type(name+'-data-shape-'+str(self._machine_id), data_tensor)
self._open_file_list.append(name+'-data-shape-'+str(self._machine_id))
else: # Read shared-tensor
while True:
......@@ -271,8 +298,8 @@ class KVServer(object):
break
else:
time.sleep(2) # wait until the file been created
data_shape = self._read_data_shape(name+'-data-shape-'+str(self._machine_id))
shared_data = empty_shared_mem(name+'-data-', False, data_shape, 'float32')
data_shape, data_type = self._read_data_shape_type(name+'-data-shape-'+str(self._machine_id))
shared_data = empty_shared_mem(name+'-data-', False, data_shape, data_type)
dlpack = shared_data.to_dlpack()
self._data_store[name+'-data-'] = F.zerocopy_from_dlpack(dlpack)
......@@ -479,7 +506,7 @@ class KVServer(object):
----------
name : str
tensor name
dtype : str
dtype : dtype
data type
Returns
......@@ -491,17 +518,11 @@ class KVServer(object):
str_data = name
str_data += '/'
if 'float32' in str(dtype):
str_data += 'float32'
elif 'int64' in str(dtype):
str_data += 'int64'
else:
raise RuntimeError('We can only process int64 and float32 shared-memory tensor now.')
str_data += get_type_str(dtype)
return str_data
def _write_data_shape(self, filename, data):
def _write_data_shape_type(self, filename, data):
"""Write data shape to a temp file.
Parameters
......@@ -518,6 +539,8 @@ class KVServer(object):
shape = F.shape(data)
str_data = ''
str_data += get_type_str(F.dtype(data))
str_data += '|'
f = open(filename, "a");
for s in shape:
str_data += str(s)
......@@ -526,7 +549,7 @@ class KVServer(object):
f.close()
def _read_data_shape(self, filename):
def _read_data_shape_type(self, filename):
"""Read data shape from a tmp file.
Parameters
......@@ -544,12 +567,13 @@ class KVServer(object):
f = open(filename, "r")
str_data = f.read()
data_list = str_data.split('|')
data_type = data_list[0]
data_shape = []
for i in range(len(data_list)-1):
for i in range(1, len(data_list)-1):
data_shape.append(int(data_list[i]))
f.close()
return data_shape
return data_shape, data_type
def _default_push_handler(self, name, ID, data, target):
......@@ -721,7 +745,8 @@ class KVClient(object):
break
else:
time.sleep(2) # wait until the file been created
shape = self._read_data_shape(tensor_name+'shape-'+str(self._machine_id))
shape, data_type = self._read_data_shape_type(tensor_name+'shape-'+str(self._machine_id))
assert data_type == dtype
shared_data = empty_shared_mem(tensor_name, False, shape, dtype)
dlpack = shared_data.to_dlpack()
self._data_store[tensor_name] = F.zerocopy_from_dlpack(dlpack)
......@@ -1124,7 +1149,7 @@ class KVClient(object):
f.close()
def _read_data_shape(self, filename):
def _read_data_shape_type(self, filename):
"""Read data shape from a tmp file.
Parameters
......@@ -1142,13 +1167,13 @@ class KVClient(object):
f = open(filename, "r")
str_data = f.read()
data_list = str_data.split('|')
data_type = data_list[0]
data_shape = []
for i in range(len(data_list)-1):
for i in range(1, len(data_list)-1):
data_shape.append(int(data_list[i]))
f.close()
return data_shape
return data_shape, data_type
def _takeId(self, elem):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment