Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Paddle
Commits
de2e6515
Commit
de2e6515
authored
Apr 26, 2023
by
yuguo960516yuguo
Browse files
2.4.1-dtk-23.04
parent
ad08b8ce
Pipeline
#228
failed with stages
in 0 seconds
Changes
272
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2963 additions
and
0 deletions
+2963
-0
paddle/fluid/distributed/fleet_executor/amplifier_interceptor.cc
...fluid/distributed/fleet_executor/amplifier_interceptor.cc
+60
-0
paddle/fluid/distributed/fleet_executor/amplifier_interceptor.h
.../fluid/distributed/fleet_executor/amplifier_interceptor.h
+43
-0
paddle/fluid/distributed/fleet_executor/carrier.cc
paddle/fluid/distributed/fleet_executor/carrier.cc
+313
-0
paddle/fluid/distributed/fleet_executor/carrier.h
paddle/fluid/distributed/fleet_executor/carrier.h
+124
-0
paddle/fluid/distributed/fleet_executor/compute_interceptor.cc
...e/fluid/distributed/fleet_executor/compute_interceptor.cc
+280
-0
paddle/fluid/distributed/fleet_executor/compute_interceptor.h
...le/fluid/distributed/fleet_executor/compute_interceptor.h
+62
-0
paddle/fluid/distributed/fleet_executor/dist_model.cc
paddle/fluid/distributed/fleet_executor/dist_model.cc
+653
-0
paddle/fluid/distributed/fleet_executor/dist_model.h
paddle/fluid/distributed/fleet_executor/dist_model.h
+107
-0
paddle/fluid/distributed/fleet_executor/dist_model_tensor_wrapper.cc
...d/distributed/fleet_executor/dist_model_tensor_wrapper.cc
+102
-0
paddle/fluid/distributed/fleet_executor/dist_model_tensor_wrapper.h
...id/distributed/fleet_executor/dist_model_tensor_wrapper.h
+84
-0
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
+196
-0
paddle/fluid/distributed/fleet_executor/fleet_executor.h
paddle/fluid/distributed/fleet_executor/fleet_executor.h
+67
-0
paddle/fluid/distributed/fleet_executor/fleet_executor_desc.proto
...luid/distributed/fleet_executor/fleet_executor_desc.proto
+26
-0
paddle/fluid/distributed/fleet_executor/global.h
paddle/fluid/distributed/fleet_executor/global.h
+117
-0
paddle/fluid/distributed/fleet_executor/interceptor.cc
paddle/fluid/distributed/fleet_executor/interceptor.cc
+122
-0
paddle/fluid/distributed/fleet_executor/interceptor.h
paddle/fluid/distributed/fleet_executor/interceptor.h
+162
-0
paddle/fluid/distributed/fleet_executor/interceptor_message.proto
...luid/distributed/fleet_executor/interceptor_message.proto
+43
-0
paddle/fluid/distributed/fleet_executor/message_bus.cc
paddle/fluid/distributed/fleet_executor/message_bus.cc
+256
-0
paddle/fluid/distributed/fleet_executor/message_bus.h
paddle/fluid/distributed/fleet_executor/message_bus.h
+95
-0
paddle/fluid/distributed/fleet_executor/message_service.cc
paddle/fluid/distributed/fleet_executor/message_service.cc
+51
-0
No files found.
Too many changes to show.
To preserve performance only
272 of 272+
files are displayed.
Plain diff
Email patch
paddle/fluid/distributed/fleet_executor/amplifier_interceptor.cc
0 → 100644
View file @
de2e6515
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/amplifier_interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/operator.h"
namespace
paddle
{
namespace
distributed
{
AmplifierInterceptor
::
AmplifierInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
)
:
ComputeInterceptor
(
interceptor_id
,
node
)
{
run_per_steps_
=
node
->
run_per_steps
();
run_at_offset_
=
node
->
run_at_offset
();
reply_up_per_steps_
=
node
->
reply_up_per_steps
();
send_down_per_steps_
=
node
->
send_down_per_steps
();
}
void
AmplifierInterceptor
::
RunOps
()
{
// run_per_steps_, run_at_offset_
// 4, 0 --> run at step 0, 4, 8, 12
// 4, 3 --> run at step 3, 7, 11, 15
if
((
step_
%
run_per_steps_
)
==
run_at_offset_
)
{
ComputeInterceptor
::
RunOps
();
}
}
void
AmplifierInterceptor
::
SendDataReadyToDownStream
()
{
// run multi times, send ready one times to downstream, that is
// input multi times, output one times
if
(
step_
%
send_down_per_steps_
==
0
)
{
ComputeInterceptor
::
SendDataReadyToDownStream
();
}
}
void
AmplifierInterceptor
::
ReplyCompletedToUpStream
()
{
// run multi times, reply one times to upstream, that is
// input one times, output multi times
if
(
step_
%
reply_up_per_steps_
==
0
)
{
ComputeInterceptor
::
ReplyCompletedToUpStream
();
}
}
REGISTER_INTERCEPTOR
(
Amplifier
,
AmplifierInterceptor
);
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/amplifier_interceptor.h
0 → 100644
View file @
de2e6515
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <utility>
#include "paddle/fluid/distributed/fleet_executor/compute_interceptor.h"
namespace
paddle
{
namespace
distributed
{
class
AmplifierInterceptor
:
public
ComputeInterceptor
{
public:
AmplifierInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
);
private:
void
RunOps
()
override
;
void
SendDataReadyToDownStream
()
override
;
void
ReplyCompletedToUpStream
()
override
;
int64_t
run_per_steps_
{
1
};
int64_t
run_at_offset_
{
0
};
// one input produces multi times output
int64_t
reply_up_per_steps_
{
1
};
// one output need multi times input
int64_t
send_down_per_steps_
{
1
};
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/carrier.cc
0 → 100644
View file @
de2e6515
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include <algorithm>
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/runtime_graph.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable_helper.h"
namespace
paddle
{
namespace
distributed
{
USE_INTERCEPTOR
(
Source
);
USE_INTERCEPTOR
(
Compute
);
USE_INTERCEPTOR
(
Amplifier
);
USE_INTERCEPTOR
(
Sink
);
void
Carrier
::
Init
(
int64_t
rank
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
)
{
rank_
=
rank
;
interceptor_id_to_rank_
=
interceptor_id_to_rank
;
// TODO(fleet_exe dev): thread pool
thread_num_
=
1
;
thread_pool_
.
SetThreadNum
(
thread_num_
);
thread_pool_
.
Start
();
}
void
Carrier
::
Init
(
int64_t
rank
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
,
const
std
::
unordered_map
<
int64_t
,
TaskNode
*>&
interceptor_id_to_node
,
const
framework
::
ProgramDesc
&
program
,
framework
::
Scope
*
scope
,
int64_t
num_micro_batches
,
const
platform
::
Place
&
place
,
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
)
{
rank_
=
rank
;
interceptor_id_to_rank_
=
interceptor_id_to_rank
;
interceptor_id_to_node_
=
interceptor_id_to_node
;
place_
=
place
;
root_scope_
=
scope
;
dev_ctx_
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
place_
);
PADDLE_ENFORCE_NOT_NULL
(
root_scope_
,
platform
::
errors
::
InvalidArgument
(
"root_scope can not be nullptr"
));
minibatch_scope_
=
&
root_scope_
->
NewScope
();
microbatch_scopes_
.
resize
(
num_micro_batches
);
for
(
int
i
=
0
;
i
<
num_micro_batches
;
++
i
)
{
microbatch_scopes_
[
i
]
=
&
minibatch_scope_
->
NewScope
();
CopyParameters
(
i
,
program
,
inference_root_scope_vars
);
}
// TODO(fleet_exe dev): thread pool
thread_num_
=
1
;
thread_pool_
.
SetThreadNum
(
thread_num_
);
thread_pool_
.
Start
();
CreateInterceptors
();
is_init_
=
true
;
}
void
Carrier
::
Release
()
{
if
(
root_scope_
)
{
root_scope_
->
DropKids
();
}
}
Carrier
::~
Carrier
()
{
VLOG
(
3
)
<<
"Carrier's destructor."
;
}
void
Carrier
::
CopyParameters
(
int
microbatch_id
,
const
framework
::
ProgramDesc
&
program
,
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
)
{
auto
&
global_block
=
program
.
Block
(
0
);
std
::
map
<
std
::
string
,
int
>
inference_root_scope_var_map
;
for
(
auto
var_name
:
inference_root_scope_vars
)
{
inference_root_scope_var_map
.
insert
({
var_name
,
1
});
}
for
(
auto
&
var
:
global_block
.
AllVars
())
{
std
::
string
var_name
=
var
->
Name
();
bool
force_root
=
inference_root_scope_var_map
.
find
(
var_name
)
!=
inference_root_scope_var_map
.
end
();
if
(
force_root
)
{
VLOG
(
4
)
<<
var_name
<<
" will be forced to be created in the root scope."
;
}
if
((
var
->
Persistable
()
||
force_root
)
&&
microbatch_id
==
0
)
{
auto
*
ptr
=
root_scope_
->
Var
(
var
->
Name
());
InitializeVariable
(
ptr
,
var
->
GetType
());
VLOG
(
5
)
<<
"Create persistable var: "
<<
var
->
Name
()
<<
", which pointer is "
<<
ptr
;
}
else
if
(
!
var
->
Persistable
())
{
auto
*
ptr
=
microbatch_scopes_
[
microbatch_id
]
->
Var
(
var
->
Name
());
VLOG
(
5
)
<<
"Create variable "
<<
var
->
Name
()
<<
" for microbatch "
<<
microbatch_id
<<
", which pointer is "
<<
ptr
<<
"."
;
InitializeVariable
(
ptr
,
var
->
GetType
());
}
}
}
bool
Carrier
::
EnqueueInterceptorMessage
(
const
InterceptorMessage
&
interceptor_message
)
{
PADDLE_ENFORCE_EQ
(
interceptor_message
.
ctrl_message
(),
false
,
platform
::
errors
::
Fatal
(
"Control message should be only send inter rank using message bus."
));
int64_t
dst_id
=
interceptor_message
.
dst_id
();
Interceptor
*
dst_interceptor
=
GetInterceptor
(
dst_id
);
dst_interceptor
->
EnqueueRemoteInterceptorMessage
(
interceptor_message
);
return
true
;
}
Interceptor
*
Carrier
::
GetInterceptor
(
int64_t
interceptor_id
)
{
auto
iter
=
interceptor_idx_to_interceptor_
.
find
(
interceptor_id
);
PADDLE_ENFORCE_NE
(
iter
,
interceptor_idx_to_interceptor_
.
end
(),
platform
::
errors
::
InvalidArgument
(
"Cannot find interceptor instance for interceptor "
"id %lld. Wrong dst? Call before init?"
,
interceptor_id
));
return
iter
->
second
.
get
();
}
void
Carrier
::
Wait
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
running_mutex_
);
cond_var_
.
wait
(
lock
);
}
void
Carrier
::
WakeUp
()
{
// probably double notify, but ok for ut
cond_var_
.
notify_all
();
}
void
Carrier
::
Start
()
{
PADDLE_ENFORCE_EQ
(
is_init_
,
true
,
platform
::
errors
::
PreconditionNotMet
(
"Using carrier before initialized."
));
for
(
int64_t
id
:
source_interceptor_ids_
)
{
VLOG
(
3
)
<<
"Carrier Start is sending start to source interceptor "
<<
id
<<
"."
;
InterceptorMessage
start_msg
;
// source node data_is_ready is send by carrier, so set src_id=-1
start_msg
.
set_src_id
(
-
1
);
start_msg
.
set_dst_id
(
id
);
start_msg
.
set_message_type
(
DATA_IS_READY
);
Send
(
start_msg
);
}
// TODO(wangxi): async step
Wait
();
dev_ctx_
->
Wait
();
for
(
auto
*
micro_scope
:
microbatch_scopes_
)
{
// By default, we should delete all kid scopes after run executor because
// some operators may create local scope when running, such as while_op.
// But when while_op also create a local executor to run it's sub block,
// the sub scopes it created should not be dropped immediately, because
// while_grad_op will use some variables created during while_op run, so
// we need to keep the kids and wait for the outer executor to drop them.
micro_scope
->
DropKids
();
}
}
bool
Carrier
::
IsInit
()
const
{
return
is_init_
;
}
int64_t
Carrier
::
GetRank
(
int64_t
interceptor_id
)
const
{
PADDLE_ENFORCE_NE
(
interceptor_id_to_rank_
.
find
(
interceptor_id
),
interceptor_id_to_rank_
.
end
(),
platform
::
errors
::
NotFound
(
"Cannot find rank for interceptor id %lld."
,
interceptor_id
));
return
interceptor_id_to_rank_
.
at
(
interceptor_id
);
}
bool
Carrier
::
Send
(
const
InterceptorMessage
&
msg
)
{
int64_t
src_id
=
msg
.
src_id
();
// TODO(liyurui): compatible solution, will be removed completely in the
// future
if
(
interceptor_id_to_rank_
.
find
(
src_id
)
==
interceptor_id_to_rank_
.
end
()
&&
src_id
==
SOURCE_ID
)
{
src_id
=
msg
.
dst_id
();
}
int64_t
dst_id
=
msg
.
dst_id
();
int64_t
src_rank
=
GetRank
(
src_id
);
int64_t
dst_rank
=
GetRank
(
dst_id
);
PADDLE_ENFORCE_EQ
(
src_rank
,
rank_
,
platform
::
errors
::
Fatal
(
"The source rank id %lld, which is not equal to "
"the carrier rank id %lld."
,
src_rank
,
rank_
));
if
(
src_rank
==
dst_rank
)
{
VLOG
(
3
)
<<
"Send a message from interceptor "
<<
src_id
<<
" to interceptor "
<<
dst_id
<<
", which are in the same ranks."
;
return
EnqueueInterceptorMessage
(
msg
);
}
else
{
VLOG
(
3
)
<<
"Send a message from interceptor "
<<
src_id
<<
" to interceptor "
<<
dst_id
<<
", which are in different ranks."
;
return
GlobalVal
<
MessageBus
>::
Get
()
->
Send
(
dst_rank
,
msg
);
}
}
Interceptor
*
Carrier
::
SetInterceptor
(
int64_t
interceptor_id
,
std
::
unique_ptr
<
Interceptor
>
interceptor
)
{
auto
iter
=
interceptor_idx_to_interceptor_
.
find
(
interceptor_id
);
PADDLE_ENFORCE_EQ
(
iter
,
interceptor_idx_to_interceptor_
.
end
(),
platform
::
errors
::
AlreadyExists
(
"The interceptor id %lld has already been created! "
"The interceptor id should be unique."
,
interceptor_id
));
interceptor
->
RegisterCarrier
(
this
);
// TODO(fleet_exe dev): get loop
auto
*
loop
=
thread_pool_
.
GetLoop
(
interceptor_id
%
thread_num_
);
PADDLE_ENFORCE_NOT_NULL
(
loop
,
platform
::
errors
::
Fatal
(
"thread task loop must not null"
));
interceptor
->
RegisterTaskLoop
(
loop
);
auto
*
ptr
=
interceptor
.
get
();
interceptor_idx_to_interceptor_
.
insert
(
std
::
make_pair
(
interceptor_id
,
std
::
move
(
interceptor
)));
return
ptr
;
}
static
std
::
shared_ptr
<
framework
::
GarbageCollector
>
GetGC
(
const
platform
::
Place
&
place
)
{
int64_t
max_memory_size
=
framework
::
GetEagerDeletionThreshold
();
std
::
shared_ptr
<
framework
::
GarbageCollector
>
gc
;
if
(
max_memory_size
>=
0
)
{
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if
(
platform
::
is_gpu_place
(
place
))
{
if
(
framework
::
IsFastEagerDeletionModeEnabled
())
{
gc
.
reset
(
new
framework
::
UnsafeFastGPUGarbageCollector
(
place
,
max_memory_size
));
}
}
#endif
}
// max_memory_size >= 0
return
gc
;
}
void
Carrier
::
CreateInterceptors
()
{
if
(
interceptor_id_to_node_
.
empty
())
return
;
auto
gc
=
GetGC
(
place_
);
// create each Interceptor
// no auto init since there is no config
for
(
const
auto
&
item
:
interceptor_id_to_node_
)
{
int64_t
interceptor_id
=
item
.
first
;
TaskNode
*
task_node
=
item
.
second
;
PADDLE_ENFORCE_LT
(
task_node
->
run_at_offset
(),
task_node
->
run_per_steps
(),
platform
::
errors
::
InvalidArgument
(
"Interceptor's run_at_offset must < run_per_steps, must now "
"run_at_offset=%ld run_per_steps=%ld"
,
task_node
->
run_at_offset
(),
task_node
->
run_per_steps
()));
std
::
unique_ptr
<
Interceptor
>
interceptor
;
PADDLE_ENFORCE_NE
(
task_node
->
type
().
empty
(),
true
,
platform
::
errors
::
NotFound
(
"Cannot found type for task node with id %lld"
,
task_node
->
task_id
()));
interceptor
=
InterceptorFactory
::
Create
(
task_node
->
type
(),
interceptor_id
,
task_node
);
interceptor
->
SetPlace
(
place_
);
interceptor
->
SetMiniBatchScope
(
minibatch_scope_
);
interceptor
->
SetMicroBatchScope
(
microbatch_scopes_
);
interceptor
->
SetRootScope
(
root_scope_
);
interceptor
->
SetGC
(
gc
);
SetInterceptor
(
interceptor_id
,
std
::
move
(
interceptor
));
VLOG
(
3
)
<<
"Create Interceptor with interceptor id: "
<<
interceptor_id
<<
" with type: "
<<
task_node
->
type
()
<<
"."
;
if
(
task_node
->
upstream
().
empty
())
{
source_interceptor_ids_
.
emplace_back
(
interceptor_id
);
}
}
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/carrier.h
0 → 100644
View file @
de2e6515
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <condition_variable>
#include <memory>
#include <mutex>
#include <set>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/place.h"
namespace
paddle
{
namespace
framework
{
class
Scope
;
class
ProgramDesc
;
}
// namespace framework
namespace
distributed
{
class
TaskNode
;
class
InterceptorMessageServiceImpl
;
class
RuntimeGraph
;
class
MessageBus
;
// TODO(liyurui): Add CarrierId instead of std::string
class
Carrier
final
{
public:
explicit
Carrier
(
const
std
::
string
&
carrier_id
)
:
carrier_id_
(
carrier_id
)
{}
~
Carrier
();
void
Init
(
int64_t
rank
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
);
void
Init
(
int64_t
rank
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
interceptor_id_to_rank
,
const
std
::
unordered_map
<
int64_t
,
TaskNode
*>&
interceptor_id_to_node
,
const
framework
::
ProgramDesc
&
program
,
framework
::
Scope
*
scope
,
int64_t
num_micro_batches
,
const
platform
::
Place
&
place
,
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
=
{});
void
CopyParameters
(
int
microbatch_id
,
const
framework
::
ProgramDesc
&
program
,
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
);
void
Release
();
void
Wait
();
void
WakeUp
();
// Enqueue a message to corresponding interceptor id
bool
EnqueueInterceptorMessage
(
const
InterceptorMessage
&
interceptor_message
);
// get interceptor based on the interceptor id
Interceptor
*
GetInterceptor
(
int64_t
interceptor_id
);
// set interceptor with interceptor id
Interceptor
*
SetInterceptor
(
int64_t
interceptor_id
,
std
::
unique_ptr
<
Interceptor
>
);
void
Start
();
bool
IsInit
()
const
;
bool
Send
(
const
InterceptorMessage
&
msg
);
private:
DISABLE_COPY_AND_ASSIGN
(
Carrier
);
Carrier
()
=
delete
;
// create each Interceptor
void
CreateInterceptors
();
int64_t
GetRank
(
int64_t
interceptor_id
)
const
;
// interceptor logic id to actually interceptor
std
::
unordered_map
<
int64_t
,
std
::
unique_ptr
<
Interceptor
>>
interceptor_idx_to_interceptor_
;
std
::
vector
<
int64_t
>
source_interceptor_ids_
;
bool
is_init_
{
false
};
std
::
mutex
running_mutex_
;
std
::
condition_variable
cond_var_
;
std
::
vector
<
framework
::
Scope
*>
microbatch_scopes_
;
framework
::
Scope
*
root_scope_
{
nullptr
};
framework
::
Scope
*
minibatch_scope_
{
nullptr
};
paddle
::
platform
::
Place
place_
;
paddle
::
platform
::
DeviceContext
*
dev_ctx_
{
nullptr
};
int64_t
rank_
;
std
::
string
carrier_id_
;
std
::
unordered_map
<
int64_t
,
TaskNode
*>
interceptor_id_to_node_
;
std
::
unordered_map
<
int64_t
,
int64_t
>
interceptor_id_to_rank_
;
int
thread_num_
;
TaskLoopThreadPool
thread_pool_
;
std
::
unordered_set
<
int64_t
>
interceptor_ids_
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/compute_interceptor.cc
0 → 100644
View file @
de2e6515
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/compute_interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/operator.h"
namespace
paddle
{
namespace
distributed
{
ComputeInterceptor
::
ComputeInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
)
:
Interceptor
(
interceptor_id
,
node
)
{
PrepareDeps
();
RegisterMsgHandle
([
this
](
const
InterceptorMessage
&
msg
)
{
Compute
(
msg
);
});
}
void
ComputeInterceptor
::
PrepareDeps
()
{
auto
&
upstream
=
node_
->
upstream
();
auto
&
downstream
=
node_
->
downstream
();
for
(
auto
up
:
upstream
)
{
in_readys_
.
emplace
(
up
.
first
,
std
::
make_pair
(
up
.
second
,
0
));
in_stops_
.
emplace
(
up
.
first
,
false
);
}
for
(
auto
down
:
downstream
)
{
out_buffs_
.
emplace
(
down
.
first
,
std
::
make_pair
(
down
.
second
,
0
));
}
// source compute node, should we add a new SourceInterceptor?
if
(
upstream
.
empty
())
{
is_source_
=
true
;
PADDLE_ENFORCE_GT
(
node_
->
max_run_times
(),
0
,
platform
::
errors
::
InvalidArgument
(
"Source ComputeInterceptor must run at least one "
"times, but now max_run_times=%ld"
,
node_
->
max_run_times
()));
in_readys_
.
emplace
(
-
1
,
std
::
make_pair
(
std
::
numeric_limits
<
int64_t
>::
max
(),
0
));
}
// If there is no downstream or every downstream is in different rank,
// then this interceptor is the last one for current rank.
// This can be get during init, can be cached for later use.
is_last_
=
downstream
.
empty
();
}
void
ComputeInterceptor
::
IncreaseReady
(
int64_t
up_id
)
{
auto
it
=
in_readys_
.
find
(
up_id
);
PADDLE_ENFORCE_NE
(
it
,
in_readys_
.
end
(),
platform
::
errors
::
NotFound
(
"Cannot find upstream=%lld in in_readys."
,
up_id
));
// source node has no upstream, data_is_ready is send by carrier or others
if
(
is_source_
&&
up_id
==
-
1
)
{
it
->
second
.
second
+=
GetTaskNode
()
->
max_run_times
();
return
;
}
auto
max_ready_size
=
it
->
second
.
first
;
auto
ready_size
=
it
->
second
.
second
;
ready_size
+=
1
;
PADDLE_ENFORCE_LE
(
ready_size
,
max_ready_size
,
platform
::
errors
::
OutOfRange
(
"upstream=%lld ready_size must <= max_ready_size, but "
"now ready_size=%lld, max_ready_size=%lld"
,
up_id
,
ready_size
,
max_ready_size
));
it
->
second
.
second
=
ready_size
;
}
void
ComputeInterceptor
::
DecreaseBuff
(
int64_t
down_id
)
{
auto
it
=
out_buffs_
.
find
(
down_id
);
PADDLE_ENFORCE_NE
(
it
,
out_buffs_
.
end
(),
platform
::
errors
::
NotFound
(
"Cannot find downstream=%lld in out_buffs."
,
down_id
));
auto
used_size
=
it
->
second
.
second
;
used_size
-=
1
;
PADDLE_ENFORCE_GE
(
used_size
,
0
,
platform
::
errors
::
OutOfRange
(
"downstream=%lld used buff size must >= 0, but now equal %lld"
,
down_id
,
used_size
));
it
->
second
.
second
=
used_size
;
}
bool
ComputeInterceptor
::
IsInputReady
()
{
for
(
auto
&
ins
:
in_readys_
)
{
auto
ready_size
=
ins
.
second
.
second
;
// not ready, return false
if
(
ready_size
==
0
)
{
VLOG
(
3
)
<<
"Interceptor "
<<
GetInterceptorId
()
<<
"'s upstreams aren't all ready."
;
return
false
;
}
}
return
true
;
}
bool
ComputeInterceptor
::
CanWriteOutput
()
{
for
(
auto
&
outs
:
out_buffs_
)
{
auto
max_buffer_size
=
outs
.
second
.
first
;
auto
used_size
=
outs
.
second
.
second
;
// full, return false
if
(
used_size
==
max_buffer_size
)
{
VLOG
(
3
)
<<
"Interceptor "
<<
GetInterceptorId
()
<<
"'s out buffer is full."
;
return
false
;
}
}
return
true
;
}
void
ComputeInterceptor
::
SendDataReadyToDownStream
()
{
for
(
auto
&
outs
:
out_buffs_
)
{
auto
down_id
=
outs
.
first
;
auto
max_buff_size
=
outs
.
second
.
first
;
auto
used_size
=
outs
.
second
.
second
;
used_size
+=
1
;
PADDLE_ENFORCE_LE
(
used_size
,
max_buff_size
,
platform
::
errors
::
OutOfRange
(
"downstream=%lld used buff size must <= "
"max_buff_size, but now used_size=%lld, "
"max_buff_size=%lld"
,
down_id
,
used_size
,
max_buff_size
));
outs
.
second
.
second
=
used_size
;
InterceptorMessage
ready_msg
;
ready_msg
.
set_message_type
(
DATA_IS_READY
);
VLOG
(
3
)
<<
"ComputeInterceptor "
<<
interceptor_id_
<<
" Send data_is_ready msg to "
<<
down_id
<<
" for step: "
<<
step_
;
Send
(
down_id
,
ready_msg
);
}
}
void
ComputeInterceptor
::
ReplyCompletedToUpStream
()
{
for
(
auto
&
ins
:
in_readys_
)
{
auto
up_id
=
ins
.
first
;
auto
ready_size
=
ins
.
second
.
second
;
ready_size
-=
1
;
PADDLE_ENFORCE_GE
(
ready_size
,
0
,
platform
::
errors
::
OutOfRange
(
"upstream=%lld ready_size must >= 0, but now got %lld"
,
up_id
,
ready_size
));
ins
.
second
.
second
=
ready_size
;
VLOG
(
3
)
<<
"ComputeInterceptor "
<<
interceptor_id_
<<
" Reply data_is_useless msg to "
<<
up_id
<<
" for step: "
<<
step_
;
if
(
is_source_
&&
up_id
==
-
1
)
return
;
InterceptorMessage
reply_msg
;
reply_msg
.
set_message_type
(
DATA_IS_USELESS
);
Send
(
up_id
,
reply_msg
);
}
}
void
ComputeInterceptor
::
RunOps
()
{
VLOG
(
3
)
<<
"ComputeInterceptor "
<<
interceptor_id_
<<
" running ops for the "
<<
step_
+
1
<<
" time."
;
for
(
auto
op
:
node_
->
ops
())
{
op
->
Run
(
*
microbatch_scopes_
[
step_
%
node_
->
max_run_times
()],
place_
);
if
(
gc_
)
{
framework
::
DeleteUnusedTensors
(
*
microbatch_scopes_
[
step_
%
node_
->
max_run_times
()],
op
,
node_
->
unused_vars
(),
gc_
.
get
());
}
}
}
void
ComputeInterceptor
::
Run
()
{
while
(
IsInputReady
()
&&
CanWriteOutput
())
{
VLOG
(
3
)
<<
"id="
<<
GetInterceptorId
()
<<
" ComputeInterceptor running"
;
RunOps
();
++
step_
;
// send to downstream and increase buff used
SendDataReadyToDownStream
();
// reply to upstream and decrease ready data
ReplyCompletedToUpStream
();
// Try to stop Carrier
if
(
is_last_
&&
(
step_
%
node_
->
max_run_times
()
==
0
))
{
VLOG
(
3
)
<<
"Interceptor "
<<
GetInterceptorId
()
<<
" is stopping carrier."
;
// FIXME(wangxi): with multi sink interceptor
StopCarrier
();
}
}
}
void
ComputeInterceptor
::
ReceivedStop
(
int64_t
up_id
)
{
received_stop_
=
true
;
// source node has no upstream, stop is send by carrier or others
if
(
is_source_
&&
up_id
==
-
1
)
return
;
auto
it
=
in_stops_
.
find
(
up_id
);
PADDLE_ENFORCE_NE
(
it
,
in_stops_
.
end
(),
platform
::
errors
::
NotFound
(
"Cannot find upstream=%lld in in_stops."
,
up_id
));
PADDLE_ENFORCE_EQ
(
it
->
second
,
false
,
platform
::
errors
::
AlreadyExists
(
"Already received stop from %lld, stop "
"cannot be send more than once."
));
it
->
second
=
true
;
}
void
ComputeInterceptor
::
TryStop
()
{
if
(
!
received_stop_
)
return
;
// can stop only when all upstream is stop and
// downstream complete
for
(
auto
&
in_stop
:
in_stops_
)
{
if
(
!
in_stop
.
second
)
return
;
}
for
(
auto
&
out_buff
:
out_buffs_
)
{
auto
used_size
=
out_buff
.
second
.
second
;
if
(
used_size
!=
0
)
return
;
}
// send stop to downstream
for
(
auto
&
out
:
out_buffs_
)
{
auto
down_id
=
out
.
first
;
InterceptorMessage
stop
;
stop
.
set_message_type
(
STOP
);
Send
(
down_id
,
stop
);
}
stop_
=
true
;
}
void
ComputeInterceptor
::
Compute
(
const
InterceptorMessage
&
msg
)
{
if
(
msg
.
message_type
()
==
DATA_IS_READY
)
{
IncreaseReady
(
msg
.
src_id
());
Run
();
}
else
if
(
msg
.
message_type
()
==
DATA_IS_USELESS
)
{
DecreaseBuff
(
msg
.
src_id
());
Run
();
}
else
if
(
msg
.
message_type
()
==
STOP
)
{
ReceivedStop
(
msg
.
src_id
());
}
TryStop
();
}
REGISTER_INTERCEPTOR
(
Compute
,
ComputeInterceptor
);
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/compute_interceptor.h
0 → 100644
View file @
de2e6515
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <utility>
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
namespace
paddle
{
namespace
distributed
{
class
ComputeInterceptor
:
public
Interceptor
{
public:
ComputeInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
);
protected:
virtual
void
RunOps
();
virtual
void
SendDataReadyToDownStream
();
virtual
void
ReplyCompletedToUpStream
();
int64_t
step_
{
0
};
private:
void
PrepareDeps
();
void
IncreaseReady
(
int64_t
up_id
);
void
DecreaseBuff
(
int64_t
down_id
);
bool
IsInputReady
();
bool
CanWriteOutput
();
void
Run
();
void
Compute
(
const
InterceptorMessage
&
msg
);
void
ReceivedStop
(
int64_t
up_id
);
void
TryStop
();
bool
is_source_
{
false
};
bool
is_last_
{
false
};
// upstream_id-->(max_ready_size, ready_size)
std
::
map
<
int64_t
,
std
::
pair
<
int64_t
,
int64_t
>>
in_readys_
{};
// downstream_id-->(max_buffer_size, used_size)
std
::
map
<
int64_t
,
std
::
pair
<
int64_t
,
int64_t
>>
out_buffs_
{};
bool
received_stop_
{
false
};
std
::
map
<
int64_t
,
bool
>
in_stops_
{};
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/dist_model.cc
0 → 100644
View file @
de2e6515
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/dist_model.h"
#include <glog/logging.h>
#include <chrono> // NOLINT
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/naive_executor.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
namespace
paddle
{
namespace
distributed
{
namespace
{
bool
IsPersistable
(
const
framework
::
VarDesc
*
var
)
{
if
(
var
->
Persistable
()
&&
var
->
GetType
()
!=
framework
::
proto
::
VarType
::
FEED_MINIBATCH
&&
var
->
GetType
()
!=
framework
::
proto
::
VarType
::
FETCH_LIST
&&
var
->
GetType
()
!=
framework
::
proto
::
VarType
::
RAW
)
{
return
true
;
}
return
false
;
}
bool
LoadDataFromDistModelTensor
(
const
DistModelTensor
&
input_data
,
framework
::
LoDTensor
*
input_tensor
,
const
platform
::
Place
&
place
)
{
VLOG
(
3
)
<<
"Loading data from DistModelTensor for "
<<
input_data
.
name
;
framework
::
DDim
dims
=
phi
::
make_ddim
(
input_data
.
shape
);
void
*
input_tensor_ptr
;
if
(
input_data
.
dtype
==
DistModelDataType
::
INT64
)
{
input_tensor_ptr
=
input_tensor
->
mutable_data
<
int64_t
>
(
dims
,
place
);
}
else
if
(
input_data
.
dtype
==
DistModelDataType
::
FLOAT32
)
{
input_tensor_ptr
=
input_tensor
->
mutable_data
<
float
>
(
dims
,
place
);
}
else
if
(
input_data
.
dtype
==
DistModelDataType
::
INT32
)
{
input_tensor_ptr
=
input_tensor
->
mutable_data
<
int32_t
>
(
dims
,
place
);
}
else
if
(
input_data
.
dtype
==
DistModelDataType
::
FLOAT16
)
{
input_tensor_ptr
=
input_tensor
->
mutable_data
<
float16
>
(
dims
,
place
);
}
else
{
LOG
(
ERROR
)
<<
"unsupported feed type "
<<
input_data
.
dtype
;
return
false
;
}
PADDLE_ENFORCE_NOT_NULL
(
input_tensor_ptr
,
paddle
::
platform
::
errors
::
Fatal
(
"LoDTensor creation failed. DistModel loaded data failed."
));
PADDLE_ENFORCE_NOT_NULL
(
input_data
.
data
.
data
(),
paddle
::
platform
::
errors
::
InvalidArgument
(
"DistModelTensor contains no data."
));
if
(
platform
::
is_cpu_place
(
place
))
{
VLOG
(
3
)
<<
"Loading data for CPU."
;
std
::
memcpy
(
static_cast
<
void
*>
(
input_tensor_ptr
),
input_data
.
data
.
data
(),
input_data
.
data
.
length
());
}
else
if
(
platform
::
is_gpu_place
(
place
))
{
VLOG
(
3
)
<<
"Loading data for GPU."
;
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
*
dev_ctx
=
dynamic_cast
<
const
phi
::
GPUContext
*>
(
pool
.
Get
(
place
));
auto
gpu_place
=
place
;
memory
::
Copy
(
gpu_place
,
static_cast
<
void
*>
(
input_tensor_ptr
),
platform
::
CPUPlace
(),
input_data
.
data
.
data
(),
input_data
.
data
.
length
(),
dev_ctx
->
stream
());
#else
PADDLE_THROW
(
paddle
::
platform
::
errors
::
Fatal
(
"Paddle wasn't compiled with CUDA, but place is GPU."
));
#endif
}
else
if
(
platform
::
is_xpu_place
(
place
))
{
VLOG
(
3
)
<<
"Loading data for XPU."
;
#if defined(PADDLE_WITH_XPU)
auto
xpu_place
=
place
;
memory
::
Copy
(
xpu_place
,
static_cast
<
void
*>
(
input_tensor_ptr
),
platform
::
CPUPlace
(),
input_data
.
data
.
data
(),
input_data
.
data
.
length
());
#else
PADDLE_THROW
(
paddle
::
platform
::
errors
::
Fatal
(
"Paddle wasn't compiled with XPU, but place is XPU."
));
#endif
}
else
{
PADDLE_THROW
(
paddle
::
platform
::
errors
::
InvalidArgument
(
"DistModel only supports CPU and GPU and XPU."
));
}
framework
::
LoD
dst_lod
;
for
(
auto
&
src_lod
:
input_data
.
lod
)
{
dst_lod
.
emplace_back
(
src_lod
);
}
input_tensor
->
set_lod
(
dst_lod
);
return
true
;
}
std
::
string
DistModelDTypeToString
(
DistModelDataType
dtype
)
{
switch
(
dtype
)
{
case
DistModelDataType
::
FLOAT32
:
return
"float32"
;
case
DistModelDataType
::
FLOAT16
:
return
"float16"
;
case
DistModelDataType
::
INT64
:
return
"int64"
;
case
DistModelDataType
::
INT32
:
return
"int32"
;
case
DistModelDataType
::
INT8
:
return
"int8"
;
}
return
"NOT SUPPORT DTYPE"
;
}
class
DistModelTimer
{
public:
void
tic
()
{
tic_time
=
std
::
chrono
::
high_resolution_clock
::
now
();
}
double
toc
()
{
std
::
chrono
::
high_resolution_clock
::
time_point
toc_time
=
std
::
chrono
::
high_resolution_clock
::
now
();
std
::
chrono
::
duration
<
double
>
time_elapse
=
std
::
chrono
::
duration_cast
<
std
::
chrono
::
duration
<
double
>>
(
toc_time
-
tic_time
);
double
time_elapse_in_ms
=
static_cast
<
double
>
(
time_elapse
.
count
())
*
1000.0
;
return
time_elapse_in_ms
;
}
private:
std
::
chrono
::
high_resolution_clock
::
time_point
tic_time
;
};
}
// namespace
bool
DistModel
::
Init
()
{
carrier_id_
=
"inference"
;
bool
init_method
=
(
!
config_
.
model_dir
.
empty
()
||
config_
.
program_desc
);
PADDLE_ENFORCE_EQ
(
init_method
,
true
,
platform
::
errors
::
InvalidArgument
(
"One of model dir or program desc must be provided to "
"dist model inference."
));
if
(
config_
.
program_desc
)
{
PADDLE_ENFORCE_NOT_NULL
(
config_
.
scope
,
platform
::
errors
::
InvalidArgument
(
"Scope must be provided to dist model inference if "
"program desc has been provided."
));
}
if
(
!
PreparePlace
())
{
return
false
;
}
if
(
!
config_
.
program_desc
)
{
if
(
config_
.
scope
)
{
LOG
(
WARNING
)
<<
"The provided scope will be ignored if model dir has "
"also been provided."
;
}
if
(
!
PrepareScope
())
{
return
false
;
}
if
(
!
PrepareProgram
())
{
return
false
;
}
}
else
{
program_
.
reset
(
config_
.
program_desc
);
scope_
.
reset
(
config_
.
scope
);
}
if
(
!
PrepareFeedAndFetch
())
{
return
false
;
}
if
(
config_
.
nranks
>
1
&&
!
CommInit
())
{
return
false
;
}
if
(
!
PrepareFleetExe
())
{
return
false
;
}
return
true
;
}
bool
DistModel
::
PreparePlace
()
{
if
(
config_
.
place
==
"GPU"
)
{
place_
=
paddle
::
platform
::
CUDAPlace
(
config_
.
device_id
);
}
else
if
(
config_
.
place
==
"CPU"
)
{
place_
=
paddle
::
platform
::
CPUPlace
();
}
else
if
(
config_
.
place
==
"XPU"
)
{
place_
=
paddle
::
platform
::
XPUPlace
(
config_
.
device_id
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Place must be choosen from GPU or CPU or XPU, but got %s."
,
config_
.
place
));
}
return
true
;
}
bool
DistModel
::
CommInit
()
{
std
::
unique_ptr
<
framework
::
ProgramDesc
>
comm_init_program
(
new
framework
::
ProgramDesc
());
framework
::
BlockDesc
*
comm_init_block
=
comm_init_program
->
MutableBlock
(
0
);
std
::
vector
<
int64_t
>
&
ring_ids
=
config_
.
rank_to_ring_ids_
[
config_
.
local_rank
];
int64_t
order
=
0
;
std
::
string
var_name_base
=
"comm_init_"
;
for
(
int64_t
ring_id
:
ring_ids
)
{
VLOG
(
3
)
<<
"Init comm for ring id: "
<<
ring_id
;
int64_t
ranks_in_group
=
config_
.
ring_id_to_ranks_
[
ring_id
].
size
();
int64_t
rank_in_group
=
0
;
std
::
vector
<
int64_t
>
&
ranks
=
config_
.
ring_id_to_ranks_
[
ring_id
];
for
(
int64_t
rank
:
ranks
)
{
if
(
config_
.
local_rank
==
rank
)
{
break
;
}
rank_in_group
+=
1
;
}
std
::
vector
<
std
::
string
>
peer_endpoints
;
for
(
int64_t
rank
:
ranks
)
{
if
(
config_
.
local_rank
==
rank
)
{
continue
;
}
peer_endpoints
.
emplace_back
(
config_
.
trainer_endpoints
[
rank
]);
}
InsertCommOp
(
var_name_base
+
std
::
to_string
(
order
),
ranks_in_group
,
rank_in_group
,
peer_endpoints
,
comm_init_block
,
ring_id
);
order
+=
1
;
}
framework
::
NaiveExecutor
e
(
place_
);
e
.
CreateVariables
(
*
comm_init_program
,
0
,
true
,
scope_
.
get
());
e
.
Prepare
(
scope_
.
get
(),
*
comm_init_program
,
0
,
false
);
e
.
Run
();
VLOG
(
3
)
<<
"Comm init successful."
;
return
true
;
}
void
DistModel
::
InsertCommOp
(
std
::
string
tmp_var_name
,
int
nranks
,
int
rank
,
const
std
::
vector
<
std
::
string
>
&
peer_endpoints
,
framework
::
BlockDesc
*
block
,
int
ring_id
)
{
/*
* tmp_var_name: the var name for var comm_id
* nranks: number of total ranks
* rank: the rank of local rank in the comm group
* peer_endpoints: peer's endpoints
* block: the block where to insert the comm ops
* ring_id: the ring_id to be inited
*/
std
::
string
&
endpoint
=
config_
.
current_endpoint
;
std
::
stringstream
ss
;
ss
<<
"Init comm with tmp var: "
<<
tmp_var_name
<<
". The ring id is: "
<<
ring_id
<<
". The group has: "
<<
nranks
<<
" ranks. Current rank in the group is: "
<<
rank
<<
". The endpoint is: "
<<
endpoint
<<
". Peer endpoints are: "
;
for
(
auto
ep
:
peer_endpoints
)
{
ss
<<
ep
<<
", "
;
}
VLOG
(
3
)
<<
ss
.
str
();
if
(
config_
.
place
==
"GPU"
)
{
framework
::
VarDesc
*
new_var
=
block
->
Var
(
tmp_var_name
);
new_var
->
SetType
(
framework
::
proto
::
VarType
::
RAW
);
new_var
->
SetPersistable
(
true
);
framework
::
OpDesc
*
gen_nccl_id_op
=
block
->
AppendOp
();
gen_nccl_id_op
->
SetType
(
"c_gen_nccl_id"
);
gen_nccl_id_op
->
SetOutput
(
"Out"
,
{
tmp_var_name
});
gen_nccl_id_op
->
SetAttr
(
"rank"
,
rank
);
gen_nccl_id_op
->
SetAttr
(
"endpoint"
,
config_
.
current_endpoint
);
gen_nccl_id_op
->
SetAttr
(
"other_endpoints"
,
peer_endpoints
);
gen_nccl_id_op
->
SetAttr
(
"ring_id"
,
ring_id
);
gen_nccl_id_op
->
SetAttr
(
"op_role"
,
static_cast
<
int
>
(
framework
::
OpRole
::
kForward
));
gen_nccl_id_op
->
CheckAttrs
();
framework
::
OpDesc
*
comm_init_op
=
block
->
AppendOp
();
comm_init_op
->
SetType
(
"c_comm_init"
);
comm_init_op
->
SetInput
(
"X"
,
{
tmp_var_name
});
comm_init_op
->
SetAttr
(
"rank"
,
rank
);
comm_init_op
->
SetAttr
(
"nranks"
,
nranks
);
comm_init_op
->
SetAttr
(
"ring_id"
,
ring_id
);
comm_init_op
->
SetAttr
(
"op_role"
,
static_cast
<
int
>
(
framework
::
OpRole
::
kForward
));
comm_init_op
->
CheckAttrs
();
}
else
if
(
config_
.
place
==
"XPU"
)
{
framework
::
VarDesc
*
new_var
=
block
->
Var
(
tmp_var_name
);
new_var
->
SetType
(
framework
::
proto
::
VarType
::
RAW
);
new_var
->
SetPersistable
(
true
);
framework
::
OpDesc
*
gen_bkcl_id_op
=
block
->
AppendOp
();
gen_bkcl_id_op
->
SetType
(
"c_gen_bkcl_id"
);
gen_bkcl_id_op
->
SetOutput
(
"Out"
,
{
tmp_var_name
});
gen_bkcl_id_op
->
SetAttr
(
"rank"
,
rank
);
gen_bkcl_id_op
->
SetAttr
(
"endpoint"
,
config_
.
current_endpoint
);
gen_bkcl_id_op
->
SetAttr
(
"other_endpoints"
,
peer_endpoints
);
gen_bkcl_id_op
->
SetAttr
(
"ring_id"
,
ring_id
);
gen_bkcl_id_op
->
SetAttr
(
"op_role"
,
static_cast
<
int
>
(
framework
::
OpRole
::
kForward
));
gen_bkcl_id_op
->
CheckAttrs
();
framework
::
OpDesc
*
comm_init_op
=
block
->
AppendOp
();
comm_init_op
->
SetType
(
"c_comm_init"
);
comm_init_op
->
SetInput
(
"X"
,
{
tmp_var_name
});
comm_init_op
->
SetAttr
(
"rank"
,
rank
);
comm_init_op
->
SetAttr
(
"nranks"
,
nranks
);
comm_init_op
->
SetAttr
(
"ring_id"
,
ring_id
);
comm_init_op
->
SetAttr
(
"op_role"
,
static_cast
<
int
>
(
framework
::
OpRole
::
kForward
));
comm_init_op
->
CheckAttrs
();
}
else
{
LOG
(
WARNING
)
<<
"DistModelInf doesn't init comm."
;
// TODO(fleet exe dev): comm init for more devices
}
}
bool
DistModel
::
PrepareScope
()
{
scope_
.
reset
(
new
framework
::
Scope
());
return
true
;
}
bool
DistModel
::
PrepareProgram
()
{
if
(
!
LoadProgram
())
{
return
false
;
}
if
(
!
LoadParameters
())
{
return
false
;
}
return
true
;
}
bool
DistModel
::
LoadProgram
()
{
VLOG
(
3
)
<<
"Loading program from "
<<
config_
.
model_dir
;
PADDLE_ENFORCE_NE
(
config_
.
model_dir
,
""
,
platform
::
errors
::
InvalidArgument
(
"Model dir must be provided."
));
std
::
string
model_path
=
config_
.
model_dir
+
".pdmodel"
;
framework
::
proto
::
ProgramDesc
program_proto
;
std
::
string
pb_content
;
// Read binary
std
::
ifstream
fin
(
model_path
,
std
::
ios
::
in
|
std
::
ios
::
binary
);
PADDLE_ENFORCE_EQ
(
static_cast
<
bool
>
(
fin
.
is_open
()),
true
,
platform
::
errors
::
NotFound
(
"Cannot open file %s, please confirm whether the file is normal."
,
model_path
));
fin
.
seekg
(
0
,
std
::
ios
::
end
);
pb_content
.
resize
(
fin
.
tellg
());
fin
.
seekg
(
0
,
std
::
ios
::
beg
);
fin
.
read
(
&
(
pb_content
.
at
(
0
)),
pb_content
.
size
());
fin
.
close
();
program_proto
.
ParseFromString
(
pb_content
);
VLOG
(
5
)
<<
pb_content
;
program_
.
reset
(
new
framework
::
ProgramDesc
(
program_proto
));
return
true
;
}
bool
DistModel
::
LoadParameters
()
{
VLOG
(
3
)
<<
"Loading parameters from "
<<
config_
.
model_dir
;
PADDLE_ENFORCE_NOT_NULL
(
program_
.
get
(),
platform
::
errors
::
PreconditionNotMet
(
"The program should be loaded first."
));
const
auto
&
global_block
=
program_
->
MutableBlock
(
0
);
// create a temporary program to load parameters.
std
::
unique_ptr
<
framework
::
ProgramDesc
>
load_program
(
new
framework
::
ProgramDesc
());
framework
::
BlockDesc
*
load_block
=
load_program
->
MutableBlock
(
0
);
std
::
vector
<
std
::
string
>
params
;
for
(
auto
*
var
:
global_block
->
AllVars
())
{
if
(
IsPersistable
(
var
))
{
VLOG
(
3
)
<<
"persistable variable's name: "
<<
var
->
Name
();
framework
::
VarDesc
*
new_var
=
load_block
->
Var
(
var
->
Name
());
new_var
->
SetShape
(
var
->
GetShape
());
new_var
->
SetDataType
(
var
->
GetDataType
());
new_var
->
SetType
(
var
->
GetType
());
new_var
->
SetLoDLevel
(
var
->
GetLoDLevel
());
new_var
->
SetPersistable
(
true
);
params
.
push_back
(
new_var
->
Name
());
// NOTE: if the params are stored in different files, 'load' op should be
// added here
}
}
std
::
string
param_path
=
config_
.
model_dir
+
".pdiparams"
;
// sort paramlist to have consistent ordering
std
::
sort
(
params
.
begin
(),
params
.
end
());
// append just the load_combine op
framework
::
OpDesc
*
op
=
load_block
->
AppendOp
();
op
->
SetType
(
"load_combine"
);
op
->
SetOutput
(
"Out"
,
params
);
op
->
SetAttr
(
"file_path"
,
{
param_path
});
op
->
CheckAttrs
();
framework
::
NaiveExecutor
e
(
place_
);
// Create all persistable variables in root scope to load them from ckpt.
// Other non-persistable variables will be created in the micro scope
// managed by fleet executor.
e
.
CreateVariables
(
*
program_
,
0
,
true
,
scope_
.
get
());
e
.
Prepare
(
scope_
.
get
(),
*
load_program
,
0
,
false
);
e
.
Run
();
VLOG
(
3
)
<<
"After loading there are "
<<
scope_
->
LocalVarNames
().
size
()
<<
" vars."
;
return
true
;
}
bool
DistModel
::
PrepareFleetExe
()
{
task_node_
.
reset
(
new
TaskNode
(
program_
.
get
(),
config_
.
local_rank
));
// With auto cut, there is no concept of pp, no need to add dependency.
task_node_
->
SetType
(
"Compute"
);
task_node_
->
Init
();
executor_desc_
=
FleetExecutorDesc
();
executor_desc_
.
set_cur_rank
(
config_
.
local_rank
);
std
::
unordered_map
<
int64_t
,
int64_t
>
id_to_rank
;
for
(
int
i
=
0
;
i
<
config_
.
nranks
;
++
i
)
{
RankInfo
*
rank_info
=
executor_desc_
.
add_cluster_info
();
rank_info
->
set_rank
(
i
);
rank_info
->
set_ip_port
(
config_
.
trainer_endpoints
[
i
]);
id_to_rank
.
insert
({
i
,
i
});
}
fleet_exe
.
reset
(
new
FleetExecutor
(
executor_desc_
));
fleet_exe
->
Init
(
carrier_id_
,
*
(
program_
.
get
()),
scope_
.
get
(),
place_
,
1
,
{
task_node_
.
get
()},
id_to_rank
);
return
true
;
}
bool
DistModel
::
PrepareFeedAndFetch
()
{
for
(
auto
*
op
:
program_
->
Block
(
0
).
AllOps
())
{
if
(
op
->
Type
()
==
"feed"
)
{
VLOG
(
3
)
<<
"feed op with feed var: "
<<
op
->
Output
(
"Out"
)[
0
];
int
idx
=
PADDLE_GET_CONST
(
int
,
op
->
GetAttr
(
"col"
));
if
(
feeds_
.
size
()
<=
static_cast
<
size_t
>
(
idx
))
{
feeds_
.
resize
(
idx
+
1
);
}
feeds_
[
idx
]
=
op
;
std
::
string
var_name
=
op
->
Output
(
"Out"
)[
0
];
feed_names_
[
var_name
]
=
idx
;
idx_to_feeds_
[
idx
]
=
var_name
;
framework
::
VarDesc
*
real_var
=
program_
->
Block
(
0
).
FindVar
(
var_name
);
if
(
!
real_var
)
{
LOG
(
ERROR
)
<<
"The output of feed ops ["
<<
var_name
<<
"] cannot be found in the program. Check the inference program."
;
return
false
;
}
if
(
real_var
->
GetDataType
()
==
framework
::
proto
::
VarType
::
FP32
)
{
feeds_to_dtype_
.
insert
({
var_name
,
DistModelDataType
::
FLOAT32
});
}
else
if
(
real_var
->
GetDataType
()
==
framework
::
proto
::
VarType
::
INT32
)
{
feeds_to_dtype_
.
insert
({
var_name
,
DistModelDataType
::
INT32
});
}
else
if
(
real_var
->
GetDataType
()
==
framework
::
proto
::
VarType
::
INT64
)
{
feeds_to_dtype_
.
insert
({
var_name
,
DistModelDataType
::
INT64
});
}
else
if
(
real_var
->
GetDataType
()
==
framework
::
proto
::
VarType
::
FP16
)
{
feeds_to_dtype_
.
insert
({
var_name
,
DistModelDataType
::
FLOAT16
});
}
else
{
LOG
(
ERROR
)
<<
"Don't support feed var dtype for: "
<<
real_var
->
GetDataType
();
return
false
;
}
}
else
if
(
op
->
Type
()
==
"fetch"
)
{
VLOG
(
3
)
<<
"fetch op with fetch var: "
<<
op
->
Input
(
"X"
)[
0
];
int
idx
=
PADDLE_GET_CONST
(
int
,
op
->
GetAttr
(
"col"
));
if
(
fetches_
.
size
()
<=
static_cast
<
size_t
>
(
idx
))
{
fetches_
.
resize
(
idx
+
1
);
}
fetches_
[
idx
]
=
op
;
idx_to_fetches_
[
idx
]
=
op
->
Input
(
"X"
)[
0
];
}
}
if
(
feeds_
.
size
()
==
0
)
{
LOG
(
ERROR
)
<<
"No feed ops in the inf program, please check the program."
;
return
false
;
}
if
(
fetches_
.
size
()
==
0
)
{
LOG
(
ERROR
)
<<
"No fetch op in the inf program, please check the program."
;
return
false
;
}
return
true
;
}
bool
DistModel
::
FeedData
(
const
std
::
vector
<
DistModelTensor
>
&
input_data
,
framework
::
Scope
*
scope
)
{
VLOG
(
3
)
<<
"DistModel is feeding data."
;
if
(
input_data
.
size
()
!=
feeds_
.
size
())
{
LOG
(
ERROR
)
<<
"Should provide "
<<
feeds_
.
size
()
<<
" feeds, but got "
<<
input_data
.
size
()
<<
" data."
;
return
false
;
}
feed_tensors_
.
resize
(
feeds_
.
size
());
for
(
size_t
i
=
0
;
i
<
input_data
.
size
();
++
i
)
{
// feed each data separately
framework
::
LoDTensor
*
input_tensor
=
&
(
feed_tensors_
[
i
]);
if
(
!
LoadDataFromDistModelTensor
(
input_data
[
i
],
input_tensor
,
place_
))
{
LOG
(
ERROR
)
<<
"Fail to load data from tensor "
<<
input_data
[
i
].
name
;
return
false
;
}
std
::
string
target_name
=
input_data
[
i
].
name
;
if
(
feed_names_
.
find
(
target_name
)
==
feed_names_
.
end
())
{
LOG
(
ERROR
)
<<
"The input name ["
<<
target_name
<<
"] cannot be found in the program."
<<
" DistModel loads data failed."
;
return
false
;
}
if
(
input_data
[
i
].
dtype
!=
feeds_to_dtype_
[
target_name
])
{
LOG
(
ERROR
)
<<
"Feed var ["
<<
target_name
<<
"] expected dtype is: "
<<
DistModelDTypeToString
(
feeds_to_dtype_
[
target_name
])
<<
". But received dtype is: "
<<
DistModelDTypeToString
(
input_data
[
i
].
dtype
)
<<
"."
;
return
false
;
}
int
feed_idx
=
feed_names_
[
target_name
];
framework
::
SetFeedVariable
(
scope
,
*
input_tensor
,
"feed"
,
feed_idx
);
}
return
true
;
}
bool
DistModel
::
FetchResults
(
std
::
vector
<
DistModelTensor
>
*
output_data
,
framework
::
Scope
*
scope
)
{
VLOG
(
3
)
<<
"DistModel is fetch results."
;
output_data
->
resize
(
fetches_
.
size
());
for
(
size_t
i
=
0
;
i
<
fetches_
.
size
();
++
i
)
{
int
idx
=
PADDLE_GET_CONST
(
int
,
fetches_
[
i
]
->
GetAttr
(
"col"
));
VLOG
(
3
)
<<
"Fetching data for ["
<<
idx_to_fetches_
[
idx
]
<<
"]"
;
PADDLE_ENFORCE_EQ
(
static_cast
<
size_t
>
(
idx
),
i
,
platform
::
errors
::
InvalidArgument
(
"Fetch op's col attr(%d) should be equal to the index(%d)"
,
idx
,
i
));
framework
::
FetchType
&
fetch_var
=
framework
::
GetFetchVariable
(
*
scope
,
"fetch"
,
idx
);
auto
&
fetch
=
PADDLE_GET
(
framework
::
LoDTensor
,
fetch_var
);
auto
type
=
framework
::
TransToProtoVarType
(
fetch
.
dtype
());
auto
output
=
&
(
output_data
->
at
(
i
));
output
->
name
=
idx_to_fetches_
[
idx
];
bool
rst
=
false
;
if
(
type
==
framework
::
proto
::
VarType
::
FP32
)
{
rst
=
FetchResult
<
float
>
(
fetch
,
output
);
output
->
dtype
=
DistModelDataType
::
FLOAT32
;
}
else
if
(
type
==
framework
::
proto
::
VarType
::
INT64
)
{
rst
=
FetchResult
<
int64_t
>
(
fetch
,
output
);
output
->
dtype
=
DistModelDataType
::
INT64
;
}
else
if
(
type
==
framework
::
proto
::
VarType
::
INT32
)
{
rst
=
FetchResult
<
int32_t
>
(
fetch
,
output
);
output
->
dtype
=
DistModelDataType
::
INT32
;
}
else
if
(
type
==
framework
::
proto
::
VarType
::
FP16
)
{
rst
=
FetchResult
<
float16
>
(
fetch
,
output
);
output
->
dtype
=
DistModelDataType
::
FLOAT16
;
}
else
{
LOG
(
ERROR
)
<<
"DistModel meets unknown fetch data type. DistModel only "
"supports float32, float16, int64 and int32 fetch type "
"for now."
;
}
if
(
!
rst
)
{
LOG
(
ERROR
)
<<
"DistModel fails to fetch result "
<<
idx_to_fetches_
[
idx
];
return
false
;
}
}
return
true
;
}
template
<
typename
T
>
bool
DistModel
::
FetchResult
(
const
framework
::
LoDTensor
&
fetch
,
DistModelTensor
*
output_data
)
{
auto
shape
=
phi
::
vectorize
(
fetch
.
dims
());
output_data
->
shape
.
assign
(
shape
.
begin
(),
shape
.
end
());
const
T
*
data
=
fetch
.
data
<
T
>
();
int64_t
num_elems
=
fetch
.
numel
();
output_data
->
data
.
Resize
(
num_elems
*
sizeof
(
T
));
// The output of fetch op is always on the cpu, no need switch on place
memcpy
(
output_data
->
data
.
data
(),
data
,
num_elems
*
sizeof
(
T
));
output_data
->
lod
.
clear
();
for
(
auto
&
level
:
fetch
.
lod
())
{
output_data
->
lod
.
emplace_back
(
level
.
begin
(),
level
.
end
());
}
return
true
;
}
bool
DistModel
::
Run
(
const
std
::
vector
<
DistModelTensor
>
&
input_data
,
std
::
vector
<
DistModelTensor
>
*
output_data
)
{
VLOG
(
3
)
<<
"DistModel run for once."
;
DistModelTimer
timer
;
timer
.
tic
();
double
feed_elapse
=
0
;
double
fleet_exe_elapse
=
0
;
double
fetch_elapse
=
0
;
if
(
!
FeedData
(
input_data
,
scope_
.
get
()))
{
LOG
(
ERROR
)
<<
"DistModel failed at feeding data."
;
return
false
;
}
if
(
config_
.
enable_timer
)
{
feed_elapse
=
timer
.
toc
();
LOG
(
INFO
)
<<
"Finish loading data, cost "
<<
feed_elapse
<<
"ms."
;
}
else
{
VLOG
(
3
)
<<
"Finish loading data."
;
}
fleet_exe
->
Run
(
carrier_id_
);
if
(
config_
.
enable_timer
)
{
fleet_exe_elapse
=
timer
.
toc
();
LOG
(
INFO
)
<<
"Finish FleetExe running, cost "
<<
fleet_exe_elapse
-
feed_elapse
<<
"ms."
;
}
else
{
VLOG
(
3
)
<<
"Finish FleetExe running."
;
}
if
(
!
FetchResults
(
output_data
,
scope_
.
get
()))
{
LOG
(
ERROR
)
<<
"DistModel failed at fetching result."
;
return
false
;
}
if
(
config_
.
enable_timer
)
{
fetch_elapse
=
timer
.
toc
();
LOG
(
INFO
)
<<
"Finish fetching data, cost "
<<
fetch_elapse
-
fleet_exe_elapse
<<
"ms."
;
LOG
(
INFO
)
<<
"DistModel finish inf, cost "
<<
fetch_elapse
<<
"ms"
;
}
else
{
VLOG
(
3
)
<<
"Finish fetching data."
;
VLOG
(
3
)
<<
"DistModel finish inf."
;
}
return
true
;
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/dist_model.h
0 → 100644
View file @
de2e6515
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/distributed/fleet_executor/dist_model_tensor_wrapper.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor_desc.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/place.h"
namespace
paddle
{
namespace
framework
{
class
ProgramDesc
;
class
Scope
;
class
BlockDesc
;
}
// namespace framework
namespace
distributed
{
class
TaskNode
;
class
FleetExecutor
;
struct
DistModelConfig
{
std
::
string
model_dir
{};
framework
::
ProgramDesc
*
program_desc
{
nullptr
};
framework
::
Scope
*
scope
{
nullptr
};
std
::
string
place
{};
int64_t
device_id
{
0
};
std
::
vector
<
std
::
string
>
trainer_endpoints
{};
std
::
string
current_endpoint
{};
int64_t
nranks
{
1
};
int64_t
local_rank
{
0
};
bool
enable_timer
{
false
};
std
::
map
<
int64_t
,
std
::
vector
<
int64_t
>>
ring_id_to_ranks_
{};
std
::
map
<
int64_t
,
std
::
vector
<
int64_t
>>
rank_to_ring_ids_
{};
};
class
DistModel
{
public:
explicit
DistModel
(
const
DistModelConfig
&
config
)
:
config_
(
config
)
{}
bool
Init
();
bool
Run
(
const
std
::
vector
<
DistModelTensor
>&
input_data
,
std
::
vector
<
DistModelTensor
>*
output_data
);
~
DistModel
()
=
default
;
private:
DISABLE_COPY_AND_ASSIGN
(
DistModel
);
bool
PrepareScope
();
bool
PrepareProgram
();
bool
LoadProgram
();
bool
LoadParameters
();
bool
PreparePlace
();
bool
CommInit
();
bool
PrepareFeedAndFetch
();
bool
PrepareFleetExe
();
void
InsertCommOp
(
std
::
string
tmp_var_name
,
int
nranks
,
int
rank
,
const
std
::
vector
<
std
::
string
>&
peer_endpoints
,
framework
::
BlockDesc
*
block
,
int
ring_id
);
bool
FeedData
(
const
std
::
vector
<
DistModelTensor
>&
input_data
,
framework
::
Scope
*
scope
);
bool
FetchResults
(
std
::
vector
<
DistModelTensor
>*
output_data
,
framework
::
Scope
*
scope
);
template
<
typename
T
>
bool
FetchResult
(
const
framework
::
LoDTensor
&
fetch
,
DistModelTensor
*
output_data
);
std
::
string
carrier_id_
;
std
::
vector
<
framework
::
LoDTensor
>
feed_tensors_
;
std
::
vector
<
framework
::
OpDesc
*>
feeds_
;
std
::
map
<
std
::
string
,
int64_t
>
feed_names_
;
std
::
map
<
int64_t
,
std
::
string
>
idx_to_feeds_
;
std
::
map
<
std
::
string
,
DistModelDataType
>
feeds_to_dtype_
;
std
::
vector
<
framework
::
OpDesc
*>
fetches_
;
std
::
map
<
int64_t
,
std
::
string
>
idx_to_fetches_
;
DistModelConfig
config_
;
FleetExecutorDesc
executor_desc_
;
std
::
shared_ptr
<
FleetExecutor
>
fleet_exe
;
std
::
shared_ptr
<
TaskNode
>
task_node_
;
std
::
shared_ptr
<
framework
::
Scope
>
scope_
;
paddle
::
platform
::
Place
place_
;
std
::
shared_ptr
<
framework
::
ProgramDesc
>
program_
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/dist_model_tensor_wrapper.cc
0 → 100644
View file @
de2e6515
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/dist_model_tensor_wrapper.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
distributed
{
void
DistModelDataBuf
::
Reset
(
void
*
data
,
size_t
length
)
{
Free
();
memory_owned_
=
false
;
data_
=
data
;
length_
=
length
;
}
void
DistModelDataBuf
::
Free
()
{
if
(
memory_owned_
&&
data_
)
{
PADDLE_ENFORCE_GT
(
length_
,
0UL
,
platform
::
errors
::
PreconditionNotMet
(
"Error occurred when deconstruct DistModelDataBuf: "
"it contains no data!"
));
// NOTE: if own the memory, it must be char* type
delete
[]
static_cast
<
char
*>
(
data_
);
data_
=
nullptr
;
length_
=
0
;
}
}
void
DistModelDataBuf
::
Resize
(
size_t
length
)
{
if
(
length_
>=
length
)
{
return
;
}
if
(
memory_owned_
)
{
Free
();
data_
=
new
char
[
length
];
length_
=
length
;
memory_owned_
=
true
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"The memory is allocated externally, can not Resized"
));
}
}
DistModelDataBuf
&
DistModelDataBuf
::
operator
=
(
const
DistModelDataBuf
&
other
)
{
if
(
!
other
.
memory_owned_
)
{
data_
=
other
.
data_
;
length_
=
other
.
length_
;
memory_owned_
=
other
.
memory_owned_
;
}
else
{
Resize
(
other
.
length_
);
if
(
other
.
length
()
&&
other
.
data
())
{
std
::
memcpy
(
data_
,
other
.
data
(),
other
.
length
());
}
else
if
(
other
.
length
())
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Invalid argument, null pointer data with length %u is passed"
,
other
.
length
()));
}
length_
=
other
.
length_
;
memory_owned_
=
true
;
}
return
*
this
;
}
DistModelDataBuf
&
DistModelDataBuf
::
operator
=
(
DistModelDataBuf
&&
other
)
{
data_
=
other
.
data_
;
memory_owned_
=
other
.
memory_owned_
;
length_
=
other
.
length_
;
other
.
data_
=
nullptr
;
other
.
length_
=
0
;
other
.
memory_owned_
=
false
;
return
*
this
;
}
DistModelDataBuf
::
DistModelDataBuf
(
DistModelDataBuf
&&
other
)
:
data_
(
other
.
data_
),
length_
(
other
.
length_
),
memory_owned_
(
other
.
memory_owned_
)
{
other
.
memory_owned_
=
false
;
other
.
data_
=
nullptr
;
other
.
length_
=
0
;
}
DistModelDataBuf
::
DistModelDataBuf
(
const
DistModelDataBuf
&
other
)
{
*
this
=
other
;
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/dist_model_tensor_wrapper.h
0 → 100644
View file @
de2e6515
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/macros.h"
namespace
paddle
{
namespace
distributed
{
enum
DistModelDataType
{
FLOAT16
,
FLOAT32
,
INT64
,
INT32
,
INT8
};
template
<
typename
T
>
constexpr
DistModelDataType
DistModelGetDtype
();
template
<
>
constexpr
DistModelDataType
DistModelGetDtype
<
int32_t
>
()
{
return
DistModelDataType
::
INT32
;
}
template
<
>
constexpr
DistModelDataType
DistModelGetDtype
<
int64_t
>
()
{
return
DistModelDataType
::
INT64
;
}
template
<
>
constexpr
DistModelDataType
DistModelGetDtype
<
float
>
()
{
return
DistModelDataType
::
FLOAT32
;
}
template
<
>
constexpr
DistModelDataType
DistModelGetDtype
<
platform
::
float16
>
()
{
return
DistModelDataType
::
FLOAT16
;
}
class
DistModelDataBuf
{
public:
explicit
DistModelDataBuf
(
size_t
length
)
:
data_
(
new
char
[
length
]),
length_
(
length
),
memory_owned_
(
true
)
{}
DistModelDataBuf
(
void
*
data
,
size_t
length
)
:
data_
(
data
),
length_
(
length
),
memory_owned_
(
false
)
{}
void
Reset
(
void
*
data
,
size_t
length
);
size_t
length
()
const
{
return
length_
;
}
void
*
data
()
const
{
return
data_
;
}
~
DistModelDataBuf
()
{
Free
();
}
DistModelDataBuf
()
=
default
;
void
Resize
(
size_t
length
);
DistModelDataBuf
&
operator
=
(
const
DistModelDataBuf
&
other
);
DistModelDataBuf
&
operator
=
(
DistModelDataBuf
&&
other
);
DistModelDataBuf
(
DistModelDataBuf
&&
other
);
DistModelDataBuf
(
const
DistModelDataBuf
&
other
);
private:
void
Free
();
void
*
data_
{
nullptr
};
size_t
length_
{
0
};
bool
memory_owned_
{
true
};
};
struct
DistModelTensor
{
std
::
string
name
;
std
::
vector
<
int
>
shape
;
DistModelDataBuf
data
;
DistModelDataType
dtype
;
std
::
vector
<
std
::
vector
<
size_t
>>
lod
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
0 → 100644
View file @
de2e6515
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include <algorithm>
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/runtime_graph.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
namespace
paddle
{
namespace
distributed
{
FleetExecutor
::
FleetExecutor
(
const
std
::
string
&
exe_desc_str
)
{
bool
parse_flag
=
exe_desc_
.
ParseFromString
(
exe_desc_str
);
PADDLE_ENFORCE
(
parse_flag
,
platform
::
errors
::
PreconditionNotMet
(
"Error occurs while parsing string to proto"
));
// Message bus will be created and inited only once
GlobalVal
<
MessageBus
>::
Create
();
InitMessageBus
();
}
FleetExecutor
::
FleetExecutor
(
const
FleetExecutorDesc
&
exe_desc
)
:
exe_desc_
(
exe_desc
)
{
// Message bus will be created and inited only once
GlobalVal
<
MessageBus
>::
Create
();
InitMessageBus
();
}
FleetExecutor
::~
FleetExecutor
()
{
for
(
const
auto
&
carrier_id
:
carrier_ids_
)
{
GlobalMap
<
std
::
string
,
Carrier
>::
Get
(
carrier_id
)
->
Release
();
}
}
void
FleetExecutor
::
Init
(
const
std
::
string
&
carrier_id
,
const
framework
::
ProgramDesc
&
program_desc
,
framework
::
Scope
*
scope
,
const
platform
::
Place
&
place
,
int64_t
num_micro_batches
,
const
std
::
vector
<
TaskNode
*>&
task_nodes
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
task_id_to_rank
,
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
)
{
PADDLE_ENFORCE_GT
(
task_nodes
.
size
(),
0
,
platform
::
errors
::
InvalidArgument
(
"Fleet executor is inited with empty task node"
));
// TODO(fleet_exe devs): the unused_vars should be got from run time graph
std
::
vector
<
std
::
unique_ptr
<
framework
::
OperatorBase
>>
ops
;
for
(
auto
task_node
:
task_nodes
)
{
for
(
auto
op
:
task_node
->
ops
())
{
ops
.
emplace_back
(
std
::
unique_ptr
<
framework
::
OperatorBase
>
(
op
));
}
}
auto
unused_vars
=
framework
::
GetUnusedVars
(
program_desc
.
Block
(
0
),
ops
,
{});
// NOTE: For inference, the vars in inference_root_scope_vars
// shouldn't be deleted during inf, for that they may be the result of the
// inf. If they are GCed, it will cause error during ZeroCopy the result.
std
::
vector
<
const
framework
::
OperatorBase
*>
changed_ops
;
for
(
auto
pair
:
unused_vars
)
{
const
framework
::
OperatorBase
*
op
=
pair
.
first
;
std
::
vector
<
std
::
string
>
unused
=
pair
.
second
;
for
(
auto
name
:
inference_root_scope_vars
)
{
auto
iter
=
std
::
find
(
unused
.
begin
(),
unused
.
end
(),
name
);
if
(
iter
!=
unused
.
end
())
{
VLOG
(
3
)
<<
"Removing var: ["
<<
name
<<
"] from the unused vars list of op: ["
<<
op
->
Type
()
<<
"]"
;
unused
.
erase
(
iter
);
if
(
std
::
find
(
changed_ops
.
begin
(),
changed_ops
.
end
(),
op
)
==
changed_ops
.
end
())
{
// record the op whose unused vars have been updated
changed_ops
.
emplace_back
(
op
);
}
}
}
// update the unused vars list in the map
unused_vars
[
op
]
=
unused
;
}
for
(
auto
op
:
changed_ops
)
{
auto
iter
=
unused_vars
.
find
(
op
);
if
(
iter
->
second
.
empty
())
{
// remove those ops in the map that have empty unused vars list
VLOG
(
3
)
<<
"Removing op: ["
<<
op
->
Type
()
<<
"] from unused_vars map."
;
unused_vars
.
erase
(
iter
);
}
}
runtime_graph_
=
std
::
make_shared
<
RuntimeGraph
>
();
std
::
unordered_map
<
int64_t
,
TaskNode
*>
interceptor_id_to_task
;
for
(
auto
task_node
:
task_nodes
)
{
task_node
->
SetUnusedVars
(
unused_vars
);
int64_t
interceptor_id
=
task_node
->
task_id
();
interceptor_id_to_task
.
emplace
(
interceptor_id
,
task_node
);
}
runtime_graph_
->
SetInterceptorIdToRank
(
task_id_to_rank
);
runtime_graph_
->
SetInterceptorIdToNode
(
interceptor_id_to_task
);
for
(
auto
&
unique_op
:
ops
)
{
unique_op
.
release
();
}
VLOG
(
5
)
<<
runtime_graph_
->
DebugString
();
Carrier
*
carrier
=
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
carrier_ids_
.
insert
(
carrier_id
);
// Set current running carrier
GlobalVal
<
std
::
string
>::
Set
(
new
std
::
string
(
carrier_id
));
InitCarrier
(
carrier
,
scope
,
place
,
num_micro_batches
,
program_desc
,
inference_root_scope_vars
);
GlobalVal
<
MessageBus
>::
Get
()
->
Barrier
();
}
void
FleetExecutor
::
InitCarrier
(
Carrier
*
carrier
,
framework
::
Scope
*
scope
,
const
platform
::
Place
&
place
,
int64_t
num_micro_batches
,
const
framework
::
ProgramDesc
&
program_desc
,
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
)
{
carrier
->
Init
(
exe_desc_
.
cur_rank
(),
runtime_graph_
->
interceptor_id_to_rank
(),
runtime_graph_
->
interceptor_id_to_node
(),
program_desc
,
scope
,
num_micro_batches
,
place
,
inference_root_scope_vars
);
}
void
FleetExecutor
::
InitMessageBus
()
{
std
::
stringstream
ss
;
ss
<<
"
\n
The DNS table of the message bus is:
\n
"
;
int64_t
cur_rank
=
exe_desc_
.
cur_rank
();
std
::
unordered_map
<
int64_t
,
std
::
string
>
rank_to_addr
;
std
::
string
addr
;
for
(
const
auto
&
rank_info
:
exe_desc_
.
cluster_info
())
{
// init the dns map
int64_t
rank
=
rank_info
.
rank
();
std
::
string
ip_port
=
rank_info
.
ip_port
();
ss
<<
rank
<<
"
\t
->
\t
"
<<
ip_port
<<
"
\n
"
;
rank_to_addr
.
insert
(
std
::
make_pair
(
rank
,
ip_port
));
if
(
rank
==
cur_rank
)
{
addr
=
ip_port
;
}
}
if
(
addr
==
""
)
{
PADDLE_ENFORCE_EQ
(
rank_to_addr
.
size
(),
1
,
platform
::
errors
::
NotFound
(
"Empty address is not valid for "
"paddle.distributed.launch method."
));
PADDLE_ENFORCE_EQ
(
cur_rank
,
0
,
platform
::
errors
::
NotFound
(
"Address is empty but cur rank is not 0."
));
}
VLOG
(
3
)
<<
"Current rank is "
<<
cur_rank
<<
" and the ip_port is "
<<
(
addr
==
""
?
"empty"
:
addr
)
<<
"."
;
VLOG
(
3
)
<<
"The number of ranks are "
<<
(
rank_to_addr
.
size
()
==
0
?
1
:
rank_to_addr
.
size
())
<<
"."
;
VLOG
(
5
)
<<
ss
.
str
();
GlobalVal
<
MessageBus
>::
Get
()
->
Init
(
cur_rank
,
rank_to_addr
,
addr
);
}
void
FleetExecutor
::
Run
(
const
std
::
string
&
carrier_id
)
{
Carrier
*
carrier
=
GlobalMap
<
std
::
string
,
Carrier
>::
Get
(
carrier_id
);
// Set current running carrier
if
(
*
GlobalVal
<
std
::
string
>::
Get
()
!=
carrier_id
)
{
GlobalVal
<
std
::
string
>::
Set
(
new
std
::
string
(
carrier_id
));
GlobalVal
<
MessageBus
>::
Get
()
->
Barrier
();
}
carrier
->
Start
();
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/fleet_executor.h
0 → 100644
View file @
de2e6515
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor_desc.pb.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/place.h"
namespace
paddle
{
namespace
framework
{
class
ProgramDesc
;
class
Scope
;
}
// namespace framework
namespace
distributed
{
class
RuntimeGraph
;
class
MessageBus
;
class
TaskNode
;
class
FleetExecutor
final
{
public:
FleetExecutor
()
=
delete
;
explicit
FleetExecutor
(
const
std
::
string
&
exe_desc_str
);
explicit
FleetExecutor
(
const
FleetExecutorDesc
&
exe_desc
);
~
FleetExecutor
();
void
Init
(
const
std
::
string
&
carrier_id
,
const
framework
::
ProgramDesc
&
program_desc
,
framework
::
Scope
*
scope
,
const
platform
::
Place
&
place
,
int64_t
num_micro_batches
,
const
std
::
vector
<
TaskNode
*>&
task_nodes
,
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
task_id_to_rank
,
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
=
{});
void
Run
(
const
std
::
string
&
carrier_id
);
private:
DISABLE_COPY_AND_ASSIGN
(
FleetExecutor
);
void
InitMessageBus
();
void
InitCarrier
(
Carrier
*
carrier
,
framework
::
Scope
*
scope
,
const
platform
::
Place
&
place
,
int64_t
num_micro_batches
,
const
framework
::
ProgramDesc
&
program_desc
,
const
std
::
vector
<
std
::
string
>&
inference_root_scope_vars
=
{});
FleetExecutorDesc
exe_desc_
;
std
::
shared_ptr
<
RuntimeGraph
>
runtime_graph_
;
std
::
unordered_set
<
std
::
string
>
carrier_ids_
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/fleet_executor_desc.proto
0 → 100644
View file @
de2e6515
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax
=
"proto2"
;
package
paddle
.
distributed
;
message
RankInfo
{
required
int64
rank
=
1
;
required
string
ip_port
=
2
;
}
message
FleetExecutorDesc
{
optional
int64
cur_rank
=
1
[
default
=
0
];
// Rank id of current processor
repeated
RankInfo
cluster_info
=
2
;
}
paddle/fluid/distributed/fleet_executor/global.h
0 → 100644
View file @
de2e6515
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
distributed
{
template
<
typename
T
>
class
GlobalVal
final
{
public:
static
T
*
Get
()
{
T
*
ptr
=
GetPPtr
()
->
get
();
PADDLE_ENFORCE_NOT_NULL
(
ptr
,
platform
::
errors
::
NotFound
(
"This value is not global value."
));
return
ptr
;
}
template
<
typename
...
Args
>
static
T
*
Create
(
Args
&&
...
args
)
{
auto
*
ptr
=
GetPPtr
();
PADDLE_ENFORCE_EQ
(
ptr
->
get
(),
nullptr
,
platform
::
errors
::
AlreadyExists
(
"This value is already a global value."
));
T
*
item
=
new
T
(
std
::
forward
<
Args
>
(
args
)...);
ptr
->
reset
(
item
);
return
item
;
}
static
T
*
Set
(
T
*
new_item
)
{
auto
*
ptr
=
GetPPtr
();
ptr
->
reset
(
new_item
);
return
ptr
->
get
();
}
private:
static
std
::
unique_ptr
<
T
>*
GetPPtr
()
{
static
std
::
unique_ptr
<
T
>
ptr
;
return
&
ptr
;
}
};
template
<
typename
KeyT
,
typename
ValueT
>
class
GlobalMap
final
{
public:
static
ValueT
*
Get
(
KeyT
id
)
{
ValueT
*
item
=
GetPPtr
(
id
)
->
get
();
PADDLE_ENFORCE_NOT_NULL
(
item
,
platform
::
errors
::
NotFound
(
"This value is not in global map."
));
return
item
;
}
template
<
typename
...
Args
>
static
ValueT
*
Create
(
KeyT
id
,
Args
&&
...
args
)
{
auto
*
ptr
=
GetPPtr
(
id
);
PADDLE_ENFORCE_EQ
(
ptr
->
get
(),
nullptr
,
platform
::
errors
::
AlreadyExists
(
"This value has already in global map."
));
ValueT
*
item
=
new
ValueT
(
std
::
forward
<
Args
>
(
args
)...);
ptr
->
reset
(
item
);
return
item
;
}
private:
static
std
::
unique_ptr
<
ValueT
>*
GetPPtr
(
KeyT
id
)
{
static
std
::
unordered_map
<
KeyT
,
std
::
unique_ptr
<
ValueT
>>
id_to_ptr
;
return
&
id_to_ptr
[
id
];
}
};
template
<
typename
KeyT
,
typename
ValueT
>
class
ThreadSafeGlobalMap
final
{
public:
static
ValueT
*
Get
(
KeyT
id
)
{
ValueT
*
item
=
GetPPtr
(
id
)
->
get
();
PADDLE_ENFORCE_NOT_NULL
(
item
,
platform
::
errors
::
NotFound
(
"This value is not in thread safe global map."
));
return
item
;
}
template
<
typename
...
Args
>
static
ValueT
*
Create
(
KeyT
id
,
Args
&&
...
args
)
{
auto
*
ptr
=
GetPPtr
(
id
);
PADDLE_ENFORCE_EQ
(
ptr
->
get
(),
nullptr
,
platform
::
errors
::
AlreadyExists
(
"This value has already in thread safe global map."
));
ValueT
*
item
=
new
ValueT
(
std
::
forward
<
Args
>
(
args
)...);
ptr
->
reset
(
item
);
return
item
;
}
private:
static
std
::
unique_ptr
<
ValueT
>*
GetPPtr
(
KeyT
id
)
{
static
std
::
mutex
mutex
;
static
std
::
unordered_map
<
KeyT
,
std
::
unique_ptr
<
ValueT
>>
id_to_ptr
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex
);
return
&
id_to_ptr
[
id
];
}
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/interceptor.cc
0 → 100644
View file @
de2e6515
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace
paddle
{
namespace
distributed
{
Interceptor
::
Interceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
)
:
interceptor_id_
(
interceptor_id
),
node_
(
node
)
{}
Interceptor
::~
Interceptor
()
{
// FIXME(wangxi): throw in stop function
// std::lock_guard<std::mutex> lock(mutex_);
// PADDLE_ENFORCE_EQ(messages_.empty(), true,
// platform::errors::PreconditionNotMet(
// "Interceptor must destruct with messages empty"));
}
void
Interceptor
::
RegisterMsgHandle
(
MsgHandle
handle
)
{
handle_
=
handle
;
}
void
Interceptor
::
Handle
(
const
InterceptorMessage
&
msg
)
{
PADDLE_ENFORCE_NOT_NULL
(
handle_
,
platform
::
errors
::
PreconditionNotMet
(
"Message handle is not registered."
));
handle_
(
msg
);
}
void
Interceptor
::
LoopOnce
()
{
std
::
deque
<
InterceptorMessage
>
tmp_messages
;
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
messages_
.
swap
(
tmp_messages
);
}
PADDLE_ENFORCE_EQ
(
tmp_messages
.
empty
(),
false
,
platform
::
errors
::
PreconditionNotMet
(
"tmp_messages must not empty in task loop"
));
for
(
auto
&
msg
:
tmp_messages
)
{
const
MessageType
message_type
=
msg
.
message_type
();
VLOG
(
3
)
<<
"Interceptor "
<<
interceptor_id_
<<
" has received a message"
<<
" from interceptor "
<<
msg
.
src_id
()
<<
" with message: "
<<
message_type
<<
"."
;
Handle
(
msg
);
}
}
void
Interceptor
::
StopCarrier
()
{
PADDLE_ENFORCE_NOT_NULL
(
carrier_
,
platform
::
errors
::
PreconditionNotMet
(
"Carrier is not registered."
));
carrier_
->
WakeUp
();
}
void
Interceptor
::
EnqueueRemoteInterceptorMessage
(
const
InterceptorMessage
&
message
)
{
// Called by Carrier, enqueue an InterceptorMessage to remote mailbox
VLOG
(
3
)
<<
"Enqueue message: "
<<
message
.
message_type
()
<<
" into "
<<
interceptor_id_
<<
"'s remote mailbox."
;
bool
empty
=
false
;
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
empty
=
messages_
.
empty
();
messages_
.
emplace_back
(
message
);
}
if
(
empty
)
{
loop_
->
QueueInLoop
([
this
]()
{
LoopOnce
();
});
}
}
bool
Interceptor
::
Send
(
int64_t
dst_id
,
InterceptorMessage
&
msg
)
{
PADDLE_ENFORCE_NOT_NULL
(
carrier_
,
platform
::
errors
::
PreconditionNotMet
(
"Carrier is not registered."
));
msg
.
set_src_id
(
interceptor_id_
);
msg
.
set_dst_id
(
dst_id
);
return
carrier_
->
Send
(
msg
);
}
static
InterceptorFactory
::
CreateInterceptorMap
&
GetInterceptorMap
()
{
static
InterceptorFactory
::
CreateInterceptorMap
interceptorMap
;
return
interceptorMap
;
}
std
::
unique_ptr
<
Interceptor
>
InterceptorFactory
::
Create
(
const
std
::
string
&
type
,
int64_t
id
,
TaskNode
*
node
)
{
auto
&
interceptor_map
=
GetInterceptorMap
();
auto
iter
=
interceptor_map
.
find
(
type
);
PADDLE_ENFORCE_NE
(
iter
,
interceptor_map
.
end
(),
platform
::
errors
::
NotFound
(
"interceptor %s is not register"
,
type
));
return
iter
->
second
(
id
,
node
);
}
void
InterceptorFactory
::
Register
(
const
std
::
string
&
type
,
InterceptorFactory
::
CreateInterceptorFunc
func
)
{
auto
&
interceptor_map
=
GetInterceptorMap
();
interceptor_map
.
emplace
(
type
,
func
);
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/interceptor.h
0 → 100644
View file @
de2e6515
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <condition_variable>
#include <deque>
#include <functional>
#include <map>
#include <memory>
#include <thread>
#include <vector>
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/place.h"
namespace
paddle
{
namespace
framework
{
class
Scope
;
class
GarbageCollector
;
}
// namespace framework
namespace
distributed
{
class
TaskNode
;
class
Carrier
;
class
TaskLoop
;
constexpr
int64_t
SOURCE_ID
=
-
1
;
constexpr
int64_t
SINK_ID
=
-
2
;
class
Interceptor
{
public:
using
MsgHandle
=
std
::
function
<
void
(
const
InterceptorMessage
&
)
>
;
public:
Interceptor
()
=
delete
;
Interceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
);
virtual
~
Interceptor
();
// register interceptor handle
void
RegisterMsgHandle
(
MsgHandle
handle
);
void
Handle
(
const
InterceptorMessage
&
msg
);
// return the interceptor id
int64_t
GetInterceptorId
()
const
{
return
interceptor_id_
;
}
// Called by Carrier, enqueue an InterceptorMessage to remote mailbox
void
EnqueueRemoteInterceptorMessage
(
const
InterceptorMessage
&
interceptor_message
);
bool
Send
(
int64_t
dst_id
,
InterceptorMessage
&
msg
);
// NOLINT
void
SetPlace
(
const
platform
::
Place
&
place
)
{
place_
=
place
;
}
void
SetRootScope
(
framework
::
Scope
*
scope
)
{
root_scope_
=
scope
;
}
void
SetMiniBatchScope
(
framework
::
Scope
*
scope
)
{
minibatch_scope_
=
scope
;
}
void
SetMicroBatchScope
(
const
std
::
vector
<
framework
::
Scope
*>&
scopes
)
{
microbatch_scopes_
=
scopes
;
}
void
SetGC
(
const
std
::
shared_ptr
<
framework
::
GarbageCollector
>&
gc
)
{
gc_
=
gc
;
}
void
RegisterCarrier
(
Carrier
*
carrier
)
{
carrier_
=
carrier
;
}
void
RegisterTaskLoop
(
TaskLoop
*
loop
)
{
loop_
=
loop
;
}
TaskNode
*
GetTaskNode
()
const
{
return
node_
;
}
DISABLE_COPY_AND_ASSIGN
(
Interceptor
);
protected:
// interceptor id, handed from above layer
int64_t
interceptor_id_
;
// node need to be handled by this interceptor
TaskNode
*
node_
;
// for stop
bool
stop_
{
false
};
void
StopCarrier
();
// for runtime
platform
::
Place
place_
;
framework
::
Scope
*
root_scope_
{
nullptr
};
framework
::
Scope
*
minibatch_scope_
{
nullptr
};
std
::
vector
<
framework
::
Scope
*>
microbatch_scopes_
{};
std
::
shared_ptr
<
framework
::
GarbageCollector
>
gc_
{
nullptr
};
Carrier
*
carrier_
;
TaskLoop
*
loop_
;
private:
void
LoopOnce
();
// interceptor handle which process message
MsgHandle
handle_
{
nullptr
};
std
::
mutex
mutex_
;
std
::
deque
<
InterceptorMessage
>
messages_
;
int64_t
already_run_times_
{
0
};
int64_t
used_slot_nums_
{
0
};
};
class
InterceptorFactory
{
public:
using
CreateInterceptorFunc
=
std
::
unique_ptr
<
Interceptor
>
(
*
)(
int64_t
,
TaskNode
*
);
using
CreateInterceptorMap
=
std
::
unordered_map
<
std
::
string
,
CreateInterceptorFunc
>
;
static
void
Register
(
const
std
::
string
&
type
,
CreateInterceptorFunc
func
);
static
std
::
unique_ptr
<
Interceptor
>
Create
(
const
std
::
string
&
type
,
int64_t
id
,
TaskNode
*
node
);
};
template
<
typename
InterceptorClass
>
std
::
unique_ptr
<
Interceptor
>
CreatorInterceptor
(
int64_t
id
,
TaskNode
*
node
)
{
return
std
::
make_unique
<
InterceptorClass
>
(
id
,
node
);
}
#define REGISTER_INTERCEPTOR(interceptor_type, interceptor_class) \
class __RegisterInterceptor_##interceptor_type { \
public: \
__RegisterInterceptor_##interceptor_type() { \
InterceptorFactory::Register(#interceptor_type, \
CreatorInterceptor<interceptor_class>); \
} \
void Touch() {} \
}; \
__RegisterInterceptor_##interceptor_type g_register_##interceptor_type; \
int TouchRegisterInterceptor_##interceptor_type() { \
g_register_##interceptor_type.Touch(); \
return 0; \
}
#define USE_INTERCEPTOR(interceptor_type) \
extern int TouchRegisterInterceptor_##interceptor_type(); \
UNUSED static int use_interceptor_##interceptor_type = \
TouchRegisterInterceptor_##interceptor_type();
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/interceptor_message.proto
0 → 100644
View file @
de2e6515
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax
=
"proto2"
;
package
paddle
.
distributed
;
option
cc_generic_services
=
true
;
option
cc_enable_arenas
=
true
;
enum
MessageType
{
STOP
=
1
;
// STOP an Interceptor
DATA_IS_READY
=
2
;
// upstream data is ready
DATA_IS_USELESS
=
3
;
// downstream has used the data
ERR
=
4
;
// current Interceptor encounters error
RESET
=
5
;
// reset the status
START
=
6
;
}
message
InterceptorMessage
{
optional
sint64
src_id
=
1
[
default
=
0
];
optional
sint64
dst_id
=
2
[
default
=
0
];
optional
MessageType
message_type
=
3
[
default
=
RESET
];
optional
bool
ctrl_message
=
4
[
default
=
false
];
optional
int64
scope_idx
=
5
[
default
=
0
];
}
message
InterceptorResponse
{
optional
bool
rst
=
1
[
default
=
false
];
}
service
MessageService
{
rpc
ReceiveInterceptorMessage
(
InterceptorMessage
)
returns
(
InterceptorResponse
);
rpc
IncreaseBarrierCount
(
InterceptorMessage
)
returns
(
InterceptorResponse
);
}
paddle/fluid/distributed/fleet_executor/message_bus.cc
0 → 100644
View file @
de2e6515
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include <chrono>
#include <memory>
#include <set>
#include <thread>
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h"
namespace
paddle
{
namespace
distributed
{
void
MessageBus
::
Init
(
int64_t
rank
,
const
std
::
unordered_map
<
int64_t
,
std
::
string
>&
rank_to_addr
,
const
std
::
string
&
addr
)
{
PADDLE_ENFORCE_EQ
(
is_init_
,
false
,
platform
::
errors
::
AlreadyExists
(
"MessageBus is already init."
));
rank_
=
rank
;
is_init_
=
true
;
rank_to_addr_
=
rank_to_addr
;
addr_
=
addr
;
if
(
addr_
!=
""
)
{
const
auto
&
addr
=
GetAddr
(
rank_
);
PADDLE_ENFORCE_EQ
(
addr
,
addr_
,
platform
::
errors
::
Fatal
(
"The current rank's addr is %s, while the "
"message bus's addr is %s, which are different. "
"Init error."
,
addr
,
addr_
));
}
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_ASCEND_CL)
// NOTE: To make the brpc is compatible with collective,
// need release the handler holding the ip address.
if
(
addr_
!=
""
)
{
VLOG
(
3
)
<<
"Message bus is releasing the fd held by gen_comm_id."
;
paddle
::
platform
::
SocketServer
&
socket_server
=
paddle
::
platform
::
SocketServer
::
GetInstance
(
addr_
);
int
server_fd
=
socket_server
.
socket
();
if
(
server_fd
!=
-
1
)
{
socket_server
.
Release
();
}
}
#endif
ListenPort
();
}
bool
MessageBus
::
IsInit
()
const
{
return
is_init_
;
}
MessageBus
::~
MessageBus
()
{
VLOG
(
3
)
<<
"Message bus releases resource."
;
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
server_
.
Stop
(
1000
);
server_
.
Join
();
#endif
}
const
std
::
string
&
MessageBus
::
GetAddr
(
int64_t
rank
)
const
{
PADDLE_ENFORCE_NE
(
rank_to_addr_
.
find
(
rank
),
rank_to_addr_
.
end
(),
platform
::
errors
::
NotFound
(
"Cannot find addr rank id %lld."
,
rank
));
return
rank_to_addr_
.
at
(
rank
);
}
bool
MessageBus
::
Send
(
int64_t
dst_rank
,
const
InterceptorMessage
&
interceptor_message
)
{
PADDLE_ENFORCE_EQ
(
IsInit
(),
true
,
platform
::
errors
::
PreconditionNotMet
(
"Using message bus since it has not been initialized."
));
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
int
retry_time
=
0
;
// message bus will retry sending for 10 times
while
(
retry_time
<
10
)
{
++
retry_time
;
if
(
SendInterRank
(
dst_rank
,
interceptor_message
))
{
VLOG
(
3
)
<<
"Message bus sends inter rank successfully with "
<<
retry_time
<<
" times retries."
;
return
true
;
}
VLOG
(
3
)
<<
"Message bus sends failed, retry after 1 seconds."
;
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
1000
));
}
VLOG
(
3
)
<<
"Message bus sends inter rank fail after 10 times retries."
;
return
false
;
#else
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"Fleet executor does not support sending message between different "
"ranks when Paddle is compiled with npu or "
"isn't compiled with distributed for now."
));
#endif
return
true
;
}
void
MessageBus
::
IncreaseBarrierCount
()
{
VLOG
(
3
)
<<
"IncreaseBarrierCount"
;
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
++
count_
;
cv_
.
notify_one
();
}
VLOG
(
3
)
<<
"End IncreaseBarrierCount"
;
}
void
MessageBus
::
Barrier
()
{
// gather to root
if
(
rank_
!=
0
)
{
InterceptorMessage
ctrl_msg
;
ctrl_msg
.
set_ctrl_message
(
true
);
ctrl_msg
.
set_src_id
(
rank_
);
ctrl_msg
.
set_dst_id
(
0
);
VLOG
(
3
)
<<
"Barrier Gather ctrl message from "
<<
rank_
<<
" to 0"
;
while
(
!
Send
(
0
,
ctrl_msg
))
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
1000
));
}
}
else
{
VLOG
(
3
)
<<
"Barrier 0 wait others rank ready"
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
cv_
.
wait
(
lock
,
[
this
]
{
return
count_
==
static_cast
<
int
>
(
rank_to_addr_
.
size
()
-
1
);
});
count_
=
0
;
}
// scatter from root
if
(
rank_
==
0
)
{
for
(
int
i
=
1
;
i
<
static_cast
<
int
>
(
rank_to_addr_
.
size
());
++
i
)
{
InterceptorMessage
ctrl_msg
;
ctrl_msg
.
set_ctrl_message
(
true
);
ctrl_msg
.
set_src_id
(
0
);
ctrl_msg
.
set_dst_id
(
i
);
VLOG
(
3
)
<<
"Barrier Scatter ctrl message from 0 to "
<<
i
;
while
(
!
Send
(
i
,
ctrl_msg
))
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
1000
));
}
}
}
else
{
VLOG
(
3
)
<<
"Barrier "
<<
rank_
<<
" wait others rank ready"
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
cv_
.
wait
(
lock
,
[
this
]
{
return
count_
==
1
;
});
count_
=
0
;
}
}
bool
MessageBus
::
DispatchMsgToCarrier
(
const
InterceptorMessage
&
interceptor_message
)
{
const
std
::
string
&
carrier_id
=
*
GlobalVal
<
std
::
string
>::
Get
();
return
GlobalMap
<
std
::
string
,
Carrier
>::
Get
(
carrier_id
)
->
EnqueueInterceptorMessage
(
interceptor_message
);
}
void
MessageBus
::
ListenPort
()
{
if
(
addr_
==
""
)
{
LOG
(
INFO
)
<<
"No need listen to port since training on single card."
;
return
;
}
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
// function keep listen the port and handle the message
PADDLE_ENFORCE_EQ
(
server_
.
AddService
(
&
message_service_
,
brpc
::
SERVER_DOESNT_OWN_SERVICE
),
0
,
platform
::
errors
::
Unavailable
(
"Message bus: init brpc service error."
));
// start the server
const
char
*
ip_for_brpc
=
addr_
.
c_str
();
brpc
::
ServerOptions
options
;
options
.
idle_timeout_sec
=
-
1
;
int
retry_times
=
0
;
int
interval
=
100
;
while
(
server_
.
Start
(
ip_for_brpc
,
&
options
)
!=
0
)
{
++
retry_times
;
LOG
(
INFO
)
<<
"Message bus is retring for starting brpc for "
<<
retry_times
<<
" times. And will retry after "
<<
interval
/
1000
<<
" seconds."
;
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
interval
));
interval
+=
500
;
}
LOG
(
INFO
)
<<
"Message bus's listen port thread starts successful."
;
#else
LOG
(
WARNING
)
<<
"Fleet executor's ListenPort() is a fake function when Paddle is "
"compiled with npu or Paddle isn't compiled "
"with distributed for now."
;
#endif
}
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
bool
MessageBus
::
SendInterRank
(
int64_t
dst_rank
,
const
InterceptorMessage
&
interceptor_message
)
{
const
auto
&
dst_addr
=
GetAddr
(
dst_rank
);
VLOG
(
3
)
<<
"Message bus sending to addr: "
<<
dst_addr
;
const
char
*
dst_addr_for_brpc
=
dst_addr
.
c_str
();
brpc
::
Channel
channel
;
brpc
::
ChannelOptions
options
;
options
.
protocol
=
"baidu_std"
;
options
.
connect_timeout_ms
=
1000
;
options
.
timeout_ms
=
1000
;
options
.
max_retry
=
5
;
PADDLE_ENFORCE_EQ
(
channel
.
Init
(
dst_addr_for_brpc
,
&
options
),
0
,
platform
::
errors
::
Unavailable
(
"Message bus: init brpc channel error."
));
MessageService_Stub
stub
(
&
channel
);
InterceptorResponse
response
;
brpc
::
Controller
ctrl
;
ctrl
.
set_log_id
(
0
);
if
(
interceptor_message
.
ctrl_message
())
{
stub
.
IncreaseBarrierCount
(
&
ctrl
,
&
interceptor_message
,
&
response
,
NULL
);
}
else
{
stub
.
ReceiveInterceptorMessage
(
&
ctrl
,
&
interceptor_message
,
&
response
,
NULL
);
}
if
(
!
ctrl
.
Failed
())
{
if
(
response
.
rst
())
{
VLOG
(
3
)
<<
"Message bus: brpc sends success."
;
return
true
;
}
else
{
VLOG
(
4
)
<<
"Message bus: InterceptorMessageService error."
;
return
false
;
}
}
else
{
VLOG
(
4
)
<<
"Message bus: brpc sends failed with error text: "
<<
ctrl
.
ErrorText
();
return
false
;
}
}
#endif
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/message_bus.h
0 → 100644
View file @
de2e6515
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <condition_variable>
#include <mutex>
#include <string>
#include <thread>
#include <unordered_map>
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
#include "brpc/channel.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/fleet_executor/message_service.h"
#endif
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/macros.h"
namespace
paddle
{
namespace
distributed
{
class
Carrier
;
// A singleton MessageBus
class
MessageBus
final
{
public:
MessageBus
()
=
default
;
~
MessageBus
();
void
Init
(
int64_t
rank
,
const
std
::
unordered_map
<
int64_t
,
std
::
string
>&
rank_to_addr
,
const
std
::
string
&
addr
);
bool
IsInit
()
const
;
// called by Interceptor, send InterceptorMessage to dst
bool
Send
(
int64_t
dst_rank
,
const
InterceptorMessage
&
interceptor_message
);
void
IncreaseBarrierCount
();
void
Barrier
();
bool
DispatchMsgToCarrier
(
const
InterceptorMessage
&
interceptor_message
);
private:
DISABLE_COPY_AND_ASSIGN
(
MessageBus
);
// function keep listen the port and handle the message
void
ListenPort
();
const
std
::
string
&
GetAddr
(
int64_t
rank
)
const
;
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
// send the message inter rank (dst is different rank with src)
bool
SendInterRank
(
int64_t
dst_rank
,
const
InterceptorMessage
&
interceptor_message
);
#endif
bool
is_init_
{
false
};
int64_t
rank_
;
// handed by above layer, save the info mapping rank id to addr
std
::
unordered_map
<
int64_t
,
std
::
string
>
rank_to_addr_
;
// the ip needs to be listened
std
::
string
addr_
;
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
MessageServiceImpl
message_service_
;
// brpc server
brpc
::
Server
server_
;
#endif
// for barrier
std
::
mutex
mutex_
;
std
::
condition_variable
cv_
;
int
count_
{
0
};
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/message_service.cc
0 → 100644
View file @
de2e6515
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
#include "paddle/fluid/distributed/fleet_executor/message_service.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
namespace
paddle
{
namespace
distributed
{
void
MessageServiceImpl
::
ReceiveInterceptorMessage
(
google
::
protobuf
::
RpcController
*
control_base
,
const
InterceptorMessage
*
request
,
InterceptorResponse
*
response
,
google
::
protobuf
::
Closure
*
done
)
{
brpc
::
ClosureGuard
done_guard
(
done
);
VLOG
(
3
)
<<
"Message Service receives a message from interceptor "
<<
request
->
src_id
()
<<
" to interceptor "
<<
request
->
dst_id
()
<<
", with the message: "
<<
request
->
message_type
();
bool
flag
=
GlobalVal
<
MessageBus
>::
Get
()
->
DispatchMsgToCarrier
(
*
request
);
response
->
set_rst
(
flag
);
}
void
MessageServiceImpl
::
IncreaseBarrierCount
(
google
::
protobuf
::
RpcController
*
control_base
,
const
InterceptorMessage
*
request
,
InterceptorResponse
*
response
,
google
::
protobuf
::
Closure
*
done
)
{
brpc
::
ClosureGuard
done_guard
(
done
);
VLOG
(
3
)
<<
"Barrier Service receives a message from rank "
<<
request
->
src_id
()
<<
" to rank "
<<
request
->
dst_id
();
GlobalVal
<
MessageBus
>::
Get
()
->
IncreaseBarrierCount
();
response
->
set_rst
(
true
);
}
}
// namespace distributed
}
// namespace paddle
#endif
Prev
1
…
5
6
7
8
9
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment