Commit d2d32668 authored by yuguo960516yuguo's avatar yuguo960516yuguo
Browse files

2.3.0-dtk-22.04.2

parent ad08b8ce
Pipeline #226 failed with stages
in 0 seconds
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/global.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace paddle {
namespace distributed {
class FakeInterceptor : public Interceptor {
public:
FakeInterceptor(int64_t interceptor_id, TaskNode* node)
: Interceptor(interceptor_id, node) {
step_ = 0;
RegisterMsgHandle([this](const InterceptorMessage& msg) { NOP(msg); });
}
void NOP(const InterceptorMessage& msg) {
if (msg.message_type() == DATA_IS_READY) {
std::cout << "FakeInterceptor run in scope " << msg.scope_idx()
<< std::endl;
InterceptorMessage reply;
reply.set_message_type(DATA_IS_USELESS);
Send(SOURCE_ID, reply);
step_++;
if (step_ == node_->max_run_times()) {
carrier_->WakeUp();
}
}
}
private:
int64_t step_;
};
TEST(SourceInterceptor, Source) {
std::string carrier_id = "0";
Carrier* carrier =
GlobalMap<std::string, Carrier>::Create(carrier_id, carrier_id);
carrier->Init(0, {{SOURCE_ID, 0}, {0, 0}});
MessageBus* msg_bus = GlobalVal<MessageBus>::Create();
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* source =
new TaskNode(0, SOURCE_ID, 0, 3, 0); // role, rank, task_id
TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id
source->AddDownstreamTask(0, 1);
node_a->AddUpstreamTask(SOURCE_ID, 1);
carrier->SetInterceptor(
SOURCE_ID, InterceptorFactory::Create("Source", SOURCE_ID, source));
carrier->SetInterceptor(0, std::make_unique<FakeInterceptor>(0, node_a));
// start
InterceptorMessage msg;
msg.set_message_type(START);
msg.set_dst_id(SOURCE_ID);
carrier->EnqueueInterceptorMessage(msg);
carrier->Wait();
carrier->Release();
}
} // namespace distributed
} // namespace paddle
proto_library(index_dataset_proto SRCS index_dataset.proto)
cc_library(
index_wrapper
SRCS index_wrapper.cc
DEPS index_dataset_proto fs)
if(WITH_MKLDNN)
cc_library(
index_sampler
SRCS index_sampler.cc
DEPS xxhash index_wrapper eigen3 mkldnn)
else()
cc_library(
index_sampler
SRCS index_sampler.cc
DEPS xxhash index_wrapper eigen3)
endif()
if(WITH_PYTHON)
py_proto_compile(index_dataset_py_proto SRCS index_dataset.proto)
endif()
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package paddle.distributed;
message IndexNode {
required uint64 id = 1;
required bool is_leaf = 2;
required float probability = 3;
optional string item_name = 4;
}
message TreeMeta {
required int32 height = 1;
required int32 branch = 2;
}
message KVItem {
required bytes key = 1;
required bytes value = 2;
}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/index_dataset/index_sampler.h"
#include "paddle/fluid/framework/data_feed.h"
namespace paddle {
namespace distributed {
std::vector<std::vector<uint64_t>> LayerWiseSampler::sample(
const std::vector<std::vector<uint64_t>>& user_inputs,
const std::vector<uint64_t>& target_ids,
bool with_hierarchy) {
auto input_num = target_ids.size();
auto user_feature_num = user_inputs[0].size();
std::vector<std::vector<uint64_t>> outputs(
input_num * layer_counts_sum_,
std::vector<uint64_t>(user_feature_num + 2));
auto max_layer = tree_->Height();
size_t idx = 0;
for (size_t i = 0; i < input_num; i++) {
auto travel_codes =
tree_->GetTravelCodes(target_ids[i], start_sample_layer_);
auto travel_path = tree_->GetNodes(travel_codes);
for (size_t j = 0; j < travel_path.size(); j++) {
// user
if (j > 0 && with_hierarchy) {
auto ancestor_codes =
tree_->GetAncestorCodes(user_inputs[i], max_layer - j - 1);
auto hierarchical_user = tree_->GetNodes(ancestor_codes);
for (int idx_offset = 0; idx_offset <= layer_counts_[j]; idx_offset++) {
for (size_t k = 0; k < user_feature_num; k++) {
outputs[idx + idx_offset][k] = hierarchical_user[k].id();
}
}
} else {
for (int idx_offset = 0; idx_offset <= layer_counts_[j]; idx_offset++) {
for (size_t k = 0; k < user_feature_num; k++) {
outputs[idx + idx_offset][k] = user_inputs[i][k];
}
}
}
// sampler ++
outputs[idx][user_feature_num] = travel_path[j].id();
outputs[idx][user_feature_num + 1] = 1.0;
idx += 1;
for (int idx_offset = 0; idx_offset < layer_counts_[j]; idx_offset++) {
int sample_res = 0;
do {
sample_res = sampler_vec_[j]->Sample();
} while (layer_ids_[j][sample_res].id() == travel_path[j].id());
outputs[idx + idx_offset][user_feature_num] =
layer_ids_[j][sample_res].id();
outputs[idx + idx_offset][user_feature_num + 1] = 0;
}
idx += layer_counts_[j];
}
}
return outputs;
}
void LayerWiseSampler::sample_from_dataset(
const uint16_t sample_slot,
std::vector<paddle::framework::Record>* src_datas,
std::vector<paddle::framework::Record>* sample_results) {
sample_results->clear();
for (auto& data : *src_datas) {
VLOG(1) << "src data size = " << src_datas->size();
VLOG(1) << "float data size = " << data.float_feasigns_.size();
// data.Print();
uint64_t start_idx = sample_results->size();
VLOG(1) << "before sample, sample_results.size = " << start_idx;
uint64_t sample_feasign_idx = -1;
bool sample_sign = false;
for (unsigned int i = 0; i < data.uint64_feasigns_.size(); i++) {
VLOG(1) << "slot" << i << " = " << data.uint64_feasigns_[i].slot();
if (data.uint64_feasigns_[i].slot() == sample_slot) {
sample_sign = true;
sample_feasign_idx = i;
}
if (sample_sign) break;
}
VLOG(1) << "sample_feasign_idx: " << sample_feasign_idx;
if (sample_sign) {
auto target_id =
data.uint64_feasigns_[sample_feasign_idx].sign().uint64_feasign_;
auto travel_codes = tree_->GetTravelCodes(target_id, start_sample_layer_);
auto travel_path = tree_->GetNodes(travel_codes);
for (unsigned int j = 0; j < travel_path.size(); j++) {
paddle::framework::Record instance(data);
instance.uint64_feasigns_[sample_feasign_idx].sign().uint64_feasign_ =
travel_path[j].id();
sample_results->push_back(instance);
for (int idx_offset = 0; idx_offset < layer_counts_[j]; idx_offset++) {
int sample_res = 0;
do {
sample_res = sampler_vec_[j]->Sample();
} while (layer_ids_[j][sample_res].id() == travel_path[j].id());
paddle::framework::Record instance(data);
instance.uint64_feasigns_[sample_feasign_idx].sign().uint64_feasign_ =
layer_ids_[j][sample_res].id();
VLOG(1) << "layer id :" << layer_ids_[j][sample_res].id();
// sample_feasign_idx + 1 == label's id
instance.uint64_feasigns_[sample_feasign_idx + 1]
.sign()
.uint64_feasign_ = 0;
sample_results->push_back(instance);
}
VLOG(1) << "layer end!!!!!!!!!!!!!!!!!!";
}
}
}
VLOG(1) << "after sample, sample_results.size = " << sample_results->size();
return;
}
std::vector<uint64_t> float2int(std::vector<double> tmp) {
std::vector<uint64_t> tmp_int;
for (auto i : tmp) tmp_int.push_back(uint64_t(i));
return tmp_int;
}
} // end namespace distributed
} // end namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/distributed/index_dataset/index_wrapper.h"
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/math/sampler.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace distributed {
class IndexSampler {
public:
virtual ~IndexSampler() {}
IndexSampler() {}
template <typename T>
static std::shared_ptr<IndexSampler> Init(const std::string& name) {
std::shared_ptr<IndexSampler> instance = nullptr;
instance.reset(new T(name));
return instance;
}
virtual void init_layerwise_conf(
const std::vector<uint16_t>& layer_sample_counts,
uint16_t start_sample_layer = 1,
uint16_t seed = 0) {}
virtual void init_beamsearch_conf(const int64_t k) {}
virtual std::vector<std::vector<uint64_t>> sample(
const std::vector<std::vector<uint64_t>>& user_inputs,
const std::vector<uint64_t>& input_targets,
bool with_hierarchy = false) = 0;
virtual void sample_from_dataset(
const uint16_t sample_slot,
std::vector<paddle::framework::Record>* src_datas,
std::vector<paddle::framework::Record>* sample_results) = 0;
};
class LayerWiseSampler : public IndexSampler {
public:
virtual ~LayerWiseSampler() {}
explicit LayerWiseSampler(const std::string& name) {
tree_ = IndexWrapper::GetInstance()->get_tree_index(name);
}
void init_layerwise_conf(const std::vector<uint16_t>& layer_sample_counts,
uint16_t start_sample_layer,
uint16_t seed) override {
seed_ = seed;
start_sample_layer_ = start_sample_layer;
PADDLE_ENFORCE_GT(
start_sample_layer_,
0,
paddle::platform::errors::InvalidArgument(
"start sampler layer = [%d], it should greater than 0.",
start_sample_layer_));
PADDLE_ENFORCE_LT(start_sample_layer_,
tree_->Height(),
paddle::platform::errors::InvalidArgument(
"start sampler layer = [%d], it should less than "
"max_layer, which is [%d].",
start_sample_layer_,
tree_->Height()));
size_t i = 0;
layer_counts_sum_ = 0;
layer_counts_.clear();
int cur_layer = start_sample_layer_;
while (cur_layer < tree_->Height()) {
int layer_sample_num = 1;
if (i < layer_sample_counts.size()) {
layer_sample_num = layer_sample_counts[i];
}
layer_counts_sum_ += layer_sample_num + 1;
layer_counts_.push_back(layer_sample_num);
VLOG(3) << "[INFO] level " << cur_layer
<< " sample_layer_counts.push_back: " << layer_sample_num;
cur_layer += 1;
i += 1;
}
reverse(layer_counts_.begin(), layer_counts_.end());
VLOG(3) << "sample counts sum: " << layer_counts_sum_;
auto max_layer = tree_->Height();
sampler_vec_.clear();
layer_ids_.clear();
auto layer_index = max_layer - 1;
size_t idx = 0;
while (layer_index >= start_sample_layer_) {
auto layer_codes = tree_->GetLayerCodes(layer_index);
layer_ids_.push_back(tree_->GetNodes(layer_codes));
auto sampler_temp =
std::make_shared<paddle::operators::math::UniformSampler>(
layer_ids_[idx].size() - 1, seed_);
sampler_vec_.push_back(sampler_temp);
layer_index--;
idx++;
}
}
std::vector<std::vector<uint64_t>> sample(
const std::vector<std::vector<uint64_t>>& user_inputs,
const std::vector<uint64_t>& target_ids,
bool with_hierarchy) override;
void sample_from_dataset(
const uint16_t sample_slot,
std::vector<paddle::framework::Record>* src_datas,
std::vector<paddle::framework::Record>* sample_results) override;
private:
std::vector<int> layer_counts_;
int64_t layer_counts_sum_{0};
std::shared_ptr<TreeIndex> tree_{nullptr};
int seed_{0};
int start_sample_layer_{1};
std::vector<std::shared_ptr<paddle::operators::math::Sampler>> sampler_vec_;
std::vector<std::vector<IndexNode>> layer_ids_;
};
} // end namespace distributed
} // end namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/index_dataset/index_wrapper.h"
#include <memory>
#include <string>
#include <thread>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/io/fs.h"
namespace paddle {
namespace distributed {
std::shared_ptr<IndexWrapper> IndexWrapper::s_instance_(nullptr);
int TreeIndex::Load(const std::string filename) {
int err_no;
auto fp = paddle::framework::fs_open_read(filename, &err_no, "");
PADDLE_ENFORCE_NE(
fp,
nullptr,
platform::errors::InvalidArgument(
"Open file %s failed. Please check whether the file exists.",
filename));
int num = 0;
max_id_ = 0;
fake_node_.set_id(0);
fake_node_.set_is_leaf(false);
fake_node_.set_probability(0.0);
max_code_ = 0;
size_t ret = fread(&num, sizeof(num), 1, fp.get());
while (ret == 1 && num > 0) {
std::string content(num, '\0');
size_t read_num =
fread(const_cast<char*>(content.data()), 1, num, fp.get());
PADDLE_ENFORCE_EQ(
read_num,
static_cast<size_t>(num),
platform::errors::InvalidArgument(
"Read from file: %s failed. Valid Format is "
"an integer representing the length of the following string, "
"and the string itself.We got an iteger[% d], "
"but the following string's length is [%d].",
filename,
num,
read_num));
KVItem item;
PADDLE_ENFORCE_EQ(
item.ParseFromString(content),
true,
platform::errors::InvalidArgument("Parse from file: %s failed. It's "
"content can't be parsed by KVItem.",
filename));
if (item.key() == ".tree_meta") {
meta_.ParseFromString(item.value());
} else {
auto code = std::stoull(item.key());
IndexNode node;
node.ParseFromString(item.value());
// PADDLE_ENFORCE_NE(node.id(), 0,
// platform::errors::InvalidArgument(
// "Node'id should not be equel to zero."));
if (node.is_leaf()) {
id_codes_map_[node.id()] = code;
}
data_[code] = node;
if (node.id() > max_id_) {
max_id_ = node.id();
}
if (code > max_code_) {
max_code_ = code;
}
}
ret = fread(&num, sizeof(num), 1, fp.get());
}
total_nodes_num_ = data_.size();
max_code_ += 1;
return 0;
}
std::vector<IndexNode> TreeIndex::GetNodes(const std::vector<uint64_t>& codes) {
std::vector<IndexNode> nodes;
nodes.reserve(codes.size());
for (size_t i = 0; i < codes.size(); i++) {
if (CheckIsValid(codes[i])) {
nodes.push_back(data_.at(codes[i]));
} else {
nodes.push_back(fake_node_);
}
}
return nodes;
}
std::vector<uint64_t> TreeIndex::GetLayerCodes(int level) {
uint64_t level_num = static_cast<uint64_t>(std::pow(meta_.branch(), level));
uint64_t level_offset = level_num - 1;
std::vector<uint64_t> res;
res.reserve(level_num);
for (uint64_t i = 0; i < level_num; i++) {
auto code = level_offset + i;
if (CheckIsValid(code)) {
res.push_back(code);
}
}
return res;
}
std::vector<uint64_t> TreeIndex::GetAncestorCodes(
const std::vector<uint64_t>& ids, int level) {
std::vector<uint64_t> res;
res.reserve(ids.size());
int cur_level;
for (size_t i = 0; i < ids.size(); i++) {
if (id_codes_map_.find(ids[i]) == id_codes_map_.end()) {
res.push_back(max_code_);
} else {
auto code = id_codes_map_.at(ids[i]);
cur_level = meta_.height() - 1;
while (level >= 0 && cur_level > level) {
code = (code - 1) / meta_.branch();
cur_level--;
}
res.push_back(code);
}
}
return res;
}
std::vector<uint64_t> TreeIndex::GetChildrenCodes(uint64_t ancestor,
int level) {
auto level_code_num = static_cast<uint64_t>(std::pow(meta_.branch(), level));
auto code_min = level_code_num - 1;
auto code_max = meta_.branch() * level_code_num - 1;
std::vector<uint64_t> parent;
parent.push_back(ancestor);
std::vector<uint64_t> res;
size_t p_idx = 0;
while (true) {
size_t p_size = parent.size();
for (; p_idx < p_size; p_idx++) {
for (int i = 0; i < meta_.branch(); i++) {
auto code = parent[p_idx] * meta_.branch() + i + 1;
if (data_.find(code) != data_.end()) parent.push_back(code);
}
}
if ((code_min <= parent[p_idx]) && (parent[p_idx] < code_max)) {
break;
}
}
return std::vector<uint64_t>(parent.begin() + p_idx, parent.end());
}
std::vector<uint64_t> TreeIndex::GetTravelCodes(uint64_t id, int start_level) {
std::vector<uint64_t> res;
PADDLE_ENFORCE_NE(id_codes_map_.find(id),
id_codes_map_.end(),
paddle::platform::errors::InvalidArgument(
"id = %d doesn't exist in Tree.", id));
auto code = id_codes_map_.at(id);
int level = meta_.height() - 1;
while (level >= start_level) {
res.push_back(code);
code = (code - 1) / meta_.branch();
level--;
}
return res;
}
std::vector<IndexNode> TreeIndex::GetAllLeafs() {
std::vector<IndexNode> res;
res.reserve(id_codes_map_.size());
for (auto& ite : id_codes_map_) {
auto code = ite.second;
res.push_back(data_.at(code));
}
return res;
}
} // end namespace distributed
} // end namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cmath>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/distributed/index_dataset/index_dataset.pb.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace distributed {
class Index {
public:
Index() {}
~Index() {}
};
class TreeIndex : public Index {
public:
TreeIndex() {}
~TreeIndex() {}
int Height() { return meta_.height(); }
int Branch() { return meta_.branch(); }
uint64_t TotalNodeNums() { return total_nodes_num_; }
uint64_t EmbSize() { return max_id_ + 1; }
int Load(const std::string path);
inline bool CheckIsValid(int code) {
if (data_.find(code) != data_.end()) {
return true;
} else {
return false;
}
}
std::vector<IndexNode> GetNodes(const std::vector<uint64_t>& codes);
std::vector<uint64_t> GetLayerCodes(int level);
std::vector<uint64_t> GetAncestorCodes(const std::vector<uint64_t>& ids,
int level);
std::vector<uint64_t> GetChildrenCodes(uint64_t ancestor, int level);
std::vector<uint64_t> GetTravelCodes(uint64_t id, int start_level);
std::vector<IndexNode> GetAllLeafs();
std::unordered_map<uint64_t, IndexNode> data_;
std::unordered_map<uint64_t, uint64_t> id_codes_map_;
uint64_t total_nodes_num_;
TreeMeta meta_;
uint64_t max_id_;
uint64_t max_code_;
IndexNode fake_node_;
};
using TreePtr = std::shared_ptr<TreeIndex>;
class IndexWrapper {
public:
virtual ~IndexWrapper() {}
IndexWrapper() {}
void clear_tree() { tree_map.clear(); }
TreePtr get_tree_index(const std::string name) {
PADDLE_ENFORCE_NE(tree_map.find(name),
tree_map.end(),
paddle::platform::errors::InvalidArgument(
"tree [%s] doesn't exist. Please insert it firstly "
"by API[\' insert_tree_index \'].",
name));
return tree_map[name];
}
void insert_tree_index(const std::string name, const std::string tree_path) {
if (tree_map.find(name) != tree_map.end()) {
VLOG(0) << "Tree " << name << " has already existed.";
return;
}
TreePtr tree = std::make_shared<TreeIndex>();
int ret = tree->Load(tree_path);
PADDLE_ENFORCE_EQ(ret,
0,
paddle::platform::errors::InvalidArgument(
"Load tree[%s] from path[%s] failed. Please "
"check whether the file exists.",
name,
tree_path));
tree_map.insert(std::pair<std::string, TreePtr>{name, tree});
}
static std::shared_ptr<IndexWrapper> GetInstancePtr() {
if (NULL == s_instance_) {
s_instance_.reset(new paddle::distributed::IndexWrapper());
}
return s_instance_;
}
static IndexWrapper* GetInstance() {
if (NULL == s_instance_) {
s_instance_.reset(new paddle::distributed::IndexWrapper());
}
return s_instance_.get();
}
private:
static std::shared_ptr<IndexWrapper> s_instance_;
std::unordered_map<std::string, TreePtr> tree_map;
};
} // end namespace distributed
} // end namespace paddle
set_property(GLOBAL PROPERTY RPC_DEPS sendrecv_rpc ${BRPC_DEPS} string_helper)
add_subdirectory(table)
add_subdirectory(service)
add_subdirectory(wrapper)
# 目录说明
Table: for param storage and update
-----MemorySparseTable: table for sparse param, used in cpu async mode
-----MemoryDenseTable: table for dense param, used in cpu async/geo mode
-----MemorySparseGeoTable: table for sparse param, used in cpu async mode
-----CommonGraphTable: table used for graph learning
-----BarrierTable: table for barrier function, used in cpu sync mode
-----TensorTable: table which run program, used for learning rate decay only
ValueAccessor: for pull param and push gradient
-----CtrCommonAccessor: pull/push value with show/click, float type
-----CtrDoubleAccessor: same as CtrCommonAccessor, other than show/click with double type
-----SparseAccessor: used for common embedding, pull value without show/click, push value with show/click
-----CommMergeAccessor: used for dense table only, for get param dim
PsService(proto): for server to handle request
-----PsBaseService
----------BrpcPsService: for cpu dnn training task
----------GraphBrpcService: for graph learning
-----HeterService: for dnn training task with heterogeneous computing resources
PSServer: recv request from trainer and handle it by service
-----BrpcPsServer: for cpu dnn training task
-----GraphBrpcServer: for graph learning
-----PsLocalServer: for GpuPS
HeterServer: for HeterPS
PSClient: pull param and push gradient for trainer
-----BrpcPsClient: for cpu dnn training task
----------GraphBrpcClient: for graph learning
-----PsLocalClient: for GpuPS
HeterClient: for HeterPS
PSCore: Wrapper for InitServer
GraphPyService: for graph learning
set(BRPC_SRCS ps_client.cc server.cc)
set_source_files_properties(${BRPC_SRCS})
if(WITH_HETERPS)
set(BRPC_DEPS
brpc
ssl
crypto
protobuf
gflags
glog
zlib
leveldb
snappy
gflags
glog
device_context
rocksdb)
else()
set(BRPC_DEPS
brpc
ssl
crypto
protobuf
gflags
glog
zlib
leveldb
snappy
gflags
glog
device_context)
endif()
brpc_library(
sendrecv_rpc
SRCS
${BRPC_SRCS}
PROTO
sendrecv.proto
DEPS
${BRPC_DEPS})
#set_property(GLOBAL PROPERTY RPC_DEPS sendrecv_rpc ${BRPC_DEPS} string_helper)
get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS)
set_source_files_properties(
communicator/communicator.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
ps_service/service.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
brpc_ps_server.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
brpc_ps_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
ps_local_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
brpc_utils.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
heter_server.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
heter_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(client.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(ps_client.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(server.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
graph_brpc_server.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
graph_brpc_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(
brpc_utils
SRCS brpc_utils.cc
DEPS tensor device_context ${COMMON_DEPS} ${RPC_DEPS})
cc_library(
downpour_server
SRCS graph_brpc_server.cc brpc_ps_server.cc
DEPS boost eigen3 table brpc_utils simple_threadpool ${RPC_DEPS})
cc_library(
downpour_client
SRCS graph_brpc_client.cc brpc_ps_client.cc ps_local_client.cc
DEPS boost eigen3 table brpc_utils simple_threadpool ${RPC_DEPS})
cc_library(
client
SRCS ps_client.cc
DEPS downpour_client boost ${RPC_DEPS})
cc_library(
server
SRCS server.cc
DEPS downpour_server boost ${RPC_DEPS})
cc_library(
communicator
SRCS communicator/communicator.cc
DEPS scope
client
boost
table
math_function
selected_rows_functor
${RPC_DEPS})
cc_library(
ps_service
SRCS ps_service/service.cc
DEPS communicator client server boost ${RPC_DEPS})
cc_library(
heter_client
SRCS heter_client.cc
DEPS brpc_utils ${COMMON_DEPS} ${RPC_DEPS})
cc_library(
heter_server
SRCS heter_server.cc
DEPS heter_client brpc_utils ${COMMON_DEPS} ${RPC_DEPS})
set_source_files_properties(
ps_service/graph_py_service.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS})
cc_library(
graph_py_service
SRCS ps_service/graph_py_service.cc
DEPS ps_service)
#add_subdirectory(communicator)
# 目录说明
* PSServer
* PSClient
* PsService
* Communicator
* MessageBusFramework
* *.proto
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/brpc_ps_client.h"
#include <memory>
#include <sstream>
#include <string>
#include "paddle/fluid/framework/archive.h"
static const int max_port = 65535;
DEFINE_int32(pserver_push_dense_merge_limit,
12,
"limit max push_dense local merge requests");
DEFINE_int32(pserver_push_sparse_merge_limit,
12,
"limit max push_sparse local merge requests");
DEFINE_int32(pserver_pull_dense_limit,
12,
"limit max push_sparse local merge requests");
DEFINE_int32(pserver_async_push_dense_interval_ms,
10,
"async push_dense to server interval");
DEFINE_int32(pserver_async_push_sparse_interval_ms,
10,
"async push_sparse to server interval");
DEFINE_bool(pserver_scale_gradient_by_merge,
false,
"scale dense gradient when merged");
DEFINE_int32(pserver_communicate_compress_type,
0,
"none:0 snappy:1 gzip:2 zlib:3 lz4:4");
DEFINE_int32(pserver_max_async_call_num,
13,
"max task num in async_call_server");
DEFINE_int32(pserver_timeout_ms, 500000, "pserver request server timeout_ms");
DEFINE_int32(pserver_connect_timeout_ms,
10000,
"pserver connect server timeout_ms");
DEFINE_int32(pserver_sparse_merge_thread, 1, "pserver sparse merge thread num");
DEFINE_int32(pserver_sparse_table_shard_num,
1000,
"sparse table shard for save & load");
namespace paddle {
namespace framework {
class Scope;
class Variable;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace distributed {
inline size_t get_sparse_shard(uint32_t shard_num,
uint32_t server_num,
uint64_t key) {
size_t remind = shard_num % server_num;
size_t local_shard_num =
remind == 0 ? shard_num / server_num : shard_num / server_num + 1;
return (key % shard_num) / local_shard_num;
}
void DownpourPsClientService::service(
::google::protobuf::RpcController *controller,
const PsRequestMessage *request,
PsResponseMessage *response,
::google::protobuf::Closure *done) {
brpc::ClosureGuard done_guard(done);
int ret = _client->HandleClient2ClientMsg(
request->cmd_id(), request->client_id(), request->data());
response->set_err_code(0);
response->set_err_msg("");
if (ret != 0) {
response->set_err_code(-1);
response->set_err_msg("handle_client2client_msg failed");
}
}
// 启动client端RpcService 用于数据互发等操作
int32_t BrpcPsClient::StartClientService() {
if (_service.Configure(this, _client_id) != 0) {
LOG(ERROR)
<< "service initialize failed, service_name:DownpourPsClientService";
return -1;
}
_server.AddService(&_service, brpc::SERVER_DOESNT_OWN_SERVICE);
brpc::ServerOptions options;
int start_port = 8500;
options.num_threads = 24;
if (_server.Start(butil::my_ip_cstr(),
brpc::PortRange(start_port, max_port),
&options) != 0) {
LOG(ERROR) << "BrpcPsServer start failed";
return -1;
}
_server_started = true;
_env->RegistePsClient(
butil::my_ip_cstr(), _server.listen_address().port, _client_id);
return 0;
}
int32_t BrpcPsClient::CreateClient2ClientConnection(
int pserver_timeout_ms, int pserver_connect_timeout_ms, int max_retry) {
brpc::ChannelOptions options;
options.protocol = "baidu_std";
options.timeout_ms = pserver_timeout_ms;
options.connection_type = "pooled";
options.connect_timeout_ms = pserver_connect_timeout_ms;
options.max_retry = max_retry;
std::vector<PSHost> client_list = _env->GetPsClients();
VLOG(1) << "BrpcPsClient::create_c2c_connection client_list size: "
<< client_list.size();
for (auto cc : client_list) {
VLOG(1) << "BrpcPsClient::create_c2c_connection client_list: "
<< cc.ToString();
}
_client_channels.resize(client_list.size());
std::ostringstream os;
std::string server_ip_port;
for (size_t i = 0; i < client_list.size(); ++i) {
server_ip_port.assign(client_list[i].ip.c_str());
server_ip_port.append(":");
server_ip_port.append(std::to_string(client_list[i].port));
_client_channels[i].reset(new brpc::Channel());
if (_client_channels[i]->Init(server_ip_port.c_str(), "", &options)) {
VLOG(0) << "BrpcPSClient connect to Client:" << server_ip_port
<< " Failed! Try again.";
std::string int_ip_port =
GetIntTypeEndpoint(client_list[i].ip, client_list[i].port);
if (_client_channels[i]->Init(int_ip_port.c_str(), "", &options) != 0) {
LOG(ERROR) << "BrpcPSClient connect to Client:" << int_ip_port
<< " Failed!";
return -1;
}
}
os << server_ip_port << ",";
}
LOG(INFO) << "Client connect success:" << os.str();
return 0;
}
int32_t BrpcPsClient::Initialize() {
_async_call_num = 0;
brpc::ChannelOptions options;
options.protocol = "baidu_std";
options.timeout_ms = FLAGS_pserver_timeout_ms;
options.connection_type = "pooled";
options.connect_timeout_ms = FLAGS_pserver_connect_timeout_ms;
options.max_retry = 3;
std::ostringstream os;
std::string server_ip_port;
std::string client_ip(butil::my_ip_cstr());
// 获取server列表,并连接
std::vector<PSHost> server_list = _env->GetPsServers();
_server_channels.resize(server_list.size());
for (size_t i = 0; i < server_list.size(); ++i) {
server_ip_port.assign(server_list[i].ip.c_str());
server_ip_port.append(":");
server_ip_port.append(std::to_string(server_list[i].port));
for (size_t j = 0; j < _server_channels[i].size(); ++j) {
_server_channels[i][j].reset(new brpc::Channel());
if (_server_channels[i][j]->Init(server_ip_port.c_str(), "", &options) !=
0) {
VLOG(0) << "BrpcPSclient connect to Server:" << server_ip_port
<< " Failed! Try again.";
std::string int_ip_port =
GetIntTypeEndpoint(server_list[i].ip, server_list[i].port);
if (_server_channels[i][j]->Init(int_ip_port.c_str(), "", &options) !=
0) {
LOG(ERROR) << "BrpcPSclient connect to Server:" << int_ip_port
<< " Failed!";
return -1;
}
}
}
os << server_ip_port << ",";
}
// 启动client探听接口, 并相互建立连接
StartClientService();
// 异步push 请求队列初始化
const auto &worker_param = _config.worker_param().downpour_worker_param();
for (int i = 0; i < worker_param.downpour_table_param_size(); ++i) {
auto type = worker_param.downpour_table_param(i).type();
auto table_id = worker_param.downpour_table_param(i).table_id();
if (type == PS_DENSE_TABLE) {
_push_dense_task_queue_map[table_id] =
paddle::framework::MakeChannel<DenseAsyncTask *>();
}
if (type == PS_SPARSE_TABLE) {
_push_sparse_task_queue_map[table_id] =
paddle::framework::MakeChannel<SparseAsyncTask *>();
_push_sparse_merge_count_map[table_id] = 0;
}
}
auto &profiler = CostProfiler::instance();
profiler.register_profiler("pserver_client_pull_dense");
profiler.register_profiler("pserver_client_pull_sparse");
profiler.register_profiler("pserver_client_pull_sparse_param");
profiler.register_profiler("pserver_client_pull_sparse_local");
profiler.register_profiler("pserver_client_push_sparse");
profiler.register_profiler("pserver_client_push_sparse_parse");
profiler.register_profiler("client_push_sparse_put");
profiler.register_profiler("pserver_client_push_sparse");
profiler.register_profiler("pserver_client_push_sparse_merge");
profiler.register_profiler("pserver_client_push_sparse_rpc");
profiler.register_profiler("pserver_client_push_dense");
profiler.register_profiler("pserver_client_push_dense_parse");
profiler.register_profiler("push_dense_put");
profiler.register_profiler("pserver_client_push_dense_merge");
profiler.register_profiler("pserver_client_push_dense_rpc");
profiler.register_profiler("pserver_client_push_dense_send");
_running = true;
_flushing = false;
// 启动异步push线程
_async_push_sparse_thread =
std::thread(std::bind(&BrpcPsClient::PushSparseTaskConsume, this));
// _async_push_sparse_thread.detach();
_async_push_dense_thread =
std::thread(std::bind(&BrpcPsClient::PushDenseTaskConsume, this));
// for debug
// _print_thread =
// std::thread(std::bind(&BrpcPsClient::PrintQueueSizeThread, this));
return 0;
}
int DownpourBrpcClosure::check_response(size_t request_idx, int cmd_id) {
if (_cntls[request_idx]->Failed()) {
LOG(ERROR) << "resquest cmd_id:" << cmd_id
<< " failed, "
"err:"
<< _cntls[request_idx]->ErrorText();
return -1;
}
if (_responses[request_idx].err_code() != 0) {
LOG(ERROR) << "response ret bad, server_idx:" << request_idx
<< "cmd_id:" << cmd_id
<< " err_code:" << _responses[request_idx].err_code()
<< " err_msg:" << _responses[request_idx].err_msg();
return -1;
}
return 0;
}
int DownpourBrpcClosure::check_save_response(size_t request_idx, int cmd_id) {
int32_t feasign_size = 0;
if (_cntls[request_idx]->Failed()) {
LOG(ERROR) << "resquest cmd_id:" << cmd_id
<< " failed, "
"err:"
<< _cntls[request_idx]->ErrorText();
return -1;
}
feasign_size = _responses[request_idx].err_code();
if (feasign_size < 0) {
LOG(ERROR) << "response ret bad, server_idx:" << request_idx
<< "cmd_id:" << cmd_id
<< " err_code:" << _responses[request_idx].err_code()
<< " err_msg:" << _responses[request_idx].err_msg();
return -1;
}
return feasign_size;
}
std::string DownpourBrpcClosure::get_response(size_t request_idx, int cmd_id) {
std::string data = _responses[request_idx].data();
return data;
}
std::future<int32_t> BrpcPsClient::PrintTableStat(uint32_t table_id) {
size_t request_call_num = _server_channels.size();
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [request_call_num, table_id](void *done) {
int ret = 0;
uint64_t feasign_size = 0;
uint64_t mf_size = 0;
paddle::framework::BinaryArchive ar;
auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, PS_PRINT_TABLE_STAT) != 0) {
ret = -1;
break;
}
std::string resp = closure->get_response(i, PS_PRINT_TABLE_STAT);
ar.SetReadBuffer(
const_cast<char *>(resp.c_str()), resp.length(), nullptr);
feasign_size += ar.Get<uint64_t>();
mf_size += ar.Get<uint64_t>();
}
closure->set_promise_value(ret);
std::cout << "table id: " << table_id
<< ", feasign size: " << feasign_size
<< ", mf size: " << mf_size << std::endl;
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(PS_PRINT_TABLE_STAT);
closure->request(i)->set_table_id(table_id);
closure->request(i)->set_client_id(_client_id);
PsService_Stub rpc_stub(GetCmdChannel(i));
closure->cntl(i)->set_timeout_ms(
10800000); // cmd msg don't limit timeout for save/load
rpc_stub.service(
closure->cntl(i), closure->request(i), closure->response(i), closure);
}
return fut;
}
std::future<int32_t> BrpcPsClient::SendCmd(
uint32_t table_id, int cmd_id, const std::vector<std::string> &params) {
size_t request_call_num = _server_channels.size();
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [request_call_num, cmd_id](void *done) {
int ret = 0;
auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, cmd_id) != 0) {
ret = -1;
break;
}
}
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(cmd_id);
closure->request(i)->set_table_id(table_id);
closure->request(i)->set_client_id(_client_id);
for (const auto &param : params) {
closure->request(i)->add_params(param);
}
PsService_Stub rpc_stub(GetCmdChannel(i));
closure->cntl(i)->set_timeout_ms(
10800000 * 2); // cmd msg don't limit timeout for save/load
rpc_stub.service(
closure->cntl(i), closure->request(i), closure->response(i), closure);
}
return fut;
}
std::future<int32_t> BrpcPsClient::SendSaveCmd(
uint32_t table_id, int cmd_id, const std::vector<std::string> &params) {
size_t request_call_num = _server_channels.size();
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [request_call_num, cmd_id](void *done) {
int ret = 0;
uint32_t feasign_size = 0;
auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_save_response(i, cmd_id) < 0) {
ret = -1;
break;
}
feasign_size += closure->check_save_response(i, cmd_id);
}
if (ret == 0) {
closure->set_promise_value(feasign_size);
} else {
closure->set_promise_value(ret);
}
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(cmd_id);
closure->request(i)->set_table_id(table_id);
closure->request(i)->set_client_id(_client_id);
for (const auto &param : params) {
closure->request(i)->add_params(param);
}
PsService_Stub rpc_stub(GetCmdChannel(i));
closure->cntl(i)->set_timeout_ms(
10800000); // cmd msg don't limit timeout for save/load
rpc_stub.service(
closure->cntl(i), closure->request(i), closure->response(i), closure);
}
return fut;
}
std::future<int32_t> BrpcPsClient::Shrink(uint32_t table_id,
const std::string threshold) {
return SendCmd(table_id, PS_SHRINK_TABLE, {threshold});
}
std::future<int32_t> BrpcPsClient::Load(const std::string &epoch,
const std::string &mode) {
return SendCmd(-1, PS_LOAD_ALL_TABLE, {epoch, mode});
}
std::future<int32_t> BrpcPsClient::Load(uint32_t table_id,
const std::string &epoch,
const std::string &mode) {
return SendCmd(table_id, PS_LOAD_ONE_TABLE, {epoch, mode});
}
std::future<int32_t> BrpcPsClient::Save(const std::string &epoch,
const std::string &mode) {
VLOG(1) << "BrpcPsClient::save path " << epoch;
return SendSaveCmd(-1, PS_SAVE_ALL_TABLE, {epoch, mode});
}
std::future<int32_t> BrpcPsClient::Save(uint32_t table_id,
const std::string &epoch,
const std::string &mode) {
VLOG(1) << "BrpcPsClient::save one table path " << epoch << " table_id "
<< table_id;
return SendSaveCmd(table_id, PS_SAVE_ONE_TABLE, {epoch, mode});
}
std::future<int32_t> BrpcPsClient::CacheShuffle(
uint32_t table_id,
const std::string &path,
const std::string &mode,
const std::string &cache_threshold) {
VLOG(1) << "BrpcPsClient send cmd for cache shuffle";
return SendSaveCmd(table_id, PS_CACHE_SHUFFLE, {path, mode, cache_threshold});
}
std::future<int32_t> BrpcPsClient::CacheShuffleMultiTable(
std::vector<int> tables,
const std::string &path,
const std::string &mode,
const std::string &cache_threshold) {
VLOG(1) << "BrpcPsClient send cmd for cache shuffle multi table one path";
std::vector<std::string> param;
param.push_back(path);
param.push_back(mode);
param.push_back(cache_threshold);
for (size_t i = 0; i < tables.size(); i++) {
param.push_back(std::to_string(tables[i]));
}
return SendSaveCmd(0, PS_CACHE_SHUFFLE, param);
}
std::future<int32_t> BrpcPsClient::SaveCache(uint32_t table_id,
const std::string &path,
const std::string &mode) {
return SendSaveCmd(table_id, PS_SAVE_ONE_CACHE_TABLE, {path, mode});
}
std::future<int32_t> BrpcPsClient::GetCacheThreshold(uint32_t table_id,
double &cache_threshold) {
int cmd_id = PS_GET_CACHE_THRESHOLD;
size_t request_call_num = _server_channels.size();
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num,
[request_call_num, cmd_id, &cache_threshold](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
std::vector<double> cache_thresholds(request_call_num, 0);
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, cmd_id) != 0) {
ret = -1;
break;
}
std::string cur_res = closure->get_response(i, cmd_id);
cache_thresholds[i] = std::stod(cur_res);
}
double sum_threshold = 0.0;
int count = 0;
for (auto t : cache_thresholds) {
if (t >= 0) {
sum_threshold += t;
++count;
}
}
if (count == 0) {
cache_threshold = 0;
} else {
cache_threshold = sum_threshold / count;
}
VLOG(1) << "client get cache threshold: " << cache_threshold;
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(cmd_id);
closure->request(i)->set_table_id(table_id);
closure->request(i)->set_client_id(_client_id);
PsService_Stub rpc_stub(GetCmdChannel(i));
closure->cntl(i)->set_timeout_ms(10800000);
rpc_stub.service(
closure->cntl(i), closure->request(i), closure->response(i), closure);
}
return fut;
}
std::future<int32_t> BrpcPsClient::Clear() {
return SendCmd(-1, PS_CLEAR_ALL_TABLE, {});
}
std::future<int32_t> BrpcPsClient::Clear(uint32_t table_id) {
return SendCmd(table_id, PS_CLEAR_ONE_TABLE, {});
}
std::future<int32_t> BrpcPsClient::Flush() {
VLOG(0) << "BrpcPsClient::flush begin";
_flushing = true;
std::promise<int> promise;
std::future<int32_t> fut = promise.get_future();
do {
VLOG(3) << "wait _async_call_num:" << _async_call_num;
usleep(100000); // sleep 100ms wait async end
} while (_async_call_num > 0);
VLOG(1) << "flush _async_call_num = 0";
promise.set_value(0);
_flushing = false;
VLOG(0) << "BrpcPsClient::flush done";
PrintQueueSize();
return fut;
}
void BrpcPsClient::PrintQueueSize() {
for (auto &push_sparse_task_itr : _push_sparse_task_queue_map) {
auto table_id = push_sparse_task_itr.first;
auto queue_size = push_sparse_task_itr.second->Size();
VLOG(0) << "BrpcPsClient::PrintQueueSize: table " << table_id
<< " size: " << queue_size;
}
for (auto &task_queue_itr : _push_dense_task_queue_map) {
auto table_id = task_queue_itr.first;
auto queue_size = task_queue_itr.second->Size();
VLOG(0) << "BrpcPsClient::PrintQueueSize: table " << table_id
<< " size: " << queue_size;
}
}
void BrpcPsClient::PrintQueueSizeThread() {
while (_running) {
usleep(1000000 * 60 * 2);
PrintQueueSize();
}
}
void BrpcPsClient::FinalizeWorker() {
Flush();
VLOG(0) << "BrpcPsClient::FinalizeWorker begin join thread";
_running = false;
_async_push_dense_thread.join();
_async_push_sparse_thread.join();
// _print_thread.join();
VLOG(0) << "BrpcPsClient::FinalizeWorker begin join server";
_server.Stop(1000);
_server.Join();
_server_started = false;
VLOG(0) << "BrpcPsClient::FinalizeWorker done";
}
std::future<int32_t> BrpcPsClient::StopServer() {
return SendCmd(-1, PS_STOP_SERVER, {});
}
std::future<int32_t> BrpcPsClient::StartProfiler() {
return SendCmd(-1, PS_START_PROFILER, {});
}
std::future<int32_t> BrpcPsClient::StopProfiler() {
return SendCmd(-1, PS_STOP_PROFILER, {});
}
std::future<int32_t> BrpcPsClient::Barrier(size_t table_id,
uint32_t barrier_type) {
return SendCmd(table_id, PS_BARRIER, {std::to_string(barrier_type)});
}
std::future<int32_t> BrpcPsClient::PullGeoParam(size_t table_id,
std::vector<float> *values,
std::vector<uint64_t> *keys,
int pserver_idx) {
auto *accessor = GetTableAccessor(table_id);
DownpourBrpcClosure *closure =
new DownpourBrpcClosure(1, [keys, values, accessor](void *done) {
int ret = 0;
auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
uint32_t shard_nums;
if (closure->check_response(0, PS_PULL_GEO_PARAM) != 0) {
ret = -1;
}
auto &res_io_buffer = closure->cntl(0)->response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
io_buffer_itr.copy_and_forward(reinterpret_cast<void *>(&shard_nums),
sizeof(uint32_t));
keys->resize(shard_nums);
values->resize(shard_nums * accessor->GetAccessorInfo().update_dim);
io_buffer_itr.copy_and_forward((void *)(keys->data()), // NOLINT
sizeof(uint64_t) * shard_nums);
io_buffer_itr.copy_and_forward(
(void *)(values->data()), // NOLINT
shard_nums * accessor->GetAccessorInfo().update_size);
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
closure->request(0)->set_cmd_id(PS_PULL_GEO_PARAM);
closure->request(0)->set_table_id(table_id);
closure->request(0)->set_client_id(_client_id);
PsService_Stub rpc_stub(GetCmdChannel(pserver_idx));
closure->cntl(0)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(
closure->cntl(0), closure->request(0), closure->response(0), closure);
return fut;
}
// for GEO
std::future<int32_t> BrpcPsClient::PushSparseParam(size_t table_id,
const uint64_t *keys,
const float **update_values,
size_t num,
void *done) {
auto *accessor = GetTableAccessor(table_id);
// 发送RPC请求
DownpourBrpcClosure *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
size_t request_call_num = _server_channels.size();
std::vector<std::vector<uint64_t>> ids;
std::vector<std::vector<const float *>> value_ptrs;
ids.resize(request_call_num);
value_ptrs.resize(request_call_num);
for (size_t i = 0; i < num; ++i) {
size_t pserver_idx = keys[i] % request_call_num;
ids[pserver_idx].push_back(keys[i]);
value_ptrs[pserver_idx].push_back(update_values[i]);
}
for (size_t shard_idx = 0; shard_idx < request_call_num; ++shard_idx) {
auto kvs = ids[shard_idx];
auto value_ptr = value_ptrs[shard_idx];
size_t kv_size = kvs.size();
uint32_t value_size = accessor->GetAccessorInfo().update_size;
// 发送RPC请求
auto *push_request = closure->request(shard_idx);
push_request->set_cmd_id(PS_PUSH_SPARSE_PARAM);
push_request->set_table_id(table_id);
push_request->set_client_id(_client_id);
push_request->add_params((char *)&kv_size, sizeof(uint32_t)); // NOLINT
auto *push_data = push_request->mutable_data();
push_data->resize(kv_size * (sizeof(uint64_t) + value_size));
char *push_data_ptr = const_cast<char *>(push_data->data());
memcpy(push_data_ptr, kvs.data(), kv_size * sizeof(uint64_t));
push_data_ptr += kv_size * sizeof(uint64_t);
for (size_t i = 0; i < kv_size; ++i) {
memcpy(push_data_ptr, value_ptr[i], value_size);
push_data_ptr += value_size;
}
PsService_Stub rpc_stub(GetSparseChannel(shard_idx));
closure->cntl(shard_idx)->set_request_compress_type(
(brpc::CompressType)FLAGS_pserver_communicate_compress_type);
rpc_stub.service(closure->cntl(shard_idx),
closure->request(shard_idx),
closure->response(shard_idx),
closure);
}
return fut;
}
std::future<int32_t> BrpcPsClient::PullDense(Region *regions,
size_t region_num,
size_t table_id) {
auto timer = std::make_shared<CostTimer>("pserver_client_pull_dense");
auto *accessor = GetTableAccessor(table_id);
auto fea_dim = accessor->GetAccessorInfo().fea_dim;
size_t request_call_num = _server_channels.size();
uint32_t num_per_shard = DenseDimPerShard(fea_dim, request_call_num);
// callback 将各shard结果,顺序填入region
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num,
[request_call_num, num_per_shard, regions, region_num, accessor](
void *done) {
int ret = 0;
size_t region_idx = 0; // 当前填充的region偏移
size_t region_data_idx = 0; // 当前填充的region内data偏移
auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
size_t shard_data_size =
num_per_shard * accessor->GetAccessorInfo().select_size;
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, PS_PULL_DENSE_TABLE) != 0) {
ret = -1;
break;
}
auto &res_io_buffer = closure->cntl(i)->response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
size_t shard_buffer_remain = res_io_buffer.size();
if (shard_buffer_remain != shard_data_size) {
LOG(ERROR) << "expect res_size:" << shard_data_size
<< ", but size:" << shard_buffer_remain
<< ", ignore this response";
ret = -1;
break;
}
while (shard_buffer_remain > 0 && region_idx < region_num) {
auto &region = regions[region_idx];
if (region.size - region_data_idx >= shard_buffer_remain) {
// region待填充空间 >= 分片buffer数据, 直接拷贝置入
io_buffer_itr.copy_and_forward(
reinterpret_cast<void *>(region.data + region_data_idx),
shard_buffer_remain);
region_data_idx += shard_buffer_remain;
shard_buffer_remain = 0;
} else if (region.size - region_data_idx == 0) {
// region填满,切换到下一个region
++region_idx;
region_data_idx = 0;
} else {
// region不足以容纳所有数据,则能放多少 拷贝多少
io_buffer_itr.copy_and_forward(
reinterpret_cast<void *>(region.data + region_data_idx),
region.size - region_data_idx);
shard_buffer_remain -= (region.size - region_data_idx);
++region_idx;
region_data_idx = 0;
}
}
}
closure->set_promise_value(ret);
});
closure->add_timer(timer);
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(PS_PULL_DENSE_TABLE);
closure->request(i)->set_table_id(table_id);
closure->request(i)->set_client_id(_client_id);
closure->request(i)->add_params((char *)&num_per_shard, // NOLINT
sizeof(num_per_shard));
PsService_Stub rpc_stub(GetDenseChannel(i));
rpc_stub.service(
closure->cntl(i), closure->request(i), closure->response(i), closure);
}
return fut;
}
std::future<int32_t> BrpcPsClient::PushDenseParam(const Region *regions,
size_t region_num,
size_t table_id) {
auto *accessor = GetTableAccessor(table_id);
auto accessor_info = accessor->GetAccessorInfo();
size_t request_call_num = _server_channels.size();
// 1.拆分Region数据到shard中,后续多shard并行拷贝数据
std::vector<std::vector<Region>> regions_partition(request_call_num);
uint32_t num_per_shard =
DenseDimPerShard(accessor_info.fea_dim, request_call_num);
size_t shard_data_size = num_per_shard * accessor_info.update_size;
size_t current_region_idx = 0;
size_t current_region_data_idx = 0;
for (size_t i = 0; i < request_call_num; ++i) {
size_t shard_data_remain_size = shard_data_size;
while (shard_data_remain_size > 0 && current_region_idx < region_num) {
const auto &region = regions[current_region_idx];
size_t region_remain_size = region.size - current_region_data_idx;
if (shard_data_remain_size >= region_remain_size) {
regions_partition[i].push_back(
Region(region.data + current_region_data_idx, region_remain_size));
++current_region_idx;
current_region_data_idx = 0;
shard_data_remain_size -= region_remain_size;
} else {
regions_partition[i].push_back(Region(
region.data + current_region_data_idx, shard_data_remain_size));
current_region_data_idx += shard_data_remain_size;
shard_data_remain_size = 0;
}
}
}
DownpourBrpcClosure *closure =
new DownpourBrpcClosure(request_call_num, [request_call_num](void *done) {
int ret = 0;
auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, PS_PUSH_DENSE_PARAM) != 0) {
ret = -1;
break;
}
}
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
static const int REGION_ASSIGN_BUFFER_SIZE = 1024 * 10;
static char region_assign_buffer[REGION_ASSIGN_BUFFER_SIZE]; // 用于数据补齐
// 开始多shard并行拷贝&请求
for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(PS_PUSH_DENSE_PARAM);
closure->request(i)->set_table_id(table_id);
closure->request(i)->set_client_id(_client_id);
auto &request_buffer = closure->cntl(i)->request_attachment();
request_buffer.append(reinterpret_cast<void *>(&num_per_shard),
sizeof(uint32_t));
auto &region_list = regions_partition[i];
size_t fill_remain_size = shard_data_size;
for (auto &region : region_list) {
fill_remain_size -= region.size;
request_buffer.append(reinterpret_cast<void *>(region.data), region.size);
}
// 保证各分片数据对齐
while (fill_remain_size > 0) {
size_t fill_num = fill_remain_size > REGION_ASSIGN_BUFFER_SIZE
? REGION_ASSIGN_BUFFER_SIZE
: fill_remain_size;
request_buffer.append(reinterpret_cast<void *>(region_assign_buffer),
fill_num);
fill_remain_size -= fill_num;
}
PsService_Stub rpc_stub(GetDenseChannel(i));
rpc_stub.service(
closure->cntl(i), closure->request(i), closure->response(i), closure);
}
return fut;
}
std::future<int32_t> BrpcPsClient::PushSparseRawGradient(
size_t table_id,
const uint64_t *keys,
const float **update_values,
size_t num,
void *done) {
auto *accessor = GetTableAccessor(table_id);
// 发送RPC请求
DownpourBrpcClosure *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
size_t request_call_num = _server_channels.size();
std::vector<std::vector<uint64_t>> ids;
std::vector<std::vector<const float *>> value_ptrs;
ids.resize(request_call_num);
value_ptrs.resize(request_call_num);
const auto &server_param = _config.server_param().downpour_server_param();
uint64_t shard_num = FLAGS_pserver_sparse_table_shard_num;
for (int i = 0; i < server_param.downpour_table_param_size(); ++i) {
const auto &table_param = server_param.downpour_table_param(i);
if (table_param.table_id() == table_id) {
shard_num = table_param.shard_num();
break;
}
}
for (size_t i = 0; i < num; ++i) {
size_t pserver_idx = get_sparse_shard(shard_num, request_call_num, keys[i]);
ids[pserver_idx].push_back(keys[i]);
value_ptrs[pserver_idx].push_back(update_values[i]);
}
for (size_t shard_idx = 0; shard_idx < request_call_num; ++shard_idx) {
auto kvs = ids[shard_idx];
auto value_ptr = value_ptrs[shard_idx];
size_t kv_size = kvs.size();
uint32_t value_size = accessor->GetAccessorInfo().update_size;
// 发送RPC请求
auto *push_request = closure->request(shard_idx);
push_request->set_cmd_id(PS_PUSH_SPARSE_TABLE);
push_request->set_table_id(table_id);
push_request->set_client_id(_client_id);
push_request->add_params((char *)&kv_size, sizeof(uint32_t)); // NOLINT
auto *push_data = push_request->mutable_data();
push_data->resize(kv_size * (sizeof(uint64_t) + value_size));
char *push_data_ptr = const_cast<char *>(push_data->data());
memcpy(push_data_ptr, kvs.data(), kv_size * sizeof(uint64_t));
push_data_ptr += kv_size * sizeof(uint64_t);
for (size_t i = 0; i < kv_size; ++i) {
memcpy(push_data_ptr, value_ptr[i], value_size);
push_data_ptr += value_size;
}
PsService_Stub rpc_stub(GetSparseChannel(shard_idx));
closure->cntl(shard_idx)->set_request_compress_type(
(brpc::CompressType)FLAGS_pserver_communicate_compress_type);
rpc_stub.service(closure->cntl(shard_idx),
closure->request(shard_idx),
closure->response(shard_idx),
closure);
}
return fut;
}
std::future<int32_t> BrpcPsClient::PushDenseRawGradient(
int table_id,
float *total_send_data,
size_t total_send_data_size,
void *done) {
size_t request_call_num = _server_channels.size();
DownpourBrpcClosure *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
auto *accessor = GetTableAccessor(table_id);
uint32_t num_per_shard =
DenseDimPerShard(accessor->GetAccessorInfo().fea_dim, request_call_num);
for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(PS_PUSH_DENSE_TABLE);
closure->request(i)->set_table_id(table_id);
closure->request(i)->set_client_id(_client_id);
auto *push_data = closure->request(i)->mutable_data();
push_data->clear();
push_data->resize(sizeof(uint32_t) + num_per_shard * sizeof(float));
char *push_data_ptr = const_cast<char *>(push_data->data());
memcpy(push_data_ptr, &num_per_shard, sizeof(uint32_t));
memcpy(push_data_ptr + sizeof(uint32_t),
total_send_data + i * num_per_shard,
num_per_shard * sizeof(float));
// closure->cntl(i)->set_request_compress_type(
// (brpc::CompressType)FLAGS_pserver_communicate_compress_type);
PsService_Stub rpc_stub(GetDenseChannel(i));
rpc_stub.service(
closure->cntl(i), closure->request(i), closure->response(i), closure);
}
return fut;
}
std::future<int32_t> BrpcPsClient::PushGlobalStep(int table_id,
int64_t *total_send_data,
void *done) {
size_t request_call_num = _server_channels.size();
DownpourBrpcClosure *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(PS_PUSH_GLOBAL_STEP);
closure->request(i)->set_table_id(table_id);
closure->request(i)->set_client_id(_client_id);
auto *push_data = closure->request(i)->mutable_data();
push_data->clear();
int32_t num_per_shard = 1;
push_data->resize(sizeof(uint32_t) + num_per_shard * sizeof(int64_t));
char *push_data_ptr = const_cast<char *>(push_data->data());
memcpy(push_data_ptr, &num_per_shard, sizeof(uint32_t));
memcpy(push_data_ptr + sizeof(uint32_t),
total_send_data,
num_per_shard * sizeof(int64_t));
PsService_Stub rpc_stub(GetDenseChannel(i));
rpc_stub.service(
closure->cntl(i), closure->request(i), closure->response(i), closure);
}
return fut;
}
std::future<int32_t> BrpcPsClient::PullSparse(float **select_values,
size_t table_id,
const uint64_t *keys,
size_t num,
bool is_training) {
auto timer = std::make_shared<CostTimer>("pserver_client_pull_sparse");
auto local_timer =
std::make_shared<CostTimer>("pserver_client_pull_sparse_local");
size_t request_call_num = _server_channels.size();
auto shard_sorted_kvs = std::make_shared<
std::vector<std::vector<std::pair<uint64_t, float *>>>>();
shard_sorted_kvs->resize(request_call_num);
const auto &server_param = _config.server_param().downpour_server_param();
uint64_t shard_num = FLAGS_pserver_sparse_table_shard_num;
for (int i = 0; i < server_param.downpour_table_param_size(); ++i) {
const auto &table_param = server_param.downpour_table_param(i);
if (table_param.table_id() == table_id) {
shard_num = table_param.shard_num();
break;
}
}
for (size_t i = 0; i < num; ++i) {
size_t shard_id = get_sparse_shard(shard_num, request_call_num, keys[i]);
shard_sorted_kvs->at(shard_id).push_back({keys[i], select_values[i]});
}
auto *accessor = GetTableAccessor(table_id);
size_t value_size = accessor->GetAccessorInfo().select_size;
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [shard_sorted_kvs, value_size](void *done) {
int ret = 0;
auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
for (size_t i = 0; i < shard_sorted_kvs->size(); ++i) {
if (closure->check_response(i, PS_PULL_SPARSE_TABLE) != 0) {
ret = -1;
break;
}
auto &request_kvs = shard_sorted_kvs->at(i);
auto &res_io_buffer = closure->cntl(i)->response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
uint64_t last_key = UINT64_MAX;
float *last_value_data = NULL;
for (size_t kv_idx = 0; kv_idx < request_kvs.size(); ++kv_idx) {
auto *kv_pair = &(request_kvs[kv_idx]);
if (kv_pair->first == last_key) {
memcpy(reinterpret_cast<void *>(kv_pair->second),
reinterpret_cast<void *>(last_value_data),
value_size);
} else {
last_key = kv_pair->first;
last_value_data = kv_pair->second;
if (value_size !=
io_buffer_itr.copy_and_forward(
reinterpret_cast<void *>(last_value_data), value_size)) {
LOG(WARNING) << "res data is lack or not in format";
ret = -1;
break;
}
}
}
}
closure->set_promise_value(ret);
});
closure->add_timer(timer);
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t i = 0; i < request_call_num; ++i) {
auto &sorted_kvs = shard_sorted_kvs->at(i);
std::sort(sorted_kvs.begin(),
sorted_kvs.end(),
[](const std::pair<uint64_t, float *> &k1,
const std::pair<uint64_t, float *> &k2) {
return k1.first < k2.first;
});
uint64_t last_key = UINT64_MAX;
uint32_t kv_request_count = 0;
size_t sorted_kv_size = sorted_kvs.size();
auto &request_buffer = closure->cntl(i)->request_attachment();
request_buffer.append(reinterpret_cast<void *>(&is_training), sizeof(bool));
std::vector<uint32_t> keys_counter;
keys_counter.reserve(sorted_kv_size);
for (size_t kv_idx = 0; kv_idx < sorted_kv_size; ++kv_idx) {
++kv_request_count;
uint32_t keys = 1;
last_key = sorted_kvs[kv_idx].first;
request_buffer.append(reinterpret_cast<void *>(&last_key),
sizeof(uint64_t));
while (kv_idx < sorted_kv_size - 1 &&
last_key == sorted_kvs[kv_idx + 1].first) {
++kv_idx;
++keys;
}
keys_counter.push_back(keys);
}
request_buffer.append(reinterpret_cast<void *>(keys_counter.data()),
sizeof(uint32_t) * keys_counter.size());
if (kv_request_count == 0) {
closure->Run();
} else {
closure->request(i)->set_cmd_id(PS_PULL_SPARSE_TABLE);
closure->request(i)->set_table_id(table_id);
closure->request(i)->set_client_id(_client_id);
closure->request(i)->add_params((char *)&kv_request_count, // NOLINT
sizeof(uint32_t));
PsService_Stub rpc_stub(GetCmdChannel(i));
closure->cntl(i)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(
closure->cntl(i), closure->request(i), closure->response(i), closure);
}
}
return fut;
}
// for GEO
std::future<int32_t> BrpcPsClient::PullSparseParam(float **select_values,
size_t table_id,
const uint64_t *keys,
size_t num,
bool is_training) {
auto timer = std::make_shared<CostTimer>("pserver_client_pull_sparse_param");
size_t request_call_num = _server_channels.size();
auto shard_sorted_kvs = std::make_shared<
std::vector<std::vector<std::pair<uint64_t, float *>>>>();
shard_sorted_kvs->resize(request_call_num);
for (size_t i = 0; i < num; ++i) {
size_t shard_id = keys[i] % request_call_num;
shard_sorted_kvs->at(shard_id).push_back({keys[i], select_values[i]});
}
auto *accessor = GetTableAccessor(table_id);
size_t value_size = accessor->GetAccessorInfo().select_size;
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [shard_sorted_kvs, value_size](void *done) {
int ret = 0;
auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
for (size_t i = 0; i < shard_sorted_kvs->size(); ++i) {
if (closure->check_response(i, PS_PULL_SPARSE_TABLE) != 0) {
ret = -1;
break;
}
auto &request_kvs = shard_sorted_kvs->at(i);
auto &res_io_buffer = closure->cntl(i)->response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
uint64_t last_key = UINT64_MAX;
float *last_value_data = NULL;
// can remove sort&unique
for (size_t kv_idx = 0; kv_idx < request_kvs.size(); ++kv_idx) {
auto *kv_pair = &(request_kvs[kv_idx]);
if (kv_pair->first == last_key) {
memcpy(reinterpret_cast<void *>(kv_pair->second),
reinterpret_cast<void *>(last_value_data),
value_size);
} else {
last_key = kv_pair->first;
last_value_data = kv_pair->second;
if (value_size !=
io_buffer_itr.copy_and_forward(
reinterpret_cast<void *>(last_value_data), value_size)) {
LOG(WARNING) << "res data is lack or not in format";
ret = -1;
break;
}
}
}
}
closure->set_promise_value(ret);
});
closure->add_timer(timer);
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t i = 0; i < request_call_num; ++i) {
auto &sorted_kvs = shard_sorted_kvs->at(i);
std::sort(sorted_kvs.begin(),
sorted_kvs.end(),
[](const std::pair<uint64_t, float *> &k1,
const std::pair<uint64_t, float *> &k2) {
return k1.first < k2.first;
});
uint64_t last_key = UINT64_MAX;
uint32_t kv_request_count = 0;
size_t sorted_kv_size = sorted_kvs.size();
auto &request_buffer = closure->cntl(i)->request_attachment();
request_buffer.append(reinterpret_cast<void *>(&is_training), sizeof(bool));
std::vector<uint32_t> keys_counter;
keys_counter.reserve(sorted_kv_size);
for (size_t kv_idx = 0; kv_idx < sorted_kv_size; ++kv_idx) {
++kv_request_count;
uint32_t keys = 1;
last_key = sorted_kvs[kv_idx].first;
request_buffer.append(reinterpret_cast<void *>(&last_key),
sizeof(uint64_t));
while (kv_idx < sorted_kv_size - 1 &&
last_key == sorted_kvs[kv_idx + 1].first) {
++kv_idx;
++keys;
}
keys_counter.push_back(keys);
}
request_buffer.append(reinterpret_cast<void *>(keys_counter.data()),
sizeof(uint32_t) * keys_counter.size());
if (kv_request_count == 0) {
closure->Run();
} else {
closure->request(i)->set_cmd_id(PS_PULL_SPARSE_TABLE);
closure->request(i)->set_table_id(table_id);
closure->request(i)->set_client_id(_client_id);
closure->request(i)->add_params((char *)&kv_request_count, // NOLINT
sizeof(uint32_t));
PsService_Stub rpc_stub(GetCmdChannel(i));
closure->cntl(i)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(
closure->cntl(i), closure->request(i), closure->response(i), closure);
}
}
return fut;
}
std::future<int32_t> BrpcPsClient::SendClient2ClientMsg(
int msg_type, int to_client_id, const std::string &msg) {
auto promise = std::make_shared<std::promise<int32_t>>();
std::future<int> fut = promise->get_future();
if (to_client_id >= 0 &&
static_cast<size_t>(to_client_id) >= _client_channels.size()) {
VLOG(0) << "to_client_id is out of range clients, which size is "
<< _client_channels.size();
promise->set_value(-1);
return fut;
}
auto *closure = new DownpourBrpcClosure(1, [msg_type](void *done) {
auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
int32_t ret = closure->check_response(0, msg_type + 1000);
closure->set_promise_value(ret);
});
closure->add_promise(promise);
closure->request(0)->set_cmd_id(msg_type);
closure->request(0)->set_client_id(_client_id);
closure->request(0)->set_data(msg);
PsService_Stub rpc_stub(_client_channels[to_client_id].get());
rpc_stub.service(
closure->cntl(0), closure->request(0), closure->response(0), closure);
return fut;
}
std::future<int32_t> BrpcPsClient::PushSparseRawGradientPartial(
size_t table_id,
const uint64_t *keys,
const float **update_values,
uint32_t num,
void *done,
int pserver_idx) {
auto *accessor = GetTableAccessor(table_id);
size_t value_size = accessor->GetAccessorInfo().update_size;
DownpourBrpcClosure *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
// 发送RPC请求
auto *push_request = closure->request(0);
push_request->set_cmd_id(PS_PUSH_SPARSE_TABLE);
push_request->set_table_id(table_id);
push_request->set_client_id(_client_id);
push_request->add_params((char *)&num, sizeof(uint32_t)); // NOLINT
auto *push_data = push_request->mutable_data();
push_data->resize(num * (sizeof(uint64_t) + value_size));
char *push_data_ptr = const_cast<char *>(push_data->data());
memcpy(push_data_ptr, keys, num * sizeof(uint64_t));
push_data_ptr += num * sizeof(uint64_t);
for (uint32_t i = 0; i < num; ++i) {
memcpy(push_data_ptr, update_values[i], value_size);
push_data_ptr += value_size;
}
PsService_Stub rpc_stub(GetSparseChannel(pserver_idx));
closure->cntl(0)->set_request_compress_type(
(brpc::CompressType)FLAGS_pserver_communicate_compress_type);
rpc_stub.service(
closure->cntl(0), closure->request(0), closure->response(0), closure);
return fut;
}
int32_t BrpcPsClient::RecvAndSaveTable(const uint64_t table_id,
const std::string &path) {
// get var information
std::string var_name = "";
int64_t var_num = 0;
int64_t var_shape = 0;
std::string table_class;
const auto &worker_param = _config.worker_param().downpour_worker_param();
for (int i = 0; i < worker_param.downpour_table_param_size(); ++i) {
if (worker_param.downpour_table_param(i).table_id() == table_id) {
var_name = worker_param.downpour_table_param(i).common().table_name();
var_num = worker_param.downpour_table_param(i).common().table_num();
var_shape = worker_param.downpour_table_param(i).common().table_dim();
table_class = worker_param.downpour_table_param(i).table_class();
break;
}
}
PADDLE_ENFORCE_NE(
var_name,
"",
platform::errors::InvalidArgument(
"Cannot find table id %d to save variables.", table_id));
std::string var_store = string::Sprintf("%s", path);
MkDirRecursively(var_store.c_str());
// pull sparse from server
std::vector<float> save_huge_vec(var_num * var_shape);
std::vector<uint64_t> save_key(var_num);
std::vector<float *> save_vec;
for (size_t i = 0; i < save_key.size(); ++i) {
save_key[i] = i;
save_vec.push_back(save_huge_vec.data() + i * var_shape);
}
VLOG(2) << "RecvAndSaveTable: table_class: " << table_class;
// TODO(zhaocaibei123): new GeoBrpcPSClient, move this to its
// RecvAndSaveTable
if (table_class == "MemorySparseGeoTable") {
auto status = PullSparseParam(reinterpret_cast<float **>(save_vec.data()),
table_id,
save_key.data(),
save_key.size(),
true);
status.wait();
} else {
auto status = PullSparse(reinterpret_cast<float **>(save_vec.data()),
table_id,
save_key.data(),
save_key.size(),
true);
status.wait();
}
// create lod tensor
std::shared_ptr<framework::Scope> scope;
scope.reset(new framework::Scope());
auto place = platform::CPUPlace();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
framework::Variable *var = scope->Var(var_name);
framework::LoDTensor *var_tensor = var->GetMutable<framework::LoDTensor>();
std::vector<int64_t> vec_dim = {var_num, var_shape};
var_tensor->Resize(phi::make_ddim(vec_dim));
// copy and save
float *tensor_data = var_tensor->mutable_data<float>(place);
memcpy(
tensor_data, save_huge_vec.data(), var_num * var_shape * sizeof(float));
std::string file_name = string::Sprintf("%s/%s", var_store, var_name);
std::ofstream fout(file_name, std::ios::binary);
PADDLE_ENFORCE_EQ(static_cast<bool>(fout),
true,
platform::errors::Unavailable(
"Cannot open %s to save variables.", file_name));
framework::SerializeToStream(fout, *var_tensor, dev_ctx);
fout.close();
return 0;
}
std::future<int32_t> BrpcPsClient::PushSparse(size_t table_id,
const uint64_t *keys,
const float **update_values,
size_t num) {
auto push_timer = std::make_shared<CostTimer>("pserver_client_push_sparse");
CostTimer parse_timer("pserver_client_push_sparse_parse");
int push_sparse_async_num = _push_sparse_task_queue_map[table_id]->Size();
while (push_sparse_async_num > FLAGS_pserver_max_async_call_num) {
// LOG(INFO) << "PushSparse Waiting for async_call_num comsume,
// task_num:"
// << push_sparse_async_num
// << ", max_task_limit:" << FLAGS_pserver_max_async_call_num;
usleep(5000); // 5ms
push_sparse_async_num = _push_sparse_task_queue_map[table_id]->Size();
}
auto put_timer = std::make_shared<CostTimer>("client_push_sparse_put");
thread_local std::vector<std::vector<std::pair<uint64_t, const float *>>>
shard_sorted_kv_list;
auto *accessor = GetTableAccessor(table_id);
size_t request_call_num = _server_channels.size();
shard_sorted_kv_list.resize(request_call_num);
for (auto &x : shard_sorted_kv_list) {
x.clear();
}
const auto &server_param = _config.server_param().downpour_server_param();
uint64_t shard_num = FLAGS_pserver_sparse_table_shard_num;
for (int i = 0; i < server_param.downpour_table_param_size(); ++i) {
const auto &table_param = server_param.downpour_table_param(i);
if (table_param.table_id() == table_id) {
shard_num = table_param.shard_num();
break;
}
}
for (size_t i = 0; i < num; ++i) {
size_t shard_id = get_sparse_shard(shard_num, request_call_num, keys[i]);
shard_sorted_kv_list[shard_id].push_back({keys[i], update_values[i]});
}
auto sparse_task_data = _sparse_task_pool.get();
sparse_task_data->shared_data.resize(request_call_num);
auto async_task = new SparseAsyncTask(sparse_task_data, table_id, push_timer);
for (size_t i = 0; i < request_call_num; ++i) {
auto &sorted_kv_list = shard_sorted_kv_list[i];
size_t sorted_kv_size = sorted_kv_list.size();
auto &shard_kv_data = async_task->data()->shared_data[i];
shard_kv_data.key_list.resize(sorted_kv_size);
shard_kv_data.value_list.resize(sorted_kv_size);
if (sorted_kv_size == 0) {
shard_kv_data.kv_num = 0;
continue;
}
uint32_t value_size = accessor->GetAccessorInfo().update_size;
for (size_t kv_idx = 0; kv_idx < sorted_kv_size; ++kv_idx) {
shard_kv_data.key_list[kv_idx] = sorted_kv_list[kv_idx].first;
shard_kv_data.value_list[kv_idx].assign(
(const char *)sorted_kv_list[kv_idx].second, value_size);
}
shard_kv_data.kv_num = sorted_kv_size;
}
std::future<int> fut = async_task->get_future();
_push_sparse_task_queue_map[table_id]->Put(std::move(async_task));
return fut;
}
void BrpcPsClient::PushSparseTaskConsume() {
uint64_t merge_size = FLAGS_pserver_push_sparse_merge_limit;
std::vector<std::shared_ptr<SparseAsyncTask>> task_list;
size_t request_call_num = _server_channels.size();
::ThreadPool async_push_sparse_shard_threads(
FLAGS_pserver_sparse_merge_thread);
while (_running) {
auto async_start_time_ms = butil::gettimeofday_ms();
// 所有sparseTable的pushTask 进行处理
for (auto &push_sparse_task_itr : _push_sparse_task_queue_map) {
auto table_id = push_sparse_task_itr.first;
auto *accessor = GetTableAccessor(table_id);
auto &task_queue = push_sparse_task_itr.second;
auto queue_size = task_queue->Size();
if (queue_size == 0) {
continue;
}
if (merge_size > 0 && (queue_size <= 1 && _flushing == false)) {
continue;
}
++_async_call_num;
int merge_count = 0;
for (size_t i = 0; i < task_list.size(); ++i) {
if (task_list[i]->data()) {
_sparse_task_pool.push(task_list[i]->data());
}
}
auto sparse_task_data = _sparse_task_pool.get();
task_list.clear();
int cur_meger_size = task_queue->Size();
// task_list[0] 为一个空SparseAsyncTask, 分shard异步merge结果存入此结构。
sparse_task_data->shared_data.resize(request_call_num);
auto push_timer =
std::make_shared<CostTimer>("pserver_client_push_sparse");
auto async_task =
new SparseAsyncTask(sparse_task_data, table_id, push_timer);
task_list.reserve(cur_meger_size + 1);
task_list.push_back(
std::move(std::shared_ptr<SparseAsyncTask>(async_task)));
while (!task_queue->Empty() && merge_count < cur_meger_size) {
++merge_count;
SparseAsyncTask *task;
task_queue->Get(task);
task_list.push_back(std::shared_ptr<SparseAsyncTask>(task));
}
_push_sparse_merge_count_map[table_id] += merge_count;
// 达到或大于 merge_size发送, 发送过程中
std::vector<int> request_kv_num(request_call_num, 0);
if (_push_sparse_merge_count_map[table_id] >= merge_size ||
_flushing == true) {
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [this, request_call_num](void *done) {
int ret = 0;
auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, PS_PUSH_SPARSE_TABLE) != 0) {
ret = -1;
break;
}
}
closure->set_promise_value(ret);
--_async_call_num;
});
for_each(task_list.begin() + 1,
task_list.end(),
[&request_kv_num, request_call_num, closure](
std::shared_ptr<SparseAsyncTask> &task) {
closure->add_timer(task->timer());
closure->add_promise(task->promise());
});
CostTimer merge_timer("pserver_client_push_sparse_merge");
auto rpc_timer =
std::make_shared<CostTimer>("pserver_client_push_sparse_rpc");
closure->add_timer(rpc_timer);
std::vector<std::future<int>> merge_status(request_call_num);
for (size_t shard_idx = 0; shard_idx < request_call_num; ++shard_idx) {
merge_status[shard_idx] = async_push_sparse_shard_threads.enqueue(
std::bind(&BrpcPsClient::PushSparseAsyncShardPush,
this,
task_list,
request_kv_num,
table_id,
shard_idx,
closure,
accessor));
}
for (size_t shard_idx = 0; shard_idx < request_call_num; ++shard_idx) {
merge_status[shard_idx].wait();
}
merge_status.clear();
std::vector<std::future<int>>().swap(merge_status);
_push_sparse_merge_count_map[table_id] = 0;
} else { // 未达到阈值 只做多路归并
std::vector<std::future<int>> merge_status(request_call_num);
for (size_t shard_idx = 0; shard_idx < request_call_num; ++shard_idx) {
merge_status[shard_idx] = async_push_sparse_shard_threads.enqueue(
std::bind(&BrpcPsClient::PushSparseAsyncShardMerge,
this,
task_list,
request_kv_num,
table_id,
shard_idx,
accessor));
}
for (size_t shard_idx = 0; shard_idx < request_call_num; ++shard_idx) {
merge_status[shard_idx].wait();
}
// meger到task_list[0]
auto async_task = new SparseAsyncTask(*(task_list[0].get()));
task_queue->Put(std::move(async_task));
--_async_call_num;
merge_status.clear();
std::vector<std::future<int>>().swap(merge_status);
}
}
auto wait_ms = FLAGS_pserver_async_push_sparse_interval_ms -
(butil::gettimeofday_ms() - async_start_time_ms);
if (wait_ms > 0) {
usleep(wait_ms * 1000);
}
}
}
void sparse_local_merge(ValueAccessor *accessor,
float *merge_data,
const float *another_data) {
size_t col_num = accessor->GetAccessorInfo().update_dim;
float *merge_data_shell[col_num];
const float *another_data_shell[col_num];
for (size_t i = 0; i < col_num; ++i) {
merge_data_shell[i] = merge_data + i;
another_data_shell[i] = another_data + i;
}
accessor->Merge(merge_data_shell, another_data_shell, 1);
}
int BrpcPsClient::PushSparseAsyncShardMerge(
std::vector<std::shared_ptr<SparseAsyncTask>> &task_list,
std::vector<int> &request_kv_num,
int table_id,
int shard_idx,
ValueAccessor *accessor) {
size_t merged_kv_count = 0;
uint32_t value_size = accessor->GetAccessorInfo().update_size;
thread_local std::vector<std::pair<uint64_t, const float *>> sorted_kv_list;
sorted_kv_list.clear();
for (size_t i = 1; i < task_list.size(); ++i) {
size_t kv_num = task_list[i]->data()->shared_data[shard_idx].kv_num;
auto &key_list = task_list[i]->data()->shared_data[shard_idx].key_list;
auto &value_list = task_list[i]->data()->shared_data[shard_idx].value_list;
for (size_t j = 0; j < kv_num; ++j) {
if (value_list[j].size() < value_size) {
LOG(WARNING) << "value_list[" << j << "]: " << value_list[j].c_str()
<< "is invalid.";
continue;
}
char *task_data_ptr = const_cast<char *>(value_list[j].data());
sorted_kv_list.push_back(
{key_list[j], reinterpret_cast<float *>(task_data_ptr)});
}
}
// 按key排序&去重
std::sort(sorted_kv_list.begin(),
sorted_kv_list.end(),
[](const std::pair<uint64_t, const float *> &k1,
const std::pair<uint64_t, const float *> &k2) {
return k1.first < k2.first;
});
auto &async_task = task_list[0];
size_t sorted_kv_size = sorted_kv_list.size();
auto &shard_kv_data = async_task->data()->shared_data[shard_idx];
shard_kv_data.key_list.resize(sorted_kv_size);
shard_kv_data.value_list.resize(sorted_kv_size);
// 将去重后数据写入分shard包
if (sorted_kv_size == 0) {
shard_kv_data.kv_num = 0;
return 0;
} else if (sorted_kv_size == 1) {
shard_kv_data.kv_num = 1;
shard_kv_data.key_list[0] = sorted_kv_list[0].first;
shard_kv_data.value_list[0].assign((const char *)(sorted_kv_list[0].second),
value_size);
return 0;
}
// 去重 本地merge
uint64_t last_key = sorted_kv_list[0].first;
const float *last_value_data = sorted_kv_list[0].second;
float *last_merge_data = NULL;
std::shared_ptr<char> merger_buffer(new char[value_size],
array_deleter<char>());
for (size_t kv_idx = 1; kv_idx < sorted_kv_size; ++kv_idx) {
while (kv_idx < sorted_kv_size &&
last_key == sorted_kv_list[kv_idx].first) {
if (last_merge_data == NULL) {
last_merge_data = reinterpret_cast<float *>(merger_buffer.get());
memcpy(last_merge_data, last_value_data, value_size);
}
sparse_local_merge(
accessor, last_merge_data, sorted_kv_list[kv_idx].second);
++kv_idx;
}
if (last_merge_data != NULL) {
shard_kv_data.value_list[merged_kv_count].assign(
(const char *)last_merge_data, value_size);
last_merge_data = NULL;
} else {
shard_kv_data.value_list[merged_kv_count].assign(
(const char *)sorted_kv_list[kv_idx - 1].second, value_size);
}
shard_kv_data.key_list[merged_kv_count++] = last_key;
if (kv_idx < sorted_kv_size) {
last_key = sorted_kv_list[kv_idx].first;
last_value_data = sorted_kv_list[kv_idx].second;
}
if (kv_idx == sorted_kv_size - 1) {
shard_kv_data.value_list[merged_kv_count].assign(
(const char *)last_value_data, value_size);
shard_kv_data.key_list[merged_kv_count++] = last_key;
}
}
shard_kv_data.kv_num = merged_kv_count;
return 0;
}
int BrpcPsClient::PushSparseAsyncShardPush(
std::vector<std::shared_ptr<SparseAsyncTask>> &task_list,
std::vector<int> &request_kv_num,
int table_id,
int shard_idx,
DownpourBrpcClosure *closure,
ValueAccessor *accessor) {
PushSparseAsyncShardMerge(
task_list, request_kv_num, table_id, shard_idx, accessor);
size_t merged_kv_count = task_list[0]->data()->shared_data[shard_idx].kv_num;
auto &merged_key_list = task_list[0]->data()->shared_data[shard_idx].key_list;
auto &merged_value_list =
task_list[0]->data()->shared_data[shard_idx].value_list;
// 发送RPC请求
auto *push_request = closure->request(shard_idx);
push_request->set_cmd_id(PS_PUSH_SPARSE_TABLE);
push_request->set_table_id(table_id);
push_request->set_client_id(_client_id);
push_request->add_params(reinterpret_cast<char *>(&merged_kv_count),
sizeof(uint32_t)); // NOLINT
auto *push_data = push_request->mutable_data();
int update_size = accessor->GetAccessorInfo().update_size;
push_data->resize(merged_kv_count * (sizeof(uint64_t) + update_size));
char *push_data_ptr = const_cast<char *>(push_data->data());
memcpy(push_data_ptr,
merged_key_list.data(),
merged_kv_count * sizeof(uint64_t));
push_data_ptr += merged_kv_count * sizeof(uint64_t);
for (size_t i = 0; i < merged_kv_count; ++i) {
const char *task_data_ptr = merged_value_list[i].data();
memcpy(push_data_ptr,
(float *)(task_data_ptr), // NOLINT
update_size);
push_data_ptr += update_size;
}
PsService_Stub rpc_stub(GetSparseChannel(shard_idx));
closure->cntl(shard_idx)->set_request_compress_type(
(brpc::CompressType)FLAGS_pserver_communicate_compress_type);
rpc_stub.service(closure->cntl(shard_idx),
closure->request(shard_idx),
closure->response(shard_idx),
closure);
_push_sparse_merge_count_map[table_id] = 0;
return 0;
}
std::future<int32_t> BrpcPsClient::PushDense(const Region *regions,
size_t region_num,
size_t table_id) {
auto *accessor = GetTableAccessor(table_id);
int fea_dim = accessor->GetAccessorInfo().fea_dim;
int update_dim = accessor->GetAccessorInfo().update_dim;
auto push_timer = std::make_shared<CostTimer>("pserver_client_push_dense");
auto parse_timer =
std::make_shared<CostTimer>("pserver_client_push_dense_parse");
int push_dense_async_num = _push_dense_task_queue_map[table_id]->Size();
while (push_dense_async_num > FLAGS_pserver_max_async_call_num) {
// LOG(INFO) << "PushDense Waiting for async_call_num comsume,
// task_num:"
// << push_dense_async_num
// << ", max_task_limit:" << FLAGS_pserver_max_async_call_num;
usleep(5000); // 5ms
push_dense_async_num = _push_dense_task_queue_map[table_id]->Size();
}
auto push_dense_timer = std::make_shared<CostTimer>("push_dense_put");
// auto dense_data = _dense_matrix_obj_pool.get();
auto dense_data = std::make_shared<std::vector<float>>();
auto async_task = new DenseAsyncTask(dense_data, table_id, push_timer);
size_t request_call_num = _server_channels.size();
uint32_t num_per_shard = DenseDimPerShard(fea_dim, request_call_num);
// 将region数据拷贝到转置矩阵中
async_task->data()->resize(num_per_shard * request_call_num * update_dim);
float *data = async_task->data()->data();
size_t data_size = async_task->data()->size();
uint32_t pos = 0;
for (size_t i = 0; i < region_num; ++i) {
uint32_t data_num = regions[i].size / sizeof(float);
CHECK(pos + data_num <= data_size)
<< "invalid dense size, cur pos[" << pos << "]"
<< " data_num[" << data_num << "] size[" << data_size << "]";
const float *region_data = (const float *)(regions[i].data);
memcpy(data + pos, region_data, regions[i].size);
pos += data_num;
}
std::future<int> fut = async_task->get_future();
_push_dense_task_queue_map[table_id]->Put(std::move(async_task));
return fut;
}
void BrpcPsClient::PushDenseTaskConsume() {
uint64_t merge_size = FLAGS_pserver_push_dense_merge_limit;
static bool scale_gradient = FLAGS_pserver_scale_gradient_by_merge;
::ThreadPool async_merge_dense_threads(10);
while (_running) {
auto async_start_time_ms = butil::gettimeofday_ms();
for (auto &task_queue_itr : _push_dense_task_queue_map) {
auto &task_queue = task_queue_itr.second;
auto queue_size = task_queue->Size();
if (queue_size == 0) {
continue;
}
if (queue_size <= merge_size && _flushing == false) {
continue;
}
++_async_call_num;
DenseAsyncTask *task;
task_queue->Get(task);
auto *accessor = GetTableAccessor(task->table_id());
// 设置请求回调
size_t request_call_num = _server_channels.size();
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [this, request_call_num](void *done) {
int ret = 0;
auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, PS_PUSH_DENSE_TABLE) != 0) {
ret = -1;
break;
}
}
closure->set_promise_value(ret);
--_async_call_num;
});
auto &total_send_data_vec = *(task->data());
float *total_send_data =
reinterpret_cast<float *>(total_send_data_vec.data());
size_t total_send_data_size = total_send_data_vec.size();
{
CostTimer merge_timer("pserver_client_push_dense_merge");
uint32_t merge_count = 0;
std::vector<std::future<int>> merge_status(merge_size);
while (!task_queue->Empty() && merge_count < merge_size) {
auto *async_task = new DenseAsyncTask();
task_queue->Get(async_task);
closure->add_timer(async_task->timer());
closure->add_promise(async_task->promise());
merge_status[merge_count] =
async_merge_dense_threads.enqueue([closure,
accessor,
&total_send_data,
total_send_data_size,
async_task]() -> int {
auto &tmp_task_vec = *(async_task->data());
const float *merge_data = tmp_task_vec.data();
accessor->Merge(
&total_send_data, &merge_data, total_send_data_size);
#pragma optimize("", off)
delete async_task;
#pragma optimize("", on)
return 0;
});
++merge_count;
}
for (size_t i = 0; i < merge_count; ++i) {
merge_status[i].wait();
}
VLOG(3) << "BrpcPsClient::PushDenseTaskConsume before merge "
"total_send_data[0]"
<< total_send_data[0] << " total_send_data[-2]"
<< total_send_data[total_send_data_size - 2]
<< total_send_data[0] << " total_send_data[-1]"
<< total_send_data[total_send_data_size - 1];
if (scale_gradient && merge_count > 1) {
Eigen::Map<Eigen::MatrixXf> mat(
total_send_data, 1, total_send_data_size);
mat *= (1.0 / (merge_count + 1));
}
VLOG(3) << "BrpcPsClient::PushDenseTaskConsume after merge "
"total_send_data[0]"
<< total_send_data[0] << " total_send_data[-2]"
<< total_send_data[total_send_data_size - 2]
<< " total_send_data[-1]"
<< total_send_data[total_send_data_size - 1] << " merge_count "
<< merge_count;
}
std::shared_ptr<DenseAsyncTask> task_ptr(task);
PushDenseRawGradient(
task_ptr, total_send_data, total_send_data_size, closure);
}
auto wait_ms = FLAGS_pserver_async_push_dense_interval_ms -
(butil::gettimeofday_ms() - async_start_time_ms);
if (wait_ms > 0) {
usleep(wait_ms * 1000);
}
}
}
void BrpcPsClient::PushDenseRawGradient(std::shared_ptr<DenseAsyncTask> &task,
float *total_send_data,
size_t total_send_data_size,
DownpourBrpcClosure *closure) {
auto *accessor = GetTableAccessor(task->table_id());
size_t request_call_num = _server_channels.size();
// 将数据拷贝到请求buffer区
auto timer = std::make_shared<CostTimer>("pserver_client_push_dense_rpc");
closure->add_timer(timer);
uint32_t num_per_shard =
DenseDimPerShard(accessor->GetAccessorInfo().fea_dim, request_call_num);
auto send_timer =
std::make_shared<CostTimer>("pserver_client_push_dense_send");
for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(PS_PUSH_DENSE_TABLE);
closure->request(i)->set_table_id(task->table_id());
closure->request(i)->set_client_id(_client_id);
auto *push_data = closure->request(i)->mutable_data();
push_data->clear();
push_data->resize(sizeof(uint32_t) + num_per_shard * sizeof(float));
char *push_data_ptr = const_cast<char *>(push_data->data());
memcpy(push_data_ptr, &num_per_shard, sizeof(uint32_t));
memcpy(push_data_ptr + sizeof(uint32_t),
total_send_data + i * num_per_shard,
num_per_shard * sizeof(float));
closure->cntl(i)->set_request_compress_type(
(brpc::CompressType)FLAGS_pserver_communicate_compress_type);
PsService_Stub rpc_stub(GetDenseChannel(i));
rpc_stub.service(
closure->cntl(i), closure->request(i), closure->response(i), closure);
}
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ThreadPool.h>
#include <memory>
#include <string>
#include <vector>
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/ps/service/brpc_utils.h"
#include "paddle/fluid/distributed/ps/service/ps_client.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor_util.h"
namespace brpc {
class Channel;
class Controller;
} // namespace brpc
namespace google {
namespace protobuf {
class Closure;
class RpcController;
} // namespace protobuf
} // namespace google
namespace paddle {
namespace distributed {
struct Region;
class DownpourPsClientService : public PsService {
public:
DownpourPsClientService() {}
virtual ~DownpourPsClientService() {}
virtual int32_t Configure(PSClient *client, size_t rank_id) {
_client = client;
_rank = rank_id;
return 0;
}
void service(::google::protobuf::RpcController *controller,
const PsRequestMessage *request,
PsResponseMessage *response,
::google::protobuf::Closure *done) override;
protected:
size_t _rank;
PSClient *_client;
};
class DownpourBrpcClosure : public PSClientClosure {
public:
DownpourBrpcClosure(size_t num, PSClientCallBack callback)
: PSClientClosure(callback) {
_waiting_num = num;
_cntls.resize(num);
_requests.resize(num);
_responses.resize(num);
for (size_t i = 0; i < num; ++i) {
_cntls[i].reset(new brpc::Controller());
}
}
virtual ~DownpourBrpcClosure() {}
void Run() override {
if (_waiting_num.fetch_sub(1) == 1) {
_callback(this);
delete this;
}
}
PsRequestMessage *request(size_t i) { return &_requests[i]; }
PsResponseMessage *response(size_t i) { return &_responses[i]; }
brpc::Controller *cntl(size_t i) { return _cntls[i].get(); }
int check_response(size_t request_idx, int cmd_id);
int check_save_response(size_t request_idx, int cmd_id);
std::string get_response(size_t request_idx, int cmd_id);
private:
std::atomic<int32_t> _waiting_num;
std::vector<PsRequestMessage> _requests;
std::vector<PsResponseMessage> _responses;
std::vector<std::shared_ptr<brpc::Controller>> _cntls;
};
struct SharedSparsePushData {
SharedSparsePushData() {}
~SharedSparsePushData() noexcept {}
size_t kv_num;
std::vector<uint64_t> key_list;
std::vector<std::string> value_list;
};
struct SparsePushTaskData {
std::vector<SharedSparsePushData> shared_data; // sparse数据按key hash分片
};
// push sparse 对象池
struct SparseTaskPool {
std::shared_ptr<SparsePushTaskData> get() {
std::lock_guard<std::mutex> lock(_mutex);
if (_pool.empty()) {
return std::make_shared<SparsePushTaskData>();
} else {
auto ret = _pool.back();
_pool.pop_back();
return ret;
}
}
void push(std::shared_ptr<SparsePushTaskData> data) {
std::lock_guard<std::mutex> lock(_mutex);
_pool.push_back(std::move(data));
}
std::vector<std::shared_ptr<SparsePushTaskData>> _pool;
std::mutex _mutex;
};
template <class T>
struct array_deleter {
void operator()(T *&x) const { delete[] x; } // NOLINT
};
class BrpcPsClient : public PSClient {
public:
BrpcPsClient() {}
virtual ~BrpcPsClient() {
if (_running) {
Flush();
_running = false;
}
if (_async_push_dense_thread.joinable()) {
_async_push_dense_thread.join();
}
if (_async_push_sparse_thread.joinable()) {
_async_push_sparse_thread.join();
}
if (_server_started) {
_server.Stop(1000);
_server.Join();
_server_started = false;
}
}
virtual int32_t CreateClient2ClientConnection(int pserver_timeout_ms,
int pserver_connect_timeout_ms,
int max_retry);
std::future<int32_t> Shrink(uint32_t table_id,
const std::string threshold) override;
std::future<int32_t> Load(const std::string &epoch,
const std::string &mode) override;
std::future<int32_t> Load(uint32_t table_id,
const std::string &epoch,
const std::string &mode) override;
std::future<int32_t> Save(const std::string &epoch,
const std::string &mode) override;
std::future<int32_t> Save(uint32_t table_id,
const std::string &epoch,
const std::string &mode) override;
std::future<int32_t> Clear() override;
std::future<int32_t> Clear(uint32_t table_id) override;
std::future<int32_t> StopServer() override;
std::future<int32_t> StartProfiler() override;
std::future<int32_t> StopProfiler() override;
void FinalizeWorker() override;
virtual std::future<int32_t> PullDense(Region *regions,
size_t region_num,
size_t table_id);
virtual std::future<int32_t> PushDenseParam(const Region *regions,
size_t region_num,
size_t table_id);
virtual std::future<int32_t> PushDense(const Region *regions,
size_t region_num,
size_t table_id);
void PushDenseTaskConsume();
virtual std::future<int32_t> PullSparse(float **select_values,
size_t table_id,
const uint64_t *keys,
size_t num,
bool is_training);
virtual std::future<int32_t> PullSparseParam(float **select_values,
size_t table_id,
const uint64_t *keys,
size_t num,
bool is_training);
virtual std::future<int32_t> PrintTableStat(uint32_t table_id);
virtual std::future<int32_t> Barrier(size_t table_id, uint32_t barrier_type);
virtual std::future<int32_t> PullGeoParam(size_t table_id,
std::vector<float> *values,
std::vector<uint64_t> *keys,
int pserver_idx);
virtual std::future<int32_t> PushGlobalStep(int table_id,
int64_t *total_send_data,
void *done);
virtual std::future<int32_t> Flush();
std::future<int32_t> SendClient2ClientMsg(int msg_type,
int to_client_id,
const std::string &msg) override;
// for local save sparse
virtual int32_t RecvAndSaveTable(const uint64_t table_id,
const std::string &path);
std::future<int32_t> CacheShuffle(
uint32_t table_id,
const std::string &path,
const std::string &mode,
const std::string &cache_threshold) override;
std::future<int32_t> CacheShuffleMultiTable(
std::vector<int> tables,
const std::string &path,
const std::string &mode,
const std::string &cache_threshold);
std::future<int32_t> SaveCache(uint32_t table_id,
const std::string &path,
const std::string &mode) override;
std::future<int32_t> GetCacheThreshold(uint32_t table_id,
double &cache_threshold) override;
void PrintQueueSize();
void PrintQueueSizeThread();
protected:
virtual size_t GetServerNums() { return _server_channels.size(); }
inline brpc::Channel *GetSparseChannel(size_t server_id) {
return _server_channels[server_id][0].get();
}
inline brpc::Channel *GetDenseChannel(size_t server_id) {
return _server_channels[server_id][1].get();
}
inline brpc::Channel *GetCmdChannel(size_t server_id) {
return _server_channels[server_id][2].get();
}
int32_t Initialize() override;
private:
inline uint32_t DenseDimPerShard(uint32_t dense_dim_total,
uint32_t shard_num) {
return dense_dim_total / shard_num + 1;
}
std::future<int32_t> SendCmd(uint32_t table_id,
int cmd_id,
const std::vector<std::string> &param);
std::future<int32_t> SendSaveCmd(uint32_t table_id,
int cmd_id,
const std::vector<std::string> &param);
bool _running = false;
bool _flushing = false;
std::atomic<uint32_t> _async_call_num; // 异步请求计数
// 异步push dense task
std::thread _async_push_dense_thread;
typedef AsyncRequestTask<std::shared_ptr<std::vector<float>>> DenseAsyncTask;
std::unordered_map<uint32_t, paddle::framework::Channel<DenseAsyncTask *>>
_push_dense_task_queue_map;
// 异步push sparse task
std::thread _async_push_sparse_thread;
typedef AsyncRequestTask<std::shared_ptr<SparsePushTaskData>> SparseAsyncTask;
std::unordered_map<uint32_t, paddle::framework::Channel<SparseAsyncTask *>>
_push_sparse_task_queue_map;
std::unordered_map<uint32_t, uint32_t> _push_sparse_merge_count_map;
std::thread _print_thread;
int PushSparseAsyncShardMerge(
std::vector<std::shared_ptr<SparseAsyncTask>> &task_list, // NOLINT
std::vector<int> &request_kv_num,
int table_id,
int shard_idx, // NOLINT
ValueAccessor *accessor);
int PushSparseAsyncShardPush(
std::vector<std::shared_ptr<SparseAsyncTask>> &task_list, // NOLINT
std::vector<int> &request_kv_num,
int table_id,
int shard_idx, // NOLINT
DownpourBrpcClosure *closure,
ValueAccessor *accessor);
SparseTaskPool _sparse_task_pool;
std::vector<std::shared_ptr<brpc::Channel>>
_client_channels; // client2client
std::vector<std::array<std::shared_ptr<brpc::Channel>, 3>>
_server_channels; // client2server
std::future<int32_t> PushDenseRawGradient(int table_id,
float *total_send_data,
size_t total_send_data_size,
void *done) override;
std::future<int32_t> PushSparseRawGradient(size_t table_id,
const uint64_t *keys,
const float **update_values,
size_t num,
void *done) override;
std::future<int32_t> PushSparseRawGradientPartial(size_t table_id,
const uint64_t *keys,
const float **update_values,
uint32_t num,
void *done,
int pserver_idx) override;
std::future<int32_t> PushSparseParam(size_t table_id,
const uint64_t *keys,
const float **update_values,
size_t num,
void *done) override;
std::future<int32_t> PushSparse(size_t table_id,
const uint64_t *keys,
const float **update_values,
size_t num) override;
void PushSparseTaskConsume();
private:
int32_t StartClientService();
void PushDenseRawGradient(std::shared_ptr<DenseAsyncTask> &task, // NOLINT
float *total_send_data,
size_t total_send_data_size,
DownpourBrpcClosure *closure);
float _mae = 0;
float _mse = 0;
uint16_t _push_times = 0;
brpc::Server _server;
DownpourPsClientService _service;
bool _server_started = false;
std::atomic_uint grad_num_{0};
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/brpc_ps_server.h"
#include <thread> // NOLINT
#include "butil/object_pool.h"
#include "paddle/fluid/distributed/common/cost_timer.h"
#include "paddle/fluid/distributed/ps/table/depends/sparse_utils.h"
#include "paddle/fluid/distributed/ps/table/table.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/platform/profiler.h"
namespace google {
namespace protobuf {
class Closure;
class RpcController;
} // namespace protobuf
} // namespace google
DEFINE_int32(pserver_timeout_ms_s2s,
10000,
"pserver request server timeout_ms");
DEFINE_int32(pserver_connect_timeout_ms_s2s,
10000,
"pserver connect server timeout_ms");
DEFINE_string(pserver_connection_type_s2s,
"pooled",
"pserver connection_type[pooled:single]");
namespace paddle {
namespace distributed {
int32_t BrpcPsServer::Initialize() {
auto &service_config = _config.downpour_server_param().service_param();
if (!service_config.has_service_class()) {
LOG(ERROR) << "miss service_class in ServerServiceParameter";
return -1;
}
auto *service =
CREATE_PSCORE_CLASS(PsBaseService, service_config.service_class());
if (service == NULL) {
LOG(ERROR) << "service is unregistered, service_name:"
<< service_config.service_class();
return -1;
}
_service.reset(service);
if (service->Configure(this) != 0 || service->Initialize() != 0) {
LOG(ERROR) << "service initialize failed, service_name:"
<< service_config.service_class();
return -1;
}
if (_server.AddService(service, brpc::SERVER_DOESNT_OWN_SERVICE) != 0) {
LOG(ERROR) << "service add to brpc failed, service:"
<< service_config.service_class();
return -1;
}
return 0;
}
uint64_t BrpcPsServer::Start(const std::string &ip, uint32_t port) {
std::unique_lock<std::mutex> lock(mutex_);
std::string ip_port = ip + ":" + std::to_string(port);
VLOG(0) << "running server with rank id: " << _rank
<< ", endpoint: " << ip_port;
brpc::ServerOptions options;
int num_threads = std::thread::hardware_concurrency();
auto trainers = _environment->GetTrainers();
options.num_threads = trainers > num_threads ? trainers : num_threads;
if (_server.Start(ip_port.c_str(), &options) != 0) {
VLOG(0) << "BrpcPsServer start failed, ip_port= " << ip_port
<< " , Try Again.";
std::string int_ip_port = GetIntTypeEndpoint(ip, port);
if (_server.Start(int_ip_port.c_str(), &options) != 0) {
LOG(ERROR) << "BrpcPsServer start failed, ip_port= " << int_ip_port;
return 0;
}
}
_environment->RegistePsServer(ip, port, _rank);
cv_.wait(lock, [&] { return stoped_; });
PSHost host;
host.ip = ip;
host.port = port;
host.rank = _rank;
return host.rank;
}
int32_t BrpcPsServer::StartS2S() {
brpc::ChannelOptions options;
options.protocol = "baidu_std";
options.timeout_ms = FLAGS_pserver_timeout_ms_s2s;
options.connection_type = FLAGS_pserver_connection_type_s2s;
options.connect_timeout_ms = FLAGS_pserver_connect_timeout_ms_s2s;
options.max_retry = 3;
std::vector<PSHost> pserver_list = _environment->GetPsServers();
_pserver_channels.resize(pserver_list.size());
VLOG(2) << "pserver start s2s server_list size: " << _pserver_channels.size();
std::ostringstream os;
std::string server_ip_port;
for (size_t i = 0; i < pserver_list.size(); ++i) {
server_ip_port.assign(pserver_list[i].ip.c_str());
server_ip_port.append(":");
server_ip_port.append(std::to_string(pserver_list[i].port));
_pserver_channels[i].reset(new brpc::Channel());
if (_pserver_channels[i]->Init(server_ip_port.c_str(), "", &options) != 0) {
LOG(ERROR) << "pserver connect to pserver:" << server_ip_port
<< " Failed!";
}
os << server_ip_port << ",";
}
LOG(INFO) << "pserver connect success: " << os.str();
return 0;
}
std::future<int32_t> BrpcPsServer::SendPServer2PServerMsg(
int msg_type, int to_pserver_id, const std::string &msg) {
auto promise = std::make_shared<std::promise<int32_t>>();
std::future<int> fut = promise->get_future();
if (static_cast<size_t>(to_pserver_id) >= _pserver_channels.size()) {
LOG(FATAL) << "to_pserver_id is out of range pservers, which size is "
<< _pserver_channels.size();
promise->set_value(-1);
return fut;
}
auto *closure = new DownpourPServerBrpcClosure(1, [msg_type](void *done) {
auto *closure = (DownpourPServerBrpcClosure *)done;
int32_t ret = closure->check_response(0, msg_type + 1000);
closure->set_promise_value(ret);
});
closure->add_promise(promise);
closure->request(0)->set_cmd_id(101);
closure->request(0)->set_client_id(_rank);
closure->request(0)->set_table_id(0);
closure->request(0)->set_data(msg);
PsService_Stub rpc_stub(_pserver_channels[to_pserver_id].get());
rpc_stub.service(
closure->cntl(0), closure->request(0), closure->response(0), closure);
return fut;
}
int32_t BrpcPsServer::ReceiveFromPServer(int msg_type,
int pserver_id,
const std::string &msg) {
if (msg.length() == 0) {
LOG(WARNING) << "SERVER>>RESPONSE>>msg = 0 Finish S2S Response";
return 0;
}
paddle::framework::BinaryArchive ar;
ar.SetReadBuffer(const_cast<char *>(msg.c_str()), msg.length(), nullptr);
if (ar.Cursor() == ar.Finish()) {
LOG(WARNING) << "SERVER>>RESPONSE ar = 0>> Finish S2S Response";
return 0;
}
std::vector<std::pair<uint64_t, std::string>> data;
while (ar.Cursor() < ar.Finish()) {
data.push_back(ar.Get<std::pair<uint64_t, std::string>>());
}
CHECK(ar.Cursor() == ar.Finish());
this->_shuffled_ins->Write(std::move(data));
return 0;
}
int32_t BrpcPsServer::Port() { return _server.listen_address().port; }
int32_t BrpcPsService::Initialize() {
_is_initialize_shard_info = false;
_service_handler_map[PS_STOP_SERVER] = &BrpcPsService::StopServer;
_service_handler_map[PS_PULL_DENSE_TABLE] = &BrpcPsService::PullDense;
_service_handler_map[PS_PUSH_DENSE_TABLE] = &BrpcPsService::PushDense;
_service_handler_map[PS_PULL_SPARSE_TABLE] = &BrpcPsService::PullSparse;
_service_handler_map[PS_PUSH_SPARSE_TABLE] = &BrpcPsService::PushSparse;
_service_handler_map[PS_SAVE_ONE_TABLE] = &BrpcPsService::SaveOneTable;
_service_handler_map[PS_SAVE_ALL_TABLE] = &BrpcPsService::SaveAllTable;
_service_handler_map[PS_SHRINK_TABLE] = &BrpcPsService::ShrinkTable;
_service_handler_map[PS_LOAD_ONE_TABLE] = &BrpcPsService::LoadOneTable;
_service_handler_map[PS_LOAD_ALL_TABLE] = &BrpcPsService::LoadAllTable;
_service_handler_map[PS_CLEAR_ONE_TABLE] = &BrpcPsService::ClearOneTable;
_service_handler_map[PS_CLEAR_ALL_TABLE] = &BrpcPsService::ClearAllTable;
_service_handler_map[PS_PUSH_DENSE_PARAM] = &BrpcPsService::PushDenseParam;
_service_handler_map[PS_PRINT_TABLE_STAT] = &BrpcPsService::PrintTableStat;
_service_handler_map[PS_PULL_GEO_PARAM] = &BrpcPsService::PullGeoParam;
_service_handler_map[PS_PUSH_SPARSE_PARAM] = &BrpcPsService::PushSparseParam;
_service_handler_map[PS_BARRIER] = &BrpcPsService::Barrier;
_service_handler_map[PS_START_PROFILER] = &BrpcPsService::StartProfiler;
_service_handler_map[PS_STOP_PROFILER] = &BrpcPsService::StopProfiler;
_service_handler_map[PS_PUSH_GLOBAL_STEP] = &BrpcPsService::PushGlobalStep;
// for save cache
_service_handler_map[PS_SAVE_ONE_CACHE_TABLE] =
&BrpcPsService::SaveCacheTable;
_service_handler_map[PS_GET_CACHE_THRESHOLD] =
&BrpcPsService::GetCacheThreshold;
_service_handler_map[PS_CACHE_SHUFFLE] = &BrpcPsService::CacheShuffle;
auto &profiler = CostProfiler::instance();
profiler.register_profiler("pserver_server_pull_dense");
profiler.register_profiler("pserver_server_push_dense");
profiler.register_profiler("pserver_server_pull_sparse");
profiler.register_profiler("pserver_server_push_sparse");
// shard初始化,server启动后才可从env获取到server_list的shard信息
InitializeShardInfo();
return 0;
}
#define CHECK_TABLE_EXIST(table, request, response) \
if (table == NULL) { \
std::string err_msg("table not found with table_id:"); \
err_msg.append(std::to_string(request.table_id())); \
set_response_code(response, -1, err_msg.c_str()); \
return -1; \
}
int32_t BrpcPsService::InitializeShardInfo() {
if (!_is_initialize_shard_info) {
std::lock_guard<std::mutex> guard(_initialize_shard_mutex);
if (_is_initialize_shard_info) {
return 0;
}
size_t shard_num = _server->Environment()->GetPsServers().size();
auto &table_map = *(_server->GetTable());
for (auto itr : table_map) {
itr.second->SetShard(_rank, shard_num);
}
_is_initialize_shard_info = true;
}
return 0;
}
void BrpcPsService::service(google::protobuf::RpcController *cntl_base,
const PsRequestMessage *request,
PsResponseMessage *response,
google::protobuf::Closure *done) {
brpc::ClosureGuard done_guard(done);
std::string log_label("ReceiveCmd-");
if (!request->has_table_id()) {
set_response_code(*response, -1, "PsRequestMessage.tabel_id is required");
return;
}
response->set_err_code(0);
response->set_err_msg("");
auto *table = _server->GetTable(request->table_id());
brpc::Controller *cntl = static_cast<brpc::Controller *>(cntl_base);
if (request->cmd_id() < 100) {
auto itr = _service_handler_map.find(request->cmd_id());
if (itr == _service_handler_map.end()) {
std::string err_msg(
"undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:");
err_msg.append(std::to_string(request->cmd_id()));
set_response_code(*response, -1, err_msg.c_str());
return;
}
serviceHandlerFunc handler_func = itr->second;
int service_ret = (this->*handler_func)(table, *request, *response, cntl);
if (service_ret != 0) {
response->set_err_code(service_ret);
response->set_err_msg("server internal error");
}
} else {
int service_ret = _server->HandlePServer2PServerMsg(
request->cmd_id(), request->client_id(), request->data());
if (service_ret != 0) {
response->set_err_code(-1);
response->set_err_msg("handle_pserver2pserver_msg failed");
}
}
}
int32_t BrpcPsService::PullDense(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event(
"PsService->PullDense", platform::TracerEventType::Communication, 1);
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 1) {
set_response_code(
response,
-1,
"PsRequestMessage.datas is requeired at least 1 for num of dense");
return 0;
}
CostTimer timer("pserver_server_pull_dense");
uint32_t num = *(const uint32_t *)request.params(0).c_str();
auto res_data = butil::get_object<std::vector<float>>();
res_data->resize(num * table->ValueAccesor()->GetAccessorInfo().select_size /
sizeof(float));
TableContext table_context;
table_context.value_type = Dense;
table_context.pull_context.values = res_data->data();
table_context.num = num;
table->Pull(table_context);
// table->PullDense(res_data->data(), num);
cntl->response_attachment().append((char *)(res_data->data()),
res_data->size() * sizeof(float));
butil::return_object(res_data);
return 0;
}
int32_t BrpcPsService::PushDenseParam(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event(
"PsService->PushDenseParam", platform::TracerEventType::Communication, 1);
CHECK_TABLE_EXIST(table, request, response)
thread_local std::string push_buffer;
auto &req_io_buffer = cntl->request_attachment();
auto req_buffer_size = req_io_buffer.size();
if (req_buffer_size < 1) {
set_response_code(response, -1, "req attachment is empty");
return 0;
}
push_buffer.resize(0);
push_buffer.reserve(req_buffer_size);
const char *data = (const char *)cntl->request_attachment().fetch(
const_cast<char *>(push_buffer.data()), req_buffer_size);
uint32_t num = *(const uint32_t *)data;
const float *values = (const float *)(data + sizeof(uint32_t));
TableContext table_context;
table_context.value_type = Dense;
table_context.push_context.values = values;
table_context.push_context.is_param = true;
table_context.num = num;
// if (table->PushDenseParam(values, num) != 0) {
if (table->Push(table_context) != 0) {
set_response_code(response, -1, "PushDenseParam failed");
}
return 0;
}
int32_t BrpcPsService::PushDense(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event(
"PsService->PushDense", platform::TracerEventType::Communication, 1);
CHECK_TABLE_EXIST(table, request, response)
auto req_buffer_size = request.data().size();
if (req_buffer_size < 1) {
// set_response_code(response, 0, "push dense data is empty");
return 0;
}
CostTimer timer("pserver_server_push_dense");
/*
Push Content:
|--num--|---valuesData---|
|--4B---|----------------|
*/
uint32_t num = *(const uint32_t *)(request.data().data());
TableContext table_context;
table_context.value_type = Dense;
table_context.push_context.values =
(const float *)(request.data().data() + sizeof(uint32_t));
table_context.num = num;
// const float *values = (const float *)(request.data().data() +
// sizeof(uint32_t));
if (table->Push(table_context) != 0) {
// if (table->PushDense(values, num) != 0) {
set_response_code(response, -1, "PushDense failed");
}
return 0;
}
int32_t BrpcPsService::Barrier(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 1) {
set_response_code(response,
-1,
"PsRequestMessage.params is requeired at "
"least 1 for num of sparse_key");
return 0;
}
auto trainer_id = request.client_id();
auto barrier_type = request.params(0);
table->Barrier(trainer_id, barrier_type);
return 0;
}
int32_t BrpcPsService::PushSparseParam(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event("PsService->PushSparseParam",
platform::TracerEventType::Communication,
1);
CHECK_TABLE_EXIST(table, request, response)
auto &push_data = request.data();
if (push_data.size() < 1) {
// set_response_code(response, 0, "push sparse data is empty");
return 0;
}
if (request.params_size() < 1) {
set_response_code(response,
-1,
"PsRequestMessage.params is requeired at "
"least 1 for num of sparse_key");
return 0;
}
uint32_t num = *(uint32_t *)(request.params(0).c_str());
/*
Push Content:
|---keysData---|---valuesData---|
|---8*{num}B---|----------------|
*/
const uint64_t *keys = (const uint64_t *)push_data.data();
const float *values =
(const float *)(push_data.data() + sizeof(uint64_t) * num);
TableContext table_context;
table_context.value_type = Sparse;
table_context.push_context.keys = keys;
table_context.push_context.values = values;
table_context.push_context.is_param = true;
table_context.num = num;
// if (table->PushSparseParam(keys, values, num) != 0) {
if (table->Push(table_context) != 0) {
set_response_code(response, -1, "PushSparseParam error");
}
return 0;
}
int32_t BrpcPsService::PullGeoParam(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event(
"PsService->pull_geo_param", platform::TracerEventType::Communication, 1);
CHECK_TABLE_EXIST(table, request, response)
thread_local std::string push_sparse_request_buffer;
auto trainer_id = request.client_id();
std::vector<float> values;
std::vector<uint64_t> ids;
TableContext table_context;
table_context.value_type = Sparse;
table_context.pull_context.geo_pull_keys = &ids;
table_context.pull_context.geo_pull_values = &values;
table_context.trainer_id = trainer_id;
table->Pull(table_context);
// table->PullGeoParam(trainer_id, &values, &ids);
uint32_t num = ids.size();
cntl->response_attachment().append((char *)(&num), sizeof(uint32_t));
cntl->response_attachment().append((char *)ids.data(),
ids.size() * sizeof(uint64_t));
cntl->response_attachment().append((char *)values.data(),
values.size() * sizeof(float));
return 0;
}
int32_t BrpcPsService::PullSparse(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event(
"PsService->PullSparse", platform::TracerEventType::Communication, 1);
CHECK_TABLE_EXIST(table, request, response)
auto &req_io_buffer = cntl->request_attachment();
auto req_buffer_size = req_io_buffer.size();
if (req_buffer_size < 1) {
set_response_code(response, -1, "req attachment is empty");
return 0;
}
if (request.params_size() < 1) {
set_response_code(response,
-1,
"PsRequestMessage.params is requeired at "
"least 1 for num of sparse_key");
return 0;
}
CostTimer timer("pserver_server_pull_sparse");
uint32_t num = *(uint32_t *)(request.params(0).c_str());
auto dim = table->ValueAccesor()->GetAccessorInfo().select_dim;
thread_local std::string req_buffer;
req_buffer.reserve(req_buffer_size);
const void *data = cntl->request_attachment().fetch(
const_cast<char *>(req_buffer.data()), req_buffer_size);
auto value = PullSparseValue(num, dim);
value.DeserializeFromBytes(const_cast<void *>(data));
auto res_data = butil::get_object<std::vector<float>>();
res_data->resize(num * dim);
TableContext table_context;
table_context.value_type = Sparse;
table_context.pull_context.pull_value = value;
table_context.pull_context.values = res_data->data();
table->Pull(table_context);
// table->PullSparse(res_data->data(), value);
cntl->response_attachment().append((char *)(res_data->data()),
res_data->size() * sizeof(float));
butil::return_object(res_data);
return 0;
}
int32_t BrpcPsService::PushSparse(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event(
"PsService->PushSparse", platform::TracerEventType::Communication, 1);
CHECK_TABLE_EXIST(table, request, response)
auto &push_data = request.data();
if (push_data.size() < 1) {
// set_response_code(response, 0, "push sparse data is empty");
return 0;
}
if (request.params_size() < 1) {
set_response_code(response,
-1,
"PsRequestMessage.params is requeired at "
"least 1 for num of sparse_key");
return 0;
}
CostTimer timer("pserver_server_push_sparse");
uint32_t num = *(uint32_t *)(request.params(0).c_str());
/*
Push Content:
|---keysData---|---valuesData---|
|---8*{num}B---|----------------|
*/
TableContext table_context;
table_context.value_type = Sparse;
table_context.push_context.keys = (const uint64_t *)push_data.data();
table_context.push_context.values =
(const float *)(push_data.data() + sizeof(uint64_t) * num);
table_context.num = num;
// const uint64_t *keys = (const uint64_t *)push_data.data();
// const float *values = (const float *)(push_data.data() + sizeof(uint64_t) *
// num);
if (table->Push(table_context) != 0) {
// if (table->PushSparse(keys, values, num) != 0) {
set_response_code(response, -1, "PushSparse error");
}
return 0;
}
int32_t BrpcPsService::PrintTableStat(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
std::pair<int64_t, int64_t> ret = table->PrintTableStat();
paddle::framework::BinaryArchive ar;
ar << ret.first << ret.second;
std::string table_info(ar.Buffer(), ar.Length());
response.set_data(table_info);
return 0;
}
int32_t BrpcPsService::LoadOneTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) {
set_response_code(
response,
-1,
"PsRequestMessage.datas is requeired at least 2 for path & load_param");
return -1;
}
if (table->Load(request.params(0), request.params(1)) != 0) {
set_response_code(response, -1, "table load failed");
return -1;
}
return 0;
}
int32_t BrpcPsService::LoadAllTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
auto &table_map = *(_server->GetTable());
for (auto &itr : table_map) {
if (LoadOneTable(itr.second.get(), request, response, cntl) != 0) {
LOG(ERROR) << "load table[" << itr.first << "] failed";
return -1;
}
}
return 0;
}
int32_t BrpcPsService::SaveOneTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) {
set_response_code(
response,
-1,
"PsRequestMessage.datas is requeired at least 2, path&mode");
return -1;
}
table->Flush();
int32_t feasign_size = 0;
VLOG(3) << "save table " << request.params(0) << " " << request.params(1);
feasign_size = table->Save(request.params(0), request.params(1));
if (feasign_size < 0) {
set_response_code(response, -1, "table save failed");
return -1;
}
return feasign_size;
}
int32_t BrpcPsService::SaveAllTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
auto &table_map = *(_server->GetTable());
int32_t feasign_size = 0;
for (auto &itr : table_map) {
feasign_size = SaveOneTable(itr.second.get(), request, response, cntl);
if (feasign_size < 0) {
LOG(ERROR) << "save table[" << itr.first << "] failed";
return -1;
}
}
return 0;
}
int32_t BrpcPsService::SaveCacheTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) {
set_response_code(
response,
-1,
"PsRequestMessage.datas is requeired at least 3, path&mode");
return -1;
}
table->Flush();
int32_t feasign_size = 0;
// if (_server->_shuffled_ins->size() <= 0) {
// LOG(WARNING) << "shuffled ins size <= 0";
//}
feasign_size = table->SaveCache(
request.params(0), request.params(1), _server->_shuffled_ins);
if (feasign_size < 0) {
set_response_code(response, -1, "table save failed");
return -1;
}
return feasign_size;
}
int32_t BrpcPsService::CacheShuffle(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
// start cache shuffle
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 3) {
set_response_code(response,
-1,
"PsRequestMessage.datas is requeired at least 3, "
"path&mode&cache_threshold");
return -1;
}
table->Flush();
double cache_threshold = std::stod(request.params(2));
LOG(INFO) << "cache threshold for cache shuffle: " << cache_threshold;
// auto shuffled_ins = paddle::ps::make_channel<std::pair<uint64_t,
// std::string>>();
// shuffled_ins->set_block_size(80000);
_server->StartS2S();
std::function<std::future<int32_t>(
int msg_type, int to_pserver_id, const std::string &msg)>
send_msg_func = [this](int msg_type,
int to_pserver_id,
const std::string &msg) -> std::future<int32_t> {
return this->_server->SendPServer2PServerMsg(msg_type, to_pserver_id, msg);
};
std::vector<Table *> table_ptrs;
for (int i = 3; i < request.params_size(); ++i) {
int table_id = std::stoi(request.params(i));
Table *table_ptr = _server->GetTable(table_id);
table_ptrs.push_back(table_ptr);
}
if (table_ptrs.empty()) {
table_ptrs.push_back(table);
}
table->CacheShuffle(request.params(0),
request.params(1),
cache_threshold,
send_msg_func,
_server->_shuffled_ins,
table_ptrs);
return 0;
}
int32_t BrpcPsService::GetCacheThreshold(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
table->Flush();
double cache_threshold = 0.0;
cache_threshold = table->GetCacheThreshold();
if (cache_threshold < 0) {
LOG(WARNING) << "wrong threshold: " << cache_threshold;
}
std::stringstream ss;
ss << std::setprecision(15) << cache_threshold;
std::string cache_threshold_str = ss.str();
response.set_data(cache_threshold_str);
return 0;
}
int32_t BrpcPsService::ShrinkTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 1) {
set_response_code(
response,
-1,
"PsRequestMessage.datas is requeired at least 1, threshold");
return -1;
}
table->Flush();
if (table->Shrink(request.params(0)) != 0) {
set_response_code(response, -1, "table shrink failed");
return -1;
}
VLOG(3) << "Pserver Shrink Finished";
return 0;
}
int32_t BrpcPsService::ClearOneTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
table->Flush();
table->Clear();
return 0;
}
int32_t BrpcPsService::ClearAllTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
auto &table_map = *(_server->GetTable());
for (auto &itr : table_map) {
if (ClearOneTable(itr.second.get(), request, response, cntl) != 0) {
return -1;
}
}
return 0;
}
int32_t BrpcPsService::StopServer(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
auto *p_server = _server;
std::thread t_stop([p_server]() {
p_server->Stop();
VLOG(3) << "Server Stoped";
});
t_stop.detach();
return 0;
}
int32_t BrpcPsService::StopProfiler(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::DisableProfiler(platform::EventSortingKey::kDefault,
string::Sprintf("server_%s_profile", _rank));
return 0;
}
int32_t BrpcPsService::StartProfiler(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::EnableProfiler(platform::ProfilerState::kCPU);
return 0;
}
int32_t BrpcPsService::PushGlobalStep(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response);
auto req_buffer_size = request.data().size();
if (req_buffer_size < 1) {
set_response_code(response, 0, "run_program data is empty");
return 0;
}
const int64_t *values =
(const int64_t *)(request.data().data() + sizeof(uint32_t));
auto trainer_id = request.client_id();
TableContext context;
context.trainer_id = trainer_id;
context.push_context.push_steps = values;
// if (table->PushDense(values, trainer_id) != 0) {
if (table->Push(context) != 0) {
set_response_code(response, -1, "run_program failed");
}
return 0;
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/ps/service/brpc_utils.h"
#include "paddle/fluid/distributed/ps/service/server.h"
namespace brpc {
class Controller;
} // namespace brpc
namespace google {
namespace protobuf {
class Closure;
class RpcController;
} // namespace protobuf
} // namespace google
namespace paddle {
namespace distributed {
class PsRequestMessage;
class PsResponseMessage;
class Table;
class BrpcPsServer : public PSServer {
public:
BrpcPsServer() {}
virtual ~BrpcPsServer() {}
virtual uint64_t Start(const std::string &ip, uint32_t port);
virtual int32_t Stop() {
std::unique_lock<std::mutex> lock(mutex_);
stoped_ = true;
cv_.notify_all();
_server.Stop(1000);
_server.Join();
return 0;
}
int32_t Port();
virtual int32_t StartS2S() override;
virtual ::std::future<int32_t> SendPServer2PServerMsg(
int msg_type, int to_pserver_id, const std::string &msg) override;
virtual int32_t ReceiveFromPServer(int msg_type,
int pserver_id,
const std::string &msg) override;
private:
virtual int32_t Initialize();
mutable std::mutex mutex_;
std::condition_variable cv_;
bool stoped_ = false;
brpc::Server _server;
std::shared_ptr<PsBaseService> _service;
std::vector<std::shared_ptr<brpc::Channel>> _pserver_channels;
};
class BrpcPsService;
typedef int32_t (BrpcPsService::*serviceHandlerFunc)(
Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
class BrpcPsService : public PsBaseService {
public:
virtual int32_t Initialize() override;
virtual void service(::google::protobuf::RpcController *controller,
const PsRequestMessage *request,
PsResponseMessage *response,
::google::protobuf::Closure *done) override;
private:
int32_t InitializeShardInfo();
int32_t PullDense(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t PushDense(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t PushDenseParam(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t PushSparseParam(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t PullSparse(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t PullGeoParam(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t Barrier(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t PushSparse(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t LoadOneTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t LoadAllTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t SaveOneTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t SaveAllTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t ShrinkTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t ClearOneTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t ClearAllTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t StopServer(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t StartProfiler(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t StopProfiler(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t PrintTableStat(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t PushGlobalStep(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t CacheShuffle(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t SaveCacheTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t GetCacheThreshold(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
bool _is_initialize_shard_info;
std::mutex _initialize_shard_mutex;
std::unordered_map<int32_t, serviceHandlerFunc> _service_handler_map;
std::unordered_map<int32_t, serviceHandlerFunc> _msg_handler_map;
std::vector<float> _ori_values;
};
class DownpourPServerBrpcClosure : public PServerClosure {
public:
DownpourPServerBrpcClosure(size_t num, PServerCallBack callback)
: PServerClosure(callback) {
_waiting_num = num;
_cntls.resize(num);
_requests.resize(num);
_responses.resize(num);
for (size_t i = 0; i < num; ++i) {
_cntls[i].reset(new brpc::Controller());
}
}
virtual ~DownpourPServerBrpcClosure() {}
virtual void Run() override {
if (_waiting_num.fetch_sub(1) == 1) {
_callback(this);
delete this;
}
}
PsRequestMessage *request(size_t i) { return &_requests[i]; }
PsResponseMessage *response(size_t i) { return &_responses[i]; }
brpc::Controller *cntl(size_t i) { return _cntls[i].get(); }
int check_response(size_t request_idx, int cmd_id) { return 1; }
int check_save_response(size_t request_idx, int cmd_id) { return 1; }
private:
std::atomic<int32_t> _waiting_num;
std::vector<PsRequestMessage> _requests;
std::vector<PsResponseMessage> _responses;
std::vector<std::shared_ptr<brpc::Controller>> _cntls;
};
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/ps/service/brpc_utils.h"
#include <arpa/inet.h>
#include <netdb.h>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
class Variable;
} // namespace framework
} // namespace paddle
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace distributed {
framework::proto::VarType::Type VarMessageToVarType(
VariableMessage::Type type) {
switch (type) {
case VariableMessage::FP32:
return framework::proto::VarType::FP32; // NOLINT
case VariableMessage::FP64:
return framework::proto::VarType::FP64; // NOLINT
case VariableMessage::INT32:
return framework::proto::VarType::INT32; // NOLINT
case VariableMessage::INT64:
return framework::proto::VarType::INT64; // NOLINT
case VariableMessage::BOOL:
return framework::proto::VarType::BOOL; // NOLINT
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"VarMessageToVarType:Unsupported type %d", type));
}
}
void SerializeToMultiVarMsgAndIOBuf(
const std::string& message_name,
const std::vector<std::string>& send_var_name_val,
const std::vector<std::string>& recv_var_name_val,
const platform::DeviceContext& ctx,
const framework::Scope* scope,
MultiVarMsg* request,
butil::IOBuf* iobuf) {
// 1. message_name
request->set_message_name(message_name);
// 2. var_names
for (auto& send_var_name : send_var_name_val) {
request->add_send_var_names(send_var_name);
}
for (auto& recv_var_name : recv_var_name_val) {
request->add_recv_var_names(recv_var_name);
}
// 3. VarMessage
for (auto& send_var_name : send_var_name_val) {
auto* send_var_msg = request->add_var_messages();
butil::IOBuf temp_iobuf;
send_var_msg->set_varname(send_var_name);
framework::Variable* var = scope->FindVar(send_var_name);
if (var->IsType<framework::LoDTensor>()) {
SerializeLodTensor(var, ctx, send_var_msg, &temp_iobuf);
} else if (var->IsType<phi::SelectedRows>()) {
SerializeSelectedRows(var, ctx, send_var_msg, &temp_iobuf);
}
iobuf->append(temp_iobuf);
}
}
void SerializeLodTensor(framework::Variable* var,
const platform::DeviceContext& ctx,
VarMsg* var_msg,
butil::IOBuf* iobuf) {
auto* tensor = var->GetMutable<framework::LoDTensor>();
var_msg->set_type(::paddle::distributed::LOD_TENSOR);
const framework::LoD lod = tensor->lod();
if (lod.size() > 0) {
var_msg->set_lod_level(lod.size());
for (auto& each : lod) {
VarMsg::LodData* lod_inner = var_msg->add_lod();
for (auto& d : each) {
lod_inner->add_lod_data(d);
}
}
}
var_msg->set_data_type(static_cast<VarMsg::Type>(
framework::TransToProtoVarType(tensor->dtype())));
for (auto& dim : phi::vectorize(tensor->dims())) {
var_msg->add_dims(dim);
}
// IO Buffer
if (platform::is_cpu_place(tensor->place())) {
auto data_len = tensor->numel() * framework::DataTypeSize(tensor->dtype());
iobuf->append(reinterpret_cast<const char*>(&data_len), 8);
iobuf->append(reinterpret_cast<const char*>(tensor->data()), data_len);
} else {
#ifdef PADDLE_WITH_CUDA
char* temp_ptr =
new char[tensor->numel() *
framework::DataTypeSize(tensor->dtype())]; // NOLINT
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
memory::Copy(
platform::CPUPlace(),
temp_ptr,
tensor->place(),
tensor->data(),
tensor->numel() * framework::SizeOfType(
framework::TransToProtoVarType(tensor->dtype())),
stream);
auto data_len = tensor->numel() * framework::DataTypeSize(tensor->dtype());
iobuf->append(reinterpret_cast<const char*>(&data_len), 8);
iobuf->append(reinterpret_cast<const char*>(temp_ptr), data_len);
delete[] temp_ptr;
#endif
}
}
void SerializeSelectedRows(framework::Variable* var,
const platform::DeviceContext& ctx,
VarMsg* var_msg,
butil::IOBuf* iobuf) {
phi::SelectedRows* slr = var->GetMutable<phi::SelectedRows>();
auto* tensor = slr->mutable_value();
auto* rows = slr->mutable_rows();
var_msg->set_type(::paddle::distributed::SELECTED_ROWS);
var_msg->set_slr_height(slr->height());
auto* var_data = var_msg->mutable_data();
var_data->clear();
var_data->resize(rows->size() * sizeof(int64_t));
char* data_ptr = const_cast<char*>(var_data->data());
memcpy(data_ptr, &((*rows)[0]), rows->size() * sizeof(int64_t));
var_msg->set_data_type(static_cast<VarMsg::Type>(
framework::TransToProtoVarType(tensor->dtype())));
for (auto& dim : phi::vectorize(tensor->dims())) {
var_msg->add_dims(dim);
}
// IO Buffer
if (platform::is_cpu_place(tensor->place())) {
auto data_len = tensor->numel() * framework::DataTypeSize(tensor->dtype());
iobuf->append(reinterpret_cast<const char*>(&data_len), 8);
iobuf->append(reinterpret_cast<const char*>(tensor->data()), data_len);
} else {
#ifdef PADDLE_WITH_CUDA
char* temp_ptr =
new char[tensor->numel() *
framework::DataTypeSize(tensor->dtype())]; // NOLINT
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
memory::Copy(
platform::CPUPlace(),
temp_ptr,
tensor->place(),
tensor->data(),
tensor->numel() * framework::SizeOfType(
framework::TransToProtoVarType(tensor->dtype())),
stream);
auto data_len = tensor->numel() * framework::DataTypeSize(tensor->dtype());
iobuf->append(reinterpret_cast<const char*>(&data_len), 8);
iobuf->append(reinterpret_cast<const char*>(temp_ptr), data_len);
delete[] temp_ptr;
#endif
}
}
void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg,
const butil::IOBuf* iobuf,
const platform::DeviceContext& ctx,
framework::Scope* scope) {
butil::IOBufBytesIterator io_buffer_itr(*iobuf);
// size_t shard_buffer_remain = res_io_buffer.size();
for (int recv_var_index = 0; recv_var_index < multi_msg.send_var_names_size();
++recv_var_index) {
const auto& msg = multi_msg.var_messages(recv_var_index);
auto* var = scope->Var(msg.varname());
if (msg.type() == ::paddle::distributed::LOD_TENSOR) {
DeserializeLodTensor(var, msg, io_buffer_itr, ctx);
} else if (msg.type() == ::paddle::distributed::SELECTED_ROWS) {
DeserializeSelectedRows(var, msg, io_buffer_itr, ctx);
}
}
}
void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg,
const butil::IOBuf* iobuf,
const platform::DeviceContext& ctx,
const framework::Scope* scope) {
butil::IOBufBytesIterator io_buffer_itr(*iobuf);
// size_t shard_buffer_remain = res_io_buffer.size();
for (int recv_var_index = 0; recv_var_index < multi_msg.send_var_names_size();
++recv_var_index) {
const auto& msg = multi_msg.var_messages(recv_var_index);
auto* var = scope->FindVar(msg.varname());
PADDLE_ENFORCE_NE(var,
nullptr,
platform::errors::InvalidArgument(
"Not find variable %s in scope.", msg.varname()));
if (msg.type() == ::paddle::distributed::LOD_TENSOR) {
DeserializeLodTensor(var, msg, io_buffer_itr, ctx);
} else if (msg.type() == ::paddle::distributed::SELECTED_ROWS) {
DeserializeSelectedRows(var, msg, io_buffer_itr, ctx);
}
}
}
void DeserializeLodTensor(framework::Variable* var,
const VarMsg& msg,
butil::IOBufBytesIterator& io_buffer_itr, // NOLINT
const platform::DeviceContext& ctx) {
const auto place = ctx.GetPlace();
framework::LoDTensor* tensor = var->GetMutable<framework::LoDTensor>();
std::vector<int> vec_dim;
for (auto& x : msg.dims()) {
vec_dim.push_back(x);
}
tensor->Resize(phi::make_ddim(vec_dim));
framework::LoD lod;
for (int i = 0; i < msg.lod_level(); ++i) {
framework::Vector<size_t> v;
for (int j = 0; j < msg.lod(i).lod_data_size(); ++j) {
v.push_back(msg.lod(i).lod_data(j));
}
lod.push_back(v);
}
tensor->set_lod(lod);
void* tensor_data = tensor->mutable_data(
place,
framework::TransToPhiDataType(VarMessageToVarType(msg.data_type())));
// IO Buffer
if (platform::is_cpu_place(place)) {
unsigned long data_len; // NOLINT
io_buffer_itr.copy_and_forward((void*)(&data_len), 8); // NOLINT
io_buffer_itr.copy_and_forward(tensor_data, data_len);
} else if (platform::is_gpu_place(place)) {
#ifdef PADDLE_WITH_CUDA
unsigned long data_len; // NOLINT
char* temp_ptr =
new char[tensor->numel() *
framework::DataTypeSize(tensor->dtype())]; // NOLINT
io_buffer_itr.copy_and_forward((void*)(&data_len), 8); // NOLINT
io_buffer_itr.copy_and_forward((void*)temp_ptr, data_len); // NOLINT
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
memory::Copy(place,
tensor_data,
platform::CPUPlace(),
(void*)temp_ptr, // NOLINT
tensor->numel() * framework::DataTypeSize(tensor->dtype()),
stream);
delete[] temp_ptr;
#endif
}
}
void DeserializeSelectedRows(
framework::Variable* var,
const VarMsg& msg,
butil::IOBufBytesIterator& io_buffer_itr, // NOLINT
const platform::DeviceContext& ctx) {
const auto place = ctx.GetPlace();
auto* slr = var->GetMutable<phi::SelectedRows>();
framework::Tensor* tensor = slr->mutable_value();
slr->set_height(msg.slr_height());
std::vector<int64_t> tmp_rows(msg.dims()[0]);
memcpy(tmp_rows.data(), msg.data().data(), msg.dims()[0] * sizeof(int64_t));
slr->set_rows(tmp_rows);
std::vector<int> vec_dim;
for (auto& x : msg.dims()) {
vec_dim.push_back(x);
}
tensor->Resize(phi::make_ddim(vec_dim));
void* tensor_data = tensor->mutable_data(
place,
framework::TransToPhiDataType(VarMessageToVarType(msg.data_type())));
// IO Buffer
if (platform::is_cpu_place(place)) {
unsigned long data_len; // NOLINT
io_buffer_itr.copy_and_forward((void*)(&data_len), 8); // NOLINT
io_buffer_itr.copy_and_forward(tensor_data, data_len);
} else if (platform::is_gpu_place(place)) {
#ifdef PADDLE_WITH_CUDA
char* temp_ptr =
new char[tensor->numel() *
framework::DataTypeSize(tensor->dtype())]; // NOLINT
unsigned long data_len; // NOLINT
io_buffer_itr.copy_and_forward((void*)(&data_len), 8); // NOLINT
io_buffer_itr.copy_and_forward(temp_ptr, data_len);
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
memory::Copy(place,
tensor_data,
platform::CPUPlace(),
temp_ptr,
tensor->numel() * framework::DataTypeSize(tensor->dtype()),
stream);
delete[] temp_ptr;
#endif
}
}
std::string GetIntTypeEndpoint(const std::string& ip, const uint32_t& port) {
// There are usually two forms of IP address: ip(int) / ip (hostname)
// If there're some problem with DNS, or ip triggers the bug of Brpc
// We will try to get the IP address of the domain name manually again
std::string ip_port = ip + ":" + std::to_string(port);
struct hostent* hp = NULL;
hp = gethostbyname(ip.c_str());
if (NULL == hp) {
LOG(ERROR) << "Brpc Start failed, ip_port= " << ip_port
<< " , Error infomation: " << hstrerror(h_errno);
}
int i = 0;
char* int_ip = NULL;
while (hp->h_addr_list[i] != NULL) {
int_ip = inet_ntoa(*(struct in_addr*)hp->h_addr_list[i]);
VLOG(3) << "Brpc Get host by name, host:" << ip << " -> ip: " << int_ip;
break;
}
std::string str_ip = int_ip;
std::string int_ip_port = str_ip + ":" + std::to_string(port);
return int_ip_port;
}
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <netdb.h>
#include <iostream>
#include <string>
#include <vector>
#include "brpc/channel.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows_utils.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/phi/backends/dynload/port.h"
namespace butil {
class IOBuf;
class IOBufBytesIterator;
} // namespace butil
namespace grpc {
class ByteBuffer;
} // namespace grpc
namespace paddle {
namespace framework {
class Scope;
class Variable;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace distributed {
using MultiVarMsg = ::paddle::distributed::MultiVariableMessage;
using VarMsg = ::paddle::distributed::VariableMessage;
void SerializeToMultiVarMsgAndIOBuf(
const std::string& message_name,
const std::vector<std::string>& send_var_name_val,
const std::vector<std::string>& recv_var_name_val,
const platform::DeviceContext& ctx,
const framework::Scope* scope,
MultiVarMsg* var_msg,
butil::IOBuf* iobuf);
void SerializeLodTensor(framework::Variable* var,
const platform::DeviceContext& ctx,
VarMsg* var_msg,
butil::IOBuf* iobuf);
void SerializeSelectedRows(framework::Variable* var,
const platform::DeviceContext& ctx,
VarMsg* request,
butil::IOBuf* iobuf);
// Deserialize for Server
void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg,
const butil::IOBuf* iobuf,
const platform::DeviceContext& ctx,
framework::Scope* scope);
// Deserialize for Client
void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg,
const butil::IOBuf* iobuf,
const platform::DeviceContext& ctx,
const framework::Scope* scope);
void DeserializeLodTensor(framework::Variable* var,
const VarMsg& msg,
butil::IOBufBytesIterator& iobuf, // NOLINT
const platform::DeviceContext& ctx);
void DeserializeSelectedRows(framework::Variable* var,
const VarMsg& msg,
butil::IOBufBytesIterator& iobuf, // NOLINT
const platform::DeviceContext& ctx);
std::string GetIntTypeEndpoint(const std::string& ip, const uint32_t& port);
} // namespace distributed
} // namespace paddle
get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS)
set_source_files_properties(
communicator.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(
communicator
SRCS communicator.cc
DEPS scope
client
boost
table
math_function
selected_rows_functor
${RPC_DEPS})
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/ps/service/communicator/communicator.h"
#include <google/protobuf/text_format.h>
#include "gflags/gflags.h"
#include "paddle/fluid/distributed/ps/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/ps/wrapper/fleet.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/string/string_helper.h"
#define LEARNING_RATE_DECAY_COUNTER "@LR_DECAY_COUNTER@"
#define STEP_COUNTER "@PS_STEP_COUNTER@"
namespace paddle {
namespace distributed {
using framework::LoDTensor;
using phi::SelectedRows;
const uint32_t MAX_FEASIGN_NUM = 1024 * 100 * 100;
inline double GetCurrentUS() {
struct timeval time;
gettimeofday(&time, NULL);
return 1e+6 * time.tv_sec + time.tv_usec;
}
Communicator::Communicator() {}
void Communicator::InitGFlag(const std::string &gflags) {
VLOG(3) << "Init With Gflags:" << gflags;
std::vector<std::string> flags = paddle::string::split_string(gflags);
if (flags.size() < 1) {
flags.push_back("-max_body_size=314217728");
flags.push_back("-bthread_concurrency=40");
flags.push_back("-socket_max_unwritten_bytes=2048000000");
flags.push_back("-max_connection_pool_size=1950");
}
auto it = flags.begin();
flags.insert(it, "exe default");
char *flags_ptr[flags.size()];
for (size_t i = 0; i < flags.size(); ++i) {
flags_ptr[i] = (char *)(flags[i].c_str()); // NOLINT
}
int params_cnt = flags.size();
char **params_ptr = &(flags_ptr[0]);
::GFLAGS_NAMESPACE::ParseCommandLineFlags(&params_cnt, &params_ptr, true);
}
std::once_flag Communicator::init_flag_;
std::shared_ptr<Communicator> Communicator::communicator_(nullptr);
void Communicator::InitBrpcClient(
const std::string &dist_desc,
const std::vector<std::string> &host_sign_list) {
auto fleet = paddle::distributed::FleetWrapper::GetInstance();
if (_worker_ptr.get() == nullptr) {
_worker_ptr = fleet->worker_ptr_;
}
return;
}
std::vector<uint64_t> Communicator::GetClientInfo() {
std::vector<uint64_t> res = _ps_env.GetClientInfo();
for (auto rr : res) {
VLOG(2) << "Communicator::GetClientInfo " << rr;
}
return res;
}
int Communicator::SetClients(std::vector<uint64_t> &host_sign_list) {
int node = host_sign_list.size();
return _ps_env.SetPsClients(host_sign_list.data(), node);
}
void Communicator::RpcRecvDense(const std::vector<std::string> &varnames,
int table_id,
Scope *scope) {
platform::RecordEvent record_event("Communicator->RpcRecvDense",
platform::TracerEventType::Communication,
1);
std::vector<paddle::distributed::Region> regions;
regions.reserve(varnames.size());
for (auto &t : varnames) {
Variable *var = scope->Var(t);
LoDTensor *tensor = var->GetMutable<LoDTensor>();
if (platform::is_gpu_place(tensor->place())) {
#ifdef PADDLE_WITH_CUDA
Variable *temp_var = xpu_temp_scope_->Var(t);
LoDTensor *temp_tensor = temp_var->GetMutable<LoDTensor>();
temp_tensor->Resize(tensor->dims());
float *temp_data = temp_tensor->mutable_data<float>(platform::CPUPlace());
paddle::distributed::Region reg(temp_data, tensor->numel());
regions.emplace_back(std::move(reg));
VLOG(1) << "AsyncCommunicator::RpcRecvDense Var " << t << " table_id "
<< table_id << " Temp_data[0] " << temp_data[0]
<< " Temp_data[-1] " << temp_data[tensor->numel() - 1];
#endif
} else {
float *w = tensor->mutable_data<float>(tensor->place());
paddle::distributed::Region reg(w, tensor->numel());
regions.emplace_back(std::move(reg));
}
}
auto status =
_worker_ptr->PullDense(regions.data(), regions.size(), table_id);
status.wait();
for (auto &t : varnames) {
Variable *var = scope->FindVar(t);
LoDTensor *tensor = var->GetMutable<LoDTensor>();
VLOG(3) << "AsyncCommunicator::RecvNoBarrier Var " << t << " On gpu? "
<< platform::is_gpu_place(tensor->place());
float *temp_recv_data = tensor->mutable_data<float>(platform::CPUPlace());
VLOG(3) << "AsyncCommunicator::RpcRecvDense Var " << t << " table_id "
<< table_id << " Temp_data[0] " << temp_recv_data[0]
<< " Temp_data[-1] " << temp_recv_data[tensor->numel() - 1];
if (platform::is_gpu_place(tensor->place())) {
#ifdef PADDLE_WITH_CUDA
LoDTensor *temp_tensor =
xpu_temp_scope_->FindVar(t)->GetMutable<LoDTensor>();
framework::TensorCopy(*temp_tensor, tensor->place(), tensor);
float *temp_data = temp_tensor->mutable_data<float>(platform::CPUPlace());
VLOG(1) << "AsyncCommunicator::RpcRecvDense Var " << t << " table_id "
<< table_id << " Temp_data[0] " << temp_data[0]
<< " Temp_data[-1] " << temp_data[tensor->numel() - 1];
#endif
}
}
return;
}
void Communicator::RpcSendDenseParam(const std::vector<std::string> &varnames,
int table_id,
const Scope &scope) {
platform::RecordEvent record_event("Communicator->RpcSendDenseParam",
platform::TracerEventType::Communication,
1);
auto place = platform::CPUPlace();
std::vector<paddle::distributed::Region> regions;
for (auto &t : varnames) {
Variable *var = scope.FindVar(t);
CHECK(var != nullptr) << "var[" << t << "] not found";
LoDTensor *tensor = var->GetMutable<LoDTensor>();
if (platform::is_gpu_place(tensor->place())) {
#ifdef PADDLE_WITH_CUDA
Variable *temp_var = xpu_temp_scope_->Var(t);
LoDTensor *temp_tensor = temp_var->GetMutable<LoDTensor>();
temp_tensor->Resize(tensor->dims());
float *temp_data = temp_tensor->mutable_data<float>(platform::CPUPlace());
framework::TensorCopy(*tensor, platform::CPUPlace(), temp_tensor);
paddle::distributed::Region reg(temp_data, tensor->numel());
regions.emplace_back(std::move(reg));
VLOG(1) << "AsyncCommunicator::RpcSendDenseParam Var " << t
<< " table_id " << table_id << " Temp_data[0] " << temp_data[0]
<< " Temp_data[-1] " << temp_data[tensor->numel() - 1];
#endif
} else {
float *w = tensor->mutable_data<float>(place);
paddle::distributed::Region reg(w, tensor->numel());
regions.emplace_back(std::move(reg));
VLOG(1) << "AsyncCommunicator::RpcSendDenseParam Var " << t
<< " talbe_id " << table_id << " Temp_data[0] " << w[0]
<< " Temp_data[-1] " << w[tensor->numel() - 1];
}
}
auto status =
_worker_ptr->PushDenseParam(regions.data(), regions.size(), table_id);
status.wait();
VLOG(4) << "RPC Send Dense Param " << table_id << " done!";
return;
}
void Communicator::RpcSendDense(const CommContext &ctx, const Scope &scope) {
platform::RecordEvent record_event("Communicator->RpcSendDense",
platform::TracerEventType::Communication,
1);
auto &var_names = ctx.origin_varnames;
auto &table_id = ctx.table_id;
auto dense_data = std::make_shared<std::vector<float>>();
size_t request_call_num = _worker_ptr->GetServerNums();
uint32_t num_per_shard =
DenseDimPerShard(ctx.height_sections[0], request_call_num);
dense_data->resize(num_per_shard *
request_call_num); // accessor->update_dim() = 1
float *data = dense_data->data();
uint32_t pos = 0;
for (size_t i = 0; i < var_names.size(); ++i) {
const LoDTensor tensor = scope.FindVar(var_names[i])->Get<LoDTensor>();
size_t count = static_cast<size_t>(tensor.numel());
const float *g = tensor.data<float>();
CHECK(pos + count <= dense_data->size())
<< "invalid dense size, cur pos[" << pos << "]"
<< " data_num[" << count << "] size[" << dense_data->size() << "]";
memcpy(data + pos, g, count * sizeof(float));
pos += count;
}
++_async_call_num;
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [this, request_call_num](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done; // NOLINT
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, PS_PUSH_DENSE_TABLE) != 0) {
ret = -1;
break;
}
}
closure->set_promise_value(ret);
--_async_call_num;
});
auto status = _worker_ptr->PushDenseRawGradient(
table_id, data, dense_data->size(), closure);
status.wait();
return;
}
void Communicator::RpcSendSparseParam(const std::string &varname,
int table_id,
const Scope &scope) {
platform::RecordEvent record_event("Communicator->RpcSendSparseParam",
platform::TracerEventType::Communication,
1);
size_t request_call_num = _worker_ptr->GetServerNums();
std::vector<float *> push_g_vec;
auto *send_var = scope.FindVar(varname);
auto *tensor = send_var->GetMutable<framework::LoDTensor>();
auto dim = tensor->dims()[1];
uint64_t sparse_num = static_cast<uint64_t>(tensor->dims()[0]);
std::vector<uint64_t> sparse_push_keys(sparse_num);
std::iota(sparse_push_keys.begin(), sparse_push_keys.end(), 0);
push_g_vec.reserve(sparse_num);
for (auto i = 0; i < static_cast<int>(sparse_push_keys.size()); ++i) {
push_g_vec.push_back(tensor->data<float>() + i * dim);
}
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [this, request_call_num](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done; // NOLINT
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, PS_PUSH_SPARSE_PARAM) != 0) {
ret = -1;
break;
}
}
closure->set_promise_value(ret);
});
auto status = _worker_ptr->PushSparseParam(table_id,
sparse_push_keys.data(),
(const float **)push_g_vec.data(),
sparse_push_keys.size(),
closure);
status.wait();
return;
}
void Communicator::RpcSendSparse(const std::string &var_name,
int table_id,
const Scope &scope) {
platform::RecordEvent record_event("Communicator->RpcSendSparse",
platform::TracerEventType::Communication,
1);
size_t request_call_num = _worker_ptr->GetServerNums();
std::vector<uint64_t> sparse_push_keys;
std::vector<float *> push_g_vec;
auto *send_var = scope.FindVar(var_name);
auto *tensor = send_var->GetMutable<phi::SelectedRows>();
auto dim = tensor->value().dims()[1];
std::transform(tensor->rows().begin(),
tensor->rows().end(),
std::back_inserter(sparse_push_keys),
[&](int64_t id) { return static_cast<uint64_t>(id); });
for (auto i = 0; i < static_cast<int>(sparse_push_keys.size()); ++i) {
push_g_vec.push_back(tensor->mutable_value()->data<float>() + i * dim);
}
// TODO(wangguanqun): padding_idx is not ignored, this is a bug.
// if padding_idx == padding in datareader, the server will core.
/*
for (size_t i = 0; i < tensor->rows().size(); ++i) {
uint64_t real_id = static_cast<uint64_t>(tensor->rows()[i]);
if (real_id != 0) {
sparse_push_keys.push_back(real_id);
push_g_vec.push_back(tensor->mutable_value()->data<float>() + i * dim);
}
}
*/
++_async_call_num;
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [this, request_call_num](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done; // NOLINT
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, PS_PUSH_SPARSE_TABLE) != 0) {
ret = -1;
break;
}
}
closure->set_promise_value(ret);
--_async_call_num;
});
auto status =
_worker_ptr->PushSparseRawGradient(table_id,
sparse_push_keys.data(),
(const float **)push_g_vec.data(),
sparse_push_keys.size(),
closure);
status.wait();
return;
}
void Communicator::RpcRecvSparse(const std::string &varname,
int table_id,
Scope *scope) {
platform::RecordEvent record_event("Communicator->RpcRecvSparse",
platform::TracerEventType::Communication,
1);
auto *send_var = scope->Var(varname);
auto *tensor = send_var->GetMutable<framework::LoDTensor>();
auto dim = tensor->dims()[1];
uint64_t sparse_num = static_cast<uint64_t>(tensor->dims()[0]);
std::vector<uint64_t> sparse_push_keys(sparse_num);
std::iota(sparse_push_keys.begin(), sparse_push_keys.end(), 0);
std::vector<float *> push_g_vec;
for (auto i = 0; i < static_cast<int>(sparse_push_keys.size()); ++i) {
push_g_vec.push_back(tensor->data<float>() + i * dim);
}
bool training = true;
auto status =
_worker_ptr->PullSparseParam(static_cast<float **>(push_g_vec.data()),
table_id,
sparse_push_keys.data(),
sparse_push_keys.size(),
training);
status.wait();
return;
}
void Communicator::InitParams(const RecvCtxMap &recv_varname_to_ctx) {
if (trainer_id_ == 0) {
for (auto &iter : recv_varname_to_ctx) {
auto &table_id = iter.first;
auto &varnames = iter.second;
RpcSendDenseParam(varnames, table_id, *recv_scope_);
VLOG(1) << "push dense param to table " << table_id
<< " from 0' trainer done";
}
}
return;
}
void Communicator::PullDense(const RecvCtxMap &recv_varname_to_ctx) {
for (auto &iter : recv_varname_to_ctx) {
auto &table_id = iter.first;
auto &varnames = iter.second;
RpcRecvDense(varnames, table_id, recv_scope_);
VLOG(1) << "pull dense param to table " << table_id
<< " from 0' trainer done";
}
return;
}
void Communicator::RpcProfilerControl() {
if (trainer_id_ == 0) {
if (!do_server_profiler_ && platform::IsProfileEnabled()) {
// send profiler start flag
do_server_profiler_ = true;
auto start_status = _worker_ptr->StartProfiler();
start_status.wait();
} else if (do_server_profiler_ && !platform::IsProfileEnabled()) {
// send profiler end flag
auto stop_status = _worker_ptr->StopProfiler();
stop_status.wait();
do_server_profiler_ = false;
}
}
}
void Communicator::SendGlobalStep(const CommContext &ctx,
int batches,
Scope *send_scope) {
if (batches == 0) {
return;
}
platform::RecordEvent record_event("Communicator->SendGlobalStep",
platform::TracerEventType::Communication,
1);
auto &table_id = ctx.table_id;
size_t request_call_num = _worker_ptr->GetServerNums();
auto &var_name = STEP_COUNTER;
auto *out_var = send_scope->Var(var_name);
auto *out_t = out_var->GetMutable<framework::LoDTensor>();
auto *data = out_t->mutable_data<int64_t>({1}, platform::CPUPlace());
data[0] = static_cast<int64_t>(batches);
VLOG(3) << "Communicator::SendGlobalStep send: " << batches;
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [this, request_call_num](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done; // NOLINT
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, PS_PUSH_GLOBAL_STEP) != 0) {
ret = -1;
break;
}
}
closure->set_promise_value(ret);
});
auto status = _worker_ptr->PushGlobalStep(table_id, data, closure);
status.wait();
return;
}
void AsyncCommunicator::RecvThread() {
if (!independent_recv_) return;
VLOG(3) << "Independent RecvThread Start and Wait";
while (running_) {
int grad_num = grad_num_.load();
if (grad_num > min_send_grad_num_before_recv_) {
RecvByCommunicator();
grad_num_.store(0);
} else {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
}
VLOG(1) << "communicator stopped, independent recv thread exit";
}
void AsyncCommunicator::RecvByCommunicator() {
if (!running_) return;
RecvNoBarrier();
VLOG(3) << "run recv graph end";
}
void AsyncCommunicator::RecvNoBarrier() {
for (auto &iter : recv_varname_to_ctx_) {
auto &table_id = iter.first;
auto &varnames = iter.second;
RpcRecvDense(varnames, table_id, recv_scope_);
}
for (auto &iter : recv_varname_to_ctx_) {
auto var_names = iter.second;
for (auto &t : var_names) {
Variable *var = recv_scope_->FindVar(t);
LoDTensor *tensor = var->GetMutable<LoDTensor>();
VLOG(3) << "AsyncCommunicator::RecvNoBarrier Var " << t << " On gpu? "
<< platform::is_gpu_place(tensor->place());
if (platform::is_gpu_place(tensor->place())) {
#ifdef PADDLE_WITH_CUDA
LoDTensor *temp_tensor =
xpu_temp_scope_->FindVar(t)->GetMutable<LoDTensor>();
framework::TensorCopy(*temp_tensor, tensor->place(), tensor);
#endif
}
}
}
return;
}
void AsyncCommunicator::SendByCommunicator() {
std::vector<std::future<void>> tasks;
tasks.reserve(send_varname_to_ctx_.size());
for (auto &iter : send_varname_to_ctx_) {
auto &ctx = iter.second;
auto send_recv_task = [this, &ctx] {
auto &varnames = ctx.origin_varnames;
auto &table_id = ctx.table_id;
size_t var_nums = varnames.size();
auto &check_queue = send_varname_to_queue_[varnames[0]];
std::vector<std::vector<std::shared_ptr<Variable>>> vars;
vars.resize(var_nums);
int merged_var_num = 0;
int wait_times = 0;
while (merged_var_num < max_merge_var_num_) {
if (check_queue->Size() == 0) {
VLOG(4) << "wait_times -> " << wait_times;
if (wait_times >= send_wait_times_) {
break;
}
std::this_thread::sleep_for(std::chrono::milliseconds(10));
wait_times++;
continue;
} else {
wait_times = 0;
for (size_t i = 0; i < var_nums; i++) {
auto &var_name = varnames[i];
auto &var_queue = send_varname_to_queue_[var_name];
vars[i].push_back(var_queue->Pop());
}
merged_var_num++;
}
}
if (merged_var_num == 0) return;
for (size_t i = 0; i < var_nums; i++) {
auto &var_name = varnames[i];
if (var_name == STEP_COUNTER) {
MergeVars<int64_t>(var_name, vars[i], send_scope_.get(), 1);
} else {
MergeVars<float>(var_name, vars[i], send_scope_.get(), 1);
}
}
if (ctx.is_tensor_table) {
SendGlobalStep(ctx, merged_var_num, send_scope_.get());
} else if (ctx.is_sparse) {
PADDLE_ENFORCE_EQ(
varnames.size(),
1,
platform::errors::InvalidArgument(
"sparse variables can only be merged by one variables"));
RpcSendSparse(varnames[0], table_id, *send_scope_);
} else {
RpcSendDense(ctx, *send_scope_);
if (!independent_recv_ &&
recv_varname_to_ctx_.find(table_id) != recv_varname_to_ctx_.end()) {
auto recv_varnames = recv_varname_to_ctx_.at(table_id);
RpcRecvDense(recv_varnames, table_id, recv_scope_);
}
}
if (independent_recv_) {
grad_num_.fetch_add(1, std::memory_order_relaxed);
}
};
tasks.emplace_back(send_threadpool_->enqueue(std::move(send_recv_task)));
}
for (auto &task : tasks) {
task.wait();
}
return;
}
void AsyncCommunicator::PushDensePostProcessing() {
if (independent_recv_) {
grad_num_.fetch_add(1, std::memory_order_relaxed);
}
return;
}
void AsyncCommunicator::MainThread() {
VLOG(3) << "AsyncCommunicator MainThread start and wait";
while (waiting_ && running_) {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
VLOG(3) << "wait for running";
}
while (running_) {
SendByCommunicator();
RpcProfilerControl();
}
VLOG(1) << "communicator stopped, send thread exit";
}
void AsyncCommunicator::PullSparseToTensorSync(
const uint64_t table_id,
int fea_dim,
uint64_t padding_id,
platform::Place place,
bool is_training,
std::vector<const LoDTensor *> *inputs,
std::vector<LoDTensor *> *outputs) {
std::vector<uint64_t> fea_keys;
std::vector<float *> pull_result_ptr;
fea_keys.reserve(MAX_FEASIGN_NUM / 100);
pull_result_ptr.reserve(MAX_FEASIGN_NUM / 100);
std::vector<float> init_value(fea_dim, 0);
framework::LoDTensor *output = nullptr;
float *output_data = nullptr;
size_t output_index = -1;
size_t output_len = 0;
for (size_t index = 0; index < inputs->size(); ++index) {
const framework::LoDTensor *tensor = inputs->at(index);
const int64_t *ids = tensor->data<int64_t>();
size_t len = tensor->numel();
for (size_t i = 0; i < len; ++i, output_len += fea_dim) {
if (!output || output_len == size_t(output->numel())) {
++output_index;
CHECK(output_index < outputs->size()); // NOLINT
output = outputs->at(output_index);
output->set_lod(tensor->lod());
output_data = output->mutable_data<float>(place);
output_len = 0;
CHECK(output->numel() % fea_dim == 0); // NOLINT
CHECK(output_data != nullptr); // NOLINT
}
uint64_t real_id = static_cast<uint64_t>(ids[i]);
if (real_id == padding_id) {
memcpy(output_data + output_len,
init_value.data(),
sizeof(float) * fea_dim);
continue;
}
fea_keys.push_back(real_id);
pull_result_ptr.push_back(output_data + output_len);
}
}
auto status = _worker_ptr->PullSparse(pull_result_ptr.data(),
table_id,
fea_keys.data(),
fea_keys.size(),
is_training);
status.wait();
auto ret = status.get();
if (ret != 0) {
LOG(ERROR) << "fleet pull sparse failed, status[" << ret << "]";
sleep(sleep_seconds_before_fail_exit_);
}
}
void AsyncCommunicator::PushSparseFromTensorAsync(
const uint64_t table_id,
int fea_dim,
uint64_t padding_id,
platform::Place place,
std::vector<const framework::LoDTensor *> *inputs,
const framework::LoDTensor *shows,
const framework::LoDTensor *clks,
std::vector<framework::LoDTensor *> *outputs) {
int batch_size = -1;
bool batch_size_consist = true;
for (auto *input : *inputs) {
int cur_batch_size =
input->lod().size() ? input->lod()[0].size() - 1 : input->dims()[0];
if (batch_size == -1) {
batch_size = cur_batch_size;
} else if (batch_size != cur_batch_size) {
// CHECK(batch_size == cur_batch_size); // NOLINT
batch_size_consist = false;
break;
}
}
CHECK(batch_size > 0); // NOLINT
int show_size =
shows->lod().size() ? shows->lod()[0].size() - 1 : shows->dims()[0];
CHECK(show_size == batch_size || show_size == 1);
int clk_size =
clks->lod().size() ? clks->lod()[0].size() - 1 : clks->dims()[0];
CHECK(clk_size == batch_size || clk_size == 1);
CHECK(outputs->size() == inputs->size());
std::vector<uint64_t> push_keys;
push_keys.reserve(MAX_FEASIGN_NUM / 100);
std::vector<std::vector<float>> push_values;
push_values.reserve(MAX_FEASIGN_NUM / 100);
size_t output_len = 0;
size_t input_idx = 0;
VLOG(2) << "fleet.cc::emb_dim: " << fea_dim << " batch_size: " << batch_size
<< " batch_size_consist: " << batch_size_consist;
// TODO(zhaocaibei123): check type of show/clk is int? float? uint64?
// const long int* show_tensor = shows->data<int64_t>();
// const long int* clk_tensor = clks->data<int64_t>();
for (size_t index = 0; index < inputs->size(); ++index) {
framework::LoDTensor *g_tensor = outputs->at(index);
float *g = g_tensor->data<float>();
if (batch_size_consist) { // TODO(zhaocaibei123): add config
// scale_sparse_gradient_with_batch_size_
Eigen::Map<
Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
g_mat(g, g_tensor->numel() / fea_dim, fea_dim);
g_mat.rightCols(fea_dim - 2) *=
batch_size; // hard code here, because of cvm_grad op
}
const framework::LoDTensor *tensor = inputs->at(index);
const int64_t *ids = tensor->data<int64_t>();
size_t len = tensor->numel();
output_len = 0;
if (tensor->lod().size() > 0) {
for (size_t i = 0; i < tensor->lod()[0].size() - 1; ++i) {
for (size_t j = tensor->lod()[0][i]; j < tensor->lod()[0][i + 1];
++j, output_len += fea_dim) {
uint64_t real_id = static_cast<uint64_t>(ids[j]);
if (real_id == padding_id) {
continue;
}
push_keys.emplace_back(real_id);
push_values.emplace_back(fea_dim + 1);
// slot show clk grad... consistent with CtrCommonPushValue defined in
// ctr_accessor.h
push_values.back()[0] = 2; // TODO(zhaocaibei123): slot
// push_values.back()[1] =
// (i >= show_size ? 1 : static_cast<float>(show_tensor[i]));
// push_values.back()[2] =
// (i >= clk_size ? 0 : static_cast<float>(clk_tensor[i]));
float *data = push_values.back().data() + 1; // hard code here
memcpy(data, g + output_len, sizeof(float) * fea_dim);
++input_idx;
}
}
} else {
for (size_t i = 0; i < len; ++i, output_len += fea_dim) {
uint64_t real_id = static_cast<uint64_t>(ids[i]);
if (real_id == padding_id) {
continue;
}
push_keys.emplace_back(real_id);
push_values.emplace_back(fea_dim + 1);
// slot show clk grad... consistent with CtrCommonPushValue defined in
// ctr_accessor.h
push_values.back()[0] = 2; // TODO(zhaocaibei123): slot
// push_values.back()[1] =
// (i >= show_size ? 1 : static_cast<float>(show_tensor[i]));
// push_values.back()[2] =
// (i >= clk_size ? 0 : static_cast<float>(clk_tensor[i]));
float *data = push_values.back().data() + 1;
memcpy(data, g + output_len, sizeof(float) * fea_dim);
++input_idx;
}
}
CHECK(static_cast<int64_t>(output_len) == g_tensor->numel());
}
std::vector<float *> push_g_vec(input_idx, nullptr);
for (auto i = 0u; i < push_keys.size(); ++i) {
push_g_vec[i] = push_values.at(i).data();
}
PADDLE_ENFORCE_EQ(
this->Check(table_id),
true,
platform::errors::InvalidArgument(
"can not find table: %s, please check your config", table_id));
auto status = _worker_ptr->PushSparse(table_id,
push_keys.data(),
(const float **)push_g_vec.data(),
push_keys.size());
}
void HalfAsyncCommunicator::MainThread() {
VLOG(3) << "HalfAsyncCommunicator MainThread start and wait";
while (waiting_ && running_) {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
VLOG(3) << "wait for running";
}
while (running_) {
SendByCommunicator();
BarrierSend();
RecvByCommunicator();
BarrierRecv();
BarrierWeakUp();
}
VLOG(1) << "communicator stopped, send thread exit";
}
void AsyncCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
const RecvCtxMap &recv_varname_to_ctx,
Scope *recv_scope) {
send_varname_to_ctx_ = std::move(send_varname_to_ctx);
recv_varname_to_ctx_ = std::move(recv_varname_to_ctx);
recv_scope_ = std::move(recv_scope);
send_scope_.reset(new Scope());
xpu_temp_scope_.reset(new Scope());
for (auto &iter : send_varname_to_ctx_) {
auto &ctx = iter.second;
auto &varnames = ctx.origin_varnames;
for (auto &var_name : varnames) {
send_varname_to_queue_[var_name] =
std::make_shared<BlockingQueue<std::shared_ptr<Variable>>>(
send_queue_size_);
}
}
send_threadpool_.reset(new ::ThreadPool(thread_pool_size_));
}
AsyncCommunicator::~AsyncCommunicator() {
running_ = false;
if (main_thread_) main_thread_->join();
if (recv_thread_) recv_thread_->join();
}
void AsyncCommunicator::Start() {
VLOG(1) << "Communicator start";
if (!communicator_) {
VLOG(0) << "Communicator is not inited, do nothing";
} else {
VLOG(1) << "start send thread and recv thread";
waiting_ = true;
running_ = true;
// flushing_ = false;
BarrierTriggerReset(max_merge_var_num_);
// start send and recv thread
main_thread_.reset(
new std::thread(std::bind(&AsyncCommunicator::MainThread, this)));
if (independent_recv_) {
recv_thread_.reset(
new std::thread(std::bind(&AsyncCommunicator::RecvThread, this)));
}
}
}
void AsyncCommunicator::Stop() {
VLOG(1) << "Communicator stop begin";
running_ = false;
if (!communicator_) {
VLOG(0) << "Communicator is not inited, do nothing";
} else {
// _worker_ptr->FinalizeWorker();
VLOG(1) << "client finalize_worker done";
if (recv_thread_) {
VLOG(1) << "stop recv thread";
recv_thread_->join();
recv_thread_.reset(nullptr);
}
if (main_thread_) {
VLOG(1) << "stop main thread";
main_thread_->join();
main_thread_.reset(nullptr);
}
}
VLOG(1) << "Communicator stop done";
}
bool AsyncCommunicator::Check(const std::vector<std::string> &var_tables) {
PADDLE_ENFORCE_EQ(
var_tables.size(),
1,
platform::errors::InvalidArgument("var_tables.size() == 1 is permitted"));
auto table_name = var_tables[0];
if (send_varname_to_ctx_.find(table_name) == send_varname_to_ctx_.end()) {
return false;
}
if (table_name == STEP_COUNTER) {
VLOG(3) << "send step_counter into queue";
auto tmp_var = std::make_shared<Variable>();
auto *tensor = tmp_var->GetMutable<framework::LoDTensor>();
tensor->Resize(phi::make_ddim({1}));
auto *out_d = tensor->mutable_data<int64_t>(platform::CPUPlace());
out_d[0] = 1;
send_varname_to_queue_[table_name]->Push(tmp_var);
}
return true;
}
bool AsyncCommunicator::Check(const int table_id) {
for (auto &iter : send_varname_to_ctx_) {
auto &ctx = iter.second;
if (ctx.table_id == table_id) return true;
}
return false;
}
void AsyncCommunicator::Send(const std::vector<std::string> &var_names,
const framework::Scope &scope) {
waiting_ = false;
for (size_t i = 0; i < var_names.size(); i++) {
auto *var = scope.FindVar(var_names[i]);
auto tmp_grad_var = std::make_shared<Variable>();
framework::CopyVariable(*var, tmp_grad_var.get());
send_varname_to_queue_[var_names[i]]->Push(tmp_grad_var);
}
}
void HalfAsyncCommunicator::Clean() {
for (auto &iter : send_varname_to_queue_) {
auto &var_name = iter.first;
auto &var_queue = iter.second;
while (var_queue->Size() > 0) {
var_queue->Pop();
}
VLOG(3) << "clean var: " << var_name << " done";
}
}
void HalfAsyncCommunicator::BarrierTriggerDecrement() {
barrier_trigger_--;
VLOG(3) << "BarrierTriggerDecrement decrement barrier trigger to "
<< barrier_trigger_.load();
}
void HalfAsyncCommunicator::BarrierTriggerReset(int initial_val) {
barrier_trigger_.store(initial_val);
VLOG(3) << "BarrierTriggerReset reset barrier trigger to "
<< barrier_trigger_.load();
}
void HalfAsyncCommunicator::Barrier() {
barrier_counter_++;
if (!running_) {
VLOG(3) << "Communicator is not running, release barrier";
return;
}
{
std::unique_lock<std::mutex> lk(barrier_mutex_);
barrier_cond_.wait(lk, [this] { return (barrier_counter_ == 0); });
}
}
int HalfAsyncCommunicator::BatchesCounter() {
while (running_) {
if (barrier_counter_.load() >= barrier_trigger_.load() &&
barrier_trigger_.load() != 0) {
break;
} else {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
}
return barrier_counter_.load();
}
void HalfAsyncCommunicator::SendByCommunicator() {
int batches = BatchesCounter();
VLOG(1) << "HalfAsyncCommunicator::BatchesCounter = " << batches;
if (batches <= 0) return;
std::vector<std::future<void>> tasks;
tasks.reserve(send_varname_to_ctx_.size());
for (auto &iter : send_varname_to_ctx_) {
auto &ctx = iter.second;
auto send_recv_task = [this, &ctx, batches] {
auto &varnames = ctx.origin_varnames;
auto &table_id = ctx.table_id;
size_t var_nums = varnames.size();
std::vector<std::vector<std::shared_ptr<Variable>>> vars;
vars.resize(var_nums);
for (size_t i = 0; i < var_nums; i++) {
auto &var_name = varnames[i];
auto &var_queue = send_varname_to_queue_[var_name];
for (int j = 0; j < batches; j++) vars[i].push_back(var_queue->Pop());
MergeVars<float>(var_name, vars[i], send_scope_.get(), 1);
}
if (ctx.is_sparse) {
PADDLE_ENFORCE_EQ(
varnames.size(),
1,
platform::errors::InvalidArgument(
"sparse variables can only be merged by one variables"));
RpcSendSparse(varnames[0], table_id, *send_scope_);
} else {
RpcSendDense(ctx, *send_scope_);
}
};
tasks.emplace_back(send_threadpool_->enqueue(std::move(send_recv_task)));
}
for (auto &task : tasks) {
task.wait();
}
return;
}
void HalfAsyncCommunicator::BarrierWeakUp() {
barrier_counter_.store(0);
barrier_cond_.notify_all();
}
void SyncCommunicator::BarrierSend() {
if (!running_) return;
BarrierWithTable(0);
VLOG(4) << "BarrierSend with SyncCommunicator";
}
void SyncCommunicator::BarrierRecv() {
if (!running_) return;
BarrierWithTable(1);
VLOG(4) << "BarrierRecv with SyncCommunicator";
}
void GeoCommunicator::Send(const std::vector<std::string> &var_names,
const framework::Scope &scope) {
platform::RecordEvent record_event(
"GeoCommunicator->Send", platform::TracerEventType::Communication, 1);
waiting_ = false;
auto before_send = GetCurrentUS();
auto table_name = var_names[0];
size_t splited_var_nums =
send_varname_to_ctx_[table_name].splited_varnames.size();
std::unordered_map<std::string, std::unordered_set<int64_t>> ids_table;
for (size_t j = 0; j < splited_var_nums; j++) {
ids_table.insert(std::pair<std::string, std::unordered_set<int64_t>>(
send_varname_to_ctx_[table_name].splited_varnames[j],
std::unordered_set<int64_t>()));
}
auto *var = scope.FindVar(table_name);
PADDLE_ENFORCE_EQ(var->IsType<phi::SelectedRows>(),
true,
platform::errors::InvalidArgument(
"Only need to send Sparse Grad in Geo mode."));
auto &rows = var->Get<phi::SelectedRows>().rows();
// insert ids which has not been record
for (size_t j = 0; j < rows.size(); j++) {
auto ep_idx = rows[j] % splited_var_nums;
ids_table.at(send_varname_to_ctx_[table_name].splited_varnames[ep_idx])
.insert(rows[j]);
}
for (auto &iter : ids_table) {
auto &key = iter.first;
auto &sparse_ids_set = iter.second;
auto sparse_ids_vec = std::make_shared<std::vector<int64_t>>();
sparse_ids_vec->assign(sparse_ids_set.begin(), sparse_ids_set.end());
sparse_id_queues_.at(key)->Put(sparse_ids_vec);
VLOG(3) << "push " << sparse_ids_vec->size() << " ids to " << key
<< "'s queue";
}
auto after_send = GetCurrentUS();
VLOG(2) << "run send op finish. use time " << (after_send - before_send);
}
void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
const RecvCtxMap &recv_varname_to_ctx,
Scope *recv_scope) {
send_varname_to_ctx_ = std::move(send_varname_to_ctx);
recv_varname_to_ctx_ = std::move(recv_varname_to_ctx);
recv_scope_ = std::move(recv_scope);
PADDLE_ENFORCE_GT(
send_varname_to_ctx.size(),
0,
platform::errors::InvalidArgument("send var contexts can not be zero"));
for (auto &iter : send_varname_to_ctx_) {
auto &ctx = iter.second;
if (!ctx.is_sparse) {
parallel_task_nums_ += 1;
continue;
}
auto &varnames = ctx.origin_varnames;
PADDLE_ENFORCE_EQ(
varnames.size(),
1,
platform::errors::InvalidArgument(
"sparse variables can only be merged by one variables"));
for (auto &splited_var : ctx.splited_varnames) {
parallel_task_nums_ += 1;
sparse_id_queues_.insert(
std::pair<std::string,
paddle::framework::Channel<
std::shared_ptr<std::vector<int64_t>>>>(
splited_var,
paddle::framework::MakeChannel<
std::shared_ptr<std::vector<int64_t>>>(send_queue_size_)));
}
}
send_threadpool_.reset(new ::ThreadPool(thread_pool_size_));
delta_scope_.reset(new Scope());
old_scope_.reset(new Scope());
pserver_scope_.reset(new Scope());
}
void GeoCommunicator::InitParams(const RecvCtxMap &recv_varname_to_ctx) {
std::vector<std::future<void>> tasks;
tasks.reserve(recv_varname_to_ctx_.size());
for (auto &iter : recv_varname_to_ctx_) {
auto &table_id = iter.first;
auto &varnames = iter.second;
auto recv_task = [this, &table_id, &varnames] {
InitDense(varnames, table_id);
};
tasks.emplace_back(send_threadpool_->enqueue(std::move(recv_task)));
}
for (auto &task : tasks) {
task.wait();
}
for (auto &iter : send_varname_to_ctx_) {
auto &ctx = iter.second;
if (!ctx.is_sparse) continue;
auto &varname = ctx.origin_varnames[0];
auto &table_id = ctx.table_id;
auto param = varname.substr(0, varname.size() - 5);
InitSparse(param, table_id);
}
return;
}
void GeoCommunicator::InitDense(std::vector<std::string> &varnames,
int table_id) {
if (trainer_id_ == 0) {
RpcSendDenseParam(varnames, table_id, *recv_scope_);
BarrierWithTable(1);
VLOG(1) << "push dense param to table " << table_id
<< " from 0' trainer done";
} else {
BarrierWithTable(1);
RpcRecvDense(varnames, table_id, recv_scope_);
VLOG(1) << "pull dense param to table " << table_id
<< " from 0' trainer done";
}
// copy to old_scope
for (auto &t : varnames) {
auto *global_var = recv_scope_->FindVar(t);
global_var->GetMutable<framework::LoDTensor>();
auto *old_var = old_scope_->Var(t);
old_var->GetMutable<framework::LoDTensor>();
framework::CopyVariable(*global_var, old_var);
// init pserver_scope_
auto *pserver_var = pserver_scope_->Var(t);
pserver_var->GetMutable<framework::LoDTensor>();
framework::CopyVariable(*global_var, pserver_var);
}
VLOG(1) << "init dense table " << table_id << " done";
}
void GeoCommunicator::SendDense(const CommContext &send_ctx) {
platform::RecordEvent record_event("GeoCommunicator->SendDense",
platform::TracerEventType::Communication,
1);
auto &var_names = send_ctx.origin_varnames;
auto &table_id = send_ctx.table_id;
for (auto &varname : var_names) {
auto param_name = GradToParam(varname);
auto *var_latest = recv_scope_->FindVar(param_name);
auto *var_timestamp = old_scope_->FindVar(param_name);
PADDLE_ENFORCE_EQ(var_latest->IsInitialized(),
true,
platform::errors::Unavailable(
"%s is not initialized, please check", param_name));
PADDLE_ENFORCE_EQ(var_timestamp->IsInitialized(),
true,
platform::errors::Unavailable(
"%s is not initialized, please check", param_name));
auto &t_latest = var_latest->Get<framework::LoDTensor>();
auto t_timestamp = var_timestamp->GetMutable<framework::LoDTensor>();
phi::CPUContext cpu_ctx;
auto *var_delta = delta_scope_->Var(varname);
auto *t_delta = var_delta->GetMutable<framework::LoDTensor>();
t_delta->mutable_data<float>(t_latest.dims(), cpu_ctx.GetPlace());
auto blas = phi::funcs::GetBlas<phi::CPUContext, float>(cpu_ctx);
blas.VSUB(t_latest.numel(),
t_latest.data<float>(),
t_timestamp->data<float>(),
t_delta->data<float>());
float coefficient = 1.0 / static_cast<float>(trainers_);
blas.SCAL(t_latest.numel(), coefficient, t_delta->data<float>());
blas.VADD(t_latest.numel(),
t_timestamp->data<float>(),
t_delta->data<float>(),
t_timestamp->data<float>());
}
RpcSendDense(send_ctx, *delta_scope_);
VLOG(1) << "Finish Send Dense " << var_names[0] << ", table_id: " << table_id;
return;
}
void GeoCommunicator::RecvDense(const CommContext &send_ctx) {
platform::RecordEvent record_event("GeoCommunicator->RecvDense",
platform::TracerEventType::Communication,
1);
auto &table_id = send_ctx.table_id;
auto &varnames = recv_varname_to_ctx_.at(table_id);
// 1. recv from pserver
RpcRecvDense(varnames, table_id, pserver_scope_.get());
// 2.1 pserver - old => delta; 2.2 latest + old => latest 2.3 old => pserver
phi::CPUContext cpu_ctx;
for (auto &varname : varnames) {
auto *var_latest = recv_scope_->FindVar(varname);
auto t_latest = var_latest->GetMutable<framework::LoDTensor>();
auto *var_old = old_scope_->FindVar(varname);
auto t_old = var_old->GetMutable<framework::LoDTensor>();
auto *var_pserver = pserver_scope_->FindVar(varname);
auto t_pserver = var_pserver->Get<framework::LoDTensor>();
auto *var_delta = delta_scope_->Var(varname);
auto *t_delta = var_delta->GetMutable<framework::LoDTensor>();
t_delta->mutable_data<float>(t_latest->dims(), cpu_ctx.GetPlace());
auto blas = phi::funcs::GetBlas<phi::CPUContext, float>(cpu_ctx);
blas.VSUB(t_latest->numel(),
t_pserver.data<float>(),
t_old->data<float>(),
t_delta->data<float>());
blas.VADD(t_latest->numel(),
t_latest->data<float>(),
t_delta->data<float>(),
t_latest->data<float>());
blas.VCOPY(
t_latest->numel(), t_pserver.data<float>(), t_old->data<float>());
}
VLOG(1) << "Finish Recv Dense " << varnames[0] << ", table_id: " << table_id;
return;
}
void GeoCommunicator::InitSparse(const std::string &var_name, int table_id) {
VLOG(1) << "Init Sparse " << var_name << " : table " << table_id << " begin.";
if (trainer_id_ == 0) {
RpcSendSparseParam(var_name, table_id, *recv_scope_);
BarrierWithTable(1);
VLOG(1) << "push sparse param to table " << table_id
<< " from 0' trainer done";
} else {
BarrierWithTable(1);
RpcRecvSparse(var_name, table_id, recv_scope_);
VLOG(1) << "pull sparse param to table " << table_id
<< " from 0' trainer done";
}
VLOG(1) << "Init Sparse " << var_name << " : table " << table_id << " done.";
auto *global_var = recv_scope_->FindVar(var_name);
auto *var = old_scope_->Var(var_name);
framework::CopyVariable(*global_var, var);
return;
}
std::vector<int64_t> GeoCommunicator::MergeSparseIds(
const std::string &send_varname) {
platform::RecordEvent record_event("GeoCommunicator->MergeSparseIds",
platform::TracerEventType::Communication,
1);
size_t merge_num = 0, wait_times = 0;
std::unordered_set<int64_t> sparse_ids;
while (merge_num < static_cast<size_t>(max_merge_var_num_)) {
VLOG(3) << "Merge Number of " << send_varname << " = " << merge_num;
if (sparse_id_queues_.at(send_varname)->Size() > 0) {
wait_times = 0;
std::shared_ptr<std::vector<int64_t>> pop_ids = nullptr;
sparse_id_queues_.at(send_varname)->Get(pop_ids);
for (size_t j = 0; j < pop_ids->size(); j++) {
sparse_ids.insert(pop_ids->at(j));
}
merge_num += 1;
VLOG(3) << "sparse_id_queues_(" << send_varname << ") pushed";
} else if (sparse_id_queues_.at(send_varname)->Size() == 0) {
VLOG(3) << "wait_times -> " << wait_times;
if (wait_times >= static_cast<size_t>(send_wait_times_)) {
break;
}
std::this_thread::sleep_for(std::chrono::milliseconds(10));
wait_times++;
continue;
}
}
std::vector<int64_t> res;
res.assign(sparse_ids.begin(), sparse_ids.end());
return res;
}
void GeoCommunicator::SendSparse(const std::string &varname,
std::vector<int64_t> &sparse_ids,
int table_id,
int ep_idx) {
platform::RecordEvent record_event("GeoCommunicator->SendSparse",
platform::TracerEventType::Communication,
1);
if (sparse_ids.size() == 0) {
return;
}
std::string param_name = SplitedGradToParam(varname);
VLOG(1) << "In GeoCommunicator::SendSparse(" << varname << " " << param_name
<< ", ids.size = " << sparse_ids.size() << ", table_id: " << table_id
<< ", ep_idx: " << ep_idx;
auto *var_latest = recv_scope_->FindVar(param_name);
auto *var_old = old_scope_->FindVar(param_name);
PADDLE_ENFORCE_EQ(var_latest->IsInitialized(),
true,
platform::errors::Unavailable(
"%s is not initialized, please check", param_name));
PADDLE_ENFORCE_EQ(var_old->IsInitialized(),
true,
platform::errors::Unavailable(
"%s is not initialized, please check", param_name));
auto &t_latest = var_latest->Get<framework::LoDTensor>();
auto *t_old = var_old->GetMutable<framework::LoDTensor>();
auto dims1 = t_latest.dims()[1];
phi::CPUContext cpu_ctx;
auto *var_delta = delta_scope_->Var(varname);
auto *t_delta = var_delta->GetMutable<phi::SelectedRows>();
auto *var_t_value = t_delta->mutable_value();
var_t_value->Resize({static_cast<int64_t>(sparse_ids.size()), dims1});
auto *t_value = var_t_value->mutable_data<float>(cpu_ctx.GetPlace());
t_delta->set_rows(sparse_ids);
t_delta->set_height(t_latest.dims()[0]);
auto blas = phi::funcs::GetBlas<phi::CPUContext, float>(cpu_ctx);
float coefficient = 1.0 / static_cast<float>(trainers_);
std::vector<float *> push_g_vec;
for (auto j = 0; j < static_cast<int>(sparse_ids.size()); ++j) {
blas.VSUB(dims1,
t_latest.data<float>() + sparse_ids[j] * dims1,
t_old->data<float>() + sparse_ids[j] * dims1,
t_value + j * dims1);
blas.SCAL(dims1, coefficient, t_value + j * dims1);
blas.VADD(dims1,
t_old->data<float>() + sparse_ids[j] * dims1,
t_value + j * dims1,
t_old->data<float>() + sparse_ids[j] * dims1);
push_g_vec.push_back(t_value + j * dims1);
VLOG(5) << "DEBUG GeoCommunicator::SendSparse send sparse key "
<< sparse_ids[j] << " value[0] " << push_g_vec[j][0]
<< " value[-1] " << push_g_vec[j][dims1 - 1];
}
++_async_call_num;
DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [this](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done; // NOLINT
if (closure->check_response(0, PS_PUSH_SPARSE_TABLE) != 0) {
ret = -1;
}
closure->set_promise_value(ret);
--_async_call_num;
});
auto status = _worker_ptr->PushSparseRawGradientPartial(
table_id,
(const uint64_t *)sparse_ids.data(),
(const float **)push_g_vec.data(),
sparse_ids.size(),
closure,
ep_idx);
status.wait();
VLOG(1) << "Finish Send Sparse " << varname
<< ", ids.size = " << sparse_ids.size() << ", table_id: " << table_id;
return;
}
void GeoCommunicator::RecvSparse(const std::string &varname,
int table_id,
int ep_idx) {
platform::RecordEvent record_event("GeoCommunicator->RecvSparse",
platform::TracerEventType::Communication,
1);
// 1. recv from pserver
std::vector<uint64_t> keys;
std::vector<float> values;
auto status = _worker_ptr->PullGeoParam(table_id, &values, &keys, ep_idx);
status.wait();
std::string param = SplitedGradToParam(varname);
VLOG(1) << "RecvSparse receive var: " << varname << " " << param << ", "
<< table_id << "; ids Size: " << keys.size()
<< "; values size: " << values.size();
auto *var_latest = recv_scope_->FindVar(param);
auto *var_old = old_scope_->FindVar(param);
auto *t_latest = var_latest->GetMutable<framework::LoDTensor>();
auto *t_old = var_old->GetMutable<framework::LoDTensor>();
auto dims1 = t_latest->dims()[1];
auto numel = keys.size() * dims1;
std::vector<float> v_delta;
v_delta.resize(numel);
phi::CPUContext cpu_ctx;
auto blas = phi::funcs::GetBlas<phi::CPUContext, float>(cpu_ctx);
for (auto j = 0; j < static_cast<int>(keys.size()); ++j) {
VLOG(5) << "DEBUG GeoCommunicator::RecvSparse recv sparse key" << keys[j]
<< "value[0] " << values[j * dims1] << " value[-1] "
<< values[j * dims1 + dims1 - 1];
float *latest_data = t_latest->data<float>() + keys[j] * dims1;
float *old_data = t_old->data<float>() + keys[j] * dims1;
// pserver - old => delta
blas.VSUB(
dims1, values.data() + j * dims1, old_data, v_delta.data() + j * dims1);
// latest + delta => latest
blas.VADD(dims1, latest_data, v_delta.data() + j * dims1, latest_data);
// pserver => old
blas.VCOPY(dims1, values.data() + j * dims1, old_data);
}
VLOG(1) << "Finish Recv Sparse " << param << ", table_id: " << table_id;
}
void GeoCommunicator::MainThread() {
VLOG(3) << "MainThread start and wait";
while (waiting_ && running_) {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
VLOG(3) << "wait for running";
}
while (running_) {
std::vector<std::future<void>> tasks;
tasks.reserve(parallel_task_nums_);
for (auto &iter : send_varname_to_ctx_) {
auto &ctx = iter.second;
auto &varnames = ctx.origin_varnames;
auto &table_id = ctx.table_id;
if (ctx.is_sparse) {
PADDLE_ENFORCE_EQ(
varnames.size(),
1,
platform::errors::InvalidArgument(
"sparse variables can only be merged by one variables"));
int pserver_num = static_cast<int>(ctx.epmap.size());
for (int ep_idx = 0; ep_idx < pserver_num; ep_idx++) {
// varname: emb@GRAD, param_name: emb, splited_varname: emb.delta0
auto send_recv_task = [this, table_id, ep_idx, &ctx] {
auto splited_varname = ctx.splited_varnames[ep_idx];
auto sparse_ids = MergeSparseIds(splited_varname);
SendSparse(splited_varname, sparse_ids, table_id, ep_idx);
RecvSparse(splited_varname, table_id, ep_idx);
};
tasks.emplace_back(
send_threadpool_->enqueue(std::move(send_recv_task)));
}
} else {
auto send_recv_task = [this, &ctx] {
SendDense(ctx);
RecvDense(ctx);
};
tasks.emplace_back(
send_threadpool_->enqueue(std::move(send_recv_task)));
}
}
for (auto &task : tasks) {
task.wait();
}
}
}
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <ThreadPool.h>
#include <stdint.h>
#include <atomic>
#include <deque>
#include <map>
#include <memory>
#include <numeric>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "gflags/gflags.h"
#include "paddle/fluid/distributed/ps/service/communicator/communicator_common.h"
#include "paddle/fluid/distributed/ps/service/ps_client.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/split.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace distributed {
class PSClient;
struct CommContext;
} // namespace distributed
} // namespace paddle
DECLARE_bool(communicator_is_sgd_optimizer);
namespace paddle {
namespace distributed {
using Scope = framework::Scope;
using Variable = framework::Variable;
template <typename T>
class BlockingQueue {
public:
explicit BlockingQueue(size_t capacity) : capacity_(capacity) {
PADDLE_ENFORCE_GT(capacity_,
0,
platform::errors::InvalidArgument(
"The capacity must be greater than 0."));
}
bool Push(const T &elem) {
std::unique_lock<std::mutex> lock(mutex_);
WaitForWrite(lock);
queue_.push_back(elem);
Notify();
return true;
}
bool WaitForWrite(std::unique_lock<std::mutex> &lock) { // NOLINT
while (FullUnlocked()) {
if (empty_waiters_ != 0) {
empty_cond_.notify_one();
}
full_waiters_++;
full_cond_.wait(lock);
full_waiters_--;
}
return true;
}
bool WaitForRead(std::unique_lock<std::mutex> &lock) { // NOLINT
while (EmptyUnlocked()) {
if (full_waiters_ != 0) {
full_cond_.notify_one();
}
empty_waiters_++;
empty_cond_.wait(lock);
empty_waiters_--;
}
return true;
}
bool EmptyUnlocked() { return queue_.empty(); }
bool FullUnlocked() { return queue_.size() >= capacity_; }
void Notify() {
if (empty_waiters_ != 0 && (!EmptyUnlocked())) {
empty_cond_.notify_one();
}
if (full_waiters_ != 0 && (!FullUnlocked())) {
full_cond_.notify_one();
}
}
bool Push(T &&elem) {
std::unique_lock<std::mutex> lock(mutex_);
WaitForWrite(lock);
queue_.emplace_back(std::move(elem));
Notify();
return true;
}
T Pop() {
std::unique_lock<std::mutex> lock(mutex_);
WaitForRead(lock);
T rc(std::move(queue_.front()));
queue_.pop_front();
Notify();
return rc;
}
size_t Cap() const {
std::lock_guard<std::mutex> lock(mutex_);
return capacity_;
}
size_t Size() const {
std::lock_guard<std::mutex> lock(mutex_);
return queue_.size();
}
private:
int empty_waiters_ = 0;
int full_waiters_ = 0;
std::condition_variable empty_cond_;
std::condition_variable full_cond_;
const size_t capacity_;
std::deque<T> queue_;
mutable std::mutex mutex_;
};
template <typename T,
int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T>
inline void MergeVars(const std::string &var_name,
const std::vector<std::shared_ptr<Variable>> &vars,
Scope *scope,
bool merge_add = true) {
PADDLE_ENFORCE_NE(
vars.empty(),
true,
platform::errors::InvalidArgument("vector vars are empty."));
auto cpu_place = platform::CPUPlace();
auto &var0 = vars[0];
auto *out_var = scope->Var(var_name);
if (var0->IsType<framework::LoDTensor>()) {
auto dims = var0->Get<framework::LoDTensor>().dims();
VLOG(3) << "merge " << var_name << " LoDTensor dims " << dims
<< "; merge add: " << merge_add;
// init output tensor
auto *out_t = out_var->GetMutable<framework::LoDTensor>();
out_t->mutable_data<T>(dims, cpu_place);
// check the input dims
for (auto &var : vars) {
auto &var_t = var->Get<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(
var_t.dims(),
dims,
platform::errors::InvalidArgument("vars should have the same dims."));
}
// set output tensor to 0.
phi::CPUContext cpu_ctx;
phi::funcs::SetConstant<phi::CPUContext, T> constant_functor;
constant_functor(cpu_ctx, out_t, static_cast<T>(0));
// sum all vars to out
auto result = EigenVector<T>::Flatten(*out_t);
for (auto &var : vars) {
auto &in_t = var->Get<framework::LoDTensor>();
auto in = EigenVector<T>::Flatten(in_t);
result.device(*cpu_ctx.eigen_device()) = result + in;
}
if (!merge_add) {
result.device(*cpu_ctx.eigen_device()) =
result / static_cast<T>(vars.size());
}
} else if (var0->IsType<phi::SelectedRows>()) {
auto &slr0 = var0->Get<phi::SelectedRows>();
auto *out_slr = out_var->GetMutable<phi::SelectedRows>();
out_slr->mutable_rows()->clear();
out_slr->mutable_value()->mutable_data<T>({{}}, cpu_place);
std::vector<const phi::SelectedRows *> inputs;
inputs.reserve(vars.size());
for (auto &var : vars) {
inputs.push_back(&var->Get<phi::SelectedRows>());
}
phi::CPUContext dev_ctx;
if (merge_add) {
paddle::operators::math::scatter::MergeAdd<phi::CPUContext, T> merge_add;
merge_add(dev_ctx, inputs, out_slr);
} else {
paddle::operators::math::scatter::MergeAverage<phi::CPUContext, T>
merge_average;
merge_average(dev_ctx, inputs, out_slr);
}
VLOG(3) << "merge " << var_name << " SelectedRows height: " << slr0.height()
<< " dims: " << slr0.value().dims() << "; merge add: " << merge_add;
} else {
PADDLE_THROW(platform::errors::InvalidArgument("unsupported var type: %s!",
var0->Type()));
}
}
using RpcCtxMap = std::unordered_map<std::string, CommContext>;
using RecvCtxMap = std::unordered_map<uint64_t, std::vector<std::string>>;
using SparseValue = std::unordered_map<int64_t, std::vector<float>>;
class Communicator {
public:
Communicator();
explicit Communicator(const std::map<std::string, std::string> &envs_) {
VLOG(3) << "Communicator Init Envs";
for (auto &iter : envs_) {
envs[iter.first] = iter.second;
VLOG(3) << iter.first << ": " << iter.second;
}
barrier_table_id_ = std::stoi(envs.at("barrier_table_id"));
trainer_id_ = std::stoi(envs.at("trainer_id"));
trainers_ = std::stoi(envs.at("trainers"));
}
virtual void InitBrpcClient(const std::string &dist_desc,
const std::vector<std::string> &host_sign_list);
virtual std::vector<uint64_t> GetClientInfo();
virtual int SetClients(std::vector<uint64_t> &host_sign_list); // NOLINT
// 1. recv dense param
virtual void RpcRecvDense(const std::vector<std::string> &varnames,
int table_id,
Scope *scope);
// 2. send dense param
virtual void RpcSendDenseParam(const std::vector<std::string> &varnames,
int table_id,
const Scope &scope);
// 3. send dense grad
virtual void RpcSendDense(const CommContext &ctx, const Scope &scope);
// 4. send sparse grad
virtual void RpcSendSparse(const std::string &var_name,
int table_id,
const Scope &scope);
// 5. send sparse param
virtual void RpcSendSparseParam(const std::string &varname,
int table_id,
const Scope &scope);
// 6. recv sparse param
virtual void RpcRecvSparse(const std::string &varname,
int table_id,
Scope *scope);
// 7. send gloabl step
virtual void SendGlobalStep(const CommContext &ctx,
int batches,
Scope *send_scope);
virtual ~Communicator() {}
virtual void RpcProfilerControl();
virtual void InitParams(const RecvCtxMap &recv_varname_to_ctx);
// note: only for pull dense param first before training
virtual void PullDense(const RecvCtxMap &recv_varname_to_ctx);
virtual void Start() = 0;
virtual void Stop() = 0;
virtual bool IsRunning() { return running_; }
virtual void Clean() {}
virtual bool Check(const int table_id) = 0;
virtual bool Check(const std::vector<std::string> &var_tables) = 0;
virtual void Send(const std::vector<std::string> &var_names,
const framework::Scope &scope) = 0;
virtual void RecvNoBarrier() {}
virtual void Barrier() {}
virtual void BarrierWithTable(uint32_t barrier_type) {
auto rets = _worker_ptr->Barrier(barrier_table_id_, barrier_type);
rets.wait();
int status = rets.get();
PADDLE_ENFORCE_EQ(status,
0,
platform::errors::InvalidArgument(
"The ret status must be 0 when barrier with table"));
}
virtual void CreateC2CConnection(int pserver_timeout_ms,
int pserver_connect_timeout_ms,
int max_retry) {
_worker_ptr->CreateClient2ClientConnection(
pserver_timeout_ms, pserver_connect_timeout_ms, max_retry);
}
virtual void BarrierTriggerDecrement() {}
virtual void BarrierTriggerReset(int init_counter) {}
virtual void InitEnvs() = 0;
virtual void InitImpl(const RpcCtxMap &send_varname_to_ctx,
const RecvCtxMap &recv_varname_to_ctx,
Scope *recv_scope) {}
static Communicator *GetInstance() { return communicator_.get(); }
static std::shared_ptr<Communicator> GetInstantcePtr() {
return communicator_;
}
template <typename T>
static Communicator *InitInstance(
const RpcCtxMap &send_ctx,
const RecvCtxMap &recv_ctx,
const std::string &dist_desc,
const std::vector<std::string> &host_sign_list,
Scope *recv_scope,
const std::map<std::string, std::string> &envs) {
std::call_once(init_flag_,
&Communicator::InitWithRpcCtx<T>,
send_ctx,
recv_ctx,
dist_desc,
host_sign_list,
recv_scope,
std::ref(envs));
return communicator_.get();
}
// Init is called by InitInstance.
template <typename T>
static void InitWithRpcCtx(const RpcCtxMap &send_ctx,
const RecvCtxMap &recv_ctx,
const std::string &dist_desc,
const std::vector<std::string> &host_sign_list,
Scope *recv_scope,
const std::map<std::string, std::string> &envs) {
if (communicator_.get() == nullptr) {
communicator_.reset(new T(std::ref(envs)));
communicator_->InitEnvs();
communicator_->InitBrpcClient(dist_desc, host_sign_list);
communicator_->InitImpl(send_ctx, recv_ctx, recv_scope);
}
}
PSClient *GetPsClient() { return _worker_ptr.get(); }
std::shared_ptr<paddle::distributed::PSClient> GetPsClientPtr() {
return std::move(_worker_ptr);
}
RecvCtxMap &GetRecvCtxMap() { return recv_varname_to_ctx_; }
std::shared_ptr<PSClient> _worker_ptr; // pointer to worker
protected:
bool running_ = false;
bool waiting_ = true;
bool flushing_ = false;
bool do_server_profiler_ = false;
static std::shared_ptr<Communicator> communicator_;
static std::once_flag init_flag_;
std::unordered_map<std::string, std::string> envs;
// 计算每个shard 对 dense的存储量
inline uint32_t DenseDimPerShard(uint32_t dense_dim_total,
uint32_t shard_num) {
return dense_dim_total / shard_num + 1;
}
void InitGFlag(const std::string &gflags);
paddle::distributed::PSParameter _ps_param;
paddle::distributed::PaddlePSEnvironment _ps_env;
int servers_ = 0;
int trainers_;
int trainer_id_ = 0;
int barrier_table_id_ = 0;
RpcCtxMap send_varname_to_ctx_;
RecvCtxMap recv_varname_to_ctx_;
Scope *recv_scope_; // should be global scope
std::unique_ptr<Scope> xpu_temp_scope_;
std::atomic<uint32_t> _async_call_num{0};
};
class AsyncCommunicator : public Communicator {
public:
AsyncCommunicator() : Communicator() {}
explicit AsyncCommunicator(const std::map<std::string, std::string> &envs)
: Communicator(envs) {}
~AsyncCommunicator();
void InitEnvs() {
independent_recv_ = static_cast<bool>(
std::stoi(envs.at("communicator_independent_recv_thread")));
min_send_grad_num_before_recv_ =
std::stoi(envs.at("communicator_min_send_grad_num_before_recv"));
thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size"));
max_merge_var_num_ = std::stoi(envs.at("communicator_max_merge_var_num"));
send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times"));
send_queue_size_ = std::stoi(envs.at("communicator_send_queue_size"));
need_global_step_ =
static_cast<bool>(std::stoi(envs.at("need_global_step")));
}
void Start() override;
void Stop() override;
void InitImpl(const RpcCtxMap &send_varname_to_ctx,
const RecvCtxMap &recv_varname_to_ctx,
Scope *recv_scope) override;
virtual void MainThread();
virtual void RecvThread();
virtual bool Check(const int table_id);
virtual bool Check(const std::vector<std::string> &var_tables);
void Send(const std::vector<std::string> &var_names,
const framework::Scope &scope) override;
virtual void SendByCommunicator();
virtual void RecvByCommunicator();
virtual void RecvNoBarrier();
virtual int BatchesCounter() { return 1; }
virtual void BarrierSend() {}
virtual void BarrierRecv() {}
virtual void BarrierWeakUp() {}
void PushDensePostProcessing();
void PullSparseToTensorSync(
const uint64_t table_id,
int fea_dim,
uint64_t padding_id,
platform::Place place,
bool is_training,
std::vector<const framework::LoDTensor *> *inputs, // NOLINT
std::vector<framework::LoDTensor *> *outputs); // NOLINT
void PushSparseFromTensorAsync(
const uint64_t table_id,
int fea_dim,
uint64_t padding_id,
platform::Place place,
std::vector<const framework::LoDTensor *> *inputs,
const framework::LoDTensor *shows,
const framework::LoDTensor *clicks,
std::vector<framework::LoDTensor *> *outputs);
protected:
std::unordered_map<std::string,
std::shared_ptr<BlockingQueue<std::shared_ptr<Variable>>>>
send_varname_to_queue_;
std::unique_ptr<::ThreadPool> send_threadpool_{nullptr};
int min_send_grad_num_before_recv_;
int thread_pool_size_;
int max_merge_var_num_;
int send_wait_times_;
int send_queue_size_;
bool need_global_step_ = false;
bool independent_recv_ = true;
int parallel_task_nums_ = 0;
int32_t sleep_seconds_before_fail_exit_;
std::unique_ptr<std::thread> main_thread_{nullptr};
std::unique_ptr<std::thread> recv_thread_{nullptr};
std::unique_ptr<Scope> send_scope_; // an independent scope
std::atomic_uint grad_num_{0}; // the num of gradient sent since last recv
};
class HalfAsyncCommunicator : public AsyncCommunicator {
public:
HalfAsyncCommunicator() {}
explicit HalfAsyncCommunicator(const std::map<std::string, std::string> &envs)
: AsyncCommunicator(envs) {}
void InitEnvs() {
// enfore to recv after send
independent_recv_ = false;
min_send_grad_num_before_recv_ = 0;
thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size"));
max_merge_var_num_ = std::stoi(envs.at("communicator_max_merge_var_num"));
send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times"));
send_queue_size_ = std::stoi(envs.at("communicator_send_queue_size"));
need_global_step_ =
static_cast<bool>(std::stoi(envs.at("need_global_step")));
VLOG(1) << "HalfAsyncCommunicator Initialized";
}
void MainThread() override;
void SendByCommunicator() override;
void Clean() override;
void Barrier() override;
void BarrierTriggerDecrement() override;
void BarrierTriggerReset(int initial_val) override;
int BatchesCounter();
void BarrierWeakUp();
protected:
// mutex for Wait for barrier
std::mutex barrier_mutex_;
std::condition_variable barrier_cond_;
std::atomic<int64_t> barrier_trigger_{0};
std::atomic<int64_t> barrier_counter_{0};
};
class SyncCommunicator : public HalfAsyncCommunicator {
public:
SyncCommunicator() : HalfAsyncCommunicator() {}
explicit SyncCommunicator(const std::map<std::string, std::string> &envs)
: HalfAsyncCommunicator(envs) {}
void InitEnvs() {
// enfore to recv after send
independent_recv_ = false;
min_send_grad_num_before_recv_ = 0;
max_merge_var_num_ = std::stoi(envs.at("communicator_max_merge_var_num"));
send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times"));
thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size"));
send_queue_size_ = std::stoi(envs.at("communicator_send_queue_size"));
need_global_step_ =
static_cast<bool>(std::stoi(envs.at("need_global_step")));
VLOG(1) << "SyncCommunicator Initialized";
}
void BarrierSend();
void BarrierRecv();
private:
std::vector<std::string> pserver_endpoints_{};
};
class GeoCommunicator : public AsyncCommunicator {
public:
GeoCommunicator() : AsyncCommunicator() {}
explicit GeoCommunicator(const std::map<std::string, std::string> &envs)
: AsyncCommunicator(envs) {}
void InitImpl(const RpcCtxMap &send_varname_to_ctx,
const RecvCtxMap &recv_varname_to_ctx,
Scope *recv_scope) override;
void InitParams(const RecvCtxMap &recv_varname_to_ctx) override;
void InitDense(std::vector<std::string> &varnames, int table_id); // NOLINT
void InitSparse(const std::string &var_name, int table_id);
void SendDense(const CommContext &send_ctx);
void RecvDense(const CommContext &send_ctx);
std::vector<int64_t> MergeSparseIds(const std::string &varname);
void SendSparse(const std::string &varname,
std::vector<int64_t> &sparse_ids, // NOLINT
int table_id,
int ep_idx);
void RecvSparse(const std::string &varname, int table_id, int ep_idx);
void MainThread() override;
void InitEnvs() {
independent_recv_ = false;
min_send_grad_num_before_recv_ = 0;
send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times"));
thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size"));
// id_queue's size
max_merge_var_num_ = std::stoi(envs.at("communicator_max_merge_var_num"));
send_queue_size_ = max_merge_var_num_;
VLOG(1) << "GeoCommunicator Initialized";
}
void Send(const std::vector<std::string> &var_names,
const framework::Scope &scope) override;
void SendByCommunicator() { return; }
void RecvByCommunicator() override { return; }
inline std::string GradToParam(const std::string var_name) {
std::string param_name = var_name.substr(0, var_name.size() - 5);
return param_name;
}
inline std::string SplitedGradToParam(const std::string delta_name) {
// delta_name: emb.delta0
auto pos = delta_name.find(".block");
std::string param_name = delta_name.substr(0, pos);
return param_name;
}
private:
// parameter for delta calc and send
std::shared_ptr<Scope> delta_scope_;
// parameter for storage the pserver param after last recv
std::shared_ptr<Scope> old_scope_;
// parameter on pserver
std::shared_ptr<Scope> pserver_scope_;
std::unordered_map<
std::string,
paddle::framework::Channel<std::shared_ptr<std::vector<int64_t>>>>
sparse_id_queues_;
};
} // namespace distributed
} // namespace paddle
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment