Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Paddle
Commits
f0ef3442
Commit
f0ef3442
authored
Apr 26, 2023
by
yuguo960516yuguo
Browse files
2.3.2-dtk-22.10.1
parent
ad08b8ce
Pipeline
#227
failed with stages
in 0 seconds
Changes
274
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1669 additions
and
0 deletions
+1669
-0
paddle/fluid/distributed/fleet_executor/interceptor.cc
paddle/fluid/distributed/fleet_executor/interceptor.cc
+122
-0
paddle/fluid/distributed/fleet_executor/interceptor.h
paddle/fluid/distributed/fleet_executor/interceptor.h
+162
-0
paddle/fluid/distributed/fleet_executor/interceptor_message.proto
...luid/distributed/fleet_executor/interceptor_message.proto
+43
-0
paddle/fluid/distributed/fleet_executor/message_bus.cc
paddle/fluid/distributed/fleet_executor/message_bus.cc
+256
-0
paddle/fluid/distributed/fleet_executor/message_bus.h
paddle/fluid/distributed/fleet_executor/message_bus.h
+95
-0
paddle/fluid/distributed/fleet_executor/message_service.cc
paddle/fluid/distributed/fleet_executor/message_service.cc
+51
-0
paddle/fluid/distributed/fleet_executor/message_service.h
paddle/fluid/distributed/fleet_executor/message_service.h
+41
-0
paddle/fluid/distributed/fleet_executor/runtime_graph.cc
paddle/fluid/distributed/fleet_executor/runtime_graph.cc
+33
-0
paddle/fluid/distributed/fleet_executor/runtime_graph.h
paddle/fluid/distributed/fleet_executor/runtime_graph.h
+56
-0
paddle/fluid/distributed/fleet_executor/sink_interceptor.cc
paddle/fluid/distributed/fleet_executor/sink_interceptor.cc
+66
-0
paddle/fluid/distributed/fleet_executor/sink_interceptor.h
paddle/fluid/distributed/fleet_executor/sink_interceptor.h
+41
-0
paddle/fluid/distributed/fleet_executor/source_interceptor.cc
...le/fluid/distributed/fleet_executor/source_interceptor.cc
+58
-0
paddle/fluid/distributed/fleet_executor/source_interceptor.h
paddle/fluid/distributed/fleet_executor/source_interceptor.h
+41
-0
paddle/fluid/distributed/fleet_executor/task_loop.cc
paddle/fluid/distributed/fleet_executor/task_loop.cc
+85
-0
paddle/fluid/distributed/fleet_executor/task_loop.h
paddle/fluid/distributed/fleet_executor/task_loop.h
+84
-0
paddle/fluid/distributed/fleet_executor/task_loop_thread.cc
paddle/fluid/distributed/fleet_executor/task_loop_thread.cc
+60
-0
paddle/fluid/distributed/fleet_executor/task_loop_thread.h
paddle/fluid/distributed/fleet_executor/task_loop_thread.h
+48
-0
paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.cc
...fluid/distributed/fleet_executor/task_loop_thread_pool.cc
+77
-0
paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h
.../fluid/distributed/fleet_executor/task_loop_thread_pool.h
+51
-0
paddle/fluid/distributed/fleet_executor/task_node.cc
paddle/fluid/distributed/fleet_executor/task_node.cc
+199
-0
No files found.
Too many changes to show.
To preserve performance only
274 of 274+
files are displayed.
Plain diff
Email patch
paddle/fluid/distributed/fleet_executor/interceptor.cc
0 → 100644
View file @
f0ef3442
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace
paddle
{
namespace
distributed
{
Interceptor
::
Interceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
)
:
interceptor_id_
(
interceptor_id
),
node_
(
node
)
{}
Interceptor
::~
Interceptor
()
{
// FIXME(wangxi): throw in stop function
// std::lock_guard<std::mutex> lock(mutex_);
// PADDLE_ENFORCE_EQ(messages_.empty(), true,
// platform::errors::PreconditionNotMet(
// "Interceptor must destruct with messages empty"));
}
void
Interceptor
::
RegisterMsgHandle
(
MsgHandle
handle
)
{
handle_
=
handle
;
}
void
Interceptor
::
Handle
(
const
InterceptorMessage
&
msg
)
{
PADDLE_ENFORCE_NOT_NULL
(
handle_
,
platform
::
errors
::
PreconditionNotMet
(
"Message handle is not registered."
));
handle_
(
msg
);
}
void
Interceptor
::
LoopOnce
()
{
std
::
deque
<
InterceptorMessage
>
tmp_messages
;
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
messages_
.
swap
(
tmp_messages
);
}
PADDLE_ENFORCE_EQ
(
tmp_messages
.
empty
(),
false
,
platform
::
errors
::
PreconditionNotMet
(
"tmp_messages must not empty in task loop"
));
for
(
auto
&
msg
:
tmp_messages
)
{
const
MessageType
message_type
=
msg
.
message_type
();
VLOG
(
3
)
<<
"Interceptor "
<<
interceptor_id_
<<
" has received a message"
<<
" from interceptor "
<<
msg
.
src_id
()
<<
" with message: "
<<
message_type
<<
"."
;
Handle
(
msg
);
}
}
void
Interceptor
::
StopCarrier
()
{
PADDLE_ENFORCE_NOT_NULL
(
carrier_
,
platform
::
errors
::
PreconditionNotMet
(
"Carrier is not registered."
));
carrier_
->
WakeUp
();
}
void
Interceptor
::
EnqueueRemoteInterceptorMessage
(
const
InterceptorMessage
&
message
)
{
// Called by Carrier, enqueue an InterceptorMessage to remote mailbox
VLOG
(
3
)
<<
"Enqueue message: "
<<
message
.
message_type
()
<<
" into "
<<
interceptor_id_
<<
"'s remote mailbox."
;
bool
empty
=
false
;
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
empty
=
messages_
.
empty
();
messages_
.
emplace_back
(
message
);
}
if
(
empty
)
{
loop_
->
QueueInLoop
([
this
]()
{
LoopOnce
();
});
}
}
bool
Interceptor
::
Send
(
int64_t
dst_id
,
InterceptorMessage
&
msg
)
{
PADDLE_ENFORCE_NOT_NULL
(
carrier_
,
platform
::
errors
::
PreconditionNotMet
(
"Carrier is not registered."
));
msg
.
set_src_id
(
interceptor_id_
);
msg
.
set_dst_id
(
dst_id
);
return
carrier_
->
Send
(
msg
);
}
static
InterceptorFactory
::
CreateInterceptorMap
&
GetInterceptorMap
()
{
static
InterceptorFactory
::
CreateInterceptorMap
interceptorMap
;
return
interceptorMap
;
}
std
::
unique_ptr
<
Interceptor
>
InterceptorFactory
::
Create
(
const
std
::
string
&
type
,
int64_t
id
,
TaskNode
*
node
)
{
auto
&
interceptor_map
=
GetInterceptorMap
();
auto
iter
=
interceptor_map
.
find
(
type
);
PADDLE_ENFORCE_NE
(
iter
,
interceptor_map
.
end
(),
platform
::
errors
::
NotFound
(
"interceptor %s is not register"
,
type
));
return
iter
->
second
(
id
,
node
);
}
void
InterceptorFactory
::
Register
(
const
std
::
string
&
type
,
InterceptorFactory
::
CreateInterceptorFunc
func
)
{
auto
&
interceptor_map
=
GetInterceptorMap
();
interceptor_map
.
emplace
(
type
,
func
);
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/interceptor.h
0 → 100644
View file @
f0ef3442
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <condition_variable>
#include <deque>
#include <functional>
#include <map>
#include <memory>
#include <thread>
#include <vector>
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/place.h"
namespace
paddle
{
namespace
framework
{
class
Scope
;
class
GarbageCollector
;
}
// namespace framework
namespace
distributed
{
class
TaskNode
;
class
Carrier
;
class
TaskLoop
;
constexpr
int64_t
SOURCE_ID
=
-
1
;
constexpr
int64_t
SINK_ID
=
-
2
;
class
Interceptor
{
public:
using
MsgHandle
=
std
::
function
<
void
(
const
InterceptorMessage
&
)
>
;
public:
Interceptor
()
=
delete
;
Interceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
);
virtual
~
Interceptor
();
// register interceptor handle
void
RegisterMsgHandle
(
MsgHandle
handle
);
void
Handle
(
const
InterceptorMessage
&
msg
);
// return the interceptor id
int64_t
GetInterceptorId
()
const
{
return
interceptor_id_
;
}
// Called by Carrier, enqueue an InterceptorMessage to remote mailbox
void
EnqueueRemoteInterceptorMessage
(
const
InterceptorMessage
&
interceptor_message
);
bool
Send
(
int64_t
dst_id
,
InterceptorMessage
&
msg
);
// NOLINT
void
SetPlace
(
const
platform
::
Place
&
place
)
{
place_
=
place
;
}
void
SetRootScope
(
framework
::
Scope
*
scope
)
{
root_scope_
=
scope
;
}
void
SetMiniBatchScope
(
framework
::
Scope
*
scope
)
{
minibatch_scope_
=
scope
;
}
void
SetMicroBatchScope
(
const
std
::
vector
<
framework
::
Scope
*>&
scopes
)
{
microbatch_scopes_
=
scopes
;
}
void
SetGC
(
const
std
::
shared_ptr
<
framework
::
GarbageCollector
>&
gc
)
{
gc_
=
gc
;
}
void
RegisterCarrier
(
Carrier
*
carrier
)
{
carrier_
=
carrier
;
}
void
RegisterTaskLoop
(
TaskLoop
*
loop
)
{
loop_
=
loop
;
}
TaskNode
*
GetTaskNode
()
const
{
return
node_
;
}
DISABLE_COPY_AND_ASSIGN
(
Interceptor
);
protected:
// interceptor id, handed from above layer
int64_t
interceptor_id_
;
// node need to be handled by this interceptor
TaskNode
*
node_
;
// for stop
bool
stop_
{
false
};
void
StopCarrier
();
// for runtime
platform
::
Place
place_
;
framework
::
Scope
*
root_scope_
{
nullptr
};
framework
::
Scope
*
minibatch_scope_
{
nullptr
};
std
::
vector
<
framework
::
Scope
*>
microbatch_scopes_
{};
std
::
shared_ptr
<
framework
::
GarbageCollector
>
gc_
{
nullptr
};
Carrier
*
carrier_
;
TaskLoop
*
loop_
;
private:
void
LoopOnce
();
// interceptor handle which process message
MsgHandle
handle_
{
nullptr
};
std
::
mutex
mutex_
;
std
::
deque
<
InterceptorMessage
>
messages_
;
int64_t
already_run_times_
{
0
};
int64_t
used_slot_nums_
{
0
};
};
class
InterceptorFactory
{
public:
using
CreateInterceptorFunc
=
std
::
unique_ptr
<
Interceptor
>
(
*
)(
int64_t
,
TaskNode
*
);
using
CreateInterceptorMap
=
std
::
unordered_map
<
std
::
string
,
CreateInterceptorFunc
>
;
static
void
Register
(
const
std
::
string
&
type
,
CreateInterceptorFunc
func
);
static
std
::
unique_ptr
<
Interceptor
>
Create
(
const
std
::
string
&
type
,
int64_t
id
,
TaskNode
*
node
);
};
template
<
typename
InterceptorClass
>
std
::
unique_ptr
<
Interceptor
>
CreatorInterceptor
(
int64_t
id
,
TaskNode
*
node
)
{
return
std
::
make_unique
<
InterceptorClass
>
(
id
,
node
);
}
#define REGISTER_INTERCEPTOR(interceptor_type, interceptor_class) \
class __RegisterInterceptor_##interceptor_type { \
public: \
__RegisterInterceptor_##interceptor_type() { \
InterceptorFactory::Register(#interceptor_type, \
CreatorInterceptor<interceptor_class>); \
} \
void Touch() {} \
}; \
__RegisterInterceptor_##interceptor_type g_register_##interceptor_type; \
int TouchRegisterInterceptor_##interceptor_type() { \
g_register_##interceptor_type.Touch(); \
return 0; \
}
#define USE_INTERCEPTOR(interceptor_type) \
extern int TouchRegisterInterceptor_##interceptor_type(); \
UNUSED static int use_interceptor_##interceptor_type = \
TouchRegisterInterceptor_##interceptor_type();
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/interceptor_message.proto
0 → 100644
View file @
f0ef3442
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax
=
"proto2"
;
package
paddle
.
distributed
;
option
cc_generic_services
=
true
;
option
cc_enable_arenas
=
true
;
enum
MessageType
{
STOP
=
1
;
// STOP an Interceptor
DATA_IS_READY
=
2
;
// upstream data is ready
DATA_IS_USELESS
=
3
;
// downstream has used the data
ERR
=
4
;
// current Interceptor encounters error
RESET
=
5
;
// reset the status
START
=
6
;
}
message
InterceptorMessage
{
optional
sint64
src_id
=
1
[
default
=
0
];
optional
sint64
dst_id
=
2
[
default
=
0
];
optional
MessageType
message_type
=
3
[
default
=
RESET
];
optional
bool
ctrl_message
=
4
[
default
=
false
];
optional
int64
scope_idx
=
5
[
default
=
0
];
}
message
InterceptorResponse
{
optional
bool
rst
=
1
[
default
=
false
];
}
service
MessageService
{
rpc
ReceiveInterceptorMessage
(
InterceptorMessage
)
returns
(
InterceptorResponse
);
rpc
IncreaseBarrierCount
(
InterceptorMessage
)
returns
(
InterceptorResponse
);
}
paddle/fluid/distributed/fleet_executor/message_bus.cc
0 → 100644
View file @
f0ef3442
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include <chrono>
#include <memory>
#include <set>
#include <thread>
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h"
namespace
paddle
{
namespace
distributed
{
void
MessageBus
::
Init
(
int64_t
rank
,
const
std
::
unordered_map
<
int64_t
,
std
::
string
>&
rank_to_addr
,
const
std
::
string
&
addr
)
{
PADDLE_ENFORCE_EQ
(
is_init_
,
false
,
platform
::
errors
::
AlreadyExists
(
"MessageBus is already init."
));
rank_
=
rank
;
is_init_
=
true
;
rank_to_addr_
=
rank_to_addr
;
addr_
=
addr
;
if
(
addr_
!=
""
)
{
const
auto
&
addr
=
GetAddr
(
rank_
);
PADDLE_ENFORCE_EQ
(
addr
,
addr_
,
platform
::
errors
::
Fatal
(
"The current rank's addr is %s, while the "
"message bus's addr is %s, which are different. "
"Init error."
,
addr
,
addr_
));
}
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_ASCEND_CL)
// NOTE: To make the brpc is compatible with collective,
// need release the handler holding the ip address.
if
(
addr_
!=
""
)
{
VLOG
(
3
)
<<
"Message bus is releasing the fd held by gen_comm_id."
;
paddle
::
platform
::
SocketServer
&
socket_server
=
paddle
::
platform
::
SocketServer
::
GetInstance
(
addr_
);
int
server_fd
=
socket_server
.
socket
();
if
(
server_fd
!=
-
1
)
{
socket_server
.
Release
();
}
}
#endif
ListenPort
();
}
bool
MessageBus
::
IsInit
()
const
{
return
is_init_
;
}
MessageBus
::~
MessageBus
()
{
VLOG
(
3
)
<<
"Message bus releases resource."
;
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
server_
.
Stop
(
1000
);
server_
.
Join
();
#endif
}
const
std
::
string
&
MessageBus
::
GetAddr
(
int64_t
rank
)
const
{
PADDLE_ENFORCE_NE
(
rank_to_addr_
.
find
(
rank
),
rank_to_addr_
.
end
(),
platform
::
errors
::
NotFound
(
"Cannot find addr rank id %lld."
,
rank
));
return
rank_to_addr_
.
at
(
rank
);
}
bool
MessageBus
::
Send
(
int64_t
dst_rank
,
const
InterceptorMessage
&
interceptor_message
)
{
PADDLE_ENFORCE_EQ
(
IsInit
(),
true
,
platform
::
errors
::
PreconditionNotMet
(
"Using message bus since it has not been initialized."
));
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
int
retry_time
=
0
;
// message bus will retry sending for 10 times
while
(
retry_time
<
10
)
{
++
retry_time
;
if
(
SendInterRank
(
dst_rank
,
interceptor_message
))
{
VLOG
(
3
)
<<
"Message bus sends inter rank successfully with "
<<
retry_time
<<
" times retries."
;
return
true
;
}
VLOG
(
3
)
<<
"Message bus sends failed, retry after 1 seconds."
;
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
1000
));
}
VLOG
(
3
)
<<
"Message bus sends inter rank fail after 10 times retries."
;
return
false
;
#else
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"Fleet executor does not support sending message between different "
"ranks when Paddle is compiled with npu or "
"isn't compiled with distributed for now."
));
#endif
return
true
;
}
void
MessageBus
::
IncreaseBarrierCount
()
{
VLOG
(
3
)
<<
"IncreaseBarrierCount"
;
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
++
count_
;
cv_
.
notify_one
();
}
VLOG
(
3
)
<<
"End IncreaseBarrierCount"
;
}
void
MessageBus
::
Barrier
()
{
// gather to root
if
(
rank_
!=
0
)
{
InterceptorMessage
ctrl_msg
;
ctrl_msg
.
set_ctrl_message
(
true
);
ctrl_msg
.
set_src_id
(
rank_
);
ctrl_msg
.
set_dst_id
(
0
);
VLOG
(
3
)
<<
"Barrier Gather ctrl message from "
<<
rank_
<<
" to 0"
;
while
(
!
Send
(
0
,
ctrl_msg
))
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
1000
));
}
}
else
{
VLOG
(
3
)
<<
"Barrier 0 wait others rank ready"
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
cv_
.
wait
(
lock
,
[
this
]
{
return
count_
==
static_cast
<
int
>
(
rank_to_addr_
.
size
()
-
1
);
});
count_
=
0
;
}
// scatter from root
if
(
rank_
==
0
)
{
for
(
int
i
=
1
;
i
<
static_cast
<
int
>
(
rank_to_addr_
.
size
());
++
i
)
{
InterceptorMessage
ctrl_msg
;
ctrl_msg
.
set_ctrl_message
(
true
);
ctrl_msg
.
set_src_id
(
0
);
ctrl_msg
.
set_dst_id
(
i
);
VLOG
(
3
)
<<
"Barrier Scatter ctrl message from 0 to "
<<
i
;
while
(
!
Send
(
i
,
ctrl_msg
))
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
1000
));
}
}
}
else
{
VLOG
(
3
)
<<
"Barrier "
<<
rank_
<<
" wait others rank ready"
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
cv_
.
wait
(
lock
,
[
this
]
{
return
count_
==
1
;
});
count_
=
0
;
}
}
bool
MessageBus
::
DispatchMsgToCarrier
(
const
InterceptorMessage
&
interceptor_message
)
{
const
std
::
string
&
carrier_id
=
*
GlobalVal
<
std
::
string
>::
Get
();
return
GlobalMap
<
std
::
string
,
Carrier
>::
Get
(
carrier_id
)
->
EnqueueInterceptorMessage
(
interceptor_message
);
}
void
MessageBus
::
ListenPort
()
{
if
(
addr_
==
""
)
{
LOG
(
INFO
)
<<
"No need listen to port since training on single card."
;
return
;
}
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
// function keep listen the port and handle the message
PADDLE_ENFORCE_EQ
(
server_
.
AddService
(
&
message_service_
,
brpc
::
SERVER_DOESNT_OWN_SERVICE
),
0
,
platform
::
errors
::
Unavailable
(
"Message bus: init brpc service error."
));
// start the server
const
char
*
ip_for_brpc
=
addr_
.
c_str
();
brpc
::
ServerOptions
options
;
options
.
idle_timeout_sec
=
-
1
;
int
retry_times
=
0
;
int
interval
=
100
;
while
(
server_
.
Start
(
ip_for_brpc
,
&
options
)
!=
0
)
{
++
retry_times
;
LOG
(
INFO
)
<<
"Message bus is retring for starting brpc for "
<<
retry_times
<<
" times. And will retry after "
<<
interval
/
1000
<<
" seconds."
;
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
interval
));
interval
+=
500
;
}
LOG
(
INFO
)
<<
"Message bus's listen port thread starts successful."
;
#else
LOG
(
WARNING
)
<<
"Fleet executor's ListenPort() is a fake function when Paddle is "
"compiled with npu or Paddle isn't compiled "
"with distributed for now."
;
#endif
}
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
bool
MessageBus
::
SendInterRank
(
int64_t
dst_rank
,
const
InterceptorMessage
&
interceptor_message
)
{
const
auto
&
dst_addr
=
GetAddr
(
dst_rank
);
VLOG
(
3
)
<<
"Message bus sending to addr: "
<<
dst_addr
;
const
char
*
dst_addr_for_brpc
=
dst_addr
.
c_str
();
brpc
::
Channel
channel
;
brpc
::
ChannelOptions
options
;
options
.
protocol
=
"baidu_std"
;
options
.
connect_timeout_ms
=
1000
;
options
.
timeout_ms
=
1000
;
options
.
max_retry
=
5
;
PADDLE_ENFORCE_EQ
(
channel
.
Init
(
dst_addr_for_brpc
,
&
options
),
0
,
platform
::
errors
::
Unavailable
(
"Message bus: init brpc channel error."
));
MessageService_Stub
stub
(
&
channel
);
InterceptorResponse
response
;
brpc
::
Controller
ctrl
;
ctrl
.
set_log_id
(
0
);
if
(
interceptor_message
.
ctrl_message
())
{
stub
.
IncreaseBarrierCount
(
&
ctrl
,
&
interceptor_message
,
&
response
,
NULL
);
}
else
{
stub
.
ReceiveInterceptorMessage
(
&
ctrl
,
&
interceptor_message
,
&
response
,
NULL
);
}
if
(
!
ctrl
.
Failed
())
{
if
(
response
.
rst
())
{
VLOG
(
3
)
<<
"Message bus: brpc sends success."
;
return
true
;
}
else
{
VLOG
(
4
)
<<
"Message bus: InterceptorMessageService error."
;
return
false
;
}
}
else
{
VLOG
(
4
)
<<
"Message bus: brpc sends failed with error text: "
<<
ctrl
.
ErrorText
();
return
false
;
}
}
#endif
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/message_bus.h
0 → 100644
View file @
f0ef3442
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <condition_variable>
#include <mutex>
#include <string>
#include <thread>
#include <unordered_map>
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
#include "brpc/channel.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/fleet_executor/message_service.h"
#endif
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/macros.h"
namespace
paddle
{
namespace
distributed
{
class
Carrier
;
// A singleton MessageBus
class
MessageBus
final
{
public:
MessageBus
()
=
default
;
~
MessageBus
();
void
Init
(
int64_t
rank
,
const
std
::
unordered_map
<
int64_t
,
std
::
string
>&
rank_to_addr
,
const
std
::
string
&
addr
);
bool
IsInit
()
const
;
// called by Interceptor, send InterceptorMessage to dst
bool
Send
(
int64_t
dst_rank
,
const
InterceptorMessage
&
interceptor_message
);
void
IncreaseBarrierCount
();
void
Barrier
();
bool
DispatchMsgToCarrier
(
const
InterceptorMessage
&
interceptor_message
);
private:
DISABLE_COPY_AND_ASSIGN
(
MessageBus
);
// function keep listen the port and handle the message
void
ListenPort
();
const
std
::
string
&
GetAddr
(
int64_t
rank
)
const
;
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
// send the message inter rank (dst is different rank with src)
bool
SendInterRank
(
int64_t
dst_rank
,
const
InterceptorMessage
&
interceptor_message
);
#endif
bool
is_init_
{
false
};
int64_t
rank_
;
// handed by above layer, save the info mapping rank id to addr
std
::
unordered_map
<
int64_t
,
std
::
string
>
rank_to_addr_
;
// the ip needs to be listened
std
::
string
addr_
;
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
MessageServiceImpl
message_service_
;
// brpc server
brpc
::
Server
server_
;
#endif
// for barrier
std
::
mutex
mutex_
;
std
::
condition_variable
cv_
;
int
count_
{
0
};
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/message_service.cc
0 → 100644
View file @
f0ef3442
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
#include "paddle/fluid/distributed/fleet_executor/message_service.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
namespace
paddle
{
namespace
distributed
{
void
MessageServiceImpl
::
ReceiveInterceptorMessage
(
google
::
protobuf
::
RpcController
*
control_base
,
const
InterceptorMessage
*
request
,
InterceptorResponse
*
response
,
google
::
protobuf
::
Closure
*
done
)
{
brpc
::
ClosureGuard
done_guard
(
done
);
VLOG
(
3
)
<<
"Message Service receives a message from interceptor "
<<
request
->
src_id
()
<<
" to interceptor "
<<
request
->
dst_id
()
<<
", with the message: "
<<
request
->
message_type
();
bool
flag
=
GlobalVal
<
MessageBus
>::
Get
()
->
DispatchMsgToCarrier
(
*
request
);
response
->
set_rst
(
flag
);
}
void
MessageServiceImpl
::
IncreaseBarrierCount
(
google
::
protobuf
::
RpcController
*
control_base
,
const
InterceptorMessage
*
request
,
InterceptorResponse
*
response
,
google
::
protobuf
::
Closure
*
done
)
{
brpc
::
ClosureGuard
done_guard
(
done
);
VLOG
(
3
)
<<
"Barrier Service receives a message from rank "
<<
request
->
src_id
()
<<
" to rank "
<<
request
->
dst_id
();
GlobalVal
<
MessageBus
>::
Get
()
->
IncreaseBarrierCount
();
response
->
set_rst
(
true
);
}
}
// namespace distributed
}
// namespace paddle
#endif
paddle/fluid/distributed/fleet_executor/message_service.h
0 → 100644
View file @
f0ef3442
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
#pragma once
#include "brpc/server.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
namespace
paddle
{
namespace
distributed
{
class
MessageServiceImpl
:
public
MessageService
{
public:
MessageServiceImpl
()
{}
virtual
~
MessageServiceImpl
()
{}
virtual
void
ReceiveInterceptorMessage
(
google
::
protobuf
::
RpcController
*
control_base
,
const
InterceptorMessage
*
request
,
InterceptorResponse
*
response
,
google
::
protobuf
::
Closure
*
done
);
virtual
void
IncreaseBarrierCount
(
google
::
protobuf
::
RpcController
*
control_base
,
const
InterceptorMessage
*
request
,
InterceptorResponse
*
response
,
google
::
protobuf
::
Closure
*
done
);
};
}
// namespace distributed
}
// namespace paddle
#endif
paddle/fluid/distributed/fleet_executor/runtime_graph.cc
0 → 100644
View file @
f0ef3442
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/runtime_graph.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace
paddle
{
namespace
distributed
{
std
::
string
RuntimeGraph
::
DebugString
()
const
{
std
::
ostringstream
os
;
os
<<
"
\n
Runtime Graph Debug:
\n
"
;
for
(
const
auto
&
pair
:
interceptor_id_to_node_
)
{
os
<<
pair
.
second
->
DebugString
();
os
<<
"
\n
"
;
}
return
os
.
str
();
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/runtime_graph.h
0 → 100644
View file @
f0ef3442
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/fleet_executor/fleet_executor_desc.pb.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/platform/macros.h"
namespace
paddle
{
namespace
distributed
{
class
TaskNode
;
class
RuntimeGraph
final
{
public:
RuntimeGraph
()
=
default
;
~
RuntimeGraph
()
=
default
;
const
std
::
unordered_map
<
int64_t
,
TaskNode
*>&
interceptor_id_to_node
()
const
{
return
interceptor_id_to_node_
;
}
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
()
const
{
return
interceptor_id_to_rank_
;
}
void
SetInterceptorIdToRank
(
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
)
{
interceptor_id_to_rank_
=
interceptor_id_to_rank
;
}
void
SetInterceptorIdToNode
(
const
std
::
unordered_map
<
int64_t
,
TaskNode
*>&
interceptor_id_to_node
)
{
interceptor_id_to_node_
=
interceptor_id_to_node
;
}
std
::
string
DebugString
()
const
;
private:
DISABLE_COPY_AND_ASSIGN
(
RuntimeGraph
);
std
::
unordered_map
<
int64_t
,
TaskNode
*>
interceptor_id_to_node_
;
std
::
unordered_map
<
int64_t
,
int64_t
>
interceptor_id_to_rank_
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/sink_interceptor.cc
0 → 100644
View file @
f0ef3442
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/sink_interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace
paddle
{
namespace
distributed
{
SinkInterceptor
::
SinkInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
)
:
Interceptor
(
interceptor_id
,
node
),
max_run_times_
(
node
->
max_run_times
())
{
// prepare the upstream running status
for
(
const
auto
&
up
:
node
->
upstream
())
{
upstream_step_
.
emplace
(
up
.
first
,
0
);
}
RegisterMsgHandle
([
this
](
const
InterceptorMessage
&
msg
)
{
Run
(
msg
);
});
}
void
SinkInterceptor
::
StopCarrierIfComplete
()
{
bool
flag
=
true
;
for
(
const
auto
&
up
:
upstream_step_
)
{
flag
=
flag
&&
(
up
.
second
==
max_run_times_
);
}
if
(
flag
)
{
VLOG
(
3
)
<<
"Sink Interceptor is stopping carrier"
;
StopCarrier
();
for
(
const
auto
&
up
:
upstream_step_
)
{
upstream_step_
.
at
(
up
.
first
)
=
0
;
}
}
}
void
SinkInterceptor
::
ReplyCompletedToUpStream
(
int64_t
upstream_id
)
{
int64_t
micro_step
=
upstream_step_
.
at
(
upstream_id
);
int64_t
scope_idx
=
micro_step
%
max_run_times_
;
InterceptorMessage
msg
;
msg
.
set_message_type
(
DATA_IS_USELESS
);
msg
.
set_scope_idx
(
scope_idx
);
Send
(
upstream_id
,
msg
);
upstream_step_
.
at
(
upstream_id
)
=
micro_step
+
1
;
if
(
micro_step
==
max_run_times_
-
1
)
{
StopCarrierIfComplete
();
}
}
void
SinkInterceptor
::
Run
(
const
InterceptorMessage
&
msg
)
{
if
(
msg
.
message_type
()
==
DATA_IS_READY
)
{
ReplyCompletedToUpStream
(
msg
.
src_id
());
}
}
REGISTER_INTERCEPTOR
(
Sink
,
SinkInterceptor
);
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/sink_interceptor.h
0 → 100644
View file @
f0ef3442
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
namespace
paddle
{
namespace
distributed
{
/*
* Sink interceptor
* There is only one sink in the runtime graph
* Take charge of:
* 1. record the num of micro-step
* 2. check whether to notify carrier the current step is finished
*/
class
SinkInterceptor
:
public
Interceptor
{
public:
SinkInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
);
private:
void
ReplyCompletedToUpStream
(
int64_t
up_id
);
void
Run
(
const
InterceptorMessage
&
msg
);
void
StopCarrierIfComplete
();
int64_t
max_run_times_
;
// upstream_id->cur_step
std
::
map
<
int64_t
,
int64_t
>
upstream_step_
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/source_interceptor.cc
0 → 100644
View file @
f0ef3442
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/source_interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace
paddle
{
namespace
distributed
{
SourceInterceptor
::
SourceInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
)
:
Interceptor
(
interceptor_id
,
node
),
max_run_times_
(
node
->
max_run_times
())
{
// prepare the downstream running status
for
(
const
auto
&
down
:
node
->
downstream
())
{
downstream_step_
.
emplace
(
down
.
first
,
0
);
}
RegisterMsgHandle
([
this
](
const
InterceptorMessage
&
msg
)
{
Run
(
msg
);
});
}
void
SourceInterceptor
::
SendDataReadyToDownStream
(
int64_t
downstream_id
)
{
int64_t
micro_step
=
downstream_step_
.
at
(
downstream_id
);
if
(
micro_step
>=
max_run_times_
)
{
return
;
}
int64_t
scope_idx
=
micro_step
%
max_run_times_
;
InterceptorMessage
ready_msg
;
ready_msg
.
set_message_type
(
DATA_IS_READY
);
ready_msg
.
set_scope_idx
(
scope_idx
);
Send
(
downstream_id
,
ready_msg
);
downstream_step_
.
at
(
downstream_id
)
=
micro_step
+
1
;
}
void
SourceInterceptor
::
Run
(
const
InterceptorMessage
&
msg
)
{
if
(
msg
.
message_type
()
==
START
)
{
// start run in a new step, reset the previous running status
for
(
const
auto
&
down
:
downstream_step_
)
{
downstream_step_
.
at
(
down
.
first
)
=
0
;
SendDataReadyToDownStream
(
down
.
first
);
}
}
else
if
(
msg
.
message_type
()
==
DATA_IS_USELESS
)
{
SendDataReadyToDownStream
(
msg
.
src_id
());
}
}
REGISTER_INTERCEPTOR
(
Source
,
SourceInterceptor
);
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/source_interceptor.h
0 → 100644
View file @
f0ef3442
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
namespace
paddle
{
namespace
distributed
{
/*
* Source interceptor
* There is only one source in the runtime graph
* Take charge of:
* 1. receive `start` message from carrier
* 2. send num_of_steps `data_is_ready` message to downstream
*/
class
SourceInterceptor
:
public
Interceptor
{
public:
SourceInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
);
private:
void
SendDataReadyToDownStream
(
int64_t
down_id
);
void
Run
(
const
InterceptorMessage
&
msg
);
int64_t
max_run_times_
;
// downstream_id->cur_step
std
::
map
<
int64_t
,
int64_t
>
downstream_step_
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/task_loop.cc
0 → 100644
View file @
f0ef3442
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/task_loop.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
namespace
paddle
{
namespace
distributed
{
thread_local
TaskLoop
*
TaskLoop
::
thread_local_loop_
=
nullptr
;
TaskLoop
*
TaskLoop
::
GetTaskLoopOfCurrentThread
()
{
return
thread_local_loop_
;
}
TaskLoop
::
TaskLoop
()
:
looping_
(
false
),
quit_
(
false
),
thread_id_
(
std
::
this_thread
::
get_id
())
{
PADDLE_ENFORCE_EQ
(
thread_local_loop_
,
nullptr
,
platform
::
errors
::
AlreadyExists
(
"Another TaskLoop is already init."
));
thread_local_loop_
=
this
;
}
TaskLoop
::~
TaskLoop
()
{
thread_local_loop_
=
nullptr
;
}
void
TaskLoop
::
Loop
()
{
PADDLE_ENFORCE_EQ
(
looping_
,
false
,
platform
::
errors
::
PreconditionNotMet
(
"Loop can only execute in one loop thread"
));
AssertInLoopThread
();
looping_
=
true
;
quit_
=
false
;
while
(
!
quit_
)
{
auto
tasks
=
tasks_
.
PopAll
();
for
(
auto
&
task
:
tasks
)
{
task
();
}
}
looping_
=
false
;
}
void
TaskLoop
::
Quit
()
{
quit_
=
true
;
if
(
!
IsInLoopThread
())
WakeUp
();
}
void
TaskLoop
::
RunInLoop
(
Functor
cb
)
{
if
(
IsInLoopThread
())
{
cb
();
}
else
{
QueueInLoop
(
cb
);
}
}
void
TaskLoop
::
QueueInLoop
(
Functor
cb
)
{
tasks_
.
Push
(
cb
);
}
void
TaskLoop
::
WakeUp
()
{
Functor
task
([]
{});
QueueInLoop
(
task
);
}
void
TaskLoop
::
AbortNotInLoopThread
()
{
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"This TaskLoop was created in thread %d, but current thread is %d"
,
thread_id_
,
std
::
this_thread
::
get_id
()));
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/task_loop.h
0 → 100644
View file @
f0ef3442
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <future>
#include <map>
#include <thread>
#include <vector>
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/platform/macros.h"
namespace
paddle
{
namespace
distributed
{
class
TaskLoop
{
public:
static
TaskLoop
*
GetTaskLoopOfCurrentThread
();
using
Functor
=
std
::
function
<
void
()
>
;
TaskLoop
();
~
TaskLoop
();
void
Loop
();
void
Quit
();
void
RunInLoop
(
Functor
cb
);
void
QueueInLoop
(
Functor
cb
);
template
<
class
F
,
class
...
Args
>
auto
Enqueue
(
F
&&
f
,
Args
&&
...
args
)
->
std
::
future
<
typename
std
::
result_of
<
F
(
Args
...)
>::
type
>
{
using
return_type
=
typename
std
::
result_of
<
F
(
Args
...)
>::
type
;
auto
task
=
std
::
make_shared
<
std
::
packaged_task
<
return_type
()
>>
(
std
::
bind
(
std
::
forward
<
F
>
(
f
),
std
::
forward
<
Args
>
(
args
)...));
std
::
future
<
return_type
>
task_future
=
task
->
get_future
();
tasks_
.
Push
([
task
]()
{
(
*
task
)();
});
return
task_future
;
}
void
WakeUp
();
bool
IsInLoopThread
()
const
{
return
thread_id_
==
std
::
this_thread
::
get_id
();
}
void
AssertInLoopThread
()
{
if
(
!
IsInLoopThread
())
{
AbortNotInLoopThread
();
}
}
private:
DISABLE_COPY_AND_ASSIGN
(
TaskLoop
);
void
AbortNotInLoopThread
();
static
thread_local
TaskLoop
*
thread_local_loop_
;
bool
looping_
;
std
::
atomic
<
bool
>
quit_
;
std
::
thread
::
id
thread_id_
;
framework
::
BlockingQueue
<
Functor
>
tasks_
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/task_loop_thread.cc
0 → 100644
View file @
f0ef3442
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/task_loop_thread.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
namespace
paddle
{
namespace
distributed
{
TaskLoopThread
::
TaskLoopThread
()
:
start_
(
false
),
loop_
(
nullptr
)
{}
TaskLoopThread
::~
TaskLoopThread
()
{
if
(
loop_
!=
nullptr
)
{
loop_
->
Quit
();
thread_
.
join
();
}
}
TaskLoop
*
TaskLoopThread
::
StartLoop
()
{
PADDLE_ENFORCE_EQ
(
start_
,
false
,
platform
::
errors
::
PreconditionNotMet
(
"thread is already running."
));
start_
=
true
;
thread_
=
std
::
thread
([
this
]()
{
Loop
();
});
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
cv_
.
wait
(
lock
,
[
=
]
{
return
loop_
!=
nullptr
;
});
return
loop_
;
}
void
TaskLoopThread
::
Loop
()
{
TaskLoop
loop
;
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
loop_
=
&
loop
;
cv_
.
notify_one
();
}
loop
.
Loop
();
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
loop_
=
nullptr
;
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/task_loop_thread.h
0 → 100644
View file @
f0ef3442
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <condition_variable>
#include <mutex>
#include <thread>
#include "paddle/fluid/platform/macros.h"
namespace
paddle
{
namespace
distributed
{
class
TaskLoop
;
class
TaskLoopThread
{
public:
TaskLoopThread
();
~
TaskLoopThread
();
TaskLoop
*
StartLoop
();
private:
DISABLE_COPY_AND_ASSIGN
(
TaskLoopThread
);
void
Loop
();
bool
start_
;
TaskLoop
*
loop_
;
std
::
thread
thread_
;
std
::
mutex
mutex_
;
std
::
condition_variable
cv_
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.cc
0 → 100644
View file @
f0ef3442
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop_thread.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
namespace
paddle
{
namespace
distributed
{
TaskLoopThreadPool
::
TaskLoopThreadPool
()
:
TaskLoopThreadPool
(
1
)
{}
TaskLoopThreadPool
::
TaskLoopThreadPool
(
int
thread_num
)
:
start_
(
false
),
thread_num_
(
thread_num
)
{}
TaskLoopThreadPool
::~
TaskLoopThreadPool
()
=
default
;
void
TaskLoopThreadPool
::
Start
()
{
PADDLE_ENFORCE_EQ
(
start_
,
false
,
platform
::
errors
::
PreconditionNotMet
(
"thread pool is already start."
));
PADDLE_ENFORCE_GT
(
thread_num_
,
0
,
platform
::
errors
::
InvalidArgument
(
"thread num must greater than 0, but now is %d"
,
thread_num_
));
start_
=
true
;
for
(
int
i
=
0
;
i
<
thread_num_
;
++
i
)
{
threads_
.
emplace_back
(
new
TaskLoopThread
());
loops_
.
push_back
(
threads_
[
i
]
->
StartLoop
());
}
}
TaskLoop
*
TaskLoopThreadPool
::
GetLoop
(
int
tid
)
{
PADDLE_ENFORCE_EQ
(
start_
,
true
,
platform
::
errors
::
PreconditionNotMet
(
"thread pool must start first."
));
PADDLE_ENFORCE_GE
(
tid
,
0
,
platform
::
errors
::
OutOfRange
(
"tid must >= 0, but now is %d"
,
tid
));
PADDLE_ENFORCE_LT
(
tid
,
thread_num_
,
platform
::
errors
::
OutOfRange
(
"tid must < thread_num, but now tid=%d thread_num=%d"
,
tid
,
thread_num_
));
return
loops_
[
tid
];
}
std
::
vector
<
TaskLoop
*>
TaskLoopThreadPool
::
GetAllLoops
()
{
PADDLE_ENFORCE_EQ
(
start_
,
true
,
platform
::
errors
::
PreconditionNotMet
(
"thread pool must start first."
));
return
loops_
;
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h
0 → 100644
View file @
f0ef3442
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <vector>
#include "paddle/fluid/platform/macros.h"
namespace
paddle
{
namespace
distributed
{
class
TaskLoop
;
class
TaskLoopThread
;
class
TaskLoopThreadPool
{
public:
TaskLoopThreadPool
();
explicit
TaskLoopThreadPool
(
int
thread_num
);
~
TaskLoopThreadPool
();
void
SetThreadNum
(
int
thread_num
)
{
thread_num_
=
thread_num
;
}
void
Start
();
TaskLoop
*
GetLoop
(
int
tid
);
std
::
vector
<
TaskLoop
*>
GetAllLoops
();
private:
DISABLE_COPY_AND_ASSIGN
(
TaskLoopThreadPool
);
bool
start_
;
int
thread_num_
;
std
::
vector
<
std
::
unique_ptr
<
TaskLoopThread
>>
threads_
;
std
::
vector
<
TaskLoop
*>
loops_
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/task_node.cc
0 → 100644
View file @
f0ef3442
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace
paddle
{
namespace
distributed
{
namespace
{
using
OperatorBase
=
TaskNode
::
OperatorBase
;
}
TaskNode
::
TaskNode
(
paddle
::
framework
::
ProgramDesc
*
program
,
int64_t
rank
,
int64_t
max_run_times
,
int64_t
max_slot_nums
)
:
program_
(
program
),
rank_
(
rank
),
max_run_times_
(
max_run_times
),
max_slot_nums_
(
max_slot_nums
)
{
// Should be serially invoked, not thread-safe
// NOTE: when instantiate TaskNode with program, won't init task node
// immediately, since the provided program may be updated later (with
// high probability) by adding_feed_fetch_ops or by RuntimeGraph.
// So, delay the init part to the Init() function.
static
int64_t
task_node_cnt
=
0
;
task_id_
=
task_node_cnt
++
;
}
TaskNode
::
TaskNode
(
paddle
::
framework
::
ProgramDesc
*
program
,
int64_t
rank
,
int64_t
task_id
,
int64_t
max_run_times
,
int64_t
max_slot_nums
)
:
program_
(
program
),
rank_
(
rank
),
task_id_
(
task_id
),
max_run_times_
(
max_run_times
),
max_slot_nums_
(
max_slot_nums
)
{
// TODO(liyurui): Will be removed when execute program is supported.
Init
();
}
TaskNode
::
TaskNode
(
paddle
::
framework
::
ProgramDesc
*
program
,
int64_t
rank
)
:
program_
(
program
),
rank_
(
rank
),
task_id_
(
rank
)
{
max_run_times_
=
1
;
max_slot_nums_
=
1
;
LOG
(
INFO
)
<<
"Constructing TaskNode for DistModelInf. The TaskNode's id is: "
<<
rank
<<
". And the TaskNode's max_run_time and max_slot_num will be set to 1."
;
}
void
TaskNode
::
SetProgram
(
paddle
::
framework
::
ProgramDesc
*
program
)
{
program_
=
program
;
}
void
TaskNode
::
Init
(
bool
use_feed_fetch_ops
)
{
if
(
!
use_feed_fetch_ops
)
{
VLOG
(
3
)
<<
"TaskNode will be inited without feed and fetch ops"
;
}
if
(
ops_
.
empty
())
{
// Q (for fleet executor dev): should we need another reset funct?
VLOG
(
3
)
<<
"Task node will be inited by calling Init()."
;
for
(
const
auto
&
op_desc
:
program_
->
Block
(
0
).
AllOps
())
{
if
(
!
use_feed_fetch_ops
&&
(
op_desc
->
Type
()
==
"feed"
||
op_desc
->
Type
()
==
"fetch"
))
{
VLOG
(
3
)
<<
"TaskNode will skip ["
<<
op_desc
->
Input
(
"X"
)[
0
]
<<
"], "
<<
op_desc
->
Type
()
<<
" -> "
<<
op_desc
->
Output
(
"Out"
)[
0
];
continue
;
}
ops_vec_
.
emplace_back
(
framework
::
OpRegistry
::
CreateOp
(
*
op_desc
));
}
for
(
const
auto
&
op
:
ops_vec_
)
{
ops_
.
emplace_back
(
op
.
get
());
}
}
}
TaskNode
::
TaskNode
(
int64_t
rank
,
int64_t
task_id
,
int64_t
max_run_times
)
:
rank_
(
rank
),
task_id_
(
task_id
),
max_run_times_
(
max_run_times
)
{}
TaskNode
::
TaskNode
(
int32_t
role
,
const
std
::
vector
<
framework
::
OpDesc
*>&
op_descs
,
int64_t
rank
,
int64_t
task_id
,
int64_t
max_run_times
,
int64_t
max_slot_nums
)
:
role_
(
role
),
rank_
(
rank
),
task_id_
(
task_id
),
max_run_times_
(
max_run_times
),
max_slot_nums_
(
max_slot_nums
)
{
if
(
op_descs
.
empty
())
{
return
;
}
VLOG
(
3
)
<<
"Task node will be inited by providing list of ops."
;
for
(
const
auto
&
desc
:
op_descs
)
{
ops_vec_
.
emplace_back
(
framework
::
OpRegistry
::
CreateOp
(
*
desc
));
}
for
(
const
auto
&
op
:
ops_vec_
)
{
ops_
.
emplace_back
(
op
.
get
());
}
}
TaskNode
::
TaskNode
(
int32_t
role
,
const
std
::
vector
<
framework
::
OperatorBase
*>&
ops
,
int64_t
rank
,
int64_t
task_id
,
int64_t
max_run_times
,
int64_t
max_slot_nums
)
:
ops_
(
ops
),
role_
(
role
),
rank_
(
rank
),
task_id_
(
task_id
),
max_run_times_
(
max_run_times
),
max_slot_nums_
(
max_slot_nums
)
{}
TaskNode
::
TaskNode
(
int32_t
role
,
int64_t
rank
,
int64_t
task_id
,
int64_t
max_run_times
,
int64_t
max_slot_nums
)
:
role_
(
role
),
rank_
(
rank
),
task_id_
(
task_id
),
max_run_times_
(
max_run_times
),
max_slot_nums_
(
max_slot_nums
)
{}
bool
TaskNode
::
AddUpstreamTask
(
int64_t
task_id
,
int64_t
buff_size
)
{
const
auto
&
ret
=
upstream_
.
emplace
(
task_id
,
buff_size
);
return
ret
.
second
;
}
bool
TaskNode
::
AddDownstreamTask
(
int64_t
task_id
,
int64_t
buff_size
)
{
const
auto
&
ret
=
downstream_
.
emplace
(
task_id
,
buff_size
);
return
ret
.
second
;
}
std
::
string
TaskNode
::
DebugString
()
const
{
std
::
ostringstream
os
;
os
<<
"role: "
<<
role_
<<
", task_id: "
<<
task_id_
<<
"
\n
"
;
for
(
std
::
size_t
i
=
0
;
i
<
ops_
.
size
();
++
i
)
{
os
<<
ops_
[
i
]
->
Type
()
<<
" "
;
}
os
<<
"
\n
"
;
return
os
.
str
();
}
void
TaskNode
::
SetRunPerSteps
(
int64_t
value
)
{
PADDLE_ENFORCE_GE
(
value
,
1
,
platform
::
errors
::
InvalidArgument
(
"run_per_steps must >= 1, but received %ld"
,
value
));
run_per_steps_
=
value
;
}
void
TaskNode
::
SetRunAtOffset
(
int64_t
value
)
{
PADDLE_ENFORCE_GE
(
value
,
0
,
platform
::
errors
::
InvalidArgument
(
"run_at_offset must >= 0, but received %ld"
,
value
));
run_at_offset_
=
value
;
}
void
TaskNode
::
SetReplyUpPerSteps
(
int64_t
value
)
{
PADDLE_ENFORCE_GE
(
value
,
1
,
platform
::
errors
::
InvalidArgument
(
"reply_up_per_steps must >= 1, but received %ld"
,
value
));
reply_up_per_steps_
=
value
;
}
void
TaskNode
::
SetSendDownPerSteps
(
int64_t
value
)
{
PADDLE_ENFORCE_GE
(
value
,
1
,
platform
::
errors
::
InvalidArgument
(
"send_down_per_steps must >= 1, but received %ld"
,
value
));
send_down_per_steps_
=
value
;
}
}
// namespace distributed
}
// namespace paddle
Prev
1
…
6
7
8
9
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment