Unverified Commit e4ef8d1a authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[KVStore] API change of kvstore (#1058)

* API change of kvstore

* add demo for kvstore

* update

* remove duplicated log

* change queue size

* update

* update

* update

* update

* update

* update

* update

* update

* update

* fix lint

* change name

* update

* fix lint

* update

* update

* update

* update

* change message queue size to a python argument
parent 41f8a162
# Usage of DGL distributed KVStore
## Usage of DGL distributed KVStore
This is a simple example shows how to use DGL distributed KVStore on MXNet locally. In this example, we start 4 servers and 4 clients, and you can first run the command:
./run_server.sh
And when you see the message
start server 1 on 127.0.0.1:50051
start server 2 on 127.0.0.1:50052
start server 0 on 127.0.0.1:50050
start server 3 on 127.0.0.1:50053
you can start client by:
./run_client.sh
This is a simple example shows how to use DGL distributed KVStore on MXNet locally.
In this example, we start two servers and four clients, and you can run the example by:
```
./run.sh
```
\ No newline at end of file
# This is a simple MXNet client demo shows how to use DGL distributed kvstore.
# In this demo, we initialize two embeddings on server and push/pull data to/from it.
# This is a simple MXNet server demo shows how to use DGL distributed kvstore.
import dgl
import mxnet as mx
import time
import argparse
import mxnet as mx
server_namebook, client_namebook = dgl.contrib.ReadNetworkConfigure('config.txt')
def start_client(args):
# Initialize client and connect to server
client = dgl.contrib.KVClient(
client_id=args.id,
server_namebook=server_namebook,
client_addr=client_namebook[args.id])
ID = []
ID.append(mx.nd.array([0,1], dtype='int64'))
ID.append(mx.nd.array([2,3], dtype='int64'))
ID.append(mx.nd.array([4,5], dtype='int64'))
ID.append(mx.nd.array([6,7], dtype='int64'))
client.connect()
edata_partition_book = {'edata':mx.nd.array([0,0,1,1,2,2,3,3], dtype='int64')}
ndata_partition_book = {'ndata':mx.nd.array([0,0,1,1,2,2,3,3], dtype='int64')}
# Initialize data on server
client.init_data(name='embed_0', server_id=0, shape=[5, 3], init_type='zero')
client.init_data(name='embed_0', server_id=1, shape=[6, 3], init_type='zero')
client.init_data(name='embed_1', server_id=0, shape=[5], init_type='uniform', low=0.0, high=0.0)
client.init_data(name='embed_1', server_id=1, shape=[6], init_type='uniform', low=0.0, high=0.0)
def start_client():
data_0 = mx.nd.array([[0., 0., 0., ], [1., 1., 1.], [2., 2., 2.]])
data_1 = mx.nd.array([0., 1., 2.])
client = dgl.contrib.start_client(ip_config='ip_config.txt',
ndata_partition_book=ndata_partition_book,
edata_partition_book=edata_partition_book)
for i in range(5):
client.push(name='embed_0', server_id=0, id_tensor=mx.nd.array([0, 2, 4], dtype='int64'), data_tensor=data_0)
client.push(name='embed_0', server_id=1, id_tensor=mx.nd.array([1, 3, 5], dtype='int64'), data_tensor=data_0)
client.push(name='embed_1', server_id=0, id_tensor=mx.nd.array([0, 2, 4], dtype='int64'), data_tensor=data_1)
client.push(name='embed_1', server_id=1, id_tensor=mx.nd.array([1, 3, 5], dtype='int64'), data_tensor=data_1)
client.push(name='server_embed', server_id=0, id_tensor=mx.nd.array([0, 2, 4], dtype='int64'), data_tensor=data_1)
client.push(name='server_embed', server_id=1, id_tensor=mx.nd.array([0, 2, 4], dtype='int64'), data_tensor=data_1)
client.push(name='edata', id_tensor=ID[client.get_id()], data_tensor=mx.nd.array([[1.,1.,1.],[1.,1.,1.]]))
client.push(name='ndata', id_tensor=ID[client.get_id()], data_tensor=mx.nd.array([[2.,2.,2.],[2.,2.,2.]]))
client.barrier()
if client.get_id() == 0:
client.pull(name='embed_0', server_id=0, id_tensor=mx.nd.array([0, 1, 2, 3, 4], dtype='int64'))
msg_0 = client.pull_wait()
assert msg_0.rank == 0
client.pull(name='embed_0', server_id=1, id_tensor=mx.nd.array([0, 1, 2, 3, 4, 5], dtype='int64'))
msg_1 = client.pull_wait()
assert msg_1.rank == 1
print("embed_0:")
print(mx.nd.concat(msg_0.data, msg_1.data, dim=0))
tensor_edata = client.pull(name='edata', id_tensor=mx.nd.array([0,1,2,3,4,5,6,7], dtype='int64'))
tensor_ndata = client.pull(name='ndata', id_tensor=mx.nd.array([0,1,2,3,4,5,6,7], dtype='int64'))
client.pull(name='embed_1', server_id=0, id_tensor=mx.nd.array([0, 1, 2, 3, 4], dtype='int64'))
msg_0 = client.pull_wait()
assert msg_0.rank == 0
client.pull(name='embed_1', server_id=1, id_tensor=mx.nd.array([0, 1, 2, 3, 4, 5], dtype='int64'))
msg_1 = client.pull_wait()
assert msg_1.rank == 1
print("embed_1:")
print(mx.nd.concat(msg_0.data, msg_1.data, dim=0))
print(tensor_edata)
client.pull(name='server_embed', server_id=0, id_tensor=mx.nd.array([0, 1, 2, 3, 4], dtype='int64'))
msg_0 = client.pull_wait()
assert msg_0.rank == 0
client.pull(name='server_embed', server_id=1, id_tensor=mx.nd.array([0, 1, 2, 3, 4], dtype='int64'))
msg_1 = client.pull_wait()
assert msg_1.rank == 1
print("server_embed:")
print(mx.nd.concat(msg_0.data, msg_1.data, dim=0))
client.barrier()
print(tensor_ndata)
client.barrier()
# Shut-down all the servers
if client.get_id() == 0:
client.shut_down()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='kvstore')
parser.add_argument("--id", type=int, default=0, help="node ID")
args = parser.parse_args()
time.sleep(2) # wait server start
start_client(args)
start_client()
\ No newline at end of file
server 127.0.0.1:50051 0
server 127.0.0.1:50052 1
client 127.0.0.1:50053 0
client 127.0.0.1:50054 1
client 127.0.0.1:50055 2
client 127.0.0.1:50056 3
\ No newline at end of file
0 127.0.0.1 50050
1 127.0.0.1 50051
2 127.0.0.1 50052
3 127.0.0.1 50053
\ No newline at end of file
DGLBACKEND=mxnet python3 ./server.py --id 0 &
DGLBACKEND=mxnet python3 ./server.py --id 1 &
DGLBACKEND=mxnet python3 ./client.py --id 0 &
DGLBACKEND=mxnet python3 ./client.py --id 1 &
DGLBACKEND=mxnet python3 ./client.py --id 2 &
DGLBACKEND=mxnet python3 ./client.py --id 3
\ No newline at end of file
DGLBACKEND=mxnet python3 client.py &
DGLBACKEND=mxnet python3 client.py &
DGLBACKEND=mxnet python3 client.py &
DGLBACKEND=mxnet python3 client.py
\ No newline at end of file
DGLBACKEND=mxnet python3 server.py --id 0 &
DGLBACKEND=mxnet python3 server.py --id 1 &
DGLBACKEND=mxnet python3 server.py --id 2 &
DGLBACKEND=mxnet python3 server.py --id 3
\ No newline at end of file
# This is a simple MXNet server demo shows how to use DGL distributed kvstore.
# In this demo, we initialize two embeddings on server and push/pull data to/from it.
import dgl
import torch
import argparse
import mxnet as mx
server_namebook, client_namebook = dgl.contrib.ReadNetworkConfigure('config.txt')
ndata_g2l = []
edata_g2l = []
ndata_g2l.append({'ndata':mx.nd.array([0,1,0,0,0,0,0,0], dtype='int64')})
ndata_g2l.append({'ndata':mx.nd.array([0,0,0,1,0,0,0,0], dtype='int64')})
ndata_g2l.append({'ndata':mx.nd.array([0,0,0,0,0,1,0,0], dtype='int64')})
ndata_g2l.append({'ndata':mx.nd.array([0,0,0,0,0,0,0,1], dtype='int64')})
edata_g2l.append({'edata':mx.nd.array([0,1,0,0,0,0,0,0], dtype='int64')})
edata_g2l.append({'edata':mx.nd.array([0,0,0,1,0,0,0,0], dtype='int64')})
edata_g2l.append({'edata':mx.nd.array([0,0,0,0,0,1,0,0], dtype='int64')})
edata_g2l.append({'edata':mx.nd.array([0,0,0,0,0,0,0,1], dtype='int64')})
def start_server(args):
server = dgl.contrib.KVServer(
server_id=args.id,
client_namebook=client_namebook,
server_addr=server_namebook[args.id])
server.init_data(name='server_embed', data_tensor=mx.nd.array([0., 0., 0., 0., 0.]))
dgl.contrib.start_server(
server_id=args.id,
ip_config='ip_config.txt',
num_client=4,
ndata={'ndata':mx.nd.array([[0.,0.,0.],[0.,0.,0.]])},
edata={'edata':mx.nd.array([[0.,0.,0.],[0.,0.,0.]])},
ndata_g2l=ndata_g2l[args.id],
edata_g2l=edata_g2l[args.id])
server.start()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='kvstore')
......
# Usage of DGL distributed KVStore
## Usage of DGL distributed KVStore
This is a simple example shows how to use DGL distributed KVStore on Pytorch locally.
In this example, we start two servers and four clients, and you can run the example by:
This is a simple example shows how to use DGL distributed KVStore on Pytorch locally. In this example, we start 4 servers and 4 clients, and you can first run the command:
./run_server.sh
And when you see the message
start server 1 on 127.0.0.1:40051
start server 2 on 127.0.0.1:40052
start server 0 on 127.0.0.1:40050
start server 3 on 127.0.0.1:40053
you can start client by:
./run_client.sh
\ No newline at end of file
```
./run.sh
```
# This is a simple pytorch client demo shows how to use DGL distributed kvstore.
# In this demo, we initialize two embeddings on server and push/pull data to/from it.
# This is a simple MXNet server demo shows how to use DGL distributed kvstore.
import dgl
import time
import argparse
import torch as th
import time
server_namebook, client_namebook = dgl.contrib.ReadNetworkConfigure('config.txt')
def start_client(args):
# Initialize client and connect to server
client = dgl.contrib.KVClient(
client_id=args.id,
server_namebook=server_namebook,
client_addr=client_namebook[args.id])
ID = []
ID.append(th.tensor([0,1]))
ID.append(th.tensor([2,3]))
ID.append(th.tensor([4,5]))
ID.append(th.tensor([6,7]))
client.connect()
edata_partition_book = {'edata':th.tensor([0,0,1,1,2,2,3,3])}
ndata_partition_book = {'ndata':th.tensor([0,0,1,1,2,2,3,3])}
# Initialize data on server
client.init_data(name='embed_0', server_id=0, shape=[5, 3], init_type='zero')
client.init_data(name='embed_0', server_id=1, shape=[6, 3], init_type='zero')
client.init_data(name='embed_1', server_id=0, shape=[5], init_type='uniform', low=0.0, high=0.0)
client.init_data(name='embed_1', server_id=1, shape=[6], init_type='uniform', low=0.0, high=0.0)
def start_client():
time.sleep(3)
data_0 = th.tensor([[0., 0., 0., ], [1., 1., 1.], [2., 2., 2.]])
data_1 = th.tensor([0., 1., 2.])
client = dgl.contrib.start_client(ip_config='ip_config.txt',
ndata_partition_book=ndata_partition_book,
edata_partition_book=edata_partition_book)
for i in range(5):
client.push(name='embed_0', server_id=0, id_tensor=th.tensor([0, 2, 4]), data_tensor=data_0)
client.push(name='embed_0', server_id=1, id_tensor=th.tensor([1, 3, 5]), data_tensor=data_0)
client.push(name='embed_1', server_id=0, id_tensor=th.tensor([0, 2, 4]), data_tensor=data_1)
client.push(name='embed_1', server_id=1, id_tensor=th.tensor([1, 3, 5]), data_tensor=data_1)
client.push(name='server_embed', server_id=0, id_tensor=th.tensor([0, 2, 4]), data_tensor=data_1)
client.push(name='server_embed', server_id=1, id_tensor=th.tensor([0, 2, 4]), data_tensor=data_1)
client.push(name='edata', id_tensor=ID[client.get_id()], data_tensor=th.tensor([[1.,1.,1.],[1.,1.,1.]]))
client.push(name='ndata', id_tensor=ID[client.get_id()], data_tensor=th.tensor([[2.,2.,2.],[2.,2.,2.]]))
client.barrier()
if client.get_id() == 0:
client.pull(name='embed_0', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
msg_0 = client.pull_wait()
assert msg_0.rank == 0
client.pull(name='embed_0', server_id=1, id_tensor=th.tensor([0, 1, 2, 3, 4, 5]))
msg_1 = client.pull_wait()
assert msg_1.rank == 1
print("embed_0:")
print(th.cat([msg_0.data, msg_1.data]))
tensor_edata = client.pull(name='edata', id_tensor=th.tensor([0,1,2,3,4,5,6,7]))
tensor_ndata = client.pull(name='ndata', id_tensor=th.tensor([0,1,2,3,4,5,6,7]))
client.pull(name='embed_1', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
msg_0 = client.pull_wait()
assert msg_0.rank == 0
client.pull(name='embed_1', server_id=1, id_tensor=th.tensor([0, 1, 2, 3, 4, 5]))
msg_1 = client.pull_wait()
assert msg_1.rank == 1
print("embed_1:")
print(th.cat([msg_0.data, msg_1.data]))
print(tensor_edata)
client.pull(name='server_embed', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4]))
msg_0 = client.pull_wait()
assert msg_0.rank == 0
client.pull(name='server_embed', server_id=1, id_tensor=th.tensor([0, 1, 2, 3, 4]))
msg_1 = client.pull_wait()
assert msg_1.rank == 1
print("server_embed:")
print(th.cat([msg_0.data, msg_1.data]))
client.barrier()
print(tensor_ndata)
client.barrier()
# Shut-down all the servers
if client.get_id() == 0:
client.shut_down()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='kvstore')
parser.add_argument("--id", type=int, default=0, help="node ID")
args = parser.parse_args()
time.sleep(2) # wait server start
start_client(args)
start_client()
\ No newline at end of file
server 127.0.0.1:50051 0
server 127.0.0.1:50052 1
client 127.0.0.1:50053 0
client 127.0.0.1:50054 1
client 127.0.0.1:50055 2
client 127.0.0.1:50056 3
\ No newline at end of file
0 127.0.0.1 40050
1 127.0.0.1 40051
2 127.0.0.1 40052
3 127.0.0.1 40053
\ No newline at end of file
python3 ./server.py --id 0 &
python3 ./server.py --id 1 &
python3 ./client.py --id 0 &
python3 ./client.py --id 1 &
python3 ./client.py --id 2 &
python3 ./client.py --id 3
\ No newline at end of file
python3 client.py &
python3 client.py &
python3 client.py &
python3 client.py
\ No newline at end of file
python3 server.py --id 0 &
python3 server.py --id 1 &
python3 server.py --id 2 &
python3 server.py --id 3
\ No newline at end of file
# This is a simple pytorch server demo shows how to use DGL distributed kvstore.
# In this demo, we initialize two embeddings on server and push/pull data to/from it.
# This is a simple MXNet server demo shows how to use DGL distributed kvstore.
import dgl
import torch
import argparse
import torch as th
server_namebook, client_namebook = dgl.contrib.ReadNetworkConfigure('config.txt')
ndata_g2l = []
edata_g2l = []
ndata_g2l.append({'ndata':th.tensor([0,1,0,0,0,0,0,0])})
ndata_g2l.append({'ndata':th.tensor([0,0,0,1,0,0,0,0])})
ndata_g2l.append({'ndata':th.tensor([0,0,0,0,0,1,0,0])})
ndata_g2l.append({'ndata':th.tensor([0,0,0,0,0,0,0,1])})
edata_g2l.append({'edata':th.tensor([0,1,0,0,0,0,0,0])})
edata_g2l.append({'edata':th.tensor([0,0,0,1,0,0,0,0])})
edata_g2l.append({'edata':th.tensor([0,0,0,0,0,1,0,0])})
edata_g2l.append({'edata':th.tensor([0,0,0,0,0,0,0,1])})
def start_server(args):
server = dgl.contrib.KVServer(
server_id=args.id,
client_namebook=client_namebook,
server_addr=server_namebook[args.id])
server.init_data(name='server_embed', data_tensor=th.tensor([0., 0., 0., 0., 0.]))
dgl.contrib.start_server(
server_id=args.id,
ip_config='ip_config.txt',
num_client=4,
ndata={'ndata':th.tensor([[0.,0.,0.],[0.,0.,0.]])},
edata={'edata':th.tensor([[0.,0.,0.],[0.,0.,0.]])},
ndata_g2l=ndata_g2l[args.id],
edata_g2l=edata_g2l[args.id])
server.start()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='kvstore')
......
from . import sampling
from . import graph_store
from .dis_kvstore import KVClient, KVServer
from .dis_kvstore import ReadNetworkConfigure
from .dis_kvstore import read_ip_config
from .dis_kvstore import start_server, start_client
\ No newline at end of file
......@@ -6,141 +6,346 @@ from ..network import _receiver_wait, _sender_connect
from ..network import _send_kv_msg, _recv_kv_msg
from ..network import KVMsgType, KVStoreMsg
import math
import dgl.backend as F
from .._ffi.ndarray import empty_shared_mem
import numpy as np
import dgl.backend as F
import socket
def ReadNetworkConfigure(filename):
def read_ip_config(filename):
"""Read networking configuration from file.
The config file is like:
Format:
server 172.31.40.143:50050 0
client 172.31.40.143:50051 0
client 172.31.36.140:50051 1
client 172.31.47.147:50051 2
client 172.31.30.180:50051 3
[server_id] [ip] [port]
Here we have 1 server node and 4 client nodes.
0 172.31.40.143 50050
1 172.31.36.140 50050
2 172.31.47.147 50050
3 172.31.30.180 50050
Parameters
----------
filename : str
name of target configure file
name of target configure file.
Returns
-------
dict
server namebook
dict
client namebook
server namebook, e.g.,
{0:'172.31.40.143:50050',
1:'172.31.36.140:50050',
2:'172.31.47.147:50050',
3:'172.31.30.180:50050'}
"""
assert len(filename) > 0, 'filename cannot be empty.'
server_namebook = {}
client_namebook = {}
try:
lines = [line.rstrip('\n') for line in open(filename)]
for line in lines:
node_type, addr, node_id = line.split(' ')
if node_type == 'server':
server_namebook[int(node_id)] = addr
elif node_type == 'client':
client_namebook[int(node_id)] = addr
else:
raise RuntimeError("Unknown node type: %s", node_type)
ID, ip, port = line.split(' ')
server_namebook[int(ID)] = ip+':'+port
except:
print("Incorrect format IP configure file, the data format on each line should be: [server_id] [ip] [port]")
return server_namebook
def start_server(server_id, ip_config, num_client, ndata, edata, ndata_g2l=None, edata_g2l=None, msg_queue_size=2*1024*1024*1024):
"""Start a kvserver node.
This function will be blocked by server.start() api.
Parameters
----------
server_id : int
ID of current server node (start from 0)
ip_config : str
Filename of server IP configure file.
num_client : int
Total number of client nodes
ndata : dict of tensor (mx.ndarray or torch.tensor)
node data
edata : dict of tensor (mx.ndarray or torch.tensor)
edge data
ndata_g2l : dict of tensor (mx.ndarray or torch.tensor)
global2local mapping of node data
edata_g2l : dict of tensor (mx.ndarray or torch.tensor)
global2local mapping of edge data
msg_queue_size : int
Size of message queue
"""
assert server_id >= 0, 'server_id (%d) cannot be a negative number.' % server_id
assert len(ip_config) > 0, 'ip_config cannot be empty.'
assert num_client > 0, 'num_client (%d) cnanot be a negative number.' % num_client
server_namebook = read_ip_config(ip_config)
server = KVServer(
server_id=server_id,
server_addr=server_namebook[server_id],
num_client=num_client,
msg_queue_size=msg_queue_size)
for name, data in ndata.items():
server.init_data(name=name, data_tensor=data)
for name, data in edata.items():
server.init_data(name=name, data_tensor=data)
if ndata_g2l is not None:
for name, data in ndata_g2l.items():
server.set_global2local(name=name, global2local=data)
if edata_g2l is not None:
for name, data in edata_g2l.items():
server.set_global2local(name=name, global2local=data)
print("start server %d on %s" % (server.get_id(), server.get_addr()))
server.start()
def start_client(ip_config, ndata_partition_book, edata_partition_book, close_shared_mem=False, msg_queue_size=2*1024*1024*1024):
"""Start a kvclient node.
Parameters
----------
ip_config : str
Filename of server IP configure file.
ndata_partition_book : dict of tensor (mx.ndarray or torch.tensor)
Data mapping of node ID to server ID
edata_partition_book : dict of tensor (mx.ndarray or torch.tensor)
Data mapping of edge ID to server ID
close_shared_mem : bool
Close local shared-memory tensor access.
msg_queue_size : int
Size of message queue
Returns
-------
KVClient
client handle
"""
assert len(ip_config) > 0, 'ip_config cannot be empty.'
assert len(ndata_partition_book) > 0, 'ndata_partition_book cannot be empty.'
assert len(edata_partition_book) > 0, 'edata_partition_book cannot be empty.'
server_namebook = read_ip_config(ip_config)
client = KVClient(server_namebook=server_namebook, close_shared_mem=close_shared_mem, msg_queue_size=msg_queue_size)
for name, data in ndata_partition_book.items():
client.set_partition_book(name=name, partition_book=data)
for name, data in edata_partition_book.items():
client.set_partition_book(name=name, partition_book=data)
client.connect()
print("Client %d (%s) connected to kvstore ..." % (client.get_id(), client.get_addr()))
return client
return server_namebook, client_namebook
class KVServer(object):
"""KVServer is a lightweight key-value store service for DGL distributed training.
In practice, developers can use KVServer to hold large-scale graph features or
graph embeddings across machines in a distributed setting. User can re-wriite _push_handler
and _pull_handler to support flexibale models.
In practice, developers can use KVServer to hold large graph features or graph embeddings
across machines in a distributed setting. User can re-wriite _push_handler and _pull_handler
to support flexibale algorithms.
Note that, DO NOT use KVServer in multiple threads on Python because this behavior is not defined.
Note that, DO NOT use KVServer in multiple threads!
For now, KVServer can only run in CPU, and we will support GPU KVServer in the future.
Parameters
----------
server_id : int
KVServer's ID (start from 0).
client_namebook : dict
IP address namebook of KVClient, where the key is the client's ID
(start from 0) and the value is client's IP address, e.g.,
{ 0:'168.12.23.45:50051',
1:'168.12.23.21:50051',
2:'168.12.46.12:50051' }
ID of current kvserver node (start from 0).
server_addr : str
IP address of current KVServer node, e.g., '127.0.0.1:50051'
IP address and port of current KVServer node, e.g., '127.0.0.1:50051'.
num_client : int
Total number of clients connecting to server.
msg_queue_size : int
Size of message queue
net_type : str
networking type, e.g., 'socket' (default) or 'mpi' (do not support yet).
"""
def __init__(self, server_id, client_namebook, server_addr, net_type='socket'):
def __init__(self, server_id, server_addr, num_client, msg_queue_size=2 * 1024 * 1024 * 1024, net_type='socket'):
assert server_id >= 0, 'server_id (%d) cannot be a negative number.' % server_id
assert len(client_namebook) > 0, 'client_namebook cannot be empty.'
assert len(server_addr.split(':')) == 2, 'Incorrect IP format: %s' % server_addr
self._is_init = set() # Contains tensor name
self._data_store = {} # Key is name (string) and value is data (tensor)
self._barrier_count = 0;
assert num_client >= 0, 'num_client (%d) cannot be a negative number.' % num_client
assert net_type == 'socket' or net_type == 'mpi', 'net_type (%s) can only be \'socket\' or \'mpi\'.' % net_type
# check if target data has been initialized
self._has_data = set()
# Store the tensor data with data name
self._data_store = {}
# Used for barrier() API on KVClient
self._barrier_count = 0
# Server ID starts from zero
self._server_id = server_id
self._client_namebook = client_namebook
self._client_count = len(client_namebook)
self._addr = server_addr
self._sender = _create_sender(net_type)
self._receiver = _create_receiver(net_type)
# client_namebook will be received from client nodes
self._client_namebook = {}
self._client_count = num_client
# Create C communicator of sender and receiver
self._sender = _create_sender(net_type, msg_queue_size)
self._receiver = _create_receiver(net_type, msg_queue_size)
def __del__(self):
"""Finalize KVServer
"""
# Finalize C communicator of sender and receiver
_finalize_sender(self._sender)
_finalize_receiver(self._receiver)
def set_global2local(self, name, global2local):
"""Set a data mapping of global ID to local ID.
Parameters
----------
name : str
data name
global2local : list or tensor (mx.ndarray or torch.tensor)
A data mapping of global ID to local ID. KVStore will use global ID automatically
if this global2local is not been set.
"""
assert len(name) > 0, 'name cannot be empty.'
assert len(global2local) > 0, 'global2local cannot be empty.'
if isinstance(global2local, list):
global2local = F.tensor(global2local)
shared_data = empty_shared_mem(name+'-g2l-'+str(self._server_id), True, global2local.shape, 'int64')
dlpack = shared_data.to_dlpack()
self._data_store[name+'-g2l-'] = F.zerocopy_from_dlpack(dlpack)
self._data_store[name+'-g2l-'][:] = global2local[:]
self._has_data.add(name+'-g2l-')
def init_data(self, name, data_tensor):
"""KVServer supports data initialization on server.
"""Initialize data on KVServer with data name.
Parameters
----------
name : str
data name
data_tensor : tensor
data_tensor : tensor (mx.ndarray or torch.tensor)
data tensor
"""
self._data_store[name] = data_tensor
self._is_init.add(name)
assert len(name) > 0, 'name cannot be empty.'
assert len(data_tensor) > 0, 'data_tensor cannot be empty.'
shared_data = empty_shared_mem(name+'-data-'+str(self._server_id), True, data_tensor.shape, 'float32')
dlpack = shared_data.to_dlpack()
self._data_store[name+'-data-'] = F.zerocopy_from_dlpack(dlpack)
self._data_store[name+'-data-'][:] = data_tensor[:]
self._has_data.add(name+'-data-')
def get_id(self):
"""Get current server id.
Return
------
int
KVServer ID
"""
return self._server_id
def get_addr(self):
"""Get current server IP address
Return
------
str
IP address
"""
return self._addr
def start(self):
"""Start service of KVServer
"""
# Get connected with all client nodes
server_ip, server_port = self._addr.split(':')
_receiver_wait(self._receiver, server_ip, int(server_port), self._client_count)
_network_wait() # wait client's start
# recv client addr and assign ID for clients
addr_list = []
for i in range(self._client_count):
msg = _recv_kv_msg(self._receiver)
assert msg.type == KVMsgType.IP_ID
addr_list.append(msg.name)
self._sort_addr(addr_list)
for ID in range(len(addr_list)):
self._client_namebook[ID] = addr_list[ID]
_network_wait()
for ID, addr in self._client_namebook.items():
client_ip, client_port = addr.split(':')
_add_receiver_addr(self._sender, client_ip, int(client_port), ID)
_sender_connect(self._sender)
if self._server_id == 0:
# assign ID to client nodes
for client_id, addr in self._client_namebook.items():
msg = KVStoreMsg(
type=KVMsgType.IP_ID,
rank=self._server_id,
name=str(client_id),
id=None,
data=None)
_send_kv_msg(self._sender, msg, client_id)
# send serilaized shared-memory tensor information to clients
shared_tensor = ''
for name in self._has_data:
shared_tensor += self._serialize_shared_tensor(
name,
F.shape(self._data_store[name]),
F.dtype(self._data_store[name]))
shared_tensor += '|'
msg = KVStoreMsg(
type=KVMsgType.IP_ID,
rank=self._server_id,
name=shared_tensor,
id=None,
data=None)
for client_id in range(len(self._client_namebook)):
_send_kv_msg(self._sender, msg, client_id)
# Service loop
while True:
msg = _recv_kv_msg(self._receiver)
if msg.type == KVMsgType.INIT:
if (msg.name in self._is_init) == False:
# we hack the msg format here:
# msg.id store the shape of target tensor
# msg.data has two row, and the first row is
# the init_type, [0, 0] means 'zero' and [1,1]
# means 'uniform'. The second row is the min & max threshold.
data_shape = F.asnumpy(msg.id).tolist()
row_0 = (F.asnumpy(msg.data).tolist())[0]
row_1 = (F.asnumpy(msg.data).tolist())[1]
init_type = 'zero' if row_0[0] == 0.0 else 'uniform'
self._init_data(name=msg.name,
shape=data_shape,
init_type=init_type,
low=row_1[0],
high=row_1[1])
self._is_init.add(msg.name)
elif msg.type == KVMsgType.PUSH:
self._push_handler(msg.name, msg.id, msg.data)
# PUSH message
if msg.type == KVMsgType.PUSH:
if (msg.name+'-g2l-' in self._has_data) == True:
local_id = self._data_store[msg.name+'-g2l-'][msg.id]
else:
local_id = msg.id
self._push_handler(msg.name+'-data-', local_id, msg.data, self._data_store)
# PULL message
elif msg.type == KVMsgType.PULL:
res_tensor = self._pull_handler(msg.name, msg.id)
if (msg.name+'-g2l-' in self._has_data) == True:
local_id = self._data_store[msg.name+'-g2l-'][msg.id]
else:
local_id = msg.id
res_tensor = self._pull_handler(msg.name+'-data-', local_id, self._data_store)
back_msg = KVStoreMsg(
type=KVMsgType.PULL_BACK,
rank=self._server_id,
......@@ -148,6 +353,7 @@ class KVServer(object):
id=msg.id,
data=res_tensor)
_send_kv_msg(self._sender, back_msg, msg.rank)
# Barrier message
elif msg.type == KVMsgType.BARRIER:
self._barrier_count += 1
if self._barrier_count == self._client_count:
......@@ -160,57 +366,18 @@ class KVServer(object):
for i in range(self._client_count):
_send_kv_msg(self._sender, back_msg, i)
self._barrier_count = 0
# FINAL message
elif msg.type == KVMsgType.FINAL:
print("Exit KVStore service, server ID: %d" % self._server_id)
break # exit loop
else:
raise RuntimeError('Unknown type of kvstore message: %d' % msg.type.value)
def get_id(self):
"""Get server id
Return
------
int
KVServer ID
"""
return self._server_id
def _init_data(self, name, shape, init_type, low, high):
"""Initialize kvstore tensor.
Parameters
----------
name : str
data name
shape : list of int
The tensor shape
init_type : str
initialize method, including 'zero' and 'uniform'
low : float
min threshold
high : float
max threshold
"""
if init_type == 'uniform':
self._data_store[name] = F.uniform(
shape=shape,
dtype=F.float32,
ctx=F.cpu(),
low=low,
high=high)
elif init_type == 'zero':
self._data_store[name] = F.zeros(
shape=shape,
dtype=F.float32,
ctx=F.cpu())
else:
raise RuntimeError('Unknown initial method')
def _push_handler(self, name, ID, data):
def _push_handler(self, name, ID, data, target):
"""Default handler for PUSH message.
On default, _push_handler perform ADD operation for the tensor.
On default, _push_handler perform SET operation on the target tensor.
Parameters
----------
......@@ -220,127 +387,208 @@ class KVServer(object):
a vector storing the ID list.
data : tensor (mx.ndarray or torch.tensor)
a tensor with the same row size of id
target : dict of data
self._data_store
"""
for idx in range(ID.shape[0]):
self._data_store[name][ID[idx]] += data[idx]
target[name][ID] = data
def _pull_handler(self, name, ID):
def _pull_handler(self, name, ID, target):
"""Default handler for PULL operation.
On default, _pull_handler perform gather_row() operation for the tensor.
On default, _pull_handler perform index select operation for the tensor.
Parameters
----------
name : str
data name
ID : tensor (mx.ndarray or torch.tensor)
a vector storing the IDs that has been re-mapped to local id.
a vector storing the ID list.
target : dict of data
self._data_store
Return
------
tensor
a tensor with the same row size of ID
a tensor with the same row size of ID.
"""
new_tensor = F.gather_row(self._data_store[name], ID)
return new_tensor
return target[name][ID]
class KVClient(object):
"""KVClient is used to push/pull tensors to/from KVServer on DGL trainer.
def _serialize_shared_tensor(self, name, shape, dtype):
"""Serialize shared tensor
There are five operations supported by KVClient:
Parameters
----------
name : str
tensor name
shape : tuple of int
tensor shape
dtype : str
data type
* init_data(name, server_id, shape, init_type, low, high):
initialize tensor on target KVServer.
* push(name, server_id, id_tensor, data_tensor):
push sparse data to KVServer given specified ID.
* pull(name, server_id, id_tensor):
pull sparse data from KVServer given specified ID.
* pull_wait():
wait scheduled pull operation finish its job.
* shut_down():
shut down all KVServer nodes.
Returns
-------
str
serialized string
"""
str_data = name
str_data += '/'
for s in shape:
str_data += str(s)
str_data += '/'
if 'float32' in str(dtype):
str_data += 'float32'
elif 'int64' in str(dtype):
str_data += 'int64'
else:
raise RuntimeError('We can only process int64 and float32 shared-memory tensor now.')
Note that, DO NOT use KVClient in multiple threads!
return str_data
def _sort_addr(self, addr_list):
"""Sort client address list
Parameters
----------
addr_list : list of str
IP address list
"""
return addr_list.sort()
class KVClient(object):
"""KVClient is used to push/pull tensors to/from KVServer. If one server node and one client node
are on the same machine, they can commuincated using shared-memory tensor (close_shared_mem=False),
instead of TCP/IP connections.
Note that, DO NOT use KVClient in multiple threads on Python because this behavior is not defined.
For now, KVClient can only run in CPU, and we will support GPU KVClient in the future.
Parameters
----------
client_id : int
KVClient's ID (start from 0)
server_namebook: dict
IP address namebook of KVServer, where key is the KVServer's ID
(start from 0) and value is the server's IP address, e.g.,
(start from 0) and value is the server's IP address and port, e.g.,
{ 0:'168.12.23.45:50051',
1:'168.12.23.21:50051',
2:'168.12.46.12:50051' }
client_addr : str
IP address of current KVClient, e.g., '168.12.23.22:50051'
close_shared_mem : bool
DO NOT use shared-memory access on local machine.
msg_queue_size : int
Size of message queue.
net_type : str
networking type, e.g., 'socket' (default) or 'mpi'.
"""
def __init__(self, client_id, server_namebook, client_addr, net_type='socket'):
assert client_id >= 0, 'client_id (%d) cannot be a nagative number.' % client_id
def __init__(self, server_namebook, close_shared_mem=False, msg_queue_size=2 * 1024 * 1024 * 1024, net_type='socket'):
assert len(server_namebook) > 0, 'server_namebook cannot be empty.'
assert len(client_addr.split(':')) == 2, 'Incorrect IP format: %s' % client_addr
self._client_id = client_id
assert net_type == 'socket' or net_type == 'mpi', 'net_type (%s) can only be \'socket\' or \'mpi\'.' % net_type
if close_shared_mem == True:
print("The shared-memory tensor has been closed, all data connections will go through TCP/IP network.")
# check if target data has a ID mapping for global ID to local ID
self._has_data = set()
# This is used to store local data, which can share memory with local KVServer.
self._data_store = {}
# This is used to check if we can access server data locally
self._local_server_id = set()
# Server information
self._server_namebook = server_namebook
self._server_count = len(server_namebook)
self._addr = client_addr
self._sender = _create_sender(net_type)
self._receiver = _create_receiver(net_type)
self._close_shared_mem = close_shared_mem
# client ID will be assign by server after connecting to server
self._client_id = -1
# create C communicator of sender and receiver
self._sender = _create_sender(net_type, msg_queue_size)
self._receiver = _create_receiver(net_type, msg_queue_size)
def __del__(self):
"""Finalize KVClient
"""
# finalize C communicator of sender and receiver
_finalize_sender(self._sender)
_finalize_receiver(self._receiver)
def set_partition_book(self, name, partition_book):
"""Set partition book for KVClient.
Using partition book, client can know the corresponded server ID of each data.
Parameters
----------
name : str
data name
partition_book : list or tensor (mx.ndarray or torch.tensor)
A book that maps global ID to target server ID.
"""
assert len(name) > 0, 'name cannot be empty.'
assert len(partition_book) > 0, 'partition_book cannot be empty.'
if isinstance(partition_book, list):
self._data_store[name+'-part-'] = F.tensor(partition_book)
else:
self._data_store[name+'-part-'] = partition_book
self._has_data.add(name+'-part-')
def connect(self):
"""Connect to all KVServer nodes
"""Connect to all the KVServer nodes
"""
for ID, addr in self._server_namebook.items():
server_ip, server_port = addr.split(':')
_add_receiver_addr(self._sender, server_ip, int(server_port), ID)
_sender_connect(self._sender)
client_ip, client_port = self._addr.split(':')
_receiver_wait(self._receiver, client_ip, int(client_port), self._server_count)
def init_data(self, name, server_id, shape, init_type='zero', low=0.0, high=0.0):
"""Initialize kvstore tensor
self._addr = self._get_local_addr()
client_ip, client_port = self._addr.split(':')
we hack the msg format here: msg.id store the shape of target tensor,
msg.data has two row, and the first row is the init_type,
[0, 0] means 'zero' and [1,1] means 'uniform'.
The second row is the min & max threshold.
# find local server nodes
for ID, addr in self._server_namebook.items():
server_ip, server_port = addr.split(':')
if client_ip == server_ip or server_ip == '127.0.0.1':
self._local_server_id.add(ID)
Parameters
----------
name : str
data name
server_id : int
target server id
shape : list of int
shape of tensor
init_type : str
initialize method, including 'zero' and 'uniform'
low : float
min threshold, if use 'uniform'
high : float
max threshold, if use 'uniform'
"""
tensor_shape = F.tensor(shape)
init_type = 0.0 if init_type == 'zero' else 1.0
threshold = F.tensor([[init_type, init_type], [low, high]])
# send addr to server nodes
msg = KVStoreMsg(
type=KVMsgType.INIT,
rank=self._client_id,
name=name,
id=tensor_shape,
data=threshold)
type=KVMsgType.IP_ID,
rank=0,
name=self._addr,
id=None,
data=None)
for server_id in range(self._server_count):
_send_kv_msg(self._sender, msg, server_id)
def push(self, name, server_id, id_tensor, data_tensor):
"""Push sparse message to target KVServer.
_receiver_wait(self._receiver, client_ip, int(client_port), self._server_count)
# recv client id
msg = _recv_kv_msg(self._receiver)
assert msg.rank == 0
self._client_id = int(msg.name)
# recv name of shared tensor from server 0
msg = _recv_kv_msg(self._receiver)
assert msg.rank == 0
data_str = msg.name.split('|')
# open shared tensor on local machine
for data in data_str:
if data != '' and self._close_shared_mem == False:
tensor_name, shape, dtype = self._deserialize_shared_tensor(data)
for server_id in self._local_server_id:
shared_data = empty_shared_mem(tensor_name+str(server_id), False, shape, dtype)
dlpack = shared_data.to_dlpack()
self._data_store[tensor_name] = F.zerocopy_from_dlpack(dlpack)
self._has_data.add(tensor_name)
def push(self, name, id_tensor, data_tensor):
"""Push message to KVServer.
Note that push() is an async operation that will return immediately after calling.
......@@ -348,66 +596,123 @@ class KVClient(object):
----------
name : str
data name
server_id : int
target server id
id_tensor : tensor (mx.ndarray or torch.tensor)
a vector storing the ID list
a vector storing the global data ID
data_tensor : tensor (mx.ndarray or torch.tensor)
a tensor with the same row size of id
a tensor with the same row size of data ID
"""
assert server_id >= 0, 'server_id (%d) cannot be a negative number' % server_id
assert server_id < self._server_count, 'server_id (%d) must be smaller than server_count' % server_id
assert len(name) > 0, 'name cannot be empty.'
assert F.ndim(id_tensor) == 1, 'ID must be a vector.'
assert F.shape(id_tensor)[0] == F.shape(data_tensor)[0], 'The data must has the same row size with ID.'
# partition data (we can move this part of code into C-api if needed)
server_id = self._data_store[name+'-part-'][id_tensor]
# sort index by server id
sorted_id = F.tensor(np.argsort(F.asnumpy(server_id)))
id_tensor = id_tensor[sorted_id]
data_tensor = data_tensor[sorted_id]
server, count = np.unique(F.asnumpy(server_id), return_counts=True)
# push data to server by order
start = 0
for idx in range(len(server)):
end = start + count[idx]
if start == end: # don't have any data for target server
continue
partial_id = id_tensor[start:end]
partial_data = data_tensor[start:end]
if server[idx] in self._local_server_id and self._close_shared_mem == False:
if (name+'-g2l-' in self._has_data) == True:
local_id = self._data_store[name+'-g2l-'][partial_id]
else:
local_id = partial_id
self._push_handler(name+'-data-', local_id, data_tensor, self._data_store)
else:
msg = KVStoreMsg(
type=KVMsgType.PUSH,
rank=self._client_id,
name=name,
id=id_tensor,
data=data_tensor)
_send_kv_msg(self._sender, msg, server_id)
id=partial_id,
data=partial_data)
_send_kv_msg(self._sender, msg, server[idx])
start += count[idx]
def pull(self, name, server_id, id_tensor):
"""Pull sparse message from KVServer
Note that pull() is async operation that will return immediately after calling.
User can use pull_wait() to get the real data pulled from the kvserver. The order
of received data that comes from the same server is deterministic.
def pull(self, name, id_tensor):
"""Pull message from KVServer.
Parameters
----------
name : str
data name
server_id : int
target server id
id_tensor : tensor (mx.ndarray or torch.tensor)
a vector storing the ID list
Returns
-------
tensor
a data tensor with the same row size of id_tensor.
"""
assert server_id >= 0, 'server_id (%d) cannot be a negative number' % server_id
assert server_id < self._server_count, 'server_id (%d) must be smaller than server_count' % server_id
assert len(name) > 0, 'name cannot be empty.'
assert F.ndim(id_tensor) == 1, 'ID must be a vector.'
# partition data (we can move this part of code into C-api if needed)
server_id = self._data_store[name+'-part-'][id_tensor]
# sort index by server id
sorted_id = np.argsort(F.asnumpy(server_id))
# we need return data with original order of ID
back_sorted_id = F.tensor(np.argsort(sorted_id))
id_tensor = id_tensor[F.tensor(sorted_id)]
server, count = np.unique(F.asnumpy(server_id), return_counts=True)
# pull data from server by server order
start = 0
pull_count = 0
local_data = {}
for idx in range(len(server)):
end = start + count[idx]
if start == end: # don't have any data in target server
continue
partial_id = id_tensor[start:end]
if server[idx] in self._local_server_id and self._close_shared_mem == False:
if (name+'-g2l-' in self._has_data) == True:
local_id = self._data_store[name+'-g2l-'][partial_id]
else:
local_id = partial_id
local_data[server[idx]] = self._pull_handler(name+'-data-', local_id, self._data_store)
else:
msg = KVStoreMsg(
type=KVMsgType.PULL,
rank=self._client_id,
name=name,
id=id_tensor,
id=partial_id,
data=None)
_send_kv_msg(self._sender, msg, server_id)
_send_kv_msg(self._sender, msg, server[idx])
pull_count += 1
def pull_wait(self):
"""Wait pull() finish its job.
start += count[idx]
msg_list = []
for server_id, data in local_data.items():
local_msg = KVStoreMsg(
type=KVMsgType.PULL_BACK,
rank=server_id,
name=name,
id=None,
data=data)
msg_list.append(local_msg)
# wait message from server nodes
for idx in range(pull_count):
msg_list.append(_recv_kv_msg(self._receiver))
# sort msg by server id
msg_list.sort(key=self._takeId)
data_tensor = F.cat(seq=[msg.data for msg in msg_list], dim=0)
return data_tensor[back_sorted_id] # return data with original index order
Returns
-------
msg.rank
server_id
msg.data
target data tensor
"""
msg = _recv_kv_msg(self._receiver)
assert msg.type == KVMsgType.PULL_BACK, 'Recv kv msg error.'
return msg
def barrier(self):
"""Barrier for all client nodes
......@@ -420,14 +725,17 @@ class KVClient(object):
name=None,
id=None,
data=None)
for server_id in range(self._server_count):
_send_kv_msg(self._sender, msg, server_id)
for server_id in range(self._server_count):
back_msg = _recv_kv_msg(self._receiver)
assert back_msg.type == KVMsgType.BARRIER, 'Recv kv msg error.'
def shut_down(self):
"""Shutdown all KVServer nodes
"""Shut down all KVServer nodes.
We usually invoke this API by just one client (e.g., client_0).
"""
......@@ -440,6 +748,7 @@ class KVClient(object):
data=None)
_send_kv_msg(self._sender, msg, server_id)
def get_id(self):
"""Get client id
......@@ -449,3 +758,116 @@ class KVClient(object):
KVClient ID
"""
return self._client_id
def get_addr(self):
"""Get client IP address
Return
------
str
IP address
"""
return self._addr
def _get_local_addr(self):
"""Get local available IP and port
Return
------
str
IP address, e.g., '192.168.8.12:50051'
"""
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
# doesn't even have to be reachable
s.connect(('10.255.255.255', 1))
IP = s.getsockname()[0]
except:
IP = '127.0.0.1'
finally:
s.close()
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("",0))
s.listen(1)
port = s.getsockname()[1]
s.close()
return IP + ':' + str(port)
def _takeId(self, elem):
"""Used by sort
"""
return elem.rank
def _push_handler(self, name, ID, data, target):
"""Default handler for local PUSH message.
On default, _push_handler perform SET operation for the tensor.
Parameters
----------
name : str
data name
ID : tensor (mx.ndarray or torch.tensor)
a vector storing the ID list.
data : tensor (mx.ndarray or torch.tensor)
a tensor with the same row size of id
target : tensor (mx.ndarray or torch.tensor)
the target tensor
"""
target[name][ID] = data
def _pull_handler(self, name, ID, target):
"""Default handler for local PULL operation.
On default, _pull_handler perform index select operation for the tensor.
Parameters
----------
name : str
data name
ID : tensor (mx.ndarray or torch.tensor)
a vector storing the IDs that has been re-mapped to local id.
target : tensor (mx.ndarray or torch.tensor)
the target tensor
Return
------
tensor
a tensor with the same row size of ID
"""
return target[name][ID]
def _deserialize_shared_tensor(self, data):
"""Deserialize shared tensor information sent from server
Parameters
----------
data : str
serialized string
Returns
-------
str
tensor name
tuple of int
tensor shape
str
data type
"""
data_list = data.split('/')
tensor_name = data_list[0]
data_type = data_list[-1]
tensor_shape = []
for i in range(1, len(data_list)-1):
tensor_shape.append(int(data_list[i]))
tensor_shape = tuple(tensor_shape)
return tensor_name, tensor_shape, data_type
......@@ -31,27 +31,31 @@ def _network_wait():
"""
time.sleep(_WAIT_TIME_SEC)
def _create_sender(net_type):
def _create_sender(net_type, msg_queue_size=2000*1024*1024*1024):
"""Create a Sender communicator via C api
Parameters
----------
net_type : str
'socket' or 'mpi'
msg_queue_size : int
message queue size
"""
assert net_type in ('socket', 'mpi'), 'Unknown network type.'
return _CAPI_DGLSenderCreate(net_type)
return _CAPI_DGLSenderCreate(net_type, msg_queue_size)
def _create_receiver(net_type):
def _create_receiver(net_type, msg_queue_size=2000*1024*1024*1024):
"""Create a Receiver communicator via C api
Parameters
----------
net_type : str
'socket' or 'mpi'
msg_queue_size : int
message queue size
"""
assert net_type in ('socket', 'mpi'), 'Unknown network type.'
return _CAPI_DGLReceiverCreate(net_type)
return _CAPI_DGLReceiverCreate(net_type, msg_queue_size)
def _finalize_sender(sender):
"""Finalize Sender communicator
......@@ -188,6 +192,7 @@ class KVMsgType(Enum):
PULL = 4
PULL_BACK = 5
BARRIER = 6
IP_ID = 7
KVStoreMsg = namedtuple("KVStoreMsg", "type rank name id data")
"""Message of DGL kvstore
......@@ -227,6 +232,13 @@ def _send_kv_msg(sender, msg, recv_id):
msg.rank,
msg.name,
tensor_id)
elif msg.type == KVMsgType.IP_ID:
_CAPI_SenderSendKVMsg(
sender,
int(recv_id),
msg.type.value,
msg.rank,
msg.name)
elif msg.type in (KVMsgType.FINAL, KVMsgType.BARRIER):
_CAPI_SenderSendKVMsg(
sender,
......@@ -271,6 +283,15 @@ def _recv_kv_msg(receiver):
id=tensor_id,
data=None)
return msg
elif msg_type == KVMsgType.IP_ID:
name = _CAPI_ReceiverGetKVMsgName(msg_ptr)
msg = KVStoreMsg(
type=msg_type,
rank=rank,
name=name,
id=None,
data=None)
return msg
elif msg_type in (KVMsgType.FINAL, KVMsgType.BARRIER):
msg = KVStoreMsg(
type=msg_type,
......
......@@ -171,9 +171,10 @@ void KVStoreMsg::Deserialize(char* buffer, int64_t size) {
DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderCreate")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
std::string type = args[0];
int64_t msg_queue_size = args[1];
network::Sender* sender = nullptr;
if (type == "socket") {
sender = new network::SocketSender(kQueueSize);
sender = new network::SocketSender(msg_queue_size);
} else {
LOG(FATAL) << "Unknown communicator type: " << type;
}
......@@ -184,9 +185,10 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderCreate")
DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverCreate")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
std::string type = args[0];
int64_t msg_queue_size = args[1];
network::Receiver* receiver = nullptr;
if (type == "socket") {
receiver = new network::SocketReceiver(kQueueSize);
receiver = new network::SocketReceiver(msg_queue_size);
} else {
LOG(FATAL) << "Unknown communicator type: " << type;
}
......@@ -444,18 +446,21 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow")
DGL_REGISTER_GLOBAL("network._CAPI_SenderSendKVMsg")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0];
int recv_id = args[1];
int args_count = 0;
CommunicatorHandle chandle = args[args_count++];
int recv_id = args[args_count++];
KVStoreMsg kv_msg;
kv_msg.msg_type = args[2];
kv_msg.rank = args[3];
kv_msg.msg_type = args[args_count++];
kv_msg.rank = args[args_count++];
network::Sender* sender = static_cast<network::Sender*>(chandle);
if (kv_msg.msg_type != kFinalMsg && kv_msg.msg_type != kBarrierMsg) {
std::string name = args[4];
std::string name = args[args_count++];
kv_msg.name = name;
kv_msg.id = args[5];
if (kv_msg.msg_type != kPullMsg) {
kv_msg.data = args[6];
if (kv_msg.msg_type != kIPIDMsg) {
kv_msg.id = args[args_count++];
}
if (kv_msg.msg_type != kPullMsg && kv_msg.msg_type != kIPIDMsg) {
kv_msg.data = args[args_count++];
}
}
int64_t kv_size = 0;
......@@ -466,7 +471,10 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendKVMsg")
send_kv_msg.size = kv_size;
send_kv_msg.deallocator = DefaultMessageDeleter;
CHECK_EQ(sender->Send(send_kv_msg, recv_id), ADD_SUCCESS);
if (kv_msg.msg_type != kFinalMsg && kv_msg.msg_type != kBarrierMsg) {
if (kv_msg.msg_type != kFinalMsg &&
kv_msg.msg_type != kBarrierMsg &&
kv_msg.msg_type != kIPIDMsg) {
// Send ArrayMeta
ArrayMeta meta(kv_msg.msg_type);
meta.AddArray(kv_msg.id);
......@@ -510,7 +518,9 @@ DGL_REGISTER_GLOBAL("network.CAPI_ReceiverRecvKVMsg")
CHECK_EQ(receiver->Recv(&recv_kv_msg, &send_id), REMOVE_SUCCESS);
kv_msg->Deserialize(recv_kv_msg.data, recv_kv_msg.size);
recv_kv_msg.deallocator(&recv_kv_msg);
if (kv_msg->msg_type == kFinalMsg || kv_msg->msg_type == kBarrierMsg) {
if (kv_msg->msg_type == kFinalMsg ||
kv_msg->msg_type == kBarrierMsg ||
kv_msg->msg_type == kIPIDMsg) {
*rv = kv_msg;
return;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment