"src/array/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "83115794c29ef1db47f7e7e2e4fde54c0d7f0a4a"
Unverified Commit e4ef8d1a authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[KVStore] API change of kvstore (#1058)

* API change of kvstore

* add demo for kvstore

* update

* remove duplicated log

* change queue size

* update

* update

* update

* update

* update

* update

* update

* update

* update

* fix lint

* change name

* update

* fix lint

* update

* update

* update

* update

* change message queue size to a python argument
parent 41f8a162
# Usage of DGL distributed KVStore ## Usage of DGL distributed KVStore
This is a simple example shows how to use DGL distributed KVStore on MXNet locally. In this example, we start 4 servers and 4 clients, and you can first run the command:
./run_server.sh
And when you see the message
start server 1 on 127.0.0.1:50051
start server 2 on 127.0.0.1:50052
start server 0 on 127.0.0.1:50050
start server 3 on 127.0.0.1:50053
you can start client by:
./run_client.sh
This is a simple example shows how to use DGL distributed KVStore on MXNet locally.
In this example, we start two servers and four clients, and you can run the example by:
```
./run.sh
```
\ No newline at end of file
# This is a simple MXNet client demo shows how to use DGL distributed kvstore. # This is a simple MXNet server demo shows how to use DGL distributed kvstore.
# In this demo, we initialize two embeddings on server and push/pull data to/from it.
import dgl import dgl
import mxnet as mx
import time
import argparse import argparse
import mxnet as mx
server_namebook, client_namebook = dgl.contrib.ReadNetworkConfigure('config.txt') ID = []
ID.append(mx.nd.array([0,1], dtype='int64'))
ID.append(mx.nd.array([2,3], dtype='int64'))
ID.append(mx.nd.array([4,5], dtype='int64'))
ID.append(mx.nd.array([6,7], dtype='int64'))
def start_client(args): edata_partition_book = {'edata':mx.nd.array([0,0,1,1,2,2,3,3], dtype='int64')}
# Initialize client and connect to server ndata_partition_book = {'ndata':mx.nd.array([0,0,1,1,2,2,3,3], dtype='int64')}
client = dgl.contrib.KVClient(
client_id=args.id,
server_namebook=server_namebook,
client_addr=client_namebook[args.id])
client.connect() def start_client():
client = dgl.contrib.start_client(ip_config='ip_config.txt',
ndata_partition_book=ndata_partition_book,
edata_partition_book=edata_partition_book)
# Initialize data on server client.push(name='edata', id_tensor=ID[client.get_id()], data_tensor=mx.nd.array([[1.,1.,1.],[1.,1.,1.]]))
client.init_data(name='embed_0', server_id=0, shape=[5, 3], init_type='zero') client.push(name='ndata', id_tensor=ID[client.get_id()], data_tensor=mx.nd.array([[2.,2.,2.],[2.,2.,2.]]))
client.init_data(name='embed_0', server_id=1, shape=[6, 3], init_type='zero')
client.init_data(name='embed_1', server_id=0, shape=[5], init_type='uniform', low=0.0, high=0.0)
client.init_data(name='embed_1', server_id=1, shape=[6], init_type='uniform', low=0.0, high=0.0)
data_0 = mx.nd.array([[0., 0., 0., ], [1., 1., 1.], [2., 2., 2.]]) client.barrier()
data_1 = mx.nd.array([0., 1., 2.])
for i in range(5): tensor_edata = client.pull(name='edata', id_tensor=mx.nd.array([0,1,2,3,4,5,6,7], dtype='int64'))
client.push(name='embed_0', server_id=0, id_tensor=mx.nd.array([0, 2, 4], dtype='int64'), data_tensor=data_0) tensor_ndata = client.pull(name='ndata', id_tensor=mx.nd.array([0,1,2,3,4,5,6,7], dtype='int64'))
client.push(name='embed_0', server_id=1, id_tensor=mx.nd.array([1, 3, 5], dtype='int64'), data_tensor=data_0)
client.push(name='embed_1', server_id=0, id_tensor=mx.nd.array([0, 2, 4], dtype='int64'), data_tensor=data_1)
client.push(name='embed_1', server_id=1, id_tensor=mx.nd.array([1, 3, 5], dtype='int64'), data_tensor=data_1)
client.push(name='server_embed', server_id=0, id_tensor=mx.nd.array([0, 2, 4], dtype='int64'), data_tensor=data_1)
client.push(name='server_embed', server_id=1, id_tensor=mx.nd.array([0, 2, 4], dtype='int64'), data_tensor=data_1)
client.barrier() print(tensor_edata)
if client.get_id() == 0: client.barrier()
client.pull(name='embed_0', server_id=0, id_tensor=mx.nd.array([0, 1, 2, 3, 4], dtype='int64'))
msg_0 = client.pull_wait()
assert msg_0.rank == 0
client.pull(name='embed_0', server_id=1, id_tensor=mx.nd.array([0, 1, 2, 3, 4, 5], dtype='int64'))
msg_1 = client.pull_wait()
assert msg_1.rank == 1
print("embed_0:")
print(mx.nd.concat(msg_0.data, msg_1.data, dim=0))
client.pull(name='embed_1', server_id=0, id_tensor=mx.nd.array([0, 1, 2, 3, 4], dtype='int64')) print(tensor_ndata)
msg_0 = client.pull_wait()
assert msg_0.rank == 0
client.pull(name='embed_1', server_id=1, id_tensor=mx.nd.array([0, 1, 2, 3, 4, 5], dtype='int64'))
msg_1 = client.pull_wait()
assert msg_1.rank == 1
print("embed_1:")
print(mx.nd.concat(msg_0.data, msg_1.data, dim=0))
client.pull(name='server_embed', server_id=0, id_tensor=mx.nd.array([0, 1, 2, 3, 4], dtype='int64')) client.barrier()
msg_0 = client.pull_wait()
assert msg_0.rank == 0
client.pull(name='server_embed', server_id=1, id_tensor=mx.nd.array([0, 1, 2, 3, 4], dtype='int64'))
msg_1 = client.pull_wait()
assert msg_1.rank == 1
print("server_embed:")
print(mx.nd.concat(msg_0.data, msg_1.data, dim=0))
# Shut-down all the servers
if client.get_id() == 0: if client.get_id() == 0:
client.shut_down() client.shut_down()
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='kvstore')
parser.add_argument("--id", type=int, default=0, help="node ID") start_client()
args = parser.parse_args() \ No newline at end of file
time.sleep(2) # wait server start
start_client(args)
server 127.0.0.1:50051 0
server 127.0.0.1:50052 1
client 127.0.0.1:50053 0
client 127.0.0.1:50054 1
client 127.0.0.1:50055 2
client 127.0.0.1:50056 3
\ No newline at end of file
0 127.0.0.1 50050
1 127.0.0.1 50051
2 127.0.0.1 50052
3 127.0.0.1 50053
\ No newline at end of file
DGLBACKEND=mxnet python3 ./server.py --id 0 &
DGLBACKEND=mxnet python3 ./server.py --id 1 &
DGLBACKEND=mxnet python3 ./client.py --id 0 &
DGLBACKEND=mxnet python3 ./client.py --id 1 &
DGLBACKEND=mxnet python3 ./client.py --id 2 &
DGLBACKEND=mxnet python3 ./client.py --id 3
\ No newline at end of file
DGLBACKEND=mxnet python3 client.py &
DGLBACKEND=mxnet python3 client.py &
DGLBACKEND=mxnet python3 client.py &
DGLBACKEND=mxnet python3 client.py
\ No newline at end of file
DGLBACKEND=mxnet python3 server.py --id 0 &
DGLBACKEND=mxnet python3 server.py --id 1 &
DGLBACKEND=mxnet python3 server.py --id 2 &
DGLBACKEND=mxnet python3 server.py --id 3
\ No newline at end of file
# This is a simple MXNet server demo shows how to use DGL distributed kvstore. # This is a simple MXNet server demo shows how to use DGL distributed kvstore.
# In this demo, we initialize two embeddings on server and push/pull data to/from it.
import dgl import dgl
import torch
import argparse import argparse
import mxnet as mx import mxnet as mx
server_namebook, client_namebook = dgl.contrib.ReadNetworkConfigure('config.txt') ndata_g2l = []
edata_g2l = []
def start_server(args): ndata_g2l.append({'ndata':mx.nd.array([0,1,0,0,0,0,0,0], dtype='int64')})
server = dgl.contrib.KVServer( ndata_g2l.append({'ndata':mx.nd.array([0,0,0,1,0,0,0,0], dtype='int64')})
server_id=args.id, ndata_g2l.append({'ndata':mx.nd.array([0,0,0,0,0,1,0,0], dtype='int64')})
client_namebook=client_namebook, ndata_g2l.append({'ndata':mx.nd.array([0,0,0,0,0,0,0,1], dtype='int64')})
server_addr=server_namebook[args.id])
edata_g2l.append({'edata':mx.nd.array([0,1,0,0,0,0,0,0], dtype='int64')})
edata_g2l.append({'edata':mx.nd.array([0,0,0,1,0,0,0,0], dtype='int64')})
edata_g2l.append({'edata':mx.nd.array([0,0,0,0,0,1,0,0], dtype='int64')})
edata_g2l.append({'edata':mx.nd.array([0,0,0,0,0,0,0,1], dtype='int64')})
server.init_data(name='server_embed', data_tensor=mx.nd.array([0., 0., 0., 0., 0.])) def start_server(args):
dgl.contrib.start_server(
server_id=args.id,
ip_config='ip_config.txt',
num_client=4,
ndata={'ndata':mx.nd.array([[0.,0.,0.],[0.,0.,0.]])},
edata={'edata':mx.nd.array([[0.,0.,0.],[0.,0.,0.]])},
ndata_g2l=ndata_g2l[args.id],
edata_g2l=edata_g2l[args.id])
server.start()
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='kvstore') parser = argparse.ArgumentParser(description='kvstore')
parser.add_argument("--id", type=int, default=0, help="node ID") parser.add_argument("--id", type=int, default=0, help="node ID")
args = parser.parse_args() args = parser.parse_args()
start_server(args) start_server(args)
\ No newline at end of file
# Usage of DGL distributed KVStore ## Usage of DGL distributed KVStore
This is a simple example shows how to use DGL distributed KVStore on Pytorch locally. This is a simple example shows how to use DGL distributed KVStore on Pytorch locally. In this example, we start 4 servers and 4 clients, and you can first run the command:
In this example, we start two servers and four clients, and you can run the example by:
``` ./run_server.sh
./run.sh
``` And when you see the message
start server 1 on 127.0.0.1:40051
start server 2 on 127.0.0.1:40052
start server 0 on 127.0.0.1:40050
start server 3 on 127.0.0.1:40053
you can start client by:
./run_client.sh
\ No newline at end of file
# This is a simple pytorch client demo shows how to use DGL distributed kvstore. # This is a simple MXNet server demo shows how to use DGL distributed kvstore.
# In this demo, we initialize two embeddings on server and push/pull data to/from it.
import dgl import dgl
import time
import argparse import argparse
import torch as th import torch as th
import time
server_namebook, client_namebook = dgl.contrib.ReadNetworkConfigure('config.txt') ID = []
ID.append(th.tensor([0,1]))
def start_client(args): ID.append(th.tensor([2,3]))
# Initialize client and connect to server ID.append(th.tensor([4,5]))
client = dgl.contrib.KVClient( ID.append(th.tensor([6,7]))
client_id=args.id,
server_namebook=server_namebook,
client_addr=client_namebook[args.id])
client.connect() edata_partition_book = {'edata':th.tensor([0,0,1,1,2,2,3,3])}
ndata_partition_book = {'ndata':th.tensor([0,0,1,1,2,2,3,3])}
# Initialize data on server def start_client():
client.init_data(name='embed_0', server_id=0, shape=[5, 3], init_type='zero') time.sleep(3)
client.init_data(name='embed_0', server_id=1, shape=[6, 3], init_type='zero')
client.init_data(name='embed_1', server_id=0, shape=[5], init_type='uniform', low=0.0, high=0.0)
client.init_data(name='embed_1', server_id=1, shape=[6], init_type='uniform', low=0.0, high=0.0)
data_0 = th.tensor([[0., 0., 0., ], [1., 1., 1.], [2., 2., 2.]]) client = dgl.contrib.start_client(ip_config='ip_config.txt',
data_1 = th.tensor([0., 1., 2.]) ndata_partition_book=ndata_partition_book,
edata_partition_book=edata_partition_book)
for i in range(5): client.push(name='edata', id_tensor=ID[client.get_id()], data_tensor=th.tensor([[1.,1.,1.],[1.,1.,1.]]))
client.push(name='embed_0', server_id=0, id_tensor=th.tensor([0, 2, 4]), data_tensor=data_0) client.push(name='ndata', id_tensor=ID[client.get_id()], data_tensor=th.tensor([[2.,2.,2.],[2.,2.,2.]]))
client.push(name='embed_0', server_id=1, id_tensor=th.tensor([1, 3, 5]), data_tensor=data_0)
client.push(name='embed_1', server_id=0, id_tensor=th.tensor([0, 2, 4]), data_tensor=data_1)
client.push(name='embed_1', server_id=1, id_tensor=th.tensor([1, 3, 5]), data_tensor=data_1)
client.push(name='server_embed', server_id=0, id_tensor=th.tensor([0, 2, 4]), data_tensor=data_1)
client.push(name='server_embed', server_id=1, id_tensor=th.tensor([0, 2, 4]), data_tensor=data_1)
client.barrier() client.barrier()
if client.get_id() == 0: tensor_edata = client.pull(name='edata', id_tensor=th.tensor([0,1,2,3,4,5,6,7]))
client.pull(name='embed_0', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4])) tensor_ndata = client.pull(name='ndata', id_tensor=th.tensor([0,1,2,3,4,5,6,7]))
msg_0 = client.pull_wait()
assert msg_0.rank == 0
client.pull(name='embed_0', server_id=1, id_tensor=th.tensor([0, 1, 2, 3, 4, 5]))
msg_1 = client.pull_wait()
assert msg_1.rank == 1
print("embed_0:")
print(th.cat([msg_0.data, msg_1.data]))
client.pull(name='embed_1', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4])) print(tensor_edata)
msg_0 = client.pull_wait()
assert msg_0.rank == 0
client.pull(name='embed_1', server_id=1, id_tensor=th.tensor([0, 1, 2, 3, 4, 5]))
msg_1 = client.pull_wait()
assert msg_1.rank == 1
print("embed_1:")
print(th.cat([msg_0.data, msg_1.data]))
client.pull(name='server_embed', server_id=0, id_tensor=th.tensor([0, 1, 2, 3, 4])) client.barrier()
msg_0 = client.pull_wait()
assert msg_0.rank == 0 print(tensor_ndata)
client.pull(name='server_embed', server_id=1, id_tensor=th.tensor([0, 1, 2, 3, 4]))
msg_1 = client.pull_wait() client.barrier()
assert msg_1.rank == 1
print("server_embed:")
print(th.cat([msg_0.data, msg_1.data]))
# Shut-down all the servers
if client.get_id() == 0: if client.get_id() == 0:
client.shut_down() client.shut_down()
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='kvstore')
parser.add_argument("--id", type=int, default=0, help="node ID") start_client()
args = parser.parse_args() \ No newline at end of file
time.sleep(2) # wait server start
start_client(args)
server 127.0.0.1:50051 0
server 127.0.0.1:50052 1
client 127.0.0.1:50053 0
client 127.0.0.1:50054 1
client 127.0.0.1:50055 2
client 127.0.0.1:50056 3
\ No newline at end of file
0 127.0.0.1 40050
1 127.0.0.1 40051
2 127.0.0.1 40052
3 127.0.0.1 40053
\ No newline at end of file
python3 ./server.py --id 0 &
python3 ./server.py --id 1 &
python3 ./client.py --id 0 &
python3 ./client.py --id 1 &
python3 ./client.py --id 2 &
python3 ./client.py --id 3
\ No newline at end of file
python3 client.py &
python3 client.py &
python3 client.py &
python3 client.py
\ No newline at end of file
python3 server.py --id 0 &
python3 server.py --id 1 &
python3 server.py --id 2 &
python3 server.py --id 3
\ No newline at end of file
# This is a simple pytorch server demo shows how to use DGL distributed kvstore. # This is a simple MXNet server demo shows how to use DGL distributed kvstore.
# In this demo, we initialize two embeddings on server and push/pull data to/from it.
import dgl import dgl
import torch
import argparse import argparse
import torch as th import torch as th
server_namebook, client_namebook = dgl.contrib.ReadNetworkConfigure('config.txt') ndata_g2l = []
edata_g2l = []
def start_server(args): ndata_g2l.append({'ndata':th.tensor([0,1,0,0,0,0,0,0])})
server = dgl.contrib.KVServer( ndata_g2l.append({'ndata':th.tensor([0,0,0,1,0,0,0,0])})
server_id=args.id, ndata_g2l.append({'ndata':th.tensor([0,0,0,0,0,1,0,0])})
client_namebook=client_namebook, ndata_g2l.append({'ndata':th.tensor([0,0,0,0,0,0,0,1])})
server_addr=server_namebook[args.id])
edata_g2l.append({'edata':th.tensor([0,1,0,0,0,0,0,0])})
edata_g2l.append({'edata':th.tensor([0,0,0,1,0,0,0,0])})
edata_g2l.append({'edata':th.tensor([0,0,0,0,0,1,0,0])})
edata_g2l.append({'edata':th.tensor([0,0,0,0,0,0,0,1])})
server.init_data(name='server_embed', data_tensor=th.tensor([0., 0., 0., 0., 0.])) def start_server(args):
dgl.contrib.start_server(
server_id=args.id,
ip_config='ip_config.txt',
num_client=4,
ndata={'ndata':th.tensor([[0.,0.,0.],[0.,0.,0.]])},
edata={'edata':th.tensor([[0.,0.,0.],[0.,0.,0.]])},
ndata_g2l=ndata_g2l[args.id],
edata_g2l=edata_g2l[args.id])
server.start()
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='kvstore') parser = argparse.ArgumentParser(description='kvstore')
parser.add_argument("--id", type=int, default=0, help="node ID") parser.add_argument("--id", type=int, default=0, help="node ID")
args = parser.parse_args() args = parser.parse_args()
start_server(args) start_server(args)
\ No newline at end of file
from . import sampling from . import sampling
from . import graph_store from . import graph_store
from .dis_kvstore import KVClient, KVServer from .dis_kvstore import KVClient, KVServer
from .dis_kvstore import ReadNetworkConfigure from .dis_kvstore import read_ip_config
from .dis_kvstore import start_server, start_client
\ No newline at end of file
This diff is collapsed.
...@@ -31,27 +31,31 @@ def _network_wait(): ...@@ -31,27 +31,31 @@ def _network_wait():
""" """
time.sleep(_WAIT_TIME_SEC) time.sleep(_WAIT_TIME_SEC)
def _create_sender(net_type): def _create_sender(net_type, msg_queue_size=2000*1024*1024*1024):
"""Create a Sender communicator via C api """Create a Sender communicator via C api
Parameters Parameters
---------- ----------
net_type : str net_type : str
'socket' or 'mpi' 'socket' or 'mpi'
msg_queue_size : int
message queue size
""" """
assert net_type in ('socket', 'mpi'), 'Unknown network type.' assert net_type in ('socket', 'mpi'), 'Unknown network type.'
return _CAPI_DGLSenderCreate(net_type) return _CAPI_DGLSenderCreate(net_type, msg_queue_size)
def _create_receiver(net_type): def _create_receiver(net_type, msg_queue_size=2000*1024*1024*1024):
"""Create a Receiver communicator via C api """Create a Receiver communicator via C api
Parameters Parameters
---------- ----------
net_type : str net_type : str
'socket' or 'mpi' 'socket' or 'mpi'
msg_queue_size : int
message queue size
""" """
assert net_type in ('socket', 'mpi'), 'Unknown network type.' assert net_type in ('socket', 'mpi'), 'Unknown network type.'
return _CAPI_DGLReceiverCreate(net_type) return _CAPI_DGLReceiverCreate(net_type, msg_queue_size)
def _finalize_sender(sender): def _finalize_sender(sender):
"""Finalize Sender communicator """Finalize Sender communicator
...@@ -188,6 +192,7 @@ class KVMsgType(Enum): ...@@ -188,6 +192,7 @@ class KVMsgType(Enum):
PULL = 4 PULL = 4
PULL_BACK = 5 PULL_BACK = 5
BARRIER = 6 BARRIER = 6
IP_ID = 7
KVStoreMsg = namedtuple("KVStoreMsg", "type rank name id data") KVStoreMsg = namedtuple("KVStoreMsg", "type rank name id data")
"""Message of DGL kvstore """Message of DGL kvstore
...@@ -227,6 +232,13 @@ def _send_kv_msg(sender, msg, recv_id): ...@@ -227,6 +232,13 @@ def _send_kv_msg(sender, msg, recv_id):
msg.rank, msg.rank,
msg.name, msg.name,
tensor_id) tensor_id)
elif msg.type == KVMsgType.IP_ID:
_CAPI_SenderSendKVMsg(
sender,
int(recv_id),
msg.type.value,
msg.rank,
msg.name)
elif msg.type in (KVMsgType.FINAL, KVMsgType.BARRIER): elif msg.type in (KVMsgType.FINAL, KVMsgType.BARRIER):
_CAPI_SenderSendKVMsg( _CAPI_SenderSendKVMsg(
sender, sender,
...@@ -271,6 +283,15 @@ def _recv_kv_msg(receiver): ...@@ -271,6 +283,15 @@ def _recv_kv_msg(receiver):
id=tensor_id, id=tensor_id,
data=None) data=None)
return msg return msg
elif msg_type == KVMsgType.IP_ID:
name = _CAPI_ReceiverGetKVMsgName(msg_ptr)
msg = KVStoreMsg(
type=msg_type,
rank=rank,
name=name,
id=None,
data=None)
return msg
elif msg_type in (KVMsgType.FINAL, KVMsgType.BARRIER): elif msg_type in (KVMsgType.FINAL, KVMsgType.BARRIER):
msg = KVStoreMsg( msg = KVStoreMsg(
type=msg_type, type=msg_type,
......
...@@ -171,9 +171,10 @@ void KVStoreMsg::Deserialize(char* buffer, int64_t size) { ...@@ -171,9 +171,10 @@ void KVStoreMsg::Deserialize(char* buffer, int64_t size) {
DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderCreate") DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderCreate")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
std::string type = args[0]; std::string type = args[0];
int64_t msg_queue_size = args[1];
network::Sender* sender = nullptr; network::Sender* sender = nullptr;
if (type == "socket") { if (type == "socket") {
sender = new network::SocketSender(kQueueSize); sender = new network::SocketSender(msg_queue_size);
} else { } else {
LOG(FATAL) << "Unknown communicator type: " << type; LOG(FATAL) << "Unknown communicator type: " << type;
} }
...@@ -184,9 +185,10 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderCreate") ...@@ -184,9 +185,10 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderCreate")
DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverCreate") DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverCreate")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
std::string type = args[0]; std::string type = args[0];
int64_t msg_queue_size = args[1];
network::Receiver* receiver = nullptr; network::Receiver* receiver = nullptr;
if (type == "socket") { if (type == "socket") {
receiver = new network::SocketReceiver(kQueueSize); receiver = new network::SocketReceiver(msg_queue_size);
} else { } else {
LOG(FATAL) << "Unknown communicator type: " << type; LOG(FATAL) << "Unknown communicator type: " << type;
} }
...@@ -444,18 +446,21 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow") ...@@ -444,18 +446,21 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow")
DGL_REGISTER_GLOBAL("network._CAPI_SenderSendKVMsg") DGL_REGISTER_GLOBAL("network._CAPI_SenderSendKVMsg")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0]; int args_count = 0;
int recv_id = args[1]; CommunicatorHandle chandle = args[args_count++];
int recv_id = args[args_count++];
KVStoreMsg kv_msg; KVStoreMsg kv_msg;
kv_msg.msg_type = args[2]; kv_msg.msg_type = args[args_count++];
kv_msg.rank = args[3]; kv_msg.rank = args[args_count++];
network::Sender* sender = static_cast<network::Sender*>(chandle); network::Sender* sender = static_cast<network::Sender*>(chandle);
if (kv_msg.msg_type != kFinalMsg && kv_msg.msg_type != kBarrierMsg) { if (kv_msg.msg_type != kFinalMsg && kv_msg.msg_type != kBarrierMsg) {
std::string name = args[4]; std::string name = args[args_count++];
kv_msg.name = name; kv_msg.name = name;
kv_msg.id = args[5]; if (kv_msg.msg_type != kIPIDMsg) {
if (kv_msg.msg_type != kPullMsg) { kv_msg.id = args[args_count++];
kv_msg.data = args[6]; }
if (kv_msg.msg_type != kPullMsg && kv_msg.msg_type != kIPIDMsg) {
kv_msg.data = args[args_count++];
} }
} }
int64_t kv_size = 0; int64_t kv_size = 0;
...@@ -466,7 +471,10 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendKVMsg") ...@@ -466,7 +471,10 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendKVMsg")
send_kv_msg.size = kv_size; send_kv_msg.size = kv_size;
send_kv_msg.deallocator = DefaultMessageDeleter; send_kv_msg.deallocator = DefaultMessageDeleter;
CHECK_EQ(sender->Send(send_kv_msg, recv_id), ADD_SUCCESS); CHECK_EQ(sender->Send(send_kv_msg, recv_id), ADD_SUCCESS);
if (kv_msg.msg_type != kFinalMsg && kv_msg.msg_type != kBarrierMsg) {
if (kv_msg.msg_type != kFinalMsg &&
kv_msg.msg_type != kBarrierMsg &&
kv_msg.msg_type != kIPIDMsg) {
// Send ArrayMeta // Send ArrayMeta
ArrayMeta meta(kv_msg.msg_type); ArrayMeta meta(kv_msg.msg_type);
meta.AddArray(kv_msg.id); meta.AddArray(kv_msg.id);
...@@ -510,7 +518,9 @@ DGL_REGISTER_GLOBAL("network.CAPI_ReceiverRecvKVMsg") ...@@ -510,7 +518,9 @@ DGL_REGISTER_GLOBAL("network.CAPI_ReceiverRecvKVMsg")
CHECK_EQ(receiver->Recv(&recv_kv_msg, &send_id), REMOVE_SUCCESS); CHECK_EQ(receiver->Recv(&recv_kv_msg, &send_id), REMOVE_SUCCESS);
kv_msg->Deserialize(recv_kv_msg.data, recv_kv_msg.size); kv_msg->Deserialize(recv_kv_msg.data, recv_kv_msg.size);
recv_kv_msg.deallocator(&recv_kv_msg); recv_kv_msg.deallocator(&recv_kv_msg);
if (kv_msg->msg_type == kFinalMsg || kv_msg->msg_type == kBarrierMsg) { if (kv_msg->msg_type == kFinalMsg ||
kv_msg->msg_type == kBarrierMsg ||
kv_msg->msg_type == kIPIDMsg) {
*rv = kv_msg; *rv = kv_msg;
return; return;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment