Unverified Commit 7c7b60be authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[BugFix] add count_nonzero() into SA_Client (#3417)

parent 2d88db5a
......@@ -87,3 +87,18 @@ class KVClient(object):
def map_shared_data(self, partition_book):
'''Mapping shared-memory tensor from server to client.'''
def count_nonzero(self, name):
"""Count nonzero value by pull request from KVServers.
Parameters
----------
name : str
data name
Returns
-------
int
the number of nonzero in this data.
"""
return F.count_nonzero(self._data[name])
......@@ -246,9 +246,11 @@ def check_dist_graph(g, num_clients, num_nodes, num_edges):
# Test init node data
new_shape = (g.number_of_nodes(), 2)
g.ndata['test1'] = dgl.distributed.DistTensor(new_shape, F.int32)
test1 = dgl.distributed.DistTensor(new_shape, F.int32)
g.ndata['test1'] = test1
feats = g.ndata['test1'][nids]
assert np.all(F.asnumpy(feats) == 0)
assert test1.count_nonzero() == 0
# reference to a one that exists
test2 = dgl.distributed.DistTensor(new_shape, F.float32, 'test2', init_func=rand_init)
......@@ -618,10 +620,7 @@ def test_standalone():
dgl.distributed.initialize("kv_ip_config.txt")
dist_g = DistGraph(graph_name, part_config='/tmp/dist_graph/{}.json'.format(graph_name))
try:
check_dist_graph(dist_g, 1, g.number_of_nodes(), g.number_of_edges())
except Exception as e:
print(e)
check_dist_graph(dist_g, 1, g.number_of_nodes(), g.number_of_edges())
dgl.distributed.exit_client() # this is needed since there's two test here in one process
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support distributed DistEmbedding")
......@@ -639,10 +638,7 @@ def test_standalone_node_emb():
dgl.distributed.initialize("kv_ip_config.txt")
dist_g = DistGraph(graph_name, part_config='/tmp/dist_graph/{}.json'.format(graph_name))
try:
check_dist_emb(dist_g, 1, g.number_of_nodes(), g.number_of_edges())
except Exception as e:
print(e)
check_dist_emb(dist_g, 1, g.number_of_nodes(), g.number_of_edges())
dgl.distributed.exit_client() # this is needed since there's two test here in one process
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment