Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Paddle
Commits
d2d32668
Commit
d2d32668
authored
Apr 26, 2023
by
yuguo960516yuguo
Browse files
2.3.0-dtk-22.04.2
parent
ad08b8ce
Pipeline
#226
failed with stages
in 0 seconds
Changes
268
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1749 additions
and
0 deletions
+1749
-0
paddle/fluid/distributed/fleet_executor/sink_interceptor.cc
paddle/fluid/distributed/fleet_executor/sink_interceptor.cc
+66
-0
paddle/fluid/distributed/fleet_executor/sink_interceptor.h
paddle/fluid/distributed/fleet_executor/sink_interceptor.h
+41
-0
paddle/fluid/distributed/fleet_executor/source_interceptor.cc
...le/fluid/distributed/fleet_executor/source_interceptor.cc
+58
-0
paddle/fluid/distributed/fleet_executor/source_interceptor.h
paddle/fluid/distributed/fleet_executor/source_interceptor.h
+41
-0
paddle/fluid/distributed/fleet_executor/task_loop.cc
paddle/fluid/distributed/fleet_executor/task_loop.cc
+85
-0
paddle/fluid/distributed/fleet_executor/task_loop.h
paddle/fluid/distributed/fleet_executor/task_loop.h
+84
-0
paddle/fluid/distributed/fleet_executor/task_loop_thread.cc
paddle/fluid/distributed/fleet_executor/task_loop_thread.cc
+60
-0
paddle/fluid/distributed/fleet_executor/task_loop_thread.h
paddle/fluid/distributed/fleet_executor/task_loop_thread.h
+48
-0
paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.cc
...fluid/distributed/fleet_executor/task_loop_thread_pool.cc
+77
-0
paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h
.../fluid/distributed/fleet_executor/task_loop_thread_pool.h
+51
-0
paddle/fluid/distributed/fleet_executor/task_node.cc
paddle/fluid/distributed/fleet_executor/task_node.cc
+185
-0
paddle/fluid/distributed/fleet_executor/task_node.h
paddle/fluid/distributed/fleet_executor/task_node.h
+134
-0
paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt
paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt
+72
-0
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc
...ed/fleet_executor/test/compute_interceptor_run_op_test.cc
+115
-0
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc
...stributed/fleet_executor/test/compute_interceptor_test.cc
+85
-0
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc
...ributed/fleet_executor/test/interceptor_ping_pong_test.cc
+81
-0
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc
...eet_executor/test/interceptor_ping_pong_with_brpc_test.cc
+140
-0
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc
...leet_executor/test/interceptor_pipeline_long_path_test.cc
+112
-0
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc
...eet_executor/test/interceptor_pipeline_short_path_test.cc
+123
-0
paddle/fluid/distributed/fleet_executor/test/sink_interceptor_test.cc
.../distributed/fleet_executor/test/sink_interceptor_test.cc
+91
-0
No files found.
Too many changes to show.
To preserve performance only
268 of 268+
files are displayed.
Plain diff
Email patch
paddle/fluid/distributed/fleet_executor/sink_interceptor.cc
0 → 100644
View file @
d2d32668
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/sink_interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace
paddle
{
namespace
distributed
{
SinkInterceptor
::
SinkInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
)
:
Interceptor
(
interceptor_id
,
node
),
max_run_times_
(
node
->
max_run_times
())
{
// prepare the upstream running status
for
(
const
auto
&
up
:
node
->
upstream
())
{
upstream_step_
.
emplace
(
up
.
first
,
0
);
}
RegisterMsgHandle
([
this
](
const
InterceptorMessage
&
msg
)
{
Run
(
msg
);
});
}
void
SinkInterceptor
::
StopCarrierIfComplete
()
{
bool
flag
=
true
;
for
(
const
auto
&
up
:
upstream_step_
)
{
flag
=
flag
&&
(
up
.
second
==
max_run_times_
);
}
if
(
flag
)
{
VLOG
(
3
)
<<
"Sink Interceptor is stopping carrier"
;
StopCarrier
();
for
(
const
auto
&
up
:
upstream_step_
)
{
upstream_step_
.
at
(
up
.
first
)
=
0
;
}
}
}
void
SinkInterceptor
::
ReplyCompletedToUpStream
(
int64_t
upstream_id
)
{
int64_t
micro_step
=
upstream_step_
.
at
(
upstream_id
);
int64_t
scope_idx
=
micro_step
%
max_run_times_
;
InterceptorMessage
msg
;
msg
.
set_message_type
(
DATA_IS_USELESS
);
msg
.
set_scope_idx
(
scope_idx
);
Send
(
upstream_id
,
msg
);
upstream_step_
.
at
(
upstream_id
)
=
micro_step
+
1
;
if
(
micro_step
==
max_run_times_
-
1
)
{
StopCarrierIfComplete
();
}
}
void
SinkInterceptor
::
Run
(
const
InterceptorMessage
&
msg
)
{
if
(
msg
.
message_type
()
==
DATA_IS_READY
)
{
ReplyCompletedToUpStream
(
msg
.
src_id
());
}
}
REGISTER_INTERCEPTOR
(
Sink
,
SinkInterceptor
);
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/sink_interceptor.h
0 → 100644
View file @
d2d32668
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
namespace
paddle
{
namespace
distributed
{
/*
* Sink interceptor
* There is only one sink in the runtime graph
* Take charge of:
* 1. record the num of micro-step
* 2. check whether to notify carrier the current step is finished
*/
class
SinkInterceptor
:
public
Interceptor
{
public:
SinkInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
);
private:
void
ReplyCompletedToUpStream
(
int64_t
up_id
);
void
Run
(
const
InterceptorMessage
&
msg
);
void
StopCarrierIfComplete
();
int64_t
max_run_times_
;
// upstream_id->cur_step
std
::
map
<
int64_t
,
int64_t
>
upstream_step_
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/source_interceptor.cc
0 → 100644
View file @
d2d32668
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/source_interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace
paddle
{
namespace
distributed
{
SourceInterceptor
::
SourceInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
)
:
Interceptor
(
interceptor_id
,
node
),
max_run_times_
(
node
->
max_run_times
())
{
// prepare the downstream running status
for
(
const
auto
&
down
:
node
->
downstream
())
{
downstream_step_
.
emplace
(
down
.
first
,
0
);
}
RegisterMsgHandle
([
this
](
const
InterceptorMessage
&
msg
)
{
Run
(
msg
);
});
}
void
SourceInterceptor
::
SendDataReadyToDownStream
(
int64_t
downstream_id
)
{
int64_t
micro_step
=
downstream_step_
.
at
(
downstream_id
);
if
(
micro_step
>=
max_run_times_
)
{
return
;
}
int64_t
scope_idx
=
micro_step
%
max_run_times_
;
InterceptorMessage
ready_msg
;
ready_msg
.
set_message_type
(
DATA_IS_READY
);
ready_msg
.
set_scope_idx
(
scope_idx
);
Send
(
downstream_id
,
ready_msg
);
downstream_step_
.
at
(
downstream_id
)
=
micro_step
+
1
;
}
void
SourceInterceptor
::
Run
(
const
InterceptorMessage
&
msg
)
{
if
(
msg
.
message_type
()
==
START
)
{
// start run in a new step, reset the previous running status
for
(
const
auto
&
down
:
downstream_step_
)
{
downstream_step_
.
at
(
down
.
first
)
=
0
;
SendDataReadyToDownStream
(
down
.
first
);
}
}
else
if
(
msg
.
message_type
()
==
DATA_IS_USELESS
)
{
SendDataReadyToDownStream
(
msg
.
src_id
());
}
}
REGISTER_INTERCEPTOR
(
Source
,
SourceInterceptor
);
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/source_interceptor.h
0 → 100644
View file @
d2d32668
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
namespace
paddle
{
namespace
distributed
{
/*
* Source interceptor
* There is only one source in the runtime graph
* Take charge of:
* 1. receive `start` message from carrier
* 2. send num_of_steps `data_is_ready` message to downstream
*/
class
SourceInterceptor
:
public
Interceptor
{
public:
SourceInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
);
private:
void
SendDataReadyToDownStream
(
int64_t
down_id
);
void
Run
(
const
InterceptorMessage
&
msg
);
int64_t
max_run_times_
;
// downstream_id->cur_step
std
::
map
<
int64_t
,
int64_t
>
downstream_step_
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/task_loop.cc
0 → 100644
View file @
d2d32668
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/task_loop.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
namespace
paddle
{
namespace
distributed
{
thread_local
TaskLoop
*
TaskLoop
::
thread_local_loop_
=
nullptr
;
TaskLoop
*
TaskLoop
::
GetTaskLoopOfCurrentThread
()
{
return
thread_local_loop_
;
}
TaskLoop
::
TaskLoop
()
:
looping_
(
false
),
quit_
(
false
),
thread_id_
(
std
::
this_thread
::
get_id
())
{
PADDLE_ENFORCE_EQ
(
thread_local_loop_
,
nullptr
,
platform
::
errors
::
AlreadyExists
(
"Another TaskLoop is already init."
));
thread_local_loop_
=
this
;
}
TaskLoop
::~
TaskLoop
()
{
thread_local_loop_
=
nullptr
;
}
void
TaskLoop
::
Loop
()
{
PADDLE_ENFORCE_EQ
(
looping_
,
false
,
platform
::
errors
::
PreconditionNotMet
(
"Loop can only execute in one loop thread"
));
AssertInLoopThread
();
looping_
=
true
;
quit_
=
false
;
while
(
!
quit_
)
{
auto
tasks
=
tasks_
.
PopAll
();
for
(
auto
&
task
:
tasks
)
{
task
();
}
}
looping_
=
false
;
}
void
TaskLoop
::
Quit
()
{
quit_
=
true
;
if
(
!
IsInLoopThread
())
WakeUp
();
}
void
TaskLoop
::
RunInLoop
(
Functor
cb
)
{
if
(
IsInLoopThread
())
{
cb
();
}
else
{
QueueInLoop
(
cb
);
}
}
void
TaskLoop
::
QueueInLoop
(
Functor
cb
)
{
tasks_
.
Push
(
cb
);
}
void
TaskLoop
::
WakeUp
()
{
Functor
task
([]
{});
QueueInLoop
(
task
);
}
void
TaskLoop
::
AbortNotInLoopThread
()
{
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"This TaskLoop was created in thread %d, but current thread is %d"
,
thread_id_
,
std
::
this_thread
::
get_id
()));
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/task_loop.h
0 → 100644
View file @
d2d32668
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <future>
#include <map>
#include <thread>
#include <vector>
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/platform/macros.h"
namespace
paddle
{
namespace
distributed
{
class
TaskLoop
{
public:
static
TaskLoop
*
GetTaskLoopOfCurrentThread
();
using
Functor
=
std
::
function
<
void
()
>
;
TaskLoop
();
~
TaskLoop
();
void
Loop
();
void
Quit
();
void
RunInLoop
(
Functor
cb
);
void
QueueInLoop
(
Functor
cb
);
template
<
class
F
,
class
...
Args
>
auto
Enqueue
(
F
&&
f
,
Args
&&
...
args
)
->
std
::
future
<
typename
std
::
result_of
<
F
(
Args
...)
>::
type
>
{
using
return_type
=
typename
std
::
result_of
<
F
(
Args
...)
>::
type
;
auto
task
=
std
::
make_shared
<
std
::
packaged_task
<
return_type
()
>>
(
std
::
bind
(
std
::
forward
<
F
>
(
f
),
std
::
forward
<
Args
>
(
args
)...));
std
::
future
<
return_type
>
task_future
=
task
->
get_future
();
tasks_
.
Push
([
task
]()
{
(
*
task
)();
});
return
task_future
;
}
void
WakeUp
();
bool
IsInLoopThread
()
const
{
return
thread_id_
==
std
::
this_thread
::
get_id
();
}
void
AssertInLoopThread
()
{
if
(
!
IsInLoopThread
())
{
AbortNotInLoopThread
();
}
}
private:
DISABLE_COPY_AND_ASSIGN
(
TaskLoop
);
void
AbortNotInLoopThread
();
static
thread_local
TaskLoop
*
thread_local_loop_
;
bool
looping_
;
std
::
atomic
<
bool
>
quit_
;
std
::
thread
::
id
thread_id_
;
framework
::
BlockingQueue
<
Functor
>
tasks_
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/task_loop_thread.cc
0 → 100644
View file @
d2d32668
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/task_loop_thread.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
namespace
paddle
{
namespace
distributed
{
TaskLoopThread
::
TaskLoopThread
()
:
start_
(
false
),
loop_
(
nullptr
)
{}
TaskLoopThread
::~
TaskLoopThread
()
{
if
(
loop_
!=
nullptr
)
{
loop_
->
Quit
();
thread_
.
join
();
}
}
TaskLoop
*
TaskLoopThread
::
StartLoop
()
{
PADDLE_ENFORCE_EQ
(
start_
,
false
,
platform
::
errors
::
PreconditionNotMet
(
"thread is already running."
));
start_
=
true
;
thread_
=
std
::
thread
([
this
]()
{
Loop
();
});
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
cv_
.
wait
(
lock
,
[
=
]
{
return
loop_
!=
nullptr
;
});
return
loop_
;
}
void
TaskLoopThread
::
Loop
()
{
TaskLoop
loop
;
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
loop_
=
&
loop
;
cv_
.
notify_one
();
}
loop
.
Loop
();
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
loop_
=
nullptr
;
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/task_loop_thread.h
0 → 100644
View file @
d2d32668
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <condition_variable>
#include <mutex>
#include <thread>
#include "paddle/fluid/platform/macros.h"
namespace
paddle
{
namespace
distributed
{
class
TaskLoop
;
class
TaskLoopThread
{
public:
TaskLoopThread
();
~
TaskLoopThread
();
TaskLoop
*
StartLoop
();
private:
DISABLE_COPY_AND_ASSIGN
(
TaskLoopThread
);
void
Loop
();
bool
start_
;
TaskLoop
*
loop_
;
std
::
thread
thread_
;
std
::
mutex
mutex_
;
std
::
condition_variable
cv_
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.cc
0 → 100644
View file @
d2d32668
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop_thread.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
namespace
paddle
{
namespace
distributed
{
TaskLoopThreadPool
::
TaskLoopThreadPool
()
:
TaskLoopThreadPool
(
1
)
{}
TaskLoopThreadPool
::
TaskLoopThreadPool
(
int
thread_num
)
:
start_
(
false
),
thread_num_
(
thread_num
)
{}
TaskLoopThreadPool
::~
TaskLoopThreadPool
()
=
default
;
void
TaskLoopThreadPool
::
Start
()
{
PADDLE_ENFORCE_EQ
(
start_
,
false
,
platform
::
errors
::
PreconditionNotMet
(
"thread pool is already start."
));
PADDLE_ENFORCE_GT
(
thread_num_
,
0
,
platform
::
errors
::
InvalidArgument
(
"thread num must greater than 0, but now is %d"
,
thread_num_
));
start_
=
true
;
for
(
int
i
=
0
;
i
<
thread_num_
;
++
i
)
{
threads_
.
emplace_back
(
new
TaskLoopThread
());
loops_
.
push_back
(
threads_
[
i
]
->
StartLoop
());
}
}
TaskLoop
*
TaskLoopThreadPool
::
GetLoop
(
int
tid
)
{
PADDLE_ENFORCE_EQ
(
start_
,
true
,
platform
::
errors
::
PreconditionNotMet
(
"thread pool must start first."
));
PADDLE_ENFORCE_GE
(
tid
,
0
,
platform
::
errors
::
OutOfRange
(
"tid must >= 0, but now is %d"
,
tid
));
PADDLE_ENFORCE_LT
(
tid
,
thread_num_
,
platform
::
errors
::
OutOfRange
(
"tid must < thread_num, but now tid=%d thread_num=%d"
,
tid
,
thread_num_
));
return
loops_
[
tid
];
}
std
::
vector
<
TaskLoop
*>
TaskLoopThreadPool
::
GetAllLoops
()
{
PADDLE_ENFORCE_EQ
(
start_
,
true
,
platform
::
errors
::
PreconditionNotMet
(
"thread pool must start first."
));
return
loops_
;
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h
0 → 100644
View file @
d2d32668
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <vector>
#include "paddle/fluid/platform/macros.h"
namespace
paddle
{
namespace
distributed
{
class
TaskLoop
;
class
TaskLoopThread
;
class
TaskLoopThreadPool
{
public:
TaskLoopThreadPool
();
explicit
TaskLoopThreadPool
(
int
thread_num
);
~
TaskLoopThreadPool
();
void
SetThreadNum
(
int
thread_num
)
{
thread_num_
=
thread_num
;
}
void
Start
();
TaskLoop
*
GetLoop
(
int
tid
);
std
::
vector
<
TaskLoop
*>
GetAllLoops
();
private:
DISABLE_COPY_AND_ASSIGN
(
TaskLoopThreadPool
);
bool
start_
;
int
thread_num_
;
std
::
vector
<
std
::
unique_ptr
<
TaskLoopThread
>>
threads_
;
std
::
vector
<
TaskLoop
*>
loops_
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/task_node.cc
0 → 100644
View file @
d2d32668
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace
paddle
{
namespace
distributed
{
namespace
{
using
OperatorBase
=
TaskNode
::
OperatorBase
;
}
TaskNode
::
TaskNode
(
paddle
::
framework
::
ProgramDesc
*
program
,
int64_t
rank
,
int64_t
max_run_times
,
int64_t
max_slot_nums
)
:
program_
(
program
),
rank_
(
rank
),
max_run_times_
(
max_run_times
),
max_slot_nums_
(
max_slot_nums
)
{
// Should be serially invoked, not thread-safe
// NOTE: when instantiate TaskNode with program, won't init task node
// immediately, since the provided program may be updated later (with
// high probability) by adding_feed_fetch_ops or by RuntimeGraph.
// So, delay the init part to the Init() function.
static
int64_t
task_node_cnt
=
0
;
task_id_
=
task_node_cnt
++
;
}
TaskNode
::
TaskNode
(
paddle
::
framework
::
ProgramDesc
*
program
,
int64_t
rank
)
:
program_
(
program
),
rank_
(
rank
),
task_id_
(
rank
)
{
max_run_times_
=
1
;
max_slot_nums_
=
1
;
LOG
(
INFO
)
<<
"Constructing TaskNode for DistModelInf. The TaskNode's id is: "
<<
rank
<<
". And the TaskNode's max_run_time and max_slot_num will be set to 1."
;
}
void
TaskNode
::
SetProgram
(
paddle
::
framework
::
ProgramDesc
*
program
)
{
program_
=
program
;
}
void
TaskNode
::
Init
(
bool
use_feed_fetch_ops
)
{
if
(
!
use_feed_fetch_ops
)
{
VLOG
(
3
)
<<
"TaskNode will be inited without feed and fetch ops"
;
}
if
(
ops_
.
empty
())
{
// Q (for fleet executor dev): should we need another reset funct?
VLOG
(
3
)
<<
"Task node will be inited by calling Init()."
;
for
(
const
auto
&
op_desc
:
program_
->
Block
(
0
).
AllOps
())
{
if
(
!
use_feed_fetch_ops
&&
(
op_desc
->
Type
()
==
"feed"
||
op_desc
->
Type
()
==
"fetch"
))
{
VLOG
(
3
)
<<
"TaskNode will skip ["
<<
op_desc
->
Input
(
"X"
)[
0
]
<<
"], "
<<
op_desc
->
Type
()
<<
" -> "
<<
op_desc
->
Output
(
"Out"
)[
0
];
continue
;
}
ops_vec_
.
emplace_back
(
framework
::
OpRegistry
::
CreateOp
(
*
op_desc
));
}
for
(
const
auto
&
op
:
ops_vec_
)
{
ops_
.
emplace_back
(
op
.
get
());
}
}
}
TaskNode
::
TaskNode
(
int64_t
rank
,
int64_t
task_id
,
int64_t
max_run_times
)
:
rank_
(
rank
),
task_id_
(
task_id
),
max_run_times_
(
max_run_times
)
{}
TaskNode
::
TaskNode
(
int32_t
role
,
const
std
::
vector
<
framework
::
OpDesc
*>&
op_descs
,
int64_t
rank
,
int64_t
task_id
,
int64_t
max_run_times
,
int64_t
max_slot_nums
)
:
role_
(
role
),
rank_
(
rank
),
task_id_
(
task_id
),
max_run_times_
(
max_run_times
),
max_slot_nums_
(
max_slot_nums
)
{
if
(
op_descs
.
empty
())
{
return
;
}
VLOG
(
3
)
<<
"Task node will be inited by providing list of ops."
;
for
(
const
auto
&
desc
:
op_descs
)
{
ops_vec_
.
emplace_back
(
framework
::
OpRegistry
::
CreateOp
(
*
desc
));
}
for
(
const
auto
&
op
:
ops_vec_
)
{
ops_
.
emplace_back
(
op
.
get
());
}
}
TaskNode
::
TaskNode
(
int32_t
role
,
const
std
::
vector
<
framework
::
OperatorBase
*>&
ops
,
int64_t
rank
,
int64_t
task_id
,
int64_t
max_run_times
,
int64_t
max_slot_nums
)
:
ops_
(
ops
),
role_
(
role
),
rank_
(
rank
),
task_id_
(
task_id
),
max_run_times_
(
max_run_times
),
max_slot_nums_
(
max_slot_nums
)
{}
TaskNode
::
TaskNode
(
int32_t
role
,
int64_t
rank
,
int64_t
task_id
,
int64_t
max_run_times
,
int64_t
max_slot_nums
)
:
role_
(
role
),
rank_
(
rank
),
task_id_
(
task_id
),
max_run_times_
(
max_run_times
),
max_slot_nums_
(
max_slot_nums
)
{}
bool
TaskNode
::
AddUpstreamTask
(
int64_t
task_id
,
int64_t
buff_size
)
{
const
auto
&
ret
=
upstream_
.
emplace
(
task_id
,
buff_size
);
return
ret
.
second
;
}
bool
TaskNode
::
AddDownstreamTask
(
int64_t
task_id
,
int64_t
buff_size
)
{
const
auto
&
ret
=
downstream_
.
emplace
(
task_id
,
buff_size
);
return
ret
.
second
;
}
std
::
string
TaskNode
::
DebugString
()
const
{
std
::
ostringstream
os
;
os
<<
"role: "
<<
role_
<<
", task_id: "
<<
task_id_
<<
"
\n
"
;
for
(
std
::
size_t
i
=
0
;
i
<
ops_
.
size
();
++
i
)
{
os
<<
ops_
[
i
]
->
Type
()
<<
" "
;
}
os
<<
"
\n
"
;
return
os
.
str
();
}
void
TaskNode
::
SetRunPerSteps
(
int64_t
value
)
{
PADDLE_ENFORCE_GE
(
value
,
1
,
platform
::
errors
::
InvalidArgument
(
"run_per_steps must >= 1, but received %ld"
,
value
));
run_per_steps_
=
value
;
}
void
TaskNode
::
SetRunAtOffset
(
int64_t
value
)
{
PADDLE_ENFORCE_GE
(
value
,
0
,
platform
::
errors
::
InvalidArgument
(
"run_at_offset must >= 0, but received %ld"
,
value
));
run_at_offset_
=
value
;
}
void
TaskNode
::
SetReplyUpPerSteps
(
int64_t
value
)
{
PADDLE_ENFORCE_GE
(
value
,
1
,
platform
::
errors
::
InvalidArgument
(
"reply_up_per_steps must >= 1, but received %ld"
,
value
));
reply_up_per_steps_
=
value
;
}
void
TaskNode
::
SetSendDownPerSteps
(
int64_t
value
)
{
PADDLE_ENFORCE_GE
(
value
,
1
,
platform
::
errors
::
InvalidArgument
(
"send_down_per_steps must >= 1, but received %ld"
,
value
));
send_down_per_steps_
=
value
;
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/task_node.h
0 → 100644
View file @
d2d32668
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstdint>
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/macros.h"
namespace
paddle
{
namespace
framework
{
class
OperatorBase
;
class
OpDesc
;
}
// namespace framework
namespace
distributed
{
class
TaskNode
final
{
public:
using
OperatorBase
=
paddle
::
framework
::
OperatorBase
;
TaskNode
(
int64_t
rank
,
int64_t
task_id
,
int64_t
max_run_times
);
TaskNode
(
int32_t
role
,
int64_t
rank
,
int64_t
task_id
,
int64_t
max_run_times
,
int64_t
max_slot_nums
);
TaskNode
(
int32_t
role
,
const
std
::
vector
<
framework
::
OpDesc
*>&
op_descs
,
int64_t
rank
,
int64_t
task_id
,
int64_t
max_run_times
,
int64_t
max_slot_nums
);
TaskNode
(
int32_t
role
,
const
std
::
vector
<
framework
::
OperatorBase
*>&
ops
,
int64_t
rank
,
int64_t
task_id
,
int64_t
max_run_times
,
int64_t
max_slot_nums
);
TaskNode
(
paddle
::
framework
::
ProgramDesc
*
program
,
int64_t
rank
,
int64_t
max_run_times
,
int64_t
max_slot_nums
);
TaskNode
(
paddle
::
framework
::
ProgramDesc
*
program
,
int64_t
rank
);
~
TaskNode
()
=
default
;
void
SetProgram
(
paddle
::
framework
::
ProgramDesc
*
program
);
void
Init
(
bool
use_feed_fetch_ops
=
true
);
int64_t
rank
()
const
{
return
rank_
;
}
int64_t
task_id
()
const
{
return
task_id_
;
}
int32_t
role
()
const
{
return
role_
;
}
int64_t
max_run_times
()
const
{
return
max_run_times_
;
}
int64_t
max_slot_nums
()
const
{
return
max_slot_nums_
;
}
int64_t
run_per_steps
()
const
{
return
run_per_steps_
;
}
int64_t
run_at_offset
()
const
{
return
run_at_offset_
;
}
int64_t
reply_up_per_steps
()
const
{
return
reply_up_per_steps_
;
}
int64_t
send_down_per_steps
()
const
{
return
send_down_per_steps_
;
}
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
upstream
()
const
{
return
upstream_
;
}
const
std
::
unordered_map
<
int64_t
,
int64_t
>&
downstream
()
const
{
return
downstream_
;
}
const
std
::
string
&
type
()
const
{
return
type_
;
}
const
paddle
::
framework
::
ProgramDesc
*
program
()
const
{
return
program_
;
}
const
std
::
vector
<
OperatorBase
*>&
ops
()
const
{
return
ops_
;
}
const
std
::
vector
<
std
::
unique_ptr
<
OperatorBase
>>&
unique_ops
()
const
{
return
ops_vec_
;
}
const
std
::
unordered_map
<
const
OperatorBase
*
,
std
::
vector
<
std
::
string
>>&
unused_vars
()
const
{
return
unused_vars_
;
}
void
SetRunPerSteps
(
int64_t
value
);
void
SetRunAtOffset
(
int64_t
value
);
void
SetReplyUpPerSteps
(
int64_t
value
);
void
SetSendDownPerSteps
(
int64_t
value
);
void
SetType
(
const
std
::
string
&
type
)
{
type_
=
type
;
}
void
SetUnusedVars
(
const
std
::
unordered_map
<
const
OperatorBase
*
,
std
::
vector
<
std
::
string
>>&
unused_vars
)
{
unused_vars_
=
unused_vars
;
}
// upstream need buffs?
bool
AddUpstreamTask
(
int64_t
task_id
,
int64_t
buff_size
=
1
);
bool
AddDownstreamTask
(
int64_t
task_id
,
int64_t
buff_size
=
1
);
std
::
string
DebugString
()
const
;
private:
DISABLE_COPY_AND_ASSIGN
(
TaskNode
);
TaskNode
()
=
default
;
// ops_ will be removed in the future
std
::
vector
<
OperatorBase
*>
ops_
;
// task_id-->buff_size
std
::
unordered_map
<
int64_t
,
int64_t
>
upstream_
;
std
::
unordered_map
<
int64_t
,
int64_t
>
downstream_
;
framework
::
ProgramDesc
*
program_
;
std
::
vector
<
std
::
unique_ptr
<
OperatorBase
>>
ops_vec_
;
std
::
unordered_map
<
const
OperatorBase
*
,
std
::
vector
<
std
::
string
>>
unused_vars_
;
int32_t
role_
;
int64_t
rank_
;
int64_t
task_id_
;
int64_t
max_run_times_
;
int64_t
max_slot_nums_
;
int64_t
run_per_steps_
{
1
};
int64_t
run_at_offset_
{
0
};
// one input produces multi times output
int64_t
reply_up_per_steps_
{
1
};
// one output need multi times input
int64_t
send_down_per_steps_
{
1
};
std
::
string
type_
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt
0 → 100644
View file @
d2d32668
set_source_files_properties
(
interceptor_ping_pong_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
interceptor_ping_pong_test
SRCS interceptor_ping_pong_test.cc
DEPS fleet_executor
${
BRPC_DEPS
}
)
set_source_files_properties
(
compute_interceptor_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
compute_interceptor_test
SRCS compute_interceptor_test.cc
DEPS fleet_executor
${
BRPC_DEPS
}
)
set_source_files_properties
(
source_interceptor_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
source_interceptor_test
SRCS source_interceptor_test.cc
DEPS fleet_executor
${
BRPC_DEPS
}
)
set_source_files_properties
(
sink_interceptor_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
sink_interceptor_test
SRCS sink_interceptor_test.cc
DEPS fleet_executor
${
BRPC_DEPS
}
)
set_source_files_properties
(
interceptor_pipeline_short_path_test.cc
PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
interceptor_pipeline_short_path_test
SRCS interceptor_pipeline_short_path_test.cc
DEPS fleet_executor
${
BRPC_DEPS
}
)
set_source_files_properties
(
interceptor_pipeline_long_path_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
interceptor_pipeline_long_path_test
SRCS interceptor_pipeline_long_path_test.cc
DEPS fleet_executor
${
BRPC_DEPS
}
)
set_source_files_properties
(
compute_interceptor_run_op_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
compute_interceptor_run_op_test
SRCS compute_interceptor_run_op_test.cc
DEPS fleet_executor
${
BRPC_DEPS
}
op_registry
fill_constant_op
elementwise_add_op
scope
device_context
)
if
(
WITH_DISTRIBUTE
AND WITH_PSCORE
AND
NOT
(
WITH_ASCEND OR WITH_ASCEND_CL
))
set_source_files_properties
(
interceptor_ping_pong_with_brpc_test.cc
PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
interceptor_ping_pong_with_brpc_test
SRCS interceptor_ping_pong_with_brpc_test.cc
DEPS fleet_executor
${
BRPC_DEPS
}
)
endif
()
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc
0 → 100644
View file @
d2d32668
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/phi/core/kernel_registry.h"
USE_OP_ITSELF
(
elementwise_add
);
USE_OP_ITSELF
(
fill_constant
);
PD_DECLARE_KERNEL
(
add
,
CPU
,
ALL_LAYOUT
);
PD_DECLARE_KERNEL
(
full
,
CPU
,
ALL_LAYOUT
);
namespace
paddle
{
namespace
distributed
{
std
::
vector
<
framework
::
OperatorBase
*>
GetOps
()
{
framework
::
AttributeMap
attrs
;
attrs
[
"dtype"
]
=
framework
::
proto
::
VarType
::
FP32
;
attrs
[
"shape"
]
=
phi
::
vectorize
<
int
>
({
2
,
3
});
attrs
[
"value"
]
=
1.0
f
;
auto
zero_op
=
framework
::
OpRegistry
::
CreateOp
(
"fill_constant"
,
{},
{{
"Out"
,
{
"x"
}}},
attrs
);
auto
op
=
framework
::
OpRegistry
::
CreateOp
(
"elementwise_add"
,
{{
"X"
,
{
"x"
}},
{
"Y"
,
{
"x"
}}},
{{
"Out"
,
{
"out"
}}},
framework
::
AttributeMap
());
// NOTE: don't delete
return
{
zero_op
.
release
(),
op
.
release
()};
}
framework
::
Scope
*
GetScope
()
{
framework
::
Scope
*
scope
=
new
framework
::
Scope
();
scope
->
Var
(
"x"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
scope
->
Var
(
"out"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
return
scope
;
}
TEST
(
ComputeInterceptor
,
Compute
)
{
std
::
vector
<
framework
::
OperatorBase
*>
ops
=
GetOps
();
framework
::
Scope
*
scope
=
GetScope
();
std
::
vector
<
framework
::
Scope
*>
scopes
=
{
scope
,
scope
};
platform
::
Place
place
=
platform
::
CPUPlace
();
std
::
string
carrier_id
=
"0"
;
Carrier
*
carrier
=
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
carrier
->
Init
(
0
,
{{
SOURCE_ID
,
0
},
{
0
,
0
},
{
1
,
0
},
{
SINK_ID
,
0
}});
MessageBus
*
msg_bus
=
GlobalVal
<
MessageBus
>::
Create
();
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
// FIXME: don't delete, otherwise interceptor will use undefined node
TaskNode
*
source
=
new
TaskNode
(
0
,
SOURCE_ID
,
2
);
// rank, task_id, max_run_times
TaskNode
*
node_a
=
new
TaskNode
(
0
,
ops
,
0
,
0
,
2
,
0
);
// role, ops, rank, task_id
TaskNode
*
node_b
=
new
TaskNode
(
0
,
0
,
1
,
2
,
0
);
TaskNode
*
sink
=
new
TaskNode
(
0
,
SINK_ID
,
2
);
// source->a->b->sink
source
->
AddDownstreamTask
(
0
);
node_a
->
AddUpstreamTask
(
SOURCE_ID
);
node_a
->
AddDownstreamTask
(
1
);
node_b
->
AddUpstreamTask
(
0
);
sink
->
AddUpstreamTask
(
1
);
node_b
->
AddDownstreamTask
(
SINK_ID
);
carrier
->
SetInterceptor
(
SOURCE_ID
,
InterceptorFactory
::
Create
(
"Source"
,
SOURCE_ID
,
source
));
auto
*
a
=
carrier
->
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"Compute"
,
0
,
node_a
));
carrier
->
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"Compute"
,
1
,
node_b
));
carrier
->
SetInterceptor
(
SINK_ID
,
InterceptorFactory
::
Create
(
"Sink"
,
SINK_ID
,
sink
));
a
->
SetPlace
(
place
);
a
->
SetMicroBatchScope
(
scopes
);
// start
InterceptorMessage
msg
;
msg
.
set_message_type
(
START
);
msg
.
set_dst_id
(
SOURCE_ID
);
carrier
->
EnqueueInterceptorMessage
(
msg
);
carrier
->
Wait
();
carrier
->
Release
();
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc
0 → 100644
View file @
d2d32668
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace
paddle
{
namespace
distributed
{
class
StartInterceptor
:
public
Interceptor
{
public:
StartInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
)
:
Interceptor
(
interceptor_id
,
node
)
{
RegisterMsgHandle
([
this
](
const
InterceptorMessage
&
msg
)
{
NOP
(
msg
);
});
}
void
NOP
(
const
InterceptorMessage
&
msg
)
{
if
(
msg
.
message_type
()
==
STOP
)
{
stop_
=
true
;
InterceptorMessage
stop
;
stop
.
set_message_type
(
STOP
);
Send
(
1
,
stop
);
// stop 1, compute
return
;
}
std
::
cout
<<
GetInterceptorId
()
<<
" recv msg from "
<<
msg
.
src_id
()
<<
std
::
endl
;
}
};
TEST
(
ComputeInterceptor
,
Compute
)
{
std
::
string
carrier_id
=
"0"
;
Carrier
*
carrier
=
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
carrier
->
Init
(
0
,
{{
0
,
0
},
{
1
,
0
},
{
2
,
0
}});
MessageBus
*
msg_bus
=
GlobalVal
<
MessageBus
>::
Create
();
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
3
,
0
);
// role, rank, task_id
TaskNode
*
node_b
=
new
TaskNode
(
0
,
0
,
1
,
3
,
0
);
TaskNode
*
node_c
=
new
TaskNode
(
0
,
0
,
2
,
3
,
0
);
// a->b->c
node_a
->
AddDownstreamTask
(
1
,
3
);
node_b
->
AddUpstreamTask
(
0
,
3
);
node_b
->
AddDownstreamTask
(
2
);
node_c
->
AddUpstreamTask
(
1
);
Interceptor
*
a
=
carrier
->
SetInterceptor
(
0
,
std
::
make_unique
<
StartInterceptor
>
(
0
,
node_a
));
carrier
->
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"Compute"
,
1
,
node_b
));
carrier
->
SetInterceptor
(
2
,
InterceptorFactory
::
Create
(
"Compute"
,
2
,
node_c
));
InterceptorMessage
msg
;
msg
.
set_message_type
(
DATA_IS_READY
);
// test run three times
a
->
Send
(
1
,
msg
);
a
->
Send
(
1
,
msg
);
a
->
Send
(
1
,
msg
);
carrier
->
Wait
();
carrier
->
Release
();
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc
0 → 100644
View file @
d2d32668
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
namespace
paddle
{
namespace
distributed
{
class
PingPongInterceptor
:
public
Interceptor
{
public:
PingPongInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
)
:
Interceptor
(
interceptor_id
,
node
)
{
RegisterMsgHandle
([
this
](
const
InterceptorMessage
&
msg
)
{
PingPong
(
msg
);
});
}
void
PingPong
(
const
InterceptorMessage
&
msg
)
{
if
(
msg
.
message_type
()
==
STOP
)
{
stop_
=
true
;
return
;
}
std
::
cout
<<
GetInterceptorId
()
<<
" recv msg, count="
<<
count_
<<
std
::
endl
;
++
count_
;
if
(
count_
==
20
)
{
InterceptorMessage
stop
;
stop
.
set_message_type
(
STOP
);
Send
(
0
,
stop
);
Send
(
1
,
stop
);
StopCarrier
();
return
;
}
InterceptorMessage
resp
;
Send
(
msg
.
src_id
(),
resp
);
}
private:
int
count_
{
0
};
};
REGISTER_INTERCEPTOR
(
PingPong
,
PingPongInterceptor
);
TEST
(
InterceptorTest
,
PingPong
)
{
std
::
string
carrier_id
=
"0"
;
Carrier
*
carrier
=
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
carrier
->
Init
(
0
,
{{
0
,
0
},
{
1
,
0
}});
MessageBus
*
msg_bus
=
GlobalVal
<
MessageBus
>::
Create
();
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
Interceptor
*
a
=
carrier
->
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"PingPong"
,
0
,
nullptr
));
carrier
->
SetInterceptor
(
1
,
std
::
make_unique
<
PingPongInterceptor
>
(
1
,
nullptr
));
InterceptorMessage
msg
;
a
->
Send
(
1
,
msg
);
carrier
->
Wait
();
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc
0 → 100644
View file @
d2d32668
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <sys/socket.h>
#include <time.h>
#include <iostream>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
namespace
paddle
{
namespace
distributed
{
class
PingPongInterceptor
:
public
Interceptor
{
public:
PingPongInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
)
:
Interceptor
(
interceptor_id
,
node
)
{
RegisterMsgHandle
([
this
](
const
InterceptorMessage
&
msg
)
{
PingPong
(
msg
);
});
}
void
PingPong
(
const
InterceptorMessage
&
msg
)
{
if
(
msg
.
message_type
()
==
STOP
)
{
stop_
=
true
;
StopCarrier
();
return
;
}
std
::
cout
<<
GetInterceptorId
()
<<
" recv msg, count="
<<
count_
<<
std
::
endl
;
++
count_
;
if
(
count_
==
20
&&
GetInterceptorId
()
==
0
)
{
InterceptorMessage
stop
;
stop
.
set_message_type
(
STOP
);
Send
(
0
,
stop
);
Send
(
1
,
stop
);
return
;
}
InterceptorMessage
resp
;
int64_t
dst
=
GetInterceptorId
()
==
0
?
1
:
0
;
Send
(
dst
,
resp
);
}
private:
int
count_
{
0
};
};
REGISTER_INTERCEPTOR
(
PingPong
,
PingPongInterceptor
);
TEST
(
InterceptorTest
,
PingPong
)
{
std
::
cout
<<
"Ping pong test through brpc"
<<
std
::
endl
;
unsigned
int
seed
=
time
(
0
);
// random generated two ports in from 6000 to 9000
int
port0
=
6000
+
rand_r
(
&
seed
)
%
3000
;
int
port1
=
port0
+
1
;
// using socket to check the availability of the port
int
server_fd
=
-
1
;
server_fd
=
socket
(
AF_INET
,
SOCK_STREAM
,
0
);
int
opt
=
1
;
linger
ling
;
ling
.
l_onoff
=
1
;
ling
.
l_linger
=
0
;
setsockopt
(
server_fd
,
SOL_SOCKET
,
SO_LINGER
,
&
ling
,
sizeof
(
ling
));
setsockopt
(
server_fd
,
SOL_SOCKET
,
SO_REUSEADDR
,
&
opt
,
sizeof
(
opt
));
struct
sockaddr_in
address
;
address
.
sin_family
=
AF_INET
;
address
.
sin_addr
.
s_addr
=
INADDR_ANY
;
address
.
sin_port
=
htons
(
port0
);
while
(
bind
(
server_fd
,
(
struct
sockaddr
*
)
&
address
,
sizeof
(
address
))
==
-
1
)
{
port0
++
;
address
.
sin_port
=
htons
(
port0
);
}
close
(
server_fd
);
// use another socket to check another port
server_fd
=
socket
(
AF_INET
,
SOCK_STREAM
,
0
);
setsockopt
(
server_fd
,
SOL_SOCKET
,
SO_LINGER
,
&
ling
,
sizeof
(
ling
));
setsockopt
(
server_fd
,
SOL_SOCKET
,
SO_REUSEADDR
,
&
opt
,
sizeof
(
opt
));
port1
=
port0
+
1
;
address
.
sin_port
=
htons
(
port1
);
while
(
bind
(
server_fd
,
(
struct
sockaddr
*
)
&
address
,
sizeof
(
address
))
==
-
1
)
{
port1
++
;
address
.
sin_port
=
htons
(
port1
);
}
close
(
server_fd
);
std
::
string
ip0
=
"127.0.0.1:"
+
std
::
to_string
(
port0
);
std
::
string
ip1
=
"127.0.0.1:"
+
std
::
to_string
(
port1
);
std
::
cout
<<
"ip0: "
<<
ip0
<<
std
::
endl
;
std
::
cout
<<
"ip1: "
<<
ip1
<<
std
::
endl
;
std
::
unordered_map
<
int64_t
,
int64_t
>
interceptor_id_to_rank
=
{{
0
,
0
},
{
1
,
1
}};
std
::
string
carrier_id
=
"0"
;
int
pid
=
fork
();
if
(
pid
==
0
)
{
Carrier
*
carrier
=
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
GlobalVal
<
std
::
string
>::
Set
(
new
std
::
string
(
carrier_id
));
MessageBus
*
msg_bus
=
GlobalVal
<
MessageBus
>::
Create
();
msg_bus
->
Init
(
0
,
{{
0
,
ip0
},
{
1
,
ip1
}},
ip0
);
carrier
->
Init
(
0
,
interceptor_id_to_rank
);
Interceptor
*
a
=
carrier
->
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"PingPong"
,
0
,
nullptr
));
msg_bus
->
Barrier
();
InterceptorMessage
msg
;
a
->
Send
(
1
,
msg
);
carrier
->
Wait
();
}
else
{
Carrier
*
carrier
=
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
GlobalVal
<
std
::
string
>::
Set
(
new
std
::
string
(
carrier_id
));
MessageBus
*
msg_bus
=
GlobalVal
<
MessageBus
>::
Create
();
msg_bus
->
Init
(
1
,
{{
0
,
ip0
},
{
1
,
ip1
}},
ip1
);
carrier
->
Init
(
1
,
interceptor_id_to_rank
);
carrier
->
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"PingPong"
,
1
,
nullptr
));
msg_bus
->
Barrier
();
carrier
->
Wait
();
}
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc
0 → 100644
View file @
d2d32668
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace
paddle
{
namespace
distributed
{
void
LinkNodes
(
const
std
::
vector
<
TaskNode
*>&
nodes
)
{
size_t
size
=
nodes
.
size
();
if
(
size
<=
1
)
return
;
{
// i = 0
TaskNode
*
now
=
nodes
[
0
];
TaskNode
*
next
=
nodes
[
1
];
now
->
AddDownstreamTask
(
next
->
task_id
());
}
{
// i = size - 1
TaskNode
*
prev
=
nodes
[
size
-
2
];
TaskNode
*
now
=
nodes
[
size
-
1
];
now
->
AddUpstreamTask
(
prev
->
task_id
());
}
for
(
size_t
i
=
1
;
i
<
size
-
1
;
++
i
)
{
TaskNode
*
prev
=
nodes
[
i
-
1
];
TaskNode
*
now
=
nodes
[
i
];
TaskNode
*
next
=
nodes
[
i
+
1
];
now
->
AddUpstreamTask
(
prev
->
task_id
());
now
->
AddDownstreamTask
(
next
->
task_id
());
}
}
TEST
(
AmplifierInterceptor
,
Amplifier
)
{
std
::
string
carrier_id
=
"0"
;
Carrier
*
carrier
=
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
carrier
->
Init
(
0
,
{{
SOURCE_ID
,
0
},
{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
},
{
4
,
0
},
{
5
,
0
},
{
SINK_ID
,
0
}});
MessageBus
*
msg_bus
=
GlobalVal
<
MessageBus
>::
Create
();
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
"127.0.0.0:0"
);
int64_t
micro_steps
=
3
;
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode
*
source
=
new
TaskNode
(
0
,
SOURCE_ID
,
micro_steps
);
// rank, task_id, max_run_times
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
1
,
0
);
// role, rank, task_id
TaskNode
*
node_b
=
new
TaskNode
(
0
,
0
,
1
,
1
,
0
);
TaskNode
*
node_c
=
new
TaskNode
(
0
,
0
,
2
,
1
,
0
);
TaskNode
*
node_d
=
new
TaskNode
(
0
,
0
,
3
,
1
,
0
);
TaskNode
*
node_e
=
new
TaskNode
(
0
,
0
,
4
,
1
,
0
);
TaskNode
*
node_f
=
new
TaskNode
(
0
,
0
,
5
,
1
,
0
);
TaskNode
*
sink
=
new
TaskNode
(
0
,
SINK_ID
,
micro_steps
);
// source->a->b->c->d->e->f->sink
LinkNodes
({
source
,
node_a
,
node_b
,
node_c
,
node_d
,
node_e
,
node_f
,
sink
});
// LR->b(1:3)->F->B->e(3:1)->U
node_b
->
SetReplyUpPerSteps
(
micro_steps
);
node_e
->
SetSendDownPerSteps
(
micro_steps
);
carrier
->
SetInterceptor
(
SOURCE_ID
,
InterceptorFactory
::
Create
(
"Source"
,
SOURCE_ID
,
source
));
carrier
->
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"Compute"
,
0
,
node_a
));
carrier
->
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"Amplifier"
,
1
,
node_b
));
carrier
->
SetInterceptor
(
2
,
InterceptorFactory
::
Create
(
"Compute"
,
2
,
node_c
));
carrier
->
SetInterceptor
(
3
,
InterceptorFactory
::
Create
(
"Compute"
,
3
,
node_d
));
carrier
->
SetInterceptor
(
4
,
InterceptorFactory
::
Create
(
"Amplifier"
,
4
,
node_e
));
carrier
->
SetInterceptor
(
5
,
InterceptorFactory
::
Create
(
"Compute"
,
5
,
node_f
));
carrier
->
SetInterceptor
(
SINK_ID
,
InterceptorFactory
::
Create
(
"Sink"
,
SINK_ID
,
sink
));
// start
InterceptorMessage
msg
;
msg
.
set_message_type
(
START
);
msg
.
set_dst_id
(
SOURCE_ID
);
carrier
->
EnqueueInterceptorMessage
(
msg
);
carrier
->
Wait
();
carrier
->
Release
();
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc
0 → 100644
View file @
d2d32668
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace
paddle
{
namespace
distributed
{
int64_t
GetBuffSize
(
const
std
::
map
<
std
::
pair
<
TaskNode
*
,
TaskNode
*>
,
int64_t
>
buffs
,
TaskNode
*
from
,
TaskNode
*
to
)
{
if
(
buffs
.
find
({
from
,
to
})
!=
buffs
.
end
())
{
return
buffs
.
at
({
from
,
to
});
}
if
(
buffs
.
find
({
to
,
from
})
!=
buffs
.
end
())
{
return
buffs
.
at
({
to
,
from
});
}
return
2
;
// set default 2
}
void
LinkNodes
(
const
std
::
vector
<
TaskNode
*>&
nodes
,
const
std
::
map
<
std
::
pair
<
TaskNode
*
,
TaskNode
*>
,
int64_t
>
buffs
)
{
size_t
size
=
nodes
.
size
();
if
(
size
<=
1
)
return
;
{
// i = 0
TaskNode
*
now
=
nodes
[
0
];
TaskNode
*
next
=
nodes
[
1
];
auto
buff_size
=
GetBuffSize
(
buffs
,
now
,
next
);
now
->
AddDownstreamTask
(
next
->
task_id
(),
buff_size
);
}
{
// i = size - 1
TaskNode
*
prev
=
nodes
[
size
-
2
];
TaskNode
*
now
=
nodes
[
size
-
1
];
auto
buff_size
=
GetBuffSize
(
buffs
,
prev
,
now
);
now
->
AddUpstreamTask
(
prev
->
task_id
(),
buff_size
);
}
for
(
size_t
i
=
1
;
i
<
size
-
1
;
++
i
)
{
TaskNode
*
prev
=
nodes
[
i
-
1
];
TaskNode
*
now
=
nodes
[
i
];
TaskNode
*
next
=
nodes
[
i
+
1
];
auto
buff_size
=
GetBuffSize
(
buffs
,
prev
,
now
);
now
->
AddUpstreamTask
(
prev
->
task_id
(),
buff_size
);
buff_size
=
GetBuffSize
(
buffs
,
now
,
next
);
now
->
AddDownstreamTask
(
next
->
task_id
(),
buff_size
);
}
}
TEST
(
AmplifierInterceptor
,
Amplifier
)
{
std
::
string
carrier_id
=
"0"
;
Carrier
*
carrier
=
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
carrier
->
Init
(
0
,
{{
SOURCE_ID
,
0
},
{
0
,
0
},
{
1
,
0
},
{
2
,
0
},
{
3
,
0
},
{
SINK_ID
,
0
}});
MessageBus
*
msg_bus
=
GlobalVal
<
MessageBus
>::
Create
();
msg_bus
->
Init
(
0
,
{{
0
,
""
}},
""
);
int64_t
micro_steps
=
6
;
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode
*
source
=
new
TaskNode
(
0
,
SOURCE_ID
,
micro_steps
);
// rank, task_id, max_run_times
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
micro_steps
,
0
);
// role, rank, task_id
TaskNode
*
node_b
=
new
TaskNode
(
0
,
0
,
1
,
3
,
0
);
TaskNode
*
node_c
=
new
TaskNode
(
0
,
0
,
2
,
3
,
0
);
TaskNode
*
node_d
=
new
TaskNode
(
0
,
0
,
3
,
micro_steps
,
0
);
TaskNode
*
sink
=
new
TaskNode
(
0
,
SINK_ID
,
micro_steps
);
// source->a->b->c->d->sink
// LR->F->B->U
LinkNodes
({
source
,
node_a
,
node_b
,
node_c
,
node_d
,
sink
},
{{{
node_b
,
node_c
},
1
}});
node_a
->
SetRunPerSteps
(
micro_steps
);
node_d
->
SetRunPerSteps
(
micro_steps
);
node_d
->
SetRunAtOffset
(
micro_steps
-
1
);
carrier
->
SetInterceptor
(
SOURCE_ID
,
InterceptorFactory
::
Create
(
"Source"
,
SOURCE_ID
,
source
));
carrier
->
SetInterceptor
(
0
,
InterceptorFactory
::
Create
(
"Amplifier"
,
0
,
node_a
));
carrier
->
SetInterceptor
(
1
,
InterceptorFactory
::
Create
(
"Compute"
,
1
,
node_b
));
carrier
->
SetInterceptor
(
2
,
InterceptorFactory
::
Create
(
"Compute"
,
2
,
node_c
));
carrier
->
SetInterceptor
(
3
,
InterceptorFactory
::
Create
(
"Amplifier"
,
3
,
node_d
));
carrier
->
SetInterceptor
(
SINK_ID
,
InterceptorFactory
::
Create
(
"Sink"
,
SINK_ID
,
sink
));
// start
InterceptorMessage
msg
;
msg
.
set_message_type
(
START
);
msg
.
set_dst_id
(
SOURCE_ID
);
carrier
->
EnqueueInterceptorMessage
(
msg
);
carrier
->
Wait
();
carrier
->
Release
();
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/fleet_executor/test/sink_interceptor_test.cc
0 → 100644
View file @
d2d32668
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace
paddle
{
namespace
distributed
{
class
FakeInterceptor
:
public
Interceptor
{
public:
FakeInterceptor
(
int64_t
interceptor_id
,
TaskNode
*
node
)
:
Interceptor
(
interceptor_id
,
node
)
{
RegisterMsgHandle
([
this
](
const
InterceptorMessage
&
msg
)
{
NOP
(
msg
);
});
}
void
NOP
(
const
InterceptorMessage
&
msg
)
{
if
(
msg
.
message_type
()
==
DATA_IS_READY
)
{
std
::
cout
<<
"FakeInterceptor run in scope "
<<
msg
.
scope_idx
()
<<
std
::
endl
;
InterceptorMessage
reply
;
reply
.
set_message_type
(
DATA_IS_USELESS
);
Send
(
SOURCE_ID
,
reply
);
InterceptorMessage
ready
;
ready
.
set_message_type
(
DATA_IS_READY
);
Send
(
SINK_ID
,
ready
);
}
else
if
(
msg
.
message_type
()
==
DATA_IS_USELESS
)
{
std
::
cout
<<
"FakeInterceptor remove result in scope "
<<
msg
.
scope_idx
()
<<
std
::
endl
;
}
}
private:
int64_t
step_
;
};
TEST
(
SourceInterceptor
,
Source
)
{
std
::
string
carrier_id
=
"0"
;
Carrier
*
carrier
=
GlobalMap
<
std
::
string
,
Carrier
>::
Create
(
carrier_id
,
carrier_id
);
carrier
->
Init
(
0
,
{{
SOURCE_ID
,
0
},
{
0
,
0
},
{
SINK_ID
,
0
}});
MessageBus
*
msg_bus
=
GlobalVal
<
MessageBus
>::
Create
();
msg_bus
->
Init
(
0
,
{{
0
,
"127.0.0.0:0"
}},
""
);
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode
*
source
=
new
TaskNode
(
0
,
SOURCE_ID
,
0
,
3
,
0
);
// role, rank, task_id
TaskNode
*
node_a
=
new
TaskNode
(
0
,
0
,
0
,
3
,
0
);
// role, rank, task_id
TaskNode
*
sink
=
new
TaskNode
(
0
,
SINK_ID
,
0
,
3
,
0
);
// role, rank, task_id
source
->
AddDownstreamTask
(
0
,
1
);
node_a
->
AddUpstreamTask
(
SOURCE_ID
,
1
);
node_a
->
AddDownstreamTask
(
SINK_ID
,
1
);
sink
->
AddUpstreamTask
(
0
,
1
);
carrier
->
SetInterceptor
(
SOURCE_ID
,
InterceptorFactory
::
Create
(
"Source"
,
SOURCE_ID
,
source
));
carrier
->
SetInterceptor
(
0
,
std
::
make_unique
<
FakeInterceptor
>
(
0
,
node_a
));
carrier
->
SetInterceptor
(
SINK_ID
,
InterceptorFactory
::
Create
(
"Sink"
,
SINK_ID
,
sink
));
// start
InterceptorMessage
msg
;
msg
.
set_message_type
(
START
);
msg
.
set_dst_id
(
SOURCE_ID
);
carrier
->
EnqueueInterceptorMessage
(
msg
);
carrier
->
Wait
();
carrier
->
Release
();
}
}
// namespace distributed
}
// namespace paddle
Prev
1
…
5
6
7
8
9
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment