Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
5cf48fc6
Unverified
Commit
5cf48fc6
authored
Sep 28, 2021
by
Jingcheng Yu
Committed by
GitHub
Sep 27, 2021
Browse files
[Feature] Implement one thread multiple socket (#3200)
Co-authored-by:
JingchengYu94
<
jingchengyu94@gmail.com
>
parent
179d6aab
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
592 additions
and
132 deletions
+592
-132
CMakeLists.txt
CMakeLists.txt
+9
-0
python/dgl/distributed/rpc.py
python/dgl/distributed/rpc.py
+5
-2
src/graph/network.cc
src/graph/network.cc
+2
-2
src/rpc/network/communicator.h
src/rpc/network/communicator.h
+16
-2
src/rpc/network/msg_queue.cc
src/rpc/network/msg_queue.cc
+1
-0
src/rpc/network/msg_queue.h
src/rpc/network/msg_queue.h
+4
-0
src/rpc/network/socket_communicator.cc
src/rpc/network/socket_communicator.cc
+191
-95
src/rpc/network/socket_communicator.h
src/rpc/network/socket_communicator.h
+46
-19
src/rpc/network/socket_pool.cc
src/rpc/network/socket_pool.cc
+110
-0
src/rpc/network/socket_pool.h
src/rpc/network/socket_pool.h
+97
-0
src/rpc/network/tcp_socket.cc
src/rpc/network/tcp_socket.cc
+3
-3
src/rpc/network/tcp_socket.h
src/rpc/network/tcp_socket.h
+3
-3
src/rpc/rpc.cc
src/rpc/rpc.cc
+6
-2
src/runtime/semaphore_wrapper.cc
src/runtime/semaphore_wrapper.cc
+47
-0
src/runtime/semaphore_wrapper.h
src/runtime/semaphore_wrapper.h
+47
-0
tests/cpp/socket_communicator_test.cc
tests/cpp/socket_communicator_test.cc
+5
-4
No files found.
CMakeLists.txt
View file @
5cf48fc6
...
@@ -34,6 +34,7 @@ dgl_option(BUILD_CPP_TEST "Build cpp unittest executables" OFF)
...
@@ -34,6 +34,7 @@ dgl_option(BUILD_CPP_TEST "Build cpp unittest executables" OFF)
dgl_option
(
LIBCXX_ENABLE_PARALLEL_ALGORITHMS
"Enable the parallel algorithms library. This requires the PSTL to be available."
OFF
)
dgl_option
(
LIBCXX_ENABLE_PARALLEL_ALGORITHMS
"Enable the parallel algorithms library. This requires the PSTL to be available."
OFF
)
dgl_option
(
USE_S3
"Build with S3 support"
OFF
)
dgl_option
(
USE_S3
"Build with S3 support"
OFF
)
dgl_option
(
USE_HDFS
"Build with HDFS support"
OFF
)
# Set env HADOOP_HDFS_HOME if needed
dgl_option
(
USE_HDFS
"Build with HDFS support"
OFF
)
# Set env HADOOP_HDFS_HOME if needed
dgl_option
(
USE_EPOLL
"Build with epoll for socket communicator"
OFF
)
# Set debug compile option for gdb, only happens when -DCMAKE_BUILD_TYPE=DEBUG
# Set debug compile option for gdb, only happens when -DCMAKE_BUILD_TYPE=DEBUG
if
(
NOT MSVC
)
if
(
NOT MSVC
)
...
@@ -120,6 +121,14 @@ if(USE_AVX)
...
@@ -120,6 +121,14 @@ if(USE_AVX)
endif
(
USE_LIBXSMM
)
endif
(
USE_LIBXSMM
)
endif
(
USE_AVX
)
endif
(
USE_AVX
)
if
(
USE_EPOLL
)
check_include_file
(
"sys/epoll.h"
USE_EPOLL
)
if
(
USE_EPOLL
)
set
(
CMAKE_C_FLAGS
"
${
CMAKE_C_FLAGS
}
-DUSE_EPOLL"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-DUSE_EPOLL"
)
endif
()
endif
()
# Build with fp16 to support mixed precision training.
# Build with fp16 to support mixed precision training.
if
(
USE_FP16
)
if
(
USE_FP16
)
set
(
CMAKE_C_FLAGS
"
${
CMAKE_C_FLAGS
}
-DUSE_FP16"
)
set
(
CMAKE_C_FLAGS
"
${
CMAKE_C_FLAGS
}
-DUSE_FP16"
)
...
...
python/dgl/distributed/rpc.py
View file @
5cf48fc6
"""RPC components. They are typically functions or utilities used by both
"""RPC components. They are typically functions or utilities used by both
server and clients."""
server and clients."""
import
os
import
abc
import
abc
import
pickle
import
pickle
import
random
import
random
...
@@ -111,7 +112,8 @@ def create_sender(max_queue_size, net_type):
...
@@ -111,7 +112,8 @@ def create_sender(max_queue_size, net_type):
net_type : str
net_type : str
Networking type. Current options are: 'socket'.
Networking type. Current options are: 'socket'.
"""
"""
_CAPI_DGLRPCCreateSender
(
int
(
max_queue_size
),
net_type
)
max_thread_count
=
int
(
os
.
getenv
(
'DGL_SOCKET_MAX_THREAD_COUNT'
,
'0'
))
_CAPI_DGLRPCCreateSender
(
int
(
max_queue_size
),
net_type
,
max_thread_count
)
def
create_receiver
(
max_queue_size
,
net_type
):
def
create_receiver
(
max_queue_size
,
net_type
):
"""Create rpc receiver of this process.
"""Create rpc receiver of this process.
...
@@ -123,7 +125,8 @@ def create_receiver(max_queue_size, net_type):
...
@@ -123,7 +125,8 @@ def create_receiver(max_queue_size, net_type):
net_type : str
net_type : str
Networking type. Current options are: 'socket'.
Networking type. Current options are: 'socket'.
"""
"""
_CAPI_DGLRPCCreateReceiver
(
int
(
max_queue_size
),
net_type
)
max_thread_count
=
int
(
os
.
getenv
(
'DGL_SOCKET_MAX_THREAD_COUNT'
,
'0'
))
_CAPI_DGLRPCCreateReceiver
(
int
(
max_queue_size
),
net_type
,
max_thread_count
)
def
finalize_sender
():
def
finalize_sender
():
"""Finalize rpc sender of this process.
"""Finalize rpc sender of this process.
...
...
src/graph/network.cc
View file @
5cf48fc6
...
@@ -206,7 +206,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderCreate")
...
@@ -206,7 +206,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderCreate")
int64_t
msg_queue_size
=
args
[
1
];
int64_t
msg_queue_size
=
args
[
1
];
network
::
Sender
*
sender
=
nullptr
;
network
::
Sender
*
sender
=
nullptr
;
if
(
type
==
"socket"
)
{
if
(
type
==
"socket"
)
{
sender
=
new
network
::
SocketSender
(
msg_queue_size
);
sender
=
new
network
::
SocketSender
(
msg_queue_size
,
0
);
}
else
{
}
else
{
LOG
(
FATAL
)
<<
"Unknown communicator type: "
<<
type
;
LOG
(
FATAL
)
<<
"Unknown communicator type: "
<<
type
;
}
}
...
@@ -220,7 +220,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverCreate")
...
@@ -220,7 +220,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverCreate")
int64_t
msg_queue_size
=
args
[
1
];
int64_t
msg_queue_size
=
args
[
1
];
network
::
Receiver
*
receiver
=
nullptr
;
network
::
Receiver
*
receiver
=
nullptr
;
if
(
type
==
"socket"
)
{
if
(
type
==
"socket"
)
{
receiver
=
new
network
::
SocketReceiver
(
msg_queue_size
);
receiver
=
new
network
::
SocketReceiver
(
msg_queue_size
,
0
);
}
else
{
}
else
{
LOG
(
FATAL
)
<<
"Unknown communicator type: "
<<
type
;
LOG
(
FATAL
)
<<
"Unknown communicator type: "
<<
type
;
}
}
...
...
src/rpc/network/communicator.h
View file @
5cf48fc6
...
@@ -28,11 +28,14 @@ class Sender {
...
@@ -28,11 +28,14 @@ class Sender {
/*!
/*!
* \brief Sender constructor
* \brief Sender constructor
* \param queue_size size (bytes) of message queue.
* \param queue_size size (bytes) of message queue.
* \param max_thread_count size of thread pool. 0 for no limit
* Note that, the queue_size parameter is optional.
* Note that, the queue_size parameter is optional.
*/
*/
explicit
Sender
(
int64_t
queue_size
=
0
)
{
explicit
Sender
(
int64_t
queue_size
=
0
,
int
max_thread_count
=
0
)
{
CHECK_GE
(
queue_size
,
0
);
CHECK_GE
(
queue_size
,
0
);
CHECK_GE
(
max_thread_count
,
0
);
queue_size_
=
queue_size
;
queue_size_
=
queue_size
;
max_thread_count_
=
max_thread_count
;
}
}
virtual
~
Sender
()
{}
virtual
~
Sender
()
{}
...
@@ -86,6 +89,10 @@ class Sender {
...
@@ -86,6 +89,10 @@ class Sender {
* \brief Size of message queue
* \brief Size of message queue
*/
*/
int64_t
queue_size_
;
int64_t
queue_size_
;
/*!
* \brief Size of thread pool. 0 for no limit
*/
int
max_thread_count_
;
};
};
/*!
/*!
...
@@ -101,13 +108,16 @@ class Receiver {
...
@@ -101,13 +108,16 @@ class Receiver {
/*!
/*!
* \brief Receiver constructor
* \brief Receiver constructor
* \param queue_size size of message queue.
* \param queue_size size of message queue.
* \param max_thread_count size of thread pool. 0 for no limit
* Note that, the queue_size parameter is optional.
* Note that, the queue_size parameter is optional.
*/
*/
explicit
Receiver
(
int64_t
queue_size
=
0
)
{
explicit
Receiver
(
int64_t
queue_size
=
0
,
int
max_thread_count
=
0
)
{
if
(
queue_size
<
0
)
{
if
(
queue_size
<
0
)
{
LOG
(
FATAL
)
<<
"queue_size cannot be a negative number."
;
LOG
(
FATAL
)
<<
"queue_size cannot be a negative number."
;
}
}
CHECK_GE
(
max_thread_count
,
0
);
queue_size_
=
queue_size
;
queue_size_
=
queue_size
;
max_thread_count_
=
max_thread_count
;
}
}
virtual
~
Receiver
()
{}
virtual
~
Receiver
()
{}
...
@@ -165,6 +175,10 @@ class Receiver {
...
@@ -165,6 +175,10 @@ class Receiver {
* \brief Size of message queue
* \brief Size of message queue
*/
*/
int64_t
queue_size_
;
int64_t
queue_size_
;
/*!
* \brief Size of thread pool. 0 for no limit
*/
int
max_thread_count_
;
};
};
}
// namespace network
}
// namespace network
...
...
src/rpc/network/msg_queue.cc
View file @
5cf48fc6
...
@@ -72,6 +72,7 @@ STATUS MessageQueue::Remove(Message* msg, bool is_blocking) {
...
@@ -72,6 +72,7 @@ STATUS MessageQueue::Remove(Message* msg, bool is_blocking) {
queue_
.
pop
();
queue_
.
pop
();
msg
->
data
=
old_msg
.
data
;
msg
->
data
=
old_msg
.
data
;
msg
->
size
=
old_msg
.
size
;
msg
->
size
=
old_msg
.
size
;
msg
->
receiver_id
=
old_msg
.
receiver_id
;
msg
->
deallocator
=
old_msg
.
deallocator
;
msg
->
deallocator
=
old_msg
.
deallocator
;
free_size_
+=
old_msg
.
size
;
free_size_
+=
old_msg
.
size
;
cond_not_full_
.
notify_one
();
cond_not_full_
.
notify_one
();
...
...
src/rpc/network/msg_queue.h
View file @
5cf48fc6
...
@@ -56,6 +56,10 @@ struct Message {
...
@@ -56,6 +56,10 @@ struct Message {
* \brief message size in bytes
* \brief message size in bytes
*/
*/
int64_t
size
;
int64_t
size
;
/*!
* \brief message receiver id
*/
int
receiver_id
=
-
1
;
/*!
/*!
* \brief user-defined deallocator, which can be nullptr
* \brief user-defined deallocator, which can be nullptr
*/
*/
...
...
src/rpc/network/socket_communicator.cc
View file @
5cf48fc6
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
#include "socket_communicator.h"
#include "socket_communicator.h"
#include "../../c_api_common.h"
#include "../../c_api_common.h"
#include "socket_pool.h"
#ifdef _WIN32
#ifdef _WIN32
#include <windows.h>
#include <windows.h>
...
@@ -51,15 +52,20 @@ void SocketSender::AddReceiver(const char* addr, int recv_id) {
...
@@ -51,15 +52,20 @@ void SocketSender::AddReceiver(const char* addr, int recv_id) {
address
.
ip
=
ip_and_port
[
0
];
address
.
ip
=
ip_and_port
[
0
];
address
.
port
=
std
::
stoi
(
ip_and_port
[
1
]);
address
.
port
=
std
::
stoi
(
ip_and_port
[
1
]);
receiver_addrs_
[
recv_id
]
=
address
;
receiver_addrs_
[
recv_id
]
=
address
;
msg_queue_
[
recv_id
]
=
std
::
make_shared
<
MessageQueue
>
(
queue_size_
);
}
}
bool
SocketSender
::
Connect
()
{
bool
SocketSender
::
Connect
()
{
// Create N sockets for Receiver
// Create N sockets for Receiver
int
receiver_count
=
static_cast
<
int
>
(
receiver_addrs_
.
size
());
if
(
max_thread_count_
==
0
||
max_thread_count_
>
receiver_count
)
{
max_thread_count_
=
receiver_count
;
}
sockets_
.
resize
(
max_thread_count_
);
for
(
const
auto
&
r
:
receiver_addrs_
)
{
for
(
const
auto
&
r
:
receiver_addrs_
)
{
int
ID
=
r
.
first
;
int
receiver_id
=
r
.
first
;
sockets_
[
ID
]
=
std
::
make_shared
<
TCPSocket
>
();
int
thread_id
=
receiver_id
%
max_thread_count_
;
TCPSocket
*
client_socket
=
sockets_
[
ID
].
get
();
sockets_
[
thread_id
][
receiver_id
]
=
std
::
make_shared
<
TCPSocket
>
();
TCPSocket
*
client_socket
=
sockets_
[
thread_id
][
receiver_id
].
get
();
bool
bo
=
false
;
bool
bo
=
false
;
int
try_count
=
0
;
int
try_count
=
0
;
const
char
*
ip
=
r
.
second
.
ip
.
c_str
();
const
char
*
ip
=
r
.
second
.
ip
.
c_str
();
...
@@ -83,12 +89,17 @@ bool SocketSender::Connect() {
...
@@ -83,12 +89,17 @@ bool SocketSender::Connect() {
if
(
bo
==
false
)
{
if
(
bo
==
false
)
{
return
bo
;
return
bo
;
}
}
}
for
(
int
thread_id
=
0
;
thread_id
<
max_thread_count_
;
++
thread_id
)
{
msg_queue_
.
push_back
(
std
::
make_shared
<
MessageQueue
>
(
queue_size_
));
// Create a new thread for this socket connection
// Create a new thread for this socket connection
threads_
[
ID
]
=
std
::
make_shared
<
std
::
thread
>
(
threads_
.
push_back
(
std
::
make_shared
<
std
::
thread
>
(
SendLoop
,
SendLoop
,
client_socket
,
sockets_
[
thread_id
]
,
msg_queue_
[
ID
].
get
(
));
msg_queue_
[
thread_id
]
));
}
}
return
true
;
return
true
;
}
}
...
@@ -96,69 +107,80 @@ STATUS SocketSender::Send(Message msg, int recv_id) {
...
@@ -96,69 +107,80 @@ STATUS SocketSender::Send(Message msg, int recv_id) {
CHECK_NOTNULL
(
msg
.
data
);
CHECK_NOTNULL
(
msg
.
data
);
CHECK_GT
(
msg
.
size
,
0
);
CHECK_GT
(
msg
.
size
,
0
);
CHECK_GE
(
recv_id
,
0
);
CHECK_GE
(
recv_id
,
0
);
msg
.
receiver_id
=
recv_id
;
// Add data message to message queue
// Add data message to message queue
STATUS
code
=
msg_queue_
[
recv_id
]
->
Add
(
msg
);
STATUS
code
=
msg_queue_
[
recv_id
%
max_thread_count_
]
->
Add
(
msg
);
return
code
;
return
code
;
}
}
void
SocketSender
::
Finalize
()
{
void
SocketSender
::
Finalize
()
{
// Send a signal to tell the msg_queue to finish its job
// Send a signal to tell the msg_queue to finish its job
for
(
auto
&
mq
:
msg_queue_
)
{
for
(
int
i
=
0
;
i
<
max_thread_count_
;
++
i
)
{
// wait until queue is empty
// wait until queue is empty
while
(
mq
.
second
->
Empty
()
==
false
)
{
auto
&
mq
=
msg_queue_
[
i
];
while
(
mq
->
Empty
()
==
false
)
{
#ifdef _WIN32
#ifdef _WIN32
// just loop
// just loop
#else // !_WIN32
#else // !_WIN32
usleep
(
1000
);
usleep
(
1000
);
#endif // _WIN32
#endif // _WIN32
}
}
int
ID
=
mq
.
first
;
// All queues have only one producer, which is main thread, so
mq
.
second
->
SignalFinished
(
ID
);
// the producerID argument here should be zero.
mq
->
SignalFinished
(
0
);
}
}
// Block main thread until all socket-threads finish their jobs
// Block main thread until all socket-threads finish their jobs
for
(
auto
&
thread
:
threads_
)
{
for
(
auto
&
thread
:
threads_
)
{
thread
.
second
->
join
();
thread
->
join
();
}
}
// Clear all sockets
// Clear all sockets
for
(
auto
&
socket
:
sockets_
)
{
for
(
auto
&
group_sockets_
:
sockets_
)
{
socket
.
second
->
Close
();
for
(
auto
&
socket
:
group_sockets_
)
{
socket
.
second
->
Close
();
}
}
}
void
SendCore
(
Message
msg
,
TCPSocket
*
socket
)
{
// First send the size
// If exit == true, we will send zero size to reciever
int64_t
sent_bytes
=
0
;
while
(
static_cast
<
size_t
>
(
sent_bytes
)
<
sizeof
(
int64_t
))
{
int64_t
max_len
=
sizeof
(
int64_t
)
-
sent_bytes
;
int64_t
tmp
=
socket
->
Send
(
reinterpret_cast
<
char
*>
(
&
msg
.
size
)
+
sent_bytes
,
max_len
);
CHECK_NE
(
tmp
,
-
1
);
sent_bytes
+=
tmp
;
}
// Then send the data
sent_bytes
=
0
;
while
(
sent_bytes
<
msg
.
size
)
{
int64_t
max_len
=
msg
.
size
-
sent_bytes
;
int64_t
tmp
=
socket
->
Send
(
msg
.
data
+
sent_bytes
,
max_len
);
CHECK_NE
(
tmp
,
-
1
);
sent_bytes
+=
tmp
;
}
// delete msg
if
(
msg
.
deallocator
!=
nullptr
)
{
msg
.
deallocator
(
&
msg
);
}
}
}
}
void
SocketSender
::
SendLoop
(
TCPSocket
*
socket
,
MessageQueue
*
queue
)
{
void
SocketSender
::
SendLoop
(
CHECK_NOTNULL
(
socket
);
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
TCPSocket
>>
sockets
,
CHECK_NOTNULL
(
queue
);
std
::
shared_ptr
<
MessageQueue
>
queue
)
{
bool
exit
=
false
;
for
(;;)
{
while
(
!
exit
)
{
Message
msg
;
Message
msg
;
STATUS
code
=
queue
->
Remove
(
&
msg
);
STATUS
code
=
queue
->
Remove
(
&
msg
);
if
(
code
==
QUEUE_CLOSE
)
{
if
(
code
==
QUEUE_CLOSE
)
{
msg
.
size
=
0
;
// send an end-signal to receiver
msg
.
size
=
0
;
// send an end-signal to receiver
exit
=
true
;
for
(
auto
&
socket
:
sockets
)
{
}
SendCore
(
msg
,
socket
.
second
.
get
());
// First send the size
}
// If exit == true, we will send zero size to reciever
break
;
int64_t
sent_bytes
=
0
;
while
(
static_cast
<
size_t
>
(
sent_bytes
)
<
sizeof
(
int64_t
))
{
int64_t
max_len
=
sizeof
(
int64_t
)
-
sent_bytes
;
int64_t
tmp
=
socket
->
Send
(
reinterpret_cast
<
char
*>
(
&
msg
.
size
)
+
sent_bytes
,
max_len
);
CHECK_NE
(
tmp
,
-
1
);
sent_bytes
+=
tmp
;
}
// Then send the data
sent_bytes
=
0
;
while
(
sent_bytes
<
msg
.
size
)
{
int64_t
max_len
=
msg
.
size
-
sent_bytes
;
int64_t
tmp
=
socket
->
Send
(
msg
.
data
+
sent_bytes
,
max_len
);
CHECK_NE
(
tmp
,
-
1
);
sent_bytes
+=
tmp
;
}
// delete msg
if
(
msg
.
deallocator
!=
nullptr
)
{
msg
.
deallocator
(
&
msg
);
}
}
SendCore
(
msg
,
sockets
[
msg
.
receiver_id
].
get
());
}
}
}
}
...
@@ -187,16 +209,20 @@ bool SocketReceiver::Wait(const char* addr, int num_sender) {
...
@@ -187,16 +209,20 @@ bool SocketReceiver::Wait(const char* addr, int num_sender) {
int
port
=
stoi
(
ip_and_port
[
1
]);
int
port
=
stoi
(
ip_and_port
[
1
]);
// Initialize message queue for each connection
// Initialize message queue for each connection
num_sender_
=
num_sender
;
num_sender_
=
num_sender
;
for
(
int
i
=
0
;
i
<
num_sender_
;
++
i
)
{
#ifdef USE_EPOLL
msg_queue_
[
i
]
=
std
::
make_shared
<
MessageQueue
>
(
queue_size_
);
if
(
max_thread_count_
==
0
||
max_thread_count_
>
num_sender_
)
{
max_thread_count_
=
num_sender_
;
}
}
mq_iter_
=
msg_queue_
.
begin
();
#else
max_thread_count_
=
num_sender_
;
#endif
// Initialize socket and socket-thread
// Initialize socket and socket-thread
server_socket_
=
new
TCPSocket
();
server_socket_
=
new
TCPSocket
();
// Bind socket
// Bind socket
if
(
server_socket_
->
Bind
(
ip
.
c_str
(),
port
)
==
false
)
{
if
(
server_socket_
->
Bind
(
ip
.
c_str
(),
port
)
==
false
)
{
LOG
(
FATAL
)
<<
"Cannot bind to "
<<
ip
<<
":"
<<
port
;
LOG
(
FATAL
)
<<
"Cannot bind to "
<<
ip
<<
":"
<<
port
;
}
}
// Listen
// Listen
if
(
server_socket_
->
Listen
(
kMaxConnection
)
==
false
)
{
if
(
server_socket_
->
Listen
(
kMaxConnection
)
==
false
)
{
LOG
(
FATAL
)
<<
"Cannot listen on "
<<
ip
<<
":"
<<
port
;
LOG
(
FATAL
)
<<
"Cannot listen on "
<<
ip
<<
":"
<<
port
;
...
@@ -204,27 +230,39 @@ bool SocketReceiver::Wait(const char* addr, int num_sender) {
...
@@ -204,27 +230,39 @@ bool SocketReceiver::Wait(const char* addr, int num_sender) {
// Accept all sender sockets
// Accept all sender sockets
std
::
string
accept_ip
;
std
::
string
accept_ip
;
int
accept_port
;
int
accept_port
;
sockets_
.
resize
(
max_thread_count_
);
for
(
int
i
=
0
;
i
<
num_sender_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_sender_
;
++
i
)
{
sockets_
[
i
]
=
std
::
make_shared
<
TCPSocket
>
();
int
thread_id
=
i
%
max_thread_count_
;
if
(
server_socket_
->
Accept
(
sockets_
[
i
].
get
(),
&
accept_ip
,
&
accept_port
)
==
false
)
{
auto
socket
=
std
::
make_shared
<
TCPSocket
>
();
sockets_
[
thread_id
][
i
]
=
socket
;
msg_queue_
[
i
]
=
std
::
make_shared
<
MessageQueue
>
(
queue_size_
);
if
(
server_socket_
->
Accept
(
socket
.
get
(),
&
accept_ip
,
&
accept_port
)
==
false
)
{
LOG
(
WARNING
)
<<
"Error on accept socket."
;
LOG
(
WARNING
)
<<
"Error on accept socket."
;
return
false
;
return
false
;
}
}
}
mq_iter_
=
msg_queue_
.
begin
();
for
(
int
thread_id
=
0
;
thread_id
<
max_thread_count_
;
++
thread_id
)
{
// create new thread for each socket
// create new thread for each socket
threads_
[
i
]
=
std
::
make_shared
<
std
::
thread
>
(
threads_
.
push_back
(
std
::
make_shared
<
std
::
thread
>
(
RecvLoop
,
RecvLoop
,
sockets_
[
i
].
get
(),
sockets_
[
thread_id
],
msg_queue_
[
i
].
get
());
msg_queue_
,
&
queue_sem_
));
}
}
return
true
;
return
true
;
}
}
STATUS
SocketReceiver
::
Recv
(
Message
*
msg
,
int
*
send_id
)
{
STATUS
SocketReceiver
::
Recv
(
Message
*
msg
,
int
*
send_id
)
{
// loop until get a message
// queue_sem_ is a semaphore indicating how many elements in multiple
// message queues.
// When calling queue_sem_.Wait(), this Recv will be suspended until
// queue_sem_ > 0, decrease queue_sem_ by 1, then start to fetch a message.
queue_sem_
.
Wait
();
for
(;;)
{
for
(;;)
{
for
(;
mq_iter_
!=
msg_queue_
.
end
();
++
mq_iter_
)
{
for
(;
mq_iter_
!=
msg_queue_
.
end
();
++
mq_iter_
)
{
// We use non-block remove here
STATUS
code
=
mq_iter_
->
second
->
Remove
(
msg
,
false
);
STATUS
code
=
mq_iter_
->
second
->
Remove
(
msg
,
false
);
if
(
code
==
QUEUE_EMPTY
)
{
if
(
code
==
QUEUE_EMPTY
)
{
continue
;
// jump to the next queue
continue
;
// jump to the next queue
...
@@ -240,6 +278,7 @@ STATUS SocketReceiver::Recv(Message* msg, int* send_id) {
...
@@ -240,6 +278,7 @@ STATUS SocketReceiver::Recv(Message* msg, int* send_id) {
STATUS
SocketReceiver
::
RecvFrom
(
Message
*
msg
,
int
send_id
)
{
STATUS
SocketReceiver
::
RecvFrom
(
Message
*
msg
,
int
send_id
)
{
// Get message from specified message queue
// Get message from specified message queue
queue_sem_
.
Wait
();
STATUS
code
=
msg_queue_
[
send_id
]
->
Remove
(
msg
);
STATUS
code
=
msg_queue_
[
send_id
]
->
Remove
(
msg
);
return
code
;
return
code
;
}
}
...
@@ -255,65 +294,122 @@ void SocketReceiver::Finalize() {
...
@@ -255,65 +294,122 @@ void SocketReceiver::Finalize() {
usleep
(
1000
);
usleep
(
1000
);
#endif // _WIN32
#endif // _WIN32
}
}
int
ID
=
mq
.
first
;
mq
.
second
->
SignalFinished
(
mq
.
first
);
mq
.
second
->
SignalFinished
(
ID
);
}
}
// Block main thread until all socket-threads finish their jobs
// Block main thread until all socket-threads finish their jobs
for
(
auto
&
thread
:
threads_
)
{
for
(
auto
&
thread
:
threads_
)
{
thread
.
second
->
join
();
thread
->
join
();
}
}
// Clear all sockets
// Clear all sockets
for
(
auto
&
socket
:
sockets_
)
{
for
(
auto
&
group_sockets
:
sockets_
)
{
socket
.
second
->
Close
();
for
(
auto
&
socket
:
group_sockets
)
{
socket
.
second
->
Close
();
}
}
}
server_socket_
->
Close
();
server_socket_
->
Close
();
delete
server_socket_
;
delete
server_socket_
;
}
}
void
SocketReceiver
::
RecvLoop
(
TCPSocket
*
socket
,
MessageQueue
*
queue
)
{
int64_t
RecvDataSize
(
TCPSocket
*
socket
)
{
CHECK_NOTNULL
(
socket
);
int64_t
received_bytes
=
0
;
CHECK_NOTNULL
(
queue
);
int64_t
data_size
=
0
;
for
(;;)
{
while
(
static_cast
<
size_t
>
(
received_bytes
)
<
sizeof
(
int64_t
))
{
// If main thread had finished its job
int64_t
max_len
=
sizeof
(
int64_t
)
-
received_bytes
;
if
(
queue
->
EmptyAndNoMoreAdd
())
{
int64_t
tmp
=
socket
->
Receive
(
return
;
// exit loop thread
reinterpret_cast
<
char
*>
(
&
data_size
)
+
received_bytes
,
}
max_len
);
// First recv the size
if
(
tmp
==
-
1
)
{
int64_t
received_bytes
=
0
;
if
(
received_bytes
>
0
)
{
int64_t
data_size
=
0
;
// We want to finish reading full data_size
while
(
static_cast
<
size_t
>
(
received_bytes
)
<
sizeof
(
int64_t
))
{
continue
;
int64_t
max_len
=
sizeof
(
int64_t
)
-
received_bytes
;
}
int64_t
tmp
=
socket
->
Receive
(
return
-
1
;
reinterpret_cast
<
char
*>
(
&
data_size
)
+
received_bytes
,
max_len
);
CHECK_NE
(
tmp
,
-
1
);
received_bytes
+=
tmp
;
}
}
if
(
data_size
<
0
)
{
received_bytes
+=
tmp
;
LOG
(
FATAL
)
<<
"Recv data error (data_size: "
<<
data_size
<<
")"
;
}
}
else
if
(
data_size
==
0
)
{
return
data_size
;
// This is an end-signal sent by client
}
void
RecvData
(
TCPSocket
*
socket
,
char
*
buffer
,
const
int64_t
&
data_size
,
int64_t
*
received_bytes
)
{
while
(
*
received_bytes
<
data_size
)
{
int64_t
max_len
=
data_size
-
*
received_bytes
;
int64_t
tmp
=
socket
->
Receive
(
buffer
+
*
received_bytes
,
max_len
);
if
(
tmp
==
-
1
)
{
// Socket not ready, no more data to read
return
;
return
;
}
else
{
}
char
*
buffer
=
nullptr
;
*
received_bytes
+=
tmp
;
try
{
}
buffer
=
new
char
[
data_size
];
}
}
catch
(
const
std
::
bad_alloc
&
)
{
LOG
(
FATAL
)
<<
"Cannot allocate enough memory for message, "
void
SocketReceiver
::
RecvLoop
(
<<
"(message size: "
<<
data_size
<<
")"
;
std
::
unordered_map
<
int
/* Sender (virtual) ID */
,
std
::
shared_ptr
<
TCPSocket
>>
sockets
,
std
::
unordered_map
<
int
/* Sender (virtual) ID */
,
std
::
shared_ptr
<
MessageQueue
>>
queues
,
runtime
::
Semaphore
*
queue_sem
)
{
std
::
unordered_map
<
int
,
std
::
unique_ptr
<
RecvContext
>>
recv_contexts
;
SocketPool
socket_pool
;
for
(
auto
&
socket
:
sockets
)
{
auto
&
sender_id
=
socket
.
first
;
socket_pool
.
AddSocket
(
socket
.
second
,
sender_id
);
recv_contexts
[
sender_id
]
=
std
::
unique_ptr
<
RecvContext
>
(
new
RecvContext
());
}
// Main loop to receive messages
for
(;;)
{
int
sender_id
;
// Get active socket using epoll
std
::
shared_ptr
<
TCPSocket
>
socket
=
socket_pool
.
GetActiveSocket
(
&
sender_id
);
if
(
queues
[
sender_id
]
->
EmptyAndNoMoreAdd
())
{
// This sender has already stopped
if
(
socket_pool
.
RemoveSocket
(
socket
)
==
0
)
{
return
;
}
}
received_bytes
=
0
;
continue
;
while
(
received_bytes
<
data_size
)
{
}
int64_t
max_len
=
data_size
-
received_bytes
;
int64_t
tmp
=
socket
->
Receive
(
buffer
+
received_bytes
,
max_len
);
// Nonblocking socket might be interrupted at any point. So we need to
CHECK_NE
(
tmp
,
-
1
);
// store the partially received data
received_bytes
+=
tmp
;
std
::
unique_ptr
<
RecvContext
>
&
ctx
=
recv_contexts
[
sender_id
];
int64_t
&
data_size
=
ctx
->
data_size
;
int64_t
&
received_bytes
=
ctx
->
received_bytes
;
char
*&
buffer
=
ctx
->
buffer
;
if
(
data_size
==
-
1
)
{
// This is a new message, so receive the data size first
data_size
=
RecvDataSize
(
socket
.
get
());
if
(
data_size
>
0
)
{
try
{
buffer
=
new
char
[
data_size
];
}
catch
(
const
std
::
bad_alloc
&
)
{
LOG
(
FATAL
)
<<
"Cannot allocate enough memory for message, "
<<
"(message size: "
<<
data_size
<<
")"
;
}
received_bytes
=
0
;
}
else
if
(
data_size
==
0
)
{
// Received stop signal
if
(
socket_pool
.
RemoveSocket
(
socket
)
==
0
)
{
return
;
}
}
}
}
RecvData
(
socket
.
get
(),
buffer
,
data_size
,
&
received_bytes
);
if
(
received_bytes
>=
data_size
)
{
// Full data received, create Message and push to queue
Message
msg
;
Message
msg
;
msg
.
data
=
buffer
;
msg
.
data
=
buffer
;
msg
.
size
=
data_size
;
msg
.
size
=
data_size
;
msg
.
deallocator
=
DefaultMessageDeleter
;
msg
.
deallocator
=
DefaultMessageDeleter
;
queue
->
Add
(
msg
);
queues
[
sender_id
]
->
Add
(
msg
);
// Reset recv context
data_size
=
-
1
;
// Signal queue semaphore
queue_sem
->
Post
();
}
}
}
}
}
}
...
...
src/rpc/network/socket_communicator.h
View file @
5cf48fc6
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
#include <unordered_map>
#include <unordered_map>
#include <memory>
#include <memory>
#include "../../runtime/semaphore_wrapper.h"
#include "communicator.h"
#include "communicator.h"
#include "msg_queue.h"
#include "msg_queue.h"
#include "tcp_socket.h"
#include "tcp_socket.h"
...
@@ -42,8 +43,10 @@ class SocketSender : public Sender {
...
@@ -42,8 +43,10 @@ class SocketSender : public Sender {
/*!
/*!
* \brief Sender constructor
* \brief Sender constructor
* \param queue_size size of message queue
* \param queue_size size of message queue
* \param max_thread_count size of thread pool. 0 for no limit
*/
*/
explicit
SocketSender
(
int64_t
queue_size
)
:
Sender
(
queue_size
)
{}
SocketSender
(
int64_t
queue_size
,
int
max_thread_count
)
:
Sender
(
queue_size
,
max_thread_count
)
{}
/*!
/*!
* \brief Add receiver's address and ID to the sender's namebook
* \brief Add receiver's address and ID to the sender's namebook
...
@@ -93,7 +96,8 @@ class SocketSender : public Sender {
...
@@ -93,7 +96,8 @@ class SocketSender : public Sender {
/*!
/*!
* \brief socket for each connection of receiver
* \brief socket for each connection of receiver
*/
*/
std
::
unordered_map
<
int
/* receiver ID */
,
std
::
shared_ptr
<
TCPSocket
>>
sockets_
;
std
::
vector
<
std
::
unordered_map
<
int
/* receiver ID */
,
std
::
shared_ptr
<
TCPSocket
>>>
sockets_
;
/*!
/*!
* \brief receivers' address
* \brief receivers' address
...
@@ -101,24 +105,27 @@ class SocketSender : public Sender {
...
@@ -101,24 +105,27 @@ class SocketSender : public Sender {
std
::
unordered_map
<
int
/* receiver ID */
,
IPAddr
>
receiver_addrs_
;
std
::
unordered_map
<
int
/* receiver ID */
,
IPAddr
>
receiver_addrs_
;
/*!
/*!
* \brief message queue for each
socket connection
* \brief message queue for each
thread
*/
*/
std
::
unordered_map
<
int
/* receiver ID */
,
std
::
shared_ptr
<
MessageQueue
>>
msg_queue_
;
std
::
vector
<
std
::
shared_ptr
<
MessageQueue
>>
msg_queue_
;
/*!
/*!
* \brief Independent thread
for each socket connection
* \brief Independent thread
*/
*/
std
::
unordered_map
<
int
/* receiver ID */
,
std
::
shared_ptr
<
std
::
thread
>>
threads_
;
std
::
vector
<
std
::
shared_ptr
<
std
::
thread
>>
threads_
;
/*!
/*!
* \brief Send-loop for each
socket in per-
thread
* \brief Send-loop for each thread
* \param socket TCPSocket for current
connection
* \param socket
s
TCPSocket
s
for current
thread
* \param queue message_queue for current
connection
* \param queue message_queue for current
thread
*
*
* Note that, the SendLoop will finish its loop-job and exit thread
* Note that, the SendLoop will finish its loop-job and exit thread
* when the main thread invokes Signal() API on the message queue.
* when the main thread invokes Signal() API on the message queue.
*/
*/
static
void
SendLoop
(
TCPSocket
*
socket
,
MessageQueue
*
queue
);
static
void
SendLoop
(
std
::
unordered_map
<
int
/* Receiver (virtual) ID */
,
std
::
shared_ptr
<
TCPSocket
>>
sockets
,
std
::
shared_ptr
<
MessageQueue
>
queue
);
};
};
/*!
/*!
...
@@ -131,8 +138,10 @@ class SocketReceiver : public Receiver {
...
@@ -131,8 +138,10 @@ class SocketReceiver : public Receiver {
/*!
/*!
* \brief Receiver constructor
* \brief Receiver constructor
* \param queue_size size of message queue.
* \param queue_size size of message queue.
* \param max_thread_count size of thread pool. 0 for no limit
*/
*/
explicit
SocketReceiver
(
int64_t
queue_size
)
:
Receiver
(
queue_size
)
{}
SocketReceiver
(
int64_t
queue_size
,
int
max_thread_count
)
:
Receiver
(
queue_size
,
max_thread_count
)
{}
/*!
/*!
* \brief Wait for all the Senders to connect
* \brief Wait for all the Senders to connect
...
@@ -183,6 +192,11 @@ class SocketReceiver : public Receiver {
...
@@ -183,6 +192,11 @@ class SocketReceiver : public Receiver {
inline
std
::
string
Type
()
const
{
return
std
::
string
(
"socket"
);
}
inline
std
::
string
Type
()
const
{
return
std
::
string
(
"socket"
);
}
private:
private:
struct
RecvContext
{
int64_t
data_size
=
-
1
;
int64_t
received_bytes
=
0
;
char
*
buffer
=
nullptr
;
};
/*!
/*!
* \brief number of sender
* \brief number of sender
*/
*/
...
@@ -196,28 +210,41 @@ class SocketReceiver : public Receiver {
...
@@ -196,28 +210,41 @@ class SocketReceiver : public Receiver {
/*!
/*!
* \brief socket for each client connections
* \brief socket for each client connections
*/
*/
std
::
unordered_map
<
int
/* Sender (virutal) ID */
,
std
::
shared_ptr
<
TCPSocket
>>
sockets_
;
std
::
vector
<
std
::
unordered_map
<
int
/* Sender (virutal) ID */
,
std
::
shared_ptr
<
TCPSocket
>>>
sockets_
;
/*!
/*!
* \brief Message queue for each socket connection
* \brief Message queue for each socket connection
*/
*/
std
::
unordered_map
<
int
/* Sender (virtual) ID */
,
std
::
shared_ptr
<
MessageQueue
>>
msg_queue_
;
std
::
unordered_map
<
int
/* Sender (virtual) ID */
,
std
::
shared_ptr
<
MessageQueue
>>
msg_queue_
;
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
MessageQueue
>>::
iterator
mq_iter_
;
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
MessageQueue
>>::
iterator
mq_iter_
;
/*!
/*!
* \brief Independent thead
for each socket connection
* \brief Independent thead
*/
*/
std
::
unordered_map
<
int
/* Sender (virtual) ID */
,
std
::
shared_ptr
<
std
::
thread
>>
threads_
;
std
::
vector
<
std
::
shared_ptr
<
std
::
thread
>>
threads_
;
/*!
/*!
* \brief Recv-loop for each socket in per-thread
* \brief queue_sem_ semphore to indicate number of messages in multiple
* \param socket client socket
* message queues to prevent busy wait of Recv
* \param queue message queue
*/
runtime
::
Semaphore
queue_sem_
;
/*!
* \brief Recv-loop for each thread
* \param sockets client sockets of current thread
* \param queue message queues of current thread
*
*
* Note that, the RecvLoop will finish its loop-job and exit thread
* Note that, the RecvLoop will finish its loop-job and exit thread
* when the main thread invokes Signal() API on the message queue.
* when the main thread invokes Signal() API on the message queue.
*/
*/
static
void
RecvLoop
(
TCPSocket
*
socket
,
MessageQueue
*
queue
);
static
void
RecvLoop
(
std
::
unordered_map
<
int
/* Sender (virtual) ID */
,
std
::
shared_ptr
<
TCPSocket
>>
sockets
,
std
::
unordered_map
<
int
/* Sender (virtual) ID */
,
std
::
shared_ptr
<
MessageQueue
>>
queues
,
runtime
::
Semaphore
*
queue_sem
);
};
};
}
// namespace network
}
// namespace network
...
...
src/rpc/network/socket_pool.cc
0 → 100644
View file @
5cf48fc6
/*!
* Copyright (c) 2021 by Contributors
* \file socket_pool.cc
* \brief Socket pool of nonblocking sockets for DGL distributed training.
*/
#include "socket_pool.h"
#include <dmlc/logging.h>
#include "tcp_socket.h"
#ifdef USE_EPOLL
#include <sys/epoll.h>
#endif
namespace
dgl
{
namespace
network
{
SocketPool
::
SocketPool
()
{
#ifdef USE_EPOLL
epfd_
=
epoll_create1
(
0
);
if
(
epfd_
<
0
)
{
LOG
(
FATAL
)
<<
"SocketPool cannot create epfd"
;
}
#endif
}
void
SocketPool
::
AddSocket
(
std
::
shared_ptr
<
TCPSocket
>
socket
,
int
socket_id
,
int
events
)
{
int
fd
=
socket
->
Socket
();
tcp_sockets_
[
fd
]
=
socket
;
socket_ids_
[
fd
]
=
socket_id
;
#ifdef USE_EPOLL
epoll_event
e
;
e
.
data
.
fd
=
fd
;
if
(
events
==
READ
)
{
e
.
events
=
EPOLLIN
;
}
else
if
(
events
==
WRITE
)
{
e
.
events
=
EPOLLOUT
;
}
else
if
(
events
==
READ
+
WRITE
)
{
e
.
events
=
EPOLLIN
|
EPOLLOUT
;
}
if
(
epoll_ctl
(
epfd_
,
EPOLL_CTL_ADD
,
fd
,
&
e
)
<
0
)
{
LOG
(
FATAL
)
<<
"SocketPool cannot add socket"
;
}
socket
->
SetNonBlocking
(
true
);
#else
if
(
tcp_sockets_
.
size
()
>
1
)
{
LOG
(
FATAL
)
<<
"SocketPool supports only one socket if not use epoll."
"Please turn on USE_EPOLL on building"
;
}
#endif
}
size_t
SocketPool
::
RemoveSocket
(
std
::
shared_ptr
<
TCPSocket
>
socket
)
{
int
fd
=
socket
->
Socket
();
socket_ids_
.
erase
(
fd
);
tcp_sockets_
.
erase
(
fd
);
#ifdef USE_EPOLL
epoll_ctl
(
epfd_
,
EPOLL_CTL_DEL
,
fd
,
NULL
);
#endif
return
socket_ids_
.
size
();
}
SocketPool
::~
SocketPool
()
{
#ifdef USE_EPOLL
for
(
auto
&
id
:
socket_ids_
)
{
int
fd
=
id
.
first
;
epoll_ctl
(
epfd_
,
EPOLL_CTL_DEL
,
fd
,
NULL
);
}
#endif
}
std
::
shared_ptr
<
TCPSocket
>
SocketPool
::
GetActiveSocket
(
int
*
socket_id
)
{
if
(
socket_ids_
.
empty
())
{
return
nullptr
;
}
for
(;;)
{
while
(
pending_fds_
.
empty
())
{
Wait
();
}
int
fd
=
pending_fds_
.
front
();
pending_fds_
.
pop
();
// Check if this socket is not removed
if
(
socket_ids_
.
find
(
fd
)
!=
socket_ids_
.
end
())
{
*
socket_id
=
socket_ids_
[
fd
];
return
tcp_sockets_
[
fd
];
}
}
return
nullptr
;
}
void
SocketPool
::
Wait
()
{
#ifdef USE_EPOLL
static
const
int
MAX_EVENTS
=
10
;
epoll_event
events
[
MAX_EVENTS
];
int
nfd
=
epoll_wait
(
epfd_
,
events
,
MAX_EVENTS
,
-
1
/*Timeout*/
);
for
(
int
i
=
0
;
i
<
nfd
;
++
i
)
{
pending_fds_
.
push
(
events
[
i
].
data
.
fd
);
}
#else
pending_fds_
.
push
(
tcp_sockets_
.
begin
()
->
second
->
Socket
());
#endif
}
}
// namespace network
}
// namespace dgl
src/rpc/network/socket_pool.h
0 → 100644
View file @
5cf48fc6
/*!
* Copyright (c) 2021 by Contributors
* \file socket_pool.h
* \brief Socket pool of nonblocking sockets for DGL distributed training.
*/
#ifndef DGL_RPC_NETWORK_SOCKET_POOL_H_
#define DGL_RPC_NETWORK_SOCKET_POOL_H_
#include <unordered_map>
#include <queue>
#include <memory>
namespace
dgl
{
namespace
network
{
class
TCPSocket
;
/*!
* \brief SocketPool maintains a group of nonblocking sockets, and can provide
* active sockets.
* Currently SocketPool is based on epoll, a scalable I/O event notification
* mechanism in Linux operating system.
*/
class
SocketPool
{
public:
/*!
* \brief socket mode read/receive
*/
static
const
int
READ
=
1
;
/*!
* \brief socket mode write/send
*/
static
const
int
WRITE
=
2
;
/*!
* \brief SocketPool constructor
*/
SocketPool
();
/*!
* \brief Add a socket to SocketPool
* \param socket tcp socket to add
* \param socket_id receiver/sender id of the socket
* \param events READ, WRITE or READ + WRITE
*/
void
AddSocket
(
std
::
shared_ptr
<
TCPSocket
>
socket
,
int
socket_id
,
int
events
=
READ
);
/*!
* \brief Remove socket from SocketPool
* \param socket tcp socket to remove
* \return number of remaing sockets in the pool
*/
size_t
RemoveSocket
(
std
::
shared_ptr
<
TCPSocket
>
socket
);
/*!
* \brief SocketPool destructor
*/
~
SocketPool
();
/*!
* \brief Get current active socket. This is a blocking method
* \param socket_id output parameter of the socket_id of active socket
* \return active TCPSocket
*/
std
::
shared_ptr
<
TCPSocket
>
GetActiveSocket
(
int
*
socket_id
);
private:
/*!
* \brief Wait for event notification
*/
void
Wait
();
/*!
* \brief map from fd to TCPSocket
*/
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
TCPSocket
>>
tcp_sockets_
;
/*!
* \brief map from fd to socket_id
*/
std
::
unordered_map
<
int
,
int
>
socket_ids_
;
/*!
* \brief fd for epoll base
*/
int
epfd_
;
/*!
* \brief queue for current active fds
*/
std
::
queue
<
int
>
pending_fds_
;
};
}
// namespace network
}
// namespace dgl
#endif // DGL_RPC_NETWORK_SOCKET_POOL_H_
src/rpc/network/tcp_socket.cc
View file @
5cf48fc6
...
@@ -119,7 +119,7 @@ bool TCPSocket::Accept(TCPSocket * socket, std::string * ip, int * port) {
...
@@ -119,7 +119,7 @@ bool TCPSocket::Accept(TCPSocket * socket, std::string * ip, int * port) {
}
}
#ifdef _WIN32
#ifdef _WIN32
bool
TCPSocket
::
SetBlocking
(
bool
flag
)
{
bool
TCPSocket
::
Set
Non
Blocking
(
bool
flag
)
{
int
result
;
int
result
;
u_long
argp
=
flag
?
1
:
0
;
u_long
argp
=
flag
?
1
:
0
;
...
@@ -134,7 +134,7 @@ bool TCPSocket::SetBlocking(bool flag) {
...
@@ -134,7 +134,7 @@ bool TCPSocket::SetBlocking(bool flag) {
return
true
;
return
true
;
}
}
#else // !_WIN32
#else // !_WIN32
bool
TCPSocket
::
SetBlocking
(
bool
flag
)
{
bool
TCPSocket
::
Set
Non
Blocking
(
bool
flag
)
{
int
opts
;
int
opts
;
if
((
opts
=
fcntl
(
socket_
,
F_GETFL
))
<
0
)
{
if
((
opts
=
fcntl
(
socket_
,
F_GETFL
))
<
0
)
{
...
@@ -205,7 +205,7 @@ int64_t TCPSocket::Receive(char * buffer, int64_t size_buffer) {
...
@@ -205,7 +205,7 @@ int64_t TCPSocket::Receive(char * buffer, int64_t size_buffer) {
do
{
// retry if EINTR failure appears
do
{
// retry if EINTR failure appears
number_recv
=
recv
(
socket_
,
buffer
,
size_buffer
,
0
);
number_recv
=
recv
(
socket_
,
buffer
,
size_buffer
,
0
);
}
while
(
number_recv
==
-
1
&&
errno
==
EINTR
);
}
while
(
number_recv
==
-
1
&&
errno
==
EINTR
);
if
(
number_recv
==
-
1
)
{
if
(
number_recv
==
-
1
&&
errno
!=
EAGAIN
&&
errno
!=
EWOULDBLOCK
)
{
LOG
(
ERROR
)
<<
"recv error: "
<<
strerror
(
errno
);
LOG
(
ERROR
)
<<
"recv error: "
<<
strerror
(
errno
);
}
}
...
...
src/rpc/network/tcp_socket.h
View file @
5cf48fc6
...
@@ -70,12 +70,12 @@ class TCPSocket {
...
@@ -70,12 +70,12 @@ class TCPSocket {
int
*
port_client
);
int
*
port_client
);
/*!
/*!
* \brief SetBlocking() is needed refering to this example of epoll:
* \brief Set
Non
Blocking() is needed refering to this example of epoll:
* http://www.kernel.org/doc/man-pages/online/pages/man4/epoll.4.html
* http://www.kernel.org/doc/man-pages/online/pages/man4/epoll.4.html
* \param flag
flag
for blocking
* \param flag
true for nonblocking, false
for blocking
* \return true for success and false for failure
* \return true for success and false for failure
*/
*/
bool
SetBlocking
(
bool
flag
);
bool
Set
Non
Blocking
(
bool
flag
);
/*!
/*!
* \brief Set timeout for socket
* \brief Set timeout for socket
...
...
src/rpc/rpc.cc
View file @
5cf48fc6
...
@@ -87,8 +87,10 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateSender")
...
@@ -87,8 +87,10 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateSender")
.
set_body
([]
(
DGLArgs
args
,
DGLRetValue
*
rv
)
{
.
set_body
([]
(
DGLArgs
args
,
DGLRetValue
*
rv
)
{
int64_t
msg_queue_size
=
args
[
0
];
int64_t
msg_queue_size
=
args
[
0
];
std
::
string
type
=
args
[
1
];
std
::
string
type
=
args
[
1
];
int
max_thread_count
=
args
[
2
];
if
(
type
.
compare
(
"socket"
)
==
0
)
{
if
(
type
.
compare
(
"socket"
)
==
0
)
{
RPCContext
::
ThreadLocal
()
->
sender
=
std
::
make_shared
<
network
::
SocketSender
>
(
msg_queue_size
);
RPCContext
::
ThreadLocal
()
->
sender
=
std
::
make_shared
<
network
::
SocketSender
>
(
msg_queue_size
,
max_thread_count
);
}
else
{
}
else
{
LOG
(
FATAL
)
<<
"Unknown communicator type for rpc receiver: "
<<
type
;
LOG
(
FATAL
)
<<
"Unknown communicator type for rpc receiver: "
<<
type
;
}
}
...
@@ -98,8 +100,10 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateReceiver")
...
@@ -98,8 +100,10 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateReceiver")
.
set_body
([]
(
DGLArgs
args
,
DGLRetValue
*
rv
)
{
.
set_body
([]
(
DGLArgs
args
,
DGLRetValue
*
rv
)
{
int64_t
msg_queue_size
=
args
[
0
];
int64_t
msg_queue_size
=
args
[
0
];
std
::
string
type
=
args
[
1
];
std
::
string
type
=
args
[
1
];
int
max_thread_count
=
args
[
2
];
if
(
type
.
compare
(
"socket"
)
==
0
)
{
if
(
type
.
compare
(
"socket"
)
==
0
)
{
RPCContext
::
ThreadLocal
()
->
receiver
=
std
::
make_shared
<
network
::
SocketReceiver
>
(
msg_queue_size
);
RPCContext
::
ThreadLocal
()
->
receiver
=
std
::
make_shared
<
network
::
SocketReceiver
>
(
msg_queue_size
,
max_thread_count
);
}
else
{
}
else
{
LOG
(
FATAL
)
<<
"Unknown communicator type for rpc sender: "
<<
type
;
LOG
(
FATAL
)
<<
"Unknown communicator type for rpc sender: "
<<
type
;
}
}
...
...
src/runtime/semaphore_wrapper.cc
0 → 100644
View file @
5cf48fc6
/*!
* Copyright (c) 2021 by Contributors
* \file semaphore_wrapper.cc
* \brief A simple corss platform semaphore wrapper
*/
#include "semaphore_wrapper.h"
#include <dmlc/logging.h>
namespace
dgl
{
namespace
runtime
{
#ifdef _WIN32
Semaphore
::
Semaphore
()
{
sem_
=
CreateSemaphore
(
nullptr
,
0
,
INT_MAX
,
nullptr
);
if
(
!
sem_
)
{
LOG
(
FATAL
)
<<
"Cannot create semaphore"
;
}
}
void
Semaphore
::
Wait
()
{
WaitForSingleObject
(
sem_
,
INFINITE
);
}
void
Semaphore
::
Post
()
{
ReleaseSemaphore
(
sem_
,
1
,
nullptr
);
}
#else
Semaphore
::
Semaphore
()
{
sem_init
(
&
sem_
,
0
,
0
);
}
void
Semaphore
::
Wait
()
{
sem_wait
(
&
sem_
);
}
void
Semaphore
::
Post
()
{
sem_post
(
&
sem_
);
}
#endif
}
// namespace runtime
}
// namespace dgl
src/runtime/semaphore_wrapper.h
0 → 100644
View file @
5cf48fc6
/*!
* Copyright (c) 2021 by Contributors
* \file semaphore_wrapper.h
* \brief A simple corss platform semaphore wrapper
*/
#ifndef DGL_RUNTIME_SEMAPHORE_WRAPPER_H_
#define DGL_RUNTIME_SEMAPHORE_WRAPPER_H_
#ifdef _WIN32
#include <windows.h>
#else
#include <semaphore.h>
#endif
namespace
dgl
{
namespace
runtime
{
/*!
* \brief A simple crossplatform Semaphore wrapper
*/
class
Semaphore
{
public:
/*!
* \brief Semaphore constructor
*/
Semaphore
();
/*!
* \brief blocking wait, decrease semaphore by 1
*/
void
Wait
();
/*!
* \brief increase semaphore by 1
*/
void
Post
();
private:
#ifdef _WIN32
HANDLE
sem_
;
#else
sem_t
sem_
;
#endif
};
}
// namespace runtime
}
// namespace dgl
#endif // DGL_RUNTIME_SEMAPHORE_WRAPPER_H_
tests/cpp/socket_communicator_test.cc
View file @
5cf48fc6
...
@@ -25,6 +25,7 @@ using dgl::network::Message;
...
@@ -25,6 +25,7 @@ using dgl::network::Message;
using
dgl
::
network
::
DefaultMessageDeleter
;
using
dgl
::
network
::
DefaultMessageDeleter
;
const
int64_t
kQueueSize
=
500
*
1024
;
const
int64_t
kQueueSize
=
500
*
1024
;
const
int
kThreadNum
=
2
;
#ifndef WIN32
#ifndef WIN32
...
@@ -61,7 +62,7 @@ TEST(SocketCommunicatorTest, SendAndRecv) {
...
@@ -61,7 +62,7 @@ TEST(SocketCommunicatorTest, SendAndRecv) {
}
}
void
start_client
()
{
void
start_client
()
{
SocketSender
sender
(
kQueueSize
);
SocketSender
sender
(
kQueueSize
,
kThreadNum
);
for
(
int
i
=
0
;
i
<
kNumReceiver
;
++
i
)
{
for
(
int
i
=
0
;
i
<
kNumReceiver
;
++
i
)
{
sender
.
AddReceiver
(
ip_addr
[
i
],
i
);
sender
.
AddReceiver
(
ip_addr
[
i
],
i
);
}
}
...
@@ -89,7 +90,7 @@ void start_client() {
...
@@ -89,7 +90,7 @@ void start_client() {
void
start_server
(
int
id
)
{
void
start_server
(
int
id
)
{
sleep
(
5
);
sleep
(
5
);
SocketReceiver
receiver
(
kQueueSize
);
SocketReceiver
receiver
(
kQueueSize
,
kThreadNum
);
receiver
.
Wait
(
ip_addr
[
id
],
kNumSender
);
receiver
.
Wait
(
ip_addr
[
id
],
kNumSender
);
for
(
int
i
=
0
;
i
<
kNumMessage
;
++
i
)
{
for
(
int
i
=
0
;
i
<
kNumMessage
;
++
i
)
{
for
(
int
n
=
0
;
n
<
kNumSender
;
++
n
)
{
for
(
int
n
=
0
;
n
<
kNumSender
;
++
n
)
{
...
@@ -168,7 +169,7 @@ static void start_client() {
...
@@ -168,7 +169,7 @@ static void start_client() {
std
::
string
ip_addr
((
std
::
istreambuf_iterator
<
char
>
(
t
)),
std
::
string
ip_addr
((
std
::
istreambuf_iterator
<
char
>
(
t
)),
std
::
istreambuf_iterator
<
char
>
());
std
::
istreambuf_iterator
<
char
>
());
t
.
close
();
t
.
close
();
SocketSender
sender
(
kQueueSize
);
SocketSender
sender
(
kQueueSize
,
kThreadNum
);
sender
.
AddReceiver
(
ip_addr
.
c_str
(),
0
);
sender
.
AddReceiver
(
ip_addr
.
c_str
(),
0
);
sender
.
Connect
();
sender
.
Connect
();
char
*
str_data
=
new
char
[
9
];
char
*
str_data
=
new
char
[
9
];
...
@@ -185,7 +186,7 @@ static bool start_server() {
...
@@ -185,7 +186,7 @@ static bool start_server() {
std
::
string
ip_addr
((
std
::
istreambuf_iterator
<
char
>
(
t
)),
std
::
string
ip_addr
((
std
::
istreambuf_iterator
<
char
>
(
t
)),
std
::
istreambuf_iterator
<
char
>
());
std
::
istreambuf_iterator
<
char
>
());
t
.
close
();
t
.
close
();
SocketReceiver
receiver
(
kQueueSize
);
SocketReceiver
receiver
(
kQueueSize
,
kThreadNum
);
receiver
.
Wait
(
ip_addr
.
c_str
(),
1
);
receiver
.
Wait
(
ip_addr
.
c_str
(),
1
);
Message
msg
;
Message
msg
;
EXPECT_EQ
(
receiver
.
RecvFrom
(
&
msg
,
0
),
REMOVE_SUCCESS
);
EXPECT_EQ
(
receiver
.
RecvFrom
(
&
msg
,
0
),
REMOVE_SUCCESS
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment