Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
5cf48fc6
Unverified
Commit
5cf48fc6
authored
Sep 28, 2021
by
Jingcheng Yu
Committed by
GitHub
Sep 27, 2021
Browse files
[Feature] Implement one thread multiple socket (#3200)
Co-authored-by:
JingchengYu94
<
jingchengyu94@gmail.com
>
parent
179d6aab
Changes
16
Show whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
592 additions
and
132 deletions
+592
-132
CMakeLists.txt
CMakeLists.txt
+9
-0
python/dgl/distributed/rpc.py
python/dgl/distributed/rpc.py
+5
-2
src/graph/network.cc
src/graph/network.cc
+2
-2
src/rpc/network/communicator.h
src/rpc/network/communicator.h
+16
-2
src/rpc/network/msg_queue.cc
src/rpc/network/msg_queue.cc
+1
-0
src/rpc/network/msg_queue.h
src/rpc/network/msg_queue.h
+4
-0
src/rpc/network/socket_communicator.cc
src/rpc/network/socket_communicator.cc
+191
-95
src/rpc/network/socket_communicator.h
src/rpc/network/socket_communicator.h
+46
-19
src/rpc/network/socket_pool.cc
src/rpc/network/socket_pool.cc
+110
-0
src/rpc/network/socket_pool.h
src/rpc/network/socket_pool.h
+97
-0
src/rpc/network/tcp_socket.cc
src/rpc/network/tcp_socket.cc
+3
-3
src/rpc/network/tcp_socket.h
src/rpc/network/tcp_socket.h
+3
-3
src/rpc/rpc.cc
src/rpc/rpc.cc
+6
-2
src/runtime/semaphore_wrapper.cc
src/runtime/semaphore_wrapper.cc
+47
-0
src/runtime/semaphore_wrapper.h
src/runtime/semaphore_wrapper.h
+47
-0
tests/cpp/socket_communicator_test.cc
tests/cpp/socket_communicator_test.cc
+5
-4
No files found.
CMakeLists.txt
View file @
5cf48fc6
...
...
@@ -34,6 +34,7 @@ dgl_option(BUILD_CPP_TEST "Build cpp unittest executables" OFF)
dgl_option
(
LIBCXX_ENABLE_PARALLEL_ALGORITHMS
"Enable the parallel algorithms library. This requires the PSTL to be available."
OFF
)
dgl_option
(
USE_S3
"Build with S3 support"
OFF
)
dgl_option
(
USE_HDFS
"Build with HDFS support"
OFF
)
# Set env HADOOP_HDFS_HOME if needed
dgl_option
(
USE_EPOLL
"Build with epoll for socket communicator"
OFF
)
# Set debug compile option for gdb, only happens when -DCMAKE_BUILD_TYPE=DEBUG
if
(
NOT MSVC
)
...
...
@@ -120,6 +121,14 @@ if(USE_AVX)
endif
(
USE_LIBXSMM
)
endif
(
USE_AVX
)
if
(
USE_EPOLL
)
check_include_file
(
"sys/epoll.h"
USE_EPOLL
)
if
(
USE_EPOLL
)
set
(
CMAKE_C_FLAGS
"
${
CMAKE_C_FLAGS
}
-DUSE_EPOLL"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-DUSE_EPOLL"
)
endif
()
endif
()
# Build with fp16 to support mixed precision training.
if
(
USE_FP16
)
set
(
CMAKE_C_FLAGS
"
${
CMAKE_C_FLAGS
}
-DUSE_FP16"
)
...
...
python/dgl/distributed/rpc.py
View file @
5cf48fc6
"""RPC components. They are typically functions or utilities used by both
server and clients."""
import
os
import
abc
import
pickle
import
random
...
...
@@ -111,7 +112,8 @@ def create_sender(max_queue_size, net_type):
net_type : str
Networking type. Current options are: 'socket'.
"""
_CAPI_DGLRPCCreateSender
(
int
(
max_queue_size
),
net_type
)
max_thread_count
=
int
(
os
.
getenv
(
'DGL_SOCKET_MAX_THREAD_COUNT'
,
'0'
))
_CAPI_DGLRPCCreateSender
(
int
(
max_queue_size
),
net_type
,
max_thread_count
)
def
create_receiver
(
max_queue_size
,
net_type
):
"""Create rpc receiver of this process.
...
...
@@ -123,7 +125,8 @@ def create_receiver(max_queue_size, net_type):
net_type : str
Networking type. Current options are: 'socket'.
"""
_CAPI_DGLRPCCreateReceiver
(
int
(
max_queue_size
),
net_type
)
max_thread_count
=
int
(
os
.
getenv
(
'DGL_SOCKET_MAX_THREAD_COUNT'
,
'0'
))
_CAPI_DGLRPCCreateReceiver
(
int
(
max_queue_size
),
net_type
,
max_thread_count
)
def
finalize_sender
():
"""Finalize rpc sender of this process.
...
...
src/graph/network.cc
View file @
5cf48fc6
...
...
@@ -206,7 +206,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderCreate")
int64_t
msg_queue_size
=
args
[
1
];
network
::
Sender
*
sender
=
nullptr
;
if
(
type
==
"socket"
)
{
sender
=
new
network
::
SocketSender
(
msg_queue_size
);
sender
=
new
network
::
SocketSender
(
msg_queue_size
,
0
);
}
else
{
LOG
(
FATAL
)
<<
"Unknown communicator type: "
<<
type
;
}
...
...
@@ -220,7 +220,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverCreate")
int64_t
msg_queue_size
=
args
[
1
];
network
::
Receiver
*
receiver
=
nullptr
;
if
(
type
==
"socket"
)
{
receiver
=
new
network
::
SocketReceiver
(
msg_queue_size
);
receiver
=
new
network
::
SocketReceiver
(
msg_queue_size
,
0
);
}
else
{
LOG
(
FATAL
)
<<
"Unknown communicator type: "
<<
type
;
}
...
...
src/rpc/network/communicator.h
View file @
5cf48fc6
...
...
@@ -28,11 +28,14 @@ class Sender {
/*!
* \brief Sender constructor
* \param queue_size size (bytes) of message queue.
* \param max_thread_count size of thread pool. 0 for no limit
* Note that, the queue_size parameter is optional.
*/
explicit
Sender
(
int64_t
queue_size
=
0
)
{
explicit
Sender
(
int64_t
queue_size
=
0
,
int
max_thread_count
=
0
)
{
CHECK_GE
(
queue_size
,
0
);
CHECK_GE
(
max_thread_count
,
0
);
queue_size_
=
queue_size
;
max_thread_count_
=
max_thread_count
;
}
virtual
~
Sender
()
{}
...
...
@@ -86,6 +89,10 @@ class Sender {
* \brief Size of message queue
*/
int64_t
queue_size_
;
/*!
* \brief Size of thread pool. 0 for no limit
*/
int
max_thread_count_
;
};
/*!
...
...
@@ -101,13 +108,16 @@ class Receiver {
/*!
* \brief Receiver constructor
* \param queue_size size of message queue.
* \param max_thread_count size of thread pool. 0 for no limit
* Note that, the queue_size parameter is optional.
*/
explicit
Receiver
(
int64_t
queue_size
=
0
)
{
explicit
Receiver
(
int64_t
queue_size
=
0
,
int
max_thread_count
=
0
)
{
if
(
queue_size
<
0
)
{
LOG
(
FATAL
)
<<
"queue_size cannot be a negative number."
;
}
CHECK_GE
(
max_thread_count
,
0
);
queue_size_
=
queue_size
;
max_thread_count_
=
max_thread_count
;
}
virtual
~
Receiver
()
{}
...
...
@@ -165,6 +175,10 @@ class Receiver {
* \brief Size of message queue
*/
int64_t
queue_size_
;
/*!
* \brief Size of thread pool. 0 for no limit
*/
int
max_thread_count_
;
};
}
// namespace network
...
...
src/rpc/network/msg_queue.cc
View file @
5cf48fc6
...
...
@@ -72,6 +72,7 @@ STATUS MessageQueue::Remove(Message* msg, bool is_blocking) {
queue_
.
pop
();
msg
->
data
=
old_msg
.
data
;
msg
->
size
=
old_msg
.
size
;
msg
->
receiver_id
=
old_msg
.
receiver_id
;
msg
->
deallocator
=
old_msg
.
deallocator
;
free_size_
+=
old_msg
.
size
;
cond_not_full_
.
notify_one
();
...
...
src/rpc/network/msg_queue.h
View file @
5cf48fc6
...
...
@@ -56,6 +56,10 @@ struct Message {
* \brief message size in bytes
*/
int64_t
size
;
/*!
* \brief message receiver id
*/
int
receiver_id
=
-
1
;
/*!
* \brief user-defined deallocator, which can be nullptr
*/
...
...
src/rpc/network/socket_communicator.cc
View file @
5cf48fc6
...
...
@@ -12,6 +12,7 @@
#include "socket_communicator.h"
#include "../../c_api_common.h"
#include "socket_pool.h"
#ifdef _WIN32
#include <windows.h>
...
...
@@ -51,15 +52,20 @@ void SocketSender::AddReceiver(const char* addr, int recv_id) {
address
.
ip
=
ip_and_port
[
0
];
address
.
port
=
std
::
stoi
(
ip_and_port
[
1
]);
receiver_addrs_
[
recv_id
]
=
address
;
msg_queue_
[
recv_id
]
=
std
::
make_shared
<
MessageQueue
>
(
queue_size_
);
}
bool
SocketSender
::
Connect
()
{
// Create N sockets for Receiver
int
receiver_count
=
static_cast
<
int
>
(
receiver_addrs_
.
size
());
if
(
max_thread_count_
==
0
||
max_thread_count_
>
receiver_count
)
{
max_thread_count_
=
receiver_count
;
}
sockets_
.
resize
(
max_thread_count_
);
for
(
const
auto
&
r
:
receiver_addrs_
)
{
int
ID
=
r
.
first
;
sockets_
[
ID
]
=
std
::
make_shared
<
TCPSocket
>
();
TCPSocket
*
client_socket
=
sockets_
[
ID
].
get
();
int
receiver_id
=
r
.
first
;
int
thread_id
=
receiver_id
%
max_thread_count_
;
sockets_
[
thread_id
][
receiver_id
]
=
std
::
make_shared
<
TCPSocket
>
();
TCPSocket
*
client_socket
=
sockets_
[
thread_id
][
receiver_id
].
get
();
bool
bo
=
false
;
int
try_count
=
0
;
const
char
*
ip
=
r
.
second
.
ip
.
c_str
();
...
...
@@ -83,12 +89,17 @@ bool SocketSender::Connect() {
if
(
bo
==
false
)
{
return
bo
;
}
}
for
(
int
thread_id
=
0
;
thread_id
<
max_thread_count_
;
++
thread_id
)
{
msg_queue_
.
push_back
(
std
::
make_shared
<
MessageQueue
>
(
queue_size_
));
// Create a new thread for this socket connection
threads_
[
ID
]
=
std
::
make_shared
<
std
::
thread
>
(
threads_
.
push_back
(
std
::
make_shared
<
std
::
thread
>
(
SendLoop
,
client_socket
,
msg_queue_
[
ID
].
get
(
));
sockets_
[
thread_id
]
,
msg_queue_
[
thread_id
]
));
}
return
true
;
}
...
...
@@ -96,53 +107,48 @@ STATUS SocketSender::Send(Message msg, int recv_id) {
CHECK_NOTNULL
(
msg
.
data
);
CHECK_GT
(
msg
.
size
,
0
);
CHECK_GE
(
recv_id
,
0
);
msg
.
receiver_id
=
recv_id
;
// Add data message to message queue
STATUS
code
=
msg_queue_
[
recv_id
]
->
Add
(
msg
);
STATUS
code
=
msg_queue_
[
recv_id
%
max_thread_count_
]
->
Add
(
msg
);
return
code
;
}
void
SocketSender
::
Finalize
()
{
// Send a signal to tell the msg_queue to finish its job
for
(
auto
&
mq
:
msg_queue_
)
{
for
(
int
i
=
0
;
i
<
max_thread_count_
;
++
i
)
{
// wait until queue is empty
while
(
mq
.
second
->
Empty
()
==
false
)
{
auto
&
mq
=
msg_queue_
[
i
];
while
(
mq
->
Empty
()
==
false
)
{
#ifdef _WIN32
// just loop
#else // !_WIN32
usleep
(
1000
);
#endif // _WIN32
}
int
ID
=
mq
.
first
;
mq
.
second
->
SignalFinished
(
ID
);
// All queues have only one producer, which is main thread, so
// the producerID argument here should be zero.
mq
->
SignalFinished
(
0
);
}
// Block main thread until all socket-threads finish their jobs
for
(
auto
&
thread
:
threads_
)
{
thread
.
second
->
join
();
thread
->
join
();
}
// Clear all sockets
for
(
auto
&
socket
:
sockets_
)
{
for
(
auto
&
group_sockets_
:
sockets_
)
{
for
(
auto
&
socket
:
group_sockets_
)
{
socket
.
second
->
Close
();
}
}
}
void
SocketSender
::
SendLoop
(
TCPSocket
*
socket
,
MessageQueue
*
queue
)
{
CHECK_NOTNULL
(
socket
);
CHECK_NOTNULL
(
queue
);
bool
exit
=
false
;
while
(
!
exit
)
{
Message
msg
;
STATUS
code
=
queue
->
Remove
(
&
msg
);
if
(
code
==
QUEUE_CLOSE
)
{
msg
.
size
=
0
;
// send an end-signal to receiver
exit
=
true
;
}
void
SendCore
(
Message
msg
,
TCPSocket
*
socket
)
{
// First send the size
// If exit == true, we will send zero size to reciever
int64_t
sent_bytes
=
0
;
while
(
static_cast
<
size_t
>
(
sent_bytes
)
<
sizeof
(
int64_t
))
{
int64_t
max_len
=
sizeof
(
int64_t
)
-
sent_bytes
;
int64_t
tmp
=
socket
->
Send
(
reinterpret_cast
<
char
*>
(
&
msg
.
size
)
+
sent_bytes
,
reinterpret_cast
<
char
*>
(
&
msg
.
size
)
+
sent_bytes
,
max_len
);
CHECK_NE
(
tmp
,
-
1
);
sent_bytes
+=
tmp
;
...
...
@@ -159,6 +165,22 @@ void SocketSender::SendLoop(TCPSocket* socket, MessageQueue* queue) {
if
(
msg
.
deallocator
!=
nullptr
)
{
msg
.
deallocator
(
&
msg
);
}
}
void
SocketSender
::
SendLoop
(
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
TCPSocket
>>
sockets
,
std
::
shared_ptr
<
MessageQueue
>
queue
)
{
for
(;;)
{
Message
msg
;
STATUS
code
=
queue
->
Remove
(
&
msg
);
if
(
code
==
QUEUE_CLOSE
)
{
msg
.
size
=
0
;
// send an end-signal to receiver
for
(
auto
&
socket
:
sockets
)
{
SendCore
(
msg
,
socket
.
second
.
get
());
}
break
;
}
SendCore
(
msg
,
sockets
[
msg
.
receiver_id
].
get
());
}
}
...
...
@@ -187,16 +209,20 @@ bool SocketReceiver::Wait(const char* addr, int num_sender) {
int
port
=
stoi
(
ip_and_port
[
1
]);
// Initialize message queue for each connection
num_sender_
=
num_sender
;
for
(
int
i
=
0
;
i
<
num_sender_
;
++
i
)
{
msg_queue_
[
i
]
=
std
::
make_shared
<
MessageQueue
>
(
queue_size_
);
#ifdef USE_EPOLL
if
(
max_thread_count_
==
0
||
max_thread_count_
>
num_sender_
)
{
max_thread_count_
=
num_sender_
;
}
mq_iter_
=
msg_queue_
.
begin
();
#else
max_thread_count_
=
num_sender_
;
#endif
// Initialize socket and socket-thread
server_socket_
=
new
TCPSocket
();
// Bind socket
if
(
server_socket_
->
Bind
(
ip
.
c_str
(),
port
)
==
false
)
{
LOG
(
FATAL
)
<<
"Cannot bind to "
<<
ip
<<
":"
<<
port
;
}
// Listen
if
(
server_socket_
->
Listen
(
kMaxConnection
)
==
false
)
{
LOG
(
FATAL
)
<<
"Cannot listen on "
<<
ip
<<
":"
<<
port
;
...
...
@@ -204,27 +230,39 @@ bool SocketReceiver::Wait(const char* addr, int num_sender) {
// Accept all sender sockets
std
::
string
accept_ip
;
int
accept_port
;
sockets_
.
resize
(
max_thread_count_
);
for
(
int
i
=
0
;
i
<
num_sender_
;
++
i
)
{
sockets_
[
i
]
=
std
::
make_shared
<
TCPSocket
>
();
if
(
server_socket_
->
Accept
(
sockets_
[
i
].
get
(),
&
accept_ip
,
&
accept_port
)
==
false
)
{
int
thread_id
=
i
%
max_thread_count_
;
auto
socket
=
std
::
make_shared
<
TCPSocket
>
();
sockets_
[
thread_id
][
i
]
=
socket
;
msg_queue_
[
i
]
=
std
::
make_shared
<
MessageQueue
>
(
queue_size_
);
if
(
server_socket_
->
Accept
(
socket
.
get
(),
&
accept_ip
,
&
accept_port
)
==
false
)
{
LOG
(
WARNING
)
<<
"Error on accept socket."
;
return
false
;
}
}
mq_iter_
=
msg_queue_
.
begin
();
for
(
int
thread_id
=
0
;
thread_id
<
max_thread_count_
;
++
thread_id
)
{
// create new thread for each socket
threads_
[
i
]
=
std
::
make_shared
<
std
::
thread
>
(
threads_
.
push_back
(
std
::
make_shared
<
std
::
thread
>
(
RecvLoop
,
sockets_
[
i
].
get
(),
msg_queue_
[
i
].
get
());
sockets_
[
thread_id
],
msg_queue_
,
&
queue_sem_
));
}
return
true
;
}
STATUS
SocketReceiver
::
Recv
(
Message
*
msg
,
int
*
send_id
)
{
// loop until get a message
// queue_sem_ is a semaphore indicating how many elements in multiple
// message queues.
// When calling queue_sem_.Wait(), this Recv will be suspended until
// queue_sem_ > 0, decrease queue_sem_ by 1, then start to fetch a message.
queue_sem_
.
Wait
();
for
(;;)
{
for
(;
mq_iter_
!=
msg_queue_
.
end
();
++
mq_iter_
)
{
// We use non-block remove here
STATUS
code
=
mq_iter_
->
second
->
Remove
(
msg
,
false
);
if
(
code
==
QUEUE_EMPTY
)
{
continue
;
// jump to the next queue
...
...
@@ -240,6 +278,7 @@ STATUS SocketReceiver::Recv(Message* msg, int* send_id) {
STATUS
SocketReceiver
::
RecvFrom
(
Message
*
msg
,
int
send_id
)
{
// Get message from specified message queue
queue_sem_
.
Wait
();
STATUS
code
=
msg_queue_
[
send_id
]
->
Remove
(
msg
);
return
code
;
}
...
...
@@ -255,47 +294,93 @@ void SocketReceiver::Finalize() {
usleep
(
1000
);
#endif // _WIN32
}
int
ID
=
mq
.
first
;
mq
.
second
->
SignalFinished
(
ID
);
mq
.
second
->
SignalFinished
(
mq
.
first
);
}
// Block main thread until all socket-threads finish their jobs
for
(
auto
&
thread
:
threads_
)
{
thread
.
second
->
join
();
thread
->
join
();
}
// Clear all sockets
for
(
auto
&
socket
:
sockets_
)
{
for
(
auto
&
group_sockets
:
sockets_
)
{
for
(
auto
&
socket
:
group_sockets
)
{
socket
.
second
->
Close
();
}
}
server_socket_
->
Close
();
delete
server_socket_
;
}
void
SocketReceiver
::
RecvLoop
(
TCPSocket
*
socket
,
MessageQueue
*
queue
)
{
CHECK_NOTNULL
(
socket
);
CHECK_NOTNULL
(
queue
);
for
(;;)
{
// If main thread had finished its job
if
(
queue
->
EmptyAndNoMoreAdd
())
{
return
;
// exit loop thread
}
// First recv the size
int64_t
RecvDataSize
(
TCPSocket
*
socket
)
{
int64_t
received_bytes
=
0
;
int64_t
data_size
=
0
;
while
(
static_cast
<
size_t
>
(
received_bytes
)
<
sizeof
(
int64_t
))
{
int64_t
max_len
=
sizeof
(
int64_t
)
-
received_bytes
;
int64_t
tmp
=
socket
->
Receive
(
reinterpret_cast
<
char
*>
(
&
data_size
)
+
received_bytes
,
reinterpret_cast
<
char
*>
(
&
data_size
)
+
received_bytes
,
max_len
);
CHECK_NE
(
tmp
,
-
1
);
if
(
tmp
==
-
1
)
{
if
(
received_bytes
>
0
)
{
// We want to finish reading full data_size
continue
;
}
return
-
1
;
}
received_bytes
+=
tmp
;
}
if
(
data_size
<
0
)
{
LOG
(
FATAL
)
<<
"Recv data error (data_size: "
<<
data_size
<<
")"
;
}
else
if
(
data_size
==
0
)
{
// This is an end-signal sent by client
return
data_size
;
}
void
RecvData
(
TCPSocket
*
socket
,
char
*
buffer
,
const
int64_t
&
data_size
,
int64_t
*
received_bytes
)
{
while
(
*
received_bytes
<
data_size
)
{
int64_t
max_len
=
data_size
-
*
received_bytes
;
int64_t
tmp
=
socket
->
Receive
(
buffer
+
*
received_bytes
,
max_len
);
if
(
tmp
==
-
1
)
{
// Socket not ready, no more data to read
return
;
}
else
{
char
*
buffer
=
nullptr
;
}
*
received_bytes
+=
tmp
;
}
}
void
SocketReceiver
::
RecvLoop
(
std
::
unordered_map
<
int
/* Sender (virtual) ID */
,
std
::
shared_ptr
<
TCPSocket
>>
sockets
,
std
::
unordered_map
<
int
/* Sender (virtual) ID */
,
std
::
shared_ptr
<
MessageQueue
>>
queues
,
runtime
::
Semaphore
*
queue_sem
)
{
std
::
unordered_map
<
int
,
std
::
unique_ptr
<
RecvContext
>>
recv_contexts
;
SocketPool
socket_pool
;
for
(
auto
&
socket
:
sockets
)
{
auto
&
sender_id
=
socket
.
first
;
socket_pool
.
AddSocket
(
socket
.
second
,
sender_id
);
recv_contexts
[
sender_id
]
=
std
::
unique_ptr
<
RecvContext
>
(
new
RecvContext
());
}
// Main loop to receive messages
for
(;;)
{
int
sender_id
;
// Get active socket using epoll
std
::
shared_ptr
<
TCPSocket
>
socket
=
socket_pool
.
GetActiveSocket
(
&
sender_id
);
if
(
queues
[
sender_id
]
->
EmptyAndNoMoreAdd
())
{
// This sender has already stopped
if
(
socket_pool
.
RemoveSocket
(
socket
)
==
0
)
{
return
;
}
continue
;
}
// Nonblocking socket might be interrupted at any point. So we need to
// store the partially received data
std
::
unique_ptr
<
RecvContext
>
&
ctx
=
recv_contexts
[
sender_id
];
int64_t
&
data_size
=
ctx
->
data_size
;
int64_t
&
received_bytes
=
ctx
->
received_bytes
;
char
*&
buffer
=
ctx
->
buffer
;
if
(
data_size
==
-
1
)
{
// This is a new message, so receive the data size first
data_size
=
RecvDataSize
(
socket
.
get
());
if
(
data_size
>
0
)
{
try
{
buffer
=
new
char
[
data_size
];
}
catch
(
const
std
::
bad_alloc
&
)
{
...
...
@@ -303,17 +388,28 @@ void SocketReceiver::RecvLoop(TCPSocket* socket, MessageQueue* queue) {
<<
"(message size: "
<<
data_size
<<
")"
;
}
received_bytes
=
0
;
while
(
received_bytes
<
data_size
)
{
int64_t
max_len
=
data_size
-
received_bytes
;
int64_t
tmp
=
socket
->
Receive
(
buffer
+
received_bytes
,
max_len
);
CHECK_NE
(
tmp
,
-
1
);
received_bytes
+=
tmp
;
}
else
if
(
data_size
==
0
)
{
// Received stop signal
if
(
socket_pool
.
RemoveSocket
(
socket
)
==
0
)
{
return
;
}
}
}
RecvData
(
socket
.
get
(),
buffer
,
data_size
,
&
received_bytes
);
if
(
received_bytes
>=
data_size
)
{
// Full data received, create Message and push to queue
Message
msg
;
msg
.
data
=
buffer
;
msg
.
size
=
data_size
;
msg
.
deallocator
=
DefaultMessageDeleter
;
queue
->
Add
(
msg
);
queues
[
sender_id
]
->
Add
(
msg
);
// Reset recv context
data_size
=
-
1
;
// Signal queue semaphore
queue_sem
->
Post
();
}
}
}
...
...
src/rpc/network/socket_communicator.h
View file @
5cf48fc6
...
...
@@ -12,6 +12,7 @@
#include <unordered_map>
#include <memory>
#include "../../runtime/semaphore_wrapper.h"
#include "communicator.h"
#include "msg_queue.h"
#include "tcp_socket.h"
...
...
@@ -42,8 +43,10 @@ class SocketSender : public Sender {
/*!
* \brief Sender constructor
* \param queue_size size of message queue
* \param max_thread_count size of thread pool. 0 for no limit
*/
explicit
SocketSender
(
int64_t
queue_size
)
:
Sender
(
queue_size
)
{}
SocketSender
(
int64_t
queue_size
,
int
max_thread_count
)
:
Sender
(
queue_size
,
max_thread_count
)
{}
/*!
* \brief Add receiver's address and ID to the sender's namebook
...
...
@@ -93,7 +96,8 @@ class SocketSender : public Sender {
/*!
* \brief socket for each connection of receiver
*/
std
::
unordered_map
<
int
/* receiver ID */
,
std
::
shared_ptr
<
TCPSocket
>>
sockets_
;
std
::
vector
<
std
::
unordered_map
<
int
/* receiver ID */
,
std
::
shared_ptr
<
TCPSocket
>>>
sockets_
;
/*!
* \brief receivers' address
...
...
@@ -101,24 +105,27 @@ class SocketSender : public Sender {
std
::
unordered_map
<
int
/* receiver ID */
,
IPAddr
>
receiver_addrs_
;
/*!
* \brief message queue for each
socket connection
* \brief message queue for each
thread
*/
std
::
unordered_map
<
int
/* receiver ID */
,
std
::
shared_ptr
<
MessageQueue
>>
msg_queue_
;
std
::
vector
<
std
::
shared_ptr
<
MessageQueue
>>
msg_queue_
;
/*!
* \brief Independent thread
for each socket connection
* \brief Independent thread
*/
std
::
unordered_map
<
int
/* receiver ID */
,
std
::
shared_ptr
<
std
::
thread
>>
threads_
;
std
::
vector
<
std
::
shared_ptr
<
std
::
thread
>>
threads_
;
/*!
* \brief Send-loop for each
socket in per-
thread
* \param socket TCPSocket for current
connection
* \param queue message_queue for current
connection
* \brief Send-loop for each thread
* \param socket
s
TCPSocket
s
for current
thread
* \param queue message_queue for current
thread
*
* Note that, the SendLoop will finish its loop-job and exit thread
* when the main thread invokes Signal() API on the message queue.
*/
static
void
SendLoop
(
TCPSocket
*
socket
,
MessageQueue
*
queue
);
static
void
SendLoop
(
std
::
unordered_map
<
int
/* Receiver (virtual) ID */
,
std
::
shared_ptr
<
TCPSocket
>>
sockets
,
std
::
shared_ptr
<
MessageQueue
>
queue
);
};
/*!
...
...
@@ -131,8 +138,10 @@ class SocketReceiver : public Receiver {
/*!
* \brief Receiver constructor
* \param queue_size size of message queue.
* \param max_thread_count size of thread pool. 0 for no limit
*/
explicit
SocketReceiver
(
int64_t
queue_size
)
:
Receiver
(
queue_size
)
{}
SocketReceiver
(
int64_t
queue_size
,
int
max_thread_count
)
:
Receiver
(
queue_size
,
max_thread_count
)
{}
/*!
* \brief Wait for all the Senders to connect
...
...
@@ -183,6 +192,11 @@ class SocketReceiver : public Receiver {
inline
std
::
string
Type
()
const
{
return
std
::
string
(
"socket"
);
}
private:
struct
RecvContext
{
int64_t
data_size
=
-
1
;
int64_t
received_bytes
=
0
;
char
*
buffer
=
nullptr
;
};
/*!
* \brief number of sender
*/
...
...
@@ -196,28 +210,41 @@ class SocketReceiver : public Receiver {
/*!
* \brief socket for each client connections
*/
std
::
unordered_map
<
int
/* Sender (virutal) ID */
,
std
::
shared_ptr
<
TCPSocket
>>
sockets_
;
std
::
vector
<
std
::
unordered_map
<
int
/* Sender (virutal) ID */
,
std
::
shared_ptr
<
TCPSocket
>>>
sockets_
;
/*!
* \brief Message queue for each socket connection
*/
std
::
unordered_map
<
int
/* Sender (virtual) ID */
,
std
::
shared_ptr
<
MessageQueue
>>
msg_queue_
;
std
::
unordered_map
<
int
/* Sender (virtual) ID */
,
std
::
shared_ptr
<
MessageQueue
>>
msg_queue_
;
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
MessageQueue
>>::
iterator
mq_iter_
;
/*!
* \brief Independent thead
for each socket connection
* \brief Independent thead
*/
std
::
unordered_map
<
int
/* Sender (virtual) ID */
,
std
::
shared_ptr
<
std
::
thread
>>
threads_
;
std
::
vector
<
std
::
shared_ptr
<
std
::
thread
>>
threads_
;
/*!
* \brief Recv-loop for each socket in per-thread
* \param socket client socket
* \param queue message queue
* \brief queue_sem_ semphore to indicate number of messages in multiple
* message queues to prevent busy wait of Recv
*/
runtime
::
Semaphore
queue_sem_
;
/*!
* \brief Recv-loop for each thread
* \param sockets client sockets of current thread
* \param queue message queues of current thread
*
* Note that, the RecvLoop will finish its loop-job and exit thread
* when the main thread invokes Signal() API on the message queue.
*/
static
void
RecvLoop
(
TCPSocket
*
socket
,
MessageQueue
*
queue
);
static
void
RecvLoop
(
std
::
unordered_map
<
int
/* Sender (virtual) ID */
,
std
::
shared_ptr
<
TCPSocket
>>
sockets
,
std
::
unordered_map
<
int
/* Sender (virtual) ID */
,
std
::
shared_ptr
<
MessageQueue
>>
queues
,
runtime
::
Semaphore
*
queue_sem
);
};
}
// namespace network
...
...
src/rpc/network/socket_pool.cc
0 → 100644
View file @
5cf48fc6
/*!
* Copyright (c) 2021 by Contributors
* \file socket_pool.cc
* \brief Socket pool of nonblocking sockets for DGL distributed training.
*/
#include "socket_pool.h"
#include <dmlc/logging.h>
#include "tcp_socket.h"
#ifdef USE_EPOLL
#include <sys/epoll.h>
#endif
namespace
dgl
{
namespace
network
{
SocketPool
::
SocketPool
()
{
#ifdef USE_EPOLL
epfd_
=
epoll_create1
(
0
);
if
(
epfd_
<
0
)
{
LOG
(
FATAL
)
<<
"SocketPool cannot create epfd"
;
}
#endif
}
void
SocketPool
::
AddSocket
(
std
::
shared_ptr
<
TCPSocket
>
socket
,
int
socket_id
,
int
events
)
{
int
fd
=
socket
->
Socket
();
tcp_sockets_
[
fd
]
=
socket
;
socket_ids_
[
fd
]
=
socket_id
;
#ifdef USE_EPOLL
epoll_event
e
;
e
.
data
.
fd
=
fd
;
if
(
events
==
READ
)
{
e
.
events
=
EPOLLIN
;
}
else
if
(
events
==
WRITE
)
{
e
.
events
=
EPOLLOUT
;
}
else
if
(
events
==
READ
+
WRITE
)
{
e
.
events
=
EPOLLIN
|
EPOLLOUT
;
}
if
(
epoll_ctl
(
epfd_
,
EPOLL_CTL_ADD
,
fd
,
&
e
)
<
0
)
{
LOG
(
FATAL
)
<<
"SocketPool cannot add socket"
;
}
socket
->
SetNonBlocking
(
true
);
#else
if
(
tcp_sockets_
.
size
()
>
1
)
{
LOG
(
FATAL
)
<<
"SocketPool supports only one socket if not use epoll."
"Please turn on USE_EPOLL on building"
;
}
#endif
}
size_t
SocketPool
::
RemoveSocket
(
std
::
shared_ptr
<
TCPSocket
>
socket
)
{
int
fd
=
socket
->
Socket
();
socket_ids_
.
erase
(
fd
);
tcp_sockets_
.
erase
(
fd
);
#ifdef USE_EPOLL
epoll_ctl
(
epfd_
,
EPOLL_CTL_DEL
,
fd
,
NULL
);
#endif
return
socket_ids_
.
size
();
}
SocketPool
::~
SocketPool
()
{
#ifdef USE_EPOLL
for
(
auto
&
id
:
socket_ids_
)
{
int
fd
=
id
.
first
;
epoll_ctl
(
epfd_
,
EPOLL_CTL_DEL
,
fd
,
NULL
);
}
#endif
}
std
::
shared_ptr
<
TCPSocket
>
SocketPool
::
GetActiveSocket
(
int
*
socket_id
)
{
if
(
socket_ids_
.
empty
())
{
return
nullptr
;
}
for
(;;)
{
while
(
pending_fds_
.
empty
())
{
Wait
();
}
int
fd
=
pending_fds_
.
front
();
pending_fds_
.
pop
();
// Check if this socket is not removed
if
(
socket_ids_
.
find
(
fd
)
!=
socket_ids_
.
end
())
{
*
socket_id
=
socket_ids_
[
fd
];
return
tcp_sockets_
[
fd
];
}
}
return
nullptr
;
}
void
SocketPool
::
Wait
()
{
#ifdef USE_EPOLL
static
const
int
MAX_EVENTS
=
10
;
epoll_event
events
[
MAX_EVENTS
];
int
nfd
=
epoll_wait
(
epfd_
,
events
,
MAX_EVENTS
,
-
1
/*Timeout*/
);
for
(
int
i
=
0
;
i
<
nfd
;
++
i
)
{
pending_fds_
.
push
(
events
[
i
].
data
.
fd
);
}
#else
pending_fds_
.
push
(
tcp_sockets_
.
begin
()
->
second
->
Socket
());
#endif
}
}
// namespace network
}
// namespace dgl
src/rpc/network/socket_pool.h
0 → 100644
View file @
5cf48fc6
/*!
* Copyright (c) 2021 by Contributors
* \file socket_pool.h
* \brief Socket pool of nonblocking sockets for DGL distributed training.
*/
#ifndef DGL_RPC_NETWORK_SOCKET_POOL_H_
#define DGL_RPC_NETWORK_SOCKET_POOL_H_
#include <unordered_map>
#include <queue>
#include <memory>
namespace
dgl
{
namespace
network
{
class
TCPSocket
;
/*!
* \brief SocketPool maintains a group of nonblocking sockets, and can provide
* active sockets.
* Currently SocketPool is based on epoll, a scalable I/O event notification
* mechanism in Linux operating system.
*/
class
SocketPool
{
public:
/*!
* \brief socket mode read/receive
*/
static
const
int
READ
=
1
;
/*!
* \brief socket mode write/send
*/
static
const
int
WRITE
=
2
;
/*!
* \brief SocketPool constructor
*/
SocketPool
();
/*!
* \brief Add a socket to SocketPool
* \param socket tcp socket to add
* \param socket_id receiver/sender id of the socket
* \param events READ, WRITE or READ + WRITE
*/
void
AddSocket
(
std
::
shared_ptr
<
TCPSocket
>
socket
,
int
socket_id
,
int
events
=
READ
);
/*!
* \brief Remove socket from SocketPool
* \param socket tcp socket to remove
* \return number of remaing sockets in the pool
*/
size_t
RemoveSocket
(
std
::
shared_ptr
<
TCPSocket
>
socket
);
/*!
* \brief SocketPool destructor
*/
~
SocketPool
();
/*!
* \brief Get current active socket. This is a blocking method
* \param socket_id output parameter of the socket_id of active socket
* \return active TCPSocket
*/
std
::
shared_ptr
<
TCPSocket
>
GetActiveSocket
(
int
*
socket_id
);
private:
/*!
* \brief Wait for event notification
*/
void
Wait
();
/*!
* \brief map from fd to TCPSocket
*/
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
TCPSocket
>>
tcp_sockets_
;
/*!
* \brief map from fd to socket_id
*/
std
::
unordered_map
<
int
,
int
>
socket_ids_
;
/*!
* \brief fd for epoll base
*/
int
epfd_
;
/*!
* \brief queue for current active fds
*/
std
::
queue
<
int
>
pending_fds_
;
};
}
// namespace network
}
// namespace dgl
#endif // DGL_RPC_NETWORK_SOCKET_POOL_H_
src/rpc/network/tcp_socket.cc
View file @
5cf48fc6
...
...
@@ -119,7 +119,7 @@ bool TCPSocket::Accept(TCPSocket * socket, std::string * ip, int * port) {
}
#ifdef _WIN32
bool
TCPSocket
::
SetBlocking
(
bool
flag
)
{
bool
TCPSocket
::
Set
Non
Blocking
(
bool
flag
)
{
int
result
;
u_long
argp
=
flag
?
1
:
0
;
...
...
@@ -134,7 +134,7 @@ bool TCPSocket::SetBlocking(bool flag) {
return
true
;
}
#else // !_WIN32
bool
TCPSocket
::
SetBlocking
(
bool
flag
)
{
bool
TCPSocket
::
Set
Non
Blocking
(
bool
flag
)
{
int
opts
;
if
((
opts
=
fcntl
(
socket_
,
F_GETFL
))
<
0
)
{
...
...
@@ -205,7 +205,7 @@ int64_t TCPSocket::Receive(char * buffer, int64_t size_buffer) {
do
{
// retry if EINTR failure appears
number_recv
=
recv
(
socket_
,
buffer
,
size_buffer
,
0
);
}
while
(
number_recv
==
-
1
&&
errno
==
EINTR
);
if
(
number_recv
==
-
1
)
{
if
(
number_recv
==
-
1
&&
errno
!=
EAGAIN
&&
errno
!=
EWOULDBLOCK
)
{
LOG
(
ERROR
)
<<
"recv error: "
<<
strerror
(
errno
);
}
...
...
src/rpc/network/tcp_socket.h
View file @
5cf48fc6
...
...
@@ -70,12 +70,12 @@ class TCPSocket {
int
*
port_client
);
/*!
* \brief SetBlocking() is needed refering to this example of epoll:
* \brief Set
Non
Blocking() is needed refering to this example of epoll:
* http://www.kernel.org/doc/man-pages/online/pages/man4/epoll.4.html
* \param flag
flag
for blocking
* \param flag
true for nonblocking, false
for blocking
* \return true for success and false for failure
*/
bool
SetBlocking
(
bool
flag
);
bool
Set
Non
Blocking
(
bool
flag
);
/*!
* \brief Set timeout for socket
...
...
src/rpc/rpc.cc
View file @
5cf48fc6
...
...
@@ -87,8 +87,10 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateSender")
.
set_body
([]
(
DGLArgs
args
,
DGLRetValue
*
rv
)
{
int64_t
msg_queue_size
=
args
[
0
];
std
::
string
type
=
args
[
1
];
int
max_thread_count
=
args
[
2
];
if
(
type
.
compare
(
"socket"
)
==
0
)
{
RPCContext
::
ThreadLocal
()
->
sender
=
std
::
make_shared
<
network
::
SocketSender
>
(
msg_queue_size
);
RPCContext
::
ThreadLocal
()
->
sender
=
std
::
make_shared
<
network
::
SocketSender
>
(
msg_queue_size
,
max_thread_count
);
}
else
{
LOG
(
FATAL
)
<<
"Unknown communicator type for rpc receiver: "
<<
type
;
}
...
...
@@ -98,8 +100,10 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateReceiver")
.
set_body
([]
(
DGLArgs
args
,
DGLRetValue
*
rv
)
{
int64_t
msg_queue_size
=
args
[
0
];
std
::
string
type
=
args
[
1
];
int
max_thread_count
=
args
[
2
];
if
(
type
.
compare
(
"socket"
)
==
0
)
{
RPCContext
::
ThreadLocal
()
->
receiver
=
std
::
make_shared
<
network
::
SocketReceiver
>
(
msg_queue_size
);
RPCContext
::
ThreadLocal
()
->
receiver
=
std
::
make_shared
<
network
::
SocketReceiver
>
(
msg_queue_size
,
max_thread_count
);
}
else
{
LOG
(
FATAL
)
<<
"Unknown communicator type for rpc sender: "
<<
type
;
}
...
...
src/runtime/semaphore_wrapper.cc
0 → 100644
View file @
5cf48fc6
/*!
* Copyright (c) 2021 by Contributors
* \file semaphore_wrapper.cc
* \brief A simple corss platform semaphore wrapper
*/
#include "semaphore_wrapper.h"
#include <dmlc/logging.h>
namespace
dgl
{
namespace
runtime
{
#ifdef _WIN32
Semaphore
::
Semaphore
()
{
sem_
=
CreateSemaphore
(
nullptr
,
0
,
INT_MAX
,
nullptr
);
if
(
!
sem_
)
{
LOG
(
FATAL
)
<<
"Cannot create semaphore"
;
}
}
void
Semaphore
::
Wait
()
{
WaitForSingleObject
(
sem_
,
INFINITE
);
}
void
Semaphore
::
Post
()
{
ReleaseSemaphore
(
sem_
,
1
,
nullptr
);
}
#else
Semaphore
::
Semaphore
()
{
sem_init
(
&
sem_
,
0
,
0
);
}
void
Semaphore
::
Wait
()
{
sem_wait
(
&
sem_
);
}
void
Semaphore
::
Post
()
{
sem_post
(
&
sem_
);
}
#endif
}
// namespace runtime
}
// namespace dgl
src/runtime/semaphore_wrapper.h
0 → 100644
View file @
5cf48fc6
/*!
* Copyright (c) 2021 by Contributors
* \file semaphore_wrapper.h
* \brief A simple corss platform semaphore wrapper
*/
#ifndef DGL_RUNTIME_SEMAPHORE_WRAPPER_H_
#define DGL_RUNTIME_SEMAPHORE_WRAPPER_H_
#ifdef _WIN32
#include <windows.h>
#else
#include <semaphore.h>
#endif
namespace
dgl
{
namespace
runtime
{
/*!
* \brief A simple crossplatform Semaphore wrapper
*/
class
Semaphore
{
public:
/*!
* \brief Semaphore constructor
*/
Semaphore
();
/*!
* \brief blocking wait, decrease semaphore by 1
*/
void
Wait
();
/*!
* \brief increase semaphore by 1
*/
void
Post
();
private:
#ifdef _WIN32
HANDLE
sem_
;
#else
sem_t
sem_
;
#endif
};
}
// namespace runtime
}
// namespace dgl
#endif // DGL_RUNTIME_SEMAPHORE_WRAPPER_H_
tests/cpp/socket_communicator_test.cc
View file @
5cf48fc6
...
...
@@ -25,6 +25,7 @@ using dgl::network::Message;
using
dgl
::
network
::
DefaultMessageDeleter
;
const
int64_t
kQueueSize
=
500
*
1024
;
const
int
kThreadNum
=
2
;
#ifndef WIN32
...
...
@@ -61,7 +62,7 @@ TEST(SocketCommunicatorTest, SendAndRecv) {
}
void
start_client
()
{
SocketSender
sender
(
kQueueSize
);
SocketSender
sender
(
kQueueSize
,
kThreadNum
);
for
(
int
i
=
0
;
i
<
kNumReceiver
;
++
i
)
{
sender
.
AddReceiver
(
ip_addr
[
i
],
i
);
}
...
...
@@ -89,7 +90,7 @@ void start_client() {
void
start_server
(
int
id
)
{
sleep
(
5
);
SocketReceiver
receiver
(
kQueueSize
);
SocketReceiver
receiver
(
kQueueSize
,
kThreadNum
);
receiver
.
Wait
(
ip_addr
[
id
],
kNumSender
);
for
(
int
i
=
0
;
i
<
kNumMessage
;
++
i
)
{
for
(
int
n
=
0
;
n
<
kNumSender
;
++
n
)
{
...
...
@@ -168,7 +169,7 @@ static void start_client() {
std
::
string
ip_addr
((
std
::
istreambuf_iterator
<
char
>
(
t
)),
std
::
istreambuf_iterator
<
char
>
());
t
.
close
();
SocketSender
sender
(
kQueueSize
);
SocketSender
sender
(
kQueueSize
,
kThreadNum
);
sender
.
AddReceiver
(
ip_addr
.
c_str
(),
0
);
sender
.
Connect
();
char
*
str_data
=
new
char
[
9
];
...
...
@@ -185,7 +186,7 @@ static bool start_server() {
std
::
string
ip_addr
((
std
::
istreambuf_iterator
<
char
>
(
t
)),
std
::
istreambuf_iterator
<
char
>
());
t
.
close
();
SocketReceiver
receiver
(
kQueueSize
);
SocketReceiver
receiver
(
kQueueSize
,
kThreadNum
);
receiver
.
Wait
(
ip_addr
.
c_str
(),
1
);
Message
msg
;
EXPECT_EQ
(
receiver
.
RecvFrom
(
&
msg
,
0
),
REMOVE_SUCCESS
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment