Unverified Commit ff9f67ee authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

Send any shape of tensor rather than matrix (#942)

parent ae3102d3
...@@ -19,6 +19,7 @@ def start_client(args): ...@@ -19,6 +19,7 @@ def start_client(args):
# Initialize data on server # Initialize data on server
client.init_data(name='embed_0', shape=[10, 3], init_type='zero') client.init_data(name='embed_0', shape=[10, 3], init_type='zero')
client.init_data(name='embed_1', shape=[11, 3], init_type='uniform', low=0.0, high=0.0) client.init_data(name='embed_1', shape=[11, 3], init_type='uniform', low=0.0, high=0.0)
client.init_data(name='embed_2', shape=[11], init_type='zero')
tensor_id = mx.nd.array([0, 1, 2], dtype='int64') tensor_id = mx.nd.array([0, 1, 2], dtype='int64')
tensor_data = mx.nd.array([[0., 0., 0., ], [1., 1., 1.], [2., 2., 2.]]) tensor_data = mx.nd.array([[0., 0., 0., ], [1., 1., 1.], [2., 2., 2.]])
...@@ -26,11 +27,13 @@ def start_client(args): ...@@ -26,11 +27,13 @@ def start_client(args):
for i in range(5): for i in range(5):
client.push('embed_0', tensor_id, tensor_data) client.push('embed_0', tensor_id, tensor_data)
client.push('embed_1', tensor_id, tensor_data) client.push('embed_1', tensor_id, tensor_data)
client.push('embed_2', tensor_id, mx.nd.array([2., 2., 2.]))
tensor_id = mx.nd.array([6, 7, 8], dtype='int64') tensor_id = mx.nd.array([6, 7, 8], dtype='int64')
for i in range(5): for i in range(5):
client.push('embed_0', tensor_id, tensor_data) client.push('embed_0', tensor_id, tensor_data)
client.push('embed_1', tensor_id, tensor_data) client.push('embed_1', tensor_id, tensor_data)
client.push('embed_2', tensor_id, mx.nd.array([3., 3., 3.]))
client.barrier() client.barrier()
...@@ -39,16 +42,21 @@ def start_client(args): ...@@ -39,16 +42,21 @@ def start_client(args):
new_tensor_0 = client.pull('embed_0', tensor_id) new_tensor_0 = client.pull('embed_0', tensor_id)
tensor_id = mx.nd.array([0,1,2,3,4,5,6,7,8,9,10], dtype='int64') tensor_id = mx.nd.array([0,1,2,3,4,5,6,7,8,9,10], dtype='int64')
new_tensor_1 = client.pull('embed_1', tensor_id) new_tensor_1 = client.pull('embed_1', tensor_id)
new_tensor_2 = client.pull('embed_2', tensor_id)
client.push_all('embed_0', new_tensor_0) client.push_all('embed_0', new_tensor_0)
client.push_all('embed_1', new_tensor_1) client.push_all('embed_1', new_tensor_1)
client.push_all('embed_2', new_tensor_2)
new_tensor_2 = client.pull_all('embed_0') new_tensor_3 = client.pull_all('embed_0')
new_tensor_3 = client.pull_all('embed_1') new_tensor_4 = client.pull_all('embed_1')
new_tensor_5 = client.pull_all('embed_2')
print("embed_0: ") print("embed_0: ")
print(new_tensor_2)
print("embed_1: ")
print(new_tensor_3) print(new_tensor_3)
print("embed_1: ")
print(new_tensor_4)
print("embed_2: ")
print(new_tensor_5)
# Shut-down all the servers # Shut-down all the servers
if client.get_id() == 0: if client.get_id() == 0:
......
...@@ -4,6 +4,7 @@ import dgl ...@@ -4,6 +4,7 @@ import dgl
import torch import torch
import time import time
import argparse import argparse
import torch as th
server_namebook, client_namebook = dgl.contrib.ReadNetworkConfigure('config.txt') server_namebook, client_namebook = dgl.contrib.ReadNetworkConfigure('config.txt')
...@@ -19,6 +20,7 @@ def start_client(args): ...@@ -19,6 +20,7 @@ def start_client(args):
# Initialize data on server # Initialize data on server
client.init_data(name='embed_0', shape=[10, 3], init_type='zero') client.init_data(name='embed_0', shape=[10, 3], init_type='zero')
client.init_data(name='embed_1', shape=[11, 3], init_type='uniform', low=0.0, high=0.0) client.init_data(name='embed_1', shape=[11, 3], init_type='uniform', low=0.0, high=0.0)
client.init_data(name='embed_2', shape=[11], init_type='zero')
tensor_id = torch.tensor([0, 1, 2]) tensor_id = torch.tensor([0, 1, 2])
tensor_data = torch.tensor([[0., 0., 0., ], [1., 1., 1.], [2., 2., 2.]]) tensor_data = torch.tensor([[0., 0., 0., ], [1., 1., 1.], [2., 2., 2.]])
...@@ -26,11 +28,13 @@ def start_client(args): ...@@ -26,11 +28,13 @@ def start_client(args):
for i in range(5): for i in range(5):
client.push('embed_0', tensor_id, tensor_data) client.push('embed_0', tensor_id, tensor_data)
client.push('embed_1', tensor_id, tensor_data) client.push('embed_1', tensor_id, tensor_data)
client.push('embed_2', tensor_id, th.tensor([2., 2., 2.]))
tensor_id = torch.tensor([6, 7, 8]) tensor_id = torch.tensor([6, 7, 8])
for i in range(5): for i in range(5):
client.push('embed_0', tensor_id, tensor_data) client.push('embed_0', tensor_id, tensor_data)
client.push('embed_1', tensor_id, tensor_data) client.push('embed_1', tensor_id, tensor_data)
client.push('embed_2', tensor_id, th.tensor([3., 3., 3.]))
client.barrier() client.barrier()
...@@ -39,16 +43,21 @@ def start_client(args): ...@@ -39,16 +43,21 @@ def start_client(args):
new_tensor_0 = client.pull('embed_0', tensor_id) new_tensor_0 = client.pull('embed_0', tensor_id)
tensor_id = torch.tensor([0,1,2,3,4,5,6,7,8,9,10]) tensor_id = torch.tensor([0,1,2,3,4,5,6,7,8,9,10])
new_tensor_1 = client.pull('embed_1', tensor_id) new_tensor_1 = client.pull('embed_1', tensor_id)
new_tensor_2 = client.pull('embed_2', tensor_id)
client.push_all('embed_0', new_tensor_0) client.push_all('embed_0', new_tensor_0)
client.push_all('embed_1', new_tensor_1) client.push_all('embed_1', new_tensor_1)
client.push_all('embed_2', new_tensor_2)
new_tensor_2 = client.pull_all('embed_0') new_tensor_3 = client.pull_all('embed_0')
new_tensor_3 = client.pull_all('embed_1') new_tensor_4 = client.pull_all('embed_1')
new_tensor_5 = client.pull_all('embed_2')
print("embed_0:") print("embed_0:")
print(new_tensor_2)
print("embed_1:")
print(new_tensor_3) print(new_tensor_3)
print("embed_1:")
print(new_tensor_4)
print("embed_2:")
print(new_tensor_5)
# Shut-down all the servers # Shut-down all the servers
if client.get_id() == 0: if client.get_id() == 0:
......
...@@ -13,6 +13,16 @@ import numpy as np ...@@ -13,6 +13,16 @@ import numpy as np
def ReadNetworkConfigure(filename): def ReadNetworkConfigure(filename):
"""Read networking configuration from file. """Read networking configuration from file.
The config file is like:
server 172.31.40.143:50050 0
client 172.31.40.143:50051 0
client 172.31.36.140:50051 1
client 172.31.47.147:50051 2
client 172.31.30.180:50051 3
Here we have 1 server node and 4 client nodes.
Parameters Parameters
---------- ----------
filename : str filename : str
...@@ -251,14 +261,16 @@ class KVServer(object): ...@@ -251,14 +261,16 @@ class KVServer(object):
class KVClient(object): class KVClient(object):
"""KVClient is used to push/pull tensors to/from KVServer on DGL trainer. """KVClient is used to push/pull tensors to/from KVServer on DGL trainer.
There are three operations supported by KVClient: There are five operations supported by KVClient:
* init_data(name, shape, low, high): initialize tensor on KVServer * init_data(name, shape, init_type, low, high): initialize tensor on KVServer
* push(name, id, data): push data to KVServer * push(name, id, data): push sparse data to KVServer given specified IDs
* pull(name, id): pull data from KVServer * pull(name, id): pull sparse data from KVServer given specified IDs
* push_all(name, data): push dense data to KVServer
* pull_all(name): pull sense data from KVServer
* shut_down(): shut down all KVServer nodes * shut_down(): shut down all KVServer nodes
DO NOT use KVClient in multiple threads! Note that, DO NOT use KVClient in multiple threads!
Parameters Parameters
---------- ----------
...@@ -277,9 +289,9 @@ class KVClient(object): ...@@ -277,9 +289,9 @@ class KVClient(object):
networking type, e.g., 'socket' (default) or 'mpi'. networking type, e.g., 'socket' (default) or 'mpi'.
""" """
def __init__(self, client_id, server_namebook, client_addr, net_type='socket'): def __init__(self, client_id, server_namebook, client_addr, net_type='socket'):
assert client_id >= 0, 'client_id cannot be a nagative number.' assert client_id >= 0, 'client_id (%d) cannot be a nagative number.' % client_id
assert len(server_namebook) > 0, 'server_namebook cannot be empty.' assert len(server_namebook) > 0, 'server_namebook cannot be empty.'
assert len(client_addr.split(':')) == 2, 'Incorrect IP format.' assert len(client_addr.split(':')) == 2, 'Incorrect IP format: %s' % client_addr
# self._data_size is a key-value store where the key is data name # self._data_size is a key-value store where the key is data name
# and value is the size of tensor. It is used to partition data into # and value is the size of tensor. It is used to partition data into
# different KVServer nodes. # different KVServer nodes.
......
...@@ -68,7 +68,7 @@ char* ArrayMeta::Serialize(int64_t* size) { ...@@ -68,7 +68,7 @@ char* ArrayMeta::Serialize(int64_t* size) {
buffer_size += sizeof(data_shape_.size()); buffer_size += sizeof(data_shape_.size());
buffer_size += sizeof(int64_t) * data_shape_.size(); buffer_size += sizeof(int64_t) * data_shape_.size();
} }
// In the future, we should have a better memory management. // In the future, we should have a better memory management as
// allocating a large chunk of memory can be very expensive. // allocating a large chunk of memory can be very expensive.
buffer = new char[buffer_size]; buffer = new char[buffer_size];
char* pointer = buffer; char* pointer = buffer;
...@@ -124,7 +124,7 @@ char* KVStoreMsg::Serialize(int64_t* size) { ...@@ -124,7 +124,7 @@ char* KVStoreMsg::Serialize(int64_t* size) {
buffer_size += sizeof(this->name.size()); buffer_size += sizeof(this->name.size());
buffer_size += this->name.size(); buffer_size += this->name.size();
} }
// In the future, we should have a better memory management. // In the future, we should have a better memory management as
// allocating a large chunk of memory can be very expensive. // allocating a large chunk of memory can be very expensive.
buffer = new char[buffer_size]; buffer = new char[buffer_size];
char* pointer = buffer; char* pointer = buffer;
...@@ -532,9 +532,13 @@ DGL_REGISTER_GLOBAL("network.CAPI_ReceiverRecvKVMsg") ...@@ -532,9 +532,13 @@ DGL_REGISTER_GLOBAL("network.CAPI_ReceiverRecvKVMsg")
if (kv_msg->msg_type != kPullMsg) { if (kv_msg->msg_type != kPullMsg) {
Message recv_data_msg; Message recv_data_msg;
CHECK_EQ(receiver->RecvFrom(&recv_data_msg, send_id), REMOVE_SUCCESS); CHECK_EQ(receiver->RecvFrom(&recv_data_msg, send_id), REMOVE_SUCCESS);
CHECK_EQ(meta.data_shape_[2], 2); CHECK_GE(meta.data_shape_[2], 1);
std::vector<int64_t> vec_shape;
for (int i = 3; i < meta.data_shape_.size(); ++i) {
vec_shape.push_back(meta.data_shape_[i]);
}
kv_msg->data = CreateNDArrayFromRaw( kv_msg->data = CreateNDArrayFromRaw(
{meta.data_shape_[3], meta.data_shape_[4]}, vec_shape,
DLDataType{kDLFloat, 32, 1}, DLDataType{kDLFloat, 32, 1},
DLContext{kDLCPU, 0}, DLContext{kDLCPU, 0},
recv_data_msg.data); recv_data_msg.data);
......
...@@ -30,6 +30,7 @@ def start_client(): ...@@ -30,6 +30,7 @@ def start_client():
client.init_data(name='embed_0', shape=[10, 3], init_type='zero') client.init_data(name='embed_0', shape=[10, 3], init_type='zero')
client.init_data(name='embed_1', shape=[11, 3], init_type='uniform', low=0.0, high=0.0) client.init_data(name='embed_1', shape=[11, 3], init_type='uniform', low=0.0, high=0.0)
client.init_data(name='embed_2', shape=[11], init_type='zero')
tensor_id = torch.tensor([0, 1, 2]) tensor_id = torch.tensor([0, 1, 2])
tensor_data = torch.tensor([[0., 0., 0., ], [1., 1., 1.], [2., 2., 2.]]) tensor_data = torch.tensor([[0., 0., 0., ], [1., 1., 1.], [2., 2., 2.]])
...@@ -38,16 +39,19 @@ def start_client(): ...@@ -38,16 +39,19 @@ def start_client():
for i in range(5): for i in range(5):
client.push('embed_0', tensor_id, tensor_data) client.push('embed_0', tensor_id, tensor_data)
client.push('embed_1', tensor_id, tensor_data) client.push('embed_1', tensor_id, tensor_data)
client.push('embed_2', tensor_id, torch.tensor([2., 2., 2.]))
tensor_id = torch.tensor([6, 7, 8]) tensor_id = torch.tensor([6, 7, 8])
for i in range(5): for i in range(5):
client.push('embed_0', tensor_id, tensor_data) client.push('embed_0', tensor_id, tensor_data)
client.push('embed_1', tensor_id, tensor_data) client.push('embed_1', tensor_id, tensor_data)
client.push('embed_2', tensor_id, torch.tensor([3., 3., 3.]))
# Pull # Pull
tensor_id = torch.tensor([0, 1, 2, 6, 7, 8]) tensor_id = torch.tensor([0, 1, 2, 6, 7, 8])
new_tensor_0 = client.pull('embed_0', tensor_id) new_tensor_0 = client.pull('embed_0', tensor_id)
new_tensor_1 = client.pull('embed_1', tensor_id) new_tensor_1 = client.pull('embed_1', tensor_id)
new_tensor_2 = client.pull('embed_2', tensor_id)
target_tensor = torch.tensor( target_tensor = torch.tensor(
[[ 0., 0., 0.], [[ 0., 0., 0.],
...@@ -60,6 +64,10 @@ def start_client(): ...@@ -60,6 +64,10 @@ def start_client():
assert torch.equal(new_tensor_0, target_tensor) == True assert torch.equal(new_tensor_0, target_tensor) == True
assert torch.equal(new_tensor_1, target_tensor) == True assert torch.equal(new_tensor_1, target_tensor) == True
target_tensor = tensor.tensor([10., 10., 10., 15., 15., 15.])
assert torch.equal(new_tensor_2, target_tensor) == True
client.shut_down() client.shut_down()
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment