Commit d2d32668 authored by yuguo960516yuguo's avatar yuguo960516yuguo
Browse files

2.3.0-dtk-22.04.2

parent ad08b8ce
Pipeline #226 failed with stages
in 0 seconds
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/sink_interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace paddle {
namespace distributed {
SinkInterceptor::SinkInterceptor(int64_t interceptor_id, TaskNode* node)
: Interceptor(interceptor_id, node), max_run_times_(node->max_run_times()) {
// prepare the upstream running status
for (const auto& up : node->upstream()) {
upstream_step_.emplace(up.first, 0);
}
RegisterMsgHandle([this](const InterceptorMessage& msg) { Run(msg); });
}
void SinkInterceptor::StopCarrierIfComplete() {
bool flag = true;
for (const auto& up : upstream_step_) {
flag = flag && (up.second == max_run_times_);
}
if (flag) {
VLOG(3) << "Sink Interceptor is stopping carrier";
StopCarrier();
for (const auto& up : upstream_step_) {
upstream_step_.at(up.first) = 0;
}
}
}
void SinkInterceptor::ReplyCompletedToUpStream(int64_t upstream_id) {
int64_t micro_step = upstream_step_.at(upstream_id);
int64_t scope_idx = micro_step % max_run_times_;
InterceptorMessage msg;
msg.set_message_type(DATA_IS_USELESS);
msg.set_scope_idx(scope_idx);
Send(upstream_id, msg);
upstream_step_.at(upstream_id) = micro_step + 1;
if (micro_step == max_run_times_ - 1) {
StopCarrierIfComplete();
}
}
void SinkInterceptor::Run(const InterceptorMessage& msg) {
if (msg.message_type() == DATA_IS_READY) {
ReplyCompletedToUpStream(msg.src_id());
}
}
REGISTER_INTERCEPTOR(Sink, SinkInterceptor);
} // namespace distributed
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
namespace paddle {
namespace distributed {
/*
* Sink interceptor
* There is only one sink in the runtime graph
* Take charge of:
* 1. record the num of micro-step
* 2. check whether to notify carrier the current step is finished
*/
class SinkInterceptor : public Interceptor {
public:
SinkInterceptor(int64_t interceptor_id, TaskNode* node);
private:
void ReplyCompletedToUpStream(int64_t up_id);
void Run(const InterceptorMessage& msg);
void StopCarrierIfComplete();
int64_t max_run_times_;
// upstream_id->cur_step
std::map<int64_t, int64_t> upstream_step_;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/source_interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace paddle {
namespace distributed {
SourceInterceptor::SourceInterceptor(int64_t interceptor_id, TaskNode* node)
: Interceptor(interceptor_id, node), max_run_times_(node->max_run_times()) {
// prepare the downstream running status
for (const auto& down : node->downstream()) {
downstream_step_.emplace(down.first, 0);
}
RegisterMsgHandle([this](const InterceptorMessage& msg) { Run(msg); });
}
void SourceInterceptor::SendDataReadyToDownStream(int64_t downstream_id) {
int64_t micro_step = downstream_step_.at(downstream_id);
if (micro_step >= max_run_times_) {
return;
}
int64_t scope_idx = micro_step % max_run_times_;
InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_IS_READY);
ready_msg.set_scope_idx(scope_idx);
Send(downstream_id, ready_msg);
downstream_step_.at(downstream_id) = micro_step + 1;
}
void SourceInterceptor::Run(const InterceptorMessage& msg) {
if (msg.message_type() == START) {
// start run in a new step, reset the previous running status
for (const auto& down : downstream_step_) {
downstream_step_.at(down.first) = 0;
SendDataReadyToDownStream(down.first);
}
} else if (msg.message_type() == DATA_IS_USELESS) {
SendDataReadyToDownStream(msg.src_id());
}
}
REGISTER_INTERCEPTOR(Source, SourceInterceptor);
} // namespace distributed
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
namespace paddle {
namespace distributed {
/*
* Source interceptor
* There is only one source in the runtime graph
* Take charge of:
* 1. receive `start` message from carrier
* 2. send num_of_steps `data_is_ready` message to downstream
*/
class SourceInterceptor : public Interceptor {
public:
SourceInterceptor(int64_t interceptor_id, TaskNode* node);
private:
void SendDataReadyToDownStream(int64_t down_id);
void Run(const InterceptorMessage& msg);
int64_t max_run_times_;
// downstream_id->cur_step
std::map<int64_t, int64_t> downstream_step_;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/task_loop.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
namespace paddle {
namespace distributed {
thread_local TaskLoop* TaskLoop::thread_local_loop_ = nullptr;
TaskLoop* TaskLoop::GetTaskLoopOfCurrentThread() { return thread_local_loop_; }
TaskLoop::TaskLoop()
: looping_(false), quit_(false), thread_id_(std::this_thread::get_id()) {
PADDLE_ENFORCE_EQ(
thread_local_loop_,
nullptr,
platform::errors::AlreadyExists("Another TaskLoop is already init."));
thread_local_loop_ = this;
}
TaskLoop::~TaskLoop() { thread_local_loop_ = nullptr; }
void TaskLoop::Loop() {
PADDLE_ENFORCE_EQ(looping_,
false,
platform::errors::PreconditionNotMet(
"Loop can only execute in one loop thread"));
AssertInLoopThread();
looping_ = true;
quit_ = false;
while (!quit_) {
auto tasks = tasks_.PopAll();
for (auto& task : tasks) {
task();
}
}
looping_ = false;
}
void TaskLoop::Quit() {
quit_ = true;
if (!IsInLoopThread()) WakeUp();
}
void TaskLoop::RunInLoop(Functor cb) {
if (IsInLoopThread()) {
cb();
} else {
QueueInLoop(cb);
}
}
void TaskLoop::QueueInLoop(Functor cb) { tasks_.Push(cb); }
void TaskLoop::WakeUp() {
Functor task([] {});
QueueInLoop(task);
}
void TaskLoop::AbortNotInLoopThread() {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"This TaskLoop was created in thread %d, but current thread is %d",
thread_id_,
std::this_thread::get_id()));
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <future>
#include <map>
#include <thread>
#include <vector>
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle {
namespace distributed {
class TaskLoop {
public:
static TaskLoop* GetTaskLoopOfCurrentThread();
using Functor = std::function<void()>;
TaskLoop();
~TaskLoop();
void Loop();
void Quit();
void RunInLoop(Functor cb);
void QueueInLoop(Functor cb);
template <class F, class... Args>
auto Enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type> {
using return_type = typename std::result_of<F(Args...)>::type;
auto task = std::make_shared<std::packaged_task<return_type()>>(
std::bind(std::forward<F>(f), std::forward<Args>(args)...));
std::future<return_type> task_future = task->get_future();
tasks_.Push([task]() { (*task)(); });
return task_future;
}
void WakeUp();
bool IsInLoopThread() const {
return thread_id_ == std::this_thread::get_id();
}
void AssertInLoopThread() {
if (!IsInLoopThread()) {
AbortNotInLoopThread();
}
}
private:
DISABLE_COPY_AND_ASSIGN(TaskLoop);
void AbortNotInLoopThread();
static thread_local TaskLoop* thread_local_loop_;
bool looping_;
std::atomic<bool> quit_;
std::thread::id thread_id_;
framework::BlockingQueue<Functor> tasks_;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/task_loop_thread.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
namespace paddle {
namespace distributed {
TaskLoopThread::TaskLoopThread() : start_(false), loop_(nullptr) {}
TaskLoopThread::~TaskLoopThread() {
if (loop_ != nullptr) {
loop_->Quit();
thread_.join();
}
}
TaskLoop* TaskLoopThread::StartLoop() {
PADDLE_ENFORCE_EQ(
start_,
false,
platform::errors::PreconditionNotMet("thread is already running."));
start_ = true;
thread_ = std::thread([this]() { Loop(); });
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [=] { return loop_ != nullptr; });
return loop_;
}
void TaskLoopThread::Loop() {
TaskLoop loop;
{
std::unique_lock<std::mutex> lock(mutex_);
loop_ = &loop;
cv_.notify_one();
}
loop.Loop();
std::unique_lock<std::mutex> lock(mutex_);
loop_ = nullptr;
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <condition_variable>
#include <mutex>
#include <thread>
#include "paddle/fluid/platform/macros.h"
namespace paddle {
namespace distributed {
class TaskLoop;
class TaskLoopThread {
public:
TaskLoopThread();
~TaskLoopThread();
TaskLoop* StartLoop();
private:
DISABLE_COPY_AND_ASSIGN(TaskLoopThread);
void Loop();
bool start_;
TaskLoop* loop_;
std::thread thread_;
std::mutex mutex_;
std::condition_variable cv_;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop_thread.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
namespace paddle {
namespace distributed {
TaskLoopThreadPool::TaskLoopThreadPool() : TaskLoopThreadPool(1) {}
TaskLoopThreadPool::TaskLoopThreadPool(int thread_num)
: start_(false), thread_num_(thread_num) {}
TaskLoopThreadPool::~TaskLoopThreadPool() = default;
void TaskLoopThreadPool::Start() {
PADDLE_ENFORCE_EQ(
start_,
false,
platform::errors::PreconditionNotMet("thread pool is already start."));
PADDLE_ENFORCE_GT(
thread_num_,
0,
platform::errors::InvalidArgument(
"thread num must greater than 0, but now is %d", thread_num_));
start_ = true;
for (int i = 0; i < thread_num_; ++i) {
threads_.emplace_back(new TaskLoopThread());
loops_.push_back(threads_[i]->StartLoop());
}
}
TaskLoop* TaskLoopThreadPool::GetLoop(int tid) {
PADDLE_ENFORCE_EQ(
start_,
true,
platform::errors::PreconditionNotMet("thread pool must start first."));
PADDLE_ENFORCE_GE(
tid,
0,
platform::errors::OutOfRange("tid must >= 0, but now is %d", tid));
PADDLE_ENFORCE_LT(tid,
thread_num_,
platform::errors::OutOfRange(
"tid must < thread_num, but now tid=%d thread_num=%d",
tid,
thread_num_));
return loops_[tid];
}
std::vector<TaskLoop*> TaskLoopThreadPool::GetAllLoops() {
PADDLE_ENFORCE_EQ(
start_,
true,
platform::errors::PreconditionNotMet("thread pool must start first."));
return loops_;
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <vector>
#include "paddle/fluid/platform/macros.h"
namespace paddle {
namespace distributed {
class TaskLoop;
class TaskLoopThread;
class TaskLoopThreadPool {
public:
TaskLoopThreadPool();
explicit TaskLoopThreadPool(int thread_num);
~TaskLoopThreadPool();
void SetThreadNum(int thread_num) { thread_num_ = thread_num; }
void Start();
TaskLoop* GetLoop(int tid);
std::vector<TaskLoop*> GetAllLoops();
private:
DISABLE_COPY_AND_ASSIGN(TaskLoopThreadPool);
bool start_;
int thread_num_;
std::vector<std::unique_ptr<TaskLoopThread>> threads_;
std::vector<TaskLoop*> loops_;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace distributed {
namespace {
using OperatorBase = TaskNode::OperatorBase;
}
TaskNode::TaskNode(paddle::framework::ProgramDesc* program,
int64_t rank,
int64_t max_run_times,
int64_t max_slot_nums)
: program_(program),
rank_(rank),
max_run_times_(max_run_times),
max_slot_nums_(max_slot_nums) {
// Should be serially invoked, not thread-safe
// NOTE: when instantiate TaskNode with program, won't init task node
// immediately, since the provided program may be updated later (with
// high probability) by adding_feed_fetch_ops or by RuntimeGraph.
// So, delay the init part to the Init() function.
static int64_t task_node_cnt = 0;
task_id_ = task_node_cnt++;
}
TaskNode::TaskNode(paddle::framework::ProgramDesc* program, int64_t rank)
: program_(program), rank_(rank), task_id_(rank) {
max_run_times_ = 1;
max_slot_nums_ = 1;
LOG(INFO)
<< "Constructing TaskNode for DistModelInf. The TaskNode's id is: "
<< rank
<< ". And the TaskNode's max_run_time and max_slot_num will be set to 1.";
}
void TaskNode::SetProgram(paddle::framework::ProgramDesc* program) {
program_ = program;
}
void TaskNode::Init(bool use_feed_fetch_ops) {
if (!use_feed_fetch_ops) {
VLOG(3) << "TaskNode will be inited without feed and fetch ops";
}
if (ops_.empty()) {
// Q (for fleet executor dev): should we need another reset funct?
VLOG(3) << "Task node will be inited by calling Init().";
for (const auto& op_desc : program_->Block(0).AllOps()) {
if (!use_feed_fetch_ops &&
(op_desc->Type() == "feed" || op_desc->Type() == "fetch")) {
VLOG(3) << "TaskNode will skip [" << op_desc->Input("X")[0] << "], "
<< op_desc->Type() << " -> " << op_desc->Output("Out")[0];
continue;
}
ops_vec_.emplace_back(framework::OpRegistry::CreateOp(*op_desc));
}
for (const auto& op : ops_vec_) {
ops_.emplace_back(op.get());
}
}
}
TaskNode::TaskNode(int64_t rank, int64_t task_id, int64_t max_run_times)
: rank_(rank), task_id_(task_id), max_run_times_(max_run_times) {}
TaskNode::TaskNode(int32_t role,
const std::vector<framework::OpDesc*>& op_descs,
int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums)
: role_(role),
rank_(rank),
task_id_(task_id),
max_run_times_(max_run_times),
max_slot_nums_(max_slot_nums) {
if (op_descs.empty()) {
return;
}
VLOG(3) << "Task node will be inited by providing list of ops.";
for (const auto& desc : op_descs) {
ops_vec_.emplace_back(framework::OpRegistry::CreateOp(*desc));
}
for (const auto& op : ops_vec_) {
ops_.emplace_back(op.get());
}
}
TaskNode::TaskNode(int32_t role,
const std::vector<framework::OperatorBase*>& ops,
int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums)
: ops_(ops),
role_(role),
rank_(rank),
task_id_(task_id),
max_run_times_(max_run_times),
max_slot_nums_(max_slot_nums) {}
TaskNode::TaskNode(int32_t role,
int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums)
: role_(role),
rank_(rank),
task_id_(task_id),
max_run_times_(max_run_times),
max_slot_nums_(max_slot_nums) {}
bool TaskNode::AddUpstreamTask(int64_t task_id, int64_t buff_size) {
const auto& ret = upstream_.emplace(task_id, buff_size);
return ret.second;
}
bool TaskNode::AddDownstreamTask(int64_t task_id, int64_t buff_size) {
const auto& ret = downstream_.emplace(task_id, buff_size);
return ret.second;
}
std::string TaskNode::DebugString() const {
std::ostringstream os;
os << "role: " << role_ << ", task_id: " << task_id_ << "\n";
for (std::size_t i = 0; i < ops_.size(); ++i) {
os << ops_[i]->Type() << " ";
}
os << "\n";
return os.str();
}
void TaskNode::SetRunPerSteps(int64_t value) {
PADDLE_ENFORCE_GE(value,
1,
platform::errors::InvalidArgument(
"run_per_steps must >= 1, but received %ld", value));
run_per_steps_ = value;
}
void TaskNode::SetRunAtOffset(int64_t value) {
PADDLE_ENFORCE_GE(value,
0,
platform::errors::InvalidArgument(
"run_at_offset must >= 0, but received %ld", value));
run_at_offset_ = value;
}
void TaskNode::SetReplyUpPerSteps(int64_t value) {
PADDLE_ENFORCE_GE(
value,
1,
platform::errors::InvalidArgument(
"reply_up_per_steps must >= 1, but received %ld", value));
reply_up_per_steps_ = value;
}
void TaskNode::SetSendDownPerSteps(int64_t value) {
PADDLE_ENFORCE_GE(
value,
1,
platform::errors::InvalidArgument(
"send_down_per_steps must >= 1, but received %ld", value));
send_down_per_steps_ = value;
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstdint>
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle {
namespace framework {
class OperatorBase;
class OpDesc;
} // namespace framework
namespace distributed {
class TaskNode final {
public:
using OperatorBase = paddle::framework::OperatorBase;
TaskNode(int64_t rank, int64_t task_id, int64_t max_run_times);
TaskNode(int32_t role,
int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums);
TaskNode(int32_t role,
const std::vector<framework::OpDesc*>& op_descs,
int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums);
TaskNode(int32_t role,
const std::vector<framework::OperatorBase*>& ops,
int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums);
TaskNode(paddle::framework::ProgramDesc* program,
int64_t rank,
int64_t max_run_times,
int64_t max_slot_nums);
TaskNode(paddle::framework::ProgramDesc* program, int64_t rank);
~TaskNode() = default;
void SetProgram(paddle::framework::ProgramDesc* program);
void Init(bool use_feed_fetch_ops = true);
int64_t rank() const { return rank_; }
int64_t task_id() const { return task_id_; }
int32_t role() const { return role_; }
int64_t max_run_times() const { return max_run_times_; }
int64_t max_slot_nums() const { return max_slot_nums_; }
int64_t run_per_steps() const { return run_per_steps_; }
int64_t run_at_offset() const { return run_at_offset_; }
int64_t reply_up_per_steps() const { return reply_up_per_steps_; }
int64_t send_down_per_steps() const { return send_down_per_steps_; }
const std::unordered_map<int64_t, int64_t>& upstream() const {
return upstream_;
}
const std::unordered_map<int64_t, int64_t>& downstream() const {
return downstream_;
}
const std::string& type() const { return type_; }
const paddle::framework::ProgramDesc* program() const { return program_; }
const std::vector<OperatorBase*>& ops() const { return ops_; }
const std::vector<std::unique_ptr<OperatorBase>>& unique_ops() const {
return ops_vec_;
}
const std::unordered_map<const OperatorBase*, std::vector<std::string>>&
unused_vars() const {
return unused_vars_;
}
void SetRunPerSteps(int64_t value);
void SetRunAtOffset(int64_t value);
void SetReplyUpPerSteps(int64_t value);
void SetSendDownPerSteps(int64_t value);
void SetType(const std::string& type) { type_ = type; }
void SetUnusedVars(
const std::unordered_map<const OperatorBase*, std::vector<std::string>>&
unused_vars) {
unused_vars_ = unused_vars;
}
// upstream need buffs?
bool AddUpstreamTask(int64_t task_id, int64_t buff_size = 1);
bool AddDownstreamTask(int64_t task_id, int64_t buff_size = 1);
std::string DebugString() const;
private:
DISABLE_COPY_AND_ASSIGN(TaskNode);
TaskNode() = default;
// ops_ will be removed in the future
std::vector<OperatorBase*> ops_;
// task_id-->buff_size
std::unordered_map<int64_t, int64_t> upstream_;
std::unordered_map<int64_t, int64_t> downstream_;
framework::ProgramDesc* program_;
std::vector<std::unique_ptr<OperatorBase>> ops_vec_;
std::unordered_map<const OperatorBase*, std::vector<std::string>>
unused_vars_;
int32_t role_;
int64_t rank_;
int64_t task_id_;
int64_t max_run_times_;
int64_t max_slot_nums_;
int64_t run_per_steps_{1};
int64_t run_at_offset_{0};
// one input produces multi times output
int64_t reply_up_per_steps_{1};
// one output need multi times input
int64_t send_down_per_steps_{1};
std::string type_;
};
} // namespace distributed
} // namespace paddle
set_source_files_properties(
interceptor_ping_pong_test.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS})
cc_test(
interceptor_ping_pong_test
SRCS interceptor_ping_pong_test.cc
DEPS fleet_executor ${BRPC_DEPS})
set_source_files_properties(
compute_interceptor_test.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS})
cc_test(
compute_interceptor_test
SRCS compute_interceptor_test.cc
DEPS fleet_executor ${BRPC_DEPS})
set_source_files_properties(
source_interceptor_test.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS})
cc_test(
source_interceptor_test
SRCS source_interceptor_test.cc
DEPS fleet_executor ${BRPC_DEPS})
set_source_files_properties(
sink_interceptor_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(
sink_interceptor_test
SRCS sink_interceptor_test.cc
DEPS fleet_executor ${BRPC_DEPS})
set_source_files_properties(
interceptor_pipeline_short_path_test.cc
PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(
interceptor_pipeline_short_path_test
SRCS interceptor_pipeline_short_path_test.cc
DEPS fleet_executor ${BRPC_DEPS})
set_source_files_properties(
interceptor_pipeline_long_path_test.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS})
cc_test(
interceptor_pipeline_long_path_test
SRCS interceptor_pipeline_long_path_test.cc
DEPS fleet_executor ${BRPC_DEPS})
set_source_files_properties(
compute_interceptor_run_op_test.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS})
cc_test(
compute_interceptor_run_op_test
SRCS compute_interceptor_run_op_test.cc
DEPS fleet_executor
${BRPC_DEPS}
op_registry
fill_constant_op
elementwise_add_op
scope
device_context)
if(WITH_DISTRIBUTE
AND WITH_PSCORE
AND NOT (WITH_ASCEND OR WITH_ASCEND_CL))
set_source_files_properties(
interceptor_ping_pong_with_brpc_test.cc
PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(
interceptor_ping_pong_with_brpc_test
SRCS interceptor_ping_pong_with_brpc_test.cc
DEPS fleet_executor ${BRPC_DEPS})
endif()
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/phi/core/kernel_registry.h"
USE_OP_ITSELF(elementwise_add);
USE_OP_ITSELF(fill_constant);
PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT);
namespace paddle {
namespace distributed {
std::vector<framework::OperatorBase*> GetOps() {
framework::AttributeMap attrs;
attrs["dtype"] = framework::proto::VarType::FP32;
attrs["shape"] = phi::vectorize<int>({2, 3});
attrs["value"] = 1.0f;
auto zero_op = framework::OpRegistry::CreateOp(
"fill_constant", {}, {{"Out", {"x"}}}, attrs);
auto op = framework::OpRegistry::CreateOp("elementwise_add",
{{"X", {"x"}}, {"Y", {"x"}}},
{{"Out", {"out"}}},
framework::AttributeMap());
// NOTE: don't delete
return {zero_op.release(), op.release()};
}
framework::Scope* GetScope() {
framework::Scope* scope = new framework::Scope();
scope->Var("x")->GetMutable<framework::LoDTensor>();
scope->Var("out")->GetMutable<framework::LoDTensor>();
return scope;
}
TEST(ComputeInterceptor, Compute) {
std::vector<framework::OperatorBase*> ops = GetOps();
framework::Scope* scope = GetScope();
std::vector<framework::Scope*> scopes = {scope, scope};
platform::Place place = platform::CPUPlace();
std::string carrier_id = "0";
Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0, {{SOURCE_ID, 0}, {0, 0}, {1, 0}, {SINK_ID, 0}});
MessageBus* msg_bus = GlobalVal<MessageBus>::Create();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
// FIXME: don't delete, otherwise interceptor will use undefined node
TaskNode* source =
new TaskNode(0, SOURCE_ID, 2); // rank, task_id, max_run_times
TaskNode* node_a =
new TaskNode(0, ops, 0, 0, 2, 0); // role, ops, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 2, 0);
TaskNode* sink = new TaskNode(0, SINK_ID, 2);
// source->a->b->sink
source->AddDownstreamTask(0);
node_a->AddUpstreamTask(SOURCE_ID);
node_a->AddDownstreamTask(1);
node_b->AddUpstreamTask(0);
sink->AddUpstreamTask(1);
node_b->AddDownstreamTask(SINK_ID);
carrier->SetInterceptor(
SOURCE_ID, InterceptorFactory::Create("Source", SOURCE_ID, source));
auto* a = carrier->SetInterceptor(
0, InterceptorFactory::Create("Compute", 0, node_a));
carrier->SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b));
carrier->SetInterceptor(SINK_ID,
InterceptorFactory::Create("Sink", SINK_ID, sink));
a->SetPlace(place);
a->SetMicroBatchScope(scopes);
// start
InterceptorMessage msg;
msg.set_message_type(START);
msg.set_dst_id(SOURCE_ID);
carrier->EnqueueInterceptorMessage(msg);
carrier->Wait();
carrier->Release();
}
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace paddle {
namespace distributed {
class StartInterceptor : public Interceptor {
public:
StartInterceptor(int64_t interceptor_id, TaskNode* node)
: Interceptor(interceptor_id, node) {
RegisterMsgHandle([this](const InterceptorMessage& msg) { NOP(msg); });
}
void NOP(const InterceptorMessage& msg) {
if (msg.message_type() == STOP) {
stop_ = true;
InterceptorMessage stop;
stop.set_message_type(STOP);
Send(1, stop); // stop 1, compute
return;
}
std::cout << GetInterceptorId() << " recv msg from " << msg.src_id()
<< std::endl;
}
};
TEST(ComputeInterceptor, Compute) {
std::string carrier_id = "0";
Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}});
MessageBus* msg_bus = GlobalVal<MessageBus>::Create();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 3, 0);
TaskNode* node_c = new TaskNode(0, 0, 2, 3, 0);
// a->b->c
node_a->AddDownstreamTask(1, 3);
node_b->AddUpstreamTask(0, 3);
node_b->AddDownstreamTask(2);
node_c->AddUpstreamTask(1);
Interceptor* a =
carrier->SetInterceptor(0, std::make_unique<StartInterceptor>(0, node_a));
carrier->SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b));
carrier->SetInterceptor(2, InterceptorFactory::Create("Compute", 2, node_c));
InterceptorMessage msg;
msg.set_message_type(DATA_IS_READY);
// test run three times
a->Send(1, msg);
a->Send(1, msg);
a->Send(1, msg);
carrier->Wait();
carrier->Release();
}
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
namespace paddle {
namespace distributed {
class PingPongInterceptor : public Interceptor {
public:
PingPongInterceptor(int64_t interceptor_id, TaskNode* node)
: Interceptor(interceptor_id, node) {
RegisterMsgHandle([this](const InterceptorMessage& msg) { PingPong(msg); });
}
void PingPong(const InterceptorMessage& msg) {
if (msg.message_type() == STOP) {
stop_ = true;
return;
}
std::cout << GetInterceptorId() << " recv msg, count=" << count_
<< std::endl;
++count_;
if (count_ == 20) {
InterceptorMessage stop;
stop.set_message_type(STOP);
Send(0, stop);
Send(1, stop);
StopCarrier();
return;
}
InterceptorMessage resp;
Send(msg.src_id(), resp);
}
private:
int count_{0};
};
REGISTER_INTERCEPTOR(PingPong, PingPongInterceptor);
TEST(InterceptorTest, PingPong) {
std::string carrier_id = "0";
Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0, {{0, 0}, {1, 0}});
MessageBus* msg_bus = GlobalVal<MessageBus>::Create();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
Interceptor* a = carrier->SetInterceptor(
0, InterceptorFactory::Create("PingPong", 0, nullptr));
carrier->SetInterceptor(1, std::make_unique<PingPongInterceptor>(1, nullptr));
InterceptorMessage msg;
a->Send(1, msg);
carrier->Wait();
}
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <sys/socket.h>
#include <time.h>
#include <iostream>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
namespace paddle {
namespace distributed {
class PingPongInterceptor : public Interceptor {
public:
PingPongInterceptor(int64_t interceptor_id, TaskNode* node)
: Interceptor(interceptor_id, node) {
RegisterMsgHandle([this](const InterceptorMessage& msg) { PingPong(msg); });
}
void PingPong(const InterceptorMessage& msg) {
if (msg.message_type() == STOP) {
stop_ = true;
StopCarrier();
return;
}
std::cout << GetInterceptorId() << " recv msg, count=" << count_
<< std::endl;
++count_;
if (count_ == 20 && GetInterceptorId() == 0) {
InterceptorMessage stop;
stop.set_message_type(STOP);
Send(0, stop);
Send(1, stop);
return;
}
InterceptorMessage resp;
int64_t dst = GetInterceptorId() == 0 ? 1 : 0;
Send(dst, resp);
}
private:
int count_{0};
};
REGISTER_INTERCEPTOR(PingPong, PingPongInterceptor);
TEST(InterceptorTest, PingPong) {
std::cout << "Ping pong test through brpc" << std::endl;
unsigned int seed = time(0);
// random generated two ports in from 6000 to 9000
int port0 = 6000 + rand_r(&seed) % 3000;
int port1 = port0 + 1;
// using socket to check the availability of the port
int server_fd = -1;
server_fd = socket(AF_INET, SOCK_STREAM, 0);
int opt = 1;
linger ling;
ling.l_onoff = 1;
ling.l_linger = 0;
setsockopt(server_fd, SOL_SOCKET, SO_LINGER, &ling, sizeof(ling));
setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt));
struct sockaddr_in address;
address.sin_family = AF_INET;
address.sin_addr.s_addr = INADDR_ANY;
address.sin_port = htons(port0);
while (bind(server_fd, (struct sockaddr*)&address, sizeof(address)) == -1) {
port0++;
address.sin_port = htons(port0);
}
close(server_fd);
// use another socket to check another port
server_fd = socket(AF_INET, SOCK_STREAM, 0);
setsockopt(server_fd, SOL_SOCKET, SO_LINGER, &ling, sizeof(ling));
setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt));
port1 = port0 + 1;
address.sin_port = htons(port1);
while (bind(server_fd, (struct sockaddr*)&address, sizeof(address)) == -1) {
port1++;
address.sin_port = htons(port1);
}
close(server_fd);
std::string ip0 = "127.0.0.1:" + std::to_string(port0);
std::string ip1 = "127.0.0.1:" + std::to_string(port1);
std::cout << "ip0: " << ip0 << std::endl;
std::cout << "ip1: " << ip1 << std::endl;
std::unordered_map<int64_t, int64_t> interceptor_id_to_rank = {{0, 0},
{1, 1}};
std::string carrier_id = "0";
int pid = fork();
if (pid == 0) {
Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
GlobalVal<std::string>::Set(new std::string(carrier_id));
MessageBus* msg_bus = GlobalVal<MessageBus>::Create();
msg_bus->Init(0, {{0, ip0}, {1, ip1}}, ip0);
carrier->Init(0, interceptor_id_to_rank);
Interceptor* a = carrier->SetInterceptor(
0, InterceptorFactory::Create("PingPong", 0, nullptr));
msg_bus->Barrier();
InterceptorMessage msg;
a->Send(1, msg);
carrier->Wait();
} else {
Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
GlobalVal<std::string>::Set(new std::string(carrier_id));
MessageBus* msg_bus = GlobalVal<MessageBus>::Create();
msg_bus->Init(1, {{0, ip0}, {1, ip1}}, ip1);
carrier->Init(1, interceptor_id_to_rank);
carrier->SetInterceptor(1,
InterceptorFactory::Create("PingPong", 1, nullptr));
msg_bus->Barrier();
carrier->Wait();
}
}
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace paddle {
namespace distributed {
void LinkNodes(const std::vector<TaskNode*>& nodes) {
size_t size = nodes.size();
if (size <= 1) return;
{ // i = 0
TaskNode* now = nodes[0];
TaskNode* next = nodes[1];
now->AddDownstreamTask(next->task_id());
}
{ // i = size - 1
TaskNode* prev = nodes[size - 2];
TaskNode* now = nodes[size - 1];
now->AddUpstreamTask(prev->task_id());
}
for (size_t i = 1; i < size - 1; ++i) {
TaskNode* prev = nodes[i - 1];
TaskNode* now = nodes[i];
TaskNode* next = nodes[i + 1];
now->AddUpstreamTask(prev->task_id());
now->AddDownstreamTask(next->task_id());
}
}
TEST(AmplifierInterceptor, Amplifier) {
std::string carrier_id = "0";
Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0,
{{SOURCE_ID, 0},
{0, 0},
{1, 0},
{2, 0},
{3, 0},
{4, 0},
{5, 0},
{SINK_ID, 0}});
MessageBus* msg_bus = GlobalVal<MessageBus>::Create();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "127.0.0.0:0");
int64_t micro_steps = 3;
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* source =
new TaskNode(0, SOURCE_ID, micro_steps); // rank, task_id, max_run_times
TaskNode* node_a = new TaskNode(0, 0, 0, 1, 0); // role, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 1, 0);
TaskNode* node_c = new TaskNode(0, 0, 2, 1, 0);
TaskNode* node_d = new TaskNode(0, 0, 3, 1, 0);
TaskNode* node_e = new TaskNode(0, 0, 4, 1, 0);
TaskNode* node_f = new TaskNode(0, 0, 5, 1, 0);
TaskNode* sink = new TaskNode(0, SINK_ID, micro_steps);
// source->a->b->c->d->e->f->sink
LinkNodes({source, node_a, node_b, node_c, node_d, node_e, node_f, sink});
// LR->b(1:3)->F->B->e(3:1)->U
node_b->SetReplyUpPerSteps(micro_steps);
node_e->SetSendDownPerSteps(micro_steps);
carrier->SetInterceptor(
SOURCE_ID, InterceptorFactory::Create("Source", SOURCE_ID, source));
carrier->SetInterceptor(0, InterceptorFactory::Create("Compute", 0, node_a));
carrier->SetInterceptor(1,
InterceptorFactory::Create("Amplifier", 1, node_b));
carrier->SetInterceptor(2, InterceptorFactory::Create("Compute", 2, node_c));
carrier->SetInterceptor(3, InterceptorFactory::Create("Compute", 3, node_d));
carrier->SetInterceptor(4,
InterceptorFactory::Create("Amplifier", 4, node_e));
carrier->SetInterceptor(5, InterceptorFactory::Create("Compute", 5, node_f));
carrier->SetInterceptor(SINK_ID,
InterceptorFactory::Create("Sink", SINK_ID, sink));
// start
InterceptorMessage msg;
msg.set_message_type(START);
msg.set_dst_id(SOURCE_ID);
carrier->EnqueueInterceptorMessage(msg);
carrier->Wait();
carrier->Release();
}
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace paddle {
namespace distributed {
int64_t GetBuffSize(
const std::map<std::pair<TaskNode*, TaskNode*>, int64_t> buffs,
TaskNode* from,
TaskNode* to) {
if (buffs.find({from, to}) != buffs.end()) {
return buffs.at({from, to});
}
if (buffs.find({to, from}) != buffs.end()) {
return buffs.at({to, from});
}
return 2; // set default 2
}
void LinkNodes(const std::vector<TaskNode*>& nodes,
const std::map<std::pair<TaskNode*, TaskNode*>, int64_t> buffs) {
size_t size = nodes.size();
if (size <= 1) return;
{ // i = 0
TaskNode* now = nodes[0];
TaskNode* next = nodes[1];
auto buff_size = GetBuffSize(buffs, now, next);
now->AddDownstreamTask(next->task_id(), buff_size);
}
{ // i = size - 1
TaskNode* prev = nodes[size - 2];
TaskNode* now = nodes[size - 1];
auto buff_size = GetBuffSize(buffs, prev, now);
now->AddUpstreamTask(prev->task_id(), buff_size);
}
for (size_t i = 1; i < size - 1; ++i) {
TaskNode* prev = nodes[i - 1];
TaskNode* now = nodes[i];
TaskNode* next = nodes[i + 1];
auto buff_size = GetBuffSize(buffs, prev, now);
now->AddUpstreamTask(prev->task_id(), buff_size);
buff_size = GetBuffSize(buffs, now, next);
now->AddDownstreamTask(next->task_id(), buff_size);
}
}
TEST(AmplifierInterceptor, Amplifier) {
std::string carrier_id = "0";
Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0,
{{SOURCE_ID, 0}, {0, 0}, {1, 0}, {2, 0}, {3, 0}, {SINK_ID, 0}});
MessageBus* msg_bus = GlobalVal<MessageBus>::Create();
msg_bus->Init(0, {{0, ""}}, "");
int64_t micro_steps = 6;
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* source =
new TaskNode(0, SOURCE_ID, micro_steps); // rank, task_id, max_run_times
TaskNode* node_a =
new TaskNode(0, 0, 0, micro_steps, 0); // role, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 3, 0);
TaskNode* node_c = new TaskNode(0, 0, 2, 3, 0);
TaskNode* node_d = new TaskNode(0, 0, 3, micro_steps, 0);
TaskNode* sink = new TaskNode(0, SINK_ID, micro_steps);
// source->a->b->c->d->sink
// LR->F->B->U
LinkNodes({source, node_a, node_b, node_c, node_d, sink},
{{{node_b, node_c}, 1}});
node_a->SetRunPerSteps(micro_steps);
node_d->SetRunPerSteps(micro_steps);
node_d->SetRunAtOffset(micro_steps - 1);
carrier->SetInterceptor(
SOURCE_ID, InterceptorFactory::Create("Source", SOURCE_ID, source));
carrier->SetInterceptor(0,
InterceptorFactory::Create("Amplifier", 0, node_a));
carrier->SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b));
carrier->SetInterceptor(2, InterceptorFactory::Create("Compute", 2, node_c));
carrier->SetInterceptor(3,
InterceptorFactory::Create("Amplifier", 3, node_d));
carrier->SetInterceptor(SINK_ID,
InterceptorFactory::Create("Sink", SINK_ID, sink));
// start
InterceptorMessage msg;
msg.set_message_type(START);
msg.set_dst_id(SOURCE_ID);
carrier->EnqueueInterceptorMessage(msg);
carrier->Wait();
carrier->Release();
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace paddle {
namespace distributed {
class FakeInterceptor : public Interceptor {
public:
FakeInterceptor(int64_t interceptor_id, TaskNode* node)
: Interceptor(interceptor_id, node) {
RegisterMsgHandle([this](const InterceptorMessage& msg) { NOP(msg); });
}
void NOP(const InterceptorMessage& msg) {
if (msg.message_type() == DATA_IS_READY) {
std::cout << "FakeInterceptor run in scope " << msg.scope_idx()
<< std::endl;
InterceptorMessage reply;
reply.set_message_type(DATA_IS_USELESS);
Send(SOURCE_ID, reply);
InterceptorMessage ready;
ready.set_message_type(DATA_IS_READY);
Send(SINK_ID, ready);
} else if (msg.message_type() == DATA_IS_USELESS) {
std::cout << "FakeInterceptor remove result in scope " << msg.scope_idx()
<< std::endl;
}
}
private:
int64_t step_;
};
TEST(SourceInterceptor, Source) {
std::string carrier_id = "0";
Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0, {{SOURCE_ID, 0}, {0, 0}, {SINK_ID, 0}});
MessageBus* msg_bus = GlobalVal<MessageBus>::Create();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* source =
new TaskNode(0, SOURCE_ID, 0, 3, 0); // role, rank, task_id
TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id
TaskNode* sink = new TaskNode(0, SINK_ID, 0, 3, 0); // role, rank, task_id
source->AddDownstreamTask(0, 1);
node_a->AddUpstreamTask(SOURCE_ID, 1);
node_a->AddDownstreamTask(SINK_ID, 1);
sink->AddUpstreamTask(0, 1);
carrier->SetInterceptor(
SOURCE_ID, InterceptorFactory::Create("Source", SOURCE_ID, source));
carrier->SetInterceptor(0, std::make_unique<FakeInterceptor>(0, node_a));
carrier->SetInterceptor(SINK_ID,
InterceptorFactory::Create("Sink", SINK_ID, sink));
// start
InterceptorMessage msg;
msg.set_message_type(START);
msg.set_dst_id(SOURCE_ID);
carrier->EnqueueInterceptorMessage(msg);
carrier->Wait();
carrier->Release();
}
} // namespace distributed
} // namespace paddle
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment