Commit f0ef3442 authored by yuguo960516yuguo's avatar yuguo960516yuguo
Browse files

2.3.2-dtk-22.10.1

parent ad08b8ce
Pipeline #227 failed with stages
in 0 seconds
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace paddle {
namespace distributed {
Interceptor::Interceptor(int64_t interceptor_id, TaskNode* node)
: interceptor_id_(interceptor_id), node_(node) {}
Interceptor::~Interceptor() {
// FIXME(wangxi): throw in stop function
// std::lock_guard<std::mutex> lock(mutex_);
// PADDLE_ENFORCE_EQ(messages_.empty(), true,
// platform::errors::PreconditionNotMet(
// "Interceptor must destruct with messages empty"));
}
void Interceptor::RegisterMsgHandle(MsgHandle handle) { handle_ = handle; }
void Interceptor::Handle(const InterceptorMessage& msg) {
PADDLE_ENFORCE_NOT_NULL(handle_,
platform::errors::PreconditionNotMet(
"Message handle is not registered."));
handle_(msg);
}
void Interceptor::LoopOnce() {
std::deque<InterceptorMessage> tmp_messages;
{
std::lock_guard<std::mutex> lock(mutex_);
messages_.swap(tmp_messages);
}
PADDLE_ENFORCE_EQ(tmp_messages.empty(),
false,
platform::errors::PreconditionNotMet(
"tmp_messages must not empty in task loop"));
for (auto& msg : tmp_messages) {
const MessageType message_type = msg.message_type();
VLOG(3) << "Interceptor " << interceptor_id_ << " has received a message"
<< " from interceptor " << msg.src_id()
<< " with message: " << message_type << ".";
Handle(msg);
}
}
void Interceptor::StopCarrier() {
PADDLE_ENFORCE_NOT_NULL(
carrier_,
platform::errors::PreconditionNotMet("Carrier is not registered."));
carrier_->WakeUp();
}
void Interceptor::EnqueueRemoteInterceptorMessage(
const InterceptorMessage& message) {
// Called by Carrier, enqueue an InterceptorMessage to remote mailbox
VLOG(3) << "Enqueue message: " << message.message_type() << " into "
<< interceptor_id_ << "'s remote mailbox.";
bool empty = false;
{
std::lock_guard<std::mutex> lock(mutex_);
empty = messages_.empty();
messages_.emplace_back(message);
}
if (empty) {
loop_->QueueInLoop([this]() { LoopOnce(); });
}
}
bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) {
PADDLE_ENFORCE_NOT_NULL(
carrier_,
platform::errors::PreconditionNotMet("Carrier is not registered."));
msg.set_src_id(interceptor_id_);
msg.set_dst_id(dst_id);
return carrier_->Send(msg);
}
static InterceptorFactory::CreateInterceptorMap& GetInterceptorMap() {
static InterceptorFactory::CreateInterceptorMap interceptorMap;
return interceptorMap;
}
std::unique_ptr<Interceptor> InterceptorFactory::Create(const std::string& type,
int64_t id,
TaskNode* node) {
auto& interceptor_map = GetInterceptorMap();
auto iter = interceptor_map.find(type);
PADDLE_ENFORCE_NE(
iter,
interceptor_map.end(),
platform::errors::NotFound("interceptor %s is not register", type));
return iter->second(id, node);
}
void InterceptorFactory::Register(
const std::string& type, InterceptorFactory::CreateInterceptorFunc func) {
auto& interceptor_map = GetInterceptorMap();
interceptor_map.emplace(type, func);
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <condition_variable>
#include <deque>
#include <functional>
#include <map>
#include <memory>
#include <thread>
#include <vector>
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace framework {
class Scope;
class GarbageCollector;
} // namespace framework
namespace distributed {
class TaskNode;
class Carrier;
class TaskLoop;
constexpr int64_t SOURCE_ID = -1;
constexpr int64_t SINK_ID = -2;
class Interceptor {
public:
using MsgHandle = std::function<void(const InterceptorMessage&)>;
public:
Interceptor() = delete;
Interceptor(int64_t interceptor_id, TaskNode* node);
virtual ~Interceptor();
// register interceptor handle
void RegisterMsgHandle(MsgHandle handle);
void Handle(const InterceptorMessage& msg);
// return the interceptor id
int64_t GetInterceptorId() const { return interceptor_id_; }
// Called by Carrier, enqueue an InterceptorMessage to remote mailbox
void EnqueueRemoteInterceptorMessage(
const InterceptorMessage& interceptor_message);
bool Send(int64_t dst_id, InterceptorMessage& msg); // NOLINT
void SetPlace(const platform::Place& place) { place_ = place; }
void SetRootScope(framework::Scope* scope) { root_scope_ = scope; }
void SetMiniBatchScope(framework::Scope* scope) { minibatch_scope_ = scope; }
void SetMicroBatchScope(const std::vector<framework::Scope*>& scopes) {
microbatch_scopes_ = scopes;
}
void SetGC(const std::shared_ptr<framework::GarbageCollector>& gc) {
gc_ = gc;
}
void RegisterCarrier(Carrier* carrier) { carrier_ = carrier; }
void RegisterTaskLoop(TaskLoop* loop) { loop_ = loop; }
TaskNode* GetTaskNode() const { return node_; }
DISABLE_COPY_AND_ASSIGN(Interceptor);
protected:
// interceptor id, handed from above layer
int64_t interceptor_id_;
// node need to be handled by this interceptor
TaskNode* node_;
// for stop
bool stop_{false};
void StopCarrier();
// for runtime
platform::Place place_;
framework::Scope* root_scope_{nullptr};
framework::Scope* minibatch_scope_{nullptr};
std::vector<framework::Scope*> microbatch_scopes_{};
std::shared_ptr<framework::GarbageCollector> gc_{nullptr};
Carrier* carrier_;
TaskLoop* loop_;
private:
void LoopOnce();
// interceptor handle which process message
MsgHandle handle_{nullptr};
std::mutex mutex_;
std::deque<InterceptorMessage> messages_;
int64_t already_run_times_{0};
int64_t used_slot_nums_{0};
};
class InterceptorFactory {
public:
using CreateInterceptorFunc = std::unique_ptr<Interceptor> (*)(int64_t,
TaskNode*);
using CreateInterceptorMap =
std::unordered_map<std::string, CreateInterceptorFunc>;
static void Register(const std::string& type, CreateInterceptorFunc func);
static std::unique_ptr<Interceptor> Create(const std::string& type,
int64_t id,
TaskNode* node);
};
template <typename InterceptorClass>
std::unique_ptr<Interceptor> CreatorInterceptor(int64_t id, TaskNode* node) {
return std::make_unique<InterceptorClass>(id, node);
}
#define REGISTER_INTERCEPTOR(interceptor_type, interceptor_class) \
class __RegisterInterceptor_##interceptor_type { \
public: \
__RegisterInterceptor_##interceptor_type() { \
InterceptorFactory::Register(#interceptor_type, \
CreatorInterceptor<interceptor_class>); \
} \
void Touch() {} \
}; \
__RegisterInterceptor_##interceptor_type g_register_##interceptor_type; \
int TouchRegisterInterceptor_##interceptor_type() { \
g_register_##interceptor_type.Touch(); \
return 0; \
}
#define USE_INTERCEPTOR(interceptor_type) \
extern int TouchRegisterInterceptor_##interceptor_type(); \
UNUSED static int use_interceptor_##interceptor_type = \
TouchRegisterInterceptor_##interceptor_type();
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package paddle.distributed;
option cc_generic_services = true;
option cc_enable_arenas = true;
enum MessageType {
STOP = 1; // STOP an Interceptor
DATA_IS_READY = 2; // upstream data is ready
DATA_IS_USELESS = 3; // downstream has used the data
ERR = 4; // current Interceptor encounters error
RESET = 5; // reset the status
START = 6;
}
message InterceptorMessage {
optional sint64 src_id = 1 [ default = 0 ];
optional sint64 dst_id = 2 [ default = 0 ];
optional MessageType message_type = 3 [ default = RESET ];
optional bool ctrl_message = 4 [ default = false ];
optional int64 scope_idx = 5 [ default = 0 ];
}
message InterceptorResponse { optional bool rst = 1 [ default = false ]; }
service MessageService {
rpc ReceiveInterceptorMessage(InterceptorMessage)
returns (InterceptorResponse);
rpc IncreaseBarrierCount(InterceptorMessage) returns (InterceptorResponse);
}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include <chrono>
#include <memory>
#include <set>
#include <thread>
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h"
namespace paddle {
namespace distributed {
void MessageBus::Init(
int64_t rank,
const std::unordered_map<int64_t, std::string>& rank_to_addr,
const std::string& addr) {
PADDLE_ENFORCE_EQ(
is_init_,
false,
platform::errors::AlreadyExists("MessageBus is already init."));
rank_ = rank;
is_init_ = true;
rank_to_addr_ = rank_to_addr;
addr_ = addr;
if (addr_ != "") {
const auto& addr = GetAddr(rank_);
PADDLE_ENFORCE_EQ(addr,
addr_,
platform::errors::Fatal(
"The current rank's addr is %s, while the "
"message bus's addr is %s, which are different. "
"Init error.",
addr,
addr_));
}
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_ASCEND_CL)
// NOTE: To make the brpc is compatible with collective,
// need release the handler holding the ip address.
if (addr_ != "") {
VLOG(3) << "Message bus is releasing the fd held by gen_comm_id.";
paddle::platform::SocketServer& socket_server =
paddle::platform::SocketServer::GetInstance(addr_);
int server_fd = socket_server.socket();
if (server_fd != -1) {
socket_server.Release();
}
}
#endif
ListenPort();
}
bool MessageBus::IsInit() const { return is_init_; }
MessageBus::~MessageBus() {
VLOG(3) << "Message bus releases resource.";
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
server_.Stop(1000);
server_.Join();
#endif
}
const std::string& MessageBus::GetAddr(int64_t rank) const {
PADDLE_ENFORCE_NE(
rank_to_addr_.find(rank),
rank_to_addr_.end(),
platform::errors::NotFound("Cannot find addr rank id %lld.", rank));
return rank_to_addr_.at(rank);
}
bool MessageBus::Send(int64_t dst_rank,
const InterceptorMessage& interceptor_message) {
PADDLE_ENFORCE_EQ(
IsInit(),
true,
platform::errors::PreconditionNotMet(
"Using message bus since it has not been initialized."));
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
int retry_time = 0; // message bus will retry sending for 10 times
while (retry_time < 10) {
++retry_time;
if (SendInterRank(dst_rank, interceptor_message)) {
VLOG(3) << "Message bus sends inter rank successfully with " << retry_time
<< " times retries.";
return true;
}
VLOG(3) << "Message bus sends failed, retry after 1 seconds.";
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
}
VLOG(3) << "Message bus sends inter rank fail after 10 times retries.";
return false;
#else
PADDLE_THROW(platform::errors::Unavailable(
"Fleet executor does not support sending message between different "
"ranks when Paddle is compiled with npu or "
"isn't compiled with distributed for now."));
#endif
return true;
}
void MessageBus::IncreaseBarrierCount() {
VLOG(3) << "IncreaseBarrierCount";
{
std::unique_lock<std::mutex> lock(mutex_);
++count_;
cv_.notify_one();
}
VLOG(3) << "End IncreaseBarrierCount";
}
void MessageBus::Barrier() {
// gather to root
if (rank_ != 0) {
InterceptorMessage ctrl_msg;
ctrl_msg.set_ctrl_message(true);
ctrl_msg.set_src_id(rank_);
ctrl_msg.set_dst_id(0);
VLOG(3) << "Barrier Gather ctrl message from " << rank_ << " to 0";
while (!Send(0, ctrl_msg)) {
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
}
} else {
VLOG(3) << "Barrier 0 wait others rank ready";
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [this] {
return count_ == static_cast<int>(rank_to_addr_.size() - 1);
});
count_ = 0;
}
// scatter from root
if (rank_ == 0) {
for (int i = 1; i < static_cast<int>(rank_to_addr_.size()); ++i) {
InterceptorMessage ctrl_msg;
ctrl_msg.set_ctrl_message(true);
ctrl_msg.set_src_id(0);
ctrl_msg.set_dst_id(i);
VLOG(3) << "Barrier Scatter ctrl message from 0 to " << i;
while (!Send(i, ctrl_msg)) {
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
}
}
} else {
VLOG(3) << "Barrier " << rank_ << " wait others rank ready";
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [this] { return count_ == 1; });
count_ = 0;
}
}
bool MessageBus::DispatchMsgToCarrier(
const InterceptorMessage& interceptor_message) {
const std::string& carrier_id = *GlobalVal<std::string>::Get();
return GlobalMap<std::string, Carrier>::Get(carrier_id)
->EnqueueInterceptorMessage(interceptor_message);
}
void MessageBus::ListenPort() {
if (addr_ == "") {
LOG(INFO) << "No need listen to port since training on single card.";
return;
}
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
// function keep listen the port and handle the message
PADDLE_ENFORCE_EQ(
server_.AddService(&message_service_, brpc::SERVER_DOESNT_OWN_SERVICE),
0,
platform::errors::Unavailable("Message bus: init brpc service error."));
// start the server
const char* ip_for_brpc = addr_.c_str();
brpc::ServerOptions options;
options.idle_timeout_sec = -1;
int retry_times = 0;
int interval = 100;
while (server_.Start(ip_for_brpc, &options) != 0) {
++retry_times;
LOG(INFO) << "Message bus is retring for starting brpc for " << retry_times
<< " times. And will retry after " << interval / 1000
<< " seconds.";
std::this_thread::sleep_for(std::chrono::milliseconds(interval));
interval += 500;
}
LOG(INFO) << "Message bus's listen port thread starts successful.";
#else
LOG(WARNING)
<< "Fleet executor's ListenPort() is a fake function when Paddle is "
"compiled with npu or Paddle isn't compiled "
"with distributed for now.";
#endif
}
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
bool MessageBus::SendInterRank(int64_t dst_rank,
const InterceptorMessage& interceptor_message) {
const auto& dst_addr = GetAddr(dst_rank);
VLOG(3) << "Message bus sending to addr: " << dst_addr;
const char* dst_addr_for_brpc = dst_addr.c_str();
brpc::Channel channel;
brpc::ChannelOptions options;
options.protocol = "baidu_std";
options.connect_timeout_ms = 1000;
options.timeout_ms = 1000;
options.max_retry = 5;
PADDLE_ENFORCE_EQ(
channel.Init(dst_addr_for_brpc, &options),
0,
platform::errors::Unavailable("Message bus: init brpc channel error."));
MessageService_Stub stub(&channel);
InterceptorResponse response;
brpc::Controller ctrl;
ctrl.set_log_id(0);
if (interceptor_message.ctrl_message()) {
stub.IncreaseBarrierCount(&ctrl, &interceptor_message, &response, NULL);
} else {
stub.ReceiveInterceptorMessage(
&ctrl, &interceptor_message, &response, NULL);
}
if (!ctrl.Failed()) {
if (response.rst()) {
VLOG(3) << "Message bus: brpc sends success.";
return true;
} else {
VLOG(4) << "Message bus: InterceptorMessageService error.";
return false;
}
} else {
VLOG(4) << "Message bus: brpc sends failed with error text: "
<< ctrl.ErrorText();
return false;
}
}
#endif
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <condition_variable>
#include <mutex>
#include <string>
#include <thread>
#include <unordered_map>
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
#include "brpc/channel.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/fleet_executor/message_service.h"
#endif
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle {
namespace distributed {
class Carrier;
// A singleton MessageBus
class MessageBus final {
public:
MessageBus() = default;
~MessageBus();
void Init(int64_t rank,
const std::unordered_map<int64_t, std::string>& rank_to_addr,
const std::string& addr);
bool IsInit() const;
// called by Interceptor, send InterceptorMessage to dst
bool Send(int64_t dst_rank, const InterceptorMessage& interceptor_message);
void IncreaseBarrierCount();
void Barrier();
bool DispatchMsgToCarrier(const InterceptorMessage& interceptor_message);
private:
DISABLE_COPY_AND_ASSIGN(MessageBus);
// function keep listen the port and handle the message
void ListenPort();
const std::string& GetAddr(int64_t rank) const;
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
// send the message inter rank (dst is different rank with src)
bool SendInterRank(int64_t dst_rank,
const InterceptorMessage& interceptor_message);
#endif
bool is_init_{false};
int64_t rank_;
// handed by above layer, save the info mapping rank id to addr
std::unordered_map<int64_t, std::string> rank_to_addr_;
// the ip needs to be listened
std::string addr_;
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
MessageServiceImpl message_service_;
// brpc server
brpc::Server server_;
#endif
// for barrier
std::mutex mutex_;
std::condition_variable cv_;
int count_{0};
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
#include "paddle/fluid/distributed/fleet_executor/message_service.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
namespace paddle {
namespace distributed {
void MessageServiceImpl::ReceiveInterceptorMessage(
google::protobuf::RpcController* control_base,
const InterceptorMessage* request,
InterceptorResponse* response,
google::protobuf::Closure* done) {
brpc::ClosureGuard done_guard(done);
VLOG(3) << "Message Service receives a message from interceptor "
<< request->src_id() << " to interceptor " << request->dst_id()
<< ", with the message: " << request->message_type();
bool flag = GlobalVal<MessageBus>::Get()->DispatchMsgToCarrier(*request);
response->set_rst(flag);
}
void MessageServiceImpl::IncreaseBarrierCount(
google::protobuf::RpcController* control_base,
const InterceptorMessage* request,
InterceptorResponse* response,
google::protobuf::Closure* done) {
brpc::ClosureGuard done_guard(done);
VLOG(3) << "Barrier Service receives a message from rank "
<< request->src_id() << " to rank " << request->dst_id();
GlobalVal<MessageBus>::Get()->IncreaseBarrierCount();
response->set_rst(true);
}
} // namespace distributed
} // namespace paddle
#endif
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
#pragma once
#include "brpc/server.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h"
namespace paddle {
namespace distributed {
class MessageServiceImpl : public MessageService {
public:
MessageServiceImpl() {}
virtual ~MessageServiceImpl() {}
virtual void ReceiveInterceptorMessage(
google::protobuf::RpcController* control_base,
const InterceptorMessage* request,
InterceptorResponse* response,
google::protobuf::Closure* done);
virtual void IncreaseBarrierCount(
google::protobuf::RpcController* control_base,
const InterceptorMessage* request,
InterceptorResponse* response,
google::protobuf::Closure* done);
};
} // namespace distributed
} // namespace paddle
#endif
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/runtime_graph.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace paddle {
namespace distributed {
std::string RuntimeGraph::DebugString() const {
std::ostringstream os;
os << "\nRuntime Graph Debug: \n";
for (const auto& pair : interceptor_id_to_node_) {
os << pair.second->DebugString();
os << "\n";
}
return os.str();
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/fleet_executor/fleet_executor_desc.pb.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle {
namespace distributed {
class TaskNode;
class RuntimeGraph final {
public:
RuntimeGraph() = default;
~RuntimeGraph() = default;
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node() const {
return interceptor_id_to_node_;
}
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank() const {
return interceptor_id_to_rank_;
}
void SetInterceptorIdToRank(
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank) {
interceptor_id_to_rank_ = interceptor_id_to_rank;
}
void SetInterceptorIdToNode(
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node) {
interceptor_id_to_node_ = interceptor_id_to_node;
}
std::string DebugString() const;
private:
DISABLE_COPY_AND_ASSIGN(RuntimeGraph);
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_node_;
std::unordered_map<int64_t, int64_t> interceptor_id_to_rank_;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/sink_interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace paddle {
namespace distributed {
SinkInterceptor::SinkInterceptor(int64_t interceptor_id, TaskNode* node)
: Interceptor(interceptor_id, node), max_run_times_(node->max_run_times()) {
// prepare the upstream running status
for (const auto& up : node->upstream()) {
upstream_step_.emplace(up.first, 0);
}
RegisterMsgHandle([this](const InterceptorMessage& msg) { Run(msg); });
}
void SinkInterceptor::StopCarrierIfComplete() {
bool flag = true;
for (const auto& up : upstream_step_) {
flag = flag && (up.second == max_run_times_);
}
if (flag) {
VLOG(3) << "Sink Interceptor is stopping carrier";
StopCarrier();
for (const auto& up : upstream_step_) {
upstream_step_.at(up.first) = 0;
}
}
}
void SinkInterceptor::ReplyCompletedToUpStream(int64_t upstream_id) {
int64_t micro_step = upstream_step_.at(upstream_id);
int64_t scope_idx = micro_step % max_run_times_;
InterceptorMessage msg;
msg.set_message_type(DATA_IS_USELESS);
msg.set_scope_idx(scope_idx);
Send(upstream_id, msg);
upstream_step_.at(upstream_id) = micro_step + 1;
if (micro_step == max_run_times_ - 1) {
StopCarrierIfComplete();
}
}
void SinkInterceptor::Run(const InterceptorMessage& msg) {
if (msg.message_type() == DATA_IS_READY) {
ReplyCompletedToUpStream(msg.src_id());
}
}
REGISTER_INTERCEPTOR(Sink, SinkInterceptor);
} // namespace distributed
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
namespace paddle {
namespace distributed {
/*
* Sink interceptor
* There is only one sink in the runtime graph
* Take charge of:
* 1. record the num of micro-step
* 2. check whether to notify carrier the current step is finished
*/
class SinkInterceptor : public Interceptor {
public:
SinkInterceptor(int64_t interceptor_id, TaskNode* node);
private:
void ReplyCompletedToUpStream(int64_t up_id);
void Run(const InterceptorMessage& msg);
void StopCarrierIfComplete();
int64_t max_run_times_;
// upstream_id->cur_step
std::map<int64_t, int64_t> upstream_step_;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/source_interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace paddle {
namespace distributed {
SourceInterceptor::SourceInterceptor(int64_t interceptor_id, TaskNode* node)
: Interceptor(interceptor_id, node), max_run_times_(node->max_run_times()) {
// prepare the downstream running status
for (const auto& down : node->downstream()) {
downstream_step_.emplace(down.first, 0);
}
RegisterMsgHandle([this](const InterceptorMessage& msg) { Run(msg); });
}
void SourceInterceptor::SendDataReadyToDownStream(int64_t downstream_id) {
int64_t micro_step = downstream_step_.at(downstream_id);
if (micro_step >= max_run_times_) {
return;
}
int64_t scope_idx = micro_step % max_run_times_;
InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_IS_READY);
ready_msg.set_scope_idx(scope_idx);
Send(downstream_id, ready_msg);
downstream_step_.at(downstream_id) = micro_step + 1;
}
void SourceInterceptor::Run(const InterceptorMessage& msg) {
if (msg.message_type() == START) {
// start run in a new step, reset the previous running status
for (const auto& down : downstream_step_) {
downstream_step_.at(down.first) = 0;
SendDataReadyToDownStream(down.first);
}
} else if (msg.message_type() == DATA_IS_USELESS) {
SendDataReadyToDownStream(msg.src_id());
}
}
REGISTER_INTERCEPTOR(Source, SourceInterceptor);
} // namespace distributed
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
namespace paddle {
namespace distributed {
/*
* Source interceptor
* There is only one source in the runtime graph
* Take charge of:
* 1. receive `start` message from carrier
* 2. send num_of_steps `data_is_ready` message to downstream
*/
class SourceInterceptor : public Interceptor {
public:
SourceInterceptor(int64_t interceptor_id, TaskNode* node);
private:
void SendDataReadyToDownStream(int64_t down_id);
void Run(const InterceptorMessage& msg);
int64_t max_run_times_;
// downstream_id->cur_step
std::map<int64_t, int64_t> downstream_step_;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/task_loop.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
namespace paddle {
namespace distributed {
thread_local TaskLoop* TaskLoop::thread_local_loop_ = nullptr;
TaskLoop* TaskLoop::GetTaskLoopOfCurrentThread() { return thread_local_loop_; }
TaskLoop::TaskLoop()
: looping_(false), quit_(false), thread_id_(std::this_thread::get_id()) {
PADDLE_ENFORCE_EQ(
thread_local_loop_,
nullptr,
platform::errors::AlreadyExists("Another TaskLoop is already init."));
thread_local_loop_ = this;
}
TaskLoop::~TaskLoop() { thread_local_loop_ = nullptr; }
void TaskLoop::Loop() {
PADDLE_ENFORCE_EQ(looping_,
false,
platform::errors::PreconditionNotMet(
"Loop can only execute in one loop thread"));
AssertInLoopThread();
looping_ = true;
quit_ = false;
while (!quit_) {
auto tasks = tasks_.PopAll();
for (auto& task : tasks) {
task();
}
}
looping_ = false;
}
void TaskLoop::Quit() {
quit_ = true;
if (!IsInLoopThread()) WakeUp();
}
void TaskLoop::RunInLoop(Functor cb) {
if (IsInLoopThread()) {
cb();
} else {
QueueInLoop(cb);
}
}
void TaskLoop::QueueInLoop(Functor cb) { tasks_.Push(cb); }
void TaskLoop::WakeUp() {
Functor task([] {});
QueueInLoop(task);
}
void TaskLoop::AbortNotInLoopThread() {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"This TaskLoop was created in thread %d, but current thread is %d",
thread_id_,
std::this_thread::get_id()));
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <future>
#include <map>
#include <thread>
#include <vector>
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle {
namespace distributed {
class TaskLoop {
public:
static TaskLoop* GetTaskLoopOfCurrentThread();
using Functor = std::function<void()>;
TaskLoop();
~TaskLoop();
void Loop();
void Quit();
void RunInLoop(Functor cb);
void QueueInLoop(Functor cb);
template <class F, class... Args>
auto Enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type> {
using return_type = typename std::result_of<F(Args...)>::type;
auto task = std::make_shared<std::packaged_task<return_type()>>(
std::bind(std::forward<F>(f), std::forward<Args>(args)...));
std::future<return_type> task_future = task->get_future();
tasks_.Push([task]() { (*task)(); });
return task_future;
}
void WakeUp();
bool IsInLoopThread() const {
return thread_id_ == std::this_thread::get_id();
}
void AssertInLoopThread() {
if (!IsInLoopThread()) {
AbortNotInLoopThread();
}
}
private:
DISABLE_COPY_AND_ASSIGN(TaskLoop);
void AbortNotInLoopThread();
static thread_local TaskLoop* thread_local_loop_;
bool looping_;
std::atomic<bool> quit_;
std::thread::id thread_id_;
framework::BlockingQueue<Functor> tasks_;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/task_loop_thread.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
namespace paddle {
namespace distributed {
TaskLoopThread::TaskLoopThread() : start_(false), loop_(nullptr) {}
TaskLoopThread::~TaskLoopThread() {
if (loop_ != nullptr) {
loop_->Quit();
thread_.join();
}
}
TaskLoop* TaskLoopThread::StartLoop() {
PADDLE_ENFORCE_EQ(
start_,
false,
platform::errors::PreconditionNotMet("thread is already running."));
start_ = true;
thread_ = std::thread([this]() { Loop(); });
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [=] { return loop_ != nullptr; });
return loop_;
}
void TaskLoopThread::Loop() {
TaskLoop loop;
{
std::unique_lock<std::mutex> lock(mutex_);
loop_ = &loop;
cv_.notify_one();
}
loop.Loop();
std::unique_lock<std::mutex> lock(mutex_);
loop_ = nullptr;
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <condition_variable>
#include <mutex>
#include <thread>
#include "paddle/fluid/platform/macros.h"
namespace paddle {
namespace distributed {
class TaskLoop;
class TaskLoopThread {
public:
TaskLoopThread();
~TaskLoopThread();
TaskLoop* StartLoop();
private:
DISABLE_COPY_AND_ASSIGN(TaskLoopThread);
void Loop();
bool start_;
TaskLoop* loop_;
std::thread thread_;
std::mutex mutex_;
std::condition_variable cv_;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop.h"
#include "paddle/fluid/distributed/fleet_executor/task_loop_thread.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
namespace paddle {
namespace distributed {
TaskLoopThreadPool::TaskLoopThreadPool() : TaskLoopThreadPool(1) {}
TaskLoopThreadPool::TaskLoopThreadPool(int thread_num)
: start_(false), thread_num_(thread_num) {}
TaskLoopThreadPool::~TaskLoopThreadPool() = default;
void TaskLoopThreadPool::Start() {
PADDLE_ENFORCE_EQ(
start_,
false,
platform::errors::PreconditionNotMet("thread pool is already start."));
PADDLE_ENFORCE_GT(
thread_num_,
0,
platform::errors::InvalidArgument(
"thread num must greater than 0, but now is %d", thread_num_));
start_ = true;
for (int i = 0; i < thread_num_; ++i) {
threads_.emplace_back(new TaskLoopThread());
loops_.push_back(threads_[i]->StartLoop());
}
}
TaskLoop* TaskLoopThreadPool::GetLoop(int tid) {
PADDLE_ENFORCE_EQ(
start_,
true,
platform::errors::PreconditionNotMet("thread pool must start first."));
PADDLE_ENFORCE_GE(
tid,
0,
platform::errors::OutOfRange("tid must >= 0, but now is %d", tid));
PADDLE_ENFORCE_LT(tid,
thread_num_,
platform::errors::OutOfRange(
"tid must < thread_num, but now tid=%d thread_num=%d",
tid,
thread_num_));
return loops_[tid];
}
std::vector<TaskLoop*> TaskLoopThreadPool::GetAllLoops() {
PADDLE_ENFORCE_EQ(
start_,
true,
platform::errors::PreconditionNotMet("thread pool must start first."));
return loops_;
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <vector>
#include "paddle/fluid/platform/macros.h"
namespace paddle {
namespace distributed {
class TaskLoop;
class TaskLoopThread;
class TaskLoopThreadPool {
public:
TaskLoopThreadPool();
explicit TaskLoopThreadPool(int thread_num);
~TaskLoopThreadPool();
void SetThreadNum(int thread_num) { thread_num_ = thread_num; }
void Start();
TaskLoop* GetLoop(int tid);
std::vector<TaskLoop*> GetAllLoops();
private:
DISABLE_COPY_AND_ASSIGN(TaskLoopThreadPool);
bool start_;
int thread_num_;
std::vector<std::unique_ptr<TaskLoopThread>> threads_;
std::vector<TaskLoop*> loops_;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace distributed {
namespace {
using OperatorBase = TaskNode::OperatorBase;
}
TaskNode::TaskNode(paddle::framework::ProgramDesc* program,
int64_t rank,
int64_t max_run_times,
int64_t max_slot_nums)
: program_(program),
rank_(rank),
max_run_times_(max_run_times),
max_slot_nums_(max_slot_nums) {
// Should be serially invoked, not thread-safe
// NOTE: when instantiate TaskNode with program, won't init task node
// immediately, since the provided program may be updated later (with
// high probability) by adding_feed_fetch_ops or by RuntimeGraph.
// So, delay the init part to the Init() function.
static int64_t task_node_cnt = 0;
task_id_ = task_node_cnt++;
}
TaskNode::TaskNode(paddle::framework::ProgramDesc* program,
int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums)
: program_(program),
rank_(rank),
task_id_(task_id),
max_run_times_(max_run_times),
max_slot_nums_(max_slot_nums) {
// TODO(liyurui): Will be removed when execute program is supported.
Init();
}
TaskNode::TaskNode(paddle::framework::ProgramDesc* program, int64_t rank)
: program_(program), rank_(rank), task_id_(rank) {
max_run_times_ = 1;
max_slot_nums_ = 1;
LOG(INFO)
<< "Constructing TaskNode for DistModelInf. The TaskNode's id is: "
<< rank
<< ". And the TaskNode's max_run_time and max_slot_num will be set to 1.";
}
void TaskNode::SetProgram(paddle::framework::ProgramDesc* program) {
program_ = program;
}
void TaskNode::Init(bool use_feed_fetch_ops) {
if (!use_feed_fetch_ops) {
VLOG(3) << "TaskNode will be inited without feed and fetch ops";
}
if (ops_.empty()) {
// Q (for fleet executor dev): should we need another reset funct?
VLOG(3) << "Task node will be inited by calling Init().";
for (const auto& op_desc : program_->Block(0).AllOps()) {
if (!use_feed_fetch_ops &&
(op_desc->Type() == "feed" || op_desc->Type() == "fetch")) {
VLOG(3) << "TaskNode will skip [" << op_desc->Input("X")[0] << "], "
<< op_desc->Type() << " -> " << op_desc->Output("Out")[0];
continue;
}
ops_vec_.emplace_back(framework::OpRegistry::CreateOp(*op_desc));
}
for (const auto& op : ops_vec_) {
ops_.emplace_back(op.get());
}
}
}
TaskNode::TaskNode(int64_t rank, int64_t task_id, int64_t max_run_times)
: rank_(rank), task_id_(task_id), max_run_times_(max_run_times) {}
TaskNode::TaskNode(int32_t role,
const std::vector<framework::OpDesc*>& op_descs,
int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums)
: role_(role),
rank_(rank),
task_id_(task_id),
max_run_times_(max_run_times),
max_slot_nums_(max_slot_nums) {
if (op_descs.empty()) {
return;
}
VLOG(3) << "Task node will be inited by providing list of ops.";
for (const auto& desc : op_descs) {
ops_vec_.emplace_back(framework::OpRegistry::CreateOp(*desc));
}
for (const auto& op : ops_vec_) {
ops_.emplace_back(op.get());
}
}
TaskNode::TaskNode(int32_t role,
const std::vector<framework::OperatorBase*>& ops,
int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums)
: ops_(ops),
role_(role),
rank_(rank),
task_id_(task_id),
max_run_times_(max_run_times),
max_slot_nums_(max_slot_nums) {}
TaskNode::TaskNode(int32_t role,
int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums)
: role_(role),
rank_(rank),
task_id_(task_id),
max_run_times_(max_run_times),
max_slot_nums_(max_slot_nums) {}
bool TaskNode::AddUpstreamTask(int64_t task_id, int64_t buff_size) {
const auto& ret = upstream_.emplace(task_id, buff_size);
return ret.second;
}
bool TaskNode::AddDownstreamTask(int64_t task_id, int64_t buff_size) {
const auto& ret = downstream_.emplace(task_id, buff_size);
return ret.second;
}
std::string TaskNode::DebugString() const {
std::ostringstream os;
os << "role: " << role_ << ", task_id: " << task_id_ << "\n";
for (std::size_t i = 0; i < ops_.size(); ++i) {
os << ops_[i]->Type() << " ";
}
os << "\n";
return os.str();
}
void TaskNode::SetRunPerSteps(int64_t value) {
PADDLE_ENFORCE_GE(value,
1,
platform::errors::InvalidArgument(
"run_per_steps must >= 1, but received %ld", value));
run_per_steps_ = value;
}
void TaskNode::SetRunAtOffset(int64_t value) {
PADDLE_ENFORCE_GE(value,
0,
platform::errors::InvalidArgument(
"run_at_offset must >= 0, but received %ld", value));
run_at_offset_ = value;
}
void TaskNode::SetReplyUpPerSteps(int64_t value) {
PADDLE_ENFORCE_GE(
value,
1,
platform::errors::InvalidArgument(
"reply_up_per_steps must >= 1, but received %ld", value));
reply_up_per_steps_ = value;
}
void TaskNode::SetSendDownPerSteps(int64_t value) {
PADDLE_ENFORCE_GE(
value,
1,
platform::errors::InvalidArgument(
"send_down_per_steps must >= 1, but received %ld", value));
send_down_per_steps_ = value;
}
} // namespace distributed
} // namespace paddle
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment