Commit de2e6515 authored by yuguo960516yuguo's avatar yuguo960516yuguo
Browse files

2.4.1-dtk-23.04

parent ad08b8ce
Pipeline #228 failed with stages
in 0 seconds
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/ps/service/communicator/communicator.h"
#include <google/protobuf/text_format.h>
#include "gflags/gflags.h"
#include "paddle/fluid/distributed/ps/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/ps/wrapper/fleet.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/string/string_helper.h"
#define LEARNING_RATE_DECAY_COUNTER "@LR_DECAY_COUNTER@"
#define STEP_COUNTER "@PS_STEP_COUNTER@"
namespace paddle {
namespace distributed {
using framework::LoDTensor;
using phi::SelectedRows;
const uint32_t MAX_FEASIGN_NUM = 1024 * 100 * 100;
inline double GetCurrentUS() {
struct timeval time;
gettimeofday(&time, NULL);
return 1e+6 * time.tv_sec + time.tv_usec;
}
Communicator::Communicator() {}
void Communicator::InitGFlag(const std::string &gflags) {
VLOG(3) << "Init With Gflags:" << gflags;
std::vector<std::string> flags = paddle::string::split_string(gflags);
if (flags.size() < 1) {
flags.push_back("-max_body_size=314217728");
flags.push_back("-bthread_concurrency=40");
flags.push_back("-socket_max_unwritten_bytes=2048000000");
flags.push_back("-max_connection_pool_size=1950");
}
auto it = flags.begin();
flags.insert(it, "exe default");
char *flags_ptr[flags.size()];
for (size_t i = 0; i < flags.size(); ++i) {
flags_ptr[i] = (char *)(flags[i].c_str()); // NOLINT
}
int params_cnt = flags.size();
char **params_ptr = &(flags_ptr[0]);
::GFLAGS_NAMESPACE::ParseCommandLineFlags(&params_cnt, &params_ptr, true);
}
std::once_flag Communicator::init_flag_;
std::shared_ptr<Communicator> Communicator::communicator_(nullptr);
void Communicator::InitBrpcClient(
const std::string &dist_desc,
const std::vector<std::string> &host_sign_list) {
auto fleet = paddle::distributed::FleetWrapper::GetInstance();
if (_worker_ptr.get() == nullptr) {
_worker_ptr = fleet->worker_ptr_;
}
return;
}
std::vector<uint64_t> Communicator::GetClientInfo() {
std::vector<uint64_t> res = _ps_env.GetClientInfo();
for (auto rr : res) {
VLOG(2) << "Communicator::GetClientInfo " << rr;
}
return res;
}
int Communicator::SetClients(std::vector<uint64_t> &host_sign_list) {
int node = host_sign_list.size();
return _ps_env.SetPsClients(host_sign_list.data(), node);
}
void Communicator::RpcRecvDense(const std::vector<std::string> &varnames,
int table_id,
Scope *scope) { // pserver_scope_
platform::RecordEvent record_event("Communicator->RpcRecvDense",
platform::TracerEventType::Communication,
1);
std::vector<paddle::distributed::Region> regions;
regions.reserve(varnames.size());
for (auto &t : varnames) {
Variable *var = scope->Var(t);
LoDTensor *tensor = var->GetMutable<LoDTensor>();
if (platform::is_gpu_place(tensor->place())) {
#ifdef PADDLE_WITH_CUDA
Variable *temp_var = xpu_temp_scope_->Var(t);
LoDTensor *temp_tensor = temp_var->GetMutable<LoDTensor>();
temp_tensor->Resize(tensor->dims());
float *temp_data = temp_tensor->mutable_data<float>(platform::CPUPlace());
paddle::distributed::Region reg(temp_data, tensor->numel());
regions.emplace_back(std::move(reg));
VLOG(1) << "Communicator::RpcRecvDense Var " << t << " table_id "
<< table_id << " Temp_data[0] " << temp_data[0]
<< " Temp_data[-1] " << temp_data[tensor->numel() - 1];
#endif
} else {
float *w = tensor->mutable_data<float>(tensor->place());
paddle::distributed::Region reg(w, tensor->numel());
regions.emplace_back(std::move(reg));
}
}
auto status =
_worker_ptr->PullDense(regions.data(), regions.size(), table_id);
status.wait();
for (auto &t : varnames) {
Variable *var = scope->FindVar(t);
LoDTensor *tensor = var->GetMutable<LoDTensor>();
VLOG(3) << "Communicator::RecvNoBarrier Var " << t << " On gpu? "
<< platform::is_gpu_place(tensor->place());
float *temp_recv_data = tensor->mutable_data<float>(platform::CPUPlace());
VLOG(3) << "Communicator::RpcRecvDense Var " << t << " table_id "
<< table_id << " Temp_data[0] " << temp_recv_data[0]
<< " Temp_data[-1] " << temp_recv_data[tensor->numel() - 1];
if (platform::is_gpu_place(tensor->place())) {
#ifdef PADDLE_WITH_CUDA
LoDTensor *temp_tensor =
xpu_temp_scope_->FindVar(t)->GetMutable<LoDTensor>();
framework::TensorCopy(*temp_tensor, tensor->place(), tensor);
float *temp_data = temp_tensor->mutable_data<float>(platform::CPUPlace());
VLOG(1) << "Communicator::RpcRecvDense Var " << t << " table_id "
<< table_id << " Temp_data[0] " << temp_data[0]
<< " Temp_data[-1] " << temp_data[tensor->numel() - 1];
#endif
}
}
return;
}
void Communicator::RpcSendDenseParam(const std::vector<std::string> &varnames,
int table_id,
const Scope &scope) {
platform::RecordEvent record_event("Communicator->RpcSendDenseParam",
platform::TracerEventType::Communication,
1);
auto place = platform::CPUPlace();
std::vector<paddle::distributed::Region> regions;
for (auto &t : varnames) {
Variable *var = scope.FindVar(t);
CHECK(var != nullptr) << "var[" << t << "] not found";
LoDTensor *tensor = var->GetMutable<LoDTensor>();
if (platform::is_gpu_place(tensor->place())) {
#ifdef PADDLE_WITH_CUDA
Variable *temp_var = xpu_temp_scope_->Var(t);
LoDTensor *temp_tensor = temp_var->GetMutable<LoDTensor>();
temp_tensor->Resize(tensor->dims());
float *temp_data = temp_tensor->mutable_data<float>(platform::CPUPlace());
framework::TensorCopy(*tensor, platform::CPUPlace(), temp_tensor);
paddle::distributed::Region reg(temp_data, tensor->numel());
regions.emplace_back(std::move(reg));
VLOG(1) << "rpc_send_dense_param Var " << t << " table_id " << table_id
<< " Temp_data[0] " << temp_data[0] << " Temp_data[-1] "
<< temp_data[tensor->numel() - 1];
#endif
} else {
float *w = tensor->mutable_data<float>(place);
paddle::distributed::Region reg(w, tensor->numel());
regions.emplace_back(reg);
VLOG(1) << "rpc_send_dense_param Var " << t << " talbe_id " << table_id
<< " Temp_data[0] " << w[0] << " Temp_data[-1] "
<< w[tensor->numel() - 1];
}
}
auto status =
_worker_ptr->PushDenseParam(regions.data(), regions.size(), table_id);
status.wait();
VLOG(4) << "RPC Send Dense Param " << table_id << " done!";
return;
}
void Communicator::RpcSendDense(const CommContext &ctx,
const Scope &scope) { // delta_scope_
platform::RecordEvent record_event("Communicator->RpcSendDense",
platform::TracerEventType::Communication,
1);
auto &var_names = ctx.origin_varnames;
auto &table_id = ctx.table_id;
auto dense_data = std::make_shared<std::vector<float>>();
size_t request_call_num = _worker_ptr->GetServerNums();
uint32_t num_per_shard =
DenseDimPerShard(ctx.height_sections[0], request_call_num);
dense_data->resize(num_per_shard *
request_call_num); // accessor->update_dim() = 1
float *data = dense_data->data();
uint32_t pos = 0;
for (size_t i = 0; i < var_names.size(); ++i) {
const LoDTensor tensor = scope.FindVar(var_names[i])->Get<LoDTensor>();
size_t count = static_cast<size_t>(tensor.numel());
const float *g = tensor.data<float>();
CHECK(pos + count <= dense_data->size())
<< "invalid dense size, cur pos[" << pos << "]"
<< " data_num[" << count << "] size[" << dense_data->size() << "]";
memcpy(data + pos, g, count * sizeof(float));
pos += count;
}
++_async_call_num;
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [this, request_call_num](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done; // NOLINT
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, PS_PUSH_DENSE_TABLE) != 0) {
ret = -1;
break;
}
}
closure->set_promise_value(ret);
--_async_call_num;
});
auto status = _worker_ptr->PushDenseRawGradient(
table_id, data, dense_data->size(), closure);
status.wait();
return;
}
void Communicator::RpcSendSparseParam(const std::string &varname,
int table_id,
const Scope &scope) {
platform::RecordEvent record_event("Communicator->RpcSendSparseParam",
platform::TracerEventType::Communication,
1);
size_t request_call_num = _worker_ptr->GetServerNums();
std::vector<float *> push_g_vec;
auto *send_var = scope.FindVar(varname);
auto *tensor = send_var->GetMutable<framework::LoDTensor>();
auto dim = tensor->dims()[1];
uint64_t sparse_num = static_cast<uint64_t>(tensor->dims()[0]);
std::vector<uint64_t> sparse_push_keys(sparse_num);
std::iota(sparse_push_keys.begin(), sparse_push_keys.end(), 0);
push_g_vec.reserve(sparse_num);
for (auto i = 0; i < static_cast<int>(sparse_push_keys.size()); ++i) {
push_g_vec.push_back(tensor->data<float>() + i * dim);
}
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [this, request_call_num](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done; // NOLINT
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, PS_PUSH_SPARSE_PARAM) != 0) {
ret = -1;
break;
}
}
closure->set_promise_value(ret);
});
auto status = _worker_ptr->PushSparseParam(table_id,
sparse_push_keys.data(),
(const float **)push_g_vec.data(),
sparse_push_keys.size(),
closure);
status.wait();
return;
}
void Communicator::RpcSendSparse(const std::string &var_name,
int table_id,
const Scope &scope) {
platform::RecordEvent record_event("Communicator->RpcSendSparse",
platform::TracerEventType::Communication,
1);
size_t request_call_num = _worker_ptr->GetServerNums();
std::vector<uint64_t> sparse_push_keys;
std::vector<float *> push_g_vec;
auto *send_var = scope.FindVar(var_name);
auto *tensor = send_var->GetMutable<phi::SelectedRows>();
auto dim = tensor->value().dims()[1];
std::transform(tensor->rows().begin(),
tensor->rows().end(),
std::back_inserter(sparse_push_keys),
[&](int64_t id) { return static_cast<uint64_t>(id); });
for (auto i = 0; i < static_cast<int>(sparse_push_keys.size()); ++i) {
push_g_vec.push_back(tensor->mutable_value()->data<float>() + i * dim);
}
// TODO(wangguanqun): padding_idx is not ignored, this is a bug.
// if padding_idx == padding in datareader, the server will core.
/*
for (size_t i = 0; i < tensor->rows().size(); ++i) {
uint64_t real_id = static_cast<uint64_t>(tensor->rows()[i]);
if (real_id != 0) {
sparse_push_keys.push_back(real_id);
push_g_vec.push_back(tensor->mutable_value()->data<float>() + i * dim);
}
}
*/
++_async_call_num;
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [this, request_call_num](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done; // NOLINT
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, PS_PUSH_SPARSE_TABLE) != 0) {
ret = -1;
break;
}
}
closure->set_promise_value(ret);
--_async_call_num;
});
auto status =
_worker_ptr->PushSparseRawGradient(table_id,
sparse_push_keys.data(),
(const float **)push_g_vec.data(),
sparse_push_keys.size(),
closure);
status.wait();
return;
}
void Communicator::RpcRecvSparse(const std::string &varname,
int table_id,
Scope *scope) {
platform::RecordEvent record_event("Communicator->RpcRecvSparse",
platform::TracerEventType::Communication,
1);
auto *send_var = scope->Var(varname);
auto *tensor = send_var->GetMutable<framework::LoDTensor>();
auto dim = tensor->dims()[1];
uint64_t sparse_num = static_cast<uint64_t>(tensor->dims()[0]);
std::vector<uint64_t> sparse_pull_keys(sparse_num);
std::iota(sparse_pull_keys.begin(), sparse_pull_keys.end(), 0);
std::vector<float *> pull_g_vec;
for (auto i = 0; i < static_cast<int>(sparse_pull_keys.size()); ++i) {
pull_g_vec.push_back(tensor->data<float>() + i * dim);
}
bool training = true;
auto status =
_worker_ptr->PullSparseParam(static_cast<float **>(pull_g_vec.data()),
table_id,
sparse_pull_keys.data(),
sparse_pull_keys.size(),
training);
status.wait();
return;
}
void Communicator::InitParams(const RecvCtxMap &recv_varname_to_ctx) {
if (trainer_id_ == 0) {
for (auto &iter : recv_varname_to_ctx) {
auto &table_id = iter.first;
auto &varnames = iter.second;
RpcSendDenseParam(varnames, table_id, *recv_scope_);
VLOG(1) << "push dense param to table " << table_id
<< " from 0' trainer done";
}
}
return;
}
void Communicator::PullDense(const RecvCtxMap &recv_varname_to_ctx) {
for (auto &iter : recv_varname_to_ctx) {
auto &table_id = iter.first;
auto &varnames = iter.second;
RpcRecvDense(varnames, table_id, recv_scope_);
VLOG(1) << "pull dense param to table " << table_id
<< " from 0' trainer done";
}
return;
}
void Communicator::RpcProfilerControl() {
if (trainer_id_ == 0) {
if (!do_server_profiler_ && platform::IsProfileEnabled()) {
// send profiler start flag
do_server_profiler_ = true;
auto start_status = _worker_ptr->StartProfiler();
start_status.wait();
} else if (do_server_profiler_ && !platform::IsProfileEnabled()) {
// send profiler end flag
auto stop_status = _worker_ptr->StopProfiler();
stop_status.wait();
do_server_profiler_ = false;
}
}
}
void Communicator::SendGlobalStep(const CommContext &ctx,
int batches,
Scope *send_scope) {
if (batches == 0) {
return;
}
platform::RecordEvent record_event("Communicator->SendGlobalStep",
platform::TracerEventType::Communication,
1);
auto &table_id = ctx.table_id;
size_t request_call_num = _worker_ptr->GetServerNums();
auto &var_name = STEP_COUNTER;
auto *out_var = send_scope->Var(var_name);
auto *out_t = out_var->GetMutable<framework::LoDTensor>();
auto *data = out_t->mutable_data<int64_t>({1}, platform::CPUPlace());
data[0] = static_cast<int64_t>(batches);
VLOG(3) << "Communicator::SendGlobalStep send: " << batches;
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [this, request_call_num](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done; // NOLINT
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, PS_PUSH_GLOBAL_STEP) != 0) {
ret = -1;
break;
}
}
closure->set_promise_value(ret);
});
auto status = _worker_ptr->PushGlobalStep(table_id, data, closure);
status.wait();
return;
}
void AsyncCommunicator::RecvThread() {
if (!independent_recv_) return;
VLOG(3) << "Independent RecvThread Start and Wait";
while (running_) {
int grad_num = grad_num_.load();
if (grad_num > min_send_grad_num_before_recv_) {
RecvByCommunicator();
grad_num_.store(0);
} else {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
}
VLOG(1) << "communicator stopped, independent recv thread exit";
}
void AsyncCommunicator::RecvByCommunicator() {
if (!running_) return;
RecvNoBarrier();
VLOG(3) << "run recv graph end";
}
void AsyncCommunicator::RecvNoBarrier() {
for (auto &iter : recv_varname_to_ctx_) {
auto &table_id = iter.first;
auto &varnames = iter.second;
RpcRecvDense(varnames, table_id, recv_scope_);
}
for (auto &iter : recv_varname_to_ctx_) {
auto var_names = iter.second;
for (auto &t : var_names) {
Variable *var = recv_scope_->FindVar(t);
LoDTensor *tensor = var->GetMutable<LoDTensor>();
VLOG(3) << "AsyncCommunicator::RecvNoBarrier Var " << t << " On gpu? "
<< platform::is_gpu_place(tensor->place());
if (platform::is_gpu_place(tensor->place())) {
#ifdef PADDLE_WITH_CUDA
LoDTensor *temp_tensor =
xpu_temp_scope_->FindVar(t)->GetMutable<LoDTensor>();
framework::TensorCopy(*temp_tensor, tensor->place(), tensor);
#endif
}
}
}
return;
}
void AsyncCommunicator::SendByCommunicator() {
std::vector<std::future<void>> tasks;
tasks.reserve(send_varname_to_ctx_.size());
for (auto &iter : send_varname_to_ctx_) {
auto &ctx = iter.second;
auto send_recv_task = [this, &ctx] {
auto &varnames = ctx.origin_varnames;
auto &table_id = ctx.table_id;
size_t var_nums = varnames.size();
auto &check_queue = send_varname_to_queue_[varnames[0]];
std::vector<std::vector<std::shared_ptr<Variable>>> vars;
vars.resize(var_nums);
int merged_var_num = 0;
int wait_times = 0;
while (merged_var_num < max_merge_var_num_) {
if (check_queue->Size() == 0) {
VLOG(4) << "wait_times -> " << wait_times;
if (wait_times >= send_wait_times_) {
break;
}
std::this_thread::sleep_for(std::chrono::milliseconds(10));
wait_times++;
continue;
} else {
wait_times = 0;
for (size_t i = 0; i < var_nums; i++) {
auto &var_name = varnames[i];
auto &var_queue = send_varname_to_queue_[var_name];
vars[i].push_back(var_queue->Pop());
}
merged_var_num++;
}
}
if (merged_var_num == 0) return;
for (size_t i = 0; i < var_nums; i++) {
auto &var_name = varnames[i];
if (var_name == STEP_COUNTER) {
MergeVars<int64_t>(var_name, vars[i], send_scope_.get(), 1);
} else {
MergeVars<float>(var_name, vars[i], send_scope_.get(), 1);
}
}
if (ctx.is_tensor_table) {
SendGlobalStep(ctx, merged_var_num, send_scope_.get());
} else if (ctx.is_sparse) {
PADDLE_ENFORCE_EQ(
varnames.size(),
1,
platform::errors::InvalidArgument(
"sparse variables can only be merged by one variables"));
RpcSendSparse(varnames[0], table_id, *send_scope_);
} else {
RpcSendDense(ctx, *send_scope_);
if (!independent_recv_ &&
recv_varname_to_ctx_.find(table_id) != recv_varname_to_ctx_.end()) {
auto recv_varnames = recv_varname_to_ctx_.at(table_id);
RpcRecvDense(recv_varnames, table_id, recv_scope_);
}
}
if (independent_recv_) {
grad_num_.fetch_add(1, std::memory_order_relaxed);
}
};
tasks.emplace_back(send_threadpool_->enqueue(std::move(send_recv_task)));
}
for (auto &task : tasks) {
task.wait();
}
return;
}
void AsyncCommunicator::PushDensePostProcessing() {
if (independent_recv_) {
grad_num_.fetch_add(1, std::memory_order_relaxed);
}
return;
}
void AsyncCommunicator::MainThread() {
VLOG(3) << "AsyncCommunicator MainThread start and wait";
while (waiting_ && running_) {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
VLOG(3) << "wait for running";
}
while (running_) {
SendByCommunicator();
RpcProfilerControl();
}
VLOG(1) << "communicator stopped, send thread exit";
}
void AsyncCommunicator::PullSparseToTensorSync(
const uint64_t table_id,
int fea_dim,
uint64_t padding_id,
platform::Place place,
bool is_training,
std::vector<const LoDTensor *> *inputs,
std::vector<LoDTensor *> *outputs) {
std::vector<uint64_t> fea_keys;
std::vector<float *> pull_result_ptr;
fea_keys.reserve(MAX_FEASIGN_NUM / 100);
pull_result_ptr.reserve(MAX_FEASIGN_NUM / 100);
std::vector<float> init_value(fea_dim, 0);
framework::LoDTensor *output = nullptr;
float *output_data = nullptr;
size_t output_index = -1;
size_t output_len = 0;
for (size_t index = 0; index < inputs->size(); ++index) {
const framework::LoDTensor *tensor = inputs->at(index);
const int64_t *ids = tensor->data<int64_t>();
size_t len = tensor->numel();
for (size_t i = 0; i < len; ++i, output_len += fea_dim) {
if (!output || output_len == size_t(output->numel())) {
++output_index;
CHECK(output_index < outputs->size()); // NOLINT
output = outputs->at(output_index);
output->set_lod(tensor->lod());
output_data = output->mutable_data<float>(place);
output_len = 0;
CHECK(output->numel() % fea_dim == 0); // NOLINT
CHECK(output_data != nullptr); // NOLINT
}
uint64_t real_id = static_cast<uint64_t>(ids[i]);
if (real_id == padding_id) {
memcpy(output_data + output_len,
init_value.data(),
sizeof(float) * fea_dim);
continue;
}
fea_keys.push_back(real_id);
pull_result_ptr.push_back(output_data + output_len);
}
}
auto status = _worker_ptr->PullSparse(pull_result_ptr.data(),
table_id,
fea_keys.data(),
fea_keys.size(),
is_training);
status.wait();
auto ret = status.get();
if (ret != 0) {
LOG(ERROR) << "fleet pull sparse failed, status[" << ret << "]";
sleep(sleep_seconds_before_fail_exit_);
}
}
void AsyncCommunicator::PushSparseFromTensorAsync(
const uint64_t table_id,
int fea_dim,
uint64_t padding_id,
platform::Place place,
std::vector<const framework::LoDTensor *> *inputs,
const framework::LoDTensor *shows,
const framework::LoDTensor *clks,
std::vector<framework::LoDTensor *> *outputs) {
int batch_size = -1;
bool batch_size_consist = true;
for (auto *input : *inputs) {
int cur_batch_size =
input->lod().size() ? input->lod()[0].size() - 1 : input->dims()[0];
if (batch_size == -1) {
batch_size = cur_batch_size;
} else if (batch_size != cur_batch_size) {
// CHECK(batch_size == cur_batch_size); // NOLINT
batch_size_consist = false;
break;
}
}
CHECK(batch_size > 0); // NOLINT
int show_size =
shows->lod().size() ? shows->lod()[0].size() - 1 : shows->dims()[0];
CHECK(show_size == batch_size || show_size == 1);
int clk_size =
clks->lod().size() ? clks->lod()[0].size() - 1 : clks->dims()[0];
CHECK(clk_size == batch_size || clk_size == 1);
CHECK(outputs->size() == inputs->size());
std::vector<uint64_t> push_keys;
push_keys.reserve(MAX_FEASIGN_NUM / 100);
std::vector<std::vector<float>> push_values;
push_values.reserve(MAX_FEASIGN_NUM / 100);
size_t output_len = 0;
size_t input_idx = 0;
VLOG(2) << "fleet.cc::emb_dim: " << fea_dim << " batch_size: " << batch_size
<< " batch_size_consist: " << batch_size_consist;
// TODO(zhaocaibei123): check type of show/clk is int? float? uint64?
// const long int* show_tensor = shows->data<int64_t>();
// const long int* clk_tensor = clks->data<int64_t>();
for (size_t index = 0; index < inputs->size(); ++index) {
framework::LoDTensor *g_tensor = outputs->at(index);
float *g = g_tensor->data<float>();
if (batch_size_consist) { // TODO(zhaocaibei123): add config
// scale_sparse_gradient_with_batch_size_
Eigen::Map<
Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
g_mat(g, g_tensor->numel() / fea_dim, fea_dim);
g_mat.rightCols(fea_dim - 2) *=
batch_size; // hard code here, because of cvm_grad op
}
const framework::LoDTensor *tensor = inputs->at(index);
const int64_t *ids = tensor->data<int64_t>();
size_t len = tensor->numel();
output_len = 0;
if (tensor->lod().size() > 0) {
for (size_t i = 0; i < tensor->lod()[0].size() - 1; ++i) {
for (size_t j = tensor->lod()[0][i]; j < tensor->lod()[0][i + 1];
++j, output_len += fea_dim) {
uint64_t real_id = static_cast<uint64_t>(ids[j]);
if (real_id == padding_id) {
continue;
}
push_keys.emplace_back(real_id);
push_values.emplace_back(fea_dim + 1);
// slot show clk grad... consistent with CtrCommonPushValue defined in
// ctr_accessor.h
push_values.back()[0] = 2; // TODO(zhaocaibei123): slot
// push_values.back()[1] =
// (i >= show_size ? 1 : static_cast<float>(show_tensor[i]));
// push_values.back()[2] =
// (i >= clk_size ? 0 : static_cast<float>(clk_tensor[i]));
float *data = push_values.back().data() + 1; // hard code here
memcpy(data, g + output_len, sizeof(float) * fea_dim);
++input_idx;
}
}
} else {
for (size_t i = 0; i < len; ++i, output_len += fea_dim) {
uint64_t real_id = static_cast<uint64_t>(ids[i]);
if (real_id == padding_id) {
continue;
}
push_keys.emplace_back(real_id);
push_values.emplace_back(fea_dim + 1);
// slot show clk grad... consistent with CtrCommonPushValue defined in
// ctr_accessor.h
push_values.back()[0] = 2; // TODO(zhaocaibei123): slot
// push_values.back()[1] =
// (i >= show_size ? 1 : static_cast<float>(show_tensor[i]));
// push_values.back()[2] =
// (i >= clk_size ? 0 : static_cast<float>(clk_tensor[i]));
float *data = push_values.back().data() + 1;
memcpy(data, g + output_len, sizeof(float) * fea_dim);
++input_idx;
}
}
CHECK(static_cast<int64_t>(output_len) == g_tensor->numel());
}
std::vector<float *> push_g_vec(input_idx, nullptr);
for (auto i = 0u; i < push_keys.size(); ++i) {
push_g_vec[i] = push_values.at(i).data();
}
PADDLE_ENFORCE_EQ(
this->Check(table_id),
true,
platform::errors::InvalidArgument(
"can not find table: %s, please check your config", table_id));
auto status = _worker_ptr->PushSparse(table_id,
push_keys.data(),
(const float **)push_g_vec.data(),
push_keys.size());
}
void HalfAsyncCommunicator::MainThread() {
VLOG(3) << "HalfAsyncCommunicator MainThread start and wait";
while (waiting_ && running_) {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
VLOG(3) << "wait for running";
}
while (running_) {
SendByCommunicator();
BarrierSend();
RecvByCommunicator();
BarrierRecv();
BarrierWeakUp();
}
VLOG(1) << "communicator stopped, send thread exit";
}
void AsyncCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
const RecvCtxMap &recv_varname_to_ctx,
Scope *recv_scope) {
send_varname_to_ctx_ = std::move(send_varname_to_ctx);
recv_varname_to_ctx_ = std::move(recv_varname_to_ctx);
recv_scope_ = std::move(recv_scope);
send_scope_.reset(new Scope());
xpu_temp_scope_.reset(new Scope());
for (auto &iter : send_varname_to_ctx_) {
auto &ctx = iter.second;
auto &varnames = ctx.origin_varnames;
for (auto &var_name : varnames) {
send_varname_to_queue_[var_name] =
std::make_shared<BlockingQueue<std::shared_ptr<Variable>>>(
send_queue_size_);
}
}
send_threadpool_.reset(new ::ThreadPool(thread_pool_size_));
}
AsyncCommunicator::~AsyncCommunicator() {
running_ = false;
if (main_thread_) main_thread_->join();
if (recv_thread_) recv_thread_->join();
}
void AsyncCommunicator::Start() {
VLOG(1) << "Communicator start";
if (!communicator_) {
VLOG(0) << "Communicator is not inited, do nothing";
} else {
VLOG(1) << "start send thread and recv thread";
waiting_ = true;
running_ = true;
// flushing_ = false;
BarrierTriggerReset(max_merge_var_num_);
// start send and recv thread
main_thread_.reset(
new std::thread(std::bind(&AsyncCommunicator::MainThread, this)));
if (independent_recv_) {
recv_thread_.reset(
new std::thread(std::bind(&AsyncCommunicator::RecvThread, this)));
}
}
}
void AsyncCommunicator::Stop() {
VLOG(1) << "Communicator stop begin";
running_ = false;
if (!communicator_) {
VLOG(0) << "Communicator is not inited, do nothing";
} else {
// _worker_ptr->FinalizeWorker();
VLOG(1) << "client finalize_worker done";
if (recv_thread_) {
VLOG(1) << "stop recv thread";
recv_thread_->join();
recv_thread_.reset(nullptr);
}
if (main_thread_) {
VLOG(1) << "stop main thread";
main_thread_->join();
main_thread_.reset(nullptr);
}
}
VLOG(1) << "Communicator stop done";
}
bool AsyncCommunicator::Check(const std::vector<std::string> &var_tables) {
PADDLE_ENFORCE_EQ(
var_tables.size(),
1,
platform::errors::InvalidArgument("var_tables.size() == 1 is permitted"));
auto table_name = var_tables[0];
if (send_varname_to_ctx_.find(table_name) == send_varname_to_ctx_.end()) {
return false;
}
if (table_name == STEP_COUNTER) {
VLOG(3) << "send step_counter into queue";
auto tmp_var = std::make_shared<Variable>();
auto *tensor = tmp_var->GetMutable<framework::LoDTensor>();
tensor->Resize(phi::make_ddim({1}));
auto *out_d = tensor->mutable_data<int64_t>(platform::CPUPlace());
out_d[0] = 1;
send_varname_to_queue_[table_name]->Push(tmp_var);
}
return true;
}
bool AsyncCommunicator::Check(const int table_id) {
for (auto &iter : send_varname_to_ctx_) {
auto &ctx = iter.second;
if (ctx.table_id == table_id) return true;
}
return false;
}
void AsyncCommunicator::Send(const std::vector<std::string> &var_names,
const framework::Scope &scope) {
waiting_ = false;
for (size_t i = 0; i < var_names.size(); i++) {
auto *var = scope.FindVar(var_names[i]);
auto tmp_grad_var = std::make_shared<Variable>();
framework::CopyVariable(*var, tmp_grad_var.get());
send_varname_to_queue_[var_names[i]]->Push(tmp_grad_var);
}
}
void HalfAsyncCommunicator::Clean() {
for (auto &iter : send_varname_to_queue_) {
auto &var_name = iter.first;
auto &var_queue = iter.second;
while (var_queue->Size() > 0) {
var_queue->Pop();
}
VLOG(3) << "clean var: " << var_name << " done";
}
}
void HalfAsyncCommunicator::BarrierTriggerDecrement() {
barrier_trigger_--;
VLOG(3) << "BarrierTriggerDecrement decrement barrier trigger to "
<< barrier_trigger_.load();
}
void HalfAsyncCommunicator::BarrierTriggerReset(int initial_val) {
barrier_trigger_.store(initial_val);
VLOG(3) << "BarrierTriggerReset reset barrier trigger to "
<< barrier_trigger_.load();
}
void HalfAsyncCommunicator::Barrier() {
barrier_counter_++;
if (!running_) {
VLOG(3) << "Communicator is not running, release barrier";
return;
}
{
std::unique_lock<std::mutex> lk(barrier_mutex_);
barrier_cond_.wait(lk, [this] { return (barrier_counter_ == 0); });
}
}
int HalfAsyncCommunicator::BatchesCounter() {
while (running_) {
if (barrier_counter_.load() >= barrier_trigger_.load() &&
barrier_trigger_.load() != 0) {
break;
} else {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
}
return barrier_counter_.load();
}
void HalfAsyncCommunicator::SendByCommunicator() {
int batches = BatchesCounter();
VLOG(1) << "HalfAsyncCommunicator::BatchesCounter = " << batches;
if (batches <= 0) return;
std::vector<std::future<void>> tasks;
tasks.reserve(send_varname_to_ctx_.size());
for (auto &iter : send_varname_to_ctx_) {
auto &ctx = iter.second;
auto send_recv_task = [this, &ctx, batches] {
auto &varnames = ctx.origin_varnames;
auto &table_id = ctx.table_id;
size_t var_nums = varnames.size();
std::vector<std::vector<std::shared_ptr<Variable>>> vars;
vars.resize(var_nums);
for (size_t i = 0; i < var_nums; i++) {
auto &var_name = varnames[i];
auto &var_queue = send_varname_to_queue_[var_name];
for (int j = 0; j < batches; j++) vars[i].push_back(var_queue->Pop());
MergeVars<float>(var_name, vars[i], send_scope_.get(), 1);
}
if (ctx.is_sparse) {
PADDLE_ENFORCE_EQ(
varnames.size(),
1,
platform::errors::InvalidArgument(
"sparse variables can only be merged by one variables"));
RpcSendSparse(varnames[0], table_id, *send_scope_);
} else {
RpcSendDense(ctx, *send_scope_);
}
};
tasks.emplace_back(send_threadpool_->enqueue(std::move(send_recv_task)));
}
for (auto &task : tasks) {
task.wait();
}
return;
}
void HalfAsyncCommunicator::BarrierWeakUp() {
barrier_counter_.store(0);
barrier_cond_.notify_all();
}
void SyncCommunicator::BarrierSend() {
if (!running_) return;
BarrierWithTable(0);
VLOG(4) << "BarrierSend with SyncCommunicator";
}
void SyncCommunicator::BarrierRecv() {
if (!running_) return;
BarrierWithTable(1);
VLOG(4) << "BarrierRecv with SyncCommunicator";
}
void GeoCommunicator::Send(
const std::vector<std::string> &var_names,
const framework::Scope &scope) { // last op in program
platform::RecordEvent record_event(
"GeoCommunicator->Send", platform::TracerEventType::Communication, 1);
waiting_ = false;
auto before_send = GetCurrentUS();
auto table_name = var_names[0];
size_t splited_var_nums =
send_varname_to_ctx_[table_name].splited_varnames.size();
std::unordered_map<std::string, std::unordered_set<int64_t>> ids_table;
for (size_t j = 0; j < splited_var_nums; j++) {
ids_table.insert(std::pair<std::string, std::unordered_set<int64_t>>(
send_varname_to_ctx_[table_name].splited_varnames[j],
std::unordered_set<int64_t>()));
}
auto *var = scope.FindVar(table_name);
PADDLE_ENFORCE_EQ(var->IsType<phi::SelectedRows>(),
true,
platform::errors::InvalidArgument(
"Only need to send Sparse Grad in Geo mode."));
auto &rows = var->Get<phi::SelectedRows>().rows();
// insert ids which has not been record
// VLOG(0) << "fl-ps > table_name: " << table_name << " splited_var_nums: " <<
// splited_var_nums << " rows size: " << rows.size();
for (size_t j = 0; j < rows.size(); j++) { // batch_size == rows.size()
auto ep_idx = rows[j] % splited_var_nums;
ids_table.at(send_varname_to_ctx_[table_name].splited_varnames[ep_idx])
.insert(rows[j]);
// VLOG(0) << " id: " << rows[j] << " ";
}
for (auto &iter : ids_table) {
auto &key = iter.first;
auto &sparse_ids_set = iter.second;
auto sparse_ids_vec = std::make_shared<std::vector<int64_t>>();
sparse_ids_vec->assign(sparse_ids_set.begin(), sparse_ids_set.end());
sparse_id_queues_.at(key)->Put(sparse_ids_vec);
VLOG(3) << "push " << sparse_ids_vec->size() << " ids to " << key
<< "'s queue";
}
auto after_send = GetCurrentUS();
VLOG(2) << "run send op finish. use time " << (after_send - before_send);
}
void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
const RecvCtxMap &recv_varname_to_ctx,
Scope *recv_scope) {
send_varname_to_ctx_ = std::move(send_varname_to_ctx);
recv_varname_to_ctx_ = std::move(
recv_varname_to_ctx); // dense_map - key: table_id, value: params
recv_scope_ = std::move(recv_scope);
for (auto it = send_varname_to_ctx_.begin();
it != send_varname_to_ctx_.end();) {
auto &ctx = it->second;
if (!ctx.is_sparse) {
parallel_task_nums_ += 1;
it++;
continue;
}
auto &varnames = ctx.origin_varnames;
if (varnames.empty()) {
VLOG(0) << "ERROR! sparse variables num can not be zero";
}
auto &varname = varnames[0]; // embedding_0.w_0@GRAD
auto &ids = ctx.remote_sparse_ids;
if (!ids.empty()) {
it = send_varname_to_ctx_.erase(it);
continue;
} else {
it++;
}
for (auto &splited_var : ctx.splited_varnames) { // embedding_0.w_0.block0
parallel_task_nums_ += 1;
sparse_id_queues_.insert(
std::pair<std::string,
paddle::framework::Channel<
std::shared_ptr<std::vector<int64_t>>>>(
splited_var,
paddle::framework::MakeChannel<
std::shared_ptr<std::vector<int64_t>>>(send_queue_size_)));
}
}
send_threadpool_ = std::make_unique<ThreadPool>(thread_pool_size_);
delta_scope_ = std::make_shared<Scope>();
old_scope_ = std::make_shared<Scope>();
pserver_scope_ = std::make_shared<Scope>();
return;
}
void GeoCommunicator::InitParams(const RecvCtxMap &recv_varname_to_ctx) {
std::vector<std::future<void>> tasks;
tasks.reserve(recv_varname_to_ctx_.size());
for (auto &iter : recv_varname_to_ctx_) {
auto &table_id = iter.first;
auto &varnames = iter.second;
auto recv_task = [this, &table_id, &varnames] {
InitDense(varnames, table_id);
};
if (send_threadpool_ == nullptr) {
VLOG(0) << "ERROR! send_threadpool_ is nullptr";
}
tasks.emplace_back(send_threadpool_->enqueue(std::move(recv_task)));
}
for (auto &task : tasks) {
task.wait();
}
for (auto &iter : send_varname_to_ctx_) {
auto &ctx = iter.second;
if (!ctx.is_sparse) {
continue;
}
auto &varname = ctx.origin_varnames[0];
auto &table_id = ctx.table_id;
auto param = varname.substr(0, varname.size() - 5);
VLOG(0) << "InitSparse: " << param << ", " << table_id;
InitSparse(param, table_id);
}
return;
}
void GeoCommunicator::InitDense(std::vector<std::string> &varnames,
int table_id) {
VLOG(1) << "init dense table " << table_id << " begin";
if (trainer_id_ == 0) {
RpcSendDenseParam(varnames, table_id, *recv_scope_);
BarrierWithTable(1);
VLOG(1) << "push dense param to table " << table_id
<< " from 0' trainer done";
} else {
BarrierWithTable(1);
RpcRecvDense(varnames, table_id, recv_scope_);
VLOG(1) << "pull dense param from table " << table_id
<< " from 0' trainer done";
}
// copy to old_scope
for (auto &t : varnames) {
auto *global_var = recv_scope_->FindVar(t);
global_var->GetMutable<framework::LoDTensor>();
auto *old_var = old_scope_->Var(t);
old_var->GetMutable<framework::LoDTensor>();
framework::CopyVariable(*global_var, old_var); // src, dst
// init pserver_scope_
auto *pserver_var = pserver_scope_->Var(t);
pserver_var->GetMutable<framework::LoDTensor>();
framework::CopyVariable(*global_var, pserver_var);
}
VLOG(1) << "init dense table " << table_id << " done";
}
void GeoCommunicator::SendDense(const CommContext &send_ctx) {
platform::RecordEvent record_event("GeoCommunicator->SendDense",
platform::TracerEventType::Communication,
1);
auto &var_names = send_ctx.origin_varnames;
auto &table_id = send_ctx.table_id;
for (auto &varname : var_names) {
auto param_name = GradToParam(varname);
auto *var_latest = recv_scope_->FindVar(param_name);
auto *var_timestamp = old_scope_->FindVar(param_name);
PADDLE_ENFORCE_EQ(var_latest->IsInitialized(),
true,
platform::errors::Unavailable(
"%s is not initialized, please check", param_name));
PADDLE_ENFORCE_EQ(var_timestamp->IsInitialized(),
true,
platform::errors::Unavailable(
"%s is not initialized, please check", param_name));
auto &t_latest = var_latest->Get<framework::LoDTensor>();
auto t_timestamp = var_timestamp->GetMutable<framework::LoDTensor>();
phi::CPUContext cpu_ctx;
auto *var_delta = delta_scope_->Var(varname);
auto *t_delta = var_delta->GetMutable<framework::LoDTensor>();
t_delta->mutable_data<float>(t_latest.dims(), cpu_ctx.GetPlace());
auto blas = phi::funcs::GetBlas<phi::CPUContext, float>(cpu_ctx);
blas.VSUB(t_latest.numel(),
t_latest.data<float>(),
t_timestamp->data<float>(),
t_delta->data<float>());
float coefficient = 1.0 / static_cast<float>(trainers_);
blas.SCAL(t_latest.numel(), coefficient, t_delta->data<float>());
blas.VADD(t_latest.numel(),
t_timestamp->data<float>(),
t_delta->data<float>(),
t_timestamp->data<float>());
}
RpcSendDense(send_ctx, *delta_scope_);
VLOG(1) << "Finish Send Dense " << var_names[0] << ", table_id: " << table_id;
return;
}
void GeoCommunicator::RecvDense(const CommContext &send_ctx) {
platform::RecordEvent record_event("GeoCommunicator->RecvDense",
platform::TracerEventType::Communication,
1);
auto &table_id = send_ctx.table_id;
auto &varnames = recv_varname_to_ctx_.at(table_id);
// 1. recv from pserver
RpcRecvDense(varnames, table_id, pserver_scope_.get());
// 2.1 pserver - old => delta; 2.2 latest + delta => latest 2.3 old =>
// pserver
phi::CPUContext cpu_ctx;
for (auto &varname : varnames) {
auto *var_latest = recv_scope_->FindVar(varname);
auto t_latest = var_latest->GetMutable<framework::LoDTensor>();
auto *var_old = old_scope_->FindVar(varname);
auto t_old = var_old->GetMutable<framework::LoDTensor>();
auto *var_pserver = pserver_scope_->FindVar(varname);
auto t_pserver = var_pserver->Get<framework::LoDTensor>();
auto *var_delta = delta_scope_->Var(varname);
auto *t_delta = var_delta->GetMutable<framework::LoDTensor>();
t_delta->mutable_data<float>(t_latest->dims(), cpu_ctx.GetPlace());
auto blas = phi::funcs::GetBlas<phi::CPUContext, float>(cpu_ctx);
blas.VSUB(t_latest->numel(),
t_pserver.data<float>(),
t_old->data<float>(),
t_delta->data<float>());
blas.VADD(t_latest->numel(),
t_latest->data<float>(),
t_delta->data<float>(),
t_latest->data<float>());
blas.VCOPY(
t_latest->numel(), t_pserver.data<float>(), t_old->data<float>());
}
VLOG(1) << "Finish Recv Dense " << varnames[0] << ", table_id: " << table_id;
return;
}
void GeoCommunicator::InitSparse(const std::string &var_name, int table_id) {
VLOG(1) << "Init Sparse " << var_name << " : table " << table_id << " begin.";
if (trainer_id_ == 0) {
RpcSendSparseParam(var_name, table_id, *recv_scope_);
BarrierWithTable(1);
VLOG(1) << "push sparse param to table " << table_id
<< " from 0' trainer done";
} else {
BarrierWithTable(1);
RpcRecvSparse(var_name, table_id, recv_scope_);
VLOG(1) << "pull sparse param to table " << table_id
<< " from 0' trainer done";
}
VLOG(1) << "Init Sparse " << var_name << " : table " << table_id << " done.";
auto *global_var = recv_scope_->FindVar(var_name);
auto *var = old_scope_->Var(var_name);
framework::CopyVariable(*global_var, var); // src, dst
return;
}
std::vector<int64_t> GeoCommunicator::MergeSparseIds(
const std::string &send_varname) {
platform::RecordEvent record_event("GeoCommunicator->MergeSparseIds",
platform::TracerEventType::Communication,
1);
size_t merge_num = 0, wait_times = 0;
std::unordered_set<int64_t> sparse_ids;
while (merge_num <
static_cast<size_t>(max_merge_var_num_)) { // -> geo_step: 100
VLOG(3) << "Merge Number of " << send_varname << " = " << merge_num;
if (sparse_id_queues_.at(send_varname)->Size() > 0) {
wait_times = 0;
std::shared_ptr<std::vector<int64_t>> pop_ids = nullptr;
sparse_id_queues_.at(send_varname)->Get(pop_ids);
for (size_t j = 0; j < pop_ids->size(); j++) {
sparse_ids.insert(pop_ids->at(j));
}
merge_num += 1;
VLOG(3) << "sparse_id_queues_(" << send_varname << ") pushed";
} else if (sparse_id_queues_.at(send_varname)->Size() == 0) {
VLOG(3) << "wait_times -> " << wait_times;
if (wait_times >= static_cast<size_t>(send_wait_times_)) {
break;
}
std::this_thread::sleep_for(std::chrono::milliseconds(10));
wait_times++;
continue;
}
}
std::vector<int64_t> res;
res.assign(sparse_ids.begin(), sparse_ids.end());
return res;
}
void GeoCommunicator::SendSparse(const std::string &varname,
std::vector<int64_t> &sparse_ids,
int table_id,
int ep_idx) {
platform::RecordEvent record_event("GeoCommunicator->SendSparse",
platform::TracerEventType::Communication,
1);
if (sparse_ids.size() == 0) {
return;
}
std::string param_name = SplitedGradToParam(varname);
VLOG(1) << "In GeoCommunicator::SendSparse(" << varname << " " << param_name
<< ", ids.size = " << sparse_ids.size() << ", table_id: " << table_id
<< ", ep_idx: " << ep_idx;
auto *var_latest = recv_scope_->FindVar(param_name);
auto *var_old = old_scope_->FindVar(param_name);
PADDLE_ENFORCE_EQ(var_latest->IsInitialized(),
true,
platform::errors::Unavailable(
"%s is not initialized, please check", param_name));
PADDLE_ENFORCE_EQ(var_old->IsInitialized(),
true,
platform::errors::Unavailable(
"%s is not initialized, please check", param_name));
auto &t_latest = var_latest->Get<framework::LoDTensor>();
auto *t_old = var_old->GetMutable<framework::LoDTensor>();
auto dims1 = t_latest.dims()[1];
phi::CPUContext cpu_ctx;
auto *var_delta = delta_scope_->Var(varname);
auto *t_delta = var_delta->GetMutable<phi::SelectedRows>();
auto *var_t_value = t_delta->mutable_value();
var_t_value->Resize({static_cast<int64_t>(sparse_ids.size()), dims1});
auto *t_value = var_t_value->mutable_data<float>(cpu_ctx.GetPlace());
t_delta->set_rows(sparse_ids);
t_delta->set_height(t_latest.dims()[0]);
auto blas = phi::funcs::GetBlas<phi::CPUContext, float>(cpu_ctx);
float coefficient = 1.0 / static_cast<float>(trainers_);
std::vector<float *> push_g_vec;
for (auto j = 0; j < static_cast<int>(sparse_ids.size()); ++j) {
blas.VSUB(dims1,
t_latest.data<float>() + sparse_ids[j] * dims1,
t_old->data<float>() + sparse_ids[j] * dims1,
t_value + j * dims1);
blas.SCAL(dims1, coefficient, t_value + j * dims1);
blas.VADD(dims1,
t_old->data<float>() + sparse_ids[j] * dims1,
t_value + j * dims1,
t_old->data<float>() + sparse_ids[j] * dims1);
push_g_vec.push_back(t_value + j * dims1);
VLOG(5) << "DEBUG GeoCommunicator::SendSparse send sparse key "
<< sparse_ids[j] << " value[0] " << push_g_vec[j][0]
<< " value[-1] " << push_g_vec[j][dims1 - 1];
}
++_async_call_num;
DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [this](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done; // NOLINT
if (closure->check_response(0, PS_PUSH_SPARSE_TABLE) != 0) {
ret = -1;
}
closure->set_promise_value(ret);
--_async_call_num;
});
auto status = _worker_ptr->PushSparseRawGradientPartial(
table_id,
(const uint64_t *)sparse_ids.data(),
(const float **)push_g_vec.data(),
sparse_ids.size(),
closure,
ep_idx);
status.wait();
VLOG(1) << "Finish Send Sparse " << varname
<< ", ids.size = " << sparse_ids.size() << ", table_id: " << table_id;
return;
}
void GeoCommunicator::RecvSparse(const std::string &varname,
int table_id,
int ep_idx) {
platform::RecordEvent record_event("GeoCommunicator->RecvSparse",
platform::TracerEventType::Communication,
1);
// 1. recv from pserver
std::vector<uint64_t> keys;
std::vector<float> values;
auto status = _worker_ptr->PullGeoParam(table_id, &values, &keys, ep_idx);
status.wait();
std::string param = SplitedGradToParam(varname);
VLOG(1) << "RecvSparse receive var: " << varname << " " << param << ", "
<< table_id << "; ids Size: " << keys.size()
<< "; values size: " << values.size();
auto *var_latest = recv_scope_->FindVar(param);
auto *var_old = old_scope_->FindVar(param);
auto *t_latest = var_latest->GetMutable<framework::LoDTensor>();
auto *t_old = var_old->GetMutable<framework::LoDTensor>();
auto dims1 = t_latest->dims()[1];
auto numel = keys.size() * dims1;
std::vector<float> v_delta;
v_delta.resize(numel);
phi::CPUContext cpu_ctx;
auto blas = phi::funcs::GetBlas<phi::CPUContext, float>(cpu_ctx);
for (auto j = 0; j < static_cast<int>(keys.size()); ++j) {
VLOG(5) << "DEBUG GeoCommunicator::RecvSparse recv sparse key" << keys[j]
<< "value[0] " << values[j * dims1] << " value[-1] "
<< values[j * dims1 + dims1 - 1];
float *latest_data = t_latest->data<float>() + keys[j] * dims1;
float *old_data = t_old->data<float>() + keys[j] * dims1;
// pserver - old => delta
blas.VSUB(
dims1, values.data() + j * dims1, old_data, v_delta.data() + j * dims1);
// latest + delta => latest
blas.VADD(dims1, latest_data, v_delta.data() + j * dims1, latest_data);
// pserver => old
blas.VCOPY(dims1, values.data() + j * dims1, old_data);
}
VLOG(1) << "Finish Recv Sparse " << param << ", table_id: " << table_id;
}
void GeoCommunicator::MainThread() {
VLOG(3) << "MainThread start and wait";
while (waiting_ && running_) {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
VLOG(3) << "wait for running";
}
while (running_) {
std::vector<std::future<void>> tasks;
tasks.reserve(parallel_task_nums_);
for (auto &iter : send_varname_to_ctx_) {
auto &ctx = iter.second;
auto &varnames = ctx.origin_varnames;
auto &table_id = ctx.table_id;
if (ctx.is_sparse) {
PADDLE_ENFORCE_EQ(
varnames.size(),
1,
platform::errors::InvalidArgument(
"sparse variables can only be merged by one variables"));
int pserver_num = static_cast<int>(ctx.epmap.size());
for (int ep_idx = 0; ep_idx < pserver_num; ep_idx++) {
// varname: emb@GRAD, param_name: emb, splited_varname: emb.delta0
auto send_recv_task = [this, table_id, ep_idx, &ctx] {
auto splited_varname =
ctx.splited_varnames[ep_idx]; // embedding_0.w_0.block0
// embedding_1.w_0.block0
auto sparse_ids = MergeSparseIds(splited_varname);
SendSparse(splited_varname, sparse_ids, table_id, ep_idx);
RecvSparse(splited_varname, table_id, ep_idx);
};
tasks.emplace_back(
send_threadpool_->enqueue(std::move(send_recv_task)));
}
} else {
auto send_recv_task = [this, &ctx] {
SendDense(ctx);
RecvDense(ctx);
};
tasks.emplace_back(
send_threadpool_->enqueue(std::move(send_recv_task)));
}
}
for (auto &task : tasks) {
task.wait();
}
}
}
void FLCommunicator::InitBrpcClient(
const std::string &dist_desc,
const std::vector<std::string> &host_sign_list) {
auto fleet = paddle::distributed::FleetWrapper::GetInstance();
if (_worker_ptr.get() == nullptr) {
VLOG(0) << "fl-ps > FLCommunicator::InitBrpcClient get _worker_ptr";
_worker_ptr =
fleet->worker_ptr_; // FleetWrapper::InitWorker must be excuted
// before, but no need for Coordinator
}
if (coordinator_client_ptr_ == nullptr) {
coordinator_client_ptr_.reset(new CoordinatorClient);
}
int16_t servers = host_sign_list.size();
coordinator_client_ptr_->_env = &ps_env_;
coordinator_client_ptr_->_env->SetPsServers(&host_sign_list, servers);
}
void FLCommunicator::StartCoordinatorClient(
const std::vector<std::string> &trainer_endpoints) {
if (coordinator_client_ptr_ == nullptr) {
LOG(ERROR) << "coordinator_client_ptr_ is null";
return;
}
coordinator_client_ptr_->Initialize(trainer_endpoints);
VLOG(0) << "fl-ps > StartCoordinatorClient finish!";
}
void FLCommunicator::StartCoordinatorServer() {
if (coordinator_client_ptr_ == nullptr) {
LOG(ERROR) << "coordinator_client_ptr_ is null";
}
int ret = coordinator_client_ptr_->StartClientService();
if (ret != 0) {
LOG(ERROR) << "coordinator_client_ptr_ StartClientService failed";
}
VLOG(0) << "fl-ps > StartCoordinatorServer finished!";
return;
}
std::unordered_map<uint32_t, std::string> FLCommunicator::QueryFLClientsInfo() {
return coordinator_client_ptr_->QueryFLClientsInfo();
}
void FLCommunicator::SaveFLStrategy(
const std::unordered_map<uint32_t, std::string> &fl_strategy) {
coordinator_client_ptr_->SaveFLStrategy(fl_strategy);
return;
}
void FLCommunicator::SendThreadAsync() {
while (is_running_) {
RpcSendFLStrategy();
}
return;
}
void FLCommunicator::RpcSendFLStrategy() {
std::set<uint32_t> clients = coordinator_client_ptr_->GetFLClientIds();
coordinator_client_ptr_->WaitForFLStrategyReady();
for (auto client_id : clients) {
coordinator_client_ptr_->SendFLStrategy(client_id);
}
coordinator_client_ptr_->ResetFLStrategyFlag();
VLOG(0) << "fl-ps > RpcSendFLStrategy finished!";
return;
}
void FLCommunicator::StartCoordinator(
const std::string &self_endpoint,
const std::vector<std::string> &trainer_endpoints) {
coordinator_client_ptr_->SetEndpoint(self_endpoint);
StartCoordinatorClient(trainer_endpoints);
StartCoordinatorServer();
async_send_thread_.reset(
new std::thread(&FLCommunicator::SendThreadAsync, this));
}
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <ThreadPool.h>
#include <stdint.h>
#include <atomic>
#include <deque>
#include <map>
#include <memory>
#include <numeric>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "gflags/gflags.h"
#include "paddle/fluid/distributed/ps/service/communicator/communicator_common.h"
#include "paddle/fluid/distributed/ps/service/coordinator_client.h"
#include "paddle/fluid/distributed/ps/service/ps_client.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/split.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace distributed {
class PSClient;
struct CommContext;
} // namespace distributed
} // namespace paddle
DECLARE_bool(communicator_is_sgd_optimizer);
namespace paddle {
namespace distributed {
using Scope = framework::Scope;
using Variable = framework::Variable;
template <typename T>
class BlockingQueue {
public:
explicit BlockingQueue(size_t capacity) : capacity_(capacity) {
PADDLE_ENFORCE_GT(capacity_,
0,
platform::errors::InvalidArgument(
"The capacity must be greater than 0."));
}
bool Push(const T &elem) {
std::unique_lock<std::mutex> lock(mutex_);
WaitForWrite(lock);
queue_.push_back(elem);
Notify();
return true;
}
bool WaitForWrite(std::unique_lock<std::mutex> &lock) { // NOLINT
while (FullUnlocked()) {
if (empty_waiters_ != 0) {
empty_cond_.notify_one();
}
full_waiters_++;
full_cond_.wait(lock);
full_waiters_--;
}
return true;
}
bool WaitForRead(std::unique_lock<std::mutex> &lock) { // NOLINT
while (EmptyUnlocked()) {
if (full_waiters_ != 0) {
full_cond_.notify_one();
}
empty_waiters_++;
empty_cond_.wait(lock);
empty_waiters_--;
}
return true;
}
bool EmptyUnlocked() { return queue_.empty(); }
bool FullUnlocked() { return queue_.size() >= capacity_; }
void Notify() {
if (empty_waiters_ != 0 && (!EmptyUnlocked())) {
empty_cond_.notify_one();
}
if (full_waiters_ != 0 && (!FullUnlocked())) {
full_cond_.notify_one();
}
}
bool Push(T &&elem) {
std::unique_lock<std::mutex> lock(mutex_);
WaitForWrite(lock);
queue_.emplace_back(std::move(elem));
Notify();
return true;
}
T Pop() {
std::unique_lock<std::mutex> lock(mutex_);
WaitForRead(lock);
T rc(std::move(queue_.front()));
queue_.pop_front();
Notify();
return rc;
}
size_t Cap() const {
std::lock_guard<std::mutex> lock(mutex_);
return capacity_;
}
size_t Size() const {
std::lock_guard<std::mutex> lock(mutex_);
return queue_.size();
}
private:
int empty_waiters_ = 0;
int full_waiters_ = 0;
std::condition_variable empty_cond_;
std::condition_variable full_cond_;
const size_t capacity_;
std::deque<T> queue_;
mutable std::mutex mutex_;
};
template <typename T,
int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T>
inline void MergeVars(const std::string &var_name,
const std::vector<std::shared_ptr<Variable>> &vars,
Scope *scope,
bool merge_add = true) {
PADDLE_ENFORCE_NE(
vars.empty(),
true,
platform::errors::InvalidArgument("vector vars are empty."));
auto cpu_place = platform::CPUPlace();
auto &var0 = vars[0];
auto *out_var = scope->Var(var_name);
if (var0->IsType<framework::LoDTensor>()) {
auto dims = var0->Get<framework::LoDTensor>().dims();
VLOG(3) << "merge " << var_name << " LoDTensor dims " << dims
<< "; merge add: " << merge_add;
// init output tensor
auto *out_t = out_var->GetMutable<framework::LoDTensor>();
out_t->mutable_data<T>(dims, cpu_place);
// check the input dims
for (auto &var : vars) {
auto &var_t = var->Get<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(
var_t.dims(),
dims,
platform::errors::InvalidArgument("vars should have the same dims."));
}
// set output tensor to 0.
phi::CPUContext cpu_ctx;
phi::funcs::SetConstant<phi::CPUContext, T> constant_functor;
constant_functor(cpu_ctx, out_t, static_cast<T>(0));
// sum all vars to out
auto result = EigenVector<T>::Flatten(*out_t);
for (auto &var : vars) {
auto &in_t = var->Get<framework::LoDTensor>();
auto in = EigenVector<T>::Flatten(in_t);
result.device(*cpu_ctx.eigen_device()) = result + in;
}
if (!merge_add) {
result.device(*cpu_ctx.eigen_device()) =
result / static_cast<T>(vars.size());
}
} else if (var0->IsType<phi::SelectedRows>()) {
auto &slr0 = var0->Get<phi::SelectedRows>();
auto *out_slr = out_var->GetMutable<phi::SelectedRows>();
out_slr->mutable_rows()->clear();
out_slr->mutable_value()->mutable_data<T>({{}}, cpu_place);
std::vector<const phi::SelectedRows *> inputs;
inputs.reserve(vars.size());
for (auto &var : vars) {
inputs.push_back(&var->Get<phi::SelectedRows>());
}
phi::CPUContext dev_ctx;
if (merge_add) {
paddle::operators::math::scatter::MergeAdd<phi::CPUContext, T> merge_add;
merge_add(dev_ctx, inputs, out_slr);
} else {
paddle::operators::math::scatter::MergeAverage<phi::CPUContext, T>
merge_average;
merge_average(dev_ctx, inputs, out_slr);
}
VLOG(3) << "merge " << var_name << " SelectedRows height: " << slr0.height()
<< " dims: " << slr0.value().dims() << "; merge add: " << merge_add;
} else {
PADDLE_THROW(platform::errors::InvalidArgument("unsupported var type: %s!",
var0->Type()));
}
}
using RpcCtxMap = std::unordered_map<std::string, CommContext>;
using RecvCtxMap = std::unordered_map<uint64_t, std::vector<std::string>>;
using SparseValue = std::unordered_map<int64_t, std::vector<float>>;
class Communicator {
public:
Communicator();
explicit Communicator(const std::map<std::string, std::string> &envs_) {
VLOG(3) << "Communicator Init Envs";
for (auto &iter : envs_) {
envs[iter.first] = iter.second;
VLOG(3) << iter.first << ": " << iter.second;
}
if (!envs.empty()) {
barrier_table_id_ = std::stoi(envs.at("barrier_table_id"));
trainer_id_ = std::stoi(envs.at("trainer_id"));
trainers_ = std::stoi(envs.at("trainers"));
}
}
virtual void InitBrpcClient(const std::string &dist_desc,
const std::vector<std::string> &host_sign_list);
virtual std::vector<uint64_t> GetClientInfo();
virtual int SetClients(std::vector<uint64_t> &host_sign_list); // NOLINT
// 1. recv dense param
virtual void RpcRecvDense(const std::vector<std::string> &varnames,
int table_id,
Scope *scope);
// 2. send dense param
virtual void RpcSendDenseParam(const std::vector<std::string> &varnames,
int table_id,
const Scope &scope);
// 3. send dense grad
virtual void RpcSendDense(const CommContext &ctx, const Scope &scope);
// 4. send sparse grad
virtual void RpcSendSparse(const std::string &var_name,
int table_id,
const Scope &scope);
// 5. send sparse param
virtual void RpcSendSparseParam(const std::string &varname,
int table_id,
const Scope &scope);
// 6. recv sparse param
virtual void RpcRecvSparse(const std::string &varname,
int table_id,
Scope *scope);
// 7. send gloabl step
virtual void SendGlobalStep(const CommContext &ctx,
int batches,
Scope *send_scope);
virtual std::unordered_map<uint32_t, std::string> QueryFLClientsInfo() {
return {};
}
virtual void SaveFLStrategy(
const std::unordered_map<uint32_t, std::string> &fl_strategy) {}
virtual void StartCoordinator(
const std::string &self_endpoint,
const std::vector<std::string> &trainer_endpoints) {}
virtual ~Communicator() {}
virtual void RpcProfilerControl();
virtual void InitParams(const RecvCtxMap &recv_varname_to_ctx);
// note: only for pull dense param first before training
virtual void PullDense(const RecvCtxMap &recv_varname_to_ctx);
virtual void Start() = 0;
virtual void Stop() = 0;
virtual bool IsRunning() { return running_; }
virtual void Clean() {}
virtual bool Check(const int table_id) = 0;
virtual bool Check(const std::vector<std::string> &var_tables) = 0;
virtual void Send(const std::vector<std::string> &var_names,
const framework::Scope &scope) = 0;
virtual void RecvNoBarrier() {}
virtual void Barrier() {}
virtual void BarrierWithTable(uint32_t barrier_type) {
auto rets = _worker_ptr->Barrier(barrier_table_id_, barrier_type);
rets.wait();
int status = rets.get();
PADDLE_ENFORCE_EQ(status,
0,
platform::errors::InvalidArgument(
"The ret status must be 0 when barrier with table"));
}
virtual void CreateC2CConnection(int pserver_timeout_ms,
int pserver_connect_timeout_ms,
int max_retry) {
_worker_ptr->CreateClient2ClientConnection(
pserver_timeout_ms, pserver_connect_timeout_ms, max_retry);
}
virtual void BarrierTriggerDecrement() {}
virtual void BarrierTriggerReset(int init_counter) {}
virtual void InitEnvs() = 0;
virtual void InitImpl(const RpcCtxMap &send_varname_to_ctx,
const RecvCtxMap &recv_varname_to_ctx,
Scope *recv_scope) {}
static Communicator *GetInstance() { return communicator_.get(); }
static std::shared_ptr<Communicator> GetInstantcePtr() {
return communicator_;
}
template <typename T>
static Communicator *InitInstance(
const RpcCtxMap &send_ctx,
const RecvCtxMap &recv_ctx,
const std::string &dist_desc,
const std::vector<std::string> &host_sign_list,
Scope *recv_scope,
const std::map<std::string, std::string> &envs) {
std::call_once(init_flag_,
&Communicator::InitWithRpcCtx<T>,
send_ctx,
recv_ctx,
dist_desc,
host_sign_list,
recv_scope,
std::ref(envs));
return communicator_.get();
}
// called by InitInstance.
template <typename T>
static void InitWithRpcCtx(const RpcCtxMap &send_ctx,
const RecvCtxMap &recv_ctx,
const std::string &dist_desc,
const std::vector<std::string> &host_sign_list,
Scope *recv_scope,
const std::map<std::string, std::string> &envs) {
VLOG(0) << "Communicator type is: " << typeid(T).name();
if (communicator_.get() == nullptr) {
communicator_.reset(new T(std::ref(envs)));
communicator_->InitEnvs();
communicator_->InitBrpcClient(dist_desc, host_sign_list);
communicator_->InitImpl(send_ctx, recv_ctx, recv_scope);
}
}
PSClient *GetPsClient() { return _worker_ptr.get(); }
RecvCtxMap &GetRecvCtxMap() { return recv_varname_to_ctx_; }
std::shared_ptr<PSClient> _worker_ptr; // pointer to worker
protected:
bool running_ = false;
bool waiting_ = true;
bool flushing_ = false;
bool do_server_profiler_ = false;
static std::shared_ptr<Communicator> communicator_;
static std::once_flag init_flag_;
std::unordered_map<std::string, std::string> envs;
// 计算每个shard 对 dense的存储量
inline uint32_t DenseDimPerShard(uint32_t dense_dim_total,
uint32_t shard_num) {
return dense_dim_total / shard_num + 1;
}
void InitGFlag(const std::string &gflags);
paddle::distributed::PSParameter _ps_param;
paddle::distributed::PaddlePSEnvironment _ps_env;
int servers_ = 0;
int trainers_;
int trainer_id_ = 0;
int barrier_table_id_ = 0;
RpcCtxMap send_varname_to_ctx_;
RecvCtxMap recv_varname_to_ctx_;
Scope *recv_scope_; // should be global scope
std::unique_ptr<Scope> xpu_temp_scope_;
std::atomic<uint32_t> _async_call_num{0};
};
class AsyncCommunicator : public Communicator {
public:
AsyncCommunicator() : Communicator() {}
explicit AsyncCommunicator(const std::map<std::string, std::string> &envs)
: Communicator(envs) {}
~AsyncCommunicator();
void InitEnvs() {
independent_recv_ = static_cast<bool>(
std::stoi(envs.at("communicator_independent_recv_thread")));
min_send_grad_num_before_recv_ =
std::stoi(envs.at("communicator_min_send_grad_num_before_recv"));
thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size"));
max_merge_var_num_ = std::stoi(envs.at("communicator_max_merge_var_num"));
send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times"));
send_queue_size_ = std::stoi(envs.at("communicator_send_queue_size"));
need_global_step_ =
static_cast<bool>(std::stoi(envs.at("need_global_step")));
}
void Start() override;
void Stop() override;
void InitImpl(const RpcCtxMap &send_varname_to_ctx,
const RecvCtxMap &recv_varname_to_ctx,
Scope *recv_scope) override;
virtual void MainThread();
virtual void RecvThread();
virtual bool Check(const int table_id);
virtual bool Check(const std::vector<std::string> &var_tables);
void Send(const std::vector<std::string> &var_names,
const framework::Scope &scope) override;
virtual void SendByCommunicator();
virtual void RecvByCommunicator();
virtual void RecvNoBarrier();
virtual int BatchesCounter() { return 1; }
virtual void BarrierSend() {}
virtual void BarrierRecv() {}
virtual void BarrierWeakUp() {}
void PushDensePostProcessing();
void PullSparseToTensorSync(
const uint64_t table_id,
int fea_dim,
uint64_t padding_id,
platform::Place place,
bool is_training,
std::vector<const framework::LoDTensor *> *inputs, // NOLINT
std::vector<framework::LoDTensor *> *outputs); // NOLINT
void PushSparseFromTensorAsync(
const uint64_t table_id,
int fea_dim,
uint64_t padding_id,
platform::Place place,
std::vector<const framework::LoDTensor *> *inputs,
const framework::LoDTensor *shows,
const framework::LoDTensor *clicks,
std::vector<framework::LoDTensor *> *outputs);
protected:
std::unordered_map<std::string,
std::shared_ptr<BlockingQueue<std::shared_ptr<Variable>>>>
send_varname_to_queue_;
std::unique_ptr<::ThreadPool> send_threadpool_{nullptr};
int min_send_grad_num_before_recv_;
int thread_pool_size_;
int max_merge_var_num_;
int send_wait_times_;
int send_queue_size_;
bool need_global_step_ = false;
bool independent_recv_ = true;
int parallel_task_nums_ = 0;
int32_t sleep_seconds_before_fail_exit_;
std::unique_ptr<std::thread> main_thread_{nullptr};
std::unique_ptr<std::thread> recv_thread_{nullptr};
std::unique_ptr<Scope> send_scope_; // an independent scope
std::atomic_uint grad_num_{0}; // the num of gradient sent since last recv
};
class HalfAsyncCommunicator : public AsyncCommunicator {
public:
HalfAsyncCommunicator() {}
explicit HalfAsyncCommunicator(const std::map<std::string, std::string> &envs)
: AsyncCommunicator(envs) {}
void InitEnvs() {
// enfore to recv after send
independent_recv_ = false;
min_send_grad_num_before_recv_ = 0;
thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size"));
max_merge_var_num_ = std::stoi(envs.at("communicator_max_merge_var_num"));
send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times"));
send_queue_size_ = std::stoi(envs.at("communicator_send_queue_size"));
need_global_step_ =
static_cast<bool>(std::stoi(envs.at("need_global_step")));
VLOG(1) << "HalfAsyncCommunicator Initialized";
}
void MainThread() override;
void SendByCommunicator() override;
void Clean() override;
void Barrier() override;
void BarrierTriggerDecrement() override;
void BarrierTriggerReset(int initial_val) override;
int BatchesCounter();
void BarrierWeakUp();
protected:
// mutex for Wait for barrier
std::mutex barrier_mutex_;
std::condition_variable barrier_cond_;
std::atomic<int64_t> barrier_trigger_{0};
std::atomic<int64_t> barrier_counter_{0};
};
class SyncCommunicator : public HalfAsyncCommunicator {
public:
SyncCommunicator() : HalfAsyncCommunicator() {}
explicit SyncCommunicator(const std::map<std::string, std::string> &envs)
: HalfAsyncCommunicator(envs) {}
void InitEnvs() {
// enfore to recv after send
independent_recv_ = false;
min_send_grad_num_before_recv_ = 0;
max_merge_var_num_ = std::stoi(envs.at("communicator_max_merge_var_num"));
send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times"));
thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size"));
send_queue_size_ = std::stoi(envs.at("communicator_send_queue_size"));
need_global_step_ =
static_cast<bool>(std::stoi(envs.at("need_global_step")));
VLOG(1) << "SyncCommunicator Initialized";
}
void BarrierSend();
void BarrierRecv();
private:
std::vector<std::string> pserver_endpoints_{};
};
class GeoCommunicator : public AsyncCommunicator {
public:
GeoCommunicator() : AsyncCommunicator() {}
explicit GeoCommunicator(const std::map<std::string, std::string> &envs)
: AsyncCommunicator(envs) {}
void InitParams(const RecvCtxMap &recv_varname_to_ctx) override;
void InitDense(std::vector<std::string> &varnames, int table_id); // NOLINT
void InitSparse(const std::string &var_name, int table_id);
void SendDense(const CommContext &send_ctx);
void RecvDense(const CommContext &send_ctx);
std::vector<int64_t> MergeSparseIds(const std::string &varname);
void SendSparse(const std::string &varname,
std::vector<int64_t> &sparse_ids, // NOLINT
int table_id,
int ep_idx);
void RecvSparse(const std::string &varname, int table_id, int ep_idx);
void MainThread() override;
virtual void InitEnvs() {
independent_recv_ = false;
min_send_grad_num_before_recv_ = 0;
send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times"));
thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size"));
// id_queue's size
max_merge_var_num_ = std::stoi(envs.at("communicator_max_merge_var_num"));
send_queue_size_ = max_merge_var_num_;
VLOG(1) << "GeoCommunicator Initialized";
}
void InitImpl(const RpcCtxMap &send_varname_to_ctx,
const RecvCtxMap &recv_varname_to_ctx,
Scope *recv_scope) override;
void Send(const std::vector<std::string> &var_names,
const framework::Scope &scope) override;
void SendByCommunicator() { return; }
void RecvByCommunicator() override { return; }
inline std::string GradToParam(const std::string var_name) {
std::string param_name = var_name.substr(0, var_name.size() - 5);
return param_name;
}
inline std::string SplitedGradToParam(const std::string delta_name) {
// delta_name: emb.delta0
auto pos = delta_name.find(".block");
std::string param_name = delta_name.substr(0, pos);
return param_name;
}
public:
// parameter for delta calc and send
std::shared_ptr<Scope> delta_scope_;
// parameter for storage the pserver param after last recv
std::shared_ptr<Scope> old_scope_;
// parameter on pserver
std::shared_ptr<Scope> pserver_scope_;
std::unordered_map<
std::string,
paddle::framework::Channel<std::shared_ptr<std::vector<int64_t>>>>
sparse_id_queues_;
};
class FLCommunicator : public GeoCommunicator {
public:
FLCommunicator() : GeoCommunicator() {}
~FLCommunicator() {
is_running_ = false;
async_send_thread_->join();
}
explicit FLCommunicator(const std::map<std::string, std::string> &envs)
: GeoCommunicator(envs) {}
void InitEnvs() override {}
virtual void InitBrpcClient(const std::string &dist_desc,
const std::vector<std::string> &host_sign_list);
void InitImpl(const RpcCtxMap &send_varname_to_ctx,
const RecvCtxMap &recv_varname_to_ctx,
Scope *recv_scope) {}
void StartCoordinatorClient(
const std::vector<std::string> &trainer_endpoints);
void StartCoordinatorServer();
void StartCoordinator(
const std::string &self_endpoint,
const std::vector<std::string> &trainer_endpoints) override;
std::unordered_map<uint32_t, std::string> QueryFLClientsInfo();
void SaveFLStrategy(
const std::unordered_map<uint32_t, std::string> &fl_strategy);
void SendThreadAsync();
void RpcSendFLStrategy();
private:
int thread_pool_size_ = 1;
bool is_running_ = true;
PaddlePSEnvironment ps_env_;
std::shared_ptr<CoordinatorClient> coordinator_client_ptr_{nullptr};
std::unique_ptr<std::thread> async_send_thread_{nullptr};
};
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <iostream>
#include <sstream>
#include <string>
#include <vector>
namespace paddle {
namespace distributed {
struct CommContext {
CommContext() = default;
CommContext(const std::string &name,
const std::vector<std::string> &names,
const std::vector<std::string> &emap,
const std::vector<int64_t> &sections,
const std::vector<std::string> &origin_names,
int trainer_id,
bool merge_add = true,
bool is_sparse = true,
bool is_distributed = false,
int table_id = -1,
bool is_tensor_table = false,
bool is_datanorm_table = false,
int64_t program_id = -1,
const std::vector<int32_t> &remote_sparse_ids = {})
: var_name(name),
splited_varnames(names),
epmap(emap),
height_sections(sections),
origin_varnames(origin_names),
trainer_id(trainer_id),
merge_add(merge_add),
is_sparse(is_sparse),
is_distributed(is_distributed),
table_id(table_id),
program_id(program_id),
is_tensor_table(is_tensor_table),
is_datanorm_table(is_datanorm_table),
remote_sparse_ids(remote_sparse_ids) {}
CommContext(const CommContext &ctx) {
var_name = ctx.var_name;
splited_varnames = ctx.splited_varnames;
epmap = ctx.epmap;
height_sections = ctx.height_sections;
trainer_id = ctx.trainer_id;
merge_add = ctx.merge_add;
is_sparse = ctx.is_sparse;
origin_varnames = ctx.origin_varnames;
is_distributed = ctx.is_distributed;
table_id = ctx.table_id;
program_id = ctx.program_id;
is_tensor_table = ctx.is_tensor_table;
is_datanorm_table = ctx.is_datanorm_table;
remote_sparse_ids = ctx.remote_sparse_ids;
}
std::string print() const {
std::stringstream ss;
ss << "varname: " << var_name << " trainer_id: " << trainer_id << " ";
ss << " table_id: " << table_id;
std::for_each(
remote_sparse_ids.begin(), remote_sparse_ids.end(), [&](const int &i) {
ss << "remote_sparse_id: " << i << " ";
});
for (size_t i = 0; i < splited_varnames.size(); i++) {
ss << "slice varname: " << splited_varnames[i] << " ep: " << epmap[i]
<< " section: " << height_sections[i] << " ";
}
ss << "origin varnames: ";
for (size_t i = 0; i < origin_varnames.size(); i++) {
ss << origin_varnames[i] << " ";
}
ss << " aggregation->add: " << merge_add;
ss << " is_sparse: " << is_sparse;
ss << " is_distributed: " << is_distributed << "\n";
ss << " table_id: " << table_id << "\n";
ss << " program_id: " << program_id << "\n";
ss << " is_tensor_table: " << is_tensor_table << "\n";
ss << " is_datanorm_table: " << is_datanorm_table << "\n";
return ss.str();
}
std::string var_name;
std::vector<std::string> splited_varnames;
std::vector<std::string> epmap;
std::vector<int64_t> height_sections;
std::vector<std::string> origin_varnames;
int trainer_id;
bool merge_add;
bool is_sparse;
bool is_distributed;
int table_id;
int64_t program_id;
bool is_tensor_table;
bool is_datanorm_table;
std::vector<int32_t> remote_sparse_ids;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/coordinator_client.h"
#include <memory>
#include <sstream>
#include <string>
#include "paddle/fluid/distributed/ps/service/brpc_ps_client.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/string/split.h"
static const int MIN_PORT = 8500;
static const int MAX_PORT = 65535;
namespace paddle {
namespace distributed {
DEFINE_uint64(total_fl_client_size, 100, "supported total fl client size");
DEFINE_uint32(coordinator_wait_all_clients_max_time, 60, "uint32: s");
void CoordinatorService::FLService(
::google::protobuf::RpcController* controller,
const CoordinatorReqMessage* request,
CoordinatorResMessage* response,
::google::protobuf::Closure* done) {
brpc::ClosureGuard done_guard(done);
response->set_err_code(0);
response->set_err_msg("");
brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
int32_t msg_type = request->cmd_id();
uint32_t from_client_id = request->client_id();
VLOG(0) << "fl-ps > recv from client id: " << from_client_id
<< ", msg_type: " << msg_type;
// TODO(ziyoujiyi): find is not thread safe, beacuse of RB_Tree traversal
auto itr = _service_handle_map.find(msg_type);
if (itr == _service_handle_map.end()) {
LOG(ERROR) << "fl-ps > unknown flClient2Coordinator msg type: " << msg_type;
return;
}
int ret = itr->second(*request, response, cntl); // SaveFLClientInfo
if (ret != 0) {
response->set_err_code(-1);
response->set_err_msg("fl-ps > handle flClient2Coordinator msg failed");
}
return;
}
int32_t CoordinatorClient::Initialize(
const std::vector<std::string>& trainer_endpoints) {
brpc::ChannelOptions options;
options.protocol = "baidu_std";
options.timeout_ms = paddle::distributed::FLAGS_pserver_timeout_ms;
options.connection_type = "pooled";
options.connect_timeout_ms =
paddle::distributed::FLAGS_pserver_connect_timeout_ms;
options.max_retry = 3;
std::string server_ip_port;
// 获取 Pserver 列表,并连接
if (_env == nullptr) {
LOG(ERROR) << "_env is null in CoordinatorClient::Initialize()";
return -1;
}
std::vector<PSHost> pserver_list = _env->GetPsServers();
_pserver_channels.resize(pserver_list.size());
for (size_t i = 0; i < pserver_list.size(); ++i) {
server_ip_port.assign(pserver_list[i].ip.c_str());
server_ip_port.append(":");
server_ip_port.append(std::to_string(pserver_list[i].port));
for (size_t j = 0; j < _pserver_channels[i].size(); ++j) {
_pserver_channels[i][j].reset(new brpc::Channel());
if (_pserver_channels[i][j]->Init(server_ip_port.c_str(), "", &options) !=
0) {
LOG(ERROR) << "CoordinatorClient connect to PServer:" << server_ip_port
<< " Failed! Try again.";
std::string int_ip_port =
GetIntTypeEndpoint(pserver_list[i].ip, pserver_list[i].port);
if (_pserver_channels[i][j]->Init(int_ip_port.c_str(), "", &options) !=
0) {
LOG(ERROR) << "CoordinatorClient connect to PServer:" << int_ip_port
<< " Failed!";
return -1;
}
}
}
}
// 获取 fl_client 列表,并连接
std::vector<PSHost> fl_client_list;
fl_client_list.resize(trainer_endpoints.size());
if (fl_client_list.empty()) {
LOG(ERROR) << ">>> fl clients addr info lost";
return -1;
}
for (size_t i = 0; i < trainer_endpoints.size(); i++) {
std::vector<std::string> addr =
paddle::string::Split(trainer_endpoints[i], ':');
fl_client_list[i].ip = addr[0];
fl_client_list[i].port = std::stol(addr[1]);
fl_client_list[i].rank = i; // TO CHECK
}
std::string fl_client_ip_port;
for (size_t i = 0; i < fl_client_list.size(); ++i) {
fl_client_ip_port.assign(fl_client_list[i].ip);
fl_client_ip_port.append(":");
fl_client_ip_port.append(std::to_string(fl_client_list[i].port));
uint32_t rank = fl_client_list[i].rank;
VLOG(0) << "fl-ps > coordinator connect to fl_client: " << rank;
_fl_client_channels[rank].reset(new brpc::Channel());
if (_fl_client_channels[rank]->Init(
fl_client_ip_port.c_str(), "", &options) != 0) {
LOG(ERROR) << "CoordinatorClient connect to FLClient:"
<< fl_client_ip_port << " Failed! Try again.";
std::string int_ip_port =
GetIntTypeEndpoint(fl_client_list[i].ip, fl_client_list[i].port);
if (_fl_client_channels[rank]->Init(int_ip_port.c_str(), "", &options) !=
0) {
LOG(ERROR) << "CoordinatorClient connect to PSClient:" << int_ip_port
<< " Failed!";
return -1;
}
}
}
SetTotalFLClientsNum(fl_client_list.size());
SetDefaultFLStrategy();
return 0;
}
int32_t CoordinatorClient::StartClientService() {
_service.Initialize();
_server.AddService(&_service, brpc::SERVER_DOESNT_OWN_SERVICE);
brpc::ServerOptions options;
options.num_threads = 1;
if (_endpoint.empty()) {
LOG(ERROR) << "fl-ps > coordinator server endpoint not set";
return -1;
}
auto addr = paddle::string::Split(_endpoint, ':');
std::string ip = addr[0];
std::string port = addr[1];
std::string rank = addr[2];
std::string ip_port = ip + ":" + port;
if (_server.Start(ip_port.c_str(), &options) != 0) {
LOG(ERROR) << "fl-ps > StartClientService failed";
return -1;
}
uint32_t port_ = std::stol(port);
int32_t rank_ = std::stoi(rank);
_env->RegisteCoordinatorClient(ip, port_, rank_);
VLOG(0) << "fl-ps > coordinator service addr: " << ip << ", " << port << ", "
<< _coordinator_id;
return 0;
}
void CoordinatorClient::SendFLStrategy(const uint32_t& client_id) {
size_t request_call_num = 1;
FlClientBrpcClosure* closure =
new FlClientBrpcClosure(request_call_num, [](void* done) {
auto* closure = reinterpret_cast<FlClientBrpcClosure*>(done);
int ret = 0;
if (closure->check_response(0, PUSH_FL_STRATEGY) != 0) {
LOG(ERROR) << "fl-ps > SendFLStrategy failed";
ret = -1;
}
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
std::future<int32_t> fut = promise->get_future();
closure->add_promise(promise);
closure->request(0)->set_cmd_id(PUSH_FL_STRATEGY);
closure->request(0)->set_client_id(client_id);
std::string fl_strategy = _fl_strategy_mp[client_id];
closure->request(0)->set_str_params(fl_strategy);
brpc::Channel* rpc_channel = _fl_client_channels[client_id].get();
if (rpc_channel == nullptr) {
LOG(ERROR) << "fl-ps > _fl_client_channels is null";
return;
}
PsService_Stub rpc_stub(rpc_channel); // DownpourPsClientService
rpc_stub.FLService(
closure->cntl(0), closure->request(0), closure->response(0), closure);
fut.wait();
VLOG(0) << "fl-ps > SendFLStrategy to client: " << client_id << " finished";
return;
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ThreadPool.h>
#include <memory>
#include <string>
#include <vector>
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/ps/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/ps/service/brpc_utils.h"
#include "paddle/fluid/distributed/ps/service/ps_client.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor_util.h"
namespace paddle {
namespace distributed {
DECLARE_int32(pserver_timeout_ms);
DECLARE_int32(pserver_connect_timeout_ms);
DECLARE_uint64(total_fl_client_size);
DECLARE_uint32(coordinator_wait_all_clients_max_time);
using CoordinatorServiceFunc =
std::function<int32_t(const CoordinatorReqMessage& request,
CoordinatorResMessage* response,
brpc::Controller* cntl)>;
class ClientReportedInfo {
public:
ClientReportedInfo() {}
~ClientReportedInfo() {}
uint32_t client_id;
uint32_t iteration_idx;
double auc = 0.0;
};
class CoordinatorServiceHandle {
public:
CoordinatorServiceHandle() {}
virtual ~CoordinatorServiceHandle() {}
void SaveFLClientInfo(const CoordinatorReqMessage& request) {
auto client_id = request.client_id();
const std::string& str_params = request.str_params();
// each client is allowed to send empty message to maintain heartbeat(i.e.
// use staleness msg)
std::unique_lock<std::mutex> lck(_mtx);
if (str_params.size() != 0) {
_client_info_mp[client_id] = str_params;
} else {
LOG(INFO) << "fl-ps > content in request from " << client_id
<< " is null";
}
fl_client_ids.insert(client_id);
_fl_clients_count++;
// TODO(ziyoujiyi): how to process when a client loss connection?
if (_fl_clients_count.load() == last_round_total_fl_clients_num) {
_is_all_clients_info_collected = true;
_cv.notify_one();
}
lck.unlock();
VLOG(0) << "last_round_total_fl_clients_num: "
<< last_round_total_fl_clients_num
<< ", has recved fl client num: " << _fl_clients_count.load();
return;
}
std::unordered_map<uint32_t, std::string> QueryFLClientsInfo() {
platform::Timer timeline;
double query_wait_time = 0.0;
timeline.Start();
auto f = [&]() -> bool {
while (query_wait_time <
paddle::distributed::
FLAGS_coordinator_wait_all_clients_max_time) { // in case that
// some
// clients down
if (_is_all_clients_info_collected == true) {
// LOG(INFO) << "fl-ps > _is_all_clients_info_collected";
return true;
}
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
timeline.Pause();
query_wait_time += timeline.ElapsedSec();
}
// LOG(WARNNING) << "fl-ps > query_wait_time exceed!";
return true;
};
std::unique_lock<std::mutex> lck(_mtx);
_cv.wait(lck, f);
lck.unlock();
_is_all_clients_info_collected = false;
_fl_clients_count.store(0);
return _client_info_mp;
}
public:
std::unordered_map<uint32_t, std::string> _client_info_mp;
std::set<uint32_t> fl_client_ids;
uint32_t last_round_total_fl_clients_num = 0;
bool _is_all_clients_info_collected = false;
private:
std::mutex _mtx;
std::condition_variable _cv;
std::atomic<uint32_t> _fl_clients_count{0};
};
class CoordinatorService : public PsService {
public:
CoordinatorService() {
_coordinator_service_handle = std::make_shared<CoordinatorServiceHandle>();
}
virtual ~CoordinatorService() {}
virtual void Initialize() {
_service_handle_map[PUSH_FL_CLIENT_INFO_SYNC] =
std::bind(&CoordinatorService::SaveFLClientInfo,
this,
std::placeholders::_1,
std::placeholders::_2,
std::placeholders::_3);
}
virtual void FLService(::google::protobuf::RpcController* controller,
const CoordinatorReqMessage* request,
CoordinatorResMessage* response,
::google::protobuf::Closure* done);
int32_t SaveFLClientInfo(const CoordinatorReqMessage& request,
CoordinatorResMessage* response,
brpc::Controller* cntl) {
_coordinator_service_handle->SaveFLClientInfo(request);
return 0;
}
void SetTotalFLClientsNum(uint32_t all_fl_clients_num) {
if (_coordinator_service_handle.get() != nullptr) {
_coordinator_service_handle->last_round_total_fl_clients_num =
all_fl_clients_num;
} else {
LOG(ERROR) << "fl-ps > _coordinator_service_handle is null in "
"CoordinatorService";
}
return;
}
std::set<uint32_t> GetFLClientIds() {
return _coordinator_service_handle->fl_client_ids;
}
std::unordered_map<uint32_t, std::string> QueryFLClientsInfo() {
return _coordinator_service_handle->QueryFLClientsInfo();
}
private:
std::shared_ptr<CoordinatorServiceHandle> _coordinator_service_handle;
std::unordered_map<int32_t, CoordinatorServiceFunc> _service_handle_map;
std::mutex _mtx;
};
class CoordinatorClient : public BrpcPsClient {
public:
CoordinatorClient() : _coordinator_id(0) {}
virtual ~CoordinatorClient() {}
int32_t Initialize(const std::vector<std::string>& trainer_endpoints);
void SetTotalFLClientsNum(uint32_t all_fl_clients_num) {
_service.SetTotalFLClientsNum(all_fl_clients_num);
this->_total_clients_num = all_fl_clients_num;
return;
}
int32_t StartClientService();
void SaveFLStrategy(
const std::unordered_map<uint32_t, std::string>& fl_strategy) {
for (auto it = fl_strategy.begin(); it != fl_strategy.end(); it++) {
uint32_t client_id = it->first;
_fl_strategy_mp[client_id] = it->second;
}
std::unique_lock<std::mutex> lck(_mtx);
_is_fl_strategy_ready = true;
_cv.notify_all();
return;
}
void WaitForFLStrategyReady() {
std::unique_lock<std::mutex> lck(_mtx);
_cv.wait(lck, [=]() { return _is_fl_strategy_ready; });
}
void SendFLStrategy(const uint32_t& client_id);
void ResetFLStrategyFlag() { _is_fl_strategy_ready = false; }
void SetDefaultFLStrategy() {
for (size_t i = 0; i < _total_clients_num; i++) {
_fl_strategy_mp[i] = "";
}
return;
}
std::set<uint32_t> GetFLClientIds() { return _service.GetFLClientIds(); }
std::unordered_map<uint32_t, std::string> QueryFLClientsInfo() {
return _service.QueryFLClientsInfo();
}
void SetEndpoint(const std::string& endpoint) {
_endpoint = std::move(endpoint);
}
public:
size_t _coordinator_id;
uint32_t _total_clients_num;
std::string _endpoint;
std::vector<std::array<std::shared_ptr<brpc::Channel>, 1>>
_pserver_channels; // coordinator2pserver
std::unordered_map<uint32_t, std::shared_ptr<brpc::Channel>>
_fl_client_channels; // coordinator2psclient
brpc::Server _server;
CoordinatorService _service;
std::unordered_map<uint32_t, std::string> _fl_strategy_mp;
bool _is_fl_strategy_ready = false;
std::mutex _mtx;
std::condition_variable _cv;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/env.h"
namespace paddle {
namespace distributed {} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <arpa/inet.h>
#include <glog/logging.h>
#include <netinet/in.h>
#include <stdio.h>
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "gflags/gflags.h"
namespace paddle {
namespace distributed {
struct PSHost {
std::string ip;
uint32_t port;
uint32_t rank;
PSHost() = default;
PSHost(const std::string ip, uint32_t port, uint32_t rank)
: ip(ip), port(port), rank(rank) {}
// |---ip---|---port---|--rank--|
// |-32bit--|--20bit---|--12bit-|
uint64_t SerializeToUint64() {
uint64_t host_label = 0;
host_label = inet_addr(ip.c_str());
host_label = host_label << 32;
host_label += (port << 12);
host_label += rank;
return host_label;
}
void ParseFromUint64(uint64_t host_label) {
static uint64_t rank_label_mask = (1L << 12) - 1;
static uint64_t port_label_mask = (1L << 20) - 1;
rank = host_label & rank_label_mask;
port = (host_label >> 12) & port_label_mask;
uint32_t ip_addr = (host_label >> 32);
ip = inet_ntoa(*(in_addr *)&ip_addr); // NOLINT
}
std::string ToString() {
std::stringstream s;
s << "host: " << ip;
s << " port: " << port;
s << " rank: " << rank;
s << " uint64: " << SerializeToUint64();
return s.str();
}
// for open source parameter server
std::string SerializeToString() {
std::stringstream s;
s << ip << ":";
s << port << ":";
s << rank;
return s.str();
}
void ParseFromString(std::string endpoint) {
std::vector<std::string> endpoint_info;
StringSplit(endpoint, ':', &endpoint_info);
ip = endpoint_info[0];
port = std::stoi(endpoint_info[1]);
rank = std::stoi(endpoint_info[2]);
}
void StringSplit(const std::string &str,
char sep,
std::vector<std::string> *pieces,
bool ignore_null = true) {
pieces->clear();
if (str.empty()) {
if (!ignore_null) {
pieces->push_back(str);
}
return;
}
size_t pos = 0;
size_t next = str.find(sep, pos);
while (next != std::string::npos) {
pieces->push_back(str.substr(pos, next - pos));
pos = next + 1;
next = str.find(sep, pos);
}
if (!str.substr(pos).empty()) {
pieces->push_back(str.substr(pos));
}
}
};
class PSEnvironment {
public:
explicit PSEnvironment() {} // NOLINT
virtual ~PSEnvironment() {}
virtual int32_t SetPsServers(uint64_t *host_sign_list, int node_num) {
return 0;
}
virtual int32_t SetPsServers(
const std::vector<std::string> *host_endpoint_list, int node_num) {
return 0;
}
virtual int32_t SetPsClients(uint64_t *host_sign_list, int node_num) {
return 0;
}
virtual int32_t SetPsClients(std::string *host_endpoint_list, int node_num) {
return 0;
}
virtual uint64_t GetLocalHostSign() { return 0; }
virtual std::vector<PSHost> GetPsServers() const { return _ps_server_list; }
virtual int32_t RegistePsServer(const std::string &ip,
uint32_t port,
int32_t rank) {
return RegistePsHost(ip, port, rank, _ps_server_list, _ps_server_sign_set);
}
virtual std::vector<PSHost> GetPsClients() const { return _ps_client_list; }
virtual int32_t RegistePsClient(const std::string &ip,
uint32_t port,
int32_t rank) {
return RegistePsHost(ip, port, rank, _ps_client_list, _ps_client_sign_set);
}
virtual std::vector<PSHost> GetCoordinators() const {
return _coordinator_list;
}
virtual int32_t RegisteCoordinatorClient(const std::string &ip,
uint32_t port,
int32_t rank) {
return RegistePsHost(
ip, port, rank, _coordinator_list, _coordinator_sign_set);
}
virtual std::vector<uint64_t> GetClientInfo() {
std::vector<uint64_t> client_info;
for (auto &i : _ps_client_list) {
client_info.push_back(i.SerializeToUint64());
}
return client_info;
}
virtual std::vector<std::string> GetClientInfo(bool use_string_endpoint) {
if (use_string_endpoint) {
std::vector<std::string> client_info;
for (auto &i : _ps_client_list) {
client_info.push_back(i.SerializeToString());
}
return client_info;
}
return {};
}
virtual void SetTrainers(int trainers) { trainers_ = trainers; }
virtual int GetTrainers() { return trainers_; }
protected:
//注册一个host // NOLINT
virtual int32_t RegistePsHost(
const std::string &ip,
uint32_t port,
int32_t rank,
std::vector<PSHost> &host_list, // NOLINT
std::unordered_set<uint64_t> &sign_set) { // NOLINT
PSHost host;
host.ip = ip;
host.port = port;
host.rank = rank;
if (sign_set.count(rank) == 0) {
host_list.push_back(host);
sign_set.insert(rank);
}
return 0;
}
int trainers_ = 0;
std::vector<PSHost> _ps_client_list;
std::unordered_set<uint64_t> _ps_client_sign_set; // for unique filter
std::vector<PSHost> _ps_server_list;
std::unordered_set<uint64_t> _ps_server_sign_set; // for unique filter
std::vector<PSHost> _coordinator_list;
std::unordered_set<uint64_t> _coordinator_sign_set;
};
class PaddlePSEnvironment : public PSEnvironment {
public:
explicit PaddlePSEnvironment() {} // NOLINT
virtual ~PaddlePSEnvironment() {}
virtual int32_t SetPsServers(uint64_t *host_sign_list, int node_num) {
_ps_server_list.clear();
_ps_server_sign_set.clear();
for (int i = 0; i < node_num; ++i) {
if (host_sign_list[i] > 0) {
PSHost host;
host.ParseFromUint64(host_sign_list[i]);
_ps_server_list.push_back(host);
_ps_server_sign_set.insert(host.SerializeToUint64());
}
}
std::sort(
_ps_server_list.begin(),
_ps_server_list.end(),
[](const PSHost &h1, const PSHost &h2) { return h1.rank < h2.rank; });
return 0;
}
virtual int32_t SetPsServers(const std::vector<std::string> *host_sign_list,
int node_num) {
_ps_server_list.clear();
_ps_server_sign_set.clear();
for (int i = 0; i < node_num; ++i) {
if (host_sign_list->at(i) != "") {
PSHost host;
host.ParseFromString(host_sign_list->at(i));
_ps_server_list.push_back(host);
_ps_server_sign_set.insert(host.rank);
}
}
std::sort(
_ps_server_list.begin(),
_ps_server_list.end(),
[](const PSHost &h1, const PSHost &h2) { return h1.rank < h2.rank; });
return 0;
}
virtual int32_t SetPsClients(uint64_t *host_sign_list, int node_num) {
_ps_client_list.clear();
_ps_client_sign_set.clear();
for (int i = 0; i < node_num; ++i) {
if (host_sign_list[i] > 0) {
PSHost host;
host.ParseFromUint64(host_sign_list[i]);
_ps_client_list.push_back(host);
_ps_client_sign_set.insert(host.SerializeToUint64());
}
}
std::sort(
_ps_client_list.begin(),
_ps_client_list.end(),
[](const PSHost &h1, const PSHost &h2) { return h1.rank < h2.rank; });
return 0;
}
virtual int32_t SetPsClients(const std::vector<std::string> *host_sign_list,
int node_num) {
_ps_client_list.clear();
_ps_client_sign_set.clear();
for (int i = 0; i < node_num; ++i) {
if (host_sign_list->at(i) != "") {
PSHost host;
host.ParseFromString(host_sign_list->at(i));
_ps_client_list.push_back(host);
_ps_client_sign_set.insert(host.rank);
}
}
std::sort(
_ps_client_list.begin(),
_ps_client_list.end(),
[](const PSHost &h1, const PSHost &h2) { return h1.rank < h2.rank; });
VLOG(1) << "env.set_ps_clients done\n";
return 0;
}
virtual void SetCoordinators(const std::vector<std::string> *host_sign_list,
size_t node_num) {
_coordinator_list.clear();
_coordinator_sign_set.clear();
for (size_t i = 0; i < node_num; ++i) {
if (host_sign_list->at(i) != "") {
PSHost host;
host.ParseFromString(host_sign_list->at(i));
_coordinator_list.push_back(host);
_coordinator_sign_set.insert(host.rank);
VLOG(0) << "fl-ps > coordinator info in env: " << host.ToString();
}
}
return;
}
virtual uint64_t GetLocalHostSign() {
if (_ps_client_list.size() > 0) {
return _ps_client_list[0].SerializeToUint64();
} else {
return 0;
}
}
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/graph_brpc_client.h"
#include <algorithm>
#include <memory>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
#include "Eigen/Dense"
#include "paddle/fluid/distributed/ps/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/ps/table/table.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace distributed {
void GraphPsService_Stub::service(
::google::protobuf::RpcController *controller,
const ::paddle::distributed::PsRequestMessage *request,
::paddle::distributed::PsResponseMessage *response,
::google::protobuf::Closure *done) {
if (graph_service != NULL && local_channel == channel()) {
// VLOG(0)<<"use local";
task_pool->enqueue([this, controller, request, response, done]() -> int {
this->graph_service->service(controller, request, response, done);
return 0;
});
} else {
// VLOG(0)<<"use server";
PsService_Stub::service(controller, request, response, done);
}
}
int GraphBrpcClient::get_server_index_by_id(int64_t id) {
int shard_num = get_shard_num();
int shard_per_server = shard_num % server_size == 0
? shard_num / server_size
: shard_num / server_size + 1;
return id % shard_num / shard_per_server;
}
std::future<int32_t> GraphBrpcClient::get_node_feat(
const uint32_t &table_id,
int idx_,
const std::vector<int64_t> &node_ids,
const std::vector<std::string> &feature_names,
std::vector<std::vector<std::string>> &res) {
std::vector<int> request2server;
std::vector<int> server2request(server_size, -1);
for (size_t query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]);
if (server2request[server_index] == -1) {
server2request[server_index] = request2server.size();
request2server.push_back(server_index);
}
}
size_t request_call_num = request2server.size();
std::vector<std::vector<int64_t>> node_id_buckets(request_call_num);
std::vector<std::vector<int>> query_idx_buckets(request_call_num);
for (size_t query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]);
int request_idx = server2request[server_index];
node_id_buckets[request_idx].push_back(node_ids[query_idx]);
query_idx_buckets[request_idx].push_back(query_idx);
}
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num,
[&, node_id_buckets, query_idx_buckets, request_call_num](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
size_t fail_num = 0;
for (size_t request_idx = 0; request_idx < request_call_num;
++request_idx) {
if (closure->check_response(request_idx, PS_GRAPH_GET_NODE_FEAT) !=
0) {
++fail_num;
} else {
auto &res_io_buffer =
closure->cntl(request_idx)->response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
size_t bytes_size = io_buffer_itr.bytes_left();
std::unique_ptr<char[]> buffer_wrapper(new char[bytes_size]);
char *buffer = buffer_wrapper.get();
io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);
for (size_t feat_idx = 0; feat_idx < feature_names.size();
++feat_idx) {
for (size_t node_idx = 0;
node_idx < query_idx_buckets.at(request_idx).size();
++node_idx) {
int query_idx = query_idx_buckets.at(request_idx).at(node_idx);
size_t feat_len = *(size_t *)(buffer);
buffer += sizeof(size_t);
auto feature = std::string(buffer, feat_len);
res[feat_idx][query_idx] = feature;
buffer += feat_len;
}
}
}
if (fail_num == request_call_num) {
ret = -1;
}
}
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t request_idx = 0; request_idx < request_call_num; ++request_idx) {
int server_index = request2server[request_idx];
closure->request(request_idx)->set_cmd_id(PS_GRAPH_GET_NODE_FEAT);
closure->request(request_idx)->set_table_id(table_id);
closure->request(request_idx)->set_client_id(_client_id);
size_t node_num = node_id_buckets[request_idx].size();
closure->request(request_idx)->add_params((char *)&idx_, sizeof(int));
closure->request(request_idx)
->add_params((char *)node_id_buckets[request_idx].data(),
sizeof(int64_t) * node_num);
std::string joint_feature_name =
paddle::string::join_strings(feature_names, '\t');
closure->request(request_idx)
->add_params(joint_feature_name.c_str(), joint_feature_name.size());
GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx),
closure->request(request_idx),
closure->response(request_idx),
closure);
}
return fut;
}
std::future<int32_t> GraphBrpcClient::clear_nodes(uint32_t table_id,
int type_id,
int idx_) {
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
server_size, [&, server_size = this->server_size](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
size_t fail_num = 0;
for (size_t request_idx = 0; request_idx < server_size; ++request_idx) {
if (closure->check_response(request_idx, PS_GRAPH_CLEAR) != 0) {
++fail_num;
break;
}
}
ret = fail_num == 0 ? 0 : -1;
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t i = 0; i < server_size; i++) {
int server_index = i;
closure->request(server_index)->set_cmd_id(PS_GRAPH_CLEAR);
closure->request(server_index)->set_table_id(table_id);
closure->request(server_index)->set_client_id(_client_id);
closure->request(server_index)->add_params((char *)&type_id, sizeof(int));
closure->request(server_index)->add_params((char *)&idx_, sizeof(int));
GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(server_index)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(server_index),
closure->request(server_index),
closure->response(server_index),
closure);
}
return fut;
}
std::future<int32_t> GraphBrpcClient::add_graph_node(
uint32_t table_id,
int idx_,
std::vector<int64_t> &node_id_list,
std::vector<bool> &is_weighted_list) {
std::vector<std::vector<int64_t>> request_bucket;
std::vector<std::vector<bool>> is_weighted_bucket;
bool add_weight = is_weighted_list.size() > 0;
std::vector<int> server_index_arr;
std::vector<int> index_mapping(server_size, -1);
for (size_t query_idx = 0; query_idx < node_id_list.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_id_list[query_idx]);
if (index_mapping[server_index] == -1) {
index_mapping[server_index] = request_bucket.size();
server_index_arr.push_back(server_index);
request_bucket.push_back(std::vector<int64_t>());
if (add_weight) is_weighted_bucket.push_back(std::vector<bool>());
}
request_bucket[index_mapping[server_index]].push_back(
node_id_list[query_idx]);
if (add_weight)
is_weighted_bucket[index_mapping[server_index]].push_back(
query_idx < is_weighted_list.size() ? is_weighted_list[query_idx]
: false);
}
size_t request_call_num = request_bucket.size();
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [&, request_call_num](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
size_t fail_num = 0;
for (size_t request_idx = 0; request_idx < request_call_num;
++request_idx) {
if (closure->check_response(request_idx, PS_GRAPH_ADD_GRAPH_NODE) !=
0) {
++fail_num;
}
}
ret = fail_num == request_call_num ? -1 : 0;
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t request_idx = 0; request_idx < request_call_num; ++request_idx) {
int server_index = server_index_arr[request_idx];
closure->request(request_idx)->set_cmd_id(PS_GRAPH_ADD_GRAPH_NODE);
closure->request(request_idx)->set_table_id(table_id);
closure->request(request_idx)->set_client_id(_client_id);
size_t node_num = request_bucket[request_idx].size();
closure->request(request_idx)->add_params((char *)&idx_, sizeof(int));
closure->request(request_idx)
->add_params((char *)request_bucket[request_idx].data(),
sizeof(int64_t) * node_num);
if (add_weight) {
bool weighted[is_weighted_bucket[request_idx].size() + 1];
for (size_t j = 0; j < is_weighted_bucket[request_idx].size(); j++)
weighted[j] = is_weighted_bucket[request_idx][j];
closure->request(request_idx)
->add_params((char *)weighted,
sizeof(bool) * is_weighted_bucket[request_idx].size());
}
// PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx),
closure->request(request_idx),
closure->response(request_idx),
closure);
}
return fut;
}
std::future<int32_t> GraphBrpcClient::remove_graph_node(
uint32_t table_id, int idx_, std::vector<int64_t> &node_id_list) {
std::vector<std::vector<int64_t>> request_bucket;
std::vector<int> server_index_arr;
std::vector<int> index_mapping(server_size, -1);
for (size_t query_idx = 0; query_idx < node_id_list.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_id_list[query_idx]);
if (index_mapping[server_index] == -1) {
index_mapping[server_index] = request_bucket.size();
server_index_arr.push_back(server_index);
request_bucket.push_back(std::vector<int64_t>());
}
request_bucket[index_mapping[server_index]].push_back(
node_id_list[query_idx]);
}
size_t request_call_num = request_bucket.size();
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [&, request_call_num](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
size_t fail_num = 0;
for (size_t request_idx = 0; request_idx < request_call_num;
++request_idx) {
if (closure->check_response(request_idx,
PS_GRAPH_REMOVE_GRAPH_NODE) != 0) {
++fail_num;
}
}
ret = fail_num == request_call_num ? -1 : 0;
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t request_idx = 0; request_idx < request_call_num; ++request_idx) {
int server_index = server_index_arr[request_idx];
closure->request(request_idx)->set_cmd_id(PS_GRAPH_REMOVE_GRAPH_NODE);
closure->request(request_idx)->set_table_id(table_id);
closure->request(request_idx)->set_client_id(_client_id);
size_t node_num = request_bucket[request_idx].size();
closure->request(request_idx)->add_params((char *)&idx_, sizeof(int));
closure->request(request_idx)
->add_params((char *)request_bucket[request_idx].data(),
sizeof(int64_t) * node_num);
// PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx),
closure->request(request_idx),
closure->response(request_idx),
closure);
}
return fut;
}
// char* &buffer,int &actual_size
std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
uint32_t table_id,
int idx_,
std::vector<int64_t> node_ids,
int sample_size,
// std::vector<std::vector<std::pair<int64_t, float>>> &res,
std::vector<std::vector<int64_t>> &res,
std::vector<std::vector<float>> &res_weight,
bool need_weight,
int server_index) {
if (server_index != -1) {
res.resize(node_ids.size());
if (need_weight) {
res_weight.resize(node_ids.size());
}
DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [&](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
if (closure->check_response(0, PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER) !=
0) {
ret = -1;
} else {
auto &res_io_buffer = closure->cntl(0)->response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
size_t bytes_size = io_buffer_itr.bytes_left();
std::unique_ptr<char[]> buffer_wrapper(new char[bytes_size]);
char *buffer = buffer_wrapper.get();
io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);
size_t node_num = *(size_t *)buffer;
int *actual_sizes = (int *)(buffer + sizeof(size_t));
char *node_buffer = buffer + sizeof(size_t) + sizeof(int) * node_num;
int offset = 0;
for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
int actual_size = actual_sizes[node_idx];
int start = 0;
while (start < actual_size) {
res[node_idx].emplace_back(
*(int64_t *)(node_buffer + offset + start));
start += GraphNode::id_size;
if (need_weight) {
res_weight[node_idx].emplace_back(
*(float *)(node_buffer + offset + start));
start += GraphNode::weight_size;
}
}
offset += actual_size;
}
}
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
;
closure->request(0)->set_cmd_id(PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER);
closure->request(0)->set_table_id(table_id);
closure->request(0)->set_client_id(_client_id);
closure->request(0)->add_params((char *)&idx_, sizeof(int));
closure->request(0)->add_params((char *)node_ids.data(),
sizeof(int64_t) * node_ids.size());
closure->request(0)->add_params((char *)&sample_size, sizeof(int));
closure->request(0)->add_params((char *)&need_weight, sizeof(bool));
;
// PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(0)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(
closure->cntl(0), closure->request(0), closure->response(0), closure);
return fut;
}
std::vector<int> request2server;
std::vector<int> server2request(server_size, -1);
res.clear();
res_weight.clear();
for (size_t query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]);
if (server2request[server_index] == -1) {
server2request[server_index] = request2server.size();
request2server.push_back(server_index);
}
// res.push_back(std::vector<std::pair<int64_t, float>>());
res.push_back({});
if (need_weight) {
res_weight.push_back({});
}
}
size_t request_call_num = request2server.size();
std::vector<std::vector<int64_t>> node_id_buckets(request_call_num);
std::vector<std::vector<int>> query_idx_buckets(request_call_num);
for (size_t query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]);
int request_idx = server2request[server_index];
node_id_buckets[request_idx].push_back(node_ids[query_idx]);
query_idx_buckets[request_idx].push_back(query_idx);
}
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num,
[&, node_id_buckets, query_idx_buckets, request_call_num](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
size_t fail_num = 0;
for (size_t request_idx = 0; request_idx < request_call_num;
++request_idx) {
if (closure->check_response(request_idx, PS_GRAPH_SAMPLE_NEIGHBORS) !=
0) {
++fail_num;
} else {
auto &res_io_buffer =
closure->cntl(request_idx)->response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
size_t bytes_size = io_buffer_itr.bytes_left();
std::unique_ptr<char[]> buffer_wrapper(new char[bytes_size]);
char *buffer = buffer_wrapper.get();
io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);
size_t node_num = *(size_t *)buffer;
int *actual_sizes = (int *)(buffer + sizeof(size_t));
char *node_buffer =
buffer + sizeof(size_t) + sizeof(int) * node_num;
int offset = 0;
for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
int query_idx = query_idx_buckets.at(request_idx).at(node_idx);
int actual_size = actual_sizes[node_idx];
int start = 0;
while (start < actual_size) {
res[query_idx].emplace_back(
*(int64_t *)(node_buffer + offset + start));
start += GraphNode::id_size;
if (need_weight) {
res_weight[query_idx].emplace_back(
*(float *)(node_buffer + offset + start));
start += GraphNode::weight_size;
}
}
offset += actual_size;
}
}
if (fail_num == request_call_num) {
ret = -1;
}
}
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t request_idx = 0; request_idx < request_call_num; ++request_idx) {
int server_index = request2server[request_idx];
closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE_NEIGHBORS);
closure->request(request_idx)->set_table_id(table_id);
closure->request(request_idx)->set_client_id(_client_id);
size_t node_num = node_id_buckets[request_idx].size();
closure->request(request_idx)->add_params((char *)&idx_, sizeof(int));
closure->request(request_idx)
->add_params((char *)node_id_buckets[request_idx].data(),
sizeof(int64_t) * node_num);
closure->request(request_idx)
->add_params((char *)&sample_size, sizeof(int));
closure->request(request_idx)
->add_params((char *)&need_weight, sizeof(bool));
// PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx),
closure->request(request_idx),
closure->response(request_idx),
closure);
}
return fut;
}
std::future<int32_t> GraphBrpcClient::random_sample_nodes(
uint32_t table_id,
int type_id,
int idx_,
int server_index,
int sample_size,
std::vector<int64_t> &ids) {
DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [&](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
if (closure->check_response(0, PS_GRAPH_SAMPLE_NODES) != 0) {
ret = -1;
} else {
auto &res_io_buffer = closure->cntl(0)->response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
size_t bytes_size = io_buffer_itr.bytes_left();
char *buffer = new char[bytes_size];
size_t index = 0;
while (index < bytes_size) {
ids.push_back(*(int64_t *)(buffer + index));
index += GraphNode::id_size;
}
delete[] buffer;
}
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
;
closure->request(0)->set_cmd_id(PS_GRAPH_SAMPLE_NODES);
closure->request(0)->set_table_id(table_id);
closure->request(0)->set_client_id(_client_id);
closure->request(0)->add_params((char *)&type_id, sizeof(int));
closure->request(0)->add_params((char *)&idx_, sizeof(int));
closure->request(0)->add_params((char *)&sample_size, sizeof(int));
;
// PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(0)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(
closure->cntl(0), closure->request(0), closure->response(0), closure);
return fut;
}
std::future<int32_t> GraphBrpcClient::pull_graph_list(
uint32_t table_id,
int type_id,
int idx_,
int server_index,
int start,
int size,
int step,
std::vector<FeatureNode> &res) {
DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [&](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
if (closure->check_response(0, PS_PULL_GRAPH_LIST) != 0) {
ret = -1;
} else {
auto &res_io_buffer = closure->cntl(0)->response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
size_t bytes_size = io_buffer_itr.bytes_left();
char *buffer = new char[bytes_size];
io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);
size_t index = 0;
while (index < bytes_size) {
FeatureNode node;
node.recover_from_buffer(buffer + index);
index += node.get_size(false);
res.push_back(node);
}
delete[] buffer;
}
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
closure->request(0)->set_cmd_id(PS_PULL_GRAPH_LIST);
closure->request(0)->set_table_id(table_id);
closure->request(0)->set_client_id(_client_id);
closure->request(0)->add_params((char *)&type_id, sizeof(int));
closure->request(0)->add_params((char *)&idx_, sizeof(int));
closure->request(0)->add_params((char *)&start, sizeof(int));
closure->request(0)->add_params((char *)&size, sizeof(int));
closure->request(0)->add_params((char *)&step, sizeof(int));
// PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(0)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(
closure->cntl(0), closure->request(0), closure->response(0), closure);
return fut;
}
std::future<int32_t> GraphBrpcClient::set_node_feat(
const uint32_t &table_id,
int idx_,
const std::vector<int64_t> &node_ids,
const std::vector<std::string> &feature_names,
const std::vector<std::vector<std::string>> &features) {
std::vector<int> request2server;
std::vector<int> server2request(server_size, -1);
for (size_t query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]);
if (server2request[server_index] == -1) {
server2request[server_index] = request2server.size();
request2server.push_back(server_index);
}
}
size_t request_call_num = request2server.size();
std::vector<std::vector<int64_t>> node_id_buckets(request_call_num);
std::vector<std::vector<int>> query_idx_buckets(request_call_num);
std::vector<std::vector<std::vector<std::string>>> features_idx_buckets(
request_call_num);
for (size_t query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]);
int request_idx = server2request[server_index];
node_id_buckets[request_idx].push_back(node_ids[query_idx]);
query_idx_buckets[request_idx].push_back(query_idx);
if (features_idx_buckets[request_idx].size() == 0) {
features_idx_buckets[request_idx].resize(feature_names.size());
}
for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
features_idx_buckets[request_idx][feat_idx].push_back(
features[feat_idx][query_idx]);
}
}
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num,
[&, node_id_buckets, query_idx_buckets, request_call_num](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
size_t fail_num = 0;
for (size_t request_idx = 0; request_idx < request_call_num;
++request_idx) {
if (closure->check_response(request_idx, PS_GRAPH_SET_NODE_FEAT) !=
0) {
++fail_num;
}
if (fail_num == request_call_num) {
ret = -1;
}
}
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t request_idx = 0; request_idx < request_call_num; ++request_idx) {
int server_index = request2server[request_idx];
closure->request(request_idx)->set_cmd_id(PS_GRAPH_SET_NODE_FEAT);
closure->request(request_idx)->set_table_id(table_id);
closure->request(request_idx)->set_client_id(_client_id);
size_t node_num = node_id_buckets[request_idx].size();
closure->request(request_idx)->add_params((char *)&idx_, sizeof(int));
closure->request(request_idx)
->add_params((char *)node_id_buckets[request_idx].data(),
sizeof(int64_t) * node_num);
std::string joint_feature_name =
paddle::string::join_strings(feature_names, '\t');
closure->request(request_idx)
->add_params(joint_feature_name.c_str(), joint_feature_name.size());
// set features
std::string set_feature = "";
for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
size_t feat_len =
features_idx_buckets[request_idx][feat_idx][node_idx].size();
set_feature.append((char *)&feat_len, sizeof(size_t));
set_feature.append(
features_idx_buckets[request_idx][feat_idx][node_idx].data(),
feat_len);
}
}
closure->request(request_idx)
->add_params(set_feature.c_str(), set_feature.size());
GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx),
closure->request(request_idx),
closure->response(request_idx),
closure);
}
return fut;
}
int32_t GraphBrpcClient::Initialize() {
// set_shard_num(_config.shard_num());
BrpcPsClient::Initialize();
server_size = GetServerNums();
graph_service = NULL;
local_channel = NULL;
return 0;
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ThreadPool.h>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "ThreadPool.h"
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/ps/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/ps/service/graph_brpc_server.h"
#include "paddle/fluid/distributed/ps/service/ps_client.h"
#include "paddle/fluid/distributed/ps/table/table.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor_util.h"
namespace paddle {
namespace distributed {
class GraphPsService_Stub : public PsService_Stub {
public:
GraphPsService_Stub(::google::protobuf::RpcChannel* channel,
::google::protobuf::RpcChannel* local_channel = NULL,
GraphBrpcService* service = NULL,
int thread_num = 1)
: PsService_Stub(channel) {
this->local_channel = local_channel;
this->graph_service = service;
task_pool.reset(new ::ThreadPool(thread_num));
}
virtual ~GraphPsService_Stub() {}
// implements PsService ------------------------------------------
GraphBrpcService* graph_service;
std::shared_ptr<::ThreadPool> task_pool;
::google::protobuf::RpcChannel* local_channel;
GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(GraphPsService_Stub);
void service(::google::protobuf::RpcController* controller,
const ::paddle::distributed::PsRequestMessage* request,
::paddle::distributed::PsResponseMessage* response,
::google::protobuf::Closure* done);
};
class GraphBrpcClient : public BrpcPsClient {
public:
GraphBrpcClient() {}
virtual ~GraphBrpcClient() {}
// given a batch of nodes, sample graph_neighbors for each of them
virtual std::future<int32_t> batch_sample_neighbors(
uint32_t table_id,
int idx,
std::vector<int64_t> node_ids,
int sample_size,
std::vector<std::vector<int64_t>>& res,
std::vector<std::vector<float>>& res_weight,
bool need_weight,
int server_index = -1);
virtual std::future<int32_t> pull_graph_list(uint32_t table_id,
int type_id,
int idx,
int server_index,
int start,
int size,
int step,
std::vector<FeatureNode>& res);
virtual std::future<int32_t> random_sample_nodes(uint32_t table_id,
int type_id,
int idx,
int server_index,
int sample_size,
std::vector<int64_t>& ids);
virtual std::future<int32_t> get_node_feat(
const uint32_t& table_id,
int idx,
const std::vector<int64_t>& node_ids,
const std::vector<std::string>& feature_names,
std::vector<std::vector<std::string>>& res);
virtual std::future<int32_t> set_node_feat(
const uint32_t& table_id,
int idx,
const std::vector<int64_t>& node_ids,
const std::vector<std::string>& feature_names,
const std::vector<std::vector<std::string>>& features);
virtual std::future<int32_t> clear_nodes(uint32_t table_id,
int type_id,
int idx);
virtual std::future<int32_t> add_graph_node(
uint32_t table_id,
int idx,
std::vector<int64_t>& node_id_list,
std::vector<bool>& is_weighted_list);
virtual std::future<int32_t> remove_graph_node(
uint32_t table_id, int idx_, std::vector<int64_t>& node_id_list);
virtual int32_t Initialize();
int get_shard_num() { return shard_num; }
void set_shard_num(int shard_num) { this->shard_num = shard_num; }
int get_server_index_by_id(int64_t id);
void set_local_channel(int index) {
this->local_channel = GetCmdChannel(index);
}
void set_local_graph_service(GraphBrpcService* graph_service) {
this->graph_service = graph_service;
}
GraphPsService_Stub getServiceStub(::google::protobuf::RpcChannel* channel,
int thread_num = 1) {
return GraphPsService_Stub(
channel, local_channel, graph_service, thread_num);
}
private:
int shard_num;
size_t server_size;
::google::protobuf::RpcChannel* local_channel;
GraphBrpcService* graph_service;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/graph_brpc_server.h"
#include <thread> // NOLINT
#include <utility>
#include "butil/endpoint.h"
#include "iomanip"
#include "paddle/fluid/distributed/ps/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/ps/service/brpc_ps_server.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace distributed {
#define CHECK_TABLE_EXIST(table, request, response) \
if (table == NULL) { \
std::string err_msg("table not found with table_id:"); \
err_msg.append(std::to_string(request.table_id())); \
set_response_code(response, -1, err_msg.c_str()); \
return -1; \
}
int32_t GraphBrpcServer::Initialize() {
auto &service_config = _config.downpour_server_param().service_param();
if (!service_config.has_service_class()) {
LOG(ERROR) << "miss service_class in ServerServiceParameter";
return -1;
}
auto *service =
CREATE_PSCORE_CLASS(PsBaseService, service_config.service_class());
if (service == NULL) {
LOG(ERROR) << "service is unregistered, service_name:"
<< service_config.service_class();
return -1;
}
_service.reset(service);
if (service->Configure(this) != 0 || service->Initialize() != 0) {
LOG(ERROR) << "service initialize failed, service_name:"
<< service_config.service_class();
return -1;
}
if (_server.AddService(service, brpc::SERVER_DOESNT_OWN_SERVICE) != 0) {
LOG(ERROR) << "service add to brpc failed, service:"
<< service_config.service_class();
return -1;
}
return 0;
}
brpc::Channel *GraphBrpcServer::GetCmdChannel(size_t server_index) {
return _pserver_channels[server_index].get();
}
uint64_t GraphBrpcServer::Start(const std::string &ip, uint32_t port) {
std::unique_lock<std::mutex> lock(mutex_);
std::string ip_port = ip + ":" + std::to_string(port);
VLOG(3) << "server of rank " << _rank << " starts at " << ip_port;
brpc::ServerOptions options;
int num_threads = std::thread::hardware_concurrency();
auto trainers = _environment->GetTrainers();
options.num_threads = trainers > num_threads ? trainers : num_threads;
if (_server.Start(ip_port.c_str(), &options) != 0) {
LOG(ERROR) << "GraphBrpcServer start failed, ip_port=" << ip_port;
return 0;
}
_environment->RegistePsServer(ip, port, _rank);
return 0;
}
int32_t GraphBrpcServer::build_peer2peer_connection(int rank) {
this->rank = rank;
auto _env = Environment();
brpc::ChannelOptions options;
options.protocol = "baidu_std";
options.timeout_ms = 500000;
options.connection_type = "pooled";
options.connect_timeout_ms = 10000;
options.max_retry = 3;
std::vector<PSHost> server_list = _env->GetPsServers();
_pserver_channels.resize(server_list.size());
std::ostringstream os;
std::string server_ip_port;
for (size_t i = 0; i < server_list.size(); ++i) {
server_ip_port.assign(server_list[i].ip.c_str());
server_ip_port.append(":");
server_ip_port.append(std::to_string(server_list[i].port));
_pserver_channels[i].reset(new brpc::Channel());
if (_pserver_channels[i]->Init(server_ip_port.c_str(), "", &options) != 0) {
VLOG(0) << "GraphServer connect to Server:" << server_ip_port
<< " Failed! Try again.";
std::string int_ip_port =
GetIntTypeEndpoint(server_list[i].ip, server_list[i].port);
if (_pserver_channels[i]->Init(int_ip_port.c_str(), "", &options) != 0) {
LOG(ERROR) << "GraphServer connect to Server:" << int_ip_port
<< " Failed!";
return -1;
}
}
os << server_ip_port << ",";
}
LOG(INFO) << "servers peer2peer connection success:" << os.str();
return 0;
}
int32_t GraphBrpcService::clear_nodes(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
int type_id = *(int *)(request.params(0).c_str());
int idx_ = *(int *)(request.params(1).c_str());
((GraphTable *)table)->clear_nodes(type_id, idx_);
return 0;
}
int32_t GraphBrpcService::add_graph_node(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) {
set_response_code(
response, -1, "add_graph_node request requires at least 2 arguments");
return 0;
}
int idx_ = *(int *)(request.params(0).c_str());
size_t node_num = request.params(1).size() / sizeof(int64_t);
uint64_t *node_data = (uint64_t *)(request.params(1).c_str());
std::vector<uint64_t> node_ids(node_data, node_data + node_num);
std::vector<bool> is_weighted_list;
if (request.params_size() == 3) {
size_t weight_list_size = request.params(2).size() / sizeof(bool);
bool *is_weighted_buffer = (bool *)(request.params(2).c_str());
is_weighted_list = std::vector<bool>(is_weighted_buffer,
is_weighted_buffer + weight_list_size);
}
// if (request.params_size() == 2) {
// size_t weight_list_size = request.params(1).size() / sizeof(bool);
// bool *is_weighted_buffer = (bool *)(request.params(1).c_str());
// is_weighted_list = std::vector<bool>(is_weighted_buffer,
// is_weighted_buffer +
// weight_list_size);
// }
((GraphTable *)table)->add_graph_node(idx_, node_ids, is_weighted_list);
return 0;
}
int32_t GraphBrpcService::remove_graph_node(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) {
set_response_code(
response,
-1,
"remove_graph_node request requires at least 2 arguments");
return 0;
}
int idx_ = *(int *)(request.params(0).c_str());
size_t node_num = request.params(1).size() / sizeof(uint64_t);
uint64_t *node_data = (uint64_t *)(request.params(1).c_str());
std::vector<uint64_t> node_ids(node_data, node_data + node_num);
((GraphTable *)table)->remove_graph_node(idx_, node_ids);
return 0;
}
int32_t GraphBrpcServer::Port() { return _server.listen_address().port; }
int32_t GraphBrpcService::Initialize() {
_is_initialize_shard_info = false;
_service_handler_map[PS_STOP_SERVER] = &GraphBrpcService::StopServer;
_service_handler_map[PS_LOAD_ONE_TABLE] = &GraphBrpcService::LoadOneTable;
_service_handler_map[PS_LOAD_ALL_TABLE] = &GraphBrpcService::LoadAllTable;
_service_handler_map[PS_PRINT_TABLE_STAT] = &GraphBrpcService::PrintTableStat;
_service_handler_map[PS_BARRIER] = &GraphBrpcService::Barrier;
_service_handler_map[PS_START_PROFILER] = &GraphBrpcService::StartProfiler;
_service_handler_map[PS_STOP_PROFILER] = &GraphBrpcService::StopProfiler;
_service_handler_map[PS_PULL_GRAPH_LIST] = &GraphBrpcService::pull_graph_list;
_service_handler_map[PS_GRAPH_SAMPLE_NEIGHBORS] =
&GraphBrpcService::graph_random_sample_neighbors;
_service_handler_map[PS_GRAPH_SAMPLE_NODES] =
&GraphBrpcService::graph_random_sample_nodes;
_service_handler_map[PS_GRAPH_GET_NODE_FEAT] =
&GraphBrpcService::graph_get_node_feat;
_service_handler_map[PS_GRAPH_CLEAR] = &GraphBrpcService::clear_nodes;
_service_handler_map[PS_GRAPH_ADD_GRAPH_NODE] =
&GraphBrpcService::add_graph_node;
_service_handler_map[PS_GRAPH_REMOVE_GRAPH_NODE] =
&GraphBrpcService::remove_graph_node;
_service_handler_map[PS_GRAPH_SET_NODE_FEAT] =
&GraphBrpcService::graph_set_node_feat;
_service_handler_map[PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER] =
&GraphBrpcService::sample_neighbors_across_multi_servers;
InitializeShardInfo();
return 0;
}
int32_t GraphBrpcService::InitializeShardInfo() {
if (!_is_initialize_shard_info) {
std::lock_guard<std::mutex> guard(_initialize_shard_mutex);
if (_is_initialize_shard_info) {
return 0;
}
server_size = _server->Environment()->GetPsServers().size();
auto &table_map = *(_server->GetTable());
for (auto itr : table_map) {
itr.second->SetShard(_rank, server_size);
}
_is_initialize_shard_info = true;
}
return 0;
}
void GraphBrpcService::service(google::protobuf::RpcController *cntl_base,
const PsRequestMessage *request,
PsResponseMessage *response,
google::protobuf::Closure *done) {
brpc::ClosureGuard done_guard(done);
std::string log_label("ReceiveCmd-");
if (!request->has_table_id()) {
set_response_code(*response, -1, "PsRequestMessage.tabel_id is required");
return;
}
response->set_err_code(0);
response->set_err_msg("");
auto *table = _server->GetTable(request->table_id());
brpc::Controller *cntl = static_cast<brpc::Controller *>(cntl_base);
auto itr = _service_handler_map.find(request->cmd_id());
if (itr == _service_handler_map.end()) {
std::string err_msg(
"undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:");
err_msg.append(std::to_string(request->cmd_id()));
set_response_code(*response, -1, err_msg.c_str());
return;
}
serviceFunc handler_func = itr->second;
int service_ret = (this->*handler_func)(table, *request, *response, cntl);
if (service_ret != 0) {
response->set_err_code(service_ret);
if (!response->has_err_msg()) {
response->set_err_msg("server internal error");
}
}
}
int32_t GraphBrpcService::Barrier(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 1) {
set_response_code(response,
-1,
"PsRequestMessage.params is requeired at "
"least 1 for num of sparse_key");
return 0;
}
auto trainer_id = request.client_id();
auto barrier_type = request.params(0);
table->Barrier(trainer_id, barrier_type);
return 0;
}
int32_t GraphBrpcService::PrintTableStat(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
std::pair<int64_t, int64_t> ret = table->PrintTableStat();
paddle::framework::BinaryArchive ar;
ar << ret.first << ret.second;
std::string table_info(ar.Buffer(), ar.Length());
response.set_data(table_info);
return 0;
}
int32_t GraphBrpcService::LoadOneTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) {
set_response_code(
response,
-1,
"PsRequestMessage.datas is requeired at least 2 for path & load_param");
return -1;
}
if (table->Load(request.params(0), request.params(1)) != 0) {
set_response_code(response, -1, "table load failed");
return -1;
}
return 0;
}
int32_t GraphBrpcService::LoadAllTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
auto &table_map = *(_server->GetTable());
for (auto &itr : table_map) {
if (LoadOneTable(itr.second.get(), request, response, cntl) != 0) {
LOG(ERROR) << "load table[" << itr.first << "] failed";
return -1;
}
}
return 0;
}
int32_t GraphBrpcService::StopServer(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
GraphBrpcServer *p_server = (GraphBrpcServer *)_server;
std::thread t_stop([p_server]() {
p_server->Stop();
LOG(INFO) << "Server Stoped";
});
p_server->export_cv()->notify_all();
t_stop.detach();
return 0;
}
int32_t GraphBrpcService::StopProfiler(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::DisableProfiler(platform::EventSortingKey::kDefault,
string::Sprintf("server_%s_profile", _rank));
return 0;
}
int32_t GraphBrpcService::StartProfiler(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::EnableProfiler(platform::ProfilerState::kCPU);
return 0;
}
int32_t GraphBrpcService::pull_graph_list(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 5) {
set_response_code(
response, -1, "pull_graph_list request requires at least 5 arguments");
return 0;
}
int type_id = *(int *)(request.params(0).c_str());
int idx = *(int *)(request.params(1).c_str());
int start = *(int *)(request.params(2).c_str());
int size = *(int *)(request.params(3).c_str());
int step = *(int *)(request.params(4).c_str());
std::unique_ptr<char[]> buffer;
int actual_size;
((GraphTable *)table)
->pull_graph_list(
type_id, idx, start, size, buffer, actual_size, false, step);
cntl->response_attachment().append(buffer.get(), actual_size);
return 0;
}
int32_t GraphBrpcService::graph_random_sample_neighbors(
Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 4) {
set_response_code(
response,
-1,
"graph_random_sample_neighbors request requires at least 3 arguments");
return 0;
}
int idx_ = *(int *)(request.params(0).c_str());
size_t node_num = request.params(1).size() / sizeof(uint64_t);
uint64_t *node_data = (uint64_t *)(request.params(1).c_str());
int sample_size = *(int *)(request.params(2).c_str());
bool need_weight = *(bool *)(request.params(3).c_str());
std::vector<std::shared_ptr<char>> buffers(node_num);
std::vector<int> actual_sizes(node_num, 0);
((GraphTable *)table)
->random_sample_neighbors(
idx_, node_data, sample_size, buffers, actual_sizes, need_weight);
cntl->response_attachment().append(&node_num, sizeof(size_t));
cntl->response_attachment().append(actual_sizes.data(),
sizeof(int) * node_num);
for (size_t idx = 0; idx < node_num; ++idx) {
cntl->response_attachment().append(buffers[idx].get(), actual_sizes[idx]);
}
return 0;
}
int32_t GraphBrpcService::graph_random_sample_nodes(
Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
int type_id = *(int *)(request.params(0).c_str());
int idx_ = *(int *)(request.params(1).c_str());
size_t size = *(uint64_t *)(request.params(2).c_str());
// size_t size = *(int64_t *)(request.params(0).c_str());
std::unique_ptr<char[]> buffer;
int actual_size;
if (((GraphTable *)table)
->random_sample_nodes(type_id, idx_, size, buffer, actual_size) ==
0) {
cntl->response_attachment().append(buffer.get(), actual_size);
} else
cntl->response_attachment().append(NULL, 0);
return 0;
}
int32_t GraphBrpcService::graph_get_node_feat(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 3) {
set_response_code(
response,
-1,
"graph_get_node_feat request requires at least 3 arguments");
return 0;
}
int idx_ = *(int *)(request.params(0).c_str());
size_t node_num = request.params(1).size() / sizeof(uint64_t);
uint64_t *node_data = (uint64_t *)(request.params(1).c_str());
std::vector<uint64_t> node_ids(node_data, node_data + node_num);
std::vector<std::string> feature_names =
paddle::string::split_string<std::string>(request.params(2), "\t");
std::vector<std::vector<std::string>> feature(
feature_names.size(), std::vector<std::string>(node_num));
((GraphTable *)table)->get_node_feat(idx_, node_ids, feature_names, feature);
for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
size_t feat_len = feature[feat_idx][node_idx].size();
cntl->response_attachment().append(&feat_len, sizeof(size_t));
cntl->response_attachment().append(feature[feat_idx][node_idx].data(),
feat_len);
}
}
return 0;
}
int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
// sleep(5);
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 4) {
set_response_code(response,
-1,
"sample_neighbors_across_multi_servers request requires "
"at least 4 arguments");
return 0;
}
int idx_ = *(int *)(request.params(0).c_str());
size_t node_num = request.params(1).size() / sizeof(uint64_t);
uint64_t *node_data = (uint64_t *)(request.params(1).c_str());
int sample_size = *(int *)(request.params(2).c_str());
bool need_weight = *(bool *)(request.params(3).c_str());
std::vector<int> request2server;
std::vector<int> server2request(server_size, -1);
std::vector<uint64_t> local_id;
std::vector<int> local_query_idx;
size_t rank = GetRank();
for (size_t query_idx = 0; query_idx < node_num; ++query_idx) {
int server_index =
((GraphTable *)table)->get_server_index_by_id(node_data[query_idx]);
if (server2request[server_index] == -1) {
server2request[server_index] = request2server.size();
request2server.push_back(server_index);
}
}
if (server2request[rank] != -1) {
auto pos = server2request[rank];
std::swap(request2server[pos],
request2server[(int)request2server.size() - 1]);
server2request[request2server[pos]] = pos;
server2request[request2server[(int)request2server.size() - 1]] =
request2server.size() - 1;
}
size_t request_call_num = request2server.size();
std::vector<std::shared_ptr<char>> local_buffers;
std::vector<int> local_actual_sizes;
std::vector<size_t> seq;
std::vector<std::vector<uint64_t>> node_id_buckets(request_call_num);
std::vector<std::vector<int>> query_idx_buckets(request_call_num);
for (size_t query_idx = 0; query_idx < node_num; ++query_idx) {
int server_index =
((GraphTable *)table)->get_server_index_by_id(node_data[query_idx]);
int request_idx = server2request[server_index];
node_id_buckets[request_idx].push_back(node_data[query_idx]);
query_idx_buckets[request_idx].push_back(query_idx);
seq.push_back(request_idx);
}
size_t remote_call_num = request_call_num;
if (request2server.size() != 0 &&
static_cast<size_t>(request2server.back()) == rank) {
remote_call_num--;
local_buffers.resize(node_id_buckets.back().size());
local_actual_sizes.resize(node_id_buckets.back().size());
}
cntl->response_attachment().append(&node_num, sizeof(size_t));
auto local_promise = std::make_shared<std::promise<int32_t>>();
std::future<int> local_fut = local_promise->get_future();
std::vector<bool> failed(server_size, false);
std::function<void(void *)> func = [&,
node_id_buckets,
query_idx_buckets,
request_call_num](void *done) {
local_fut.get();
std::vector<int> actual_size;
auto *closure = (DownpourBrpcClosure *)done;
std::vector<std::unique_ptr<butil::IOBufBytesIterator>> res(
remote_call_num);
size_t fail_num = 0;
for (size_t request_idx = 0; request_idx < remote_call_num; ++request_idx) {
if (closure->check_response(request_idx, PS_GRAPH_SAMPLE_NEIGHBORS) !=
0) {
++fail_num;
failed[request2server[request_idx]] = true;
} else {
auto &res_io_buffer = closure->cntl(request_idx)->response_attachment();
res[request_idx].reset(new butil::IOBufBytesIterator(res_io_buffer));
size_t num;
res[request_idx]->copy_and_forward(&num, sizeof(size_t));
}
}
int size;
int local_index = 0;
for (size_t i = 0; i < node_num; i++) {
if (fail_num > 0 && failed[seq[i]]) {
size = 0;
} else if (static_cast<size_t>(request2server[seq[i]]) != rank) {
res[seq[i]]->copy_and_forward(&size, sizeof(int));
} else {
size = local_actual_sizes[local_index++];
}
actual_size.push_back(size);
}
cntl->response_attachment().append(actual_size.data(),
actual_size.size() * sizeof(int));
local_index = 0;
for (size_t i = 0; i < node_num; i++) {
if (fail_num > 0 && failed[seq[i]]) {
continue;
} else if (static_cast<size_t>(request2server[seq[i]]) != rank) {
char temp[actual_size[i] + 1];
res[seq[i]]->copy_and_forward(temp, actual_size[i]);
cntl->response_attachment().append(temp, actual_size[i]);
} else {
char *temp = local_buffers[local_index++].get();
cntl->response_attachment().append(temp, actual_size[i]);
}
}
closure->set_promise_value(0);
};
DownpourBrpcClosure *closure = new DownpourBrpcClosure(remote_call_num, func);
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t request_idx = 0; request_idx < remote_call_num; ++request_idx) {
int server_index = request2server[request_idx];
closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE_NEIGHBORS);
closure->request(request_idx)->set_table_id(request.table_id());
closure->request(request_idx)->set_client_id(rank);
size_t node_num = node_id_buckets[request_idx].size();
closure->request(request_idx)->add_params((char *)&idx_, sizeof(int));
closure->request(request_idx)
->add_params((char *)node_id_buckets[request_idx].data(),
sizeof(uint64_t) * node_num);
closure->request(request_idx)
->add_params((char *)&sample_size, sizeof(int));
closure->request(request_idx)
->add_params((char *)&need_weight, sizeof(bool));
PsService_Stub rpc_stub(
((GraphBrpcServer *)GetServer())->GetCmdChannel(server_index));
// GraphPsService_Stub rpc_stub =
// getServiceStub(GetCmdChannel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx),
closure->request(request_idx),
closure->response(request_idx),
closure);
}
if (server2request[rank] != -1) {
((GraphTable *)table)
->random_sample_neighbors(idx_,
node_id_buckets.back().data(),
sample_size,
local_buffers,
local_actual_sizes,
need_weight);
}
local_promise.get()->set_value(0);
if (remote_call_num == 0) func(closure);
fut.get();
return 0;
}
int32_t GraphBrpcService::graph_set_node_feat(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 4) {
set_response_code(
response,
-1,
"graph_set_node_feat request requires at least 3 arguments");
return 0;
}
int idx_ = *(int *)(request.params(0).c_str());
size_t node_num = request.params(1).size() / sizeof(uint64_t);
uint64_t *node_data = (uint64_t *)(request.params(1).c_str());
std::vector<uint64_t> node_ids(node_data, node_data + node_num);
// std::vector<std::string> feature_names =
// paddle::string::split_string<std::string>(request.params(1), "\t");
std::vector<std::string> feature_names =
paddle::string::split_string<std::string>(request.params(2), "\t");
std::vector<std::vector<std::string>> features(
feature_names.size(), std::vector<std::string>(node_num));
// const char *buffer = request.params(2).c_str();
const char *buffer = request.params(3).c_str();
for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
size_t feat_len = *(size_t *)(buffer);
buffer += sizeof(size_t);
auto feat = std::string(buffer, feat_len);
features[feat_idx][node_idx] = feat;
buffer += feat_len;
}
}
((GraphTable *)table)->set_node_feat(idx_, node_ids, feature_names, features);
return 0;
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <vector>
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/ps/service/brpc_ps_server.h"
#include "paddle/fluid/distributed/ps/service/server.h"
#include "paddle/fluid/distributed/ps/table/common_graph_table.h"
#include "paddle/fluid/distributed/ps/table/table.h"
namespace paddle {
namespace distributed {
class GraphBrpcServer : public PSServer {
public:
GraphBrpcServer() {}
virtual ~GraphBrpcServer() {}
PsBaseService *get_service() { return _service.get(); }
virtual uint64_t Start(const std::string &ip, uint32_t port);
virtual int32_t build_peer2peer_connection(int rank);
virtual brpc::Channel *GetCmdChannel(size_t server_index);
virtual int32_t Stop() {
std::unique_lock<std::mutex> lock(mutex_);
if (stoped_) return 0;
stoped_ = true;
// cv_.notify_all();
_server.Stop(1000);
_server.Join();
return 0;
}
int32_t Port();
std::condition_variable *export_cv() { return &cv_; }
private:
virtual int32_t Initialize();
mutable std::mutex mutex_;
std::condition_variable cv_;
bool stoped_ = false;
int rank;
brpc::Server _server;
std::shared_ptr<PsBaseService> _service;
std::vector<std::shared_ptr<brpc::Channel>> _pserver_channels;
};
class GraphBrpcService;
typedef int32_t (GraphBrpcService::*serviceFunc)(
Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
class GraphBrpcService : public PsBaseService {
public:
virtual int32_t Initialize() override;
virtual void service(::google::protobuf::RpcController *controller,
const PsRequestMessage *request,
PsResponseMessage *response,
::google::protobuf::Closure *done) override;
protected:
std::unordered_map<int32_t, serviceFunc> _service_handler_map;
int32_t InitializeShardInfo();
int32_t pull_graph_list(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t graph_random_sample_neighbors(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t graph_random_sample_nodes(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t graph_get_node_feat(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t graph_set_node_feat(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t clear_nodes(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t add_graph_node(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t remove_graph_node(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t Barrier(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t LoadOneTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t LoadAllTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t StopServer(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t StartProfiler(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t StopProfiler(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t PrintTableStat(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t sample_neighbors_across_multi_servers(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t use_neighbors_sample_cache(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t load_graph_split_config(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
private:
bool _is_initialize_shard_info;
std::mutex _initialize_shard_mutex;
std::unordered_map<int32_t, serviceHandlerFunc> _msg_handler_map;
std::vector<float> _ori_values;
const int sample_nodes_ranges = 23;
size_t server_size;
std::shared_ptr<::ThreadPool> task_pool;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/heter_client.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace distributed {
DEFINE_int32(heter_world_size, 100, "group size"); // group max size
DEFINE_int32(switch_send_recv_timeout_s, 600, "switch_send_recv_timeout_s");
std::shared_ptr<HeterClient> HeterClient::s_instance_ = nullptr;
std::mutex HeterClient::mtx_;
std::shared_ptr<HeterClient> HeterClient::switch_s_instance_ = nullptr;
int GetMicroId(const platform::DeviceContext& ctx,
const framework::Scope* scope) {
framework::Variable* var = scope->FindVar("microbatch_id");
PADDLE_ENFORCE_EQ(var->IsType<framework::LoDTensor>(),
true,
platform::errors::InvalidArgument(
"the type of micro id shoulde be LoDTensor."));
auto micro_id = -1;
auto* tensor = var->GetMutable<framework::LoDTensor>();
if (platform::is_cpu_place(tensor->place())) {
auto data = reinterpret_cast<const float*>(tensor->data());
micro_id = static_cast<int>(data[0]);
} else {
#ifdef PADDLE_WITH_CUDA
std::vector<char> temp;
temp.resize(tensor->numel() * framework::DataTypeSize(tensor->dtype()));
char* temp_ptr = temp.data();
auto stream = reinterpret_cast<const phi::GPUContext&>(ctx).stream();
memory::Copy(platform::CPUPlace(),
temp_ptr,
tensor->place(),
tensor->data(),
tensor->numel() * framework::DataTypeSize(tensor->dtype()),
stream);
float* temp_ptr_float = reinterpret_cast<float*>(temp_ptr);
micro_id = static_cast<int>(temp_ptr_float[0]);
#endif
}
return micro_id;
}
void HeterClient::Stop() {
auto status = StopHeterWorker();
status.wait();
}
std::future<int32_t> HeterClient::StopHeterWorker() {
return SendCmd(-1, PS_STOP_SERVER, {});
}
std::future<int32_t> HeterClient::StartProfiler() {
return SendCmd(-1, PS_START_PROFILER, {});
}
std::future<int32_t> HeterClient::StopProfiler() {
return SendCmd(-1, PS_STOP_PROFILER, {});
}
void HeterClient::CreateClient2XpuConnection() {
brpc::ChannelOptions options;
options.protocol = "baidu_std";
options.connection_type = "single";
options.timeout_ms = FLAGS_pserver_timeout_ms;
xpu_channels_.resize(xpu_list_.size());
for (size_t i = 0; i < xpu_list_.size(); ++i) {
xpu_channels_[i].reset(new brpc::Channel());
if (xpu_channels_[i]->Init(xpu_list_[i].c_str(), "", &options) != 0) {
VLOG(0) << "HeterClient channel init fail. Try Again";
auto ip_port = paddle::string::Split(xpu_list_[i], ':');
std::string ip = ip_port[0];
int port = std::stoi(ip_port[1]);
std::string int_ip_port = GetIntTypeEndpoint(ip, port);
if (xpu_channels_[i]->Init(int_ip_port.c_str(), "", &options) != 0) {
LOG(ERROR) << "BrpcPsServer start failed, ip_port= " << int_ip_port;
}
}
}
previous_xpu_channels_.resize(previous_xpu_list_.size());
for (size_t i = 0; i < previous_xpu_list_.size(); ++i) {
previous_xpu_channels_[i].reset(new brpc::Channel());
if (previous_xpu_channels_[i]->Init(
previous_xpu_list_[i].c_str(), "", &options) != 0) {
VLOG(0) << "HeterClient channel init fail. Try Again";
auto ip_port = paddle::string::Split(previous_xpu_list_[i], ':');
std::string ip = ip_port[0];
int port = std::stoi(ip_port[1]);
std::string int_ip_port = GetIntTypeEndpoint(ip, port);
if (previous_xpu_channels_[i]->Init(int_ip_port.c_str(), "", &options) !=
0) {
LOG(ERROR) << "BrpcPsServer start failed, ip_port= " << int_ip_port;
}
}
}
}
void HeterClient::SendAndRecvAsync(
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& message_name,
const std::vector<std::string>& send_var_name,
const std::vector<std::string>& recv_var_name,
const std::string& mode) {
platform::RecordEvent record_event("HeterClient->SendAndRecvAsync",
platform::TracerEventType::Communication,
1);
const platform::DeviceContext* p_ctx = &ctx;
const framework::Scope* p_scope = &scope;
const std::vector<std::string> send_var_name_val = send_var_name;
const std::vector<std::string> recv_var_name_val = recv_var_name;
VLOG(3) << "BRPCClient::SendAndRecv Begin, message_name: " << message_name;
brpc::Channel* channel = nullptr;
distributed::MultiVarMsg request;
OnHeterRpcDone* closure = new OnHeterRpcDone([](void* done) {
auto* closure = reinterpret_cast<OnHeterRpcDone*>(done);
PADDLE_ENFORCE_NE(
closure->cntl.Failed(),
true,
platform::errors::Unimplemented(
"HeterClient::SendAndRecv meets brpc error, error message is %s",
closure->cntl.ErrorText()));
VLOG(4) << "call heter_worker success";
});
closure->cntl.set_timeout_ms(FLAGS_pserver_timeout_ms);
auto& request_io_buffer = closure->cntl.request_attachment();
distributed::SerializeToMultiVarMsgAndIOBuf(message_name,
send_var_name_val,
recv_var_name_val,
*p_ctx,
p_scope,
&request,
&request_io_buffer);
int micro_id = GetMicroId(ctx, p_scope); // global
auto minibatch_id = micro_id / 10;
VLOG(4) << "micro_id: " << micro_id;
// select channel according to micro id
if (mode == "forward") {
int num = minibatch_id % xpu_channels_.size();
channel = xpu_channels_[num].get();
} else if (mode == "backward") {
int num = minibatch_id % previous_xpu_channels_.size();
channel = previous_xpu_channels_[num].get();
} else if (mode == "send_to_switch") {
VLOG(4) << "calling switch service";
// auto promise = std::make_shared<std::promise<int32_t>>();
// closure->add_promise(promise);
// std::future<int> fut = promise->get_future();
// int idx = 1; // for test
// LOG(INFO) << "xpu_channels_ size: " << xpu_channels_.size();
// channel = xpu_channels_[idx].get(); // 为了适配 send_and_recv op
// ::paddle::distributed::PsService_Stub stub(channel);
// stub.SendToSwitch(&closure->cntl, &request, &closure->response,
// closure); fut.wait();
VLOG(4) << "calling switch service done";
return;
}
::paddle::distributed::PsService_Stub stub(channel);
stub.SendAndRecvVariable(
&closure->cntl, &request, &closure->response, closure);
}
std::future<int32_t> HeterClient::SendCmd(
uint32_t table_id, int cmd_id, const std::vector<std::string>& params) {
size_t request_call_num = xpu_channels_.size();
paddle::distributed::DownpourBrpcClosure* closure =
new paddle::distributed::DownpourBrpcClosure(
request_call_num, [request_call_num, cmd_id](void* done) {
int ret = 0;
auto* closure = (paddle::distributed::DownpourBrpcClosure*)done;
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, cmd_id) != 0) {
ret = -1;
break;
}
}
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(cmd_id);
closure->request(i)->set_table_id(table_id);
closure->request(i)->set_client_id(trainer_id_);
for (const auto& param : params) {
closure->request(i)->add_params(param);
}
::paddle::distributed::PsService_Stub rpc_stub(xpu_channels_[i].get());
closure->cntl(i)->set_timeout_ms(
FLAGS_pserver_timeout_ms); // cmd msg don't limit timeout for save/load
rpc_stub.service(
closure->cntl(i), closure->request(i), closure->response(i), closure);
}
return fut;
}
int HeterClient::Send(const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& message_name,
const std::vector<std::string>& send_var_names) {
const framework::Scope* p_scope = &scope; // 注意是 const
OnHeterRpcDone* closure = new OnHeterRpcDone([](void* done) {
auto* closure = reinterpret_cast<OnHeterRpcDone*>(done);
int ret = 0;
closure->set_promise_value(ret);
if (closure->cntl.Failed()) {
PADDLE_ENFORCE_NE(
closure->cntl.Failed(),
true,
platform::errors::Unimplemented(
"HeterClient::SendToSwitch meets brpc error, error message is %s",
closure->cntl.ErrorText()));
}
});
closure->cntl.set_timeout_ms(FLAGS_pserver_timeout_ms);
auto& request_io_buffer = closure->cntl.request_attachment();
distributed::MultiVarMsg request;
// 1. set req message_name(string)
request.set_message_name(message_name);
request.set_group_id(0);
// 2. set req send_var_names(<string>)
for (auto& send_var_name : send_var_names) {
request.add_send_var_names(send_var_name);
}
// 3. set req var_messages(<VarMessage>)
for (auto& send_var_name : send_var_names) {
auto* send_var_msg = request.add_var_messages();
send_var_msg->set_varname(send_var_name);
framework::Variable* var = p_scope->FindVar(send_var_name);
butil::IOBuf temp_iobuf;
if (var->IsType<framework::LoDTensor>()) {
SerializeLodTensor(var, ctx, send_var_msg, &temp_iobuf);
} else if (var->IsType<phi::SelectedRows>()) {
SerializeSelectedRows(var, ctx, send_var_msg, &temp_iobuf);
}
request_io_buffer.append(temp_iobuf);
}
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
if (send_switch_channels_.empty()) {
LOG(ERROR) << "send_switch_channels_ is null, get xpu_channels_[0]";
if (xpu_channels_.empty()) {
LOG(ERROR) << "xpu_channels_ is null";
}
send_switch_channels_.push_back(xpu_channels_[0]);
}
brpc::Channel* channel = send_switch_channels_[0].get();
// brpc::Channel* channel = xpu_channels_[0].get();
::paddle::distributed::PsService_Stub stub(channel);
stub.SendToSwitch(&closure->cntl, &request, &closure->ps_response, closure);
VLOG(4) << "waiting SendToSwitch response result......";
fut.wait();
VLOG(4) << "Send done";
return 0;
}
int HeterClient::Send(int group_id,
const std::vector<std::string>& var_names,
const std::vector<int64_t>& vars_size,
void* data_ptr,
int64_t data_size) {
OnHeterRpcDone* closure = new OnHeterRpcDone([](void* done) {
auto* closure = reinterpret_cast<OnHeterRpcDone*>(done);
int ret = 0;
closure->set_promise_value(ret);
if (closure->cntl.Failed()) {
LOG(ERROR) << "Send meets brpc error, err msg is %s"
<< closure->cntl.ErrorText();
}
});
distributed::MultiVarMsg request;
closure->cntl.set_timeout_ms(FLAGS_pserver_timeout_ms);
std::string message_name = "send and save";
request.set_message_name(message_name);
request.set_group_id(group_id);
for (auto& send_var_name : var_names) {
request.add_send_var_names(send_var_name);
}
for (auto var_len : vars_size) {
request.add_vars_len(var_len);
}
auto& request_buffer = closure->cntl.request_attachment();
request_buffer.append(reinterpret_cast<void*>(data_ptr), data_size);
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
if (send_switch_channels_.empty()) {
LOG(ERROR) << "send_switch_channels_ is null, get xpu_channels_[0]";
if (xpu_channels_.empty()) {
LOG(ERROR) << "xpu_channels_ is null";
}
send_switch_channels_.push_back(xpu_channels_[0]);
}
brpc::Channel* channel = send_switch_channels_[0].get();
::paddle::distributed::PsService_Stub stub(channel);
stub.SendToSwitch(&closure->cntl, &request, &closure->ps_response, closure);
fut.wait();
delete closure;
return 0;
}
int HeterClient::Recv(const platform::DeviceContext& ctx,
framework::Scope& recv_scope, // NOLINT
const std::string& message_name,
const std::vector<std::string>& recv_var_names) {
OnHeterRpcDone* closure = new OnHeterRpcDone([](void* done) {
auto* closure = reinterpret_cast<OnHeterRpcDone*>(done);
VLOG(4) << "Recv service call done";
int ret = 0;
closure->set_promise_value(ret);
if (closure->cntl.Failed()) {
VLOG(4) << "HeterClient::RecvFromSwitch meets "
"brpc error, error message is %s"
<< closure->cntl.ErrorText();
}
});
closure->cntl.set_timeout_ms(FLAGS_pserver_timeout_ms);
distributed::MultiVarMsg request;
// 1. set req message_name(string)
request.set_message_name(message_name);
request.set_group_id(0);
// 2. set req recv_var_names(<string>)
for (auto& recv_var_name : recv_var_names) {
request.add_recv_var_names(recv_var_name);
}
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
if (recv_switch_channels_.empty()) {
LOG(ERROR) << "peer_switch_channels_ is null, get xpu_channels_[1]";
if (xpu_channels_.size() < 2) {
LOG(ERROR) << "xpu_channels_ is null";
}
recv_switch_channels_.push_back(xpu_channels_[1]);
}
brpc::Channel* channel = recv_switch_channels_[0].get();
::paddle::distributed::PsService_Stub stub(channel);
stub.RecvFromSwitch(&closure->cntl, &request, &closure->response, closure);
fut.wait();
VLOG(4) << "RecvFromSwitch done";
// save in worker
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::CPUPlace cpu_place;
auto& cpu_dev_ctx = *pool.Get(cpu_place);
auto& res_io_buffer = closure->cntl.response_attachment();
VLOG(4) << "entering DeserializeFromMultiVarMsgAndIOBuf";
distributed::DeserializeFromMultiVarMsgAndIOBuf(
closure->response, &res_io_buffer, cpu_dev_ctx, &recv_scope);
VLOG(4) << "Recv done";
return 0;
}
int HeterClient::Recv(int group_id,
const std::vector<std::string>& var_names,
void* data_ptr,
int64_t data_size) {
OnHeterRpcDone* closure = new OnHeterRpcDone([](void* done) {
auto* closure = reinterpret_cast<OnHeterRpcDone*>(done);
int ret = 0;
closure->set_promise_value(ret);
if (closure->cntl.Failed()) {
LOG(ERROR) << "Recv meets brpc error, err msg is %s"
<< closure->cntl.ErrorText();
}
});
closure->cntl.set_timeout_ms(FLAGS_pserver_timeout_ms);
distributed::MultiVarMsg request;
std::string message_name = "query and recv";
request.set_message_name(message_name);
request.set_group_id(group_id);
for (auto& recv_var_name : var_names) {
request.add_recv_var_names(recv_var_name);
}
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
if (recv_switch_channels_.empty()) {
LOG(ERROR) << "peer_switch_channels_ is null, get xpu_channels_[1]";
if (xpu_channels_.size() < 2) {
LOG(ERROR) << "xpu_channels_ is null";
}
recv_switch_channels_.push_back(xpu_channels_[0]);
}
brpc::Channel* channel = recv_switch_channels_[0].get();
::paddle::distributed::PsService_Stub stub(channel);
stub.RecvFromSwitch(&closure->cntl, &request, &closure->response, closure);
fut.wait();
VLOG(4) << "RecvFromSwitch done";
// save in worker
auto& res_io_buffer = closure->cntl.response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
io_buffer_itr.copy_and_forward(reinterpret_cast<void*>(data_ptr), data_size);
delete closure;
VLOG(4) << "Recv done";
return 0;
}
} // namespace distributed
} // end namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <atomic>
#include <ctime>
#include <map>
#include <memory>
#include <random>
#include <string>
#include <unordered_map>
#include <vector>
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/ps/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/ps/service/brpc_utils.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
#include "paddle/fluid/string/split.h"
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace distributed {
DECLARE_int32(pserver_timeout_ms);
using MultiVarMsg = ::paddle::distributed::MultiVariableMessage;
using VarMsg = ::paddle::distributed::VariableMessage;
typedef std::function<void(void*)> HeterRpcCallbackFunc;
class OnHeterRpcDone : public google::protobuf::Closure {
public:
explicit OnHeterRpcDone(HeterRpcCallbackFunc func) : handler_(func) {}
virtual ~OnHeterRpcDone() {}
void Run() { handler_(this); }
void add_promise(std::shared_ptr<std::promise<int32_t>>& promise) { // NOLINT
_promises.push_back(promise);
}
void set_promise_value(int value) {
for (auto& promise : _promises) {
promise->set_value(value);
}
}
int CheckResponse() { return 0; }
std::vector<std::shared_ptr<std::promise<int32_t>>> _promises;
HeterRpcCallbackFunc handler_;
MultiVariableMessage request;
MultiVariableMessage response;
PsResponseMessage ps_response;
brpc::Controller cntl;
// PsRequestMessage *request(size_t i) { return &_requests[i]; }
// PsResponseMessage *response(size_t i) { return &_responses[i]; }
// std::vector<PsRequestMessage> _requests;
// std::vector<PsResponseMessage> _responses;
// std::vector<std::shared_ptr<brpc::Controller>> _cntls;
};
class HeterClient {
public:
virtual ~HeterClient() {}
void InitClientChannels(bool need_encrypt,
const std::vector<std::string>& node_list,
int32_t peer_role) {
brpc::ChannelOptions options;
options.protocol = "baidu_std";
options.connection_type = "single";
options.timeout_ms = FLAGS_pserver_timeout_ms;
std::vector<std::shared_ptr<brpc::Channel>>* client_channels = nullptr;
if (peer_role == PEER_ROLE_IS_SWITCH) {
#ifdef PADDLE_WITH_ARM_BRPC
if (need_encrypt) {
options.mutable_ssl_options();
}
options.connection_type = "";
VLOG(4) << "ssl enabled in arm";
#else
if (need_encrypt) {
options.mutable_ssl_options();
}
#endif
client_channels = &peer_switch_channels_;
} else if (peer_role == PEER_ROLE_IS_WORKER) {
client_channels = &peer_worker_channels_;
} else {
LOG(ERROR) << "init switch client failed, peer_role not valid";
}
(*client_channels).resize(node_list.size());
for (size_t i = 0; i < node_list.size(); ++i) {
(*client_channels)[i].reset(new brpc::Channel());
if ((*client_channels)[i]->Init(node_list[i].c_str(), "", &options) !=
0) {
VLOG(0) << "client channel init failed! try again";
auto ip_port = paddle::string::Split(node_list[i], ':');
std::string ip = ip_port[0];
int port = std::stoi(ip_port[1]);
std::string int_ip_port = GetIntTypeEndpoint(ip, port);
if ((*client_channels)[i]->Init(int_ip_port.c_str(), "", &options) !=
0) {
LOG(ERROR) << "client channel init failed! peer ip_port = "
<< int_ip_port;
}
}
}
VLOG(4) << "InitClientChannels success";
}
void CreateClient2XpuConnection();
void SendAndRecvAsync(const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& message_name,
const std::vector<std::string>& send_var_name,
const std::vector<std::string>& recv_var_name,
const std::string& mode = "forward");
int Send(int group_id,
const std::vector<std::string>& var_names,
const std::vector<int64_t>& vars_len,
void* data_ptr,
int64_t data_size);
int Send(const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& message_name,
const std::vector<std::string>& send_var_names);
int Recv(int group_id,
const std::vector<std::string>& var_names,
void* data_ptr,
int64_t data_size);
int Recv(const platform::DeviceContext& ctx,
framework::Scope& recv_scope, // NOLINT
const std::string& message_name,
const std::vector<std::string>& recv_var_names);
// HeterClient singleton
static std::shared_ptr<HeterClient> GetInstance(
const std::vector<std::string>& endpoints,
const std::vector<std::string>& previous_endpoints,
const int& trainer_id) {
if (NULL == s_instance_) {
s_instance_.reset(new HeterClient());
s_instance_->SetXpuList(endpoints);
s_instance_->SetPreviousXpuList(previous_endpoints);
s_instance_->SetTrainerID(trainer_id);
s_instance_->CreateClient2XpuConnection();
}
return s_instance_;
}
// switch client singleton
static std::shared_ptr<HeterClient> GetSwitchInstance(
const std::vector<std::string>& peer_endpoints, int32_t peer_role) {
std::unique_lock<std::mutex> lock(mtx_);
if (peer_endpoints.empty()) {
VLOG(4) << "init switch client failed, null peer_endpoints";
}
VLOG(4) << "peer role is: " << peer_role
<< ", addr is: " << peer_endpoints[0];
if (switch_s_instance_ == nullptr) {
switch_s_instance_.reset(new HeterClient());
switch_s_instance_->SetPeerSwitchList(peer_endpoints);
switch_s_instance_->InitClientChannels(false, peer_endpoints, peer_role);
}
return switch_s_instance_;
}
void SetPeerSwitchList(const std::vector<std::string>& peer_endpoints) {
peer_switch_list_ = peer_endpoints;
}
void SetPeerWorkerList(const std::vector<std::string>& worker_endpoints) {
peer_worker_list_ = worker_endpoints;
}
void Stop();
std::future<int32_t> SendCmd(uint32_t table_id,
int cmd_id,
const std::vector<std::string>& params);
std::future<int32_t> StartProfiler();
std::future<int32_t> StopProfiler();
std::future<int32_t> StopHeterWorker();
std::vector<std::string>& GetXpuList() { return xpu_list_; }
void SetXpuList(const std::vector<std::string>& xpu_list) {
xpu_list_ = xpu_list;
}
void SetPreviousXpuList(const std::vector<std::string>& xpu_list) {
previous_xpu_list_ = xpu_list;
}
void SetTrainerID(const int& trainer_id) { trainer_id_ = trainer_id; }
public:
std::vector<std::string> send_switch_list_;
std::vector<std::string> recv_switch_list_;
std::vector<std::string> peer_switch_list_;
std::vector<std::string> peer_worker_list_;
std::vector<std::shared_ptr<brpc::Channel>> send_switch_channels_;
std::vector<std::shared_ptr<brpc::Channel>> recv_switch_channels_;
std::vector<std::shared_ptr<brpc::Channel>> peer_switch_channels_;
std::vector<std::shared_ptr<brpc::Channel>> peer_worker_channels_;
private:
HeterClient() {}
HeterClient& operator=(const HeterClient&);
HeterClient(const HeterClient&);
static std::shared_ptr<HeterClient> s_instance_;
static std::mutex mtx_;
static std::shared_ptr<HeterClient> switch_s_instance_;
std::vector<std::shared_ptr<brpc::Channel>> xpu_channels_;
std::vector<std::shared_ptr<brpc::Channel>> previous_xpu_channels_;
// DISABLE_COPY_AND_ASSIGN(HeterClient);
std::vector<std::string> xpu_list_;
std::vector<std::string> previous_xpu_list_;
int trainer_id_;
};
} // end namespace distributed
} // end namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/heter_server.h"
#include "paddle/fluid/string/split.h"
namespace paddle {
namespace distributed {
// DEFINE_string(cert_path, "./cert.pem", "cert.pem path");
// DEFINE_string(key_path, "./key.pem", "key.pem path");
std::shared_ptr<HeterServer> HeterServer::s_instance_ = nullptr;
std::mutex HeterServer::mtx_;
void HeterServer::RegisterServiceHandler(std::string message_name,
HeterServiceHandler func) {
service_.RegisterServiceHandler(message_name, func);
}
void HeterServer::StartHeterService(bool neeed_encrypt) {
server_.AddService(&service_, brpc::SERVER_DOESNT_OWN_SERVICE);
brpc::ServerOptions options;
if (neeed_encrypt) {
options.mutable_ssl_options()->default_cert.certificate = "/cert.pem";
options.mutable_ssl_options()->default_cert.private_key = "/key.pem";
}
if (server_.Start(endpoint_.c_str(), &options) != 0) {
VLOG(0) << "HeterServer start fail. Try again.";
auto ip_port = paddle::string::Split(endpoint_, ':');
std::string ip = ip_port[0];
int port = std::stoi(ip_port[1]);
std::string int_ip_port = GetIntTypeEndpoint(ip, port);
if (server_.Start(endpoint_.c_str(), &options) != 0) {
LOG(ERROR) << "HeterServer start failed, ip_port= " << int_ip_port;
}
} else {
VLOG(0) << "heter server start success! listen on " << endpoint_;
}
{
std::lock_guard<std::mutex> lock(this->mutex_ready_);
stoped_ = false;
ready_ = 1;
}
condition_ready_.notify_all();
VLOG(4) << "stopped: " << stoped_ << ", ready_: " << ready_;
std::unique_lock<std::mutex> running_lock(mutex_);
cv_.wait(running_lock, [&] {
VLOG(4) << "Heter Server is Stop? " << stoped_;
return stoped_;
});
VLOG(4) << "start service done";
}
void HeterServer::StartHeterInterService(bool neeed_encrypt) {
server_inter_.AddService(&service_, brpc::SERVER_DOESNT_OWN_SERVICE);
brpc::ServerOptions options;
if (neeed_encrypt) {
options.mutable_ssl_options()->default_cert.certificate = "/cert.pem";
options.mutable_ssl_options()->default_cert.private_key = "/key.pem";
}
if (server_inter_.Start(endpoint_inter_.c_str(), &options) != 0) {
VLOG(4) << "switch inter server start fail. Try again.";
auto ip_port = paddle::string::Split(endpoint_inter_, ':');
std::string ip = ip_port[0];
int port = std::stoi(ip_port[1]);
std::string int_ip_port = GetIntTypeEndpoint(ip, port);
if (server_inter_.Start(endpoint_inter_.c_str(), &options) != 0) {
LOG(ERROR) << "switch inter server start failed, ip_port= "
<< int_ip_port;
}
} else {
VLOG(4) << "switch inter server server start success! listen on "
<< endpoint_inter_;
}
{
std::lock_guard<std::mutex> lock(this->mutex_ready_);
stoped_ = false;
ready_ = 1;
}
condition_ready_.notify_all();
VLOG(4) << "stopped: " << stoped_ << ", ready_: " << ready_;
std::unique_lock<std::mutex> running_lock(mutex_);
cv_.wait(running_lock, [&] {
VLOG(4) << "Heter Server is Stop? " << stoped_;
return stoped_;
});
VLOG(4) << "start service done";
}
void HeterServer::SetFanin(const int& fan_in) { service_.SetFanin(fan_in); }
void HeterServer::WaitServerReady() {
std::unique_lock<std::mutex> lock(this->mutex_ready_);
condition_ready_.wait(lock, [=] { return this->ready_ == 1; });
}
int SendAndRecvVariableHandler::SaveInSwitchWithShard(
const MultiVarMsg* request,
PsResponseMessage* response,
brpc::Controller* cntl) {
VLOG(4) << "entering SaveInSwitchWithShard";
int32_t group_id = request->group_id();
if (group_id >= FLAGS_heter_world_size) {
LOG(ERROR) << "group id exceed maxmium";
}
auto& local_shard = _local_shards[group_id];
auto& request_io_buffer = cntl->request_attachment();
butil::IOBufBytesIterator io_buffer_itr(request_io_buffer);
for (int idx = 0; idx < request->send_var_names_size(); idx++) {
const auto& var_name = request->send_var_names(idx);
const auto& var_size = request->vars_len(idx);
WaitForVarsConsumed(group_id, var_name);
std::unique_lock<std::mutex> lk(scope_mutex_);
auto& value = local_shard[var_name];
value.resize(var_size);
io_buffer_itr.copy_and_forward(reinterpret_cast<void*>(value.data()),
var_size);
vars_ready_flag[group_id][var_name] = 1;
VLOG(4) << "saved var_name: " << var_name << "is saved ready!";
}
VLOG(4) << "SaveInSwitchWithShard success";
return 0;
}
int SendAndRecvVariableHandler::QueryInSwitchWithShard(
const MultiVarMsg* request, MultiVarMsg* response, brpc::Controller* cntl) {
VLOG(4) << "entering QueryInSwitchWithShard";
int32_t group_id = request->group_id();
VLOG(4) << "group id: " << group_id;
auto& local_shard = _local_shards[group_id];
auto& response_io_buffer = cntl->response_attachment();
auto req_var_nums = request->recv_var_names_size();
std::vector<std::string> req_var_names(req_var_nums);
for (int var_idx = 0; var_idx < req_var_nums; ++var_idx) {
req_var_names[var_idx] = request->recv_var_names(var_idx);
}
auto msg_name = request->message_name();
response->set_message_name(msg_name);
for (auto& req_var_name : req_var_names) {
VLOG(4) << "req var name: " << req_var_name;
response->add_send_var_names(req_var_name);
WaitForVarsProduced(group_id, req_var_name);
std::unique_lock<std::mutex> lk(scope_mutex_);
auto itr = local_shard.find(req_var_name);
auto& value = itr.value();
response_io_buffer.append(value.data(), value.size());
value.resize(0); // 清空内存
vars_ready_flag[group_id][req_var_name] = 0;
VLOG(4) << "query var_name: " << req_var_name << "is consumed ready!";
}
VLOG(4) << "heter server QueryInSwitchWithShard done";
return 0;
}
int SendAndRecvVariableHandler::SaveInSwitchWithScope(
const MultiVarMsg* request,
PsResponseMessage* response,
brpc::Controller* cntl) {
VLOG(4) << "entering SaveInSwitchWithScope";
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::CPUPlace cpu_place;
auto& cpu_dev_ctx = *pool.Get(cpu_place);
auto message_name = request->message_name();
VLOG(4) << "message_name in heter server: " << message_name;
auto send_var_nums = request->send_var_names_size();
std::vector<std::string> send_var_names(send_var_nums);
for (int idx = 0; idx < send_var_nums; idx++) {
send_var_names[idx] = request->var_messages(idx).varname();
}
std::unique_lock<std::mutex> lk(scope_mutex_);
auto local_scope = local_scope_ptr.get();
if (!local_scope) {
LOG(ERROR) << "local_scope_ptr is null in SaveInSwitchWithScope";
}
for (auto var_name : send_var_names) {
auto* var_exist_ptr = local_scope->FindVar(var_name);
if (!var_exist_ptr) {
VLOG(4) << "not find var: " << var_name << " in local_scope";
}
WaitForVarsConsumed(0, var_name);
}
auto& request_io_buffer = cntl->request_attachment();
distributed::DeserializeFromMultiVarMsgAndIOBuf(
*request, &request_io_buffer, cpu_dev_ctx, local_scope);
lk.unlock();
for (auto var_name : send_var_names) {
std::unique_lock<std::mutex> lk(scope_mutex_);
vars_ready_flag[0][var_name] = 1;
}
VLOG(4) << "SaveInSwitchWithScope success";
return 0;
}
int SendAndRecvVariableHandler::QueryInSwitchWithScope(
const MultiVarMsg* request, MultiVarMsg* response, brpc::Controller* cntl) {
VLOG(4) << "entering QueryInSwitchWithScope";
auto local_scope = local_scope_ptr.get();
if (!local_scope) {
LOG(INFO) << "local_scope is null";
}
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::CPUPlace cpu_place;
auto& cpu_dev_ctx = *pool.Get(cpu_place);
// get req message_name & req_var_names
auto msg_name = request->message_name();
auto req_var_nums = request->recv_var_names_size();
std::vector<std::string> req_var_names(req_var_nums);
for (int var_idx = 0; var_idx < req_var_nums; ++var_idx) {
req_var_names[var_idx] = request->recv_var_names(var_idx);
}
auto& response_io_buffer = cntl->response_attachment();
// 1. fill message_name(string)
response->set_message_name(msg_name);
// 2. fill var_names(string)
for (auto& req_var_name : req_var_names) {
response->add_send_var_names(req_var_name);
}
// 3. fill var_messages(VarMessage)
for (auto& req_var_name : req_var_names) {
WaitForVarsProduced(0, req_var_name);
auto* send_var_msg = response->add_var_messages();
send_var_msg->set_varname(req_var_name);
framework::Variable* var_ptr;
var_ptr = local_scope->FindVar(req_var_name);
if (!var_ptr) {
LOG(INFO) << "local_scope not find var: " << req_var_name;
}
butil::IOBuf temp_iobuf;
if (var_ptr->IsType<framework::LoDTensor>()) {
SerializeLodTensor(var_ptr, cpu_dev_ctx, send_var_msg, &temp_iobuf);
} else if (var_ptr->IsType<phi::SelectedRows>()) {
SerializeSelectedRows(var_ptr, cpu_dev_ctx, send_var_msg, &temp_iobuf);
}
response_io_buffer.append(temp_iobuf);
}
for (auto& req_var_name : req_var_names) {
std::unique_lock<std::mutex> lk(scope_mutex_);
vars_ready_flag[0][req_var_name] = 0;
}
VLOG(4) << "heter server QueryInSwitchWithScope done";
return 0;
}
} // end namespace distributed
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <atomic>
#include <ctime>
#include <map>
#include <memory>
#include <random>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/ps/service/brpc_utils.h"
#include "paddle/fluid/distributed/ps/service/heter_client.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/ps/table/depends/feature_value.h"
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
#include "paddle/fluid/platform/profiler.h"
namespace google {
namespace protobuf {
class Closure;
class RpcController;
} // namespace protobuf
} // namespace google
namespace paddle {
namespace framework {
class Executor;
class ProgramDesc;
class Scope;
} // namespace framework
} // namespace paddle
DECLARE_double(eager_delete_tensor_gb);
namespace paddle {
namespace distributed {
DECLARE_int32(pserver_timeout_ms);
DECLARE_int32(heter_world_size);
DECLARE_int32(switch_send_recv_timeout_s);
using MultiVarMsg = MultiVariableMessage;
using VarMsg = VariableMessage;
using serviceHandler =
std::function<int32_t(const PsRequestMessage& request,
PsResponseMessage& response, // NOLINT
brpc::Controller* cntl)>;
using HeterServiceHandler =
std::function<int32_t(const MultiVarMsg*, MultiVarMsg*, brpc::Controller*)>;
using HeterRpcCallbackFunc = std::function<void(void*)>;
class ServiceHandlerBase {
public:
ServiceHandlerBase() : dev_ctx_(nullptr), scope_(nullptr) {}
virtual ~ServiceHandlerBase() {}
void SetScope(const framework::Scope* scope) { scope_ = scope; }
void SetDevCtx(const platform::DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; }
virtual int Handle(const MultiVarMsg* request,
MultiVarMsg* response,
brpc::Controller* cntl) = 0;
protected:
const platform::DeviceContext* dev_ctx_;
const framework::Scope* scope_;
};
using SharedMiniScope =
std::shared_ptr<std::unordered_map<int, ::paddle::framework::Scope*>>;
using SharedMicroScope = std::shared_ptr<std::unordered_map<
int,
std::shared_ptr<std::vector<::paddle::framework::Scope*>>>>;
using SharedTaskQueue = std::shared_ptr<
std::unordered_map<int,
std::shared_ptr<::paddle::framework::BlockingQueue<
std::pair<std::string, int>>>>>;
class ValueInSwitch {
public:
ValueInSwitch() {}
~ValueInSwitch() {}
char* data() { return _data.data(); }
size_t size() { return _data.size(); }
void resize(size_t size) { _data.resize(size); }
void shrink_to_fit() { _data.shrink_to_fit(); }
private:
std::vector<char> _data;
};
class SendAndRecvVariableHandler final : public ServiceHandlerBase {
public:
SendAndRecvVariableHandler() {
this->num_microbatch_ = 0;
this->num_minibatch_ = 0;
_local_shards.reset(new shard_type[FLAGS_heter_world_size]);
}
virtual ~SendAndRecvVariableHandler() {}
void SetMiniScopes(SharedMiniScope mini_scopes) {
mini_scopes_ = mini_scopes;
num_minibatch_ = mini_scopes_->size();
}
void SetMicroScopes(SharedMicroScope micro_scopes) {
micro_scopes_ = micro_scopes;
for (auto& scope_pair : (*micro_scopes_)) {
// auto mini_idx = scope_pair.first;
auto& micro_scopes = scope_pair.second;
num_microbatch_ = micro_scopes->size();
break;
}
}
int GetThreadNum() {
std::unique_lock<std::mutex> lk(scope_mutex_);
return (*task_queue_).size();
}
int SaveInSwitchWithScope(const MultiVarMsg* request,
PsResponseMessage* response,
brpc::Controller* cntl);
void WaitForVarsConsumed(int32_t group_id, const std::string& var_name) {
// timeline_.Start();
while (true) {
{
std::lock_guard<std::mutex> lock(scope_mutex_);
if (vars_ready_flag[group_id][var_name] == 0) {
break;
}
}
/*
timeline_.Pause();
if (timeline_.ElapsedSec() > FLAGS_switch_send_recv_timeout_s) {
VLOG(0) << "vars not consumed exceed 10 miniutes";
break;
}
*/
}
return;
}
void WaitForVarsProduced(int32_t group_id, const std::string& var_name) {
// timeline_.Start();
while (true) {
{
std::lock_guard<std::mutex> lock(scope_mutex_);
if (vars_ready_flag[group_id][var_name] == 1) {
break;
}
}
/*
timeline_.Pause();
if (timeline_.ElapsedSec() > FLAGS_switch_send_recv_timeout_s) {
VLOG(0) << "vars not produced exceed 10 miniutes";
break;
}
*/
}
return;
}
int SaveInSwitchWithShard(const MultiVarMsg* request,
PsResponseMessage* response,
brpc::Controller* cntl);
int QueryInSwitchWithShard(const MultiVarMsg* request,
MultiVarMsg* response,
brpc::Controller* cntl);
int QueryInSwitchWithScope(const MultiVarMsg* request,
MultiVarMsg* response,
brpc::Controller* cntl);
void SetTaskQueue(SharedTaskQueue task_queue) { task_queue_ = task_queue; }
int Handle(const MultiVarMsg* request,
MultiVarMsg* response,
brpc::Controller* cntl) override {
LOG(INFO) << "entered Handle";
platform::RecordEvent record_event("SendAndRecvVariableHandler->Handle",
platform::TracerEventType::Communication,
1);
FLAGS_eager_delete_tensor_gb = -1;
// get microID from request
// deserialize variable to micro scope
// Push to heter worker's task_queue
std::unique_ptr<paddle::framework::Scope> local_scope_ptr(
new paddle::framework::Scope());
auto& local_scope = *(local_scope_ptr.get());
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::CPUPlace cpu_place;
auto& cpu_dev_ctx = *pool.Get(cpu_place);
auto message_name = request->message_name();
auto& request_io_buffer = cntl->request_attachment();
distributed::DeserializeFromMultiVarMsgAndIOBuf(
*request, &request_io_buffer, cpu_dev_ctx, &local_scope);
auto* var = local_scope.FindVar("microbatch_id");
PADDLE_ENFORCE_NE(var,
nullptr,
platform::errors::InvalidArgument(
"Not find variable microbatch_id in scope."));
auto* tensor = var->GetMutable<framework::LoDTensor>();
auto data = reinterpret_cast<const float*>(tensor->data());
auto micro_id = static_cast<int>(data[0]);
VLOG(4) << "micro_id in heter server: " << micro_id;
int minibatch_index = micro_id / 10;
int microbatch_index = micro_id % 10;
// check minibatch_index is in mini_scopes_
std::unique_lock<std::mutex> lk(scope_mutex_);
if ((*mini_scopes_).find(minibatch_index) != (*mini_scopes_).end()) {
lk.unlock();
PADDLE_ENFORCE_EQ(
(*micro_scopes_).find(minibatch_index) != (*micro_scopes_).end(),
1,
platform::errors::InvalidArgument(
"minibatch index should in current trainer"));
} else {
// create mini scope & micro scopes
auto* minibatch_scope = &(scope_->NewScope());
(*mini_scopes_)[minibatch_index] = minibatch_scope;
(*micro_scopes_)[minibatch_index].reset(
new std::vector<paddle::framework::Scope*>{});
for (int i = 0; i < num_microbatch_; i++) {
auto* micro_scope = &(minibatch_scope->NewScope());
(*((*micro_scopes_)[minibatch_index])).push_back(micro_scope);
}
(*task_queue_)[minibatch_index].reset(
new ::paddle::framework::BlockingQueue<
std::pair<std::string, int>>());
lk.unlock();
}
auto* micro_scope =
(*((*micro_scopes_)[minibatch_index]))[microbatch_index];
distributed::DeserializeFromMultiVarMsgAndIOBuf(
*request, &request_io_buffer, *dev_ctx_, micro_scope);
// blocking queue handles multi thread
VLOG(4) << "Handle in HeterServer: " << message_name << ", "
<< microbatch_index;
VLOG(4) << "task_queue_ size: " << task_queue_->size();
(*task_queue_)[minibatch_index]->Push(
std::make_pair(message_name, microbatch_index));
auto response_var_nums = request->recv_var_names_size();
std::vector<std::string> response_var_names(response_var_nums),
empty_var_names{};
for (int var_idx = 0; var_idx < response_var_nums; ++var_idx) {
response_var_names[var_idx] = request->recv_var_names(var_idx);
}
auto& response_io_buffer = cntl->response_attachment();
distributed::SerializeToMultiVarMsgAndIOBuf(message_name,
response_var_names,
empty_var_names,
*dev_ctx_,
&local_scope,
response,
&response_io_buffer);
VLOG(4) << "Handle over";
return 0;
}
public:
using shard_type = SparseTableShard<std::string, ValueInSwitch>;
std::shared_ptr<paddle::framework::Scope> local_scope_ptr; // for switch
std::unordered_map<uint32_t, std::unordered_map<std::string, uint32_t>>
vars_ready_flag;
std::unique_ptr<shard_type[]> _local_shards;
platform::Timer timeline_;
private:
// share with HeterPipelineTrainer
SharedMiniScope mini_scopes_{nullptr};
SharedMicroScope micro_scopes_{nullptr};
int num_microbatch_;
int num_minibatch_;
std::mutex scope_mutex_;
bool is_first_stage_ = false;
bool is_last_stage_ = false;
SharedTaskQueue task_queue_;
};
class HeterService : public PsService {
public:
HeterService() {
_service_handler_map[PS_STOP_SERVER] =
std::bind(&HeterService::stop_heter_worker,
this,
std::placeholders::_1,
std::placeholders::_2,
std::placeholders::_3);
_service_handler_map[PS_START_PROFILER] =
std::bind(&HeterService::start_profiler,
this,
std::placeholders::_1,
std::placeholders::_2,
std::placeholders::_3);
_service_handler_map[PS_STOP_PROFILER] =
std::bind(&HeterService::stop_profiler,
this,
std::placeholders::_1,
std::placeholders::_2,
std::placeholders::_3);
service_handler_.local_scope_ptr =
std::make_shared<paddle::framework::Scope>();
}
virtual ~HeterService() {}
virtual void service(::google::protobuf::RpcController* controller,
const PsRequestMessage* request,
PsResponseMessage* response,
::google::protobuf::Closure* done) {
brpc::ClosureGuard done_guard(done);
response->set_err_code(0);
response->set_err_msg("");
brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
auto itr = _service_handler_map.find(request->cmd_id());
if (itr == _service_handler_map.end()) {
std::string err_msg(
"undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:");
err_msg.append(std::to_string(request->cmd_id()));
return;
}
serviceHandler handler = itr->second;
int service_ret = handler(*request, *response, cntl);
VLOG(4) << "handler in service ret: " << service_ret;
if (service_ret != 0) {
response->set_err_code(service_ret);
response->set_err_msg("server internal error");
}
}
virtual void SendAndRecvVariable(
::google::protobuf::RpcController* controller,
const MultiVarMsg* request,
MultiVarMsg* response,
::google::protobuf::Closure* done) {
// This object helps you to call done->Run() in RAII style. If you need
// to process the request asynchronously, pass done_guard.release().
brpc::ClosureGuard done_guard(done);
std::string message_name = request->message_name();
VLOG(0) << "SendAndRecvVariable message_name: " << message_name;
auto itr = handler_map_.find(message_name);
brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
LOG(INFO) << "SendAndRecvVariable(client addr) =" << cntl->remote_side();
PADDLE_ENFORCE_NE(
itr,
handler_map_.end(),
platform::errors::InvalidArgument(
"HeterService::SendAndRecvVariable Get illegal message_name: %s "
"which is not in HeterService::handler_map_",
message_name));
itr->second(request, response, cntl);
// We don't want to call done->Run() here, release the guard.
// done_guard.release();
}
virtual void RecvFromSwitch(::google::protobuf::RpcController* controller,
const MultiVarMsg* request,
MultiVarMsg* response,
::google::protobuf::Closure* done) {
brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
// int ret = service_handler_.QueryInSwitchWithScope(request, response,
// cntl);
int ret = service_handler_.QueryInSwitchWithShard(request, response, cntl);
// std::string message_name = request->message_name();
// auto itr = handler_map_.find(message_name);
// int ret = itr->second(request, response, cntl);
if (ret != 0) {
LOG(ERROR) << "QueryInSwitchWithScope failed!";
}
// response->set_message_name(message_name);
}
virtual void SendToSwitch(::google::protobuf::RpcController* controller,
const MultiVarMsg* request,
PsResponseMessage* response,
::google::protobuf::Closure* done) {
VLOG(4) << "entering SendToSwitch";
brpc::ClosureGuard done_guard(done);
std::shared_ptr<HeterClient> switch_client_ptr_ =
HeterClient::GetSwitchInstance(peer_endpoints_, PEER_ROLE_IS_SWITCH);
if (switch_client_ptr_->peer_switch_channels_.empty()) {
LOG(ERROR) << "switch_client_ptr_->peer_switch_channels_ null";
}
brpc::Channel* channel = switch_client_ptr_->peer_switch_channels_[0].get();
brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
// proxy: 定义新的 OnHeterRpcDone 对象(或者在类 OnHeterRpcDone 中 reset)
OnHeterRpcDone* closure2 = new OnHeterRpcDone([](void* done) {
auto* closure = reinterpret_cast<OnHeterRpcDone*>(done);
int ret = closure->CheckResponse();
closure->set_promise_value(ret);
if (closure->cntl.Failed()) {
PADDLE_ENFORCE_NE(
closure->cntl.Failed(),
true,
platform::errors::Unimplemented(
"HeterClient::SendS2S meets brpc error, error message is %s",
closure->cntl.ErrorText()));
}
});
auto& std_cntl = closure2->cntl;
std_cntl.set_timeout_ms(FLAGS_pserver_timeout_ms);
std_cntl.request_attachment().append(cntl->request_attachment().movable());
auto promise = std::make_shared<std::promise<int32_t>>();
closure2->add_promise(promise);
std::future<int> fut = promise->get_future();
// brpc::Controller std_cntl;
// std_cntl.request_attachment().append(cntl->request_attachment().movable());
PsService_Stub stub(channel);
stub.SendS2S(&std_cntl, request, response, closure2);
cntl->response_attachment().append(
std_cntl.response_attachment().movable());
fut.wait();
VLOG(4) << "SendToSwitch done";
delete closure2;
}
void SendS2S(::google::protobuf::RpcController* controller,
const MultiVarMsg* request,
PsResponseMessage* response,
::google::protobuf::Closure* done) {
VLOG(4) << "entering SendS2S";
brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
// int ret = service_handler_.SaveInSwitchWithScope(request, response,
// cntl);
int ret = service_handler_.SaveInSwitchWithShard(request, response, cntl);
// std::string message_name = request->message_name();
// auto itr = handler_map_.find(message_name);
// if (itr == handler_map_.end()) {
// LOG(ERROR) << "can not find func handler";
//}
// int ret = itr->second(request, response, cntl);
if (ret != 0) {
LOG(ERROR) << "SaveInSwitchWithScope failed";
}
std::string err_msg = "ok";
response->set_err_msg(err_msg.c_str());
response->set_err_code(ret);
VLOG(4) << "heter server SendS2S done";
}
void SendToWorker(::google::protobuf::RpcController* controller,
const MultiVarMsg* request,
PsResponseMessage* response,
::google::protobuf::Closure* done) {
brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
VLOG(4) << "SendToWorker(client addr) =" << cntl->remote_side();
std::shared_ptr<distributed::HeterClient> switch_client_ptr_ =
HeterClient::GetSwitchInstance(peer_endpoints_, PEER_ROLE_IS_WORKER);
VLOG(4) << "in switch client, peer worker 0: "
<< switch_client_ptr_->peer_worker_list_[0];
brpc::Channel* channel = switch_client_ptr_->peer_worker_channels_[0].get();
auto* closure = reinterpret_cast<OnHeterRpcDone*>(done);
PsService_Stub stub(channel);
stub.SendAndRecvVariable(controller, request, &closure->response, done);
// fill response content
std::string err_msg("pass to worker");
response->set_err_msg(err_msg.c_str());
response->set_err_code(0);
}
void RegisterServiceHandler(std::string message_name,
HeterServiceHandler func) {
handler_map_[message_name] = func;
}
void SetEndpoint(const std::string& end_point) { endpoint_ = end_point; }
void SetInterEndpoint(const std::string& end_point) {
endpoint_inter_ = end_point;
}
void SetPeerEndPoints(const std::vector<std::string>& peer_endpoints) {
peer_endpoints_ = peer_endpoints;
}
void SetFanin(const int& fan_in) { fan_in_ = fan_in; }
void ForceExit() {
VLOG(3) << "heter service force exit";
is_exit_ = true;
return;
}
bool IsExit() { return is_exit_; }
private:
int32_t stop_profiler(const PsRequestMessage& request,
PsResponseMessage& response, // NOLINT
brpc::Controller* cntl) {
platform::DisableProfiler(
platform::EventSortingKey::kDefault,
string::Sprintf("heter_worker_%s_profile", endpoint_));
return 0;
}
int32_t start_profiler(const PsRequestMessage& request,
PsResponseMessage& response, // NOLINT
brpc::Controller* cntl) {
platform::EnableProfiler(platform::ProfilerState::kAll);
return 0;
}
int32_t stop_heter_worker(const PsRequestMessage& request,
PsResponseMessage& response, // NOLINT
brpc::Controller* cntl) {
auto client_id = request.client_id();
stop_cpu_worker_set_.insert(client_id);
if (stop_cpu_worker_set_.size() == fan_in_) {
is_exit_ = true;
}
return 0;
}
private:
SendAndRecvVariableHandler service_handler_;
std::string endpoint_;
std::string endpoint_inter_;
// for switch
std::vector<std::string> peer_endpoints_;
std::unordered_map<int32_t, serviceHandler> _service_handler_map;
std::unordered_map<std::string, HeterServiceHandler> handler_map_;
std::unordered_set<int> stop_cpu_worker_set_;
uint32_t fan_in_;
bool is_exit_ = false;
};
class HeterServer {
public:
HeterServer() : ready_(0) {}
virtual ~HeterServer() {}
void Stop() {
std::unique_lock<std::mutex> lock(mutex_);
if (stoped_ == true) return;
if (!IsExit()) {
service_.ForceExit();
}
stoped_ = true;
cv_.notify_all();
server_.Stop(1000);
server_.Join();
}
bool IsStop() {
std::unique_lock<std::mutex> lock(mutex_);
return stoped_;
}
bool IsExit() { return service_.IsExit(); }
void RegisterServiceHandler(std::string message_name,
HeterServiceHandler func);
void StartHeterService(bool need_encrypt = false);
void StartHeterInterService(bool need_encrypt = false);
void SetEndPoint(const std::string& endpoint) {
this->endpoint_ = endpoint;
service_.SetEndpoint(endpoint);
}
void SetLocalScope() {
request_handler_->local_scope_ptr =
std::make_shared<paddle::framework::Scope>();
}
void SetInterEndpoint(const std::string& endpoint) {
this->endpoint_inter_ = endpoint;
service_.SetInterEndpoint(endpoint);
}
void SetPeerEndPoints(const std::vector<std::string>& peer_endpoints) {
this->peer_endpoints_ = peer_endpoints;
service_.SetPeerEndPoints(peer_endpoints);
}
void SetFanin(const int& fan_in);
void SetServiceHandler(
std::shared_ptr<SendAndRecvVariableHandler> request_handler) {
request_handler_ = request_handler;
}
void SetMiniBatchScopes(SharedMiniScope mini_scopes) {
request_handler_->SetMiniScopes(mini_scopes);
}
void SetMicroBatchScopes(SharedMicroScope micro_scopes) {
request_handler_->SetMicroScopes(micro_scopes);
}
int GetThreadNum() { return request_handler_->GetThreadNum(); }
void SetTaskQueue(SharedTaskQueue task_queue) {
request_handler_->SetTaskQueue(task_queue);
}
// HeterWrapper singleton
static std::shared_ptr<HeterServer> GetInstance() {
std::unique_lock<std::mutex> lock(mtx_);
if (s_instance_ == nullptr) {
s_instance_.reset(new HeterServer());
}
return s_instance_;
}
void WaitServerReady();
private:
static std::shared_ptr<HeterServer> s_instance_;
mutable std::mutex mutex_;
static std::mutex mtx_;
std::condition_variable cv_;
std::condition_variable condition_ready_;
bool stoped_ = true;
std::string endpoint_;
std::string endpoint_inter_;
// for switch
std::vector<std::string> peer_endpoints_;
protected:
brpc::Server server_;
brpc::Server server_inter_;
HeterService service_;
std::shared_ptr<SendAndRecvVariableHandler> request_handler_;
DISABLE_COPY_AND_ASSIGN(HeterServer);
std::mutex mutex_ready_;
int ready_;
};
} // end namespace distributed
} // end namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/ps_client.h"
#include "glog/logging.h"
#include "paddle/fluid/distributed/ps/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/ps/service/coordinator_client.h"
#include "paddle/fluid/distributed/ps/service/graph_brpc_client.h"
#include "paddle/fluid/distributed/ps/service/ps_local_client.h"
#include "paddle/fluid/distributed/ps/table/table.h"
namespace paddle {
namespace distributed {
REGISTER_PSCORE_CLASS(PSClient, BrpcPsClient);
REGISTER_PSCORE_CLASS(PSClient, PsLocalClient);
REGISTER_PSCORE_CLASS(PSClient, GraphBrpcClient);
REGISTER_PSCORE_CLASS(PSClient, CoordinatorClient);
int32_t PSClient::Configure( // called in FleetWrapper::InitWorker
const PSParameter &config,
const std::map<uint64_t, std::vector<paddle::distributed::Region>> &regions,
PSEnvironment &env,
size_t client_id) {
_env = &env;
_config = config;
_dense_pull_regions = regions;
_client_id = client_id;
_config.mutable_worker_param()
->mutable_downpour_worker_param()
->mutable_downpour_table_param()
->CopyFrom(_config.server_param()
.downpour_server_param()
.downpour_table_param());
const auto &work_param = _config.worker_param().downpour_worker_param();
for (int i = 0; i < work_param.downpour_table_param_size(); ++i) {
auto *accessor = CREATE_PSCORE_CLASS(
ValueAccessor,
work_param.downpour_table_param(i).accessor().accessor_class());
accessor->Configure(work_param.downpour_table_param(i).accessor());
accessor->Initialize();
_table_accessors[work_param.downpour_table_param(i).table_id()].reset(
accessor);
}
return Initialize();
}
PSClient *PSClientFactory::Create(const PSParameter &ps_config) {
const auto &config = ps_config.server_param();
if (!config.has_downpour_server_param()) {
LOG(ERROR) << "miss downpour_server_param in ServerParameter";
return NULL;
}
if (!config.downpour_server_param().has_service_param()) {
LOG(ERROR) << "miss service_param in ServerParameter.downpour_server_param";
return NULL;
}
if (!config.downpour_server_param().service_param().has_client_class()) {
LOG(ERROR) << "miss client_class in "
"ServerParameter.downpour_server_param.service_param";
return NULL;
}
const auto &service_param = config.downpour_server_param().service_param();
PSClient *client =
CREATE_PSCORE_CLASS(PSClient, service_param.client_class());
if (client == NULL) {
LOG(ERROR) << "client is not registered, server_name:"
<< service_param.client_class();
return NULL;
}
TableManager::Instance().Initialize();
VLOG(3) << "Create PSClient[" << service_param.client_class() << "] success";
return client;
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <future>
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/common/cost_timer.h"
#include "paddle/fluid/distributed/ps/service/env.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/distributed/the_one_ps.pb.h"
#include "paddle/fluid/platform/timer.h"
namespace paddle {
namespace distributed {
using paddle::distributed::PsRequestMessage;
using paddle::distributed::PsResponseMessage;
typedef std::function<void(void *)> PSClientCallBack;
class PSClientClosure : public google::protobuf::Closure {
public:
explicit PSClientClosure(PSClientCallBack callback) : _callback(callback) {}
virtual ~PSClientClosure() {}
virtual void set_promise_value(int value) {
for (auto &promise : _promises) {
promise->set_value(value);
}
}
void add_promise(std::shared_ptr<std::promise<int32_t>> &promise) { // NOLINT
_promises.push_back(promise);
}
void add_timer(std::shared_ptr<CostTimer> &timer) { // NOLINT
_timers.push_back(timer);
}
protected:
PSClientCallBack _callback;
std::vector<std::shared_ptr<CostTimer>> _timers;
std::vector<std::shared_ptr<std::promise<int32_t>>> _promises;
};
class PSClient {
public:
PSClient() {}
virtual ~PSClient() {}
PSClient(PSClient &&) = delete;
PSClient(const PSClient &) = delete;
virtual int32_t Configure(
const PSParameter &config,
const std::map<uint64_t, std::vector<paddle::distributed::Region>>
&regions,
PSEnvironment &_env, // NOLINT
size_t client_id) final;
virtual int32_t CreateClient2ClientConnection(int pserver_timeout_ms,
int pserver_connect_timeout_ms,
int max_retry) = 0;
// 触发table数据退场
virtual std::future<int32_t> Shrink(uint32_t table_id,
const std::string threshold) = 0;
// 全量table进行数据load
virtual std::future<int32_t> Load(const std::string &epoch,
const std::string &mode) = 0;
// 指定table数据load
virtual std::future<int32_t> Load(uint32_t table_id,
const std::string &epoch,
const std::string &mode) = 0;
// 全量table数据save value_accessor根据mode,可能有不同的save条件
virtual std::future<int32_t> Save(const std::string &epoch,
const std::string &mode) = 0;
// 指定table数据save value_accessor根据mode,可能有不同的save条件
virtual std::future<int32_t> Save(uint32_t table_id,
const std::string &epoch,
const std::string &mode) = 0;
// 清空table数据
virtual std::future<int32_t> Clear() = 0;
virtual std::future<int32_t> Clear(uint32_t table_id) = 0;
// pull dense的参数部分,并分块填充到本地网络参数中
// start和num用于拉取部分参数
// future结束前keys和values缓冲区不能再次使用
// client将values按照区块拆包后送交多个sender
// sender聚集同一区块的请求,累计多个填充buffer
// server将参数区块中配置的某一维提取返回
// 返回数据解包后填充到累计的多个buffer中
virtual std::future<int32_t> PullDense(Region *regions,
size_t region_num,
size_t table_id) = 0; // 保留
// firstly push dense param for parameter server
// this is necessary because dense weight initialized in trainer on cold
// start
virtual std::future<int32_t> PushDenseParam(const Region *regions,
size_t region_num,
size_t table_id) = 0;
virtual std::future<int32_t> PushDense(const Region *regions,
size_t region_num,
size_t table_id) = 0;
// 使用keys进行pull请求,结果填充values
// keys和values的个数均为num个,每个value占用select_size空间
// future结束前keys和values缓冲区不能再次使用
// 整合多个线程请求的keys,聚集并分散发送到server
// 返回结果后,遍历buffer并对values赋值
// is_training 用于区分请求是训练/预测,server端对于特征和准入会有不同的处理.
virtual std::future<int32_t> PullSparse(float **select_values,
size_t table_id,
const uint64_t *keys,
size_t num,
bool is_training) = 0;
virtual std::future<int32_t> PullSparseParam(float **select_values,
size_t table_id,
const uint64_t *keys,
size_t num,
bool is_training) {
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
virtual ::std::future<int32_t> PullSparsePtr(char **select_values,
size_t table_id,
const uint64_t *keys,
size_t num) {
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
virtual std::future<int32_t> PrintTableStat(uint32_t table_id) = 0;
// 确保所有积攒中的请求都发起发送
virtual std::future<int32_t> Flush() = 0;
// server优雅退出
virtual std::future<int32_t> StopServer() = 0;
// server profilera
virtual std::future<int32_t> StartProfiler() = 0;
virtual std::future<int32_t> StopProfiler() = 0;
virtual std::future<int32_t> Barrier(size_t table_id,
uint32_t barrier_type) = 0;
virtual std::future<int32_t> PullGeoParam(size_t table_id,
std::vector<float> *values,
std::vector<uint64_t> *keys,
int pserver_idx) = 0;
virtual std::future<int32_t> PushGlobalStep(int table_id,
int64_t *total_send_data,
void *done) = 0;
// recv table from server and save it in LodTensor
virtual int32_t RecvAndSaveTable(const uint64_t table_id,
const std::string &path) = 0;
virtual void FinalizeWorker() = 0;
// client to client, 消息发送
virtual std::future<int32_t> SendClient2ClientMsg(int msg_type,
int to_client_id,
const std::string &msg) {
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
// client2client消息处理,std::function<int32_t (int, int, const std::string&)
// -> ret (msg_type, from_client_id, msg)
typedef std::function<int32_t(int, int, const std::string &)> MsgHandlerFunc;
virtual int RegisteClient2ClientMsgHandler(int msg_type,
MsgHandlerFunc handler) {
_msg_handler_map[msg_type] = handler;
return 0;
}
virtual int HandleClient2ClientMsg(int msg_type,
int from_client_id,
const std::string &msg) {
auto itr = _msg_handler_map.find(msg_type);
if (itr == _msg_handler_map.end()) {
LOG(WARNING) << "unknown client2client_msg type:" << msg_type;
return -1;
}
return itr->second(msg_type, from_client_id, msg);
}
virtual ValueAccessor *GetTableAccessor(size_t table_id) {
auto itr = _table_accessors.find(table_id);
if (itr == _table_accessors.end()) {
return NULL;
}
return itr->second.get();
}
virtual size_t GetServerNums() = 0;
virtual std::future<int32_t> PushDenseRawGradient(int table_id,
float *total_send_data,
size_t total_send_data_size,
void *done) = 0;
virtual std::future<int32_t> PushSparseRawGradient(
size_t table_id,
const uint64_t *keys,
const float **update_values,
size_t num,
void *done) = 0;
virtual std::future<int32_t> PushSparseRawGradientPartial(
size_t table_id,
const uint64_t *keys,
const float **update_values,
uint32_t num,
void *done,
int pserver_idx) = 0;
virtual std::future<int32_t> PushSparseParam(size_t table_id,
const uint64_t *keys,
const float **update_values,
size_t num,
void *done) = 0;
virtual std::future<int32_t> PushSparse(size_t table_id,
const uint64_t *keys,
const float **update_values,
size_t num) = 0;
// for save cache
virtual std::future<int32_t> CacheShuffle(
uint32_t table_id,
const std::string &path,
const std::string &mode,
const std::string &cache_threshold) {
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
virtual std::future<int32_t> CacheShuffleMultiTable(
std::vector<int> tables,
const std::string &path,
const std::string &mode,
const std::string &cache_threshold) {
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
virtual std::future<int32_t> SaveCache(uint32_t table_id,
const std::string &path,
const std::string &mode) {
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
virtual std::future<int32_t> GetCacheThreshold(
uint32_t table_id,
double &cache_threshold) { // NOLINT
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
virtual std::future<int32_t> Revert() {
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
virtual std::future<int32_t> CheckSavePrePatchDone() {
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
protected:
virtual int32_t Initialize() = 0;
PSParameter _config;
std::map<uint64_t, std::vector<paddle::distributed::Region>>
_dense_pull_regions;
std::unordered_map<uint32_t, std::shared_ptr<ValueAccessor>> _table_accessors;
std::unordered_map<int32_t, MsgHandlerFunc>
_msg_handler_map; // 处理client2client消息
public:
size_t _client_id;
PSEnvironment *_env;
};
template <class T>
class AsyncRequestTask {
public:
AsyncRequestTask() : _promise(std::make_shared<std::promise<int32_t>>()) {}
AsyncRequestTask(T &data, size_t table_id, std::shared_ptr<CostTimer> &timer)
: _table_id(table_id),
_timer(timer),
_promise(std::make_shared<std::promise<int32_t>>()) {
_data = std::move(data);
}
AsyncRequestTask(AsyncRequestTask &data) // NOLINT
: _table_id(data.table_id()),
_timer(data.timer()),
_promise(data.promise()) {
_data = std::move(data.data());
}
~AsyncRequestTask() {}
inline T &data() { return _data; }
inline size_t table_id() { return _table_id; }
inline std::shared_ptr<CostTimer> &timer() { return _timer; }
inline std::future<int32_t> get_future() { return _promise->get_future(); }
inline std::shared_ptr<std::promise<int32_t>> &promise() { return _promise; }
private:
T _data;
size_t _table_id;
std::shared_ptr<CostTimer> _timer;
std::shared_ptr<std::promise<int32_t>> _promise;
};
REGISTER_PSCORE_REGISTERER(PSClient);
class PSClientFactory {
public:
static PSClient *Create(const PSParameter &config);
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/ps_local_client.h"
#include "paddle/fluid/distributed/ps/table/table.h"
//#define pslib_debug_dense_compress
namespace paddle {
namespace distributed {
int32_t PsLocalClient::Initialize() {
const auto& downpour_param = _config.server_param().downpour_server_param();
TableManager::Instance().Initialize();
for (int i = 0; i < downpour_param.downpour_table_param_size(); ++i) {
auto* table = CREATE_PSCORE_CLASS(
Table, downpour_param.downpour_table_param(i).table_class());
table->SetShard(0, 1);
table->Initialize(downpour_param.downpour_table_param(i),
_config.fs_client_param());
_table_map[downpour_param.downpour_table_param(i).table_id()].reset(table);
}
return 0;
}
::std::future<int32_t> PsLocalClient::Shrink(uint32_t table_id,
const std::string threshold) {
// TODO
return done();
}
::std::future<int32_t> PsLocalClient::Load(const std::string& epoch,
const std::string& mode) {
// TODO
for (auto& it : _table_map) {
Load(it.first, epoch, mode);
}
return done();
}
::std::future<int32_t> PsLocalClient::Load(uint32_t table_id,
const std::string& epoch,
const std::string& mode) {
// TODO
auto* table_ptr = GetTable(table_id);
table_ptr->Load(epoch, mode);
return done();
}
::std::future<int32_t> PsLocalClient::Save(const std::string& epoch,
const std::string& mode) {
// TODO
for (auto& it : _table_map) {
Save(it.first, epoch, mode);
}
return done();
}
::std::future<int32_t> PsLocalClient::Save(uint32_t table_id,
const std::string& epoch,
const std::string& mode) {
// TODO
auto* table_ptr = GetTable(table_id);
table_ptr->Flush();
table_ptr->Save(epoch, mode);
return done();
}
::std::future<int32_t> PsLocalClient::Clear() {
// TODO
return done();
}
::std::future<int32_t> PsLocalClient::Clear(uint32_t table_id) {
// TODO
return done();
}
::std::future<int32_t> PsLocalClient::Flush() {
// no need
return done();
}
::std::future<int32_t> PsLocalClient::StopServer() {
// no need
return done();
}
::std::future<int32_t> PsLocalClient::PullDense(Region* regions,
size_t region_num,
size_t table_id) {
auto* accessor = GetTableAccessor(table_id);
auto* table_ptr = GetTable(table_id);
uint32_t num_per_shard =
DenseDimPerShard(accessor->GetAccessorInfo().fea_dim, 1);
std::vector<float> region_buffer;
region_buffer.resize(num_per_shard);
TableContext table_context;
table_context.value_type = Dense;
table_context.pull_context.values = region_buffer.data();
table_context.num = region_buffer.size();
table_ptr->Pull(table_context);
// table_ptr->PullDense(region_buffer.data(), region_buffer.size());
size_t region_idx = 0;
size_t region_data_idx = 0;
size_t shard_data_size = num_per_shard;
size_t shard_buffer_remain = shard_data_size * sizeof(float);
PADDLE_ENFORCE_EQ(
shard_buffer_remain,
region_buffer.size() * sizeof(float),
platform::errors::PreconditionNotMet("pull dense size error."));
size_t index = 0;
while (shard_buffer_remain > 0 && region_idx < region_num) {
auto& region = regions[region_idx];
if (region.size - region_data_idx >= shard_buffer_remain) {
memcpy((void*)(region.data + region_data_idx),
(uint8_t*)(void*)(region_buffer.data()) + index,
shard_buffer_remain);
region_data_idx += shard_buffer_remain;
shard_buffer_remain = 0;
} else if (region.size - region_data_idx == 0) {
++region_idx;
region_data_idx = 0;
} else {
memcpy((void*)(region.data + region_data_idx),
(uint8_t*)(void*)(region_buffer.data()) + index,
region.size - region_data_idx);
shard_buffer_remain -= (region.size - region_data_idx);
index += (region.size - region_data_idx);
++region_idx;
region_data_idx = 0;
}
}
return done();
}
::std::future<int32_t> PsLocalClient::PushDenseParam(const Region* regions,
size_t region_num,
size_t table_id) {
auto* accessor = GetTableAccessor(table_id);
auto* table_ptr = GetTable(table_id);
std::vector<float> region_buffer;
region_buffer.resize(DenseDimPerShard(accessor->GetAccessorInfo().fea_dim, 1),
0);
for (size_t i = 0, offset = 0; i < region_num; ++i) {
uint32_t data_num = regions[i].size / sizeof(float);
memcpy(region_buffer.data() + offset, regions[i].data, regions[i].size);
offset += data_num;
}
TableContext table_context;
table_context.value_type = Dense;
table_context.push_context.values = region_buffer.data();
table_context.push_context.is_param = true;
table_context.num = region_buffer.size();
table_ptr->Push(table_context);
// table_ptr->PushDenseParam(region_buffer.data(), region_buffer.size());
return done();
}
::std::future<int32_t> PsLocalClient::PushDenseRawGradient(
int table_id,
float* total_send_data,
size_t total_send_data_size,
void* callback) {
VLOG(1) << "wxx push_dense_raw_gradient";
PSClientClosure* closure = reinterpret_cast<PSClientClosure*>(callback);
auto* table_ptr = GetTable(table_id);
TableContext table_context;
table_context.value_type = Dense;
table_context.push_context.values = total_send_data;
table_context.num = total_send_data_size;
// table_ptr->PushDense(total_send_data, total_send_data_size);
table_ptr->Push(table_context);
delete closure;
return done();
}
::std::future<int32_t> PsLocalClient::PushDense(const Region* regions,
size_t region_num,
size_t table_id) {
auto* accessor = GetTableAccessor(table_id);
auto* table_ptr = GetTable(table_id);
std::vector<float> region_buffer;
region_buffer.resize(
DenseDimPerShard(accessor->GetAccessorInfo().fea_dim, 1));
size_t data_size = region_buffer.size();
for (size_t i = 0, offset = 0; i < region_num; ++i) {
uint32_t data_num = regions[i].size / sizeof(float);
PADDLE_ENFORCE_LE(
offset + data_num,
data_size,
platform::errors::PreconditionNotMet(
"invalid dense size, cur pos[%d] data_num[%d] size[%d]",
offset,
data_num,
data_size));
memcpy(region_buffer.data() + offset, regions[i].data, regions[i].size);
offset += data_num;
}
TableContext table_context;
table_context.value_type = Dense;
table_context.push_context.values = region_buffer.data();
table_context.num = region_buffer.size();
// table_ptr->PushDense(total_send_data, total_send_data_size);
table_ptr->Push(table_context);
return done();
}
//::std::future<int32_t> PsLocalClient::PullSparse(float** select_values,
// size_t table_id,
// const uint64_t* keys,
// size_t num) {
// // FIXME
// // auto timer =
// // std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse");
// // auto local_timer =
// // std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse_local");
// //将key拆分到各shard请求,并记录原始对应value指针
// auto* accessor = GetTableAccessor(table_id);
// auto* table_ptr = GetTable(table_id);
// size_t value_size = accessor->select_size();
//
// // table_ptr->PullSparse(keys, num);
// std::vector<float> res_data;
// res_data.resize(num * value_size / sizeof(float));
// table_ptr->PullSparse(res_data.data(), keys, num);
// // memcpy(select_values[0], res_data->data(), res_data->size() *
// // sizeof(float));
// size_t offset = 0;
// for (int i = 0; i < num; ++i) {
// memcpy(select_values[i], (char*)res_data.data() + offset, value_size);
// offset += value_size;
// }
//
// // return fut;
// return done();
//}
::std::future<int32_t> PsLocalClient::PullSparsePtr(char** select_values,
size_t table_id,
const uint64_t* keys,
size_t num) {
// FIXME
// auto timer =
// std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse");
// auto local_timer =
// std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse_local");
//将key拆分到各shard请求,并记录原始对应value指针
auto* table_ptr = GetTable(table_id);
TableContext table_context;
table_context.value_type = Sparse;
table_context.pull_context.keys = keys;
table_context.pull_context.ptr_values = select_values;
table_context.use_ptr = true;
table_context.num = num;
// table_ptr->PullSparsePtr(select_values, keys, num);
table_ptr->Pull(table_context);
return done();
}
::std::future<int32_t> PsLocalClient::PushSparseRawGradient(
size_t table_id,
const uint64_t* keys,
const float** update_values,
size_t num,
void* callback) {
PSClientClosure* closure = reinterpret_cast<PSClientClosure*>(callback);
auto* table_ptr = GetTable(table_id);
TableContext table_context;
table_context.value_type = Sparse;
table_context.push_context.keys = keys;
table_context.push_context.ptr_values = update_values;
table_context.num = num;
table_context.use_ptr = true;
// table_ptr->PushSparse(keys, update_values, num);
table_ptr->Push(table_context);
delete closure;
return done();
}
::std::future<int32_t> PsLocalClient::PushSparse(size_t table_id,
const uint64_t* keys,
const float** update_values,
size_t num) {
auto* table_ptr = GetTable(table_id);
TableContext table_context;
table_context.value_type = Sparse;
table_context.push_context.keys = keys;
table_context.push_context.ptr_values = update_values;
table_context.num = num;
table_context.use_ptr = true;
// table_ptr->PushSparse(keys, update_values, num);
table_ptr->Push(table_context);
return done();
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License 0//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/ps/service/ps_client.h"
namespace paddle {
namespace distributed {
class Table;
class PsLocalClient : public PSClient {
public:
PsLocalClient() {}
virtual ~PsLocalClient() { _running = false; }
virtual int32_t CreateClient2ClientConnection(int pslib_timeout_ms,
int pslib_connect_timeout_ms,
int max_retry) {
return 0;
}
virtual ::std::future<int32_t> Shrink(uint32_t table_id,
const std::string threshold) override;
virtual ::std::future<int32_t> Load(const std::string& epoch,
const std::string& mode) override;
virtual ::std::future<int32_t> Load(uint32_t table_id,
const std::string& epoch,
const std::string& mode) override;
virtual ::std::future<int32_t> Save(const std::string& epoch,
const std::string& mode) override;
virtual ::std::future<int32_t> Save(uint32_t table_id,
const std::string& epoch,
const std::string& mode) override;
virtual ::std::future<int32_t> Clear() override;
virtual ::std::future<int32_t> Clear(uint32_t table_id) override;
virtual ::std::future<int32_t> StopServer() override;
virtual void FinalizeWorker() override {}
virtual ::std::future<int32_t> PullDense(Region* regions,
size_t region_num,
size_t table_id);
virtual ::std::future<int32_t> PushDense(const Region* regions,
size_t region_num,
size_t table_id);
virtual ::std::future<int32_t> PushDenseParam(const Region* regions,
size_t region_num,
size_t table_id);
virtual ::std::future<int32_t> PullSparse(float** select_values,
size_t table_id,
const uint64_t* keys,
size_t num,
bool is_training) {
std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future();
prom.set_value(0);
return fut;
}
virtual ::std::future<int32_t> PullSparsePtr(char** select_values,
size_t table_id,
const uint64_t* keys,
size_t num);
virtual ::std::future<int32_t> PrintTableStat(uint32_t table_id) {
std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future();
prom.set_value(0);
return fut;
}
virtual ::std::future<int32_t> PushSparse(size_t table_id,
const uint64_t* keys,
const float** update_values,
size_t num);
virtual ::std::future<int32_t> Flush();
// server profilera
virtual std::future<int32_t> StartProfiler() {
std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future();
prom.set_value(0);
return fut;
};
virtual std::future<int32_t> StopProfiler() {
std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future();
prom.set_value(0);
return fut;
}
virtual std::future<int32_t> Barrier(size_t table_id, uint32_t barrier_type) {
std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future();
prom.set_value(0);
return fut;
}
virtual std::future<int32_t> PullGeoParam(size_t table_id,
std::vector<float>* values,
std::vector<uint64_t>* keys,
int pserver_idx) {
std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future();
prom.set_value(0);
return fut;
}
virtual std::future<int32_t> PushGlobalStep(int table_id,
int64_t* total_send_data,
void* done) {
std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future();
prom.set_value(0);
return fut;
}
// recv table from server and save it in LodTensor
virtual int32_t RecvAndSaveTable(const uint64_t table_id,
const std::string& path) {
return 0;
}
virtual ::std::future<int32_t> SendClient2ClientMsg(
int msg_type, int to_client_id, const std::string& msg) override {
std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future();
prom.set_value(0);
return fut;
}
virtual size_t GetServerNums() { return 1; }
virtual std::future<int32_t> PushDenseRawGradient(int table_id,
float* total_send_data,
size_t total_send_data_size,
void* callback) override;
virtual std::future<int32_t> PushSparseRawGradient(
size_t table_id,
const uint64_t* keys,
const float** update_values,
size_t num,
void* callback) override;
virtual std::future<int32_t> PushSparseRawGradientPartial(
size_t table_id,
const uint64_t* keys,
const float** update_values,
uint32_t num,
void* done,
int pserver_idx) override {
std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future();
prom.set_value(0);
return fut;
}
virtual std::future<int32_t> PushSparseParam(size_t table_id,
const uint64_t* keys,
const float** update_values,
size_t num,
void* done) override {
std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future();
prom.set_value(0);
return fut;
}
private:
virtual int32_t Initialize() override;
std::future<int32_t> done() {
std::shared_ptr<std::promise<int32_t>> prom =
std::make_shared<std::promise<int32_t>>();
std::future<int32_t> fut = prom->get_future();
prom->set_value(0);
return fut;
}
inline uint32_t DenseDimPerShard(uint32_t dense_dim_total,
uint32_t shard_num) {
return dense_dim_total / shard_num + 1;
}
inline std::unordered_map<uint32_t, std::shared_ptr<Table>>* GetTable() {
return &_table_map;
}
inline Table* GetTable(size_t table_id) {
auto itr = _table_map.find(table_id);
if (itr != _table_map.end()) {
return itr->second.get();
}
LOG(ERROR) << "table not found " << table_id;
return NULL;
}
std::unordered_map<uint32_t, std::shared_ptr<Table>> _table_map;
bool _running = false;
bool _flushing = false;
private:
float _mae = 0;
float _mse = 0;
uint16_t _push_times = 0;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <vector>
#include "paddle/fluid/distributed/ps/service/server.h"
namespace paddle {
namespace distributed {
class PsLocalServer : public PSServer {
public:
PsLocalServer() {}
virtual ~PsLocalServer() {}
virtual uint64_t Start() { return 0; }
virtual uint64_t Start(const std::string &ip, uint32_t port) { return 0; }
virtual int32_t Stop() { return 0; }
virtual int32_t Configure(
const PSParameter &config,
PSEnvironment &env,
size_t server_rank,
const std::vector<framework::ProgramDesc> &server_sub_program = {}) {
return 0;
}
private:
virtual int32_t Initialize() { return 0; }
};
} // namespace distributed
} // namespace paddle
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment