Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
8058f1c5
Unverified
Commit
8058f1c5
authored
Apr 05, 2019
by
Chao Ma
Committed by
GitHub
Apr 05, 2019
Browse files
[Fix] Update inner API of distributed sampler (#478)
* update inner API of distributed sampler * update
parent
da3ab84c
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
56 additions
and
27 deletions
+56
-27
python/dgl/contrib/sampling/dis_sampler.py
python/dgl/contrib/sampling/dis_sampler.py
+7
-7
python/dgl/network.py
python/dgl/network.py
+4
-6
src/graph/network.cc
src/graph/network.cc
+7
-14
src/graph/network/communicator.h
src/graph/network/communicator.h
+10
-0
src/graph/network/socket_communicator.cc
src/graph/network/socket_communicator.cc
+13
-0
src/graph/network/socket_communicator.h
src/graph/network/socket_communicator.h
+15
-0
No files found.
python/dgl/contrib/sampling/dis_sampler.py
View file @
8058f1c5
# This file contains DGL distributed samplers APIs.
# This file contains DGL distributed samplers APIs.
from
...network
import
_send_subgraph
,
_recv_subgraph
from
...network
import
_send_subgraph
,
_recv_subgraph
from
...network
import
_create_
sampler_
sender
,
_create_
sampler_
receiver
from
...network
import
_create_sender
,
_create_receiver
from
...network
import
_finalize_
sampler_
sender
,
_finalize_
sampler_
receiver
from
...network
import
_finalize_sender
,
_finalize_receiver
from
multiprocessing
import
Pool
from
multiprocessing
import
Pool
from
abc
import
ABCMeta
,
abstractmethod
from
abc
import
ABCMeta
,
abstractmethod
...
@@ -68,14 +68,14 @@ class SamplerSender(object):
...
@@ -68,14 +68,14 @@ class SamplerSender(object):
def
__init__
(
self
,
ip
,
port
):
def
__init__
(
self
,
ip
,
port
):
self
.
_ip
=
ip
self
.
_ip
=
ip
self
.
_port
=
port
self
.
_port
=
port
self
.
_sender
=
_create_
sampler_
sender
(
ip
,
port
)
self
.
_sender
=
_create_sender
(
ip
,
port
)
def
__del__
(
self
):
def
__del__
(
self
):
"""Finalize Sender
"""Finalize Sender
"""
"""
# _finalize_
sampler_
sender will send a special message
# _finalize_sender will send a special message
# to tell the remote trainer machine that it has finished its job.
# to tell the remote trainer machine that it has finished its job.
_finalize_
sampler_
sender
(
self
.
_sender
)
_finalize_sender
(
self
.
_sender
)
def
send
(
self
,
nodeflow
):
def
send
(
self
,
nodeflow
):
"""Send sampled subgraph (NodeFlow) to remote trainer.
"""Send sampled subgraph (NodeFlow) to remote trainer.
...
@@ -109,7 +109,7 @@ class SamplerReceiver(object):
...
@@ -109,7 +109,7 @@ class SamplerReceiver(object):
self
.
_ip
=
ip
self
.
_ip
=
ip
self
.
_port
=
port
self
.
_port
=
port
self
.
_num_sender
=
num_sender
self
.
_num_sender
=
num_sender
self
.
_receiver
=
_create_
sampler_
receiver
(
ip
,
port
,
num_sender
)
self
.
_receiver
=
_create_receiver
(
ip
,
port
,
num_sender
)
def
__del__
(
self
):
def
__del__
(
self
):
"""Finalize Receiver
"""Finalize Receiver
...
@@ -117,7 +117,7 @@ class SamplerReceiver(object):
...
@@ -117,7 +117,7 @@ class SamplerReceiver(object):
_finalize_sampler_receiver method will clean up the
_finalize_sampler_receiver method will clean up the
back-end threads started by the SamplerReceiver.
back-end threads started by the SamplerReceiver.
"""
"""
_finalize_
sampler_
receiver
(
self
.
_receiver
)
_finalize_receiver
(
self
.
_receiver
)
def
recv
(
self
,
graph
):
def
recv
(
self
,
graph
):
"""Receive a NodeFlow object from remote sampler.
"""Receive a NodeFlow object from remote sampler.
...
...
python/dgl/network.py
View file @
8058f1c5
...
@@ -8,9 +8,7 @@ from . import utils
...
@@ -8,9 +8,7 @@ from . import utils
_init_api
(
"dgl.network"
)
_init_api
(
"dgl.network"
)
############################# Distributed Sampler #############################
def
_create_sender
(
ip_addr
,
port
):
def
_create_sampler_sender
(
ip_addr
,
port
):
"""Create a sender communicator via C socket.
"""Create a sender communicator via C socket.
Parameters
Parameters
...
@@ -22,7 +20,7 @@ def _create_sampler_sender(ip_addr, port):
...
@@ -22,7 +20,7 @@ def _create_sampler_sender(ip_addr, port):
"""
"""
return
_CAPI_DGLSenderCreate
(
ip_addr
,
port
)
return
_CAPI_DGLSenderCreate
(
ip_addr
,
port
)
def
_create_
sampler_
receiver
(
ip_addr
,
port
,
num_sender
):
def
_create_receiver
(
ip_addr
,
port
,
num_sender
):
"""Create a receiver communicator via C socket.
"""Create a receiver communicator via C socket.
Parameters
Parameters
...
@@ -78,7 +76,7 @@ def _recv_subgraph(receiver, graph):
...
@@ -78,7 +76,7 @@ def _recv_subgraph(receiver, graph):
hdl
=
unwrap_to_ptr_list
(
_CAPI_ReceiverRecvSubgraph
(
receiver
))
hdl
=
unwrap_to_ptr_list
(
_CAPI_ReceiverRecvSubgraph
(
receiver
))
return
NodeFlow
(
graph
,
hdl
[
0
])
return
NodeFlow
(
graph
,
hdl
[
0
])
def
_finalize_
sampler_
sender
(
sender
):
def
_finalize_sender
(
sender
):
"""Finalize Sender communicator
"""Finalize Sender communicator
Parameters
Parameters
...
@@ -88,7 +86,7 @@ def _finalize_sampler_sender(sender):
...
@@ -88,7 +86,7 @@ def _finalize_sampler_sender(sender):
"""
"""
_CAPI_DGLFinalizeCommunicator
(
sender
)
_CAPI_DGLFinalizeCommunicator
(
sender
)
def
_finalize_
sampler_
receiver
(
receiver
):
def
_finalize_receiver
(
receiver
):
"""Finalize Receiver communicator
"""Finalize Receiver communicator
Parameters
Parameters
...
...
src/graph/network.cc
View file @
8058f1c5
...
@@ -20,11 +20,6 @@ using dgl::runtime::NDArray;
...
@@ -20,11 +20,6 @@ using dgl::runtime::NDArray;
namespace
dgl
{
namespace
dgl
{
namespace
network
{
namespace
network
{
static
char
*
sender_data_buffer
=
nullptr
;
static
char
*
recv_data_buffer
=
nullptr
;
///////////////////////// Distributed Sampler /////////////////////////
DGL_REGISTER_GLOBAL
(
"network._CAPI_DGLSenderCreate"
)
DGL_REGISTER_GLOBAL
(
"network._CAPI_DGLSenderCreate"
)
.
set_body
([]
(
DGLArgs
args
,
DGLRetValue
*
rv
)
{
.
set_body
([]
(
DGLArgs
args
,
DGLRetValue
*
rv
)
{
std
::
string
ip
=
args
[
0
];
std
::
string
ip
=
args
[
0
];
...
@@ -34,7 +29,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderCreate")
...
@@ -34,7 +29,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderCreate")
LOG
(
FATAL
)
<<
"Initialize network communicator (sender) error."
;
LOG
(
FATAL
)
<<
"Initialize network communicator (sender) error."
;
}
}
try
{
try
{
sender_data_b
uffer
=
new
char
[
kMaxBufferSize
];
comm
->
SetB
uffer
(
new
char
[
kMaxBufferSize
]
)
;
}
catch
(
const
std
::
bad_alloc
&
)
{
}
catch
(
const
std
::
bad_alloc
&
)
{
LOG
(
FATAL
)
<<
"Not enough memory for sender buffer: "
<<
kMaxBufferSize
;
LOG
(
FATAL
)
<<
"Not enough memory for sender buffer: "
<<
kMaxBufferSize
;
}
}
...
@@ -52,7 +47,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverCreate")
...
@@ -52,7 +47,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverCreate")
LOG
(
FATAL
)
<<
"Initialize network communicator (receiver) error."
;
LOG
(
FATAL
)
<<
"Initialize network communicator (receiver) error."
;
}
}
try
{
try
{
recv_data_b
uffer
=
new
char
[
kMaxBufferSize
];
comm
->
SetB
uffer
(
new
char
[
kMaxBufferSize
]
)
;
}
catch
(
const
std
::
bad_alloc
&
)
{
}
catch
(
const
std
::
bad_alloc
&
)
{
LOG
(
FATAL
)
<<
"Not enough memory for receiver buffer: "
<<
kMaxBufferSize
;
LOG
(
FATAL
)
<<
"Not enough memory for receiver buffer: "
<<
kMaxBufferSize
;
}
}
...
@@ -73,7 +68,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendSubgraph")
...
@@ -73,7 +68,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendSubgraph")
auto
csr
=
ptr
->
GetInCSR
();
auto
csr
=
ptr
->
GetInCSR
();
// Serialize nodeflow to data buffer
// Serialize nodeflow to data buffer
int64_t
data_size
=
network
::
SerializeSampledSubgraph
(
int64_t
data_size
=
network
::
SerializeSampledSubgraph
(
sender_data_b
uffer
,
comm
->
GetB
uffer
()
,
csr
,
csr
,
node_mapping
,
node_mapping
,
edge_mapping
,
edge_mapping
,
...
@@ -81,7 +76,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendSubgraph")
...
@@ -81,7 +76,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendSubgraph")
flow_offsets
);
flow_offsets
);
CHECK_GT
(
data_size
,
0
);
CHECK_GT
(
data_size
,
0
);
// Send msg via network
// Send msg via network
int64_t
size
=
comm
->
Send
(
sender_data_b
uffer
,
data_size
);
int64_t
size
=
comm
->
Send
(
comm
->
GetB
uffer
()
,
data_size
);
if
(
size
<=
0
)
{
if
(
size
<=
0
)
{
LOG
(
ERROR
)
<<
"Send message error (size: "
<<
size
<<
")"
;
LOG
(
ERROR
)
<<
"Send message error (size: "
<<
size
<<
")"
;
}
}
...
@@ -92,15 +87,15 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvSubgraph")
...
@@ -92,15 +87,15 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvSubgraph")
CommunicatorHandle
chandle
=
args
[
0
];
CommunicatorHandle
chandle
=
args
[
0
];
network
::
Communicator
*
comm
=
static_cast
<
network
::
Communicator
*>
(
chandle
);
network
::
Communicator
*
comm
=
static_cast
<
network
::
Communicator
*>
(
chandle
);
// Recv data from network
// Recv data from network
int64_t
size
=
comm
->
Receive
(
recv_data_b
uffer
,
kMaxBufferSize
);
int64_t
size
=
comm
->
Receive
(
comm
->
GetB
uffer
()
,
kMaxBufferSize
);
if
(
size
<=
0
)
{
if
(
size
<=
0
)
{
LOG
(
ERROR
)
<<
"Receive error: (size: "
<<
size
<<
")"
;
LOG
(
ERROR
)
<<
"Receive error: (size: "
<<
size
<<
")"
;
}
}
NodeFlow
*
nf
=
new
NodeFlow
();
NodeFlow
*
nf
=
new
NodeFlow
();
ImmutableGraph
::
CSR
::
Ptr
csr
;
ImmutableGraph
::
CSR
::
Ptr
csr
;
// Deserialize nodeflow from recv_data_buffer
// Deserialize nodeflow from recv_data_buffer
network
::
DeserializeSampledSubgraph
(
recv_data_b
uffer
,
network
::
DeserializeSampledSubgraph
(
comm
->
GetB
uffer
()
,
&
csr
,
&
(
csr
)
,
&
(
nf
->
node_mapping
),
&
(
nf
->
node_mapping
),
&
(
nf
->
edge_mapping
),
&
(
nf
->
edge_mapping
),
&
(
nf
->
layer_offsets
),
&
(
nf
->
layer_offsets
),
...
@@ -116,8 +111,6 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLFinalizeCommunicator")
...
@@ -116,8 +111,6 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLFinalizeCommunicator")
CommunicatorHandle
chandle
=
args
[
0
];
CommunicatorHandle
chandle
=
args
[
0
];
network
::
Communicator
*
comm
=
static_cast
<
network
::
Communicator
*>
(
chandle
);
network
::
Communicator
*
comm
=
static_cast
<
network
::
Communicator
*>
(
chandle
);
comm
->
Finalize
();
comm
->
Finalize
();
delete
[]
sender_data_buffer
;
delete
[]
recv_data_buffer
;
});
});
}
// namespace network
}
// namespace network
...
...
src/graph/network/communicator.h
View file @
8058f1c5
...
@@ -67,6 +67,16 @@ class Communicator {
...
@@ -67,6 +67,16 @@ class Communicator {
* \brief Finalize the Communicator class
* \brief Finalize the Communicator class
*/
*/
virtual
void
Finalize
()
=
0
;
virtual
void
Finalize
()
=
0
;
/*!
* \brief Set pointer of memory buffer allocated for Communicator
*/
virtual
void
SetBuffer
(
char
*
buffer
)
=
0
;
/*!
* \brief Get pointer of memory buffer allocated for Communicator
*/
virtual
char
*
GetBuffer
()
=
0
;
};
};
}
// namespace network
}
// namespace network
...
...
src/graph/network/socket_communicator.cc
View file @
8058f1c5
...
@@ -162,6 +162,9 @@ void SocketCommunicator::FinalizeSender() {
...
@@ -162,6 +162,9 @@ void SocketCommunicator::FinalizeSender() {
delete
socket_
[
0
];
delete
socket_
[
0
];
socket_
[
0
]
=
nullptr
;
socket_
[
0
]
=
nullptr
;
}
}
if
(
buffer_
!=
nullptr
)
{
delete
[]
buffer_
;
}
}
}
void
SocketCommunicator
::
FinalizeReceiver
()
{
void
SocketCommunicator
::
FinalizeReceiver
()
{
...
@@ -209,5 +212,15 @@ int64_t SocketCommunicator::Receive(char* dest, int64_t max_size) {
...
@@ -209,5 +212,15 @@ int64_t SocketCommunicator::Receive(char* dest, int64_t max_size) {
return
queue_
->
Remove
(
dest
,
max_size
);
return
queue_
->
Remove
(
dest
,
max_size
);
}
}
void
SocketCommunicator
::
SetBuffer
(
char
*
buffer
)
{
// Set memory buffer allocated for current Communicator
buffer_
=
buffer
;
}
char
*
SocketCommunicator
::
GetBuffer
()
{
// Get memory buffer allocated for current Communicator
return
buffer_
;
}
}
// namespace network
}
// namespace network
}
// namespace dgl
}
// namespace dgl
src/graph/network/socket_communicator.h
View file @
8058f1c5
...
@@ -67,6 +67,16 @@ class SocketCommunicator : public Communicator {
...
@@ -67,6 +67,16 @@ class SocketCommunicator : public Communicator {
*/
*/
void
Finalize
();
void
Finalize
();
/*!
* \brief Set pointer of memory buffer allocated for Communicator
*/
void
SetBuffer
(
char
*
buffer
);
/*!
* \brief Get pointer of memory buffer allocated for Communicator
*/
char
*
GetBuffer
();
private:
private:
/*!
/*!
* \brief Is a sender or reciever node?
* \brief Is a sender or reciever node?
...
@@ -98,6 +108,11 @@ class SocketCommunicator : public Communicator {
...
@@ -98,6 +108,11 @@ class SocketCommunicator : public Communicator {
*/
*/
MessageQueue
*
queue_
;
MessageQueue
*
queue_
;
/*!
* \brief Memory buffer for communicator
*/
char
*
buffer_
=
nullptr
;
/*!
/*!
* \brief Initalize sender node
* \brief Initalize sender node
* \param ip receiver ip address
* \param ip receiver ip address
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment