Unverified Commit 50962c43 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

Refactor distributed graph store (#1597)

* refactor graph store for new kvstore.

* fix kvstore.

* fix.

* fix lint complains.

* add docstring.

* small fix in graph partition book.

* fix.

* remove tests

* disable sampling.

* Revert "disable sampling."

This reverts commit 70451008f61ff1481d5dadbf8bd199470aee559d.

* Revert "remove tests"

This reverts commit 1394364243bdd73b669abde6193a74e2cda5521d.
parent 8531ee6a
"""Define distributed graph."""
import socket
from collections.abc import MutableMapping
import numpy as np
from ..graph import DGLGraph
from .. import backend as F
from ..base import NID, EID
from ..contrib.dis_kvstore import KVServer, KVClient
from .kvstore import KVServer, KVClient
from ..graph_index import from_shared_mem_graph_index
from .._ffi.ndarray import empty_shared_mem
from ..frame import infer_scheme
from .partition import load_partition
from .graph_partition_book import GraphPartitionBook
from .. import ndarray as nd
from .graph_partition_book import GraphPartitionBook, PartitionPolicy, get_shared_mem_partition_book
from .. import utils
def _get_ndata_path(graph_name, ndata_name):
return "/" + graph_name + "_node_" + ndata_name
def _get_edata_path(graph_name, edata_name):
return "/" + graph_name + "_edge_" + edata_name
from .shared_mem_utils import _to_shared_mem, _get_ndata_path, _get_edata_path, DTYPE_DICT
from .rpc_client import connect_to_server
from .server_state import ServerState
from .rpc_server import start_server
def _get_graph_path(graph_name):
return "/" + graph_name
DTYPE_DICT = F.data_type_dict
DTYPE_DICT = {DTYPE_DICT[key]:key for key in DTYPE_DICT}
def _move_data_to_shared_mem_array(arr, name):
dlpack = F.zerocopy_to_dlpack(arr)
dgl_tensor = nd.from_dlpack(dlpack)
new_arr = empty_shared_mem(name, True, F.shape(arr), DTYPE_DICT[F.dtype(arr)])
dgl_tensor.copyto(new_arr)
dlpack = new_arr.to_dlpack()
return F.zerocopy_from_dlpack(dlpack)
def _copy_graph_to_shared_mem(g, graph_name):
gidx = g._graph.copyto_shared_mem(_get_graph_path(graph_name))
new_g = DGLGraph(gidx)
# We should share the node/edge data to the client explicitly instead of putting them
# in the KVStore because some of the node/edge data may be duplicated.
local_node_path = _get_ndata_path(graph_name, 'local_node')
new_g.ndata['local_node'] = _move_data_to_shared_mem_array(g.ndata['local_node'],
local_node_path)
new_g.ndata['local_node'] = _to_shared_mem(g.ndata['local_node'],
local_node_path)
local_edge_path = _get_edata_path(graph_name, 'local_edge')
new_g.edata['local_edge'] = _move_data_to_shared_mem_array(g.edata['local_edge'],
local_edge_path)
new_g.ndata[NID] = _move_data_to_shared_mem_array(g.ndata[NID],
_get_ndata_path(graph_name, NID))
new_g.edata[EID] = _move_data_to_shared_mem_array(g.edata[EID],
_get_edata_path(graph_name, EID))
new_g.edata['local_edge'] = _to_shared_mem(g.edata['local_edge'], local_edge_path)
new_g.ndata[NID] = _to_shared_mem(g.ndata[NID], _get_ndata_path(graph_name, NID))
new_g.edata[EID] = _to_shared_mem(g.edata[EID], _get_edata_path(graph_name, EID))
return new_g
FIELD_DICT = {'local_node': F.int64,
......@@ -124,45 +105,6 @@ def _get_graph_from_shared_mem(graph_name):
g.edata[EID] = _get_shared_mem_edata(g, graph_name, EID)
return g
def _move_metadata_to_shared_mam(graph_name, num_nodes, num_edges, part_id,
num_partitions, node_map, edge_map):
''' Move all metadata to the shared memory.
We need these metadata to construct graph partition book.
'''
meta = _move_data_to_shared_mem_array(F.tensor([num_nodes, num_edges,
num_partitions, part_id]),
_get_ndata_path(graph_name, 'meta'))
node_map = _move_data_to_shared_mem_array(node_map, _get_ndata_path(graph_name, 'node_map'))
edge_map = _move_data_to_shared_mem_array(edge_map, _get_edata_path(graph_name, 'edge_map'))
return meta, node_map, edge_map
def _get_shared_mem_metadata(graph_name):
''' Get the metadata of the graph through shared memory.
The metadata includes the number of nodes and the number of edges. In the future,
we can add more information, especially for heterograph.
'''
shape = (4,)
dtype = F.int64
dtype = DTYPE_DICT[dtype]
data = empty_shared_mem(_get_ndata_path(graph_name, 'meta'), False, shape, dtype)
dlpack = data.to_dlpack()
meta = F.asnumpy(F.zerocopy_from_dlpack(dlpack))
num_nodes, num_edges, num_partitions, part_id = meta[0], meta[1], meta[2], meta[3]
# Load node map
data = empty_shared_mem(_get_ndata_path(graph_name, 'node_map'), False, (num_nodes,), dtype)
dlpack = data.to_dlpack()
node_map = F.zerocopy_from_dlpack(dlpack)
# Load edge_map
data = empty_shared_mem(_get_edata_path(graph_name, 'edge_map'), False, (num_edges,), dtype)
dlpack = data.to_dlpack()
edge_map = F.zerocopy_from_dlpack(dlpack)
return num_nodes, num_edges, part_id, num_partitions, node_map, edge_map
class DistTensor:
''' Distributed tensor.
......@@ -314,81 +256,54 @@ class DistGraphServer(KVServer):
----------
server_id : int
The server ID (start from 0).
server_namebook: dict
IP address namebook of KVServer, where key is the KVServer's ID
(start from 0) and value is the server's machine_id, IP address and port, e.g.,
{0:'[0, 172.31.40.143, 30050],
1:'[0, 172.31.40.143, 30051],
2:'[1, 172.31.36.140, 30050],
3:'[1, 172.31.36.140, 30051],
4:'[2, 172.31.47.147, 30050],
5:'[2, 172.31.47.147, 30051],
6:'[3, 172.31.30.180, 30050],
7:'[3, 172.31.30.180, 30051]}
num_client : int
ip_config : str
Path of IP configuration file.
num_clients : int
Total number of client nodes.
graph_name : string
The name of the graph. The server and the client need to specify the same graph name.
conf_file : string
The path of the config file generated by the partition tool.
'''
def __init__(self, server_id, server_namebook, num_client, graph_name, conf_file):
super(DistGraphServer, self).__init__(server_id=server_id, server_namebook=server_namebook,
num_client=num_client)
host_name = socket.gethostname()
host_ip = socket.gethostbyname(host_name)
print('Server {}: host name: {}, ip: {}'.format(server_id, host_name, host_ip))
def __init__(self, server_id, ip_config, num_clients, graph_name, conf_file):
super(DistGraphServer, self).__init__(server_id=server_id, ip_config=ip_config,
num_clients=num_clients)
self.ip_config = ip_config
# Load graph partition data.
self.client_g, node_feats, edge_feats, self.meta = load_partition(conf_file, server_id)
num_nodes, num_edges, node_map, edge_map, num_partitions = self.meta
_, _, node_map, edge_map, num_partitions = self.meta
self.client_g = _copy_graph_to_shared_mem(self.client_g, graph_name)
# Create node global2local map.
node_g2l = F.zeros((num_nodes), dtype=F.int64, ctx=F.cpu()) - 1
# The nodes that belong to this partition.
local_nids = F.nonzero_1d(self.client_g.ndata['local_node'])
nids = F.asnumpy(F.gather_row(self.client_g.ndata[NID], local_nids))
assert np.all(node_map[nids] == server_id), 'Load a wrong partition'
F.scatter_row_inplace(node_g2l, nids, F.arange(0, len(nids)))
# Create edge global2local map.
if len(edge_feats) > 0:
edge_g2l = F.zeros((num_edges), dtype=F.int64, ctx=F.cpu()) - 1
local_eids = F.nonzero_1d(self.client_g.edata['local_edge'])
eids = F.asnumpy(F.gather_row(self.client_g.edata[EID], local_eids))
assert np.all(edge_map[eids] == server_id), 'Load a wrong partition'
F.scatter_row_inplace(edge_g2l, eids, F.arange(0, len(eids)))
node_map = F.zerocopy_from_numpy(node_map)
edge_map = F.zerocopy_from_numpy(edge_map)
if self.get_id() % self.get_group_count() == 0: # master server
# Init kvstore.
self.gpb = GraphPartitionBook(server_id, num_partitions, node_map, edge_map, self.client_g)
self.gpb.shared_memory(graph_name)
self.add_part_policy(PartitionPolicy('node', server_id, self.gpb))
self.add_part_policy(PartitionPolicy('edge', server_id, self.gpb))
if not self.is_backup_server():
for name in node_feats:
self.set_global2local(name=_get_ndata_name(name), global2local=node_g2l)
self.init_data(name=_get_ndata_name(name), data_tensor=node_feats[name])
self.set_partition_book(name=_get_ndata_name(name), partition_book=node_map)
self.init_data(name=_get_ndata_name(name), policy_str='node',
data_tensor=node_feats[name])
for name in edge_feats:
self.set_global2local(name=_get_edata_name(name), global2local=edge_g2l)
self.init_data(name=_get_edata_name(name), data_tensor=edge_feats[name])
self.set_partition_book(name=_get_edata_name(name), partition_book=edge_map)
self.init_data(name=_get_edata_name(name), policy_str='edge',
data_tensor=edge_feats[name])
else:
for name in node_feats:
self.set_global2local(name=_get_ndata_name(name))
self.init_data(name=_get_ndata_name(name))
self.set_partition_book(name=_get_ndata_name(name), partition_book=node_map)
self.init_data(name=_get_ndata_name(name), policy_str='node')
for name in edge_feats:
self.set_global2local(name=_get_edata_name(name))
self.init_data(name=_get_edata_name(name))
self.set_partition_book(name=_get_edata_name(name), partition_book=edge_map)
# TODO(zhengda) this is temporary solution. We don't need this in the future.
self.meta, self.node_map, self.edge_map = _move_metadata_to_shared_mam(graph_name,
num_nodes,
num_edges,
server_id,
num_partitions,
node_map, edge_map)
self.init_data(name=_get_edata_name(name), policy_str='edge')
def start(self):
""" Start graph store server.
"""
# start server
server_state = ServerState(kv_store=self)
start_server(server_id=0, ip_config=self.ip_config,
num_clients=self.num_clients, server_state=server_state)
def _default_init_data(shape, dtype):
return F.zeros(shape, dtype, F.cpu())
class DistGraph:
''' The DistGraph client.
......@@ -399,34 +314,23 @@ class DistGraph:
Parameters
----------
server_namebook: dict
IP address namebook of KVServer, where key is the KVServer's ID
(start from 0) and value is the server's machine_id, IP address and port,
and group_count, e.g.,
{0:'[0, 172.31.40.143, 30050, 2],
1:'[0, 172.31.40.143, 30051, 2],
2:'[1, 172.31.36.140, 30050, 2],
3:'[1, 172.31.36.140, 30051, 2],
4:'[2, 172.31.47.147, 30050, 2],
5:'[2, 172.31.47.147, 30051, 2],
6:'[3, 172.31.30.180, 30050, 2],
7:'[3, 172.31.30.180, 30051, 2]}
ip_config : str
Path of IP configuration file.
graph_name : str
The name of the graph. This name has to be the same as the one used in DistGraphServer.
'''
def __init__(self, server_namebook, graph_name):
self._client = KVClient(server_namebook=server_namebook)
self._client.connect()
def __init__(self, ip_config, graph_name):
connect_to_server(ip_config=ip_config)
self._client = KVClient(ip_config)
self._g = _get_graph_from_shared_mem(graph_name)
self._tot_num_nodes, self._tot_num_edges, self._part_id, num_parts, node_map, \
edge_map = _get_shared_mem_metadata(graph_name)
self._gpb = GraphPartitionBook(self._part_id, num_parts, node_map, edge_map, self._g)
self._gpb = get_shared_mem_partition_book(graph_name, self._g)
self._client.barrier()
self._client.map_shared_data(self._gpb)
self._ndata = NodeDataView(self)
self._edata = EdgeDataView(self)
self._default_init_ndata = _default_init_data
self._default_init_edata = _default_init_data
def init_ndata(self, ndata_name, shape, dtype):
......@@ -444,10 +348,8 @@ class DistGraph:
The data type of the node data.
'''
assert shape[0] == self.number_of_nodes()
names = self._ndata._get_names()
# TODO we need to fix this. We should be able to init ndata even when there is no node data.
assert len(names) > 0
self._client.init_data(_get_ndata_name(ndata_name), shape, dtype, _get_ndata_name(names[0]))
self._client.init_data(_get_ndata_name(ndata_name), shape, dtype, 'node', self._gpb,
self._default_init_ndata)
self._ndata._add(ndata_name)
def init_edata(self, edata_name, shape, dtype):
......@@ -465,10 +367,8 @@ class DistGraph:
The data type of the edge data.
'''
assert shape[0] == self.number_of_edges()
names = self._edata._get_names()
# TODO we need to fix this. We should be able to init ndata even when there is no edge data.
assert len(names) > 0
self._client.init_data(_get_edata_name(edata_name), shape, dtype, _get_edata_name(names[0]))
self._client.init_data(_get_edata_name(edata_name), shape, dtype, 'edge', self._gpb,
self._default_init_edata)
self._edata._add(edata_name)
def init_node_emb(self, name, shape, dtype, initializer):
......@@ -525,11 +425,11 @@ class DistGraph:
def number_of_nodes(self):
"""Return the number of nodes"""
return self._tot_num_nodes
return self._gpb.num_nodes()
def number_of_edges(self):
"""Return the number of edges"""
return self._tot_num_edges
return self._gpb.num_edges()
def node_attr_schemes(self):
"""Return the node feature and embedding schemes."""
......@@ -553,10 +453,7 @@ class DistGraph:
int
The rank of the current graph store.
'''
# Here the rank of the client should be the same as the partition Id to ensure
# that we always get the local partition.
# TODO(zhengda) we need to change this if we support two-level partitioning.
return self._part_id
return self._client.client_id
def get_partition_book(self):
"""Get the partition information.
......@@ -568,20 +465,10 @@ class DistGraph:
"""
return self._gpb
def shut_down(self):
"""Shut down all KVServer nodes.
We usually invoke this API by just one client (e.g., client_0).
"""
# We have to remove them. Otherwise, kvstore cannot shut down correctly.
self._ndata = None
self._edata = None
self._client.shut_down()
def _get_all_ndata_names(self):
''' Get the names of all node data.
'''
names = self._client.get_data_name_list()
names = self._client.data_name_list()
ndata_names = []
for name in names:
if _is_ndata_name(name):
......@@ -592,7 +479,7 @@ class DistGraph:
def _get_all_edata_names(self):
''' Get the names of all edge data.
'''
names = self._client.get_data_name_list()
names = self._client.data_name_list()
edata_names = []
for name in names:
if _is_edata_name(name):
......
......@@ -5,6 +5,69 @@ import numpy as np
from .. import backend as F
from ..base import NID, EID
from .. import utils
from .shared_mem_utils import _to_shared_mem, _get_ndata_path, _get_edata_path, DTYPE_DICT
from .._ffi.ndarray import empty_shared_mem
def _move_metadata_to_shared_mam(graph_name, num_nodes, num_edges, part_id,
num_partitions, node_map, edge_map):
''' Move all metadata to the shared memory.
We need these metadata to construct graph partition book.
'''
meta = _to_shared_mem(F.tensor([num_nodes, num_edges,
num_partitions, part_id]),
_get_ndata_path(graph_name, 'meta'))
node_map = _to_shared_mem(node_map, _get_ndata_path(graph_name, 'node_map'))
edge_map = _to_shared_mem(edge_map, _get_edata_path(graph_name, 'edge_map'))
return meta, node_map, edge_map
def _get_shared_mem_metadata(graph_name):
''' Get the metadata of the graph through shared memory.
The metadata includes the number of nodes and the number of edges. In the future,
we can add more information, especially for heterograph.
'''
shape = (4,)
dtype = F.int64
dtype = DTYPE_DICT[dtype]
data = empty_shared_mem(_get_ndata_path(graph_name, 'meta'), False, shape, dtype)
dlpack = data.to_dlpack()
meta = F.asnumpy(F.zerocopy_from_dlpack(dlpack))
num_nodes, num_edges, num_partitions, part_id = meta[0], meta[1], meta[2], meta[3]
# Load node map
data = empty_shared_mem(_get_ndata_path(graph_name, 'node_map'), False, (num_nodes,), dtype)
dlpack = data.to_dlpack()
node_map = F.zerocopy_from_dlpack(dlpack)
# Load edge_map
data = empty_shared_mem(_get_edata_path(graph_name, 'edge_map'), False, (num_edges,), dtype)
dlpack = data.to_dlpack()
edge_map = F.zerocopy_from_dlpack(dlpack)
return part_id, num_partitions, node_map, edge_map
def get_shared_mem_partition_book(graph_name, graph_part):
'''Get a graph partition book from shared memory.
A graph partition book of a specific graph can be serialized to shared memory.
We can reconstruct a graph partition book from shared memory.
Parameters
----------
graph_name : str
The name of the graph.
graph_part : DGLGraph
The graph structure of a partition.
Returns
-------
GraphPartitionBook
A graph partition book for a particular partition.
'''
part_id, num_parts, node_map, edge_map = _get_shared_mem_metadata(graph_name)
return GraphPartitionBook(part_id, num_parts, node_map, edge_map, graph_part)
class GraphPartitionBook:
"""GraphPartitionBook is used to store parition information.
......@@ -78,6 +141,18 @@ class GraphPartitionBook:
self._edge_size = len(self.partid2eids(part_id))
self._node_size = len(self.partid2nids(part_id))
def shared_memory(self, graph_name):
"""Move data to shared memory.
Parameters
----------
graph_name : str
The graph name
"""
self._meta, self._nid2partid, self._eid2partid = _move_metadata_to_shared_mam(
graph_name, self.num_nodes(), self.num_edges(), self._part_id, self._num_partitions,
self._nid2partid, self._eid2partid)
def num_partitions(self):
"""Return the number of partitions.
......@@ -111,6 +186,16 @@ class GraphPartitionBook:
"""
return self._partition_meta_data
def num_nodes(self):
""" The total number of nodes
"""
return len(self._nid2partid)
def num_edges(self):
""" The total number of edges
"""
return len(self._eid2partid)
def nid2partid(self, nids):
"""From global node IDs to partition IDs
......@@ -231,22 +316,22 @@ class GraphPartitionBook:
return self._graph
def get_node_size(self):
"""Get node size
"""Get the number of nodes in the current partition.
Return
------
int
node size in current partition
The number of nodes in current partition
"""
return self._node_size
def get_edge_size(self):
"""Get edge size
"""Get the number of edges in the current partition.
Return
------
int
edge size in current partition
The number of edges in current partition
"""
return self._edge_size
......
......@@ -375,7 +375,11 @@ class GetPartShapeResponse(rpc.Response):
return self.shape
def __setstate__(self, state):
self.shape = state
# When the shape has only one dimension, state is an integer.
if isinstance(state, int):
self.shape = (state,)
else:
self.shape = state
class GetPartShapeRequest(rpc.Request):
"""Send data name to get the partitioned data shape from server.
......
"""Define utility functions for shared memory."""
from .. import backend as F
from .._ffi.ndarray import empty_shared_mem
from .. import ndarray as nd
DTYPE_DICT = F.data_type_dict
DTYPE_DICT = {DTYPE_DICT[key]:key for key in DTYPE_DICT}
def _get_ndata_path(graph_name, ndata_name):
return "/" + graph_name + "_node_" + ndata_name
def _get_edata_path(graph_name, edata_name):
return "/" + graph_name + "_edge_" + edata_name
def _to_shared_mem(arr, name):
dlpack = F.zerocopy_to_dlpack(arr)
dgl_tensor = nd.from_dlpack(dlpack)
new_arr = empty_shared_mem(name, True, F.shape(arr), DTYPE_DICT[F.dtype(arr)])
dgl_tensor.copyto(new_arr)
dlpack = new_arr.to_dlpack()
return F.zerocopy_from_dlpack(dlpack)
......@@ -16,15 +16,13 @@ import backend as F
import unittest
import pickle
server_namebook = {0: [0, '127.0.0.1', 30000, 1]}
def create_random_graph(n):
arr = (spsp.random(n, n, density=0.001, format='coo') != 0).astype(np.int64)
ig = create_graph_index(arr, readonly=True)
return dgl.DGLGraph(ig)
def run_server(graph_name, server_id, num_clients, barrier):
g = DistGraphServer(server_id, server_namebook, num_clients, graph_name,
g = DistGraphServer(server_id, "kv_ip_config.txt", num_clients, graph_name,
'/tmp/{}.json'.format(graph_name))
barrier.wait()
print('start server', server_id)
......@@ -32,7 +30,7 @@ def run_server(graph_name, server_id, num_clients, barrier):
def run_client(graph_name, barrier, num_nodes, num_edges):
barrier.wait()
g = DistGraph(server_namebook, graph_name)
g = DistGraph("kv_ip_config.txt", graph_name)
# Test API
assert g.number_of_nodes() == num_nodes
......@@ -85,11 +83,14 @@ def run_client(graph_name, barrier, num_nodes, num_edges):
for n in nodes:
assert n in local_nids
g.shut_down()
# clean up
dgl.distributed.shutdown_servers()
dgl.distributed.finalize_client()
print('end')
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support some of operations in DistGraph")
def test_server_client():
prepare_dist()
g = create_random_graph(10000)
# Partition the graph
......@@ -121,6 +122,7 @@ def test_server_client():
print('clients have terminated')
def test_split():
prepare_dist()
g = create_random_graph(10000)
num_parts = 4
num_hops = 2
......@@ -156,6 +158,11 @@ def test_split():
for e in edges1:
assert e in local_eids
def prepare_dist():
ip_config = open("kv_ip_config.txt", "w")
ip_config.write('127.0.0.1 2500 1\n')
ip_config.close()
if __name__ == '__main__':
test_split()
test_server_client()
......@@ -78,6 +78,9 @@ edge_policy = dgl.distributed.PartitionPolicy(policy_str='edge',
partition_book=gpb)
data_0 = F.tensor([[1.,1.],[1.,1.],[1.,1.],[1.,1.],[1.,1.],[1.,1.]], F.float32)
data_0_1 = F.tensor([1.,2.,3.,4.,5.,6.], F.float32)
data_0_2 = F.tensor([1,2,3,4,5,6], F.int32)
data_0_3 = F.tensor([1,2,3,4,5,6], F.int64)
data_1 = F.tensor([[2.,2.],[2.,2.],[2.,2.],[2.,2.],[2.,2.],[2.,2.],[2.,2.]], F.float32)
data_2 = F.tensor([[0.,0.],[0.,0.],[0.,0.],[0.,0.],[0.,0.],[0.,0.]], F.float32)
......@@ -112,6 +115,9 @@ def start_server():
kvserver.add_part_policy(node_policy)
kvserver.add_part_policy(edge_policy)
kvserver.init_data('data_0', 'node', data_0)
kvserver.init_data('data_0_1', 'node', data_0_1)
kvserver.init_data('data_0_2', 'node', data_0_2)
kvserver.init_data('data_0_3', 'node', data_0_3)
# start server
server_state = dgl.distributed.ServerState(kv_store=kvserver)
dgl.distributed.start_server(server_id=0,
......@@ -143,6 +149,9 @@ def start_client():
name_list = kvclient.data_name_list()
print(name_list)
assert 'data_0' in name_list
assert 'data_0_1' in name_list
assert 'data_0_2' in name_list
assert 'data_0_3' in name_list
assert 'data_1' in name_list
assert 'data_2' in name_list
# Test get_meta_data
......@@ -151,16 +160,37 @@ def start_client():
assert dtype == F.dtype(data_0)
assert shape == F.shape(data_0)
assert policy.policy_str == 'node'
meta = kvclient.get_data_meta('data_0_1')
dtype, shape, policy = meta
assert dtype == F.dtype(data_0_1)
assert shape == F.shape(data_0_1)
assert policy.policy_str == 'node'
meta = kvclient.get_data_meta('data_0_2')
dtype, shape, policy = meta
assert dtype == F.dtype(data_0_2)
assert shape == F.shape(data_0_2)
assert policy.policy_str == 'node'
meta = kvclient.get_data_meta('data_0_3')
dtype, shape, policy = meta
assert dtype == F.dtype(data_0_3)
assert shape == F.shape(data_0_3)
assert policy.policy_str == 'node'
meta = kvclient.get_data_meta('data_1')
dtype, shape, policy = meta
assert dtype == F.dtype(data_1)
assert shape == F.shape(data_1)
assert policy.policy_str == 'edge'
meta = kvclient.get_data_meta('data_2')
dtype, shape, policy = meta
assert dtype == F.dtype(data_2)
assert shape == F.shape(data_2)
assert policy.policy_str == 'node'
# Test push and pull
id_tensor = F.tensor([0,2,4], F.int64)
data_tensor = F.tensor([[6.,6.],[6.,6.],[6.,6.]], F.float32)
......@@ -217,4 +247,4 @@ def test_kv_store():
if __name__ == '__main__':
test_partition_policy()
test_kv_store()
\ No newline at end of file
test_kv_store()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment