"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "cc92332096a522df689a244efeb1f2156789afe4"
Unverified Commit fe7d5e9b authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[Sampler] Change Distributed Sampler API (#499)

* Change Distributed Sampler API

* fix lint

* fix lint

* update demo

* update

* update

* update

* update demo

* update demo
parent 3f464591
...@@ -32,8 +32,7 @@ class MySamplerPool(SamplerPool): ...@@ -32,8 +32,7 @@ class MySamplerPool(SamplerPool):
# create GCN model # create GCN model
g = DGLGraph(data.graph, readonly=True) g = DGLGraph(data.graph, readonly=True)
for epoch in range(args.n_epochs): while True:
# Here we onlt send nodeflow for training
idx = 0 idx = 0
for nf in dgl.contrib.sampling.NeighborSampler(g, args.batch_size, for nf in dgl.contrib.sampling.NeighborSampler(g, args.batch_size,
args.num_neighbors, args.num_neighbors,
...@@ -44,6 +43,7 @@ class MySamplerPool(SamplerPool): ...@@ -44,6 +43,7 @@ class MySamplerPool(SamplerPool):
print("send train nodeflow: %d" %(idx)) print("send train nodeflow: %d" %(idx))
sender.send(nf, 0) sender.send(nf, 0)
idx += 1 idx += 1
sender.signal(0)
def main(args): def main(args):
pool = MySamplerPool() pool = MySamplerPool()
......
...@@ -122,9 +122,6 @@ def main(args): ...@@ -122,9 +122,6 @@ def main(args):
if args.self_loop and not args.dataset.startswith('reddit'): if args.self_loop and not args.dataset.startswith('reddit'):
data.graph.add_edges_from([(i,i) for i in range(len(data.graph))]) data.graph.add_edges_from([(i,i) for i in range(len(data.graph))])
# Create sampler receiver
receiver = dgl.contrib.sampling.SamplerReceiver(addr=args.ip, num_sender=args.num_sender)
train_nid = mx.nd.array(np.nonzero(data.train_mask)[0]).astype(np.int64).as_in_context(ctx) train_nid = mx.nd.array(np.nonzero(data.train_mask)[0]).astype(np.int64).as_in_context(ctx)
test_nid = mx.nd.array(np.nonzero(data.test_mask)[0]).astype(np.int64).as_in_context(ctx) test_nid = mx.nd.array(np.nonzero(data.test_mask)[0]).astype(np.int64).as_in_context(ctx)
...@@ -163,6 +160,9 @@ def main(args): ...@@ -163,6 +160,9 @@ def main(args):
norm = mx.nd.expand_dims(1./degs, 1) norm = mx.nd.expand_dims(1./degs, 1)
g.ndata['norm'] = norm g.ndata['norm'] = norm
# Create sampler receiver
sampler = dgl.contrib.sampling.SamplerReceiver(graph=g, addr=args.ip, num_sender=args.num_sender)
model = GCNSampling(in_feats, model = GCNSampling(in_feats,
args.n_hidden, args.n_hidden,
n_classes, n_classes,
...@@ -191,11 +191,11 @@ def main(args): ...@@ -191,11 +191,11 @@ def main(args):
# initialize graph # initialize graph
dur = [] dur = []
total_count = 153
for epoch in range(args.n_epochs): for epoch in range(args.n_epochs):
for subg_count in range(total_count): idx = 0
print(subg_count) for nf in sampler:
nf = receiver.recv(g) print("epoch: %d, subgraph: %d" %(epoch, idx))
idx += 1
nf.copy_from_parent() nf.copy_from_parent()
# forward # forward
with mx.autograd.record(): with mx.autograd.record():
......
...@@ -2,7 +2,8 @@ ...@@ -2,7 +2,8 @@
from ...network import _send_nodeflow, _recv_nodeflow from ...network import _send_nodeflow, _recv_nodeflow
from ...network import _create_sender, _create_receiver from ...network import _create_sender, _create_receiver
from ...network import _finalize_sender, _finalize_receiver from ...network import _finalize_sender, _finalize_receiver
from ...network import _add_receiver_addr, _sender_connect, _receiver_wait from ...network import _add_receiver_addr, _sender_connect
from ...network import _receiver_wait, _send_end_signal
from multiprocessing import Pool from multiprocessing import Pool
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
...@@ -103,6 +104,17 @@ class SamplerSender(object): ...@@ -103,6 +104,17 @@ class SamplerSender(object):
""" """
_send_nodeflow(self._sender, nodeflow, recv_id) _send_nodeflow(self._sender, nodeflow, recv_id)
def signal(self, recv_id):
"""Whene samplling of each epoch is finished, users can
invoke this API to tell SamplerReceiver it has finished its job.
Parameters
----------
recv_id : int
receiver ID
"""
_send_end_signal(self._sender, recv_id)
class SamplerReceiver(object): class SamplerReceiver(object):
"""SamplerReceiver for DGL distributed training. """SamplerReceiver for DGL distributed training.
...@@ -114,14 +126,18 @@ class SamplerReceiver(object): ...@@ -114,14 +126,18 @@ class SamplerReceiver(object):
Parameters Parameters
---------- ----------
graph : DGLGraph
The parent graph
addr : str addr : str
address of SamplerReceiver, e.g., '127.0.0.1:50051' address of SamplerReceiver, e.g., '127.0.0.1:50051'
num_sender : int num_sender : int
total number of SamplerSender total number of SamplerSender
""" """
def __init__(self, addr, num_sender): def __init__(self, graph, addr, num_sender):
self._graph = graph
self._addr = addr self._addr = addr
self._num_sender = num_sender self._num_sender = num_sender
self._tmp_count = 0
self._receiver = _create_receiver() self._receiver = _create_receiver()
vec = self._addr.split(':') vec = self._addr.split(':')
_receiver_wait(self._receiver, vec[0], int(vec[1]), self._num_sender); _receiver_wait(self._receiver, vec[0], int(vec[1]), self._num_sender);
...@@ -131,17 +147,20 @@ class SamplerReceiver(object): ...@@ -131,17 +147,20 @@ class SamplerReceiver(object):
""" """
_finalize_receiver(self._receiver) _finalize_receiver(self._receiver)
def recv(self, graph): def __iter__(self):
"""Receive a NodeFlow object from remote sampler. """Iterator
"""
Parameters return self
----------
graph : DGLGraph
The parent graph
Returns def __next__(self):
------- """Return sampled NodeFlow object
NodeFlow
received NodeFlow object
""" """
return _recv_nodeflow(self._receiver, graph) while True:
res = _recv_nodeflow(self._receiver, self._graph)
if isinstance(res, int):
self._tmp_count += 1
if self._tmp_count == self._num_sender:
self._tmp_count = 0
raise StopIteration
else:
return res
...@@ -8,6 +8,9 @@ from . import utils ...@@ -8,6 +8,9 @@ from . import utils
_init_api("dgl.network") _init_api("dgl.network")
_CONTROL_NODEFLOW = 0
_CONTROL_END_SIGNAL = 1
def _create_sender(): def _create_sender():
"""Create a Sender communicator via C api """Create a Sender communicator via C api
""" """
...@@ -74,6 +77,18 @@ def _send_nodeflow(sender, nodeflow, recv_id): ...@@ -74,6 +77,18 @@ def _send_nodeflow(sender, nodeflow, recv_id):
layers_offsets, layers_offsets,
flows_offsets) flows_offsets)
def _send_end_signal(sender, recv_id):
"""Send an epoch-end signal to remote Receiver.
Parameters
----------
sender : ctypes.c_void_p
C sender handle
recv_id : int
Receiver ID
"""
_CAPI_SenderSendEndSignal(sender, recv_id)
def _create_receiver(): def _create_receiver():
"""Create a Receiver communicator via C api """Create a Receiver communicator via C api
""" """
...@@ -115,6 +130,12 @@ def _recv_nodeflow(receiver, graph): ...@@ -115,6 +130,12 @@ def _recv_nodeflow(receiver, graph):
NodeFlow NodeFlow
Sampled NodeFlow object Sampled NodeFlow object
""" """
# hdl is a list of ptr res = _CAPI_ReceiverRecvSubgraph(receiver)
hdl = unwrap_to_ptr_list(_CAPI_ReceiverRecvSubgraph(receiver)) if isinstance(res, int):
return NodeFlow(graph, hdl[0]) if res == _CONTROL_END_SIGNAL:
return _CONTROL_END_SIGNAL
else:
raise RuntimeError('Got unexpected control code {}'.format(res))
else:
hdl = unwrap_to_ptr_list(res)
return NodeFlow(graph, hdl[0])
...@@ -23,6 +23,27 @@ namespace network { ...@@ -23,6 +23,27 @@ namespace network {
static char* SEND_BUFFER = nullptr; static char* SEND_BUFFER = nullptr;
static char* RECV_BUFFER = nullptr; static char* RECV_BUFFER = nullptr;
// Wrapper for Send api
static void SendData(network::Sender* sender,
const char* data,
int64_t size,
int recv_id) {
int64_t send_size = sender->Send(data, size, recv_id);
if (send_size <= 0) {
LOG(FATAL) << "Send error (size: " << send_size << ")";
}
}
// Wrapper for Recv api
static void RecvData(network::Receiver* receiver,
char* dest,
int64_t max_size) {
int64_t recv_size = receiver->Recv(dest, max_size);
if (recv_size <= 0) {
LOG(FATAL) << "Receive error (size: " << recv_size << ")";
}
}
DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderCreate") DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderCreate")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
try { try {
...@@ -74,20 +95,30 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendSubgraph") ...@@ -74,20 +95,30 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendSubgraph")
ImmutableGraph *ptr = static_cast<ImmutableGraph*>(ghandle); ImmutableGraph *ptr = static_cast<ImmutableGraph*>(ghandle);
network::Sender* sender = static_cast<network::Sender*>(chandle); network::Sender* sender = static_cast<network::Sender*>(chandle);
auto csr = ptr->GetInCSR(); auto csr = ptr->GetInCSR();
// Write control message
*SEND_BUFFER = CONTROL_NODEFLOW;
// Serialize nodeflow to data buffer // Serialize nodeflow to data buffer
int64_t data_size = network::SerializeSampledSubgraph( int64_t data_size = network::SerializeSampledSubgraph(
SEND_BUFFER, SEND_BUFFER+sizeof(CONTROL_NODEFLOW),
csr, csr,
node_mapping, node_mapping,
edge_mapping, edge_mapping,
layer_offsets, layer_offsets,
flow_offsets); flow_offsets);
CHECK_GT(data_size, 0); CHECK_GT(data_size, 0);
data_size += sizeof(CONTROL_NODEFLOW);
// Send msg via network // Send msg via network
int64_t size = sender->Send(SEND_BUFFER, data_size, recv_id); SendData(sender, SEND_BUFFER, data_size, recv_id);
if (size <= 0) { });
LOG(FATAL) << "Send message error (size: " << size << ")";
} DGL_REGISTER_GLOBAL("network._CAPI_SenderSendEndSignal")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0];
int recv_id = args[1];
network::Sender* sender = static_cast<network::Sender*>(chandle);
*SEND_BUFFER = CONTROL_END_SIGNAL;
// Send msg via network
SendData(sender, SEND_BUFFER, sizeof(CONTROL_END_SIGNAL), recv_id);
}); });
DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverCreate") DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverCreate")
...@@ -125,23 +156,27 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvSubgraph") ...@@ -125,23 +156,27 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvSubgraph")
CommunicatorHandle chandle = args[0]; CommunicatorHandle chandle = args[0];
network::Receiver* receiver = static_cast<network::SocketReceiver*>(chandle); network::Receiver* receiver = static_cast<network::SocketReceiver*>(chandle);
// Recv data from network // Recv data from network
int64_t size = receiver->Recv(RECV_BUFFER, kMaxBufferSize); RecvData(receiver, RECV_BUFFER, kMaxBufferSize);
if (size <= 0) { int control = *RECV_BUFFER;
LOG(FATAL) << "Receive error: (size: " << size << ")"; if (control == CONTROL_NODEFLOW) {
NodeFlow* nf = new NodeFlow();
ImmutableGraph::CSR::Ptr csr;
// Deserialize nodeflow from recv_data_buffer
network::DeserializeSampledSubgraph(RECV_BUFFER+sizeof(CONTROL_NODEFLOW),
&(csr),
&(nf->node_mapping),
&(nf->edge_mapping),
&(nf->layer_offsets),
&(nf->flow_offsets));
nf->graph = GraphPtr(new ImmutableGraph(csr, nullptr, false));
std::vector<NodeFlow*> subgs(1);
subgs[0] = nf;
*rv = WrapVectorReturn(subgs);
} else if (control == CONTROL_END_SIGNAL) {
*rv = CONTROL_END_SIGNAL;
} else {
LOG(FATAL) << "Unknow control number: " << control;
} }
NodeFlow* nf = new NodeFlow();
ImmutableGraph::CSR::Ptr csr;
// Deserialize nodeflow from recv_data_buffer
network::DeserializeSampledSubgraph(RECV_BUFFER,
&(csr),
&(nf->node_mapping),
&(nf->edge_mapping),
&(nf->layer_offsets),
&(nf->flow_offsets));
nf->graph = GraphPtr(new ImmutableGraph(csr, nullptr, false));
std::vector<NodeFlow*> subgs(1);
subgs[0] = nf;
*rv = WrapVectorReturn(subgs);
}); });
} // namespace network } // namespace network
......
...@@ -23,6 +23,10 @@ const int64_t kQueueSize = 1024 * 1024 * 1024; ...@@ -23,6 +23,10 @@ const int64_t kQueueSize = 1024 * 1024 * 1024;
// Maximal try count of connection // Maximal try count of connection
const int kMaxTryCount = 500; const int kMaxTryCount = 500;
// Control number
const int CONTROL_NODEFLOW = 0;
const int CONTROL_END_SIGNAL = 1;
} // namespace network } // namespace network
} // namespace dgl } // namespace dgl
......
...@@ -13,20 +13,20 @@ def generate_rand_graph(n): ...@@ -13,20 +13,20 @@ def generate_rand_graph(n):
def start_trainer(): def start_trainer():
g = generate_rand_graph(100) g = generate_rand_graph(100)
recv = dgl.contrib.sampling.SamplerReceiver(addr='127.0.0.1:50051', num_sender=1) sampler = dgl.contrib.sampling.SamplerReceiver(graph=g, addr='127.0.0.1:50051', num_sender=1)
subg = recv.recv(g) for subg in sampler:
seed_ids = subg.layer_parent_nid(-1) seed_ids = subg.layer_parent_nid(-1)
assert len(seed_ids) == 1 assert len(seed_ids) == 1
src, dst, eid = g.in_edges(seed_ids, form='all') src, dst, eid = g.in_edges(seed_ids, form='all')
assert subg.number_of_nodes() == len(src) + 1 assert subg.number_of_nodes() == len(src) + 1
assert subg.number_of_edges() == len(src) assert subg.number_of_edges() == len(src)
assert seed_ids == subg.layer_parent_nid(-1) assert seed_ids == subg.layer_parent_nid(-1)
child_src, child_dst, child_eid = subg.in_edges(subg.layer_nid(-1), form='all') child_src, child_dst, child_eid = subg.in_edges(subg.layer_nid(-1), form='all')
assert F.array_equal(child_src, subg.layer_nid(0)) assert F.array_equal(child_src, subg.layer_nid(0))
src1 = subg.map_to_parent_nid(child_src) src1 = subg.map_to_parent_nid(child_src)
assert F.array_equal(src1, src) assert F.array_equal(src1, src)
def start_sampler(): def start_sampler():
g = generate_rand_graph(100) g = generate_rand_graph(100)
...@@ -35,12 +35,12 @@ def start_sampler(): ...@@ -35,12 +35,12 @@ def start_sampler():
for i, subg in enumerate(dgl.contrib.sampling.NeighborSampler( for i, subg in enumerate(dgl.contrib.sampling.NeighborSampler(
g, 1, 100, neighbor_type='in', num_workers=4)): g, 1, 100, neighbor_type='in', num_workers=4)):
sender.send(subg, 0) sender.send(subg, 0)
break sender.signal(0)
if __name__ == '__main__': if __name__ == '__main__':
pid = os.fork() pid = os.fork()
if pid == 0: if pid == 0:
start_trainer() start_trainer()
else: else:
time.sleep(1) time.sleep(2) # wait trainer start
start_sampler() start_sampler()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment