Commit de2e6515 authored by yuguo960516yuguo's avatar yuguo960516yuguo
Browse files

2.4.1-dtk-23.04

parent ad08b8ce
Pipeline #228 failed with stages
in 0 seconds
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace paddle {
namespace distributed {
void LinkNodes(const std::vector<TaskNode*>& nodes) {
size_t size = nodes.size();
if (size <= 1) return;
{ // i = 0
TaskNode* now = nodes[0];
TaskNode* next = nodes[1];
now->AddDownstreamTask(next->task_id());
}
{ // i = size - 1
TaskNode* prev = nodes[size - 2];
TaskNode* now = nodes[size - 1];
now->AddUpstreamTask(prev->task_id());
}
for (size_t i = 1; i < size - 1; ++i) {
TaskNode* prev = nodes[i - 1];
TaskNode* now = nodes[i];
TaskNode* next = nodes[i + 1];
now->AddUpstreamTask(prev->task_id());
now->AddDownstreamTask(next->task_id());
}
}
TEST(AmplifierInterceptor, Amplifier) {
std::string carrier_id = "0";
Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0,
{{SOURCE_ID, 0},
{0, 0},
{1, 0},
{2, 0},
{3, 0},
{4, 0},
{5, 0},
{SINK_ID, 0}});
MessageBus* msg_bus = GlobalVal<MessageBus>::Create();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "127.0.0.0:0");
int64_t micro_steps = 3;
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* source =
new TaskNode(0, SOURCE_ID, micro_steps); // rank, task_id, max_run_times
TaskNode* node_a = new TaskNode(0, 0, 0, 1, 0); // role, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 1, 0);
TaskNode* node_c = new TaskNode(0, 0, 2, 1, 0);
TaskNode* node_d = new TaskNode(0, 0, 3, 1, 0);
TaskNode* node_e = new TaskNode(0, 0, 4, 1, 0);
TaskNode* node_f = new TaskNode(0, 0, 5, 1, 0);
TaskNode* sink = new TaskNode(0, SINK_ID, micro_steps);
// source->a->b->c->d->e->f->sink
LinkNodes({source, node_a, node_b, node_c, node_d, node_e, node_f, sink});
// LR->b(1:3)->F->B->e(3:1)->U
node_b->SetReplyUpPerSteps(micro_steps);
node_e->SetSendDownPerSteps(micro_steps);
carrier->SetInterceptor(
SOURCE_ID, InterceptorFactory::Create("Source", SOURCE_ID, source));
carrier->SetInterceptor(0, InterceptorFactory::Create("Compute", 0, node_a));
carrier->SetInterceptor(1,
InterceptorFactory::Create("Amplifier", 1, node_b));
carrier->SetInterceptor(2, InterceptorFactory::Create("Compute", 2, node_c));
carrier->SetInterceptor(3, InterceptorFactory::Create("Compute", 3, node_d));
carrier->SetInterceptor(4,
InterceptorFactory::Create("Amplifier", 4, node_e));
carrier->SetInterceptor(5, InterceptorFactory::Create("Compute", 5, node_f));
carrier->SetInterceptor(SINK_ID,
InterceptorFactory::Create("Sink", SINK_ID, sink));
// start
InterceptorMessage msg;
msg.set_message_type(START);
msg.set_dst_id(SOURCE_ID);
carrier->EnqueueInterceptorMessage(msg);
carrier->Wait();
carrier->Release();
}
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace paddle {
namespace distributed {
int64_t GetBuffSize(
const std::map<std::pair<TaskNode*, TaskNode*>, int64_t> buffs,
TaskNode* from,
TaskNode* to) {
if (buffs.find({from, to}) != buffs.end()) {
return buffs.at({from, to});
}
if (buffs.find({to, from}) != buffs.end()) {
return buffs.at({to, from});
}
return 2; // set default 2
}
void LinkNodes(const std::vector<TaskNode*>& nodes,
const std::map<std::pair<TaskNode*, TaskNode*>, int64_t> buffs) {
size_t size = nodes.size();
if (size <= 1) return;
{ // i = 0
TaskNode* now = nodes[0];
TaskNode* next = nodes[1];
auto buff_size = GetBuffSize(buffs, now, next);
now->AddDownstreamTask(next->task_id(), buff_size);
}
{ // i = size - 1
TaskNode* prev = nodes[size - 2];
TaskNode* now = nodes[size - 1];
auto buff_size = GetBuffSize(buffs, prev, now);
now->AddUpstreamTask(prev->task_id(), buff_size);
}
for (size_t i = 1; i < size - 1; ++i) {
TaskNode* prev = nodes[i - 1];
TaskNode* now = nodes[i];
TaskNode* next = nodes[i + 1];
auto buff_size = GetBuffSize(buffs, prev, now);
now->AddUpstreamTask(prev->task_id(), buff_size);
buff_size = GetBuffSize(buffs, now, next);
now->AddDownstreamTask(next->task_id(), buff_size);
}
}
TEST(AmplifierInterceptor, Amplifier) {
std::string carrier_id = "0";
Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0,
{{SOURCE_ID, 0}, {0, 0}, {1, 0}, {2, 0}, {3, 0}, {SINK_ID, 0}});
MessageBus* msg_bus = GlobalVal<MessageBus>::Create();
msg_bus->Init(0, {{0, ""}}, "");
int64_t micro_steps = 6;
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* source =
new TaskNode(0, SOURCE_ID, micro_steps); // rank, task_id, max_run_times
TaskNode* node_a =
new TaskNode(0, 0, 0, micro_steps, 0); // role, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 3, 0);
TaskNode* node_c = new TaskNode(0, 0, 2, 3, 0);
TaskNode* node_d = new TaskNode(0, 0, 3, micro_steps, 0);
TaskNode* sink = new TaskNode(0, SINK_ID, micro_steps);
// source->a->b->c->d->sink
// LR->F->B->U
LinkNodes({source, node_a, node_b, node_c, node_d, sink},
{{{node_b, node_c}, 1}});
node_a->SetRunPerSteps(micro_steps);
node_d->SetRunPerSteps(micro_steps);
node_d->SetRunAtOffset(micro_steps - 1);
carrier->SetInterceptor(
SOURCE_ID, InterceptorFactory::Create("Source", SOURCE_ID, source));
carrier->SetInterceptor(0,
InterceptorFactory::Create("Amplifier", 0, node_a));
carrier->SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b));
carrier->SetInterceptor(2, InterceptorFactory::Create("Compute", 2, node_c));
carrier->SetInterceptor(3,
InterceptorFactory::Create("Amplifier", 3, node_d));
carrier->SetInterceptor(SINK_ID,
InterceptorFactory::Create("Sink", SINK_ID, sink));
// start
InterceptorMessage msg;
msg.set_message_type(START);
msg.set_dst_id(SOURCE_ID);
carrier->EnqueueInterceptorMessage(msg);
carrier->Wait();
carrier->Release();
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace paddle {
namespace distributed {
class FakeInterceptor : public Interceptor {
public:
FakeInterceptor(int64_t interceptor_id, TaskNode* node)
: Interceptor(interceptor_id, node) {
RegisterMsgHandle([this](const InterceptorMessage& msg) { NOP(msg); });
}
void NOP(const InterceptorMessage& msg) {
if (msg.message_type() == DATA_IS_READY) {
std::cout << "FakeInterceptor run in scope " << msg.scope_idx()
<< std::endl;
InterceptorMessage reply;
reply.set_message_type(DATA_IS_USELESS);
Send(SOURCE_ID, reply);
InterceptorMessage ready;
ready.set_message_type(DATA_IS_READY);
Send(SINK_ID, ready);
} else if (msg.message_type() == DATA_IS_USELESS) {
std::cout << "FakeInterceptor remove result in scope " << msg.scope_idx()
<< std::endl;
}
}
private:
int64_t step_;
};
TEST(SourceInterceptor, Source) {
std::string carrier_id = "0";
Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0, {{SOURCE_ID, 0}, {0, 0}, {SINK_ID, 0}});
MessageBus* msg_bus = GlobalVal<MessageBus>::Create();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* source =
new TaskNode(0, SOURCE_ID, 0, 3, 0); // role, rank, task_id
TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id
TaskNode* sink = new TaskNode(0, SINK_ID, 0, 3, 0); // role, rank, task_id
source->AddDownstreamTask(0, 1);
node_a->AddUpstreamTask(SOURCE_ID, 1);
node_a->AddDownstreamTask(SINK_ID, 1);
sink->AddUpstreamTask(0, 1);
carrier->SetInterceptor(
SOURCE_ID, InterceptorFactory::Create("Source", SOURCE_ID, source));
carrier->SetInterceptor(0, std::make_unique<FakeInterceptor>(0, node_a));
carrier->SetInterceptor(SINK_ID,
InterceptorFactory::Create("Sink", SINK_ID, sink));
// start
InterceptorMessage msg;
msg.set_message_type(START);
msg.set_dst_id(SOURCE_ID);
carrier->EnqueueInterceptorMessage(msg);
carrier->Wait();
carrier->Release();
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace paddle {
namespace distributed {
class FakeInterceptor : public Interceptor {
public:
FakeInterceptor(int64_t interceptor_id, TaskNode* node)
: Interceptor(interceptor_id, node) {
step_ = 0;
RegisterMsgHandle([this](const InterceptorMessage& msg) { NOP(msg); });
}
void NOP(const InterceptorMessage& msg) {
if (msg.message_type() == DATA_IS_READY) {
std::cout << "FakeInterceptor run in scope " << msg.scope_idx()
<< std::endl;
InterceptorMessage reply;
reply.set_message_type(DATA_IS_USELESS);
Send(SOURCE_ID, reply);
step_++;
if (step_ == node_->max_run_times()) {
carrier_->WakeUp();
}
}
}
private:
int64_t step_;
};
TEST(SourceInterceptor, Source) {
std::string carrier_id = "0";
Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0, {{SOURCE_ID, 0}, {0, 0}});
MessageBus* msg_bus = GlobalVal<MessageBus>::Create();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* source =
new TaskNode(0, SOURCE_ID, 0, 3, 0); // role, rank, task_id
TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id
source->AddDownstreamTask(0, 1);
node_a->AddUpstreamTask(SOURCE_ID, 1);
carrier->SetInterceptor(
SOURCE_ID, InterceptorFactory::Create("Source", SOURCE_ID, source));
carrier->SetInterceptor(0, std::make_unique<FakeInterceptor>(0, node_a));
// start
InterceptorMessage msg;
msg.set_message_type(START);
msg.set_dst_id(SOURCE_ID);
carrier->EnqueueInterceptorMessage(msg);
carrier->Wait();
carrier->Release();
}
} // namespace distributed
} // namespace paddle
proto_library(index_dataset_proto SRCS index_dataset.proto)
cc_library(
index_wrapper
SRCS index_wrapper.cc
DEPS index_dataset_proto fs)
if(WITH_MKLDNN)
cc_library(
index_sampler
SRCS index_sampler.cc
DEPS xxhash index_wrapper eigen3 mkldnn)
else()
cc_library(
index_sampler
SRCS index_sampler.cc
DEPS xxhash index_wrapper eigen3)
endif()
if(WITH_PYTHON)
py_proto_compile(index_dataset_py_proto SRCS index_dataset.proto)
endif()
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package paddle.distributed;
message IndexNode {
required uint64 id = 1;
required bool is_leaf = 2;
required float probability = 3;
optional string item_name = 4;
}
message TreeMeta {
required int32 height = 1;
required int32 branch = 2;
}
message KVItem {
required bytes key = 1;
required bytes value = 2;
}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/index_dataset/index_sampler.h"
#include "paddle/fluid/framework/data_feed.h"
namespace paddle {
namespace distributed {
std::vector<std::vector<uint64_t>> LayerWiseSampler::sample(
const std::vector<std::vector<uint64_t>>& user_inputs,
const std::vector<uint64_t>& target_ids,
bool with_hierarchy) {
auto input_num = target_ids.size();
auto user_feature_num = user_inputs[0].size();
std::vector<std::vector<uint64_t>> outputs(
input_num * layer_counts_sum_,
std::vector<uint64_t>(user_feature_num + 2));
auto max_layer = tree_->Height();
size_t idx = 0;
for (size_t i = 0; i < input_num; i++) {
auto travel_codes =
tree_->GetTravelCodes(target_ids[i], start_sample_layer_);
auto travel_path = tree_->GetNodes(travel_codes);
for (size_t j = 0; j < travel_path.size(); j++) {
// user
if (j > 0 && with_hierarchy) {
auto ancestor_codes =
tree_->GetAncestorCodes(user_inputs[i], max_layer - j - 1);
auto hierarchical_user = tree_->GetNodes(ancestor_codes);
for (int idx_offset = 0; idx_offset <= layer_counts_[j]; idx_offset++) {
for (size_t k = 0; k < user_feature_num; k++) {
outputs[idx + idx_offset][k] = hierarchical_user[k].id();
}
}
} else {
for (int idx_offset = 0; idx_offset <= layer_counts_[j]; idx_offset++) {
for (size_t k = 0; k < user_feature_num; k++) {
outputs[idx + idx_offset][k] = user_inputs[i][k];
}
}
}
// sampler ++
outputs[idx][user_feature_num] = travel_path[j].id();
outputs[idx][user_feature_num + 1] = 1.0;
idx += 1;
for (int idx_offset = 0; idx_offset < layer_counts_[j]; idx_offset++) {
int sample_res = 0;
do {
sample_res = sampler_vec_[j]->Sample();
} while (layer_ids_[j][sample_res].id() == travel_path[j].id());
outputs[idx + idx_offset][user_feature_num] =
layer_ids_[j][sample_res].id();
outputs[idx + idx_offset][user_feature_num + 1] = 0;
}
idx += layer_counts_[j];
}
}
return outputs;
}
void LayerWiseSampler::sample_from_dataset(
const uint16_t sample_slot,
std::vector<paddle::framework::Record>* src_datas,
std::vector<paddle::framework::Record>* sample_results) {
sample_results->clear();
for (auto& data : *src_datas) {
VLOG(1) << "src data size = " << src_datas->size();
VLOG(1) << "float data size = " << data.float_feasigns_.size();
// data.Print();
uint64_t start_idx = sample_results->size();
VLOG(1) << "before sample, sample_results.size = " << start_idx;
uint64_t sample_feasign_idx = -1;
bool sample_sign = false;
for (unsigned int i = 0; i < data.uint64_feasigns_.size(); i++) {
VLOG(1) << "slot" << i << " = " << data.uint64_feasigns_[i].slot();
if (data.uint64_feasigns_[i].slot() == sample_slot) {
sample_sign = true;
sample_feasign_idx = i;
}
if (sample_sign) break;
}
VLOG(1) << "sample_feasign_idx: " << sample_feasign_idx;
if (sample_sign) {
auto target_id =
data.uint64_feasigns_[sample_feasign_idx].sign().uint64_feasign_;
auto travel_codes = tree_->GetTravelCodes(target_id, start_sample_layer_);
auto travel_path = tree_->GetNodes(travel_codes);
for (unsigned int j = 0; j < travel_path.size(); j++) {
paddle::framework::Record instance(data);
instance.uint64_feasigns_[sample_feasign_idx].sign().uint64_feasign_ =
travel_path[j].id();
sample_results->push_back(instance);
for (int idx_offset = 0; idx_offset < layer_counts_[j]; idx_offset++) {
int sample_res = 0;
do {
sample_res = sampler_vec_[j]->Sample();
} while (layer_ids_[j][sample_res].id() == travel_path[j].id());
paddle::framework::Record instance(data);
instance.uint64_feasigns_[sample_feasign_idx].sign().uint64_feasign_ =
layer_ids_[j][sample_res].id();
VLOG(1) << "layer id :" << layer_ids_[j][sample_res].id();
// sample_feasign_idx + 1 == label's id
instance.uint64_feasigns_[sample_feasign_idx + 1]
.sign()
.uint64_feasign_ = 0;
sample_results->push_back(instance);
}
VLOG(1) << "layer end!!!!!!!!!!!!!!!!!!";
}
}
}
VLOG(1) << "after sample, sample_results.size = " << sample_results->size();
return;
}
std::vector<uint64_t> float2int(std::vector<double> tmp) {
std::vector<uint64_t> tmp_int;
for (auto i : tmp) tmp_int.push_back(uint64_t(i));
return tmp_int;
}
} // end namespace distributed
} // end namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/distributed/index_dataset/index_wrapper.h"
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/math/sampler.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace distributed {
class IndexSampler {
public:
virtual ~IndexSampler() {}
IndexSampler() {}
template <typename T>
static std::shared_ptr<IndexSampler> Init(const std::string& name) {
std::shared_ptr<IndexSampler> instance = nullptr;
instance.reset(new T(name));
return instance;
}
virtual void init_layerwise_conf(
const std::vector<uint16_t>& layer_sample_counts,
uint16_t start_sample_layer = 1,
uint16_t seed = 0) {}
virtual void init_beamsearch_conf(const int64_t k) {}
virtual std::vector<std::vector<uint64_t>> sample(
const std::vector<std::vector<uint64_t>>& user_inputs,
const std::vector<uint64_t>& input_targets,
bool with_hierarchy = false) = 0;
virtual void sample_from_dataset(
const uint16_t sample_slot,
std::vector<paddle::framework::Record>* src_datas,
std::vector<paddle::framework::Record>* sample_results) = 0;
};
class LayerWiseSampler : public IndexSampler {
public:
virtual ~LayerWiseSampler() {}
explicit LayerWiseSampler(const std::string& name) {
tree_ = IndexWrapper::GetInstance()->get_tree_index(name);
}
void init_layerwise_conf(const std::vector<uint16_t>& layer_sample_counts,
uint16_t start_sample_layer,
uint16_t seed) override {
seed_ = seed;
start_sample_layer_ = start_sample_layer;
PADDLE_ENFORCE_GT(
start_sample_layer_,
0,
paddle::platform::errors::InvalidArgument(
"start sampler layer = [%d], it should greater than 0.",
start_sample_layer_));
PADDLE_ENFORCE_LT(start_sample_layer_,
tree_->Height(),
paddle::platform::errors::InvalidArgument(
"start sampler layer = [%d], it should less than "
"max_layer, which is [%d].",
start_sample_layer_,
tree_->Height()));
size_t i = 0;
layer_counts_sum_ = 0;
layer_counts_.clear();
int cur_layer = start_sample_layer_;
while (cur_layer < tree_->Height()) {
int layer_sample_num = 1;
if (i < layer_sample_counts.size()) {
layer_sample_num = layer_sample_counts[i];
}
layer_counts_sum_ += layer_sample_num + 1;
layer_counts_.push_back(layer_sample_num);
VLOG(3) << "[INFO] level " << cur_layer
<< " sample_layer_counts.push_back: " << layer_sample_num;
cur_layer += 1;
i += 1;
}
reverse(layer_counts_.begin(), layer_counts_.end());
VLOG(3) << "sample counts sum: " << layer_counts_sum_;
auto max_layer = tree_->Height();
sampler_vec_.clear();
layer_ids_.clear();
auto layer_index = max_layer - 1;
size_t idx = 0;
while (layer_index >= start_sample_layer_) {
auto layer_codes = tree_->GetLayerCodes(layer_index);
layer_ids_.push_back(tree_->GetNodes(layer_codes));
auto sampler_temp =
std::make_shared<paddle::operators::math::UniformSampler>(
layer_ids_[idx].size() - 1, seed_);
sampler_vec_.push_back(sampler_temp);
layer_index--;
idx++;
}
}
std::vector<std::vector<uint64_t>> sample(
const std::vector<std::vector<uint64_t>>& user_inputs,
const std::vector<uint64_t>& target_ids,
bool with_hierarchy) override;
void sample_from_dataset(
const uint16_t sample_slot,
std::vector<paddle::framework::Record>* src_datas,
std::vector<paddle::framework::Record>* sample_results) override;
private:
std::vector<int> layer_counts_;
int64_t layer_counts_sum_{0};
std::shared_ptr<TreeIndex> tree_{nullptr};
int seed_{0};
int start_sample_layer_{1};
std::vector<std::shared_ptr<paddle::operators::math::Sampler>> sampler_vec_;
std::vector<std::vector<IndexNode>> layer_ids_;
};
} // end namespace distributed
} // end namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/index_dataset/index_wrapper.h"
#include <memory>
#include <string>
#include <thread>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/io/fs.h"
namespace paddle {
namespace distributed {
std::shared_ptr<IndexWrapper> IndexWrapper::s_instance_(nullptr);
int TreeIndex::Load(const std::string filename) {
int err_no;
auto fp = paddle::framework::fs_open_read(filename, &err_no, "");
PADDLE_ENFORCE_NE(
fp,
nullptr,
platform::errors::InvalidArgument(
"Open file %s failed. Please check whether the file exists.",
filename));
int num = 0;
max_id_ = 0;
fake_node_.set_id(0);
fake_node_.set_is_leaf(false);
fake_node_.set_probability(0.0);
max_code_ = 0;
size_t ret = fread(&num, sizeof(num), 1, fp.get());
while (ret == 1 && num > 0) {
std::string content(num, '\0');
size_t read_num =
fread(const_cast<char*>(content.data()), 1, num, fp.get());
PADDLE_ENFORCE_EQ(
read_num,
static_cast<size_t>(num),
platform::errors::InvalidArgument(
"Read from file: %s failed. Valid Format is "
"an integer representing the length of the following string, "
"and the string itself.We got an iteger[% d], "
"but the following string's length is [%d].",
filename,
num,
read_num));
KVItem item;
PADDLE_ENFORCE_EQ(
item.ParseFromString(content),
true,
platform::errors::InvalidArgument("Parse from file: %s failed. It's "
"content can't be parsed by KVItem.",
filename));
if (item.key() == ".tree_meta") {
meta_.ParseFromString(item.value());
} else {
auto code = std::stoull(item.key());
IndexNode node;
node.ParseFromString(item.value());
// PADDLE_ENFORCE_NE(node.id(), 0,
// platform::errors::InvalidArgument(
// "Node'id should not be equel to zero."));
if (node.is_leaf()) {
id_codes_map_[node.id()] = code;
}
data_[code] = node;
if (node.id() > max_id_) {
max_id_ = node.id();
}
if (code > max_code_) {
max_code_ = code;
}
}
ret = fread(&num, sizeof(num), 1, fp.get());
}
total_nodes_num_ = data_.size();
max_code_ += 1;
return 0;
}
std::vector<IndexNode> TreeIndex::GetNodes(const std::vector<uint64_t>& codes) {
std::vector<IndexNode> nodes;
nodes.reserve(codes.size());
for (size_t i = 0; i < codes.size(); i++) {
if (CheckIsValid(codes[i])) {
nodes.push_back(data_.at(codes[i]));
} else {
nodes.push_back(fake_node_);
}
}
return nodes;
}
std::vector<uint64_t> TreeIndex::GetLayerCodes(int level) {
uint64_t level_num = static_cast<uint64_t>(std::pow(meta_.branch(), level));
uint64_t level_offset = level_num - 1;
std::vector<uint64_t> res;
res.reserve(level_num);
for (uint64_t i = 0; i < level_num; i++) {
auto code = level_offset + i;
if (CheckIsValid(code)) {
res.push_back(code);
}
}
return res;
}
std::vector<uint64_t> TreeIndex::GetAncestorCodes(
const std::vector<uint64_t>& ids, int level) {
std::vector<uint64_t> res;
res.reserve(ids.size());
int cur_level;
for (size_t i = 0; i < ids.size(); i++) {
if (id_codes_map_.find(ids[i]) == id_codes_map_.end()) {
res.push_back(max_code_);
} else {
auto code = id_codes_map_.at(ids[i]);
cur_level = meta_.height() - 1;
while (level >= 0 && cur_level > level) {
code = (code - 1) / meta_.branch();
cur_level--;
}
res.push_back(code);
}
}
return res;
}
std::vector<uint64_t> TreeIndex::GetChildrenCodes(uint64_t ancestor,
int level) {
auto level_code_num = static_cast<uint64_t>(std::pow(meta_.branch(), level));
auto code_min = level_code_num - 1;
auto code_max = meta_.branch() * level_code_num - 1;
std::vector<uint64_t> parent;
parent.push_back(ancestor);
std::vector<uint64_t> res;
size_t p_idx = 0;
while (true) {
size_t p_size = parent.size();
for (; p_idx < p_size; p_idx++) {
for (int i = 0; i < meta_.branch(); i++) {
auto code = parent[p_idx] * meta_.branch() + i + 1;
if (data_.find(code) != data_.end()) parent.push_back(code);
}
}
if ((code_min <= parent[p_idx]) && (parent[p_idx] < code_max)) {
break;
}
}
return std::vector<uint64_t>(parent.begin() + p_idx, parent.end());
}
std::vector<uint64_t> TreeIndex::GetTravelCodes(uint64_t id, int start_level) {
std::vector<uint64_t> res;
PADDLE_ENFORCE_NE(id_codes_map_.find(id),
id_codes_map_.end(),
paddle::platform::errors::InvalidArgument(
"id = %d doesn't exist in Tree.", id));
auto code = id_codes_map_.at(id);
int level = meta_.height() - 1;
while (level >= start_level) {
res.push_back(code);
code = (code - 1) / meta_.branch();
level--;
}
return res;
}
std::vector<IndexNode> TreeIndex::GetAllLeafs() {
std::vector<IndexNode> res;
res.reserve(id_codes_map_.size());
for (auto& ite : id_codes_map_) {
auto code = ite.second;
res.push_back(data_.at(code));
}
return res;
}
} // end namespace distributed
} // end namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cmath>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/distributed/index_dataset/index_dataset.pb.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace distributed {
class Index {
public:
Index() {}
~Index() {}
};
class TreeIndex : public Index {
public:
TreeIndex() {}
~TreeIndex() {}
int Height() { return meta_.height(); }
int Branch() { return meta_.branch(); }
uint64_t TotalNodeNums() { return total_nodes_num_; }
uint64_t EmbSize() { return max_id_ + 1; }
int Load(const std::string path);
inline bool CheckIsValid(int code) {
if (data_.find(code) != data_.end()) {
return true;
} else {
return false;
}
}
std::vector<IndexNode> GetNodes(const std::vector<uint64_t>& codes);
std::vector<uint64_t> GetLayerCodes(int level);
std::vector<uint64_t> GetAncestorCodes(const std::vector<uint64_t>& ids,
int level);
std::vector<uint64_t> GetChildrenCodes(uint64_t ancestor, int level);
std::vector<uint64_t> GetTravelCodes(uint64_t id, int start_level);
std::vector<IndexNode> GetAllLeafs();
std::unordered_map<uint64_t, IndexNode> data_;
std::unordered_map<uint64_t, uint64_t> id_codes_map_;
uint64_t total_nodes_num_;
TreeMeta meta_;
uint64_t max_id_;
uint64_t max_code_;
IndexNode fake_node_;
};
using TreePtr = std::shared_ptr<TreeIndex>;
class IndexWrapper {
public:
virtual ~IndexWrapper() {}
IndexWrapper() {}
void clear_tree() { tree_map.clear(); }
TreePtr get_tree_index(const std::string name) {
PADDLE_ENFORCE_NE(tree_map.find(name),
tree_map.end(),
paddle::platform::errors::InvalidArgument(
"tree [%s] doesn't exist. Please insert it firstly "
"by API[\' insert_tree_index \'].",
name));
return tree_map[name];
}
void insert_tree_index(const std::string name, const std::string tree_path) {
if (tree_map.find(name) != tree_map.end()) {
VLOG(0) << "Tree " << name << " has already existed.";
return;
}
TreePtr tree = std::make_shared<TreeIndex>();
int ret = tree->Load(tree_path);
PADDLE_ENFORCE_EQ(ret,
0,
paddle::platform::errors::InvalidArgument(
"Load tree[%s] from path[%s] failed. Please "
"check whether the file exists.",
name,
tree_path));
tree_map.insert(std::pair<std::string, TreePtr>{name, tree});
}
static std::shared_ptr<IndexWrapper> GetInstancePtr() {
if (NULL == s_instance_) {
s_instance_.reset(new paddle::distributed::IndexWrapper());
}
return s_instance_;
}
static IndexWrapper* GetInstance() {
if (NULL == s_instance_) {
s_instance_.reset(new paddle::distributed::IndexWrapper());
}
return s_instance_.get();
}
private:
static std::shared_ptr<IndexWrapper> s_instance_;
std::unordered_map<std::string, TreePtr> tree_map;
};
} // end namespace distributed
} // end namespace paddle
set_property(GLOBAL PROPERTY RPC_DEPS sendrecv_rpc ${BRPC_DEPS} string_helper)
add_subdirectory(table)
add_subdirectory(service)
add_subdirectory(wrapper)
# 目录说明
Table: for param storage and update
-----MemorySparseTable: table for sparse param, used in cpu async mode
-----MemoryDenseTable: table for dense param, used in cpu async/geo mode
-----MemorySparseGeoTable: table for sparse param, used in cpu async mode
-----CommonGraphTable: table used for graph learning
-----BarrierTable: table for barrier function, used in cpu sync mode
-----TensorTable: table which run program, used for learning rate decay only
ValueAccessor: for pull param and push gradient
-----CtrCommonAccessor: pull/push value with show/click, float type
-----CtrDoubleAccessor: same as CtrCommonAccessor, other than show/click with double type
-----SparseAccessor: used for common embedding, pull value without show/click, push value with show/click
-----CommMergeAccessor: used for dense table only, for get param dim
PsService(proto): for server to handle request
-----PsBaseService
----------BrpcPsService: for cpu dnn training task
----------GraphBrpcService: for graph learning
-----HeterService: for dnn training task with heterogeneous computing resources
PSServer: recv request from trainer and handle it by service
-----BrpcPsServer: for cpu dnn training task
-----GraphBrpcServer: for graph learning
-----PsLocalServer: for GpuPS
HeterServer: for HeterPS
PSClient: pull param and push gradient for trainer
-----BrpcPsClient: for cpu dnn training task
----------GraphBrpcClient: for graph learning
-----PsLocalClient: for GpuPS
HeterClient: for HeterPS
PSCore: Wrapper for InitServer
GraphPyService: for graph learning
set(BRPC_SRCS ps_client.cc server.cc)
set_source_files_properties(${BRPC_SRCS})
if(WITH_HETERPS)
set(BRPC_DEPS
brpc
ssl
crypto
protobuf
gflags
glog
zlib
leveldb
snappy
gflags
glog
device_context
rocksdb)
else()
set(BRPC_DEPS
brpc
ssl
crypto
protobuf
gflags
glog
zlib
leveldb
snappy
gflags
glog
device_context)
endif()
brpc_library(
sendrecv_rpc
SRCS
${BRPC_SRCS}
PROTO
sendrecv.proto
DEPS
${BRPC_DEPS})
#set_property(GLOBAL PROPERTY RPC_DEPS sendrecv_rpc ${BRPC_DEPS} string_helper)
get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS)
set_source_files_properties(
communicator/communicator.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
ps_service/service.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
brpc_ps_server.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
brpc_ps_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
ps_local_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
brpc_utils.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
heter_server.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
heter_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(client.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(ps_client.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(server.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
graph_brpc_server.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
graph_brpc_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
coordinator_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
ps_service/graph_py_service.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS})
cc_library(
brpc_utils
SRCS brpc_utils.cc
DEPS tensor device_context ${COMMON_DEPS} ${RPC_DEPS})
cc_library(
ps_service
SRCS graph_brpc_server.cc
brpc_ps_server.cc
server.cc
graph_brpc_client.cc
brpc_ps_client.cc
ps_local_client.cc
coordinator_client.cc
ps_client.cc
communicator/communicator.cc
ps_service/service.cc
ps_service/graph_py_service.cc
DEPS eigen3
table
brpc_utils
simple_threadpool
scope
math_function
selected_rows_functor
${RPC_DEPS})
cc_library(
heter_client
SRCS heter_client.cc
DEPS brpc_utils ${COMMON_DEPS} ${RPC_DEPS})
cc_library(
heter_server
SRCS heter_server.cc
DEPS heter_client brpc_utils ${COMMON_DEPS} ${RPC_DEPS})
# 目录说明
* PSServer
* PSClient
* PsService
* Communicator
* MessageBusFramework
* *.proto
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/brpc_ps_client.h"
#include <memory>
#include <sstream>
#include <string>
#include "paddle/fluid/distributed/ps/service/coordinator_client.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/string/split.h"
static const int max_port = 65535;
namespace paddle {
namespace framework {
class Scope;
class Variable;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace distributed {
DEFINE_int32(pserver_push_dense_merge_limit,
12,
"limit max push_dense local merge requests");
DEFINE_int32(pserver_push_sparse_merge_limit,
12,
"limit max push_sparse local merge requests");
DEFINE_int32(pserver_pull_dense_limit,
12,
"limit max push_sparse local merge requests");
DEFINE_int32(pserver_async_push_dense_interval_ms,
10,
"async push_dense to server interval");
DEFINE_int32(pserver_async_push_sparse_interval_ms,
10,
"async push_sparse to server interval");
DEFINE_bool(pserver_scale_gradient_by_merge,
false,
"scale dense gradient when merged");
DEFINE_int32(pserver_communicate_compress_type,
0,
"none:0 snappy:1 gzip:2 zlib:3 lz4:4");
DEFINE_int32(pserver_max_async_call_num,
13,
"max task num in async_call_server");
DEFINE_int32(pserver_timeout_ms, 500000, "pserver request server timeout_ms");
DEFINE_int32(pserver_connect_timeout_ms,
10000,
"pserver connect server timeout_ms");
DEFINE_int32(pserver_sparse_merge_thread, 1, "pserver sparse merge thread num");
DEFINE_int32(pserver_sparse_table_shard_num,
1000,
"sparse table shard for save & load");
inline size_t get_sparse_shard(uint32_t shard_num,
uint32_t server_num,
uint64_t key) {
size_t remind = shard_num % server_num;
size_t local_shard_num =
remind == 0 ? shard_num / server_num : shard_num / server_num + 1;
return (key % shard_num) / local_shard_num;
}
void DownpourPsClientService::service(
::google::protobuf::RpcController *controller,
const PsRequestMessage *request,
PsResponseMessage *response,
::google::protobuf::Closure *done) {
brpc::ClosureGuard done_guard(done);
int ret = _client->HandleClient2ClientMsg(
request->cmd_id(), request->client_id(), request->data());
response->set_err_code(0);
response->set_err_msg("");
if (ret != 0) {
response->set_err_code(-1);
response->set_err_msg("handle_client2client_msg failed");
}
}
// 启动 client 端 RpcService 用于数据互发等操作
int32_t BrpcPsClient::StartClientService() {
if (_service.Configure(this, _client_id) != 0) {
LOG(ERROR)
<< "service initialize failed, service_name:DownpourPsClientService";
return -1;
}
_server.AddService(&_service, brpc::SERVER_DOESNT_OWN_SERVICE);
brpc::ServerOptions options;
int start_port = 8500;
options.num_threads = 24;
if (_server.Start(butil::my_ip_cstr(),
brpc::PortRange(start_port, max_port),
&options) != 0) {
LOG(ERROR) << "BrpcPsServer start failed";
return -1;
}
_server_started = true;
_env->RegistePsClient(
butil::my_ip_cstr(), _server.listen_address().port, _client_id);
VLOG(0) << "BrpcPsClient Service addr: " << butil::my_ip_cstr() << ", "
<< _server.listen_address().port << ", " << _client_id;
return 0;
}
// 启动 FlClientService,用户接收 coordinator 数据
int32_t BrpcPsClient::StartFlClientService(const std::string &self_endpoint) {
_fl_server.AddService(&_service, brpc::SERVER_DOESNT_OWN_SERVICE);
brpc::ServerOptions options;
if (self_endpoint.empty()) {
LOG(ERROR) << "fl-ps > fl client endpoint not set";
return -1;
}
if (_fl_server.Start(self_endpoint.c_str(), &options) != 0) {
VLOG(0) << "fl-ps > StartFlClientService failed. Try again.";
auto ip_port = paddle::string::Split(self_endpoint, ':');
std::string ip = ip_port[0];
int port = std::stoi(ip_port[1]);
std::string int_ip_port = GetIntTypeEndpoint(ip, port);
if (_fl_server.Start(int_ip_port.c_str(), &options) != 0) {
LOG(ERROR) << "fl-ps > StartFlClientService failed, ip_port= "
<< int_ip_port;
return -1;
}
} else {
VLOG(0) << "fl-ps > StartFlClientService succeed! listen on "
<< self_endpoint;
}
return 0;
}
int32_t BrpcPsClient::CreateClient2ClientConnection(
int pserver_timeout_ms, int pserver_connect_timeout_ms, int max_retry) {
brpc::ChannelOptions options;
options.protocol = "baidu_std";
options.timeout_ms = pserver_timeout_ms;
options.connection_type = "pooled";
options.connect_timeout_ms = pserver_connect_timeout_ms;
options.max_retry = max_retry;
std::vector<PSHost> client_list = _env->GetPsClients();
VLOG(1) << "BrpcPsClient::create_c2c_connection client_list size: "
<< client_list.size();
for (auto cc : client_list) {
VLOG(1) << "BrpcPsClient::create_c2c_connection client_list: "
<< cc.ToString();
}
_client_channels.resize(client_list.size());
std::ostringstream os;
std::string server_ip_port;
for (size_t i = 0; i < client_list.size(); ++i) {
server_ip_port.assign(client_list[i].ip.c_str());
server_ip_port.append(":");
server_ip_port.append(std::to_string(client_list[i].port));
_client_channels[i].reset(new brpc::Channel());
if (_client_channels[i]->Init(server_ip_port.c_str(), "", &options)) {
VLOG(0) << "BrpcPSClient connect to Client:" << server_ip_port
<< " Failed! Try again.";
std::string int_ip_port =
GetIntTypeEndpoint(client_list[i].ip, client_list[i].port);
if (_client_channels[i]->Init(int_ip_port.c_str(), "", &options) != 0) {
LOG(ERROR) << "BrpcPSClient connect to Client:" << int_ip_port
<< " Failed!";
return -1;
}
}
os << server_ip_port << ",";
}
LOG(INFO) << "Client connect success:" << os.str();
return 0;
}
int32_t BrpcPsClient::InitializeFlWorker(const std::string &self_endpoint) {
brpc::ChannelOptions options;
options.protocol = "baidu_std";
options.timeout_ms = FLAGS_pserver_timeout_ms;
options.connection_type = "pooled";
options.connect_timeout_ms =
paddle::distributed::FLAGS_pserver_connect_timeout_ms;
options.max_retry = 3;
// 获取 coordinator 列表,并连接
std::string coordinator_ip_port;
std::vector<PSHost> coordinator_list = _env->GetCoordinators();
_coordinator_channels.resize(coordinator_list.size());
for (size_t i = 0; i < coordinator_list.size(); ++i) {
coordinator_ip_port.assign(coordinator_list[i].ip.c_str());
coordinator_ip_port.append(":");
coordinator_ip_port.append(std::to_string(coordinator_list[i].port));
VLOG(0) << "fl-ps > BrpcFlclient connetcting to coordinator: "
<< coordinator_ip_port;
for (size_t j = 0; j < _coordinator_channels[i].size(); ++j) {
_coordinator_channels[i][j].reset(new brpc::Channel());
if (_coordinator_channels[i][j]->Init(
coordinator_ip_port.c_str(), "", &options) != 0) {
LOG(ERROR) << "fl-ps > BrpcFlclient connect to coordinator:"
<< coordinator_ip_port << " Failed! Try again.";
std::string int_ip_port = GetIntTypeEndpoint(coordinator_list[i].ip,
coordinator_list[i].port);
if (_coordinator_channels[i][j]->Init(
int_ip_port.c_str(), "", &options) != 0) {
LOG(ERROR) << "fl-ps > BrpcFlclient connect to coordinator:"
<< int_ip_port << " Failed!";
return -1;
}
}
}
}
StartFlClientService(self_endpoint);
VLOG(0) << "fl-ps > InitializeFlWorker finished!";
return 0;
}
void BrpcPsClient::PushFLClientInfoSync(const std::string &fl_client_info) {
size_t request_call_num = _coordinator_channels.size();
FlClientBrpcClosure *closure =
new FlClientBrpcClosure(request_call_num, [request_call_num](void *done) {
auto *closure = reinterpret_cast<FlClientBrpcClosure *>(done);
int ret = 0;
for (size_t i = 0; i < request_call_num; i++) {
if (closure->check_response(i, PUSH_FL_CLIENT_INFO_SYNC) != 0) {
LOG(ERROR) << "fl-ps > PushFLClientInfoSync response from "
"coordinator is failed";
ret = -1;
return;
} else {
VLOG(0) << "fl-ps > rpc service call cost time: "
<< (closure->cntl(i)->latency_us() / 1000) << " ms";
}
}
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
std::future<int32_t> fut = promise->get_future();
closure->add_promise(promise);
for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(PUSH_FL_CLIENT_INFO_SYNC);
closure->request(i)->set_client_id(_client_id);
closure->request(i)->set_str_params(fl_client_info);
brpc::Channel *rpc_channel = _coordinator_channels[0][0].get();
if (rpc_channel == nullptr) {
LOG(ERROR) << "_coordinator_channels is null";
return;
}
PsService_Stub rpc_stub(rpc_channel); // CoordinatorService
rpc_stub.FLService(
closure->cntl(i), closure->request(i), closure->response(i), closure);
fut.wait();
}
VLOG(0) << "fl-ps > PushFLClientInfoSync finished, client id: " << _client_id;
return;
}
std::string BrpcPsClient::PullFlStrategy() {
while (!_service._is_fl_strategy_ready) {
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
VLOG(0) << "fl-ps > waiting for fl strategy returned from coordinator";
}
_service._is_fl_strategy_ready =
false; // only support single thread, no need for multi-threads
return _service._fl_strategy;
}
int32_t BrpcPsClient::Initialize() {
_async_call_num = 0;
brpc::ChannelOptions options;
options.protocol = "baidu_std";
options.timeout_ms = FLAGS_pserver_timeout_ms;
options.connection_type = "pooled";
options.connect_timeout_ms = FLAGS_pserver_connect_timeout_ms;
options.max_retry = 3;
std::ostringstream os;
std::string server_ip_port;
std::string client_ip(butil::my_ip_cstr());
// 获取server列表,并连接
std::vector<PSHost> server_list = _env->GetPsServers();
_server_channels.resize(server_list.size());
for (size_t i = 0; i < server_list.size(); ++i) {
server_ip_port.assign(server_list[i].ip.c_str());
server_ip_port.append(":");
server_ip_port.append(std::to_string(server_list[i].port));
for (size_t j = 0; j < _server_channels[i].size(); ++j) {
_server_channels[i][j].reset(new brpc::Channel());
if (_server_channels[i][j]->Init(server_ip_port.c_str(), "", &options) !=
0) {
VLOG(0) << "BrpcPSclient connect to Server:" << server_ip_port
<< " Failed! Try again.";
std::string int_ip_port =
GetIntTypeEndpoint(server_list[i].ip, server_list[i].port);
if (_server_channels[i][j]->Init(int_ip_port.c_str(), "", &options) !=
0) {
LOG(ERROR) << "BrpcPSclient connect to Server:" << int_ip_port
<< " Failed!";
return -1;
}
}
}
os << server_ip_port << ",";
}
// 启动client探听接口, 并相互建立连接
StartClientService();
// 异步push 请求队列初始化
const auto &worker_param = _config.worker_param().downpour_worker_param();
for (int i = 0; i < worker_param.downpour_table_param_size(); ++i) {
auto type = worker_param.downpour_table_param(i).type();
auto table_id = worker_param.downpour_table_param(i).table_id();
if (type == PS_DENSE_TABLE) {
_push_dense_task_queue_map[table_id] =
paddle::framework::MakeChannel<DenseAsyncTask *>();
}
if (type == PS_SPARSE_TABLE) {
_push_sparse_task_queue_map[table_id] =
paddle::framework::MakeChannel<SparseAsyncTask *>();
_push_sparse_merge_count_map[table_id] = 0;
}
}
auto &profiler = CostProfiler::instance();
profiler.register_profiler("pserver_client_pull_dense");
profiler.register_profiler("pserver_client_pull_sparse");
profiler.register_profiler("pserver_client_pull_sparse_param");
profiler.register_profiler("pserver_client_pull_sparse_local");
profiler.register_profiler("pserver_client_push_sparse");
profiler.register_profiler("pserver_client_push_sparse_parse");
profiler.register_profiler("client_push_sparse_put");
profiler.register_profiler("pserver_client_push_sparse");
profiler.register_profiler("pserver_client_push_sparse_merge");
profiler.register_profiler("pserver_client_push_sparse_rpc");
profiler.register_profiler("pserver_client_push_dense");
profiler.register_profiler("pserver_client_push_dense_parse");
profiler.register_profiler("push_dense_put");
profiler.register_profiler("pserver_client_push_dense_merge");
profiler.register_profiler("pserver_client_push_dense_rpc");
profiler.register_profiler("pserver_client_push_dense_send");
_running = true;
_flushing = false;
// 启动异步push线程
_async_push_sparse_thread =
std::thread(std::bind(&BrpcPsClient::PushSparseTaskConsume, this));
// _async_push_sparse_thread.detach();
_async_push_dense_thread =
std::thread(std::bind(&BrpcPsClient::PushDenseTaskConsume, this));
// for debug
// _print_thread =
// std::thread(std::bind(&BrpcPsClient::PrintQueueSizeThread, this));
return 0;
}
int DownpourBrpcClosure::check_response(size_t request_idx, int cmd_id) {
if (_cntls[request_idx]->Failed()) {
LOG(ERROR) << "resquest cmd_id:" << cmd_id
<< " failed, "
"err:"
<< _cntls[request_idx]->ErrorText();
return -1;
}
if (_responses[request_idx].err_code() != 0) {
LOG(ERROR) << "response ret bad, server_idx:" << request_idx
<< "cmd_id:" << cmd_id
<< " err_code:" << _responses[request_idx].err_code()
<< " err_msg:" << _responses[request_idx].err_msg();
return -1;
}
return 0;
}
int DownpourBrpcClosure::check_save_response(size_t request_idx, int cmd_id) {
int32_t feasign_size = 0;
if (_cntls[request_idx]->Failed()) {
LOG(ERROR) << "resquest cmd_id:" << cmd_id
<< " failed, "
"err:"
<< _cntls[request_idx]->ErrorText();
return -1;
}
feasign_size = _responses[request_idx].err_code();
if (feasign_size < 0) {
LOG(ERROR) << "response ret bad, server_idx:" << request_idx
<< "cmd_id:" << cmd_id
<< " err_code:" << _responses[request_idx].err_code()
<< " err_msg:" << _responses[request_idx].err_msg();
return -1;
}
return feasign_size;
}
std::string DownpourBrpcClosure::get_response(size_t request_idx, int cmd_id) {
std::string data = _responses[request_idx].data();
return data;
}
int FlClientBrpcClosure::check_response(size_t request_idx, int cmd_id) {
if (_cntls[request_idx]->Failed()) {
LOG(ERROR) << "resquest cmd_id:" << cmd_id
<< " failed, "
"err:"
<< _cntls[request_idx]->ErrorText();
return -1;
}
if (_responses[request_idx].err_code() != 0) {
LOG(ERROR) << "response ret bad, server_idx:" << request_idx
<< "cmd_id:" << cmd_id
<< " err_code:" << _responses[request_idx].err_code()
<< " err_msg:" << _responses[request_idx].err_msg();
return -1;
}
return 0;
}
std::future<int32_t> BrpcPsClient::PrintTableStat(uint32_t table_id) {
size_t request_call_num = _server_channels.size();
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [request_call_num, table_id](void *done) {
int ret = 0;
uint64_t feasign_size = 0;
uint64_t mf_size = 0;
paddle::framework::BinaryArchive ar;
auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, PS_PRINT_TABLE_STAT) != 0) {
ret = -1;
break;
}
std::string resp = closure->get_response(i, PS_PRINT_TABLE_STAT);
ar.SetReadBuffer(
const_cast<char *>(resp.c_str()), resp.length(), nullptr);
feasign_size += ar.Get<uint64_t>();
mf_size += ar.Get<uint64_t>();
}
closure->set_promise_value(ret);
std::cout << "table id: " << table_id
<< ", feasign size: " << feasign_size
<< ", mf size: " << mf_size << std::endl;
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(PS_PRINT_TABLE_STAT);
closure->request(i)->set_table_id(table_id);
closure->request(i)->set_client_id(_client_id);
PsService_Stub rpc_stub(GetCmdChannel(i));
closure->cntl(i)->set_timeout_ms(
10800000); // cmd msg don't limit timeout for save/load
rpc_stub.service(
closure->cntl(i), closure->request(i), closure->response(i), closure);
}
return fut;
}
std::future<int32_t> BrpcPsClient::SendCmd(
uint32_t table_id, int cmd_id, const std::vector<std::string> &params) {
size_t request_call_num = _server_channels.size();
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [request_call_num, cmd_id](void *done) {
int ret = 0;
auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, cmd_id) != 0) {
ret = -1;
break;
}
}
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(cmd_id);
closure->request(i)->set_table_id(table_id);
closure->request(i)->set_client_id(_client_id);
for (const auto &param : params) {
closure->request(i)->add_params(param);
}
PsService_Stub rpc_stub(GetCmdChannel(i));
closure->cntl(i)->set_timeout_ms(
10800000 * 2); // cmd msg don't limit timeout for save/load
rpc_stub.service(
closure->cntl(i), closure->request(i), closure->response(i), closure);
}
return fut;
}
std::future<int32_t> BrpcPsClient::SendSaveCmd(
uint32_t table_id, int cmd_id, const std::vector<std::string> &params) {
size_t request_call_num = _server_channels.size();
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [request_call_num, cmd_id](void *done) {
int ret = 0;
uint32_t feasign_size = 0;
auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_save_response(i, cmd_id) < 0) {
ret = -1;
break;
}
feasign_size += closure->check_save_response(i, cmd_id);
}
if (ret == 0) {
closure->set_promise_value(feasign_size);
} else {
closure->set_promise_value(ret);
}
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(cmd_id);
closure->request(i)->set_table_id(table_id);
closure->request(i)->set_client_id(_client_id);
for (const auto &param : params) {
closure->request(i)->add_params(param);
}
PsService_Stub rpc_stub(GetCmdChannel(i));
closure->cntl(i)->set_timeout_ms(
10800000); // cmd msg don't limit timeout for save/load
rpc_stub.service(
closure->cntl(i), closure->request(i), closure->response(i), closure);
}
return fut;
}
std::future<int32_t> BrpcPsClient::Shrink(uint32_t table_id,
const std::string threshold) {
return SendCmd(table_id, PS_SHRINK_TABLE, {threshold});
}
std::future<int32_t> BrpcPsClient::Load(const std::string &epoch,
const std::string &mode) {
return SendCmd(-1, PS_LOAD_ALL_TABLE, {epoch, mode});
}
std::future<int32_t> BrpcPsClient::Load(uint32_t table_id,
const std::string &epoch,
const std::string &mode) {
return SendCmd(table_id, PS_LOAD_ONE_TABLE, {epoch, mode});
}
std::future<int32_t> BrpcPsClient::Save(const std::string &epoch,
const std::string &mode) {
VLOG(1) << "BrpcPsClient::save path " << epoch;
return SendSaveCmd(-1, PS_SAVE_ALL_TABLE, {epoch, mode});
}
std::future<int32_t> BrpcPsClient::Save(uint32_t table_id,
const std::string &epoch,
const std::string &mode) {
VLOG(1) << "BrpcPsClient::save one table path " << epoch << " table_id "
<< table_id;
return SendSaveCmd(table_id, PS_SAVE_ONE_TABLE, {epoch, mode});
}
std::future<int32_t> BrpcPsClient::CacheShuffle(
uint32_t table_id,
const std::string &path,
const std::string &mode,
const std::string &cache_threshold) {
VLOG(1) << "BrpcPsClient send cmd for cache shuffle";
return SendSaveCmd(table_id, PS_CACHE_SHUFFLE, {path, mode, cache_threshold});
}
std::future<int32_t> BrpcPsClient::CacheShuffleMultiTable(
std::vector<int> tables,
const std::string &path,
const std::string &mode,
const std::string &cache_threshold) {
VLOG(1) << "BrpcPsClient send cmd for cache shuffle multi table one path";
std::vector<std::string> param;
param.push_back(path);
param.push_back(mode);
param.push_back(cache_threshold);
for (size_t i = 0; i < tables.size(); i++) {
param.push_back(std::to_string(tables[i]));
}
return SendSaveCmd(0, PS_CACHE_SHUFFLE, param);
}
std::future<int32_t> BrpcPsClient::SaveCache(uint32_t table_id,
const std::string &path,
const std::string &mode) {
return SendSaveCmd(table_id, PS_SAVE_ONE_CACHE_TABLE, {path, mode});
}
std::future<int32_t> BrpcPsClient::GetCacheThreshold(uint32_t table_id,
double &cache_threshold) {
int cmd_id = PS_GET_CACHE_THRESHOLD;
size_t request_call_num = _server_channels.size();
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num,
[request_call_num, cmd_id, &cache_threshold](void *done) {
int ret = 0;
auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
std::vector<double> cache_thresholds(request_call_num, 0);
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, cmd_id) != 0) {
ret = -1;
break;
}
std::string cur_res = closure->get_response(i, cmd_id);
cache_thresholds[i] = std::stod(cur_res);
}
double sum_threshold = 0.0;
int count = 0;
for (auto t : cache_thresholds) {
if (t >= 0) {
sum_threshold += t;
++count;
}
}
if (count == 0) {
cache_threshold = 0;
} else {
cache_threshold = sum_threshold / count;
}
VLOG(1) << "client get cache threshold: " << cache_threshold;
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(cmd_id);
closure->request(i)->set_table_id(table_id);
closure->request(i)->set_client_id(_client_id);
PsService_Stub rpc_stub(GetCmdChannel(i));
closure->cntl(i)->set_timeout_ms(10800000);
rpc_stub.service(
closure->cntl(i), closure->request(i), closure->response(i), closure);
}
return fut;
}
std::future<int32_t> BrpcPsClient::Clear() {
return SendCmd(-1, PS_CLEAR_ALL_TABLE, {});
}
std::future<int32_t> BrpcPsClient::Clear(uint32_t table_id) {
return SendCmd(table_id, PS_CLEAR_ONE_TABLE, {});
}
std::future<int32_t> BrpcPsClient::Revert() {
return SendCmd(-1, PS_REVERT, {});
}
std::future<int32_t> BrpcPsClient::CheckSavePrePatchDone() {
return SendCmd(-1, PS_CHECK_SAVE_PRE_PATCH_DONE, {});
}
std::future<int32_t> BrpcPsClient::Flush() {
VLOG(0) << "BrpcPsClient::flush begin";
_flushing = true;
std::promise<int> promise;
std::future<int32_t> fut = promise.get_future();
do {
VLOG(3) << "wait _async_call_num:" << _async_call_num;
usleep(100000); // sleep 100ms wait async end
} while (_async_call_num > 0);
VLOG(1) << "flush _async_call_num = 0";
promise.set_value(0);
_flushing = false;
VLOG(0) << "BrpcPsClient::flush done";
PrintQueueSize();
return fut;
}
void BrpcPsClient::PrintQueueSize() {
for (auto &push_sparse_task_itr : _push_sparse_task_queue_map) {
auto table_id = push_sparse_task_itr.first;
auto queue_size = push_sparse_task_itr.second->Size();
VLOG(0) << "BrpcPsClient::PrintQueueSize: table " << table_id
<< " size: " << queue_size;
}
for (auto &task_queue_itr : _push_dense_task_queue_map) {
auto table_id = task_queue_itr.first;
auto queue_size = task_queue_itr.second->Size();
VLOG(0) << "BrpcPsClient::PrintQueueSize: table " << table_id
<< " size: " << queue_size;
}
}
void BrpcPsClient::PrintQueueSizeThread() {
while (_running) {
usleep(1000000 * 60 * 2);
PrintQueueSize();
}
}
void BrpcPsClient::FinalizeWorker() {
Flush();
VLOG(0) << "BrpcPsClient::FinalizeWorker begin join thread";
_running = false;
_async_push_dense_thread.join();
_async_push_sparse_thread.join();
// _print_thread.join();
VLOG(0) << "BrpcPsClient::FinalizeWorker begin join server";
_server.Stop(1000);
_server.Join();
_server_started = false;
VLOG(0) << "BrpcPsClient::FinalizeWorker done";
}
std::future<int32_t> BrpcPsClient::StopServer() {
return SendCmd(-1, PS_STOP_SERVER, {});
}
std::future<int32_t> BrpcPsClient::StartProfiler() {
return SendCmd(-1, PS_START_PROFILER, {});
}
std::future<int32_t> BrpcPsClient::StopProfiler() {
return SendCmd(-1, PS_STOP_PROFILER, {});
}
std::future<int32_t> BrpcPsClient::Barrier(size_t table_id,
uint32_t barrier_type) {
return SendCmd(table_id, PS_BARRIER, {std::to_string(barrier_type)});
}
std::future<int32_t> BrpcPsClient::PullGeoParam(size_t table_id,
std::vector<float> *values,
std::vector<uint64_t> *keys,
int pserver_idx) {
auto *accessor = GetTableAccessor(table_id);
DownpourBrpcClosure *closure =
new DownpourBrpcClosure(1, [keys, values, accessor](void *done) {
int ret = 0;
auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
uint32_t shard_nums;
if (closure->check_response(0, PS_PULL_GEO_PARAM) != 0) {
ret = -1;
}
auto &res_io_buffer = closure->cntl(0)->response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
io_buffer_itr.copy_and_forward(reinterpret_cast<void *>(&shard_nums),
sizeof(uint32_t));
keys->resize(shard_nums);
values->resize(shard_nums * accessor->GetAccessorInfo().update_dim);
io_buffer_itr.copy_and_forward((void *)(keys->data()), // NOLINT
sizeof(uint64_t) * shard_nums);
io_buffer_itr.copy_and_forward(
(void *)(values->data()), // NOLINT
shard_nums * accessor->GetAccessorInfo().update_size);
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
closure->request(0)->set_cmd_id(PS_PULL_GEO_PARAM);
closure->request(0)->set_table_id(table_id);
closure->request(0)->set_client_id(_client_id);
PsService_Stub rpc_stub(GetCmdChannel(pserver_idx));
closure->cntl(0)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(
closure->cntl(0), closure->request(0), closure->response(0), closure);
return fut;
}
// for GEO
std::future<int32_t> BrpcPsClient::PushSparseParam(size_t table_id,
const uint64_t *keys,
const float **update_values,
size_t num,
void *done) {
auto *accessor = GetTableAccessor(table_id);
// 发送RPC请求
DownpourBrpcClosure *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
size_t request_call_num = _server_channels.size();
std::vector<std::vector<uint64_t>> ids;
std::vector<std::vector<const float *>> value_ptrs;
ids.resize(request_call_num);
value_ptrs.resize(request_call_num);
for (size_t i = 0; i < num; ++i) {
size_t pserver_idx = keys[i] % request_call_num;
ids[pserver_idx].push_back(keys[i]);
value_ptrs[pserver_idx].push_back(update_values[i]);
}
for (size_t shard_idx = 0; shard_idx < request_call_num; ++shard_idx) {
auto kvs = ids[shard_idx];
auto value_ptr = value_ptrs[shard_idx];
size_t kv_size = kvs.size();
uint32_t value_size = accessor->GetAccessorInfo().update_size;
// 发送RPC请求
auto *push_request = closure->request(shard_idx);
push_request->set_cmd_id(PS_PUSH_SPARSE_PARAM);
push_request->set_table_id(table_id);
push_request->set_client_id(_client_id);
push_request->add_params((char *)&kv_size, sizeof(uint32_t)); // NOLINT
auto *push_data = push_request->mutable_data();
push_data->resize(kv_size * (sizeof(uint64_t) + value_size));
char *push_data_ptr = const_cast<char *>(push_data->data());
memcpy(push_data_ptr, kvs.data(), kv_size * sizeof(uint64_t));
push_data_ptr += kv_size * sizeof(uint64_t);
for (size_t i = 0; i < kv_size; ++i) {
memcpy(push_data_ptr, value_ptr[i], value_size);
push_data_ptr += value_size;
}
PsService_Stub rpc_stub(GetSparseChannel(shard_idx));
closure->cntl(shard_idx)->set_request_compress_type(
(brpc::CompressType)FLAGS_pserver_communicate_compress_type);
rpc_stub.service(closure->cntl(shard_idx),
closure->request(shard_idx),
closure->response(shard_idx),
closure);
}
return fut;
}
std::future<int32_t> BrpcPsClient::PullDense(Region *regions,
size_t region_num,
size_t table_id) {
auto timer = std::make_shared<CostTimer>("pserver_client_pull_dense");
auto *accessor = GetTableAccessor(table_id);
auto fea_dim = accessor->GetAccessorInfo().fea_dim;
size_t request_call_num = _server_channels.size();
uint32_t num_per_shard = DenseDimPerShard(fea_dim, request_call_num);
// callback 将各shard结果,顺序填入region
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num,
[request_call_num, num_per_shard, regions, region_num, accessor](
void *done) {
int ret = 0;
size_t region_idx = 0; // 当前填充的region偏移
size_t region_data_idx = 0; // 当前填充的region内data偏移
auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
size_t shard_data_size =
num_per_shard * accessor->GetAccessorInfo().select_size;
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, PS_PULL_DENSE_TABLE) != 0) {
ret = -1;
break;
}
auto &res_io_buffer = closure->cntl(i)->response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
size_t shard_buffer_remain = res_io_buffer.size();
if (shard_buffer_remain != shard_data_size) {
LOG(ERROR) << "expect res_size:" << shard_data_size
<< ", but size:" << shard_buffer_remain
<< ", ignore this response";
ret = -1;
break;
}
while (shard_buffer_remain > 0 && region_idx < region_num) {
auto &region = regions[region_idx];
if (region.size - region_data_idx >= shard_buffer_remain) {
// region待填充空间 >= 分片buffer数据, 直接拷贝置入
io_buffer_itr.copy_and_forward(
reinterpret_cast<void *>(region.data + region_data_idx),
shard_buffer_remain);
region_data_idx += shard_buffer_remain;
shard_buffer_remain = 0;
} else if (region.size - region_data_idx == 0) {
// region填满,切换到下一个region
++region_idx;
region_data_idx = 0;
} else {
// region不足以容纳所有数据,则能放多少 拷贝多少
io_buffer_itr.copy_and_forward(
reinterpret_cast<void *>(region.data + region_data_idx),
region.size - region_data_idx);
shard_buffer_remain -= (region.size - region_data_idx);
++region_idx;
region_data_idx = 0;
}
}
}
closure->set_promise_value(ret);
});
closure->add_timer(timer);
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(PS_PULL_DENSE_TABLE);
closure->request(i)->set_table_id(table_id);
closure->request(i)->set_client_id(_client_id);
closure->request(i)->add_params((char *)&num_per_shard, // NOLINT
sizeof(num_per_shard));
PsService_Stub rpc_stub(GetDenseChannel(i));
rpc_stub.service(
closure->cntl(i), closure->request(i), closure->response(i), closure);
}
return fut;
}
std::future<int32_t> BrpcPsClient::PushDenseParam(const Region *regions,
size_t region_num,
size_t table_id) {
auto *accessor = GetTableAccessor(table_id);
auto accessor_info = accessor->GetAccessorInfo();
size_t request_call_num = _server_channels.size();
// 1.拆分Region数据到shard中,后续多shard并行拷贝数据
std::vector<std::vector<Region>> regions_partition(request_call_num);
uint32_t num_per_shard =
DenseDimPerShard(accessor_info.fea_dim, request_call_num);
size_t shard_data_size = num_per_shard * accessor_info.update_size;
size_t current_region_idx = 0;
size_t current_region_data_idx = 0;
for (size_t i = 0; i < request_call_num; ++i) {
size_t shard_data_remain_size = shard_data_size;
while (shard_data_remain_size > 0 && current_region_idx < region_num) {
const auto &region = regions[current_region_idx];
size_t region_remain_size = region.size - current_region_data_idx;
if (shard_data_remain_size >= region_remain_size) {
regions_partition[i].push_back(
Region(region.data + current_region_data_idx, region_remain_size));
++current_region_idx;
current_region_data_idx = 0;
shard_data_remain_size -= region_remain_size;
} else {
regions_partition[i].push_back(Region(
region.data + current_region_data_idx, shard_data_remain_size));
current_region_data_idx += shard_data_remain_size;
shard_data_remain_size = 0;
}
}
}
DownpourBrpcClosure *closure =
new DownpourBrpcClosure(request_call_num, [request_call_num](void *done) {
int ret = 0;
auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, PS_PUSH_DENSE_PARAM) != 0) {
ret = -1;
break;
}
}
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
static const int REGION_ASSIGN_BUFFER_SIZE = 1024 * 10;
static char region_assign_buffer[REGION_ASSIGN_BUFFER_SIZE]; // 用于数据补齐
// 开始多shard并行拷贝&请求
for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(PS_PUSH_DENSE_PARAM);
closure->request(i)->set_table_id(table_id);
closure->request(i)->set_client_id(_client_id);
auto &request_buffer = closure->cntl(i)->request_attachment();
request_buffer.append(reinterpret_cast<void *>(&num_per_shard),
sizeof(uint32_t));
auto &region_list = regions_partition[i];
size_t fill_remain_size = shard_data_size;
for (auto &region : region_list) {
fill_remain_size -= region.size;
request_buffer.append(reinterpret_cast<void *>(region.data), region.size);
}
// 保证各分片数据对齐
while (fill_remain_size > 0) {
size_t fill_num = fill_remain_size > REGION_ASSIGN_BUFFER_SIZE
? REGION_ASSIGN_BUFFER_SIZE
: fill_remain_size;
request_buffer.append(reinterpret_cast<void *>(region_assign_buffer),
fill_num);
fill_remain_size -= fill_num;
}
PsService_Stub rpc_stub(GetDenseChannel(i));
rpc_stub.service(
closure->cntl(i), closure->request(i), closure->response(i), closure);
}
return fut;
}
std::future<int32_t> BrpcPsClient::PushSparseRawGradient(
size_t table_id,
const uint64_t *keys,
const float **update_values,
size_t num,
void *done) {
auto *accessor = GetTableAccessor(table_id);
// 发送RPC请求
DownpourBrpcClosure *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
size_t request_call_num = _server_channels.size();
std::vector<std::vector<uint64_t>> ids;
std::vector<std::vector<const float *>> value_ptrs;
ids.resize(request_call_num);
value_ptrs.resize(request_call_num);
const auto &server_param = _config.server_param().downpour_server_param();
uint64_t shard_num = FLAGS_pserver_sparse_table_shard_num;
for (int i = 0; i < server_param.downpour_table_param_size(); ++i) {
const auto &table_param = server_param.downpour_table_param(i);
if (table_param.table_id() == table_id) {
shard_num = table_param.shard_num();
break;
}
}
for (size_t i = 0; i < num; ++i) {
size_t pserver_idx = get_sparse_shard(shard_num, request_call_num, keys[i]);
ids[pserver_idx].push_back(keys[i]);
value_ptrs[pserver_idx].push_back(update_values[i]);
}
for (size_t shard_idx = 0; shard_idx < request_call_num; ++shard_idx) {
auto kvs = ids[shard_idx];
auto value_ptr = value_ptrs[shard_idx];
size_t kv_size = kvs.size();
uint32_t value_size = accessor->GetAccessorInfo().update_size;
// 发送RPC请求
auto *push_request = closure->request(shard_idx);
push_request->set_cmd_id(PS_PUSH_SPARSE_TABLE);
push_request->set_table_id(table_id);
push_request->set_client_id(_client_id);
push_request->add_params((char *)&kv_size, sizeof(uint32_t)); // NOLINT
auto *push_data = push_request->mutable_data();
push_data->resize(kv_size * (sizeof(uint64_t) + value_size));
char *push_data_ptr = const_cast<char *>(push_data->data());
memcpy(push_data_ptr, kvs.data(), kv_size * sizeof(uint64_t));
push_data_ptr += kv_size * sizeof(uint64_t);
for (size_t i = 0; i < kv_size; ++i) {
memcpy(push_data_ptr, value_ptr[i], value_size);
push_data_ptr += value_size;
}
PsService_Stub rpc_stub(GetSparseChannel(shard_idx));
closure->cntl(shard_idx)->set_request_compress_type(
(brpc::CompressType)FLAGS_pserver_communicate_compress_type);
rpc_stub.service(closure->cntl(shard_idx),
closure->request(shard_idx),
closure->response(shard_idx),
closure);
}
return fut;
}
std::future<int32_t> BrpcPsClient::PushDenseRawGradient(
int table_id,
float *total_send_data,
size_t total_send_data_size,
void *done) {
size_t request_call_num = _server_channels.size();
DownpourBrpcClosure *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
auto *accessor = GetTableAccessor(table_id);
uint32_t num_per_shard =
DenseDimPerShard(accessor->GetAccessorInfo().fea_dim, request_call_num);
for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(PS_PUSH_DENSE_TABLE);
closure->request(i)->set_table_id(table_id);
closure->request(i)->set_client_id(_client_id);
auto *push_data = closure->request(i)->mutable_data();
push_data->clear();
push_data->resize(sizeof(uint32_t) + num_per_shard * sizeof(float));
char *push_data_ptr = const_cast<char *>(push_data->data());
memcpy(push_data_ptr, &num_per_shard, sizeof(uint32_t));
memcpy(push_data_ptr + sizeof(uint32_t),
total_send_data + i * num_per_shard,
num_per_shard * sizeof(float));
// closure->cntl(i)->set_request_compress_type(
// (brpc::CompressType)FLAGS_pserver_communicate_compress_type);
PsService_Stub rpc_stub(GetDenseChannel(i));
rpc_stub.service(
closure->cntl(i), closure->request(i), closure->response(i), closure);
}
return fut;
}
std::future<int32_t> BrpcPsClient::PushGlobalStep(int table_id,
int64_t *total_send_data,
void *done) {
size_t request_call_num = _server_channels.size();
DownpourBrpcClosure *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(PS_PUSH_GLOBAL_STEP);
closure->request(i)->set_table_id(table_id);
closure->request(i)->set_client_id(_client_id);
auto *push_data = closure->request(i)->mutable_data();
push_data->clear();
int32_t num_per_shard = 1;
push_data->resize(sizeof(uint32_t) + num_per_shard * sizeof(int64_t));
char *push_data_ptr = const_cast<char *>(push_data->data());
memcpy(push_data_ptr, &num_per_shard, sizeof(uint32_t));
memcpy(push_data_ptr + sizeof(uint32_t),
total_send_data,
num_per_shard * sizeof(int64_t));
PsService_Stub rpc_stub(GetDenseChannel(i));
rpc_stub.service(
closure->cntl(i), closure->request(i), closure->response(i), closure);
}
return fut;
}
std::future<int32_t> BrpcPsClient::PullSparse(float **select_values,
size_t table_id,
const uint64_t *keys,
size_t num,
bool is_training) {
auto timer = std::make_shared<CostTimer>("pserver_client_pull_sparse");
auto local_timer =
std::make_shared<CostTimer>("pserver_client_pull_sparse_local");
size_t request_call_num = _server_channels.size();
auto shard_sorted_kvs = std::make_shared<
std::vector<std::vector<std::pair<uint64_t, float *>>>>();
shard_sorted_kvs->resize(request_call_num);
const auto &server_param = _config.server_param().downpour_server_param();
uint64_t shard_num = FLAGS_pserver_sparse_table_shard_num;
for (int i = 0; i < server_param.downpour_table_param_size(); ++i) {
const auto &table_param = server_param.downpour_table_param(i);
if (table_param.table_id() == table_id) {
shard_num = table_param.shard_num();
break;
}
}
for (size_t i = 0; i < num; ++i) {
size_t shard_id = get_sparse_shard(shard_num, request_call_num, keys[i]);
shard_sorted_kvs->at(shard_id).push_back({keys[i], select_values[i]});
}
auto *accessor = GetTableAccessor(table_id);
size_t value_size = accessor->GetAccessorInfo().select_size;
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [shard_sorted_kvs, value_size](void *done) {
int ret = 0;
auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
for (size_t i = 0; i < shard_sorted_kvs->size(); ++i) {
if (closure->check_response(i, PS_PULL_SPARSE_TABLE) != 0) {
ret = -1;
break;
}
auto &request_kvs = shard_sorted_kvs->at(i);
auto &res_io_buffer = closure->cntl(i)->response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
uint64_t last_key = UINT64_MAX;
float *last_value_data = NULL;
for (size_t kv_idx = 0; kv_idx < request_kvs.size(); ++kv_idx) {
auto *kv_pair = &(request_kvs[kv_idx]);
if (kv_pair->first == last_key) {
memcpy(reinterpret_cast<void *>(kv_pair->second),
reinterpret_cast<void *>(last_value_data),
value_size);
} else {
last_key = kv_pair->first;
last_value_data = kv_pair->second;
if (value_size !=
io_buffer_itr.copy_and_forward(
reinterpret_cast<void *>(last_value_data), value_size)) {
LOG(WARNING) << "res data is lack or not in format";
ret = -1;
break;
}
}
}
}
closure->set_promise_value(ret);
});
closure->add_timer(timer);
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t i = 0; i < request_call_num; ++i) {
auto &sorted_kvs = shard_sorted_kvs->at(i);
std::sort(sorted_kvs.begin(),
sorted_kvs.end(),
[](const std::pair<uint64_t, float *> &k1,
const std::pair<uint64_t, float *> &k2) {
return k1.first < k2.first;
});
uint64_t last_key = UINT64_MAX;
uint32_t kv_request_count = 0;
size_t sorted_kv_size = sorted_kvs.size();
auto &request_buffer = closure->cntl(i)->request_attachment();
request_buffer.append(reinterpret_cast<void *>(&is_training), sizeof(bool));
std::vector<uint32_t> keys_counter;
keys_counter.reserve(sorted_kv_size);
for (size_t kv_idx = 0; kv_idx < sorted_kv_size; ++kv_idx) {
++kv_request_count;
uint32_t keys = 1;
last_key = sorted_kvs[kv_idx].first;
request_buffer.append(reinterpret_cast<void *>(&last_key),
sizeof(uint64_t));
while (kv_idx < sorted_kv_size - 1 &&
last_key == sorted_kvs[kv_idx + 1].first) {
++kv_idx;
++keys;
}
keys_counter.push_back(keys);
}
request_buffer.append(reinterpret_cast<void *>(keys_counter.data()),
sizeof(uint32_t) * keys_counter.size());
if (kv_request_count == 0) {
closure->Run();
} else {
closure->request(i)->set_cmd_id(PS_PULL_SPARSE_TABLE);
closure->request(i)->set_table_id(table_id);
closure->request(i)->set_client_id(_client_id);
closure->request(i)->add_params((char *)&kv_request_count, // NOLINT
sizeof(uint32_t));
PsService_Stub rpc_stub(GetCmdChannel(i));
closure->cntl(i)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(
closure->cntl(i), closure->request(i), closure->response(i), closure);
}
}
return fut;
}
// for GEO
std::future<int32_t> BrpcPsClient::PullSparseParam(float **select_values,
size_t table_id,
const uint64_t *keys,
size_t num,
bool is_training) {
auto timer = std::make_shared<CostTimer>("pserver_client_pull_sparse_param");
size_t request_call_num = _server_channels.size();
auto shard_sorted_kvs = std::make_shared<
std::vector<std::vector<std::pair<uint64_t, float *>>>>();
shard_sorted_kvs->resize(request_call_num);
for (size_t i = 0; i < num; ++i) {
size_t shard_id = keys[i] % request_call_num;
shard_sorted_kvs->at(shard_id).push_back({keys[i], select_values[i]});
}
auto *accessor = GetTableAccessor(table_id);
size_t value_size = accessor->GetAccessorInfo().select_size;
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [shard_sorted_kvs, value_size](void *done) {
int ret = 0;
auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
for (size_t i = 0; i < shard_sorted_kvs->size(); ++i) {
if (closure->check_response(i, PS_PULL_SPARSE_TABLE) != 0) {
ret = -1;
break;
}
auto &request_kvs = shard_sorted_kvs->at(i);
auto &res_io_buffer = closure->cntl(i)->response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
uint64_t last_key = UINT64_MAX;
float *last_value_data = NULL;
// can remove sort&unique
for (size_t kv_idx = 0; kv_idx < request_kvs.size(); ++kv_idx) {
auto *kv_pair = &(request_kvs[kv_idx]);
if (kv_pair->first == last_key) {
memcpy(reinterpret_cast<void *>(kv_pair->second),
reinterpret_cast<void *>(last_value_data),
value_size);
} else {
last_key = kv_pair->first;
last_value_data = kv_pair->second;
if (value_size !=
io_buffer_itr.copy_and_forward(
reinterpret_cast<void *>(last_value_data), value_size)) {
LOG(WARNING) << "res data is lack or not in format";
ret = -1;
break;
}
}
}
}
closure->set_promise_value(ret);
});
closure->add_timer(timer);
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t i = 0; i < request_call_num; ++i) {
auto &sorted_kvs = shard_sorted_kvs->at(i);
std::sort(sorted_kvs.begin(),
sorted_kvs.end(),
[](const std::pair<uint64_t, float *> &k1,
const std::pair<uint64_t, float *> &k2) {
return k1.first < k2.first;
});
uint64_t last_key = UINT64_MAX;
uint32_t kv_request_count = 0;
size_t sorted_kv_size = sorted_kvs.size();
auto &request_buffer = closure->cntl(i)->request_attachment();
request_buffer.append(reinterpret_cast<void *>(&is_training), sizeof(bool));
std::vector<uint32_t> keys_counter;
keys_counter.reserve(sorted_kv_size);
for (size_t kv_idx = 0; kv_idx < sorted_kv_size; ++kv_idx) {
++kv_request_count;
uint32_t keys = 1;
last_key = sorted_kvs[kv_idx].first;
request_buffer.append(reinterpret_cast<void *>(&last_key),
sizeof(uint64_t));
while (kv_idx < sorted_kv_size - 1 &&
last_key == sorted_kvs[kv_idx + 1].first) {
++kv_idx;
++keys;
}
keys_counter.push_back(keys);
}
request_buffer.append(reinterpret_cast<void *>(keys_counter.data()),
sizeof(uint32_t) * keys_counter.size());
if (kv_request_count == 0) {
closure->Run();
} else {
closure->request(i)->set_cmd_id(PS_PULL_SPARSE_TABLE);
closure->request(i)->set_table_id(table_id);
closure->request(i)->set_client_id(_client_id);
closure->request(i)->add_params((char *)&kv_request_count, // NOLINT
sizeof(uint32_t));
PsService_Stub rpc_stub(GetCmdChannel(i));
closure->cntl(i)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(
closure->cntl(i), closure->request(i), closure->response(i), closure);
}
}
return fut;
}
std::future<int32_t> BrpcPsClient::SendClient2ClientMsg(
int msg_type, int to_client_id, const std::string &msg) {
auto promise = std::make_shared<std::promise<int32_t>>();
std::future<int> fut = promise->get_future();
if (to_client_id >= 0 &&
static_cast<size_t>(to_client_id) >= _client_channels.size()) {
VLOG(0) << "to_client_id is out of range clients, which size is "
<< _client_channels.size();
promise->set_value(-1);
return fut;
}
auto *closure = new DownpourBrpcClosure(1, [msg_type](void *done) {
auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
int32_t ret = closure->check_response(0, msg_type + 1000);
closure->set_promise_value(ret);
});
closure->add_promise(promise);
closure->request(0)->set_cmd_id(msg_type);
closure->request(0)->set_client_id(_client_id);
closure->request(0)->set_data(msg);
PsService_Stub rpc_stub(_client_channels[to_client_id].get());
rpc_stub.service(
closure->cntl(0), closure->request(0), closure->response(0), closure);
return fut;
}
std::future<int32_t> BrpcPsClient::PushSparseRawGradientPartial(
size_t table_id,
const uint64_t *keys,
const float **update_values,
uint32_t num,
void *done,
int pserver_idx) {
auto *accessor = GetTableAccessor(table_id);
size_t value_size = accessor->GetAccessorInfo().update_size;
DownpourBrpcClosure *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
// 发送RPC请求
auto *push_request = closure->request(0);
push_request->set_cmd_id(PS_PUSH_SPARSE_TABLE);
push_request->set_table_id(table_id);
push_request->set_client_id(_client_id);
push_request->add_params((char *)&num, sizeof(uint32_t)); // NOLINT
auto *push_data = push_request->mutable_data();
push_data->resize(num * (sizeof(uint64_t) + value_size));
char *push_data_ptr = const_cast<char *>(push_data->data());
memcpy(push_data_ptr, keys, num * sizeof(uint64_t));
push_data_ptr += num * sizeof(uint64_t);
for (uint32_t i = 0; i < num; ++i) {
memcpy(push_data_ptr, update_values[i], value_size);
push_data_ptr += value_size;
}
PsService_Stub rpc_stub(GetSparseChannel(pserver_idx));
closure->cntl(0)->set_request_compress_type(
(brpc::CompressType)FLAGS_pserver_communicate_compress_type);
rpc_stub.service(
closure->cntl(0), closure->request(0), closure->response(0), closure);
return fut;
}
int32_t BrpcPsClient::RecvAndSaveTable(const uint64_t table_id,
const std::string &path) {
// get var information
std::string var_name = "";
int64_t var_num = 0;
int64_t var_shape = 0;
std::string table_class;
const auto &worker_param = _config.worker_param().downpour_worker_param();
for (int i = 0; i < worker_param.downpour_table_param_size(); ++i) {
if (worker_param.downpour_table_param(i).table_id() == table_id) {
var_name = worker_param.downpour_table_param(i).common().table_name();
var_num = worker_param.downpour_table_param(i).common().table_num();
var_shape = worker_param.downpour_table_param(i).common().table_dim();
table_class = worker_param.downpour_table_param(i).table_class();
break;
}
}
PADDLE_ENFORCE_NE(
var_name,
"",
platform::errors::InvalidArgument(
"Cannot find table id %d to save variables.", table_id));
std::string var_store = string::Sprintf("%s", path);
MkDirRecursively(var_store.c_str());
// pull sparse from server
std::vector<float> save_huge_vec(var_num * var_shape);
std::vector<uint64_t> save_key(var_num);
std::vector<float *> save_vec;
for (size_t i = 0; i < save_key.size(); ++i) {
save_key[i] = i;
save_vec.push_back(save_huge_vec.data() + i * var_shape);
}
VLOG(2) << "RecvAndSaveTable: table_class: " << table_class;
// TODO(zhaocaibei123): new GeoBrpcPSClient, move this to its
// RecvAndSaveTable
if (table_class == "MemorySparseGeoTable") {
auto status = PullSparseParam(reinterpret_cast<float **>(save_vec.data()),
table_id,
save_key.data(),
save_key.size(),
true);
status.wait();
} else {
auto status = PullSparse(reinterpret_cast<float **>(save_vec.data()),
table_id,
save_key.data(),
save_key.size(),
true);
status.wait();
}
// create lod tensor
std::shared_ptr<framework::Scope> scope;
scope.reset(new framework::Scope());
auto place = platform::CPUPlace();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
framework::Variable *var = scope->Var(var_name);
framework::LoDTensor *var_tensor = var->GetMutable<framework::LoDTensor>();
std::vector<int64_t> vec_dim = {var_num, var_shape};
var_tensor->Resize(phi::make_ddim(vec_dim));
// copy and save
float *tensor_data = var_tensor->mutable_data<float>(place);
memcpy(
tensor_data, save_huge_vec.data(), var_num * var_shape * sizeof(float));
std::string file_name = string::Sprintf("%s/%s", var_store, var_name);
std::ofstream fout(file_name, std::ios::binary);
PADDLE_ENFORCE_EQ(static_cast<bool>(fout),
true,
platform::errors::Unavailable(
"Cannot open %s to save variables.", file_name));
framework::SerializeToStream(fout, *var_tensor, dev_ctx);
fout.close();
return 0;
}
std::future<int32_t> BrpcPsClient::PushSparse(size_t table_id,
const uint64_t *keys,
const float **update_values,
size_t num) {
auto push_timer = std::make_shared<CostTimer>("pserver_client_push_sparse");
CostTimer parse_timer("pserver_client_push_sparse_parse");
int push_sparse_async_num = _push_sparse_task_queue_map[table_id]->Size();
while (push_sparse_async_num > FLAGS_pserver_max_async_call_num) {
// LOG(INFO) << "PushSparse Waiting for async_call_num comsume,
// task_num:"
// << push_sparse_async_num
// << ", max_task_limit:" << FLAGS_pserver_max_async_call_num;
usleep(5000); // 5ms
push_sparse_async_num = _push_sparse_task_queue_map[table_id]->Size();
}
auto put_timer = std::make_shared<CostTimer>("client_push_sparse_put");
thread_local std::vector<std::vector<std::pair<uint64_t, const float *>>>
shard_sorted_kv_list;
auto *accessor = GetTableAccessor(table_id);
size_t request_call_num = _server_channels.size();
shard_sorted_kv_list.resize(request_call_num);
for (auto &x : shard_sorted_kv_list) {
x.clear();
}
const auto &server_param = _config.server_param().downpour_server_param();
uint64_t shard_num = FLAGS_pserver_sparse_table_shard_num;
for (int i = 0; i < server_param.downpour_table_param_size(); ++i) {
const auto &table_param = server_param.downpour_table_param(i);
if (table_param.table_id() == table_id) {
shard_num = table_param.shard_num();
break;
}
}
for (size_t i = 0; i < num; ++i) {
size_t shard_id = get_sparse_shard(shard_num, request_call_num, keys[i]);
shard_sorted_kv_list[shard_id].push_back({keys[i], update_values[i]});
}
auto sparse_task_data = _sparse_task_pool.get();
sparse_task_data->shared_data.resize(request_call_num);
auto async_task = new SparseAsyncTask(sparse_task_data, table_id, push_timer);
for (size_t i = 0; i < request_call_num; ++i) {
auto &sorted_kv_list = shard_sorted_kv_list[i];
size_t sorted_kv_size = sorted_kv_list.size();
auto &shard_kv_data = async_task->data()->shared_data[i];
shard_kv_data.key_list.resize(sorted_kv_size);
shard_kv_data.value_list.resize(sorted_kv_size);
if (sorted_kv_size == 0) {
shard_kv_data.kv_num = 0;
continue;
}
uint32_t value_size = accessor->GetAccessorInfo().update_size;
for (size_t kv_idx = 0; kv_idx < sorted_kv_size; ++kv_idx) {
shard_kv_data.key_list[kv_idx] = sorted_kv_list[kv_idx].first;
shard_kv_data.value_list[kv_idx].assign(
(const char *)sorted_kv_list[kv_idx].second, value_size);
}
shard_kv_data.kv_num = sorted_kv_size;
}
std::future<int> fut = async_task->get_future();
_push_sparse_task_queue_map[table_id]->Put(std::move(async_task));
return fut;
}
void BrpcPsClient::PushSparseTaskConsume() {
uint64_t merge_size = FLAGS_pserver_push_sparse_merge_limit;
std::vector<std::shared_ptr<SparseAsyncTask>> task_list;
size_t request_call_num = _server_channels.size();
::ThreadPool async_push_sparse_shard_threads(
FLAGS_pserver_sparse_merge_thread);
while (_running) {
auto async_start_time_ms = butil::gettimeofday_ms();
// 所有sparseTable的pushTask 进行处理
for (auto &push_sparse_task_itr : _push_sparse_task_queue_map) {
auto table_id = push_sparse_task_itr.first;
auto *accessor = GetTableAccessor(table_id);
auto &task_queue = push_sparse_task_itr.second;
auto queue_size = task_queue->Size();
if (queue_size == 0) {
continue;
}
if (merge_size > 0 && (queue_size <= 1 && _flushing == false)) {
continue;
}
++_async_call_num;
int merge_count = 0;
for (size_t i = 0; i < task_list.size(); ++i) {
if (task_list[i]->data()) {
_sparse_task_pool.push(task_list[i]->data());
}
}
auto sparse_task_data = _sparse_task_pool.get();
task_list.clear();
int cur_meger_size = task_queue->Size();
// task_list[0] 为一个空SparseAsyncTask, 分shard异步merge结果存入此结构。
sparse_task_data->shared_data.resize(request_call_num);
auto push_timer =
std::make_shared<CostTimer>("pserver_client_push_sparse");
auto async_task =
new SparseAsyncTask(sparse_task_data, table_id, push_timer);
task_list.reserve(cur_meger_size + 1);
task_list.push_back(
std::move(std::shared_ptr<SparseAsyncTask>(async_task)));
while (!task_queue->Empty() && merge_count < cur_meger_size) {
++merge_count;
SparseAsyncTask *task;
task_queue->Get(task);
task_list.push_back(std::shared_ptr<SparseAsyncTask>(task));
}
_push_sparse_merge_count_map[table_id] += merge_count;
// 达到或大于 merge_size发送, 发送过程中
std::vector<int> request_kv_num(request_call_num, 0);
if (_push_sparse_merge_count_map[table_id] >= merge_size ||
_flushing == true) {
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [this, request_call_num](void *done) {
int ret = 0;
auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, PS_PUSH_SPARSE_TABLE) != 0) {
ret = -1;
break;
}
}
closure->set_promise_value(ret);
--_async_call_num;
});
for_each(task_list.begin() + 1,
task_list.end(),
[&request_kv_num, request_call_num, closure](
std::shared_ptr<SparseAsyncTask> &task) {
closure->add_timer(task->timer());
closure->add_promise(task->promise());
});
CostTimer merge_timer("pserver_client_push_sparse_merge");
auto rpc_timer =
std::make_shared<CostTimer>("pserver_client_push_sparse_rpc");
closure->add_timer(rpc_timer);
std::vector<std::future<int>> merge_status(request_call_num);
for (size_t shard_idx = 0; shard_idx < request_call_num; ++shard_idx) {
merge_status[shard_idx] = async_push_sparse_shard_threads.enqueue(
std::bind(&BrpcPsClient::PushSparseAsyncShardPush,
this,
task_list,
request_kv_num,
table_id,
shard_idx,
closure,
accessor));
}
for (size_t shard_idx = 0; shard_idx < request_call_num; ++shard_idx) {
merge_status[shard_idx].wait();
}
merge_status.clear();
std::vector<std::future<int>>().swap(merge_status);
_push_sparse_merge_count_map[table_id] = 0;
} else { // 未达到阈值 只做多路归并
std::vector<std::future<int>> merge_status(request_call_num);
for (size_t shard_idx = 0; shard_idx < request_call_num; ++shard_idx) {
merge_status[shard_idx] = async_push_sparse_shard_threads.enqueue(
std::bind(&BrpcPsClient::PushSparseAsyncShardMerge,
this,
task_list,
request_kv_num,
table_id,
shard_idx,
accessor));
}
for (size_t shard_idx = 0; shard_idx < request_call_num; ++shard_idx) {
merge_status[shard_idx].wait();
}
// meger到task_list[0]
auto async_task = new SparseAsyncTask(*(task_list[0].get()));
task_queue->Put(std::move(async_task));
--_async_call_num;
merge_status.clear();
std::vector<std::future<int>>().swap(merge_status);
}
}
auto wait_ms = FLAGS_pserver_async_push_sparse_interval_ms -
(butil::gettimeofday_ms() - async_start_time_ms);
if (wait_ms > 0) {
usleep(wait_ms * 1000);
}
}
}
void sparse_local_merge(ValueAccessor *accessor,
float *merge_data,
const float *another_data) {
size_t col_num = accessor->GetAccessorInfo().update_dim;
float *merge_data_shell[col_num];
const float *another_data_shell[col_num];
for (size_t i = 0; i < col_num; ++i) {
merge_data_shell[i] = merge_data + i;
another_data_shell[i] = another_data + i;
}
accessor->Merge(merge_data_shell, another_data_shell, 1);
}
int BrpcPsClient::PushSparseAsyncShardMerge(
std::vector<std::shared_ptr<SparseAsyncTask>> &task_list,
std::vector<int> &request_kv_num,
int table_id,
int shard_idx,
ValueAccessor *accessor) {
size_t merged_kv_count = 0;
uint32_t value_size = accessor->GetAccessorInfo().update_size;
thread_local std::vector<std::pair<uint64_t, const float *>> sorted_kv_list;
sorted_kv_list.clear();
for (size_t i = 1; i < task_list.size(); ++i) {
size_t kv_num = task_list[i]->data()->shared_data[shard_idx].kv_num;
auto &key_list = task_list[i]->data()->shared_data[shard_idx].key_list;
auto &value_list = task_list[i]->data()->shared_data[shard_idx].value_list;
for (size_t j = 0; j < kv_num; ++j) {
if (value_list[j].size() < value_size) {
LOG(WARNING) << "value_list[" << j << "]: " << value_list[j].c_str()
<< "is invalid.";
continue;
}
char *task_data_ptr = const_cast<char *>(value_list[j].data());
sorted_kv_list.push_back(
{key_list[j], reinterpret_cast<float *>(task_data_ptr)});
}
}
// 按key排序&去重
std::sort(sorted_kv_list.begin(),
sorted_kv_list.end(),
[](const std::pair<uint64_t, const float *> &k1,
const std::pair<uint64_t, const float *> &k2) {
return k1.first < k2.first;
});
auto &async_task = task_list[0];
size_t sorted_kv_size = sorted_kv_list.size();
auto &shard_kv_data = async_task->data()->shared_data[shard_idx];
shard_kv_data.key_list.resize(sorted_kv_size);
shard_kv_data.value_list.resize(sorted_kv_size);
// 将去重后数据写入分shard包
if (sorted_kv_size == 0) {
shard_kv_data.kv_num = 0;
return 0;
} else if (sorted_kv_size == 1) {
shard_kv_data.kv_num = 1;
shard_kv_data.key_list[0] = sorted_kv_list[0].first;
shard_kv_data.value_list[0].assign((const char *)(sorted_kv_list[0].second),
value_size);
return 0;
}
// 去重 本地merge
uint64_t last_key = sorted_kv_list[0].first;
const float *last_value_data = sorted_kv_list[0].second;
float *last_merge_data = NULL;
std::shared_ptr<char> merger_buffer(new char[value_size],
array_deleter<char>());
for (size_t kv_idx = 1; kv_idx < sorted_kv_size; ++kv_idx) {
while (kv_idx < sorted_kv_size &&
last_key == sorted_kv_list[kv_idx].first) {
if (last_merge_data == NULL) {
last_merge_data = reinterpret_cast<float *>(merger_buffer.get());
memcpy(last_merge_data, last_value_data, value_size);
}
sparse_local_merge(
accessor, last_merge_data, sorted_kv_list[kv_idx].second);
++kv_idx;
}
if (last_merge_data != NULL) {
shard_kv_data.value_list[merged_kv_count].assign(
(const char *)last_merge_data, value_size);
last_merge_data = NULL;
} else {
shard_kv_data.value_list[merged_kv_count].assign(
(const char *)sorted_kv_list[kv_idx - 1].second, value_size);
}
shard_kv_data.key_list[merged_kv_count++] = last_key;
if (kv_idx < sorted_kv_size) {
last_key = sorted_kv_list[kv_idx].first;
last_value_data = sorted_kv_list[kv_idx].second;
}
if (kv_idx == sorted_kv_size - 1) {
shard_kv_data.value_list[merged_kv_count].assign(
(const char *)last_value_data, value_size);
shard_kv_data.key_list[merged_kv_count++] = last_key;
}
}
shard_kv_data.kv_num = merged_kv_count;
return 0;
}
int BrpcPsClient::PushSparseAsyncShardPush(
std::vector<std::shared_ptr<SparseAsyncTask>> &task_list,
std::vector<int> &request_kv_num,
int table_id,
int shard_idx,
DownpourBrpcClosure *closure,
ValueAccessor *accessor) {
PushSparseAsyncShardMerge(
task_list, request_kv_num, table_id, shard_idx, accessor);
size_t merged_kv_count = task_list[0]->data()->shared_data[shard_idx].kv_num;
auto &merged_key_list = task_list[0]->data()->shared_data[shard_idx].key_list;
auto &merged_value_list =
task_list[0]->data()->shared_data[shard_idx].value_list;
// 发送RPC请求
auto *push_request = closure->request(shard_idx);
push_request->set_cmd_id(PS_PUSH_SPARSE_TABLE);
push_request->set_table_id(table_id);
push_request->set_client_id(_client_id);
push_request->add_params(reinterpret_cast<char *>(&merged_kv_count),
sizeof(uint32_t)); // NOLINT
auto *push_data = push_request->mutable_data();
int update_size = accessor->GetAccessorInfo().update_size;
push_data->resize(merged_kv_count * (sizeof(uint64_t) + update_size));
char *push_data_ptr = const_cast<char *>(push_data->data());
memcpy(push_data_ptr,
merged_key_list.data(),
merged_kv_count * sizeof(uint64_t));
push_data_ptr += merged_kv_count * sizeof(uint64_t);
for (size_t i = 0; i < merged_kv_count; ++i) {
const char *task_data_ptr = merged_value_list[i].data();
memcpy(push_data_ptr,
(float *)(task_data_ptr), // NOLINT
update_size);
push_data_ptr += update_size;
}
PsService_Stub rpc_stub(GetSparseChannel(shard_idx));
closure->cntl(shard_idx)->set_request_compress_type(
(brpc::CompressType)FLAGS_pserver_communicate_compress_type);
rpc_stub.service(closure->cntl(shard_idx),
closure->request(shard_idx),
closure->response(shard_idx),
closure);
_push_sparse_merge_count_map[table_id] = 0;
return 0;
}
std::future<int32_t> BrpcPsClient::PushDense(const Region *regions,
size_t region_num,
size_t table_id) {
auto *accessor = GetTableAccessor(table_id);
int fea_dim = accessor->GetAccessorInfo().fea_dim;
int update_dim = accessor->GetAccessorInfo().update_dim;
auto push_timer = std::make_shared<CostTimer>("pserver_client_push_dense");
auto parse_timer =
std::make_shared<CostTimer>("pserver_client_push_dense_parse");
int push_dense_async_num = _push_dense_task_queue_map[table_id]->Size();
while (push_dense_async_num > FLAGS_pserver_max_async_call_num) {
// LOG(INFO) << "PushDense Waiting for async_call_num comsume,
// task_num:"
// << push_dense_async_num
// << ", max_task_limit:" << FLAGS_pserver_max_async_call_num;
usleep(5000); // 5ms
push_dense_async_num = _push_dense_task_queue_map[table_id]->Size();
}
auto push_dense_timer = std::make_shared<CostTimer>("push_dense_put");
// auto dense_data = _dense_matrix_obj_pool.get();
auto dense_data = std::make_shared<std::vector<float>>();
auto async_task = new DenseAsyncTask(dense_data, table_id, push_timer);
size_t request_call_num = _server_channels.size();
uint32_t num_per_shard = DenseDimPerShard(fea_dim, request_call_num);
// 将region数据拷贝到转置矩阵中
async_task->data()->resize(num_per_shard * request_call_num * update_dim);
float *data = async_task->data()->data();
size_t data_size = async_task->data()->size();
uint32_t pos = 0;
for (size_t i = 0; i < region_num; ++i) {
uint32_t data_num = regions[i].size / sizeof(float);
CHECK(pos + data_num <= data_size)
<< "invalid dense size, cur pos[" << pos << "]"
<< " data_num[" << data_num << "] size[" << data_size << "]";
const float *region_data = (const float *)(regions[i].data);
memcpy(data + pos, region_data, regions[i].size);
pos += data_num;
}
std::future<int> fut = async_task->get_future();
_push_dense_task_queue_map[table_id]->Put(std::move(async_task));
return fut;
}
void BrpcPsClient::PushDenseTaskConsume() {
uint64_t merge_size = FLAGS_pserver_push_dense_merge_limit;
static bool scale_gradient = FLAGS_pserver_scale_gradient_by_merge;
::ThreadPool async_merge_dense_threads(10);
while (_running) {
auto async_start_time_ms = butil::gettimeofday_ms();
for (auto &task_queue_itr : _push_dense_task_queue_map) {
auto &task_queue = task_queue_itr.second;
auto queue_size = task_queue->Size();
if (queue_size == 0) {
continue;
}
if (queue_size <= merge_size && _flushing == false) {
continue;
}
++_async_call_num;
DenseAsyncTask *task;
task_queue->Get(task);
auto *accessor = GetTableAccessor(task->table_id());
// 设置请求回调
size_t request_call_num = _server_channels.size();
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [this, request_call_num](void *done) {
int ret = 0;
auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, PS_PUSH_DENSE_TABLE) != 0) {
ret = -1;
break;
}
}
closure->set_promise_value(ret);
--_async_call_num;
});
auto &total_send_data_vec = *(task->data());
float *total_send_data =
reinterpret_cast<float *>(total_send_data_vec.data());
size_t total_send_data_size = total_send_data_vec.size();
{
CostTimer merge_timer("pserver_client_push_dense_merge");
uint32_t merge_count = 0;
std::vector<std::future<int>> merge_status(merge_size);
while (!task_queue->Empty() && merge_count < merge_size) {
auto *async_task = new DenseAsyncTask();
task_queue->Get(async_task);
closure->add_timer(async_task->timer());
closure->add_promise(async_task->promise());
merge_status[merge_count] =
async_merge_dense_threads.enqueue([closure,
accessor,
&total_send_data,
total_send_data_size,
async_task]() -> int {
auto &tmp_task_vec = *(async_task->data());
const float *merge_data = tmp_task_vec.data();
accessor->Merge(
&total_send_data, &merge_data, total_send_data_size);
#pragma optimize("", off)
delete async_task;
#pragma optimize("", on)
return 0;
});
++merge_count;
}
for (size_t i = 0; i < merge_count; ++i) {
merge_status[i].wait();
}
VLOG(3) << "BrpcPsClient::PushDenseTaskConsume before merge "
"total_send_data[0]"
<< total_send_data[0] << " total_send_data[-2]"
<< total_send_data[total_send_data_size - 2]
<< total_send_data[0] << " total_send_data[-1]"
<< total_send_data[total_send_data_size - 1];
if (scale_gradient && merge_count > 1) {
Eigen::Map<Eigen::MatrixXf> mat(
total_send_data, 1, total_send_data_size);
mat *= (1.0 / (merge_count + 1));
}
VLOG(3) << "BrpcPsClient::PushDenseTaskConsume after merge "
"total_send_data[0]"
<< total_send_data[0] << " total_send_data[-2]"
<< total_send_data[total_send_data_size - 2]
<< " total_send_data[-1]"
<< total_send_data[total_send_data_size - 1] << " merge_count "
<< merge_count;
}
std::shared_ptr<DenseAsyncTask> task_ptr(task);
PushDenseRawGradient(
task_ptr, total_send_data, total_send_data_size, closure);
}
auto wait_ms = FLAGS_pserver_async_push_dense_interval_ms -
(butil::gettimeofday_ms() - async_start_time_ms);
if (wait_ms > 0) {
usleep(wait_ms * 1000);
}
}
}
void BrpcPsClient::PushDenseRawGradient(std::shared_ptr<DenseAsyncTask> &task,
float *total_send_data,
size_t total_send_data_size,
DownpourBrpcClosure *closure) {
auto *accessor = GetTableAccessor(task->table_id());
size_t request_call_num = _server_channels.size();
// 将数据拷贝到请求buffer区
auto timer = std::make_shared<CostTimer>("pserver_client_push_dense_rpc");
closure->add_timer(timer);
uint32_t num_per_shard =
DenseDimPerShard(accessor->GetAccessorInfo().fea_dim, request_call_num);
auto send_timer =
std::make_shared<CostTimer>("pserver_client_push_dense_send");
for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(PS_PUSH_DENSE_TABLE);
closure->request(i)->set_table_id(task->table_id());
closure->request(i)->set_client_id(_client_id);
auto *push_data = closure->request(i)->mutable_data();
push_data->clear();
push_data->resize(sizeof(uint32_t) + num_per_shard * sizeof(float));
char *push_data_ptr = const_cast<char *>(push_data->data());
memcpy(push_data_ptr, &num_per_shard, sizeof(uint32_t));
memcpy(push_data_ptr + sizeof(uint32_t),
total_send_data + i * num_per_shard,
num_per_shard * sizeof(float));
closure->cntl(i)->set_request_compress_type(
(brpc::CompressType)FLAGS_pserver_communicate_compress_type);
PsService_Stub rpc_stub(GetDenseChannel(i));
rpc_stub.service(
closure->cntl(i), closure->request(i), closure->response(i), closure);
}
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ThreadPool.h>
#include <memory>
#include <string>
#include <vector>
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/ps/service/brpc_utils.h"
#include "paddle/fluid/distributed/ps/service/ps_client.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor_util.h"
namespace brpc {
class Channel;
class Controller;
} // namespace brpc
namespace google {
namespace protobuf {
class Closure;
class RpcController;
} // namespace protobuf
} // namespace google
namespace paddle {
namespace distributed {
struct Region;
class DownpourPsClientService : public PsService {
public:
DownpourPsClientService() {}
virtual ~DownpourPsClientService() {}
virtual int32_t Configure(PSClient *client, size_t rank_id) {
_client = client;
_rank = rank_id;
return 0;
}
virtual void service(::google::protobuf::RpcController *controller,
const PsRequestMessage *request,
PsResponseMessage *response,
::google::protobuf::Closure *done);
virtual void FLService(::google::protobuf::RpcController *controller,
const CoordinatorReqMessage *request,
CoordinatorResMessage *response,
::google::protobuf::Closure *done) {
brpc::ClosureGuard done_guard(done);
size_t client_id = request->client_id();
CHECK(_client->_client_id == client_id)
<< "request client id not matched self";
_fl_strategy = request->str_params();
_is_fl_strategy_ready = true;
response->set_err_code(0);
response->set_err_msg("");
VLOG(0) << "fl-ps > DownpourPsClientService::FLService finished!";
return;
}
public:
std::string _fl_strategy;
bool _is_fl_strategy_ready = false;
protected:
size_t _rank;
PSClient *_client;
};
class FlClientBrpcClosure : public PSClientClosure {
public:
FlClientBrpcClosure(size_t num, PSClientCallBack callback)
: PSClientClosure(callback) {
_waiting_num = num;
_cntls.resize(num);
_requests.resize(num);
_responses.resize(num);
for (size_t i = 0; i < num; ++i) {
_cntls[i].reset(new brpc::Controller());
}
}
virtual ~FlClientBrpcClosure() {}
void Run() override {
if (_waiting_num.fetch_sub(1) == 1) {
_callback(this);
delete this;
}
}
CoordinatorReqMessage *request(size_t i) { return &_requests[i]; }
CoordinatorResMessage *response(size_t i) { return &_responses[i]; }
brpc::Controller *cntl(size_t i) { return _cntls[i].get(); }
int check_response(size_t request_idx, int cmd_id);
int check_save_response(size_t request_idx, int cmd_id);
std::string get_response(size_t request_idx, int cmd_id);
private:
std::atomic<int32_t> _waiting_num;
std::vector<CoordinatorReqMessage> _requests;
std::vector<CoordinatorResMessage> _responses;
std::vector<std::shared_ptr<brpc::Controller>> _cntls;
};
class DownpourBrpcClosure : public PSClientClosure {
public:
DownpourBrpcClosure(size_t num, PSClientCallBack callback)
: PSClientClosure(callback) {
_waiting_num = num;
_cntls.resize(num);
_requests.resize(num);
_responses.resize(num);
for (size_t i = 0; i < num; ++i) {
_cntls[i].reset(new brpc::Controller());
}
}
virtual ~DownpourBrpcClosure() {}
void Run() override {
if (_waiting_num.fetch_sub(1) == 1) {
_callback(this);
delete this;
}
}
PsRequestMessage *request(size_t i) { return &_requests[i]; }
PsResponseMessage *response(size_t i) { return &_responses[i]; }
brpc::Controller *cntl(size_t i) { return _cntls[i].get(); }
int check_response(size_t request_idx, int cmd_id);
int check_save_response(size_t request_idx, int cmd_id);
std::string get_response(size_t request_idx, int cmd_id);
private:
std::atomic<int32_t> _waiting_num;
std::vector<PsRequestMessage> _requests;
std::vector<PsResponseMessage> _responses;
std::vector<std::shared_ptr<brpc::Controller>> _cntls;
};
struct SharedSparsePushData {
SharedSparsePushData() {}
~SharedSparsePushData() noexcept {}
size_t kv_num;
std::vector<uint64_t> key_list;
std::vector<std::string> value_list;
};
struct SparsePushTaskData {
std::vector<SharedSparsePushData> shared_data; // sparse数据按key hash分片
};
// push sparse 对象池
struct SparseTaskPool {
std::shared_ptr<SparsePushTaskData> get() {
std::lock_guard<std::mutex> lock(_mutex);
if (_pool.empty()) {
return std::make_shared<SparsePushTaskData>();
} else {
auto ret = _pool.back();
_pool.pop_back();
return ret;
}
}
void push(std::shared_ptr<SparsePushTaskData> data) {
std::lock_guard<std::mutex> lock(_mutex);
_pool.push_back(std::move(data));
}
std::vector<std::shared_ptr<SparsePushTaskData>> _pool;
std::mutex _mutex;
};
template <class T>
struct array_deleter {
void operator()(T *&x) const { delete[] x; } // NOLINT
};
class BrpcPsClient : public PSClient {
public:
BrpcPsClient() {}
virtual ~BrpcPsClient() {
if (_running) {
Flush();
_running = false;
}
if (_async_push_dense_thread.joinable()) {
_async_push_dense_thread.join();
}
if (_async_push_sparse_thread.joinable()) {
_async_push_sparse_thread.join();
}
if (_server_started) {
_server.Stop(1000);
_server.Join();
_server_started = false;
}
}
virtual int32_t CreateClient2ClientConnection(int pserver_timeout_ms,
int pserver_connect_timeout_ms,
int max_retry);
std::future<int32_t> Shrink(uint32_t table_id,
const std::string threshold) override;
std::future<int32_t> Load(const std::string &epoch,
const std::string &mode) override;
std::future<int32_t> Load(uint32_t table_id,
const std::string &epoch,
const std::string &mode) override;
std::future<int32_t> Save(const std::string &epoch,
const std::string &mode) override;
std::future<int32_t> Save(uint32_t table_id,
const std::string &epoch,
const std::string &mode) override;
std::future<int32_t> Clear() override;
std::future<int32_t> Clear(uint32_t table_id) override;
std::future<int32_t> Revert() override;
std::future<int32_t> CheckSavePrePatchDone() override;
std::future<int32_t> StopServer() override;
std::future<int32_t> StartProfiler() override;
std::future<int32_t> StopProfiler() override;
void FinalizeWorker() override;
virtual std::future<int32_t> PullDense(Region *regions,
size_t region_num,
size_t table_id);
virtual std::future<int32_t> PushDenseParam(const Region *regions,
size_t region_num,
size_t table_id);
virtual std::future<int32_t> PushDense(const Region *regions,
size_t region_num,
size_t table_id);
void PushDenseTaskConsume();
virtual std::future<int32_t> PullSparse(float **select_values,
size_t table_id,
const uint64_t *keys,
size_t num,
bool is_training);
virtual std::future<int32_t> PullSparseParam(float **select_values,
size_t table_id,
const uint64_t *keys,
size_t num,
bool is_training);
virtual std::future<int32_t> PrintTableStat(uint32_t table_id);
virtual std::future<int32_t> Barrier(size_t table_id, uint32_t barrier_type);
virtual std::future<int32_t> PullGeoParam(size_t table_id,
std::vector<float> *values,
std::vector<uint64_t> *keys,
int pserver_idx);
virtual std::future<int32_t> PushGlobalStep(int table_id,
int64_t *total_send_data,
void *done);
virtual std::future<int32_t> Flush();
std::future<int32_t> SendClient2ClientMsg(int msg_type,
int to_client_id,
const std::string &msg) override;
// for local save sparse
virtual int32_t RecvAndSaveTable(const uint64_t table_id,
const std::string &path);
std::future<int32_t> CacheShuffle(
uint32_t table_id,
const std::string &path,
const std::string &mode,
const std::string &cache_threshold) override;
std::future<int32_t> CacheShuffleMultiTable(
std::vector<int> tables,
const std::string &path,
const std::string &mode,
const std::string &cache_threshold);
std::future<int32_t> SaveCache(uint32_t table_id,
const std::string &path,
const std::string &mode) override;
std::future<int32_t> GetCacheThreshold(uint32_t table_id,
double &cache_threshold) override;
void PrintQueueSize();
void PrintQueueSizeThread();
protected:
virtual size_t GetServerNums() { return _server_channels.size(); }
inline brpc::Channel *GetSparseChannel(size_t server_id) {
return _server_channels[server_id][0].get();
}
inline brpc::Channel *GetDenseChannel(size_t server_id) {
return _server_channels[server_id][1].get();
}
inline brpc::Channel *GetCmdChannel(size_t server_id) {
return _server_channels[server_id][2].get();
}
int32_t Initialize() override;
// for fl
public:
virtual int32_t InitializeFlWorker(const std::string &self_endpoint);
int32_t StartFlClientService(const std::string &self_endpoint);
virtual void PushFLClientInfoSync(const std::string &fl_client_info);
std::string PullFlStrategy();
// for fl
private:
inline uint32_t DenseDimPerShard(uint32_t dense_dim_total,
uint32_t shard_num) {
return dense_dim_total / shard_num + 1;
}
std::future<int32_t> SendCmd(uint32_t table_id,
int cmd_id,
const std::vector<std::string> &param);
std::future<int32_t> SendSaveCmd(uint32_t table_id,
int cmd_id,
const std::vector<std::string> &param);
bool _running = false;
bool _flushing = false;
std::atomic<uint32_t> _async_call_num; // 异步请求计数
// 异步push dense task
std::thread _async_push_dense_thread;
typedef AsyncRequestTask<std::shared_ptr<std::vector<float>>> DenseAsyncTask;
std::unordered_map<uint32_t, paddle::framework::Channel<DenseAsyncTask *>>
_push_dense_task_queue_map;
// 异步push sparse task
std::thread _async_push_sparse_thread;
typedef AsyncRequestTask<std::shared_ptr<SparsePushTaskData>> SparseAsyncTask;
std::unordered_map<uint32_t, paddle::framework::Channel<SparseAsyncTask *>>
_push_sparse_task_queue_map;
std::unordered_map<uint32_t, uint32_t> _push_sparse_merge_count_map;
std::thread _print_thread;
int PushSparseAsyncShardMerge(
std::vector<std::shared_ptr<SparseAsyncTask>> &task_list, // NOLINT
std::vector<int> &request_kv_num, // NOLINT
int table_id,
int shard_idx,
ValueAccessor *accessor);
int PushSparseAsyncShardPush(
std::vector<std::shared_ptr<SparseAsyncTask>> &task_list, // NOLINT
std::vector<int> &request_kv_num, // NOLINT
int table_id,
int shard_idx,
DownpourBrpcClosure *closure,
ValueAccessor *accessor);
SparseTaskPool _sparse_task_pool;
std::vector<std::shared_ptr<brpc::Channel>>
_client_channels; // client2client
std::vector<std::array<std::shared_ptr<brpc::Channel>, 3>>
_server_channels; // client2server
std::vector<std::array<std::shared_ptr<brpc::Channel>, 1>>
_coordinator_channels; // client2coordinator
std::future<int32_t> PushDenseRawGradient(int table_id,
float *total_send_data,
size_t total_send_data_size,
void *done) override;
std::future<int32_t> PushSparseRawGradient(size_t table_id,
const uint64_t *keys,
const float **update_values,
size_t num,
void *done) override;
std::future<int32_t> PushSparseRawGradientPartial(size_t table_id,
const uint64_t *keys,
const float **update_values,
uint32_t num,
void *done,
int pserver_idx) override;
std::future<int32_t> PushSparseParam(size_t table_id,
const uint64_t *keys,
const float **update_values,
size_t num,
void *done) override;
std::future<int32_t> PushSparse(size_t table_id,
const uint64_t *keys,
const float **update_values,
size_t num) override;
void PushSparseTaskConsume();
private:
int32_t StartClientService();
void PushDenseRawGradient(std::shared_ptr<DenseAsyncTask> &task, // NOLINT
float *total_send_data,
size_t total_send_data_size,
DownpourBrpcClosure *closure);
float _mae = 0;
float _mse = 0;
uint16_t _push_times = 0;
brpc::Server _server;
brpc::Server _fl_server;
DownpourPsClientService _service;
bool _server_started = false;
std::atomic_uint grad_num_{0};
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/brpc_ps_server.h"
#include <thread> // NOLINT
#include "butil/object_pool.h"
#include "paddle/fluid/distributed/common/cost_timer.h"
#include "paddle/fluid/distributed/ps/table/depends/sparse_utils.h"
#include "paddle/fluid/distributed/ps/table/table.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/platform/profiler.h"
namespace google {
namespace protobuf {
class Closure;
class RpcController;
} // namespace protobuf
} // namespace google
DEFINE_int32(pserver_timeout_ms_s2s,
10000,
"pserver request server timeout_ms");
DEFINE_int32(pserver_connect_timeout_ms_s2s,
10000,
"pserver connect server timeout_ms");
DEFINE_string(pserver_connection_type_s2s,
"pooled",
"pserver connection_type[pooled:single]");
namespace paddle {
namespace distributed {
int32_t BrpcPsServer::Initialize() {
auto &service_config = _config.downpour_server_param().service_param();
if (!service_config.has_service_class()) {
LOG(ERROR) << "miss service_class in ServerServiceParameter";
return -1;
}
auto *service =
CREATE_PSCORE_CLASS(PsBaseService, service_config.service_class());
if (service == NULL) {
LOG(ERROR) << "service is unregistered, service_name:"
<< service_config.service_class();
return -1;
}
_service.reset(service);
if (service->Configure(this) != 0 || service->Initialize() != 0) {
LOG(ERROR) << "service initialize failed, service_name:"
<< service_config.service_class();
return -1;
}
if (_server.AddService(service, brpc::SERVER_DOESNT_OWN_SERVICE) != 0) {
LOG(ERROR) << "service add to brpc failed, service:"
<< service_config.service_class();
return -1;
}
return 0;
}
uint64_t BrpcPsServer::Start(const std::string &ip, uint32_t port) {
std::unique_lock<std::mutex> lock(mutex_);
std::string ip_port = ip + ":" + std::to_string(port);
VLOG(0) << "running server with rank id: " << _rank
<< ", endpoint: " << ip_port;
brpc::ServerOptions options;
int num_threads = std::thread::hardware_concurrency();
auto trainers = _environment->GetTrainers();
options.num_threads = trainers > num_threads ? trainers : num_threads;
if (_server.Start(ip_port.c_str(), &options) != 0) {
VLOG(0) << "BrpcPsServer start failed, ip_port= " << ip_port
<< " , Try Again.";
std::string int_ip_port = GetIntTypeEndpoint(ip, port);
if (_server.Start(int_ip_port.c_str(), &options) != 0) {
LOG(ERROR) << "BrpcPsServer start failed, ip_port= " << int_ip_port;
return 0;
}
}
_environment->RegistePsServer(ip, port, _rank);
cv_.wait(lock, [&] { return stoped_; });
PSHost host;
host.ip = ip;
host.port = port;
host.rank = _rank;
return host.rank;
}
int32_t BrpcPsServer::StartS2S() {
brpc::ChannelOptions options;
options.protocol = "baidu_std";
options.timeout_ms = FLAGS_pserver_timeout_ms_s2s;
options.connection_type = FLAGS_pserver_connection_type_s2s;
options.connect_timeout_ms = FLAGS_pserver_connect_timeout_ms_s2s;
options.max_retry = 3;
std::vector<PSHost> pserver_list = _environment->GetPsServers();
_pserver_channels.resize(pserver_list.size());
VLOG(2) << "pserver start s2s server_list size: " << _pserver_channels.size();
std::ostringstream os;
std::string server_ip_port;
for (size_t i = 0; i < pserver_list.size(); ++i) {
server_ip_port.assign(pserver_list[i].ip.c_str());
server_ip_port.append(":");
server_ip_port.append(std::to_string(pserver_list[i].port));
_pserver_channels[i].reset(new brpc::Channel());
if (_pserver_channels[i]->Init(server_ip_port.c_str(), "", &options) != 0) {
LOG(ERROR) << "pserver connect to pserver:" << server_ip_port
<< " Failed!";
}
os << server_ip_port << ",";
}
LOG(INFO) << "pserver connect success: " << os.str();
return 0;
}
std::future<int32_t> BrpcPsServer::SendPServer2PServerMsg(
int msg_type, int to_pserver_id, const std::string &msg) {
auto promise = std::make_shared<std::promise<int32_t>>();
std::future<int> fut = promise->get_future();
if (static_cast<size_t>(to_pserver_id) >= _pserver_channels.size()) {
LOG(FATAL) << "to_pserver_id is out of range pservers, which size is "
<< _pserver_channels.size();
promise->set_value(-1);
return fut;
}
auto *closure = new DownpourPServerBrpcClosure(1, [msg_type](void *done) {
auto *closure = reinterpret_cast<DownpourPServerBrpcClosure *>(done);
int32_t ret = closure->check_response(0, msg_type + 1000);
closure->set_promise_value(ret);
});
closure->add_promise(promise);
closure->request(0)->set_cmd_id(101);
closure->request(0)->set_client_id(_rank);
closure->request(0)->set_table_id(0);
closure->request(0)->set_data(msg);
PsService_Stub rpc_stub(_pserver_channels[to_pserver_id].get());
rpc_stub.service(
closure->cntl(0), closure->request(0), closure->response(0), closure);
return fut;
}
int32_t BrpcPsServer::ReceiveFromPServer(int msg_type,
int pserver_id,
const std::string &msg) {
if (msg.length() == 0) {
LOG(WARNING) << "SERVER>>RESPONSE>>msg = 0 Finish S2S Response";
return 0;
}
paddle::framework::BinaryArchive ar;
ar.SetReadBuffer(const_cast<char *>(msg.c_str()), msg.length(), nullptr);
if (ar.Cursor() == ar.Finish()) {
LOG(WARNING) << "SERVER>>RESPONSE ar = 0>> Finish S2S Response";
return 0;
}
std::vector<std::pair<uint64_t, std::string>> data;
while (ar.Cursor() < ar.Finish()) {
data.push_back(ar.Get<std::pair<uint64_t, std::string>>());
}
CHECK(ar.Cursor() == ar.Finish());
this->_shuffled_ins->Write(std::move(data));
return 0;
}
int32_t BrpcPsServer::Port() { return _server.listen_address().port; }
int32_t BrpcPsService::Initialize() {
_is_initialize_shard_info = false;
_service_handler_map[PS_STOP_SERVER] = &BrpcPsService::StopServer;
_service_handler_map[PS_PULL_DENSE_TABLE] = &BrpcPsService::PullDense;
_service_handler_map[PS_PUSH_DENSE_TABLE] = &BrpcPsService::PushDense;
_service_handler_map[PS_PULL_SPARSE_TABLE] = &BrpcPsService::PullSparse;
_service_handler_map[PS_PUSH_SPARSE_TABLE] = &BrpcPsService::PushSparse;
_service_handler_map[PS_SAVE_ONE_TABLE] = &BrpcPsService::SaveOneTable;
_service_handler_map[PS_SAVE_ALL_TABLE] = &BrpcPsService::SaveAllTable;
_service_handler_map[PS_SHRINK_TABLE] = &BrpcPsService::ShrinkTable;
_service_handler_map[PS_LOAD_ONE_TABLE] = &BrpcPsService::LoadOneTable;
_service_handler_map[PS_LOAD_ALL_TABLE] = &BrpcPsService::LoadAllTable;
_service_handler_map[PS_CLEAR_ONE_TABLE] = &BrpcPsService::ClearOneTable;
_service_handler_map[PS_CLEAR_ALL_TABLE] = &BrpcPsService::ClearAllTable;
_service_handler_map[PS_PUSH_DENSE_PARAM] = &BrpcPsService::PushDenseParam;
_service_handler_map[PS_PRINT_TABLE_STAT] = &BrpcPsService::PrintTableStat;
_service_handler_map[PS_PULL_GEO_PARAM] = &BrpcPsService::PullGeoParam;
_service_handler_map[PS_PUSH_SPARSE_PARAM] = &BrpcPsService::PushSparseParam;
_service_handler_map[PS_BARRIER] = &BrpcPsService::Barrier;
_service_handler_map[PS_START_PROFILER] = &BrpcPsService::StartProfiler;
_service_handler_map[PS_STOP_PROFILER] = &BrpcPsService::StopProfiler;
_service_handler_map[PS_PUSH_GLOBAL_STEP] = &BrpcPsService::PushGlobalStep;
// for save cache
_service_handler_map[PS_SAVE_ONE_CACHE_TABLE] =
&BrpcPsService::SaveCacheTable;
_service_handler_map[PS_GET_CACHE_THRESHOLD] =
&BrpcPsService::GetCacheThreshold;
_service_handler_map[PS_CACHE_SHUFFLE] = &BrpcPsService::CacheShuffle;
_service_handler_map[PS_REVERT] = &BrpcPsService::Revert;
_service_handler_map[PS_CHECK_SAVE_PRE_PATCH_DONE] =
&BrpcPsService::CheckSavePrePatchDone;
auto &profiler = CostProfiler::instance();
profiler.register_profiler("pserver_server_pull_dense");
profiler.register_profiler("pserver_server_push_dense");
profiler.register_profiler("pserver_server_pull_sparse");
profiler.register_profiler("pserver_server_push_sparse");
// shard初始化,server启动后才可从env获取到server_list的shard信息
InitializeShardInfo();
return 0;
}
#define CHECK_TABLE_EXIST(table, request, response) \
if (table == NULL) { \
std::string err_msg("table not found with table_id:"); \
err_msg.append(std::to_string(request.table_id())); \
set_response_code(response, -1, err_msg.c_str()); \
return -1; \
}
int32_t BrpcPsService::InitializeShardInfo() {
if (!_is_initialize_shard_info) {
std::lock_guard<std::mutex> guard(_initialize_shard_mutex);
if (_is_initialize_shard_info) {
return 0;
}
size_t shard_num = _server->Environment()->GetPsServers().size();
auto &table_map = *(_server->GetTable());
for (auto itr : table_map) {
itr.second->SetShard(_rank, shard_num);
}
_is_initialize_shard_info = true;
}
return 0;
}
void BrpcPsService::service(google::protobuf::RpcController *cntl_base,
const PsRequestMessage *request,
PsResponseMessage *response,
google::protobuf::Closure *done) {
brpc::ClosureGuard done_guard(done);
std::string log_label("ReceiveCmd-");
if (!request->has_table_id()) {
set_response_code(*response, -1, "PsRequestMessage.tabel_id is required");
return;
}
response->set_err_code(0);
response->set_err_msg("");
auto *table = _server->GetTable(request->table_id());
brpc::Controller *cntl = static_cast<brpc::Controller *>(cntl_base);
if (request->cmd_id() < 100) {
auto itr = _service_handler_map.find(request->cmd_id());
if (itr == _service_handler_map.end()) {
std::string err_msg(
"undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:");
err_msg.append(std::to_string(request->cmd_id()));
set_response_code(*response, -1, err_msg.c_str());
return;
}
serviceHandlerFunc handler_func = itr->second;
int service_ret = (this->*handler_func)(table, *request, *response, cntl);
if (service_ret != 0) {
response->set_err_code(service_ret);
response->set_err_msg("server internal error");
}
} else {
int service_ret = _server->HandlePServer2PServerMsg(
request->cmd_id(), request->client_id(), request->data());
if (service_ret != 0) {
response->set_err_code(-1);
response->set_err_msg("handle_pserver2pserver_msg failed");
}
}
}
int32_t BrpcPsService::PullDense(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event(
"PsService->PullDense", platform::TracerEventType::Communication, 1);
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 1) {
set_response_code(
response,
-1,
"PsRequestMessage.datas is requeired at least 1 for num of dense");
return 0;
}
CostTimer timer("pserver_server_pull_dense");
uint32_t num = *(const uint32_t *)request.params(0).c_str();
auto res_data = butil::get_object<std::vector<float>>();
res_data->resize(num * table->ValueAccesor()->GetAccessorInfo().select_size /
sizeof(float));
TableContext table_context;
table_context.value_type = Dense;
table_context.pull_context.values = res_data->data();
table_context.num = num;
table->Pull(table_context);
cntl->response_attachment().append(reinterpret_cast<char *>(res_data->data()),
res_data->size() * sizeof(float));
butil::return_object(res_data);
return 0;
}
int32_t BrpcPsService::PushDenseParam(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event(
"PsService->PushDenseParam", platform::TracerEventType::Communication, 1);
CHECK_TABLE_EXIST(table, request, response)
thread_local std::string push_buffer;
auto &req_io_buffer = cntl->request_attachment();
auto req_buffer_size = req_io_buffer.size();
if (req_buffer_size < 1) {
set_response_code(response, -1, "req attachment is empty");
return 0;
}
push_buffer.resize(0);
push_buffer.reserve(req_buffer_size);
const char *data = (const char *)cntl->request_attachment().fetch(
const_cast<char *>(push_buffer.data()), req_buffer_size);
uint32_t num = *(const uint32_t *)data;
const float *values = (const float *)(data + sizeof(uint32_t));
TableContext table_context;
table_context.value_type = Dense;
table_context.push_context.values = values;
table_context.push_context.is_param = true;
table_context.num = num;
if (table->Push(table_context) != 0) {
set_response_code(response, -1, "PushDenseParam failed");
}
return 0;
}
int32_t BrpcPsService::PushDense(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event(
"PsService->PushDense", platform::TracerEventType::Communication, 1);
CHECK_TABLE_EXIST(table, request, response)
auto req_buffer_size = request.data().size();
if (req_buffer_size < 1) {
// set_response_code(response, 0, "push dense data is empty");
return 0;
}
CostTimer timer("pserver_server_push_dense");
/*
Push Content:
|--num--|---valuesData---|
|--4B---|----------------|
*/
uint32_t num = *(const uint32_t *)(request.data().data());
TableContext table_context;
table_context.value_type = Dense;
table_context.push_context.values =
(const float *)(request.data().data() + sizeof(uint32_t));
table_context.num = num;
// const float *values = (const float *)(request.data().data() +
// sizeof(uint32_t));
if (table->Push(table_context) != 0) {
// if (table->PushDense(values, num) != 0) {
set_response_code(response, -1, "PushDense failed");
}
return 0;
}
int32_t BrpcPsService::Barrier(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 1) {
set_response_code(response,
-1,
"PsRequestMessage.params is requeired at "
"least 1 for num of sparse_key");
return 0;
}
auto trainer_id = request.client_id();
auto barrier_type = request.params(0);
table->Barrier(trainer_id, barrier_type);
return 0;
}
int32_t BrpcPsService::PushSparseParam(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event("PsService->PushSparseParam",
platform::TracerEventType::Communication,
1);
CHECK_TABLE_EXIST(table, request, response)
auto &push_data = request.data();
if (push_data.size() < 1) {
// set_response_code(response, 0, "push sparse data is empty");
return 0;
}
if (request.params_size() < 1) {
set_response_code(response,
-1,
"PsRequestMessage.params is requeired at "
"least 1 for num of sparse_key");
return 0;
}
const uint32_t num =
*(reinterpret_cast<const uint32_t *>(request.params(0).c_str()));
/*
Push Content:
|---keysData---|---valuesData---|
|---8*{num}B---|----------------|
*/
const uint64_t *keys = (const uint64_t *)push_data.data();
const float *values =
(const float *)(push_data.data() + sizeof(uint64_t) * num);
TableContext table_context;
table_context.value_type = Sparse;
table_context.push_context.keys = keys;
table_context.push_context.values = values;
table_context.push_context.is_param = true;
table_context.num = num;
// if (table->PushSparseParam(keys, values, num) != 0) {
if (table->Push(table_context) != 0) {
set_response_code(response, -1, "PushSparseParam error");
}
return 0;
}
int32_t BrpcPsService::PullGeoParam(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event(
"PsService->pull_geo_param", platform::TracerEventType::Communication, 1);
CHECK_TABLE_EXIST(table, request, response)
thread_local std::string push_sparse_request_buffer;
auto trainer_id = request.client_id();
std::vector<float> values;
std::vector<uint64_t> ids;
TableContext table_context;
table_context.value_type = Sparse;
table_context.pull_context.geo_pull_keys = &ids;
table_context.pull_context.geo_pull_values = &values;
table_context.trainer_id = trainer_id;
table->Pull(table_context);
// table->PullGeoParam(trainer_id, &values, &ids);
uint32_t num = ids.size();
cntl->response_attachment().append(reinterpret_cast<char *>(&num),
sizeof(uint32_t));
cntl->response_attachment().append(reinterpret_cast<char *>(ids.data()),
ids.size() * sizeof(uint64_t));
cntl->response_attachment().append(reinterpret_cast<char *>(values.data()),
values.size() * sizeof(float));
return 0;
}
int32_t BrpcPsService::PullSparse(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event(
"PsService->PullSparse", platform::TracerEventType::Communication, 1);
CHECK_TABLE_EXIST(table, request, response)
auto &req_io_buffer = cntl->request_attachment();
auto req_buffer_size = req_io_buffer.size();
if (req_buffer_size < 1) {
set_response_code(response, -1, "req attachment is empty");
return 0;
}
if (request.params_size() < 1) {
set_response_code(response,
-1,
"PsRequestMessage.params is requeired at "
"least 1 for num of sparse_key");
return 0;
}
CostTimer timer("pserver_server_pull_sparse");
const uint32_t num =
*(reinterpret_cast<const uint32_t *>(request.params(0).c_str()));
auto dim = table->ValueAccesor()->GetAccessorInfo().select_dim;
thread_local std::string req_buffer;
req_buffer.reserve(req_buffer_size);
const void *data = cntl->request_attachment().fetch(
const_cast<char *>(req_buffer.data()), req_buffer_size);
auto value = PullSparseValue(num, dim);
value.DeserializeFromBytes(const_cast<void *>(data));
auto res_data = butil::get_object<std::vector<float>>();
res_data->resize(num * dim);
TableContext table_context;
table_context.value_type = Sparse;
table_context.pull_context.pull_value = value;
table_context.pull_context.values = res_data->data();
table->Pull(table_context);
// table->PullSparse(res_data->data(), value);
cntl->response_attachment().append(reinterpret_cast<char *>(res_data->data()),
res_data->size() * sizeof(float));
butil::return_object(res_data);
return 0;
}
int32_t BrpcPsService::PushSparse(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event(
"PsService->PushSparse", platform::TracerEventType::Communication, 1);
CHECK_TABLE_EXIST(table, request, response)
auto &push_data = request.data();
if (push_data.size() < 1) {
// set_response_code(response, 0, "push sparse data is empty");
return 0;
}
if (request.params_size() < 1) {
set_response_code(response,
-1,
"PsRequestMessage.params is requeired at "
"least 1 for num of sparse_key");
return 0;
}
CostTimer timer("pserver_server_push_sparse");
const uint32_t num =
*(reinterpret_cast<const uint32_t *>(request.params(0).c_str()));
/*
Push Content:
|---keysData---|---valuesData---|
|---8*{num}B---|----------------|
*/
TableContext table_context;
table_context.value_type = Sparse;
table_context.push_context.keys = (const uint64_t *)push_data.data();
table_context.push_context.values =
(const float *)(push_data.data() + sizeof(uint64_t) * num);
table_context.num = num;
// const uint64_t *keys = (const uint64_t *)push_data.data();
// const float *values = (const float *)(push_data.data() + sizeof(uint64_t) *
// num);
if (table->Push(table_context) != 0) {
// if (table->PushSparse(keys, values, num) != 0) {
set_response_code(response, -1, "PushSparse error");
}
return 0;
}
int32_t BrpcPsService::PrintTableStat(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
std::pair<int64_t, int64_t> ret = table->PrintTableStat();
paddle::framework::BinaryArchive ar;
ar << ret.first << ret.second;
std::string table_info(ar.Buffer(), ar.Length());
response.set_data(table_info);
return 0;
}
int32_t BrpcPsService::LoadOneTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) {
set_response_code(
response,
-1,
"PsRequestMessage.datas is requeired at least 2 for path & load_param");
return -1;
}
if (table->Load(request.params(0), request.params(1)) != 0) {
set_response_code(response, -1, "table load failed");
return -1;
}
return 0;
}
int32_t BrpcPsService::LoadAllTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
auto &table_map = *(_server->GetTable());
for (auto &itr : table_map) {
if (LoadOneTable(itr.second.get(), request, response, cntl) != 0) {
LOG(ERROR) << "load table[" << itr.first << "] failed";
return -1;
}
}
return 0;
}
int32_t BrpcPsService::SaveOneTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) {
set_response_code(
response,
-1,
"PsRequestMessage.datas is requeired at least 2, path&mode");
return -1;
}
table->Flush();
int32_t feasign_size = 0;
VLOG(3) << "save table " << request.params(0) << " " << request.params(1);
feasign_size = table->Save(request.params(0), request.params(1));
if (feasign_size < 0) {
set_response_code(response, -1, "table save failed");
return -1;
}
return feasign_size;
}
int32_t BrpcPsService::SaveAllTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
auto &table_map = *(_server->GetTable());
int32_t feasign_size = 0;
for (auto &itr : table_map) {
feasign_size = SaveOneTable(itr.second.get(), request, response, cntl);
if (feasign_size < 0) {
LOG(ERROR) << "save table[" << itr.first << "] failed";
return -1;
}
}
return 0;
}
int32_t BrpcPsService::SaveCacheTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) {
set_response_code(
response,
-1,
"PsRequestMessage.datas is requeired at least 3, path&mode");
return -1;
}
table->Flush();
int32_t feasign_size = 0;
// if (_server->_shuffled_ins->size() <= 0) {
// LOG(WARNING) << "shuffled ins size <= 0";
//}
feasign_size = table->SaveCache(
request.params(0), request.params(1), _server->_shuffled_ins);
if (feasign_size < 0) {
set_response_code(response, -1, "table save failed");
return -1;
}
return feasign_size;
}
int32_t BrpcPsService::CacheShuffle(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
// start cache shuffle
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 3) {
set_response_code(response,
-1,
"PsRequestMessage.datas is requeired at least 3, "
"path&mode&cache_threshold");
return -1;
}
table->Flush();
double cache_threshold = std::stod(request.params(2));
LOG(INFO) << "cache threshold for cache shuffle: " << cache_threshold;
// auto shuffled_ins = paddle::ps::make_channel<std::pair<uint64_t,
// std::string>>();
// shuffled_ins->set_block_size(80000);
_server->StartS2S();
std::function<std::future<int32_t>(
int msg_type, int to_pserver_id, const std::string &msg)>
send_msg_func = [this](int msg_type,
int to_pserver_id,
const std::string &msg) -> std::future<int32_t> {
return this->_server->SendPServer2PServerMsg(msg_type, to_pserver_id, msg);
};
std::vector<Table *> table_ptrs;
for (int i = 3; i < request.params_size(); ++i) {
int table_id = std::stoi(request.params(i));
Table *table_ptr = _server->GetTable(table_id);
table_ptrs.push_back(table_ptr);
}
if (table_ptrs.empty()) {
table_ptrs.push_back(table);
}
table->CacheShuffle(request.params(0),
request.params(1),
cache_threshold,
send_msg_func,
_server->_shuffled_ins,
table_ptrs);
return 0;
}
int32_t BrpcPsService::GetCacheThreshold(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
table->Flush();
double cache_threshold = 0.0;
cache_threshold = table->GetCacheThreshold();
if (cache_threshold < 0) {
LOG(WARNING) << "wrong threshold: " << cache_threshold;
}
std::stringstream ss;
ss << std::setprecision(15) << cache_threshold;
std::string cache_threshold_str = ss.str();
response.set_data(cache_threshold_str);
return 0;
}
int32_t BrpcPsService::Revert(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
auto &table_map = *(_server->GetTable());
for (auto &itr : table_map) {
itr.second->Flush();
itr.second->Revert();
}
return 0;
}
int32_t BrpcPsService::CheckSavePrePatchDone(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
auto &table_map = *(_server->GetTable());
for (auto &itr : table_map) {
itr.second->CheckSavePrePatchDone();
}
return 0;
}
int32_t BrpcPsService::ShrinkTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 1) {
set_response_code(
response,
-1,
"PsRequestMessage.datas is requeired at least 1, threshold");
return -1;
}
table->Flush();
if (table->Shrink(request.params(0)) != 0) {
set_response_code(response, -1, "table shrink failed");
return -1;
}
VLOG(3) << "Pserver Shrink Finished";
return 0;
}
int32_t BrpcPsService::ClearOneTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
table->Flush();
table->Clear();
return 0;
}
int32_t BrpcPsService::ClearAllTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
auto &table_map = *(_server->GetTable());
for (auto &itr : table_map) {
if (ClearOneTable(itr.second.get(), request, response, cntl) != 0) {
return -1;
}
}
return 0;
}
int32_t BrpcPsService::StopServer(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
auto *p_server = _server;
std::thread t_stop([p_server]() {
p_server->Stop();
VLOG(3) << "Server Stoped";
});
t_stop.detach();
return 0;
}
int32_t BrpcPsService::StopProfiler(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::DisableProfiler(platform::EventSortingKey::kDefault,
string::Sprintf("server_%s_profile", _rank));
return 0;
}
int32_t BrpcPsService::StartProfiler(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::EnableProfiler(platform::ProfilerState::kCPU);
return 0;
}
int32_t BrpcPsService::PushGlobalStep(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response);
auto req_buffer_size = request.data().size();
if (req_buffer_size < 1) {
set_response_code(response, 0, "run_program data is empty");
return 0;
}
const int64_t *values =
(const int64_t *)(request.data().data() + sizeof(uint32_t));
auto trainer_id = request.client_id();
TableContext context;
context.trainer_id = trainer_id;
context.push_context.push_steps = values;
// if (table->PushDense(values, trainer_id) != 0) {
if (table->Push(context) != 0) {
set_response_code(response, -1, "run_program failed");
}
return 0;
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/ps/service/brpc_utils.h"
#include "paddle/fluid/distributed/ps/service/server.h"
namespace brpc {
class Controller;
} // namespace brpc
namespace google {
namespace protobuf {
class Closure;
class RpcController;
} // namespace protobuf
} // namespace google
namespace paddle {
namespace distributed {
class PsRequestMessage;
class PsResponseMessage;
class Table;
class BrpcPsServer : public PSServer {
public:
BrpcPsServer() {}
virtual ~BrpcPsServer() {}
virtual uint64_t Start(const std::string &ip, uint32_t port);
virtual int32_t Stop() {
std::unique_lock<std::mutex> lock(mutex_);
stoped_ = true;
cv_.notify_all();
_server.Stop(1000);
_server.Join();
return 0;
}
int32_t Port();
int32_t StartS2S() override;
::std::future<int32_t> SendPServer2PServerMsg(
int msg_type, int to_pserver_id, const std::string &msg) override;
int32_t ReceiveFromPServer(int msg_type,
int pserver_id,
const std::string &msg) override;
private:
virtual int32_t Initialize();
mutable std::mutex mutex_;
std::condition_variable cv_;
bool stoped_ = false;
brpc::Server _server;
std::shared_ptr<PsBaseService> _service;
std::vector<std::shared_ptr<brpc::Channel>> _pserver_channels;
};
class BrpcPsService;
typedef int32_t (BrpcPsService::*serviceHandlerFunc)(
Table *table,
const PsRequestMessage &request,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
class BrpcPsService : public PsBaseService {
public:
int32_t Initialize() override;
void service(::google::protobuf::RpcController *controller,
const PsRequestMessage *request,
PsResponseMessage *response,
::google::protobuf::Closure *done) override;
private:
int32_t InitializeShardInfo();
int32_t PullDense(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t PushDense(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t PushDenseParam(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t PushSparseParam(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t PullSparse(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t PullGeoParam(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t Barrier(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t PushSparse(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t LoadOneTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t LoadAllTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t SaveOneTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t SaveAllTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t ShrinkTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t ClearOneTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t ClearAllTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t StopServer(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t StartProfiler(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t StopProfiler(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t PrintTableStat(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t PushGlobalStep(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t CacheShuffle(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t SaveCacheTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t GetCacheThreshold(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t Revert(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t CheckSavePrePatchDone(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
bool _is_initialize_shard_info;
std::mutex _initialize_shard_mutex;
std::unordered_map<int32_t, serviceHandlerFunc> _service_handler_map;
std::unordered_map<int32_t, serviceHandlerFunc> _msg_handler_map;
std::vector<float> _ori_values;
};
class DownpourPServerBrpcClosure : public PServerClosure {
public:
DownpourPServerBrpcClosure(size_t num, PServerCallBack callback)
: PServerClosure(callback) {
_waiting_num = num;
_cntls.resize(num);
_requests.resize(num);
_responses.resize(num);
for (size_t i = 0; i < num; ++i) {
_cntls[i].reset(new brpc::Controller());
}
}
virtual ~DownpourPServerBrpcClosure() {}
void Run() override {
if (_waiting_num.fetch_sub(1) == 1) {
_callback(this);
delete this;
}
}
PsRequestMessage *request(size_t i) { return &_requests[i]; }
PsResponseMessage *response(size_t i) { return &_responses[i]; }
brpc::Controller *cntl(size_t i) { return _cntls[i].get(); }
int check_response(size_t request_idx, int cmd_id) { return 1; }
int check_save_response(size_t request_idx, int cmd_id) { return 1; }
private:
std::atomic<int32_t> _waiting_num;
std::vector<PsRequestMessage> _requests;
std::vector<PsResponseMessage> _responses;
std::vector<std::shared_ptr<brpc::Controller>> _cntls;
};
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/ps/service/brpc_utils.h"
#include <arpa/inet.h>
#include <netdb.h>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
class Variable;
} // namespace framework
} // namespace paddle
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace distributed {
framework::proto::VarType::Type VarMessageToVarType(
VariableMessage::Type type) {
switch (type) {
case VariableMessage::FP32:
return framework::proto::VarType::FP32; // NOLINT
case VariableMessage::FP64:
return framework::proto::VarType::FP64; // NOLINT
case VariableMessage::INT32:
return framework::proto::VarType::INT32; // NOLINT
case VariableMessage::INT64:
return framework::proto::VarType::INT64; // NOLINT
case VariableMessage::BOOL:
return framework::proto::VarType::BOOL; // NOLINT
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"VarMessageToVarType:Unsupported type %d", type));
}
}
void SerializeToMultiVarMsgAndIOBuf(
const std::string& message_name,
const std::vector<std::string>& send_var_name_val,
const std::vector<std::string>& recv_var_name_val,
const platform::DeviceContext& ctx,
const framework::Scope* scope,
MultiVarMsg* request,
butil::IOBuf* iobuf) {
// 1. message_name
request->set_message_name(message_name);
// 2. var_names
for (auto& send_var_name : send_var_name_val) {
request->add_send_var_names(send_var_name);
}
for (auto& recv_var_name : recv_var_name_val) {
request->add_recv_var_names(recv_var_name);
}
// 3. VarMessage
for (auto& send_var_name : send_var_name_val) {
auto* send_var_msg = request->add_var_messages();
butil::IOBuf temp_iobuf;
send_var_msg->set_varname(send_var_name);
framework::Variable* var = scope->FindVar(send_var_name);
if (var->IsType<framework::LoDTensor>()) {
SerializeLodTensor(var, ctx, send_var_msg, &temp_iobuf);
} else if (var->IsType<phi::SelectedRows>()) {
SerializeSelectedRows(var, ctx, send_var_msg, &temp_iobuf);
}
iobuf->append(temp_iobuf);
}
}
void SerializeLodTensor(framework::Variable* var,
const platform::DeviceContext& ctx,
VarMsg* var_msg,
butil::IOBuf* iobuf) {
auto* tensor = var->GetMutable<framework::LoDTensor>();
var_msg->set_type(::paddle::distributed::LOD_TENSOR);
const framework::LoD lod = tensor->lod();
if (lod.size() > 0) {
var_msg->set_lod_level(lod.size());
for (auto& each : lod) {
VarMsg::LodData* lod_inner = var_msg->add_lod();
for (auto& d : each) {
lod_inner->add_lod_data(d);
}
}
}
var_msg->set_data_type(static_cast<VarMsg::Type>(
framework::TransToProtoVarType(tensor->dtype())));
for (auto& dim : phi::vectorize(tensor->dims())) {
var_msg->add_dims(dim);
}
// IO Buffer
if (platform::is_cpu_place(tensor->place())) {
auto data_len = tensor->numel() * framework::DataTypeSize(tensor->dtype());
iobuf->append(reinterpret_cast<const char*>(&data_len), 8);
iobuf->append(reinterpret_cast<const char*>(tensor->data()), data_len);
} else {
#ifdef PADDLE_WITH_CUDA
char* temp_ptr =
new char[tensor->numel() *
framework::DataTypeSize(tensor->dtype())]; // NOLINT
auto stream = reinterpret_cast<const phi::GPUContext&>(ctx).stream();
memory::Copy(
platform::CPUPlace(),
temp_ptr,
tensor->place(),
tensor->data(),
tensor->numel() * framework::SizeOfType(
framework::TransToProtoVarType(tensor->dtype())),
stream);
auto data_len = tensor->numel() * framework::DataTypeSize(tensor->dtype());
iobuf->append(reinterpret_cast<const char*>(&data_len), 8);
iobuf->append(reinterpret_cast<const char*>(temp_ptr), data_len);
delete[] temp_ptr;
#endif
}
}
void SerializeSelectedRows(framework::Variable* var,
const platform::DeviceContext& ctx,
VarMsg* var_msg,
butil::IOBuf* iobuf) {
phi::SelectedRows* slr = var->GetMutable<phi::SelectedRows>();
auto* tensor = slr->mutable_value();
auto* rows = slr->mutable_rows();
var_msg->set_type(::paddle::distributed::SELECTED_ROWS);
var_msg->set_slr_height(slr->height());
auto* var_data = var_msg->mutable_data();
var_data->clear();
var_data->resize(rows->size() * sizeof(int64_t));
char* data_ptr = const_cast<char*>(var_data->data());
memcpy(data_ptr, &((*rows)[0]), rows->size() * sizeof(int64_t));
var_msg->set_data_type(static_cast<VarMsg::Type>(
framework::TransToProtoVarType(tensor->dtype())));
for (auto& dim : phi::vectorize(tensor->dims())) {
var_msg->add_dims(dim);
}
// IO Buffer
if (platform::is_cpu_place(tensor->place())) {
auto data_len = tensor->numel() * framework::DataTypeSize(tensor->dtype());
iobuf->append(reinterpret_cast<const char*>(&data_len), 8);
iobuf->append(reinterpret_cast<const char*>(tensor->data()), data_len);
} else {
#ifdef PADDLE_WITH_CUDA
char* temp_ptr =
new char[tensor->numel() *
framework::DataTypeSize(tensor->dtype())]; // NOLINT
auto stream = reinterpret_cast<const phi::GPUContext&>(ctx).stream();
memory::Copy(
platform::CPUPlace(),
temp_ptr,
tensor->place(),
tensor->data(),
tensor->numel() * framework::SizeOfType(
framework::TransToProtoVarType(tensor->dtype())),
stream);
auto data_len = tensor->numel() * framework::DataTypeSize(tensor->dtype());
iobuf->append(reinterpret_cast<const char*>(&data_len), 8);
iobuf->append(reinterpret_cast<const char*>(temp_ptr), data_len);
delete[] temp_ptr;
#endif
}
}
void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg,
const butil::IOBuf* iobuf,
const platform::DeviceContext& ctx,
framework::Scope* scope) {
butil::IOBufBytesIterator io_buffer_itr(*iobuf);
// size_t shard_buffer_remain = res_io_buffer.size();
for (int recv_var_index = 0; recv_var_index < multi_msg.send_var_names_size();
++recv_var_index) {
const auto& msg = multi_msg.var_messages(recv_var_index);
auto* var = scope->Var(msg.varname());
if (msg.type() == ::paddle::distributed::LOD_TENSOR) {
DeserializeLodTensor(var, msg, io_buffer_itr, ctx);
} else if (msg.type() == ::paddle::distributed::SELECTED_ROWS) {
DeserializeSelectedRows(var, msg, io_buffer_itr, ctx);
}
}
}
void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg,
const butil::IOBuf* iobuf,
const platform::DeviceContext& ctx,
const framework::Scope* scope) {
butil::IOBufBytesIterator io_buffer_itr(*iobuf);
// size_t shard_buffer_remain = res_io_buffer.size();
for (int recv_var_index = 0; recv_var_index < multi_msg.send_var_names_size();
++recv_var_index) {
const auto& msg = multi_msg.var_messages(recv_var_index);
auto* var = scope->FindVar(msg.varname());
PADDLE_ENFORCE_NE(var,
nullptr,
platform::errors::InvalidArgument(
"Not find variable %s in scope.", msg.varname()));
if (msg.type() == ::paddle::distributed::LOD_TENSOR) {
DeserializeLodTensor(var, msg, io_buffer_itr, ctx);
} else if (msg.type() == ::paddle::distributed::SELECTED_ROWS) {
DeserializeSelectedRows(var, msg, io_buffer_itr, ctx);
}
}
}
void DeserializeLodTensor(framework::Variable* var,
const VarMsg& msg,
butil::IOBufBytesIterator& io_buffer_itr, // NOLINT
const platform::DeviceContext& ctx) {
const auto place = ctx.GetPlace();
framework::LoDTensor* tensor = var->GetMutable<framework::LoDTensor>();
std::vector<int> vec_dim;
for (auto& x : msg.dims()) {
vec_dim.push_back(x);
}
tensor->Resize(phi::make_ddim(vec_dim));
framework::LoD lod;
for (int i = 0; i < msg.lod_level(); ++i) {
framework::Vector<size_t> v;
for (int j = 0; j < msg.lod(i).lod_data_size(); ++j) {
v.push_back(msg.lod(i).lod_data(j));
}
lod.push_back(v);
}
tensor->set_lod(lod);
void* tensor_data = tensor->mutable_data(
place,
framework::TransToPhiDataType(VarMessageToVarType(msg.data_type())));
// IO Buffer
if (platform::is_cpu_place(place)) {
unsigned long data_len; // NOLINT
io_buffer_itr.copy_and_forward((void*)(&data_len), 8); // NOLINT
io_buffer_itr.copy_and_forward(tensor_data, data_len);
} else if (platform::is_gpu_place(place)) {
#ifdef PADDLE_WITH_CUDA
unsigned long data_len; // NOLINT
char* temp_ptr =
new char[tensor->numel() *
framework::DataTypeSize(tensor->dtype())]; // NOLINT
io_buffer_itr.copy_and_forward((void*)(&data_len), 8); // NOLINT
io_buffer_itr.copy_and_forward((void*)temp_ptr, data_len); // NOLINT
auto stream = reinterpret_cast<const phi::GPUContext&>(ctx).stream();
memory::Copy(place,
tensor_data,
platform::CPUPlace(),
(void*)temp_ptr, // NOLINT
tensor->numel() * framework::DataTypeSize(tensor->dtype()),
stream);
delete[] temp_ptr;
#endif
}
}
void DeserializeSelectedRows(
framework::Variable* var,
const VarMsg& msg,
butil::IOBufBytesIterator& io_buffer_itr, // NOLINT
const platform::DeviceContext& ctx) {
const auto place = ctx.GetPlace();
auto* slr = var->GetMutable<phi::SelectedRows>();
framework::Tensor* tensor = slr->mutable_value();
slr->set_height(msg.slr_height());
std::vector<int64_t> tmp_rows(msg.dims()[0]);
memcpy(tmp_rows.data(), msg.data().data(), msg.dims()[0] * sizeof(int64_t));
slr->set_rows(tmp_rows);
std::vector<int> vec_dim;
for (auto& x : msg.dims()) {
vec_dim.push_back(x);
}
tensor->Resize(phi::make_ddim(vec_dim));
void* tensor_data = tensor->mutable_data(
place,
framework::TransToPhiDataType(VarMessageToVarType(msg.data_type())));
// IO Buffer
if (platform::is_cpu_place(place)) {
unsigned long data_len; // NOLINT
io_buffer_itr.copy_and_forward((void*)(&data_len), 8); // NOLINT
io_buffer_itr.copy_and_forward(tensor_data, data_len);
} else if (platform::is_gpu_place(place)) {
#ifdef PADDLE_WITH_CUDA
char* temp_ptr =
new char[tensor->numel() *
framework::DataTypeSize(tensor->dtype())]; // NOLINT
unsigned long data_len; // NOLINT
io_buffer_itr.copy_and_forward((void*)(&data_len), 8); // NOLINT
io_buffer_itr.copy_and_forward(temp_ptr, data_len);
auto stream = reinterpret_cast<const phi::GPUContext&>(ctx).stream();
memory::Copy(place,
tensor_data,
platform::CPUPlace(),
temp_ptr,
tensor->numel() * framework::DataTypeSize(tensor->dtype()),
stream);
delete[] temp_ptr;
#endif
}
}
std::string GetIntTypeEndpoint(const std::string& ip, const uint32_t& port) {
// There are usually two forms of IP address: ip(int) / ip (hostname)
// If there're some problem with DNS, or ip triggers the bug of Brpc
// We will try to get the IP address of the domain name manually again
std::string ip_port = ip + ":" + std::to_string(port);
struct hostent* hp = NULL;
hp = gethostbyname(ip.c_str());
if (NULL == hp) {
LOG(ERROR) << "Brpc Start failed, ip_port= " << ip_port
<< " , Error infomation: " << hstrerror(h_errno);
}
int i = 0;
char* int_ip = NULL;
while (hp->h_addr_list[i] != NULL) {
int_ip = inet_ntoa(*(struct in_addr*)hp->h_addr_list[i]);
VLOG(3) << "Brpc Get host by name, host:" << ip << " -> ip: " << int_ip;
break;
}
std::string str_ip = int_ip;
std::string int_ip_port = str_ip + ":" + std::to_string(port);
return int_ip_port;
}
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <netdb.h>
#include <iostream>
#include <string>
#include <vector>
#include "brpc/channel.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows_utils.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/phi/backends/dynload/port.h"
namespace butil {
class IOBuf;
class IOBufBytesIterator;
} // namespace butil
namespace grpc {
class ByteBuffer;
} // namespace grpc
namespace paddle {
namespace framework {
class Scope;
class Variable;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace distributed {
using MultiVarMsg = ::paddle::distributed::MultiVariableMessage;
using VarMsg = ::paddle::distributed::VariableMessage;
void SerializeToMultiVarMsgAndIOBuf(
const std::string& message_name,
const std::vector<std::string>& send_var_name_val,
const std::vector<std::string>& recv_var_name_val,
const platform::DeviceContext& ctx,
const framework::Scope* scope,
MultiVarMsg* var_msg,
butil::IOBuf* iobuf);
void SerializeLodTensor(framework::Variable* var,
const platform::DeviceContext& ctx,
VarMsg* var_msg,
butil::IOBuf* iobuf);
void SerializeSelectedRows(framework::Variable* var,
const platform::DeviceContext& ctx,
VarMsg* request,
butil::IOBuf* iobuf);
// Deserialize for Server
void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg,
const butil::IOBuf* iobuf,
const platform::DeviceContext& ctx,
framework::Scope* scope);
// Deserialize for Client
void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg,
const butil::IOBuf* iobuf,
const platform::DeviceContext& ctx,
const framework::Scope* scope);
void DeserializeLodTensor(framework::Variable* var,
const VarMsg& msg,
butil::IOBufBytesIterator& iobuf, // NOLINT
const platform::DeviceContext& ctx);
void DeserializeSelectedRows(framework::Variable* var,
const VarMsg& msg,
butil::IOBufBytesIterator& iobuf, // NOLINT
const platform::DeviceContext& ctx);
std::string GetIntTypeEndpoint(const std::string& ip, const uint32_t& port);
} // namespace distributed
} // namespace paddle
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment