Commit d2d32668 authored by yuguo960516yuguo's avatar yuguo960516yuguo
Browse files

2.3.0-dtk-22.04.2

parent ad08b8ce
Pipeline #226 failed with stages
in 0 seconds
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <iostream>
#include <sstream>
#include <string>
#include <vector>
namespace paddle {
namespace distributed {
struct CommContext {
CommContext() = default;
CommContext(const std::string &name,
const std::vector<std::string> &names,
const std::vector<std::string> &emap,
const std::vector<int64_t> &sections,
const std::vector<std::string> &origin_names,
int id,
bool merge_add_ = true,
bool is_sparse_ = true,
bool is_distributed_ = false,
int table_id_ = -1,
bool is_tensor_table_ = false,
bool is_datanorm_table_ = false,
int64_t program_id_ = -1)
: var_name(name),
splited_varnames(names),
epmap(emap),
height_sections(sections),
origin_varnames(origin_names),
trainer_id(id),
merge_add(merge_add_),
is_sparse(is_sparse_),
is_distributed(is_distributed_),
table_id(table_id_),
program_id(program_id_),
is_tensor_table(is_tensor_table_),
is_datanorm_table(is_datanorm_table_) {}
CommContext(const CommContext &ctx) {
var_name = ctx.var_name;
splited_varnames = ctx.splited_varnames;
epmap = ctx.epmap;
height_sections = ctx.height_sections;
trainer_id = ctx.trainer_id;
merge_add = ctx.merge_add;
is_sparse = ctx.is_sparse;
origin_varnames = ctx.origin_varnames;
is_distributed = ctx.is_distributed;
table_id = ctx.table_id;
program_id = ctx.program_id;
is_tensor_table = ctx.is_tensor_table;
is_datanorm_table = ctx.is_datanorm_table;
}
std::string print() const {
std::stringstream ss;
ss << "varname: " << var_name << " trainer_id: " << trainer_id << " ";
ss << " table_id: " << table_id;
for (size_t i = 0; i < splited_varnames.size(); i++) {
ss << "slice varname: " << splited_varnames[i] << " ep: " << epmap[i]
<< " section: " << height_sections[i] << " ";
}
ss << "origin varnames: ";
for (size_t i = 0; i < origin_varnames.size(); i++) {
ss << origin_varnames[i] << " ";
}
ss << " aggregation->add: " << merge_add;
ss << " is_sparse: " << is_sparse;
ss << " is_distributed: " << is_distributed << "\n";
ss << " table_id: " << table_id << "\n";
ss << " program_id: " << program_id << "\n";
ss << " is_tensor_table: " << is_tensor_table << "\n";
ss << " is_datanorm_table: " << is_datanorm_table << "\n";
return ss.str();
}
std::string var_name;
std::vector<std::string> splited_varnames;
std::vector<std::string> epmap;
std::vector<int64_t> height_sections;
std::vector<std::string> origin_varnames;
int trainer_id;
bool merge_add;
bool is_sparse;
bool is_distributed;
int table_id;
int64_t program_id;
bool is_tensor_table;
bool is_datanorm_table;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/env.h"
namespace paddle {
namespace distributed {} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <arpa/inet.h>
#include <glog/logging.h>
#include <netinet/in.h>
#include <stdio.h>
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "gflags/gflags.h"
namespace paddle {
namespace distributed {
struct PSHost {
std::string ip;
uint32_t port;
uint32_t rank;
PSHost() = default;
PSHost(const std::string ip, uint32_t port, uint32_t rank)
: ip(ip), port(port), rank(rank) {}
// |---ip---|---port---|--rank--|
// |-32bit--|--20bit---|--12bit-|
uint64_t SerializeToUint64() {
uint64_t host_label = 0;
host_label = inet_addr(ip.c_str());
host_label = host_label << 32;
host_label += (port << 12);
host_label += rank;
return host_label;
}
void ParseFromUint64(uint64_t host_label) {
static uint64_t rank_label_mask = (1L << 12) - 1;
static uint64_t port_label_mask = (1L << 20) - 1;
rank = host_label & rank_label_mask;
port = (host_label >> 12) & port_label_mask;
uint32_t ip_addr = (host_label >> 32);
ip = inet_ntoa(*(in_addr *)&ip_addr); // NOLINT
}
std::string ToString() {
std::stringstream s;
s << "host: " << ip;
s << " port: " << port;
s << " rank: " << rank;
s << " uint: " << SerializeToUint64();
return s.str();
}
// for open source parameter server
std::string SerializeToString() {
std::stringstream s;
s << ip << ":";
s << port << ":";
s << rank;
return s.str();
}
void ParseFromString(std::string endpoint) {
std::vector<std::string> endpoint_info;
StringSplit(endpoint, ':', &endpoint_info);
ip = endpoint_info[0];
port = std::stoi(endpoint_info[1]);
rank = std::stoi(endpoint_info[2]);
}
void StringSplit(const std::string &str,
char sep,
std::vector<std::string> *pieces,
bool ignore_null = true) {
pieces->clear();
if (str.empty()) {
if (!ignore_null) {
pieces->push_back(str);
}
return;
}
size_t pos = 0;
size_t next = str.find(sep, pos);
while (next != std::string::npos) {
pieces->push_back(str.substr(pos, next - pos));
pos = next + 1;
next = str.find(sep, pos);
}
if (!str.substr(pos).empty()) {
pieces->push_back(str.substr(pos));
}
}
};
class PSEnvironment {
public:
explicit PSEnvironment() {} // NOLINT
virtual ~PSEnvironment() {}
virtual int32_t SetPsServers(uint64_t *host_sign_list, int node_num) {
return 0;
}
virtual int32_t SetPsServers(
const std::vector<std::string> *host_endpoint_list, int node_num) {
return 0;
}
virtual int32_t SetPsClients(uint64_t *host_sign_list, int node_num) {
return 0;
}
virtual int32_t SetPsClients(std::string *host_endpoint_list, int node_num) {
return 0;
}
virtual uint64_t GetLocalHostSign() { return 0; }
virtual std::vector<PSHost> GetPsServers() const { return _ps_server_list; }
virtual int32_t RegistePsServer(const std::string &ip,
uint32_t port,
int32_t rank) {
return RegistePsHost(ip, port, rank, _ps_server_list, _ps_server_sign_set);
}
virtual std::vector<PSHost> GetPsClients() const { return _ps_client_list; }
virtual int32_t RegistePsClient(const std::string &ip,
uint32_t port,
int32_t rank) {
return RegistePsHost(ip, port, rank, _ps_client_list, _ps_client_sign_set);
}
virtual std::vector<uint64_t> GetClientInfo() {
std::vector<uint64_t> client_info;
for (auto &i : _ps_client_list) {
client_info.push_back(i.SerializeToUint64());
}
return client_info;
}
virtual std::vector<std::string> GetClientInfo(bool use_string_endpoint) {
if (use_string_endpoint) {
std::vector<std::string> client_info;
for (auto &i : _ps_client_list) {
client_info.push_back(i.SerializeToString());
}
return client_info;
}
return {};
}
virtual void SetTrainers(int trainers) { trainers_ = trainers; }
virtual int GetTrainers() { return trainers_; }
protected:
//注册一个host // NOLINT
virtual int32_t RegistePsHost(
const std::string &ip,
uint32_t port,
int32_t rank,
std::vector<PSHost> &host_list, // NOLINT
std::unordered_set<uint64_t> &sign_set) { // NOLINT
PSHost host;
host.ip = ip;
host.port = port;
host.rank = rank;
if (sign_set.count(rank) == 0) {
host_list.push_back(host);
sign_set.insert(rank);
}
return 0;
}
int trainers_ = 0;
std::vector<PSHost> _ps_client_list;
std::unordered_set<uint64_t> _ps_client_sign_set; // for unique filter
std::vector<PSHost> _ps_server_list;
std::unordered_set<uint64_t> _ps_server_sign_set; // for unique filter
};
class PaddlePSEnvironment : public PSEnvironment {
public:
explicit PaddlePSEnvironment() {} // NOLINT
virtual ~PaddlePSEnvironment() {}
virtual int32_t SetPsServers(uint64_t *host_sign_list, int node_num) {
_ps_server_list.clear();
_ps_server_sign_set.clear();
for (int i = 0; i < node_num; ++i) {
if (host_sign_list[i] > 0) {
PSHost host;
host.ParseFromUint64(host_sign_list[i]);
_ps_server_list.push_back(host);
_ps_server_sign_set.insert(host.SerializeToUint64());
}
}
std::sort(
_ps_server_list.begin(),
_ps_server_list.end(),
[](const PSHost &h1, const PSHost &h2) { return h1.rank < h2.rank; });
return 0;
}
virtual int32_t SetPsServers(const std::vector<std::string> *host_sign_list,
int node_num) {
_ps_server_list.clear();
_ps_server_sign_set.clear();
for (int i = 0; i < node_num; ++i) {
if (host_sign_list->at(i) != "") {
PSHost host;
host.ParseFromString(host_sign_list->at(i));
_ps_server_list.push_back(host);
_ps_server_sign_set.insert(host.rank);
}
}
std::sort(
_ps_server_list.begin(),
_ps_server_list.end(),
[](const PSHost &h1, const PSHost &h2) { return h1.rank < h2.rank; });
return 0;
}
virtual int32_t SetPsClients(uint64_t *host_sign_list, int node_num) {
_ps_client_list.clear();
_ps_client_sign_set.clear();
for (int i = 0; i < node_num; ++i) {
if (host_sign_list[i] > 0) {
PSHost host;
host.ParseFromUint64(host_sign_list[i]);
_ps_client_list.push_back(host);
_ps_client_sign_set.insert(host.SerializeToUint64());
}
}
std::sort(
_ps_client_list.begin(),
_ps_client_list.end(),
[](const PSHost &h1, const PSHost &h2) { return h1.rank < h2.rank; });
return 0;
}
virtual int32_t SetPsClients(const std::vector<std::string> *host_sign_list,
int node_num) {
_ps_client_list.clear();
_ps_client_sign_set.clear();
for (int i = 0; i < node_num; ++i) {
if (host_sign_list->at(i) != "") {
PSHost host;
host.ParseFromString(host_sign_list->at(i));
_ps_client_list.push_back(host);
_ps_client_sign_set.insert(host.rank);
}
}
std::sort(
_ps_client_list.begin(),
_ps_client_list.end(),
[](const PSHost &h1, const PSHost &h2) { return h1.rank < h2.rank; });
VLOG(1) << "env.set_ps_clients done\n";
return 0;
}
virtual uint64_t GetLocalHostSign() {
if (_ps_client_list.size() > 0) {
return _ps_client_list[0].SerializeToUint64();
} else {
return 0;
}
}
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/graph_brpc_client.h"
#include <algorithm>
#include <memory>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
#include "Eigen/Dense"
#include "paddle/fluid/distributed/ps/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/ps/table/table.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace distributed {
void GraphPsService_Stub::service(
::google::protobuf::RpcController *controller,
const ::paddle::distributed::PsRequestMessage *request,
::paddle::distributed::PsResponseMessage *response,
::google::protobuf::Closure *done) {
if (graph_service != NULL && local_channel == channel()) {
// VLOG(0)<<"use local";
task_pool->enqueue([this, controller, request, response, done]() -> int {
this->graph_service->service(controller, request, response, done);
return 0;
});
} else {
// VLOG(0)<<"use server";
PsService_Stub::service(controller, request, response, done);
}
}
int GraphBrpcClient::get_server_index_by_id(int64_t id) {
int shard_num = get_shard_num();
int shard_per_server = shard_num % server_size == 0
? shard_num / server_size
: shard_num / server_size + 1;
return id % shard_num / shard_per_server;
}
std::future<int32_t> GraphBrpcClient::get_node_feat(
const uint32_t &table_id,
int idx_,
const std::vector<int64_t> &node_ids,
const std::vector<std::string> &feature_names,
std::vector<std::vector<std::string>> &res) {
std::vector<int> request2server;
std::vector<int> server2request(server_size, -1);
for (size_t query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]);
if (server2request[server_index] == -1) {
server2request[server_index] = request2server.size();
request2server.push_back(server_index);
}
}
size_t request_call_num = request2server.size();
std::vector<std::vector<int64_t>> node_id_buckets(request_call_num);
std::vector<std::vector<int>> query_idx_buckets(request_call_num);
for (size_t query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]);
int request_idx = server2request[server_index];
node_id_buckets[request_idx].push_back(node_ids[query_idx]);
query_idx_buckets[request_idx].push_back(query_idx);
}
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num,
[&, node_id_buckets, query_idx_buckets, request_call_num](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
size_t fail_num = 0;
for (size_t request_idx = 0; request_idx < request_call_num;
++request_idx) {
if (closure->check_response(request_idx, PS_GRAPH_GET_NODE_FEAT) !=
0) {
++fail_num;
} else {
auto &res_io_buffer =
closure->cntl(request_idx)->response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
size_t bytes_size = io_buffer_itr.bytes_left();
std::unique_ptr<char[]> buffer_wrapper(new char[bytes_size]);
char *buffer = buffer_wrapper.get();
io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);
for (size_t feat_idx = 0; feat_idx < feature_names.size();
++feat_idx) {
for (size_t node_idx = 0;
node_idx < query_idx_buckets.at(request_idx).size();
++node_idx) {
int query_idx = query_idx_buckets.at(request_idx).at(node_idx);
size_t feat_len = *(size_t *)(buffer);
buffer += sizeof(size_t);
auto feature = std::string(buffer, feat_len);
res[feat_idx][query_idx] = feature;
buffer += feat_len;
}
}
}
if (fail_num == request_call_num) {
ret = -1;
}
}
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t request_idx = 0; request_idx < request_call_num; ++request_idx) {
int server_index = request2server[request_idx];
closure->request(request_idx)->set_cmd_id(PS_GRAPH_GET_NODE_FEAT);
closure->request(request_idx)->set_table_id(table_id);
closure->request(request_idx)->set_client_id(_client_id);
size_t node_num = node_id_buckets[request_idx].size();
closure->request(request_idx)->add_params((char *)&idx_, sizeof(int));
closure->request(request_idx)
->add_params((char *)node_id_buckets[request_idx].data(),
sizeof(int64_t) * node_num);
std::string joint_feature_name =
paddle::string::join_strings(feature_names, '\t');
closure->request(request_idx)
->add_params(joint_feature_name.c_str(), joint_feature_name.size());
GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx),
closure->request(request_idx),
closure->response(request_idx),
closure);
}
return fut;
}
std::future<int32_t> GraphBrpcClient::clear_nodes(uint32_t table_id,
int type_id,
int idx_) {
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
server_size, [&, server_size = this->server_size](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
size_t fail_num = 0;
for (size_t request_idx = 0; request_idx < server_size; ++request_idx) {
if (closure->check_response(request_idx, PS_GRAPH_CLEAR) != 0) {
++fail_num;
break;
}
}
ret = fail_num == 0 ? 0 : -1;
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t i = 0; i < server_size; i++) {
int server_index = i;
closure->request(server_index)->set_cmd_id(PS_GRAPH_CLEAR);
closure->request(server_index)->set_table_id(table_id);
closure->request(server_index)->set_client_id(_client_id);
closure->request(server_index)->add_params((char *)&type_id, sizeof(int));
closure->request(server_index)->add_params((char *)&idx_, sizeof(int));
GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(server_index)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(server_index),
closure->request(server_index),
closure->response(server_index),
closure);
}
return fut;
}
std::future<int32_t> GraphBrpcClient::add_graph_node(
uint32_t table_id,
int idx_,
std::vector<int64_t> &node_id_list,
std::vector<bool> &is_weighted_list) {
std::vector<std::vector<int64_t>> request_bucket;
std::vector<std::vector<bool>> is_weighted_bucket;
bool add_weight = is_weighted_list.size() > 0;
std::vector<int> server_index_arr;
std::vector<int> index_mapping(server_size, -1);
for (size_t query_idx = 0; query_idx < node_id_list.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_id_list[query_idx]);
if (index_mapping[server_index] == -1) {
index_mapping[server_index] = request_bucket.size();
server_index_arr.push_back(server_index);
request_bucket.push_back(std::vector<int64_t>());
if (add_weight) is_weighted_bucket.push_back(std::vector<bool>());
}
request_bucket[index_mapping[server_index]].push_back(
node_id_list[query_idx]);
if (add_weight)
is_weighted_bucket[index_mapping[server_index]].push_back(
query_idx < is_weighted_list.size() ? is_weighted_list[query_idx]
: false);
}
size_t request_call_num = request_bucket.size();
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [&, request_call_num](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
size_t fail_num = 0;
for (size_t request_idx = 0; request_idx < request_call_num;
++request_idx) {
if (closure->check_response(request_idx, PS_GRAPH_ADD_GRAPH_NODE) !=
0) {
++fail_num;
}
}
ret = fail_num == request_call_num ? -1 : 0;
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t request_idx = 0; request_idx < request_call_num; ++request_idx) {
int server_index = server_index_arr[request_idx];
closure->request(request_idx)->set_cmd_id(PS_GRAPH_ADD_GRAPH_NODE);
closure->request(request_idx)->set_table_id(table_id);
closure->request(request_idx)->set_client_id(_client_id);
size_t node_num = request_bucket[request_idx].size();
closure->request(request_idx)->add_params((char *)&idx_, sizeof(int));
closure->request(request_idx)
->add_params((char *)request_bucket[request_idx].data(),
sizeof(int64_t) * node_num);
if (add_weight) {
bool weighted[is_weighted_bucket[request_idx].size() + 1];
for (size_t j = 0; j < is_weighted_bucket[request_idx].size(); j++)
weighted[j] = is_weighted_bucket[request_idx][j];
closure->request(request_idx)
->add_params((char *)weighted,
sizeof(bool) * is_weighted_bucket[request_idx].size());
}
// PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx),
closure->request(request_idx),
closure->response(request_idx),
closure);
}
return fut;
}
std::future<int32_t> GraphBrpcClient::remove_graph_node(
uint32_t table_id, int idx_, std::vector<int64_t> &node_id_list) {
std::vector<std::vector<int64_t>> request_bucket;
std::vector<int> server_index_arr;
std::vector<int> index_mapping(server_size, -1);
for (size_t query_idx = 0; query_idx < node_id_list.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_id_list[query_idx]);
if (index_mapping[server_index] == -1) {
index_mapping[server_index] = request_bucket.size();
server_index_arr.push_back(server_index);
request_bucket.push_back(std::vector<int64_t>());
}
request_bucket[index_mapping[server_index]].push_back(
node_id_list[query_idx]);
}
size_t request_call_num = request_bucket.size();
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [&, request_call_num](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
size_t fail_num = 0;
for (size_t request_idx = 0; request_idx < request_call_num;
++request_idx) {
if (closure->check_response(request_idx,
PS_GRAPH_REMOVE_GRAPH_NODE) != 0) {
++fail_num;
}
}
ret = fail_num == request_call_num ? -1 : 0;
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t request_idx = 0; request_idx < request_call_num; ++request_idx) {
int server_index = server_index_arr[request_idx];
closure->request(request_idx)->set_cmd_id(PS_GRAPH_REMOVE_GRAPH_NODE);
closure->request(request_idx)->set_table_id(table_id);
closure->request(request_idx)->set_client_id(_client_id);
size_t node_num = request_bucket[request_idx].size();
closure->request(request_idx)->add_params((char *)&idx_, sizeof(int));
closure->request(request_idx)
->add_params((char *)request_bucket[request_idx].data(),
sizeof(int64_t) * node_num);
// PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx),
closure->request(request_idx),
closure->response(request_idx),
closure);
}
return fut;
}
// char* &buffer,int &actual_size
std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
uint32_t table_id,
int idx_,
std::vector<int64_t> node_ids,
int sample_size,
// std::vector<std::vector<std::pair<int64_t, float>>> &res,
std::vector<std::vector<int64_t>> &res,
std::vector<std::vector<float>> &res_weight,
bool need_weight,
int server_index) {
if (server_index != -1) {
res.resize(node_ids.size());
if (need_weight) {
res_weight.resize(node_ids.size());
}
DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [&](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
if (closure->check_response(0, PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER) !=
0) {
ret = -1;
} else {
auto &res_io_buffer = closure->cntl(0)->response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
size_t bytes_size = io_buffer_itr.bytes_left();
std::unique_ptr<char[]> buffer_wrapper(new char[bytes_size]);
char *buffer = buffer_wrapper.get();
io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);
size_t node_num = *(size_t *)buffer;
int *actual_sizes = (int *)(buffer + sizeof(size_t));
char *node_buffer = buffer + sizeof(size_t) + sizeof(int) * node_num;
int offset = 0;
for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
int actual_size = actual_sizes[node_idx];
int start = 0;
while (start < actual_size) {
res[node_idx].emplace_back(
*(int64_t *)(node_buffer + offset + start));
start += GraphNode::id_size;
if (need_weight) {
res_weight[node_idx].emplace_back(
*(float *)(node_buffer + offset + start));
start += GraphNode::weight_size;
}
}
offset += actual_size;
}
}
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
;
closure->request(0)->set_cmd_id(PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER);
closure->request(0)->set_table_id(table_id);
closure->request(0)->set_client_id(_client_id);
closure->request(0)->add_params((char *)&idx_, sizeof(int));
closure->request(0)->add_params((char *)node_ids.data(),
sizeof(int64_t) * node_ids.size());
closure->request(0)->add_params((char *)&sample_size, sizeof(int));
closure->request(0)->add_params((char *)&need_weight, sizeof(bool));
;
// PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(0)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(
closure->cntl(0), closure->request(0), closure->response(0), closure);
return fut;
}
std::vector<int> request2server;
std::vector<int> server2request(server_size, -1);
res.clear();
res_weight.clear();
for (size_t query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]);
if (server2request[server_index] == -1) {
server2request[server_index] = request2server.size();
request2server.push_back(server_index);
}
// res.push_back(std::vector<std::pair<int64_t, float>>());
res.push_back({});
if (need_weight) {
res_weight.push_back({});
}
}
size_t request_call_num = request2server.size();
std::vector<std::vector<int64_t>> node_id_buckets(request_call_num);
std::vector<std::vector<int>> query_idx_buckets(request_call_num);
for (size_t query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]);
int request_idx = server2request[server_index];
node_id_buckets[request_idx].push_back(node_ids[query_idx]);
query_idx_buckets[request_idx].push_back(query_idx);
}
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num,
[&, node_id_buckets, query_idx_buckets, request_call_num](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
size_t fail_num = 0;
for (size_t request_idx = 0; request_idx < request_call_num;
++request_idx) {
if (closure->check_response(request_idx, PS_GRAPH_SAMPLE_NEIGHBORS) !=
0) {
++fail_num;
} else {
auto &res_io_buffer =
closure->cntl(request_idx)->response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
size_t bytes_size = io_buffer_itr.bytes_left();
std::unique_ptr<char[]> buffer_wrapper(new char[bytes_size]);
char *buffer = buffer_wrapper.get();
io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);
size_t node_num = *(size_t *)buffer;
int *actual_sizes = (int *)(buffer + sizeof(size_t));
char *node_buffer =
buffer + sizeof(size_t) + sizeof(int) * node_num;
int offset = 0;
for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
int query_idx = query_idx_buckets.at(request_idx).at(node_idx);
int actual_size = actual_sizes[node_idx];
int start = 0;
while (start < actual_size) {
res[query_idx].emplace_back(
*(int64_t *)(node_buffer + offset + start));
start += GraphNode::id_size;
if (need_weight) {
res_weight[query_idx].emplace_back(
*(float *)(node_buffer + offset + start));
start += GraphNode::weight_size;
}
}
offset += actual_size;
}
}
if (fail_num == request_call_num) {
ret = -1;
}
}
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t request_idx = 0; request_idx < request_call_num; ++request_idx) {
int server_index = request2server[request_idx];
closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE_NEIGHBORS);
closure->request(request_idx)->set_table_id(table_id);
closure->request(request_idx)->set_client_id(_client_id);
size_t node_num = node_id_buckets[request_idx].size();
closure->request(request_idx)->add_params((char *)&idx_, sizeof(int));
closure->request(request_idx)
->add_params((char *)node_id_buckets[request_idx].data(),
sizeof(int64_t) * node_num);
closure->request(request_idx)
->add_params((char *)&sample_size, sizeof(int));
closure->request(request_idx)
->add_params((char *)&need_weight, sizeof(bool));
// PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx),
closure->request(request_idx),
closure->response(request_idx),
closure);
}
return fut;
}
std::future<int32_t> GraphBrpcClient::random_sample_nodes(
uint32_t table_id,
int type_id,
int idx_,
int server_index,
int sample_size,
std::vector<int64_t> &ids) {
DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [&](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
if (closure->check_response(0, PS_GRAPH_SAMPLE_NODES) != 0) {
ret = -1;
} else {
auto &res_io_buffer = closure->cntl(0)->response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
size_t bytes_size = io_buffer_itr.bytes_left();
char *buffer = new char[bytes_size];
size_t index = 0;
while (index < bytes_size) {
ids.push_back(*(int64_t *)(buffer + index));
index += GraphNode::id_size;
}
delete[] buffer;
}
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
;
closure->request(0)->set_cmd_id(PS_GRAPH_SAMPLE_NODES);
closure->request(0)->set_table_id(table_id);
closure->request(0)->set_client_id(_client_id);
closure->request(0)->add_params((char *)&type_id, sizeof(int));
closure->request(0)->add_params((char *)&idx_, sizeof(int));
closure->request(0)->add_params((char *)&sample_size, sizeof(int));
;
// PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(0)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(
closure->cntl(0), closure->request(0), closure->response(0), closure);
return fut;
}
std::future<int32_t> GraphBrpcClient::pull_graph_list(
uint32_t table_id,
int type_id,
int idx_,
int server_index,
int start,
int size,
int step,
std::vector<FeatureNode> &res) {
DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [&](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
if (closure->check_response(0, PS_PULL_GRAPH_LIST) != 0) {
ret = -1;
} else {
auto &res_io_buffer = closure->cntl(0)->response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
size_t bytes_size = io_buffer_itr.bytes_left();
char *buffer = new char[bytes_size];
io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);
size_t index = 0;
while (index < bytes_size) {
FeatureNode node;
node.recover_from_buffer(buffer + index);
index += node.get_size(false);
res.push_back(node);
}
delete[] buffer;
}
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
closure->request(0)->set_cmd_id(PS_PULL_GRAPH_LIST);
closure->request(0)->set_table_id(table_id);
closure->request(0)->set_client_id(_client_id);
closure->request(0)->add_params((char *)&type_id, sizeof(int));
closure->request(0)->add_params((char *)&idx_, sizeof(int));
closure->request(0)->add_params((char *)&start, sizeof(int));
closure->request(0)->add_params((char *)&size, sizeof(int));
closure->request(0)->add_params((char *)&step, sizeof(int));
// PsService_Stub rpc_stub(GetCmdChannel(server_index));
GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(0)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(
closure->cntl(0), closure->request(0), closure->response(0), closure);
return fut;
}
std::future<int32_t> GraphBrpcClient::set_node_feat(
const uint32_t &table_id,
int idx_,
const std::vector<int64_t> &node_ids,
const std::vector<std::string> &feature_names,
const std::vector<std::vector<std::string>> &features) {
std::vector<int> request2server;
std::vector<int> server2request(server_size, -1);
for (size_t query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]);
if (server2request[server_index] == -1) {
server2request[server_index] = request2server.size();
request2server.push_back(server_index);
}
}
size_t request_call_num = request2server.size();
std::vector<std::vector<int64_t>> node_id_buckets(request_call_num);
std::vector<std::vector<int>> query_idx_buckets(request_call_num);
std::vector<std::vector<std::vector<std::string>>> features_idx_buckets(
request_call_num);
for (size_t query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]);
int request_idx = server2request[server_index];
node_id_buckets[request_idx].push_back(node_ids[query_idx]);
query_idx_buckets[request_idx].push_back(query_idx);
if (features_idx_buckets[request_idx].size() == 0) {
features_idx_buckets[request_idx].resize(feature_names.size());
}
for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
features_idx_buckets[request_idx][feat_idx].push_back(
features[feat_idx][query_idx]);
}
}
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num,
[&, node_id_buckets, query_idx_buckets, request_call_num](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
size_t fail_num = 0;
for (size_t request_idx = 0; request_idx < request_call_num;
++request_idx) {
if (closure->check_response(request_idx, PS_GRAPH_SET_NODE_FEAT) !=
0) {
++fail_num;
}
if (fail_num == request_call_num) {
ret = -1;
}
}
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t request_idx = 0; request_idx < request_call_num; ++request_idx) {
int server_index = request2server[request_idx];
closure->request(request_idx)->set_cmd_id(PS_GRAPH_SET_NODE_FEAT);
closure->request(request_idx)->set_table_id(table_id);
closure->request(request_idx)->set_client_id(_client_id);
size_t node_num = node_id_buckets[request_idx].size();
closure->request(request_idx)->add_params((char *)&idx_, sizeof(int));
closure->request(request_idx)
->add_params((char *)node_id_buckets[request_idx].data(),
sizeof(int64_t) * node_num);
std::string joint_feature_name =
paddle::string::join_strings(feature_names, '\t');
closure->request(request_idx)
->add_params(joint_feature_name.c_str(), joint_feature_name.size());
// set features
std::string set_feature = "";
for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
size_t feat_len =
features_idx_buckets[request_idx][feat_idx][node_idx].size();
set_feature.append((char *)&feat_len, sizeof(size_t));
set_feature.append(
features_idx_buckets[request_idx][feat_idx][node_idx].data(),
feat_len);
}
}
closure->request(request_idx)
->add_params(set_feature.c_str(), set_feature.size());
GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx),
closure->request(request_idx),
closure->response(request_idx),
closure);
}
return fut;
}
int32_t GraphBrpcClient::Initialize() {
// set_shard_num(_config.shard_num());
BrpcPsClient::Initialize();
server_size = GetServerNums();
graph_service = NULL;
local_channel = NULL;
return 0;
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ThreadPool.h>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "ThreadPool.h"
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/ps/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/ps/service/graph_brpc_server.h"
#include "paddle/fluid/distributed/ps/service/ps_client.h"
#include "paddle/fluid/distributed/ps/table/table.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor_util.h"
namespace paddle {
namespace distributed {
class GraphPsService_Stub : public PsService_Stub {
public:
GraphPsService_Stub(::google::protobuf::RpcChannel* channel,
::google::protobuf::RpcChannel* local_channel = NULL,
GraphBrpcService* service = NULL,
int thread_num = 1)
: PsService_Stub(channel) {
this->local_channel = local_channel;
this->graph_service = service;
task_pool.reset(new ::ThreadPool(thread_num));
}
virtual ~GraphPsService_Stub() {}
// implements PsService ------------------------------------------
GraphBrpcService* graph_service;
std::shared_ptr<::ThreadPool> task_pool;
::google::protobuf::RpcChannel* local_channel;
GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(GraphPsService_Stub);
void service(::google::protobuf::RpcController* controller,
const ::paddle::distributed::PsRequestMessage* request,
::paddle::distributed::PsResponseMessage* response,
::google::protobuf::Closure* done);
};
class GraphBrpcClient : public BrpcPsClient {
public:
GraphBrpcClient() {}
virtual ~GraphBrpcClient() {}
// given a batch of nodes, sample graph_neighbors for each of them
virtual std::future<int32_t> batch_sample_neighbors(
uint32_t table_id,
int idx,
std::vector<int64_t> node_ids,
int sample_size,
std::vector<std::vector<int64_t>>& res,
std::vector<std::vector<float>>& res_weight,
bool need_weight,
int server_index = -1);
virtual std::future<int32_t> pull_graph_list(uint32_t table_id,
int type_id,
int idx,
int server_index,
int start,
int size,
int step,
std::vector<FeatureNode>& res);
virtual std::future<int32_t> random_sample_nodes(uint32_t table_id,
int type_id,
int idx,
int server_index,
int sample_size,
std::vector<int64_t>& ids);
virtual std::future<int32_t> get_node_feat(
const uint32_t& table_id,
int idx,
const std::vector<int64_t>& node_ids,
const std::vector<std::string>& feature_names,
std::vector<std::vector<std::string>>& res);
virtual std::future<int32_t> set_node_feat(
const uint32_t& table_id,
int idx,
const std::vector<int64_t>& node_ids,
const std::vector<std::string>& feature_names,
const std::vector<std::vector<std::string>>& features);
virtual std::future<int32_t> clear_nodes(uint32_t table_id,
int type_id,
int idx);
virtual std::future<int32_t> add_graph_node(
uint32_t table_id,
int idx,
std::vector<int64_t>& node_id_list,
std::vector<bool>& is_weighted_list);
virtual std::future<int32_t> remove_graph_node(
uint32_t table_id, int idx_, std::vector<int64_t>& node_id_list);
virtual int32_t Initialize();
int get_shard_num() { return shard_num; }
void set_shard_num(int shard_num) { this->shard_num = shard_num; }
int get_server_index_by_id(int64_t id);
void set_local_channel(int index) {
this->local_channel = GetCmdChannel(index);
}
void set_local_graph_service(GraphBrpcService* graph_service) {
this->graph_service = graph_service;
}
GraphPsService_Stub getServiceStub(::google::protobuf::RpcChannel* channel,
int thread_num = 1) {
return GraphPsService_Stub(
channel, local_channel, graph_service, thread_num);
}
private:
int shard_num;
size_t server_size;
::google::protobuf::RpcChannel* local_channel;
GraphBrpcService* graph_service;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/graph_brpc_server.h"
#include <thread> // NOLINT
#include <utility>
#include "butil/endpoint.h"
#include "iomanip"
#include "paddle/fluid/distributed/ps/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/ps/service/brpc_ps_server.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace distributed {
#define CHECK_TABLE_EXIST(table, request, response) \
if (table == NULL) { \
std::string err_msg("table not found with table_id:"); \
err_msg.append(std::to_string(request.table_id())); \
set_response_code(response, -1, err_msg.c_str()); \
return -1; \
}
int32_t GraphBrpcServer::Initialize() {
auto &service_config = _config.downpour_server_param().service_param();
if (!service_config.has_service_class()) {
LOG(ERROR) << "miss service_class in ServerServiceParameter";
return -1;
}
auto *service =
CREATE_PSCORE_CLASS(PsBaseService, service_config.service_class());
if (service == NULL) {
LOG(ERROR) << "service is unregistered, service_name:"
<< service_config.service_class();
return -1;
}
_service.reset(service);
if (service->Configure(this) != 0 || service->Initialize() != 0) {
LOG(ERROR) << "service initialize failed, service_name:"
<< service_config.service_class();
return -1;
}
if (_server.AddService(service, brpc::SERVER_DOESNT_OWN_SERVICE) != 0) {
LOG(ERROR) << "service add to brpc failed, service:"
<< service_config.service_class();
return -1;
}
return 0;
}
brpc::Channel *GraphBrpcServer::GetCmdChannel(size_t server_index) {
return _pserver_channels[server_index].get();
}
uint64_t GraphBrpcServer::Start(const std::string &ip, uint32_t port) {
std::unique_lock<std::mutex> lock(mutex_);
std::string ip_port = ip + ":" + std::to_string(port);
VLOG(3) << "server of rank " << _rank << " starts at " << ip_port;
brpc::ServerOptions options;
int num_threads = std::thread::hardware_concurrency();
auto trainers = _environment->GetTrainers();
options.num_threads = trainers > num_threads ? trainers : num_threads;
if (_server.Start(ip_port.c_str(), &options) != 0) {
LOG(ERROR) << "GraphBrpcServer start failed, ip_port=" << ip_port;
return 0;
}
_environment->RegistePsServer(ip, port, _rank);
return 0;
}
int32_t GraphBrpcServer::build_peer2peer_connection(int rank) {
this->rank = rank;
auto _env = Environment();
brpc::ChannelOptions options;
options.protocol = "baidu_std";
options.timeout_ms = 500000;
options.connection_type = "pooled";
options.connect_timeout_ms = 10000;
options.max_retry = 3;
std::vector<PSHost> server_list = _env->GetPsServers();
_pserver_channels.resize(server_list.size());
std::ostringstream os;
std::string server_ip_port;
for (size_t i = 0; i < server_list.size(); ++i) {
server_ip_port.assign(server_list[i].ip.c_str());
server_ip_port.append(":");
server_ip_port.append(std::to_string(server_list[i].port));
_pserver_channels[i].reset(new brpc::Channel());
if (_pserver_channels[i]->Init(server_ip_port.c_str(), "", &options) != 0) {
VLOG(0) << "GraphServer connect to Server:" << server_ip_port
<< " Failed! Try again.";
std::string int_ip_port =
GetIntTypeEndpoint(server_list[i].ip, server_list[i].port);
if (_pserver_channels[i]->Init(int_ip_port.c_str(), "", &options) != 0) {
LOG(ERROR) << "GraphServer connect to Server:" << int_ip_port
<< " Failed!";
return -1;
}
}
os << server_ip_port << ",";
}
LOG(INFO) << "servers peer2peer connection success:" << os.str();
return 0;
}
int32_t GraphBrpcService::clear_nodes(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
int type_id = *(int *)(request.params(0).c_str());
int idx_ = *(int *)(request.params(1).c_str());
((GraphTable *)table)->clear_nodes(type_id, idx_);
return 0;
}
int32_t GraphBrpcService::add_graph_node(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) {
set_response_code(
response, -1, "add_graph_node request requires at least 2 arguments");
return 0;
}
int idx_ = *(int *)(request.params(0).c_str());
size_t node_num = request.params(1).size() / sizeof(int64_t);
int64_t *node_data = (int64_t *)(request.params(1).c_str());
// size_t node_num = request.params(0).size() / sizeof(int64_t);
// int64_t *node_data = (int64_t *)(request.params(0).c_str());
std::vector<int64_t> node_ids(node_data, node_data + node_num);
std::vector<bool> is_weighted_list;
if (request.params_size() == 3) {
size_t weight_list_size = request.params(2).size() / sizeof(bool);
bool *is_weighted_buffer = (bool *)(request.params(2).c_str());
is_weighted_list = std::vector<bool>(is_weighted_buffer,
is_weighted_buffer + weight_list_size);
}
// if (request.params_size() == 2) {
// size_t weight_list_size = request.params(1).size() / sizeof(bool);
// bool *is_weighted_buffer = (bool *)(request.params(1).c_str());
// is_weighted_list = std::vector<bool>(is_weighted_buffer,
// is_weighted_buffer +
// weight_list_size);
// }
((GraphTable *)table)->add_graph_node(idx_, node_ids, is_weighted_list);
return 0;
}
int32_t GraphBrpcService::remove_graph_node(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) {
set_response_code(
response,
-1,
"remove_graph_node request requires at least 2 arguments");
return 0;
}
int idx_ = *(int *)(request.params(0).c_str());
size_t node_num = request.params(1).size() / sizeof(int64_t);
int64_t *node_data = (int64_t *)(request.params(1).c_str());
// size_t node_num = request.params(0).size() / sizeof(int64_t);
// int64_t *node_data = (int64_t *)(request.params(0).c_str());
std::vector<int64_t> node_ids(node_data, node_data + node_num);
((GraphTable *)table)->remove_graph_node(idx_, node_ids);
return 0;
}
int32_t GraphBrpcServer::Port() { return _server.listen_address().port; }
int32_t GraphBrpcService::Initialize() {
_is_initialize_shard_info = false;
_service_handler_map[PS_STOP_SERVER] = &GraphBrpcService::StopServer;
_service_handler_map[PS_LOAD_ONE_TABLE] = &GraphBrpcService::LoadOneTable;
_service_handler_map[PS_LOAD_ALL_TABLE] = &GraphBrpcService::LoadAllTable;
_service_handler_map[PS_PRINT_TABLE_STAT] = &GraphBrpcService::PrintTableStat;
_service_handler_map[PS_BARRIER] = &GraphBrpcService::Barrier;
_service_handler_map[PS_START_PROFILER] = &GraphBrpcService::StartProfiler;
_service_handler_map[PS_STOP_PROFILER] = &GraphBrpcService::StopProfiler;
_service_handler_map[PS_PULL_GRAPH_LIST] = &GraphBrpcService::pull_graph_list;
_service_handler_map[PS_GRAPH_SAMPLE_NEIGHBORS] =
&GraphBrpcService::graph_random_sample_neighbors;
_service_handler_map[PS_GRAPH_SAMPLE_NODES] =
&GraphBrpcService::graph_random_sample_nodes;
_service_handler_map[PS_GRAPH_GET_NODE_FEAT] =
&GraphBrpcService::graph_get_node_feat;
_service_handler_map[PS_GRAPH_CLEAR] = &GraphBrpcService::clear_nodes;
_service_handler_map[PS_GRAPH_ADD_GRAPH_NODE] =
&GraphBrpcService::add_graph_node;
_service_handler_map[PS_GRAPH_REMOVE_GRAPH_NODE] =
&GraphBrpcService::remove_graph_node;
_service_handler_map[PS_GRAPH_SET_NODE_FEAT] =
&GraphBrpcService::graph_set_node_feat;
_service_handler_map[PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER] =
&GraphBrpcService::sample_neighbors_across_multi_servers;
// _service_handler_map[PS_GRAPH_USE_NEIGHBORS_SAMPLE_CACHE] =
// &GraphBrpcService::use_neighbors_sample_cache;
// _service_handler_map[PS_GRAPH_LOAD_GRAPH_SPLIT_CONFIG] =
// &GraphBrpcService::load_graph_split_config;
// shard初始化,server启动后才可从env获取到server_list的shard信息
InitializeShardInfo();
return 0;
}
int32_t GraphBrpcService::InitializeShardInfo() {
if (!_is_initialize_shard_info) {
std::lock_guard<std::mutex> guard(_initialize_shard_mutex);
if (_is_initialize_shard_info) {
return 0;
}
server_size = _server->Environment()->GetPsServers().size();
auto &table_map = *(_server->GetTable());
for (auto itr : table_map) {
itr.second->SetShard(_rank, server_size);
}
_is_initialize_shard_info = true;
}
return 0;
}
void GraphBrpcService::service(google::protobuf::RpcController *cntl_base,
const PsRequestMessage *request,
PsResponseMessage *response,
google::protobuf::Closure *done) {
brpc::ClosureGuard done_guard(done);
std::string log_label("ReceiveCmd-");
if (!request->has_table_id()) {
set_response_code(*response, -1, "PsRequestMessage.tabel_id is required");
return;
}
response->set_err_code(0);
response->set_err_msg("");
auto *table = _server->GetTable(request->table_id());
brpc::Controller *cntl = static_cast<brpc::Controller *>(cntl_base);
auto itr = _service_handler_map.find(request->cmd_id());
if (itr == _service_handler_map.end()) {
std::string err_msg(
"undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:");
err_msg.append(std::to_string(request->cmd_id()));
set_response_code(*response, -1, err_msg.c_str());
return;
}
serviceFunc handler_func = itr->second;
int service_ret = (this->*handler_func)(table, *request, *response, cntl);
if (service_ret != 0) {
response->set_err_code(service_ret);
if (!response->has_err_msg()) {
response->set_err_msg("server internal error");
}
}
}
int32_t GraphBrpcService::Barrier(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 1) {
set_response_code(response,
-1,
"PsRequestMessage.params is requeired at "
"least 1 for num of sparse_key");
return 0;
}
auto trainer_id = request.client_id();
auto barrier_type = request.params(0);
table->Barrier(trainer_id, barrier_type);
return 0;
}
int32_t GraphBrpcService::PrintTableStat(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
std::pair<int64_t, int64_t> ret = table->PrintTableStat();
paddle::framework::BinaryArchive ar;
ar << ret.first << ret.second;
std::string table_info(ar.Buffer(), ar.Length());
response.set_data(table_info);
return 0;
}
int32_t GraphBrpcService::LoadOneTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) {
set_response_code(
response,
-1,
"PsRequestMessage.datas is requeired at least 2 for path & load_param");
return -1;
}
if (table->Load(request.params(0), request.params(1)) != 0) {
set_response_code(response, -1, "table load failed");
return -1;
}
return 0;
}
int32_t GraphBrpcService::LoadAllTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
auto &table_map = *(_server->GetTable());
for (auto &itr : table_map) {
if (LoadOneTable(itr.second.get(), request, response, cntl) != 0) {
LOG(ERROR) << "load table[" << itr.first << "] failed";
return -1;
}
}
return 0;
}
int32_t GraphBrpcService::StopServer(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
GraphBrpcServer *p_server = (GraphBrpcServer *)_server;
std::thread t_stop([p_server]() {
p_server->Stop();
LOG(INFO) << "Server Stoped";
});
p_server->export_cv()->notify_all();
t_stop.detach();
return 0;
}
int32_t GraphBrpcService::StopProfiler(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::DisableProfiler(platform::EventSortingKey::kDefault,
string::Sprintf("server_%s_profile", _rank));
return 0;
}
int32_t GraphBrpcService::StartProfiler(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::EnableProfiler(platform::ProfilerState::kCPU);
return 0;
}
int32_t GraphBrpcService::pull_graph_list(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 5) {
set_response_code(
response, -1, "pull_graph_list request requires at least 5 arguments");
return 0;
}
int type_id = *(int *)(request.params(0).c_str());
int idx = *(int *)(request.params(1).c_str());
int start = *(int *)(request.params(2).c_str());
int size = *(int *)(request.params(3).c_str());
int step = *(int *)(request.params(4).c_str());
// int start = *(int *)(request.params(0).c_str());
// int size = *(int *)(request.params(1).c_str());
// int step = *(int *)(request.params(2).c_str());
std::unique_ptr<char[]> buffer;
int actual_size;
((GraphTable *)table)
->pull_graph_list(
type_id, idx, start, size, buffer, actual_size, false, step);
cntl->response_attachment().append(buffer.get(), actual_size);
return 0;
}
int32_t GraphBrpcService::graph_random_sample_neighbors(
Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 4) {
set_response_code(
response,
-1,
"graph_random_sample_neighbors request requires at least 3 arguments");
return 0;
}
int idx_ = *(int *)(request.params(0).c_str());
size_t node_num = request.params(1).size() / sizeof(int64_t);
int64_t *node_data = (int64_t *)(request.params(1).c_str());
int sample_size = *(int64_t *)(request.params(2).c_str());
bool need_weight = *(bool *)(request.params(3).c_str());
// size_t node_num = request.params(0).size() / sizeof(int64_t);
// int64_t *node_data = (int64_t *)(request.params(0).c_str());
// int sample_size = *(int64_t *)(request.params(1).c_str());
// bool need_weight = *(bool *)(request.params(2).c_str());
std::vector<std::shared_ptr<char>> buffers(node_num);
std::vector<int> actual_sizes(node_num, 0);
((GraphTable *)table)
->random_sample_neighbors(
idx_, node_data, sample_size, buffers, actual_sizes, need_weight);
cntl->response_attachment().append(&node_num, sizeof(size_t));
cntl->response_attachment().append(actual_sizes.data(),
sizeof(int) * node_num);
for (size_t idx = 0; idx < node_num; ++idx) {
cntl->response_attachment().append(buffers[idx].get(), actual_sizes[idx]);
}
return 0;
}
int32_t GraphBrpcService::graph_random_sample_nodes(
Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
int type_id = *(int *)(request.params(0).c_str());
int idx_ = *(int *)(request.params(1).c_str());
size_t size = *(int64_t *)(request.params(2).c_str());
// size_t size = *(int64_t *)(request.params(0).c_str());
std::unique_ptr<char[]> buffer;
int actual_size;
if (((GraphTable *)table)
->random_sample_nodes(type_id, idx_, size, buffer, actual_size) ==
0) {
cntl->response_attachment().append(buffer.get(), actual_size);
} else
cntl->response_attachment().append(NULL, 0);
return 0;
}
int32_t GraphBrpcService::graph_get_node_feat(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 3) {
set_response_code(
response,
-1,
"graph_get_node_feat request requires at least 3 arguments");
return 0;
}
int idx_ = *(int *)(request.params(0).c_str());
size_t node_num = request.params(1).size() / sizeof(int64_t);
int64_t *node_data = (int64_t *)(request.params(1).c_str());
// size_t node_num = request.params(0).size() / sizeof(int64_t);
// int64_t *node_data = (int64_t *)(request.params(0).c_str());
std::vector<int64_t> node_ids(node_data, node_data + node_num);
std::vector<std::string> feature_names =
paddle::string::split_string<std::string>(request.params(2), "\t");
std::vector<std::vector<std::string>> feature(
feature_names.size(), std::vector<std::string>(node_num));
((GraphTable *)table)->get_node_feat(idx_, node_ids, feature_names, feature);
for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
size_t feat_len = feature[feat_idx][node_idx].size();
cntl->response_attachment().append(&feat_len, sizeof(size_t));
cntl->response_attachment().append(feature[feat_idx][node_idx].data(),
feat_len);
}
}
return 0;
}
int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
// sleep(5);
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 4) {
set_response_code(response,
-1,
"sample_neighbors_across_multi_servers request requires "
"at least 4 arguments");
return 0;
}
int idx_ = *(int *)(request.params(0).c_str());
size_t node_num = request.params(1).size() / sizeof(int64_t);
int64_t *node_data = (int64_t *)(request.params(1).c_str());
int sample_size = *(int64_t *)(request.params(2).c_str());
bool need_weight = *(int64_t *)(request.params(3).c_str());
// size_t node_num = request.params(0).size() / sizeof(int64_t),
// size_of_size_t = sizeof(size_t);
// int64_t *node_data = (int64_t *)(request.params(0).c_str());
// int sample_size = *(int64_t *)(request.params(1).c_str());
// bool need_weight = *(int64_t *)(request.params(2).c_str());
// std::vector<int64_t> res = ((GraphTable
// *)table).filter_out_non_exist_nodes(node_data, sample_size);
std::vector<int> request2server;
std::vector<int> server2request(server_size, -1);
std::vector<int64_t> local_id;
std::vector<int> local_query_idx;
size_t rank = GetRank();
for (size_t query_idx = 0; query_idx < node_num; ++query_idx) {
int server_index =
((GraphTable *)table)->get_server_index_by_id(node_data[query_idx]);
if (server2request[server_index] == -1) {
server2request[server_index] = request2server.size();
request2server.push_back(server_index);
}
}
if (server2request[rank] != -1) {
auto pos = server2request[rank];
std::swap(request2server[pos],
request2server[(int)request2server.size() - 1]);
server2request[request2server[pos]] = pos;
server2request[request2server[(int)request2server.size() - 1]] =
request2server.size() - 1;
}
size_t request_call_num = request2server.size();
std::vector<std::shared_ptr<char>> local_buffers;
std::vector<int> local_actual_sizes;
std::vector<size_t> seq;
std::vector<std::vector<int64_t>> node_id_buckets(request_call_num);
std::vector<std::vector<int>> query_idx_buckets(request_call_num);
for (size_t query_idx = 0; query_idx < node_num; ++query_idx) {
int server_index =
((GraphTable *)table)->get_server_index_by_id(node_data[query_idx]);
int request_idx = server2request[server_index];
node_id_buckets[request_idx].push_back(node_data[query_idx]);
query_idx_buckets[request_idx].push_back(query_idx);
seq.push_back(request_idx);
}
size_t remote_call_num = request_call_num;
if (request2server.size() != 0 &&
static_cast<size_t>(request2server.back()) == rank) {
remote_call_num--;
local_buffers.resize(node_id_buckets.back().size());
local_actual_sizes.resize(node_id_buckets.back().size());
}
cntl->response_attachment().append(&node_num, sizeof(size_t));
auto local_promise = std::make_shared<std::promise<int32_t>>();
std::future<int> local_fut = local_promise->get_future();
std::vector<bool> failed(server_size, false);
std::function<void(void *)> func = [&,
node_id_buckets,
query_idx_buckets,
request_call_num](void *done) {
local_fut.get();
std::vector<int> actual_size;
auto *closure = (DownpourBrpcClosure *)done;
std::vector<std::unique_ptr<butil::IOBufBytesIterator>> res(
remote_call_num);
size_t fail_num = 0;
for (size_t request_idx = 0; request_idx < remote_call_num; ++request_idx) {
if (closure->check_response(request_idx, PS_GRAPH_SAMPLE_NEIGHBORS) !=
0) {
++fail_num;
failed[request2server[request_idx]] = true;
} else {
auto &res_io_buffer = closure->cntl(request_idx)->response_attachment();
res[request_idx].reset(new butil::IOBufBytesIterator(res_io_buffer));
size_t num;
res[request_idx]->copy_and_forward(&num, sizeof(size_t));
}
}
int size;
int local_index = 0;
for (size_t i = 0; i < node_num; i++) {
if (fail_num > 0 && failed[seq[i]]) {
size = 0;
} else if (static_cast<size_t>(request2server[seq[i]]) != rank) {
res[seq[i]]->copy_and_forward(&size, sizeof(int));
} else {
size = local_actual_sizes[local_index++];
}
actual_size.push_back(size);
}
cntl->response_attachment().append(actual_size.data(),
actual_size.size() * sizeof(int));
local_index = 0;
for (size_t i = 0; i < node_num; i++) {
if (fail_num > 0 && failed[seq[i]]) {
continue;
} else if (static_cast<size_t>(request2server[seq[i]]) != rank) {
char temp[actual_size[i] + 1];
res[seq[i]]->copy_and_forward(temp, actual_size[i]);
cntl->response_attachment().append(temp, actual_size[i]);
} else {
char *temp = local_buffers[local_index++].get();
cntl->response_attachment().append(temp, actual_size[i]);
}
}
closure->set_promise_value(0);
};
DownpourBrpcClosure *closure = new DownpourBrpcClosure(remote_call_num, func);
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t request_idx = 0; request_idx < remote_call_num; ++request_idx) {
int server_index = request2server[request_idx];
closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE_NEIGHBORS);
closure->request(request_idx)->set_table_id(request.table_id());
closure->request(request_idx)->set_client_id(rank);
size_t node_num = node_id_buckets[request_idx].size();
closure->request(request_idx)->add_params((char *)&idx_, sizeof(int));
closure->request(request_idx)
->add_params((char *)node_id_buckets[request_idx].data(),
sizeof(int64_t) * node_num);
closure->request(request_idx)
->add_params((char *)&sample_size, sizeof(int));
closure->request(request_idx)
->add_params((char *)&need_weight, sizeof(bool));
PsService_Stub rpc_stub(
((GraphBrpcServer *)GetServer())->GetCmdChannel(server_index));
// GraphPsService_Stub rpc_stub =
// getServiceStub(GetCmdChannel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx),
closure->request(request_idx),
closure->response(request_idx),
closure);
}
if (server2request[rank] != -1) {
((GraphTable *)table)
->random_sample_neighbors(idx_,
node_id_buckets.back().data(),
sample_size,
local_buffers,
local_actual_sizes,
need_weight);
}
local_promise.get()->set_value(0);
if (remote_call_num == 0) func(closure);
fut.get();
return 0;
}
int32_t GraphBrpcService::graph_set_node_feat(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 4) {
set_response_code(
response,
-1,
"graph_set_node_feat request requires at least 3 arguments");
return 0;
}
int idx_ = *(int *)(request.params(0).c_str());
// size_t node_num = request.params(0).size() / sizeof(int64_t);
// int64_t *node_data = (int64_t *)(request.params(0).c_str());
size_t node_num = request.params(1).size() / sizeof(int64_t);
int64_t *node_data = (int64_t *)(request.params(1).c_str());
std::vector<int64_t> node_ids(node_data, node_data + node_num);
// std::vector<std::string> feature_names =
// paddle::string::split_string<std::string>(request.params(1), "\t");
std::vector<std::string> feature_names =
paddle::string::split_string<std::string>(request.params(2), "\t");
std::vector<std::vector<std::string>> features(
feature_names.size(), std::vector<std::string>(node_num));
// const char *buffer = request.params(2).c_str();
const char *buffer = request.params(3).c_str();
for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
size_t feat_len = *(size_t *)(buffer);
buffer += sizeof(size_t);
auto feat = std::string(buffer, feat_len);
features[feat_idx][node_idx] = feat;
buffer += feat_len;
}
}
((GraphTable *)table)->set_node_feat(idx_, node_ids, feature_names, features);
return 0;
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <vector>
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/ps/service/brpc_ps_server.h"
#include "paddle/fluid/distributed/ps/service/server.h"
#include "paddle/fluid/distributed/ps/table/common_graph_table.h"
#include "paddle/fluid/distributed/ps/table/table.h"
namespace paddle {
namespace distributed {
class GraphBrpcServer : public PSServer {
public:
GraphBrpcServer() {}
virtual ~GraphBrpcServer() {}
PsBaseService *get_service() { return _service.get(); }
virtual uint64_t Start(const std::string &ip, uint32_t port);
virtual int32_t build_peer2peer_connection(int rank);
virtual brpc::Channel *GetCmdChannel(size_t server_index);
virtual int32_t Stop() {
std::unique_lock<std::mutex> lock(mutex_);
if (stoped_) return 0;
stoped_ = true;
// cv_.notify_all();
_server.Stop(1000);
_server.Join();
return 0;
}
int32_t Port();
std::condition_variable *export_cv() { return &cv_; }
private:
virtual int32_t Initialize();
mutable std::mutex mutex_;
std::condition_variable cv_;
bool stoped_ = false;
int rank;
brpc::Server _server;
std::shared_ptr<PsBaseService> _service;
std::vector<std::shared_ptr<brpc::Channel>> _pserver_channels;
};
class GraphBrpcService;
typedef int32_t (GraphBrpcService::*serviceFunc)(
Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
class GraphBrpcService : public PsBaseService {
public:
virtual int32_t Initialize() override;
virtual void service(::google::protobuf::RpcController *controller,
const PsRequestMessage *request,
PsResponseMessage *response,
::google::protobuf::Closure *done) override;
protected:
std::unordered_map<int32_t, serviceFunc> _service_handler_map;
int32_t InitializeShardInfo();
int32_t pull_graph_list(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t graph_random_sample_neighbors(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t graph_random_sample_nodes(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t graph_get_node_feat(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t graph_set_node_feat(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t clear_nodes(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t add_graph_node(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t remove_graph_node(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t Barrier(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t LoadOneTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t LoadAllTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t StopServer(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t StartProfiler(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t StopProfiler(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t PrintTableStat(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t sample_neighbors_across_multi_servers(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t use_neighbors_sample_cache(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t load_graph_split_config(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
private:
bool _is_initialize_shard_info;
std::mutex _initialize_shard_mutex;
std::unordered_map<int32_t, serviceHandlerFunc> _msg_handler_map;
std::vector<float> _ori_values;
const int sample_nodes_ranges = 23;
size_t server_size;
std::shared_ptr<::ThreadPool> task_pool;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/heter_client.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/platform/profiler.h"
DEFINE_int32(heter_world_size, 100, "group size"); // group max size
DEFINE_int32(switch_send_recv_timeout_s, 600, "switch_send_recv_timeout_s");
namespace paddle {
namespace distributed {
std::shared_ptr<HeterClient> HeterClient::s_instance_ = nullptr;
std::mutex HeterClient::mtx_;
std::shared_ptr<HeterClient> HeterClient::switch_s_instance_ = nullptr;
int GetMicroId(const platform::DeviceContext& ctx,
const framework::Scope* scope) {
framework::Variable* var = scope->FindVar("microbatch_id");
PADDLE_ENFORCE_EQ(var->IsType<framework::LoDTensor>(),
true,
platform::errors::InvalidArgument(
"the type of micro id shoulde be LoDTensor."));
auto micro_id = -1;
auto* tensor = var->GetMutable<framework::LoDTensor>();
if (platform::is_cpu_place(tensor->place())) {
auto data = reinterpret_cast<const float*>(tensor->data());
micro_id = static_cast<int>(data[0]);
} else {
#ifdef PADDLE_WITH_CUDA
std::vector<char> temp;
temp.resize(tensor->numel() * framework::DataTypeSize(tensor->dtype()));
char* temp_ptr = temp.data();
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
memory::Copy(platform::CPUPlace(),
temp_ptr,
tensor->place(),
tensor->data(),
tensor->numel() * framework::DataTypeSize(tensor->dtype()),
stream);
float* temp_ptr_float = reinterpret_cast<float*>(temp_ptr);
micro_id = static_cast<int>(temp_ptr_float[0]);
#endif
}
return micro_id;
}
void HeterClient::Stop() {
auto status = StopHeterWorker();
status.wait();
}
std::future<int32_t> HeterClient::StopHeterWorker() {
return SendCmd(-1, PS_STOP_SERVER, {});
}
std::future<int32_t> HeterClient::StartProfiler() {
return SendCmd(-1, PS_START_PROFILER, {});
}
std::future<int32_t> HeterClient::StopProfiler() {
return SendCmd(-1, PS_STOP_PROFILER, {});
}
void HeterClient::CreateClient2XpuConnection() {
brpc::ChannelOptions options;
options.protocol = "baidu_std";
options.connection_type = "single";
options.timeout_ms = FLAGS_pserver_timeout_ms;
xpu_channels_.resize(xpu_list_.size());
for (size_t i = 0; i < xpu_list_.size(); ++i) {
xpu_channels_[i].reset(new brpc::Channel());
if (xpu_channels_[i]->Init(xpu_list_[i].c_str(), "", &options) != 0) {
VLOG(0) << "HeterClient channel init fail. Try Again";
auto ip_port = paddle::string::Split(xpu_list_[i], ':');
std::string ip = ip_port[0];
int port = std::stoi(ip_port[1]);
std::string int_ip_port = GetIntTypeEndpoint(ip, port);
if (xpu_channels_[i]->Init(int_ip_port.c_str(), "", &options) != 0) {
LOG(ERROR) << "BrpcPsServer start failed, ip_port= " << int_ip_port;
}
}
}
previous_xpu_channels_.resize(previous_xpu_list_.size());
for (size_t i = 0; i < previous_xpu_list_.size(); ++i) {
previous_xpu_channels_[i].reset(new brpc::Channel());
if (previous_xpu_channels_[i]->Init(
previous_xpu_list_[i].c_str(), "", &options) != 0) {
VLOG(0) << "HeterClient channel init fail. Try Again";
auto ip_port = paddle::string::Split(previous_xpu_list_[i], ':');
std::string ip = ip_port[0];
int port = std::stoi(ip_port[1]);
std::string int_ip_port = GetIntTypeEndpoint(ip, port);
if (previous_xpu_channels_[i]->Init(int_ip_port.c_str(), "", &options) !=
0) {
LOG(ERROR) << "BrpcPsServer start failed, ip_port= " << int_ip_port;
}
}
}
}
void HeterClient::SendAndRecvAsync(
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& message_name,
const std::vector<std::string>& send_var_name,
const std::vector<std::string>& recv_var_name,
const std::string& mode) {
platform::RecordEvent record_event("HeterClient->SendAndRecvAsync",
platform::TracerEventType::Communication,
1);
const platform::DeviceContext* p_ctx = &ctx;
const framework::Scope* p_scope = &scope;
const std::vector<std::string> send_var_name_val = send_var_name;
const std::vector<std::string> recv_var_name_val = recv_var_name;
VLOG(3) << "BRPCClient::SendAndRecv Begin, message_name: " << message_name;
brpc::Channel* channel = nullptr;
distributed::MultiVarMsg request;
OnHeterRpcDone* closure = new OnHeterRpcDone([](void* done) {
auto* closure = reinterpret_cast<OnHeterRpcDone*>(done);
PADDLE_ENFORCE_NE(
closure->cntl.Failed(),
true,
platform::errors::Unimplemented(
"HeterClient::SendAndRecv meets brpc error, error message is %s",
closure->cntl.ErrorText()));
VLOG(4) << "call heter_worker success";
});
closure->cntl.set_timeout_ms(FLAGS_pserver_timeout_ms);
auto& request_io_buffer = closure->cntl.request_attachment();
distributed::SerializeToMultiVarMsgAndIOBuf(message_name,
send_var_name_val,
recv_var_name_val,
*p_ctx,
p_scope,
&request,
&request_io_buffer);
int micro_id = GetMicroId(ctx, p_scope); // global
auto minibatch_id = micro_id / 10;
VLOG(4) << "micro_id: " << micro_id;
// select channel according to micro id
if (mode == "forward") {
int num = minibatch_id % xpu_channels_.size();
channel = xpu_channels_[num].get();
} else if (mode == "backward") {
int num = minibatch_id % previous_xpu_channels_.size();
channel = previous_xpu_channels_[num].get();
} else if (mode == "send_to_switch") {
VLOG(4) << "calling switch service";
// auto promise = std::make_shared<std::promise<int32_t>>();
// closure->add_promise(promise);
// std::future<int> fut = promise->get_future();
// int idx = 1; // for test
// LOG(INFO) << "xpu_channels_ size: " << xpu_channels_.size();
// channel = xpu_channels_[idx].get(); // 为了适配 send_and_recv op
// ::paddle::distributed::PsService_Stub stub(channel);
// stub.SendToSwitch(&closure->cntl, &request, &closure->response,
// closure); fut.wait();
VLOG(4) << "calling switch service done";
return;
}
::paddle::distributed::PsService_Stub stub(channel);
stub.SendAndRecvVariable(
&closure->cntl, &request, &closure->response, closure);
}
std::future<int32_t> HeterClient::SendCmd(
uint32_t table_id, int cmd_id, const std::vector<std::string>& params) {
size_t request_call_num = xpu_channels_.size();
paddle::distributed::DownpourBrpcClosure* closure =
new paddle::distributed::DownpourBrpcClosure(
request_call_num, [request_call_num, cmd_id](void* done) {
int ret = 0;
auto* closure = (paddle::distributed::DownpourBrpcClosure*)done;
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, cmd_id) != 0) {
ret = -1;
break;
}
}
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(cmd_id);
closure->request(i)->set_table_id(table_id);
closure->request(i)->set_client_id(trainer_id_);
for (const auto& param : params) {
closure->request(i)->add_params(param);
}
::paddle::distributed::PsService_Stub rpc_stub(xpu_channels_[i].get());
closure->cntl(i)->set_timeout_ms(
FLAGS_pserver_timeout_ms); // cmd msg don't limit timeout for save/load
rpc_stub.service(
closure->cntl(i), closure->request(i), closure->response(i), closure);
}
return fut;
}
int HeterClient::Send(const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& message_name,
const std::vector<std::string>& send_var_names) {
const framework::Scope* p_scope = &scope; // 注意是 const
OnHeterRpcDone* closure = new OnHeterRpcDone([](void* done) {
auto* closure = reinterpret_cast<OnHeterRpcDone*>(done);
int ret = 0;
closure->set_promise_value(ret);
if (closure->cntl.Failed()) {
PADDLE_ENFORCE_NE(
closure->cntl.Failed(),
true,
platform::errors::Unimplemented(
"HeterClient::SendToSwitch meets brpc error, error message is %s",
closure->cntl.ErrorText()));
}
});
closure->cntl.set_timeout_ms(FLAGS_pserver_timeout_ms);
auto& request_io_buffer = closure->cntl.request_attachment();
distributed::MultiVarMsg request;
// 1. set req message_name(string)
request.set_message_name(message_name);
request.set_group_id(0);
// 2. set req send_var_names(<string>)
for (auto& send_var_name : send_var_names) {
request.add_send_var_names(send_var_name);
}
// 3. set req var_messages(<VarMessage>)
for (auto& send_var_name : send_var_names) {
auto* send_var_msg = request.add_var_messages();
send_var_msg->set_varname(send_var_name);
framework::Variable* var = p_scope->FindVar(send_var_name);
butil::IOBuf temp_iobuf;
if (var->IsType<framework::LoDTensor>()) {
SerializeLodTensor(var, ctx, send_var_msg, &temp_iobuf);
} else if (var->IsType<phi::SelectedRows>()) {
SerializeSelectedRows(var, ctx, send_var_msg, &temp_iobuf);
}
request_io_buffer.append(temp_iobuf);
}
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
if (send_switch_channels_.empty()) {
LOG(ERROR) << "send_switch_channels_ is null, get xpu_channels_[0]";
if (xpu_channels_.empty()) {
LOG(ERROR) << "xpu_channels_ is null";
}
send_switch_channels_.push_back(xpu_channels_[0]);
}
brpc::Channel* channel = send_switch_channels_[0].get();
// brpc::Channel* channel = xpu_channels_[0].get();
::paddle::distributed::PsService_Stub stub(channel);
stub.SendToSwitch(&closure->cntl, &request, &closure->ps_response, closure);
VLOG(4) << "waiting SendToSwitch response result......";
fut.wait();
VLOG(4) << "Send done";
return 0;
}
int HeterClient::Send(int group_id,
const std::vector<std::string>& var_names,
const std::vector<int64_t>& vars_size,
void* data_ptr,
int64_t data_size) {
OnHeterRpcDone* closure = new OnHeterRpcDone([](void* done) {
auto* closure = reinterpret_cast<OnHeterRpcDone*>(done);
int ret = 0;
closure->set_promise_value(ret);
if (closure->cntl.Failed()) {
LOG(ERROR) << "Send meets brpc error, err msg is %s"
<< closure->cntl.ErrorText();
}
});
distributed::MultiVarMsg request;
closure->cntl.set_timeout_ms(FLAGS_pserver_timeout_ms);
std::string message_name = "send and save";
request.set_message_name(message_name);
request.set_group_id(group_id);
for (auto& send_var_name : var_names) {
request.add_send_var_names(send_var_name);
}
for (auto var_len : vars_size) {
request.add_vars_len(var_len);
}
auto& request_buffer = closure->cntl.request_attachment();
request_buffer.append(reinterpret_cast<void*>(data_ptr), data_size);
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
if (send_switch_channels_.empty()) {
LOG(ERROR) << "send_switch_channels_ is null, get xpu_channels_[0]";
if (xpu_channels_.empty()) {
LOG(ERROR) << "xpu_channels_ is null";
}
send_switch_channels_.push_back(xpu_channels_[0]);
}
brpc::Channel* channel = send_switch_channels_[0].get();
::paddle::distributed::PsService_Stub stub(channel);
stub.SendToSwitch(&closure->cntl, &request, &closure->ps_response, closure);
fut.wait();
delete closure;
return 0;
}
int HeterClient::Recv(const platform::DeviceContext& ctx,
framework::Scope& recv_scope, // NOLINT
const std::string& message_name,
const std::vector<std::string>& recv_var_names) {
OnHeterRpcDone* closure = new OnHeterRpcDone([](void* done) {
auto* closure = reinterpret_cast<OnHeterRpcDone*>(done);
VLOG(4) << "Recv service call done";
int ret = 0;
closure->set_promise_value(ret);
if (closure->cntl.Failed()) {
VLOG(4) << "HeterClient::RecvFromSwitch meets "
"brpc error, error message is %s"
<< closure->cntl.ErrorText();
}
});
closure->cntl.set_timeout_ms(FLAGS_pserver_timeout_ms);
distributed::MultiVarMsg request;
// 1. set req message_name(string)
request.set_message_name(message_name);
request.set_group_id(0);
// 2. set req recv_var_names(<string>)
for (auto& recv_var_name : recv_var_names) {
request.add_recv_var_names(recv_var_name);
}
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
if (recv_switch_channels_.empty()) {
LOG(ERROR) << "peer_switch_channels_ is null, get xpu_channels_[1]";
if (xpu_channels_.size() < 2) {
LOG(ERROR) << "xpu_channels_ is null";
}
recv_switch_channels_.push_back(xpu_channels_[1]);
}
brpc::Channel* channel = recv_switch_channels_[0].get();
::paddle::distributed::PsService_Stub stub(channel);
stub.RecvFromSwitch(&closure->cntl, &request, &closure->response, closure);
fut.wait();
VLOG(4) << "RecvFromSwitch done";
// save in worker
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::CPUPlace cpu_place;
auto& cpu_dev_ctx = *pool.Get(cpu_place);
auto& res_io_buffer = closure->cntl.response_attachment();
VLOG(4) << "entering DeserializeFromMultiVarMsgAndIOBuf";
distributed::DeserializeFromMultiVarMsgAndIOBuf(
closure->response, &res_io_buffer, cpu_dev_ctx, &recv_scope);
VLOG(4) << "Recv done";
return 0;
}
int HeterClient::Recv(int group_id,
const std::vector<std::string>& var_names,
void* data_ptr,
int64_t data_size) {
OnHeterRpcDone* closure = new OnHeterRpcDone([](void* done) {
auto* closure = reinterpret_cast<OnHeterRpcDone*>(done);
int ret = 0;
closure->set_promise_value(ret);
if (closure->cntl.Failed()) {
LOG(ERROR) << "Recv meets brpc error, err msg is %s"
<< closure->cntl.ErrorText();
}
});
closure->cntl.set_timeout_ms(FLAGS_pserver_timeout_ms);
distributed::MultiVarMsg request;
std::string message_name = "query and recv";
request.set_message_name(message_name);
request.set_group_id(group_id);
for (auto& recv_var_name : var_names) {
request.add_recv_var_names(recv_var_name);
}
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
if (recv_switch_channels_.empty()) {
LOG(ERROR) << "peer_switch_channels_ is null, get xpu_channels_[1]";
if (xpu_channels_.size() < 2) {
LOG(ERROR) << "xpu_channels_ is null";
}
recv_switch_channels_.push_back(xpu_channels_[0]);
}
brpc::Channel* channel = recv_switch_channels_[0].get();
::paddle::distributed::PsService_Stub stub(channel);
stub.RecvFromSwitch(&closure->cntl, &request, &closure->response, closure);
fut.wait();
VLOG(4) << "RecvFromSwitch done";
// save in worker
auto& res_io_buffer = closure->cntl.response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
io_buffer_itr.copy_and_forward(reinterpret_cast<void*>(data_ptr), data_size);
delete closure;
VLOG(4) << "Recv done";
return 0;
}
} // namespace distributed
} // end namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <atomic>
#include <ctime>
#include <map>
#include <memory>
#include <random>
#include <string>
#include <unordered_map>
#include <vector>
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/ps/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/ps/service/brpc_utils.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
#include "paddle/fluid/string/split.h"
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
DECLARE_int32(pserver_timeout_ms);
namespace paddle {
namespace distributed {
using MultiVarMsg = ::paddle::distributed::MultiVariableMessage;
using VarMsg = ::paddle::distributed::VariableMessage;
typedef std::function<void(void*)> HeterRpcCallbackFunc;
class OnHeterRpcDone : public google::protobuf::Closure {
public:
explicit OnHeterRpcDone(HeterRpcCallbackFunc func) : handler_(func) {}
virtual ~OnHeterRpcDone() {}
void Run() { handler_(this); }
void add_promise(std::shared_ptr<std::promise<int32_t>>& promise) { // NOLINT
_promises.push_back(promise);
}
void set_promise_value(int value) {
for (auto& promise : _promises) {
promise->set_value(value);
}
}
int CheckResponse() { return 0; }
std::vector<std::shared_ptr<std::promise<int32_t>>> _promises;
HeterRpcCallbackFunc handler_;
MultiVariableMessage request;
MultiVariableMessage response;
PsResponseMessage ps_response;
brpc::Controller cntl;
// PsRequestMessage *request(size_t i) { return &_requests[i]; }
// PsResponseMessage *response(size_t i) { return &_responses[i]; }
// std::vector<PsRequestMessage> _requests;
// std::vector<PsResponseMessage> _responses;
// std::vector<std::shared_ptr<brpc::Controller>> _cntls;
};
class HeterClient {
public:
virtual ~HeterClient() {}
void InitClientChannels(bool need_encrypt,
const std::vector<std::string>& node_list,
int32_t peer_role) {
brpc::ChannelOptions options;
options.protocol = "baidu_std";
options.connection_type = "single";
options.timeout_ms = FLAGS_pserver_timeout_ms;
std::vector<std::shared_ptr<brpc::Channel>>* client_channels = nullptr;
if (peer_role == PEER_ROLE_IS_SWITCH) {
#ifdef PADDLE_WITH_ARM_BRPC
if (need_encrypt) {
options.mutable_ssl_options();
}
options.connection_type = "";
VLOG(4) << "ssl enabled in arm";
#else
options.ssl_options.enable = need_encrypt;
#endif
client_channels = &peer_switch_channels_;
} else if (peer_role == PEER_ROLE_IS_WORKER) {
client_channels = &peer_worker_channels_;
} else {
LOG(ERROR) << "init switch client failed, peer_role not valid";
}
(*client_channels).resize(node_list.size());
for (size_t i = 0; i < node_list.size(); ++i) {
(*client_channels)[i].reset(new brpc::Channel());
if ((*client_channels)[i]->Init(node_list[i].c_str(), "", &options) !=
0) {
VLOG(0) << "client channel init failed! try again";
auto ip_port = paddle::string::Split(node_list[i], ':');
std::string ip = ip_port[0];
int port = std::stoi(ip_port[1]);
std::string int_ip_port = GetIntTypeEndpoint(ip, port);
if ((*client_channels)[i]->Init(int_ip_port.c_str(), "", &options) !=
0) {
LOG(ERROR) << "client channel init failed! peer ip_port = "
<< int_ip_port;
}
}
}
VLOG(4) << "InitClientChannels success";
}
void CreateClient2XpuConnection();
void SendAndRecvAsync(const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& message_name,
const std::vector<std::string>& send_var_name,
const std::vector<std::string>& recv_var_name,
const std::string& mode = "forward");
int Send(int group_id,
const std::vector<std::string>& var_names,
const std::vector<int64_t>& vars_len,
void* data_ptr,
int64_t data_size);
int Send(const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& message_name,
const std::vector<std::string>& send_var_names);
int Recv(int group_id,
const std::vector<std::string>& var_names,
void* data_ptr,
int64_t data_size);
int Recv(const platform::DeviceContext& ctx,
framework::Scope& recv_scope, // NOLINT
const std::string& message_name,
const std::vector<std::string>& recv_var_names);
// HeterClient singleton
static std::shared_ptr<HeterClient> GetInstance(
const std::vector<std::string>& endpoints,
const std::vector<std::string>& previous_endpoints,
const int& trainer_id) {
if (NULL == s_instance_) {
s_instance_.reset(new HeterClient());
s_instance_->SetXpuList(endpoints);
s_instance_->SetPreviousXpuList(previous_endpoints);
s_instance_->SetTrainerID(trainer_id);
s_instance_->CreateClient2XpuConnection();
}
return s_instance_;
}
// switch client singleton
static std::shared_ptr<HeterClient> GetSwitchInstance(
const std::vector<std::string>& peer_endpoints, int32_t peer_role) {
std::unique_lock<std::mutex> lock(mtx_);
if (peer_endpoints.empty()) {
VLOG(4) << "init switch client failed, null peer_endpoints";
}
VLOG(4) << "peer role is: " << peer_role
<< ", addr is: " << peer_endpoints[0];
if (switch_s_instance_ == nullptr) {
switch_s_instance_.reset(new HeterClient());
switch_s_instance_->SetPeerSwitchList(peer_endpoints);
switch_s_instance_->InitClientChannels(false, peer_endpoints, peer_role);
}
return switch_s_instance_;
}
void SetPeerSwitchList(const std::vector<std::string>& peer_endpoints) {
peer_switch_list_ = peer_endpoints;
}
void SetPeerWorkerList(const std::vector<std::string>& worker_endpoints) {
peer_worker_list_ = worker_endpoints;
}
void Stop();
std::future<int32_t> SendCmd(uint32_t table_id,
int cmd_id,
const std::vector<std::string>& params);
std::future<int32_t> StartProfiler();
std::future<int32_t> StopProfiler();
std::future<int32_t> StopHeterWorker();
std::vector<std::string>& GetXpuList() { return xpu_list_; }
void SetXpuList(const std::vector<std::string>& xpu_list) {
xpu_list_ = xpu_list;
}
void SetPreviousXpuList(const std::vector<std::string>& xpu_list) {
previous_xpu_list_ = xpu_list;
}
void SetTrainerID(const int& trainer_id) { trainer_id_ = trainer_id; }
public:
std::vector<std::string> send_switch_list_;
std::vector<std::string> recv_switch_list_;
std::vector<std::string> peer_switch_list_;
std::vector<std::string> peer_worker_list_;
std::vector<std::shared_ptr<brpc::Channel>> send_switch_channels_;
std::vector<std::shared_ptr<brpc::Channel>> recv_switch_channels_;
std::vector<std::shared_ptr<brpc::Channel>> peer_switch_channels_;
std::vector<std::shared_ptr<brpc::Channel>> peer_worker_channels_;
private:
HeterClient() {}
HeterClient& operator=(const HeterClient&);
HeterClient(const HeterClient&);
static std::shared_ptr<HeterClient> s_instance_;
static std::mutex mtx_;
static std::shared_ptr<HeterClient> switch_s_instance_;
std::vector<std::shared_ptr<brpc::Channel>> xpu_channels_;
std::vector<std::shared_ptr<brpc::Channel>> previous_xpu_channels_;
// DISABLE_COPY_AND_ASSIGN(HeterClient);
std::vector<std::string> xpu_list_;
std::vector<std::string> previous_xpu_list_;
int trainer_id_;
};
} // end namespace distributed
} // end namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/heter_server.h"
#include "paddle/fluid/string/split.h"
namespace paddle {
namespace distributed {
// DEFINE_string(cert_path, "./cert.pem", "cert.pem path");
// DEFINE_string(key_path, "./key.pem", "key.pem path");
std::shared_ptr<HeterServer> HeterServer::s_instance_ = nullptr;
std::mutex HeterServer::mtx_;
void HeterServer::RegisterServiceHandler(std::string message_name,
HeterServiceHandler func) {
service_.RegisterServiceHandler(message_name, func);
}
void HeterServer::StartHeterService(bool neeed_encrypt) {
server_.AddService(&service_, brpc::SERVER_DOESNT_OWN_SERVICE);
brpc::ServerOptions options;
if (neeed_encrypt) {
#ifdef PADDLE_WITH_ARM_BRPC
options.mutable_ssl_options()->default_cert.certificate = "/cert.pem";
options.mutable_ssl_options()->default_cert.private_key = "/key.pem";
#else
options.ssl_options.default_cert.certificate = "/cert.pem";
options.ssl_options.default_cert.private_key = "/key.pem";
#endif
}
if (server_.Start(endpoint_.c_str(), &options) != 0) {
VLOG(0) << "HeterServer start fail. Try again.";
auto ip_port = paddle::string::Split(endpoint_, ':');
std::string ip = ip_port[0];
int port = std::stoi(ip_port[1]);
std::string int_ip_port = GetIntTypeEndpoint(ip, port);
if (server_.Start(endpoint_.c_str(), &options) != 0) {
LOG(ERROR) << "HeterServer start failed, ip_port= " << int_ip_port;
}
} else {
VLOG(0) << "heter server start success! listen on " << endpoint_;
}
{
std::lock_guard<std::mutex> lock(this->mutex_ready_);
stoped_ = false;
ready_ = 1;
}
condition_ready_.notify_all();
VLOG(4) << "stopped: " << stoped_ << ", ready_: " << ready_;
std::unique_lock<std::mutex> running_lock(mutex_);
cv_.wait(running_lock, [&] {
VLOG(4) << "Heter Server is Stop? " << stoped_;
return stoped_;
});
VLOG(4) << "start service done";
}
void HeterServer::StartHeterInterService(bool neeed_encrypt) {
server_inter_.AddService(&service_, brpc::SERVER_DOESNT_OWN_SERVICE);
brpc::ServerOptions options;
if (neeed_encrypt) {
#ifdef PADDLE_WITH_ARM_BRPC
options.mutable_ssl_options()->default_cert.certificate = "/cert.pem";
options.mutable_ssl_options()->default_cert.private_key = "/key.pem";
#else
options.ssl_options.default_cert.certificate = "/cert.pem";
options.ssl_options.default_cert.private_key = "/key.pem";
#endif
}
if (server_inter_.Start(endpoint_inter_.c_str(), &options) != 0) {
VLOG(4) << "switch inter server start fail. Try again.";
auto ip_port = paddle::string::Split(endpoint_inter_, ':');
std::string ip = ip_port[0];
int port = std::stoi(ip_port[1]);
std::string int_ip_port = GetIntTypeEndpoint(ip, port);
if (server_inter_.Start(endpoint_inter_.c_str(), &options) != 0) {
LOG(ERROR) << "switch inter server start failed, ip_port= "
<< int_ip_port;
}
} else {
VLOG(4) << "switch inter server server start success! listen on "
<< endpoint_inter_;
}
{
std::lock_guard<std::mutex> lock(this->mutex_ready_);
stoped_ = false;
ready_ = 1;
}
condition_ready_.notify_all();
VLOG(4) << "stopped: " << stoped_ << ", ready_: " << ready_;
std::unique_lock<std::mutex> running_lock(mutex_);
cv_.wait(running_lock, [&] {
VLOG(4) << "Heter Server is Stop? " << stoped_;
return stoped_;
});
VLOG(4) << "start service done";
}
void HeterServer::SetFanin(const int& fan_in) { service_.SetFanin(fan_in); }
void HeterServer::WaitServerReady() {
std::unique_lock<std::mutex> lock(this->mutex_ready_);
condition_ready_.wait(lock, [=] { return this->ready_ == 1; });
}
int SendAndRecvVariableHandler::SaveInSwitchWithShard(
const MultiVarMsg* request,
PsResponseMessage* response,
brpc::Controller* cntl) {
VLOG(4) << "entering SaveInSwitchWithShard";
int32_t group_id = request->group_id();
if (group_id >= FLAGS_heter_world_size) {
LOG(ERROR) << "group id exceed maxmium";
}
auto& local_shard = _local_shards[group_id];
auto& request_io_buffer = cntl->request_attachment();
butil::IOBufBytesIterator io_buffer_itr(request_io_buffer);
for (int idx = 0; idx < request->send_var_names_size(); idx++) {
const auto& var_name = request->send_var_names(idx);
const auto& var_size = request->vars_len(idx);
WaitForVarsConsumed(group_id, var_name);
std::unique_lock<std::mutex> lk(scope_mutex_);
auto& value = local_shard[var_name];
value.resize(var_size);
io_buffer_itr.copy_and_forward(reinterpret_cast<void*>(value.data()),
var_size);
vars_ready_flag[group_id][var_name] = 1;
VLOG(4) << "saved var_name: " << var_name << "is saved ready!";
}
VLOG(4) << "SaveInSwitchWithShard success";
return 0;
}
int SendAndRecvVariableHandler::QueryInSwitchWithShard(
const MultiVarMsg* request, MultiVarMsg* response, brpc::Controller* cntl) {
VLOG(4) << "entering QueryInSwitchWithShard";
int32_t group_id = request->group_id();
VLOG(4) << "group id: " << group_id;
auto& local_shard = _local_shards[group_id];
auto& response_io_buffer = cntl->response_attachment();
auto req_var_nums = request->recv_var_names_size();
std::vector<std::string> req_var_names(req_var_nums);
for (int var_idx = 0; var_idx < req_var_nums; ++var_idx) {
req_var_names[var_idx] = request->recv_var_names(var_idx);
}
auto msg_name = request->message_name();
response->set_message_name(msg_name);
for (auto& req_var_name : req_var_names) {
VLOG(4) << "req var name: " << req_var_name;
response->add_send_var_names(req_var_name);
WaitForVarsProduced(group_id, req_var_name);
std::unique_lock<std::mutex> lk(scope_mutex_);
auto itr = local_shard.find(req_var_name);
auto& value = itr.value();
response_io_buffer.append(value.data(), value.size());
value.resize(0); // 清空内存
vars_ready_flag[group_id][req_var_name] = 0;
VLOG(4) << "query var_name: " << req_var_name << "is consumed ready!";
}
VLOG(4) << "heter server QueryInSwitchWithShard done";
return 0;
}
int SendAndRecvVariableHandler::SaveInSwitchWithScope(
const MultiVarMsg* request,
PsResponseMessage* response,
brpc::Controller* cntl) {
VLOG(4) << "entering SaveInSwitchWithScope";
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::CPUPlace cpu_place;
auto& cpu_dev_ctx = *pool.Get(cpu_place);
auto message_name = request->message_name();
VLOG(4) << "message_name in heter server: " << message_name;
auto send_var_nums = request->send_var_names_size();
std::vector<std::string> send_var_names(send_var_nums);
for (int idx = 0; idx < send_var_nums; idx++) {
send_var_names[idx] = request->var_messages(idx).varname();
}
std::unique_lock<std::mutex> lk(scope_mutex_);
auto local_scope = local_scope_ptr.get();
if (!local_scope) {
LOG(ERROR) << "local_scope_ptr is null in SaveInSwitchWithScope";
}
for (auto var_name : send_var_names) {
auto* var_exist_ptr = local_scope->FindVar(var_name);
if (!var_exist_ptr) {
VLOG(4) << "not find var: " << var_name << " in local_scope";
}
WaitForVarsConsumed(0, var_name);
}
auto& request_io_buffer = cntl->request_attachment();
distributed::DeserializeFromMultiVarMsgAndIOBuf(
*request, &request_io_buffer, cpu_dev_ctx, local_scope);
lk.unlock();
for (auto var_name : send_var_names) {
std::unique_lock<std::mutex> lk(scope_mutex_);
vars_ready_flag[0][var_name] = 1;
}
VLOG(4) << "SaveInSwitchWithScope success";
return 0;
}
int SendAndRecvVariableHandler::QueryInSwitchWithScope(
const MultiVarMsg* request, MultiVarMsg* response, brpc::Controller* cntl) {
VLOG(4) << "entering QueryInSwitchWithScope";
auto local_scope = local_scope_ptr.get();
if (!local_scope) {
LOG(INFO) << "local_scope is null";
}
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::CPUPlace cpu_place;
auto& cpu_dev_ctx = *pool.Get(cpu_place);
// get req message_name & req_var_names
auto msg_name = request->message_name();
auto req_var_nums = request->recv_var_names_size();
std::vector<std::string> req_var_names(req_var_nums);
for (int var_idx = 0; var_idx < req_var_nums; ++var_idx) {
req_var_names[var_idx] = request->recv_var_names(var_idx);
}
auto& response_io_buffer = cntl->response_attachment();
// 1. fill message_name(string)
response->set_message_name(msg_name);
// 2. fill var_names(string)
for (auto& req_var_name : req_var_names) {
response->add_send_var_names(req_var_name);
}
// 3. fill var_messages(VarMessage)
for (auto& req_var_name : req_var_names) {
WaitForVarsProduced(0, req_var_name);
auto* send_var_msg = response->add_var_messages();
send_var_msg->set_varname(req_var_name);
framework::Variable* var_ptr;
var_ptr = local_scope->FindVar(req_var_name);
if (!var_ptr) {
LOG(INFO) << "local_scope not find var: " << req_var_name;
}
butil::IOBuf temp_iobuf;
if (var_ptr->IsType<framework::LoDTensor>()) {
SerializeLodTensor(var_ptr, cpu_dev_ctx, send_var_msg, &temp_iobuf);
} else if (var_ptr->IsType<phi::SelectedRows>()) {
SerializeSelectedRows(var_ptr, cpu_dev_ctx, send_var_msg, &temp_iobuf);
}
response_io_buffer.append(temp_iobuf);
}
for (auto& req_var_name : req_var_names) {
std::unique_lock<std::mutex> lk(scope_mutex_);
vars_ready_flag[0][req_var_name] = 0;
}
VLOG(4) << "heter server QueryInSwitchWithScope done";
return 0;
}
} // end namespace distributed
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <atomic>
#include <ctime>
#include <map>
#include <memory>
#include <random>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/ps/service/brpc_utils.h"
#include "paddle/fluid/distributed/ps/service/heter_client.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/ps/table/depends/feature_value.h"
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
#include "paddle/fluid/platform/profiler.h"
namespace google {
namespace protobuf {
class Closure;
class RpcController;
} // namespace protobuf
} // namespace google
namespace paddle {
namespace framework {
class Executor;
class ProgramDesc;
class Scope;
} // namespace framework
} // namespace paddle
DECLARE_double(eager_delete_tensor_gb);
DECLARE_int32(pserver_timeout_ms);
DECLARE_int32(heter_world_size);
DECLARE_int32(switch_send_recv_timeout_s);
namespace paddle {
namespace distributed {
using MultiVarMsg = MultiVariableMessage;
using VarMsg = VariableMessage;
using serviceHandler =
std::function<int32_t(const PsRequestMessage& request,
PsResponseMessage& response, // NOLINT
brpc::Controller* cntl)>;
using HeterServiceHandler =
std::function<int32_t(const MultiVarMsg*, MultiVarMsg*, brpc::Controller*)>;
using HeterRpcCallbackFunc = std::function<void(void*)>;
class ServiceHandlerBase {
public:
ServiceHandlerBase() : dev_ctx_(nullptr), scope_(nullptr) {}
virtual ~ServiceHandlerBase() {}
void SetScope(const framework::Scope* scope) { scope_ = scope; }
void SetDevCtx(const platform::DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; }
virtual int Handle(const MultiVarMsg* request,
MultiVarMsg* response,
brpc::Controller* cntl) = 0;
protected:
const platform::DeviceContext* dev_ctx_;
const framework::Scope* scope_;
};
using SharedMiniScope =
std::shared_ptr<std::unordered_map<int, ::paddle::framework::Scope*>>;
using SharedMicroScope = std::shared_ptr<std::unordered_map<
int,
std::shared_ptr<std::vector<::paddle::framework::Scope*>>>>;
using SharedTaskQueue = std::shared_ptr<
std::unordered_map<int,
std::shared_ptr<::paddle::framework::BlockingQueue<
std::pair<std::string, int>>>>>;
class ValueInSwitch {
public:
ValueInSwitch() {}
~ValueInSwitch() {}
char* data() { return _data.data(); }
size_t size() { return _data.size(); }
void resize(size_t size) { _data.resize(size); }
void shrink_to_fit() { _data.shrink_to_fit(); }
private:
std::vector<char> _data;
};
class SendAndRecvVariableHandler final : public ServiceHandlerBase {
public:
SendAndRecvVariableHandler() {
this->num_microbatch_ = 0;
this->num_minibatch_ = 0;
_local_shards.reset(new shard_type[FLAGS_heter_world_size]);
}
virtual ~SendAndRecvVariableHandler() {}
void SetMiniScopes(SharedMiniScope mini_scopes) {
mini_scopes_ = mini_scopes;
num_minibatch_ = mini_scopes_->size();
}
void SetMicroScopes(SharedMicroScope micro_scopes) {
micro_scopes_ = micro_scopes;
for (auto& scope_pair : (*micro_scopes_)) {
// auto mini_idx = scope_pair.first;
auto& micro_scopes = scope_pair.second;
num_microbatch_ = micro_scopes->size();
break;
}
}
int GetThreadNum() {
std::unique_lock<std::mutex> lk(scope_mutex_);
return (*task_queue_).size();
}
int SaveInSwitchWithScope(const MultiVarMsg* request,
PsResponseMessage* response,
brpc::Controller* cntl);
void WaitForVarsConsumed(int32_t group_id, const std::string& var_name) {
// timeline_.Start();
while (true) {
{
std::lock_guard<std::mutex> lock(scope_mutex_);
if (vars_ready_flag[group_id][var_name] == 0) {
break;
}
}
/*
timeline_.Pause();
if (timeline_.ElapsedSec() > FLAGS_switch_send_recv_timeout_s) {
VLOG(0) << "vars not consumed exceed 10 miniutes";
break;
}
*/
}
return;
}
void WaitForVarsProduced(int32_t group_id, const std::string& var_name) {
// timeline_.Start();
while (true) {
{
std::lock_guard<std::mutex> lock(scope_mutex_);
if (vars_ready_flag[group_id][var_name] == 1) {
break;
}
}
/*
timeline_.Pause();
if (timeline_.ElapsedSec() > FLAGS_switch_send_recv_timeout_s) {
VLOG(0) << "vars not produced exceed 10 miniutes";
break;
}
*/
}
return;
}
int SaveInSwitchWithShard(const MultiVarMsg* request,
PsResponseMessage* response,
brpc::Controller* cntl);
int QueryInSwitchWithShard(const MultiVarMsg* request,
MultiVarMsg* response,
brpc::Controller* cntl);
int QueryInSwitchWithScope(const MultiVarMsg* request,
MultiVarMsg* response,
brpc::Controller* cntl);
void SetTaskQueue(SharedTaskQueue task_queue) { task_queue_ = task_queue; }
int Handle(const MultiVarMsg* request,
MultiVarMsg* response,
brpc::Controller* cntl) override {
LOG(INFO) << "entered Handle";
platform::RecordEvent record_event("SendAndRecvVariableHandler->Handle",
platform::TracerEventType::Communication,
1);
FLAGS_eager_delete_tensor_gb = -1;
// get microID from request
// deserialize variable to micro scope
// Push to heter worker's task_queue
std::unique_ptr<paddle::framework::Scope> local_scope_ptr(
new paddle::framework::Scope());
auto& local_scope = *(local_scope_ptr.get());
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::CPUPlace cpu_place;
auto& cpu_dev_ctx = *pool.Get(cpu_place);
auto message_name = request->message_name();
auto& request_io_buffer = cntl->request_attachment();
distributed::DeserializeFromMultiVarMsgAndIOBuf(
*request, &request_io_buffer, cpu_dev_ctx, &local_scope);
auto* var = local_scope.FindVar("microbatch_id");
PADDLE_ENFORCE_NE(var,
nullptr,
platform::errors::InvalidArgument(
"Not find variable microbatch_id in scope."));
auto* tensor = var->GetMutable<framework::LoDTensor>();
auto data = reinterpret_cast<const float*>(tensor->data());
auto micro_id = static_cast<int>(data[0]);
VLOG(4) << "micro_id in heter server: " << micro_id;
int minibatch_index = micro_id / 10;
int microbatch_index = micro_id % 10;
// check minibatch_index is in mini_scopes_
std::unique_lock<std::mutex> lk(scope_mutex_);
if ((*mini_scopes_).find(minibatch_index) != (*mini_scopes_).end()) {
lk.unlock();
PADDLE_ENFORCE_EQ(
(*micro_scopes_).find(minibatch_index) != (*micro_scopes_).end(),
1,
platform::errors::InvalidArgument(
"minibatch index should in current trainer"));
} else {
// create mini scope & micro scopes
auto* minibatch_scope = &(scope_->NewScope());
(*mini_scopes_)[minibatch_index] = minibatch_scope;
(*micro_scopes_)[minibatch_index].reset(
new std::vector<paddle::framework::Scope*>{});
for (int i = 0; i < num_microbatch_; i++) {
auto* micro_scope = &(minibatch_scope->NewScope());
(*((*micro_scopes_)[minibatch_index])).push_back(micro_scope);
}
(*task_queue_)[minibatch_index].reset(
new ::paddle::framework::BlockingQueue<
std::pair<std::string, int>>());
lk.unlock();
}
auto* micro_scope =
(*((*micro_scopes_)[minibatch_index]))[microbatch_index];
distributed::DeserializeFromMultiVarMsgAndIOBuf(
*request, &request_io_buffer, *dev_ctx_, micro_scope);
// blocking queue handles multi thread
VLOG(4) << "Handle in HeterServer: " << message_name << ", "
<< microbatch_index;
VLOG(4) << "task_queue_ size: " << task_queue_->size();
(*task_queue_)[minibatch_index]->Push(
std::make_pair(message_name, microbatch_index));
auto response_var_nums = request->recv_var_names_size();
std::vector<std::string> response_var_names(response_var_nums),
empty_var_names{};
for (int var_idx = 0; var_idx < response_var_nums; ++var_idx) {
response_var_names[var_idx] = request->recv_var_names(var_idx);
}
auto& response_io_buffer = cntl->response_attachment();
distributed::SerializeToMultiVarMsgAndIOBuf(message_name,
response_var_names,
empty_var_names,
*dev_ctx_,
&local_scope,
response,
&response_io_buffer);
VLOG(4) << "Handle over";
return 0;
}
public:
using shard_type = SparseTableShard<std::string, ValueInSwitch>;
std::shared_ptr<paddle::framework::Scope> local_scope_ptr; // for switch
std::unordered_map<uint32_t, std::unordered_map<std::string, uint32_t>>
vars_ready_flag;
std::unique_ptr<shard_type[]> _local_shards;
platform::Timer timeline_;
private:
// share with HeterPipelineTrainer
SharedMiniScope mini_scopes_{nullptr};
SharedMicroScope micro_scopes_{nullptr};
int num_microbatch_;
int num_minibatch_;
std::mutex scope_mutex_;
bool is_first_stage_ = false;
bool is_last_stage_ = false;
SharedTaskQueue task_queue_;
};
class HeterService : public PsService {
public:
HeterService() {
_service_handler_map[PS_STOP_SERVER] =
std::bind(&HeterService::stop_heter_worker,
this,
std::placeholders::_1,
std::placeholders::_2,
std::placeholders::_3);
_service_handler_map[PS_START_PROFILER] =
std::bind(&HeterService::start_profiler,
this,
std::placeholders::_1,
std::placeholders::_2,
std::placeholders::_3);
_service_handler_map[PS_STOP_PROFILER] =
std::bind(&HeterService::stop_profiler,
this,
std::placeholders::_1,
std::placeholders::_2,
std::placeholders::_3);
service_handler_.local_scope_ptr =
std::make_shared<paddle::framework::Scope>();
}
virtual ~HeterService() {}
virtual void service(::google::protobuf::RpcController* controller,
const PsRequestMessage* request,
PsResponseMessage* response,
::google::protobuf::Closure* done) {
brpc::ClosureGuard done_guard(done);
response->set_err_code(0);
response->set_err_msg("");
brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
auto itr = _service_handler_map.find(request->cmd_id());
if (itr == _service_handler_map.end()) {
std::string err_msg(
"undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:");
err_msg.append(std::to_string(request->cmd_id()));
return;
}
serviceHandler handler = itr->second;
int service_ret = handler(*request, *response, cntl);
VLOG(4) << "handler in service ret: " << service_ret;
if (service_ret != 0) {
response->set_err_code(service_ret);
response->set_err_msg("server internal error");
}
}
virtual void SendAndRecvVariable(
::google::protobuf::RpcController* controller,
const MultiVarMsg* request,
MultiVarMsg* response,
::google::protobuf::Closure* done) {
// This object helps you to call done->Run() in RAII style. If you need
// to process the request asynchronously, pass done_guard.release().
brpc::ClosureGuard done_guard(done);
std::string message_name = request->message_name();
VLOG(0) << "SendAndRecvVariable message_name: " << message_name;
auto itr = handler_map_.find(message_name);
brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
LOG(INFO) << "SendAndRecvVariable(client addr) =" << cntl->remote_side();
PADDLE_ENFORCE_NE(
itr,
handler_map_.end(),
platform::errors::InvalidArgument(
"HeterService::SendAndRecvVariable Get illegal message_name: %s "
"which is not in HeterService::handler_map_",
message_name));
itr->second(request, response, cntl);
// We don't want to call done->Run() here, release the guard.
// done_guard.release();
}
virtual void RecvFromSwitch(::google::protobuf::RpcController* controller,
const MultiVarMsg* request,
MultiVarMsg* response,
::google::protobuf::Closure* done) {
brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
// int ret = service_handler_.QueryInSwitchWithScope(request, response,
// cntl);
int ret = service_handler_.QueryInSwitchWithShard(request, response, cntl);
// std::string message_name = request->message_name();
// auto itr = handler_map_.find(message_name);
// int ret = itr->second(request, response, cntl);
if (ret != 0) {
LOG(ERROR) << "QueryInSwitchWithScope failed!";
}
// response->set_message_name(message_name);
}
virtual void SendToSwitch(::google::protobuf::RpcController* controller,
const MultiVarMsg* request,
PsResponseMessage* response,
::google::protobuf::Closure* done) {
VLOG(4) << "entering SendToSwitch";
brpc::ClosureGuard done_guard(done);
std::shared_ptr<HeterClient> switch_client_ptr_ =
HeterClient::GetSwitchInstance(peer_endpoints_, PEER_ROLE_IS_SWITCH);
if (switch_client_ptr_->peer_switch_channels_.empty()) {
LOG(ERROR) << "switch_client_ptr_->peer_switch_channels_ null";
}
brpc::Channel* channel = switch_client_ptr_->peer_switch_channels_[0].get();
brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
// proxy: 定义新的 OnHeterRpcDone 对象(或者在类 OnHeterRpcDone 中 reset)
OnHeterRpcDone* closure2 = new OnHeterRpcDone([](void* done) {
auto* closure = reinterpret_cast<OnHeterRpcDone*>(done);
int ret = closure->CheckResponse();
closure->set_promise_value(ret);
if (closure->cntl.Failed()) {
PADDLE_ENFORCE_NE(
closure->cntl.Failed(),
true,
platform::errors::Unimplemented(
"HeterClient::SendS2S meets brpc error, error message is %s",
closure->cntl.ErrorText()));
}
});
auto& std_cntl = closure2->cntl;
std_cntl.set_timeout_ms(FLAGS_pserver_timeout_ms);
std_cntl.request_attachment().append(cntl->request_attachment().movable());
auto promise = std::make_shared<std::promise<int32_t>>();
closure2->add_promise(promise);
std::future<int> fut = promise->get_future();
// brpc::Controller std_cntl;
// std_cntl.request_attachment().append(cntl->request_attachment().movable());
PsService_Stub stub(channel);
stub.SendS2S(&std_cntl, request, response, closure2);
cntl->response_attachment().append(
std_cntl.response_attachment().movable());
fut.wait();
VLOG(4) << "SendToSwitch done";
delete closure2;
}
void SendS2S(::google::protobuf::RpcController* controller,
const MultiVarMsg* request,
PsResponseMessage* response,
::google::protobuf::Closure* done) {
VLOG(4) << "entering SendS2S";
brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
// int ret = service_handler_.SaveInSwitchWithScope(request, response,
// cntl);
int ret = service_handler_.SaveInSwitchWithShard(request, response, cntl);
// std::string message_name = request->message_name();
// auto itr = handler_map_.find(message_name);
// if (itr == handler_map_.end()) {
// LOG(ERROR) << "can not find func handler";
//}
// int ret = itr->second(request, response, cntl);
if (ret != 0) {
LOG(ERROR) << "SaveInSwitchWithScope failed";
}
std::string err_msg = "ok";
response->set_err_msg(err_msg.c_str());
response->set_err_code(ret);
VLOG(4) << "heter server SendS2S done";
}
void SendToWorker(::google::protobuf::RpcController* controller,
const MultiVarMsg* request,
PsResponseMessage* response,
::google::protobuf::Closure* done) {
brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
VLOG(4) << "SendToWorker(client addr) =" << cntl->remote_side();
std::shared_ptr<distributed::HeterClient> switch_client_ptr_ =
HeterClient::GetSwitchInstance(peer_endpoints_, PEER_ROLE_IS_WORKER);
VLOG(4) << "in switch client, peer worker 0: "
<< switch_client_ptr_->peer_worker_list_[0];
brpc::Channel* channel = switch_client_ptr_->peer_worker_channels_[0].get();
auto* closure = reinterpret_cast<OnHeterRpcDone*>(done);
PsService_Stub stub(channel);
stub.SendAndRecvVariable(controller, request, &closure->response, done);
// fill response content
std::string err_msg("pass to worker");
response->set_err_msg(err_msg.c_str());
response->set_err_code(0);
}
void RegisterServiceHandler(std::string message_name,
HeterServiceHandler func) {
handler_map_[message_name] = func;
}
void SetEndpoint(const std::string& end_point) { endpoint_ = end_point; }
void SetInterEndpoint(const std::string& end_point) {
endpoint_inter_ = end_point;
}
void SetPeerEndPoints(const std::vector<std::string>& peer_endpoints) {
peer_endpoints_ = peer_endpoints;
}
void SetFanin(const int& fan_in) { fan_in_ = fan_in; }
void ForceExit() {
VLOG(3) << "heter service force exit";
is_exit_ = true;
return;
}
bool IsExit() { return is_exit_; }
private:
int32_t stop_profiler(const PsRequestMessage& request,
PsResponseMessage& response, // NOLINT
brpc::Controller* cntl) {
platform::DisableProfiler(
platform::EventSortingKey::kDefault,
string::Sprintf("heter_worker_%s_profile", endpoint_));
return 0;
}
int32_t start_profiler(const PsRequestMessage& request,
PsResponseMessage& response, // NOLINT
brpc::Controller* cntl) {
platform::EnableProfiler(platform::ProfilerState::kAll);
return 0;
}
int32_t stop_heter_worker(const PsRequestMessage& request,
PsResponseMessage& response, // NOLINT
brpc::Controller* cntl) {
auto client_id = request.client_id();
stop_cpu_worker_set_.insert(client_id);
if (stop_cpu_worker_set_.size() == fan_in_) {
is_exit_ = true;
}
return 0;
}
private:
SendAndRecvVariableHandler service_handler_;
std::string endpoint_;
std::string endpoint_inter_;
// for switch
std::vector<std::string> peer_endpoints_;
std::unordered_map<int32_t, serviceHandler> _service_handler_map;
std::unordered_map<std::string, HeterServiceHandler> handler_map_;
std::unordered_set<int> stop_cpu_worker_set_;
uint32_t fan_in_;
bool is_exit_ = false;
};
class HeterServer {
public:
HeterServer() : ready_(0) {}
virtual ~HeterServer() {}
void Stop() {
std::unique_lock<std::mutex> lock(mutex_);
if (stoped_ == true) return;
if (!IsExit()) {
service_.ForceExit();
}
stoped_ = true;
cv_.notify_all();
server_.Stop(1000);
server_.Join();
}
bool IsStop() {
std::unique_lock<std::mutex> lock(mutex_);
return stoped_;
}
bool IsExit() { return service_.IsExit(); }
void RegisterServiceHandler(std::string message_name,
HeterServiceHandler func);
void StartHeterService(bool need_encrypt = false);
void StartHeterInterService(bool need_encrypt = false);
void SetEndPoint(const std::string& endpoint) {
this->endpoint_ = endpoint;
service_.SetEndpoint(endpoint);
}
void SetLocalScope() {
request_handler_->local_scope_ptr =
std::make_shared<paddle::framework::Scope>();
}
void SetInterEndpoint(const std::string& endpoint) {
this->endpoint_inter_ = endpoint;
service_.SetInterEndpoint(endpoint);
}
void SetPeerEndPoints(const std::vector<std::string>& peer_endpoints) {
this->peer_endpoints_ = peer_endpoints;
service_.SetPeerEndPoints(peer_endpoints);
}
void SetFanin(const int& fan_in);
void SetServiceHandler(
std::shared_ptr<SendAndRecvVariableHandler> request_handler) {
request_handler_ = request_handler;
}
void SetMiniBatchScopes(SharedMiniScope mini_scopes) {
request_handler_->SetMiniScopes(mini_scopes);
}
void SetMicroBatchScopes(SharedMicroScope micro_scopes) {
request_handler_->SetMicroScopes(micro_scopes);
}
int GetThreadNum() { return request_handler_->GetThreadNum(); }
void SetTaskQueue(SharedTaskQueue task_queue) {
request_handler_->SetTaskQueue(task_queue);
}
// HeterWrapper singleton
static std::shared_ptr<HeterServer> GetInstance() {
std::unique_lock<std::mutex> lock(mtx_);
if (s_instance_ == nullptr) {
s_instance_.reset(new HeterServer());
}
return s_instance_;
}
void WaitServerReady();
private:
static std::shared_ptr<HeterServer> s_instance_;
mutable std::mutex mutex_;
static std::mutex mtx_;
std::condition_variable cv_;
std::condition_variable condition_ready_;
bool stoped_ = true;
std::string endpoint_;
std::string endpoint_inter_;
// for switch
std::vector<std::string> peer_endpoints_;
protected:
brpc::Server server_;
brpc::Server server_inter_;
HeterService service_;
std::shared_ptr<SendAndRecvVariableHandler> request_handler_;
DISABLE_COPY_AND_ASSIGN(HeterServer);
std::mutex mutex_ready_;
int ready_;
};
} // end namespace distributed
} // end namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/ps_client.h"
#include "glog/logging.h"
#include "paddle/fluid/distributed/ps/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/ps/service/graph_brpc_client.h"
#include "paddle/fluid/distributed/ps/service/ps_local_client.h"
#include "paddle/fluid/distributed/ps/table/table.h"
namespace paddle {
namespace distributed {
REGISTER_PSCORE_CLASS(PSClient, BrpcPsClient);
REGISTER_PSCORE_CLASS(PSClient, PsLocalClient);
REGISTER_PSCORE_CLASS(PSClient, GraphBrpcClient);
int32_t PSClient::Configure(
const PSParameter &config,
const std::map<uint64_t, std::vector<paddle::distributed::Region>> &regions,
PSEnvironment &env,
size_t client_id) {
_env = &env;
_config = config;
_dense_pull_regions = regions;
_client_id = client_id;
_config.mutable_worker_param()
->mutable_downpour_worker_param()
->mutable_downpour_table_param()
->CopyFrom(_config.server_param()
.downpour_server_param()
.downpour_table_param());
const auto &work_param = _config.worker_param().downpour_worker_param();
for (int i = 0; i < work_param.downpour_table_param_size(); ++i) {
auto *accessor = CREATE_PSCORE_CLASS(
ValueAccessor,
work_param.downpour_table_param(i).accessor().accessor_class());
accessor->Configure(work_param.downpour_table_param(i).accessor());
accessor->Initialize();
_table_accessors[work_param.downpour_table_param(i).table_id()].reset(
accessor);
}
return Initialize();
}
PSClient *PSClientFactory::Create(const PSParameter &ps_config) {
const auto &config = ps_config.server_param();
if (!config.has_downpour_server_param()) {
LOG(ERROR) << "miss downpour_server_param in ServerParameter";
return NULL;
}
if (!config.downpour_server_param().has_service_param()) {
LOG(ERROR) << "miss service_param in ServerParameter.downpour_server_param";
return NULL;
}
if (!config.downpour_server_param().service_param().has_client_class()) {
LOG(ERROR) << "miss client_class in "
"ServerParameter.downpour_server_param.service_param";
return NULL;
}
const auto &service_param = config.downpour_server_param().service_param();
PSClient *client =
CREATE_PSCORE_CLASS(PSClient, service_param.client_class());
if (client == NULL) {
LOG(ERROR) << "client is not registered, server_name:"
<< service_param.client_class();
return NULL;
}
TableManager::Instance().Initialize();
VLOG(3) << "Create PSClient[" << service_param.client_class() << "] success";
return client;
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <future>
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/common/cost_timer.h"
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/ps/service/env.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/platform/timer.h"
namespace paddle {
namespace distributed {
using paddle::distributed::PsRequestMessage;
using paddle::distributed::PsResponseMessage;
typedef std::function<void(void *)> PSClientCallBack;
class PSClientClosure : public google::protobuf::Closure {
public:
explicit PSClientClosure(PSClientCallBack callback) : _callback(callback) {}
virtual ~PSClientClosure() {}
virtual void set_promise_value(int value) {
for (auto &promise : _promises) {
promise->set_value(value);
}
}
void add_promise(std::shared_ptr<std::promise<int32_t>> &promise) { // NOLINT
_promises.push_back(promise);
}
void add_timer(std::shared_ptr<CostTimer> &timer) { // NOLINT
_timers.push_back(timer);
}
protected:
PSClientCallBack _callback;
std::vector<std::shared_ptr<CostTimer>> _timers;
std::vector<std::shared_ptr<std::promise<int32_t>>> _promises;
};
class PSClient {
public:
PSClient() {}
virtual ~PSClient() {}
PSClient(PSClient &&) = delete;
PSClient(const PSClient &) = delete;
virtual int32_t Configure( // NOLINT
const PSParameter &config,
const std::map<uint64_t, std::vector<paddle::distributed::Region>>
&regions,
PSEnvironment &_env,
size_t client_id) final; // NOLINT
virtual int32_t CreateClient2ClientConnection(int pserver_timeout_ms,
int pserver_connect_timeout_ms,
int max_retry) = 0;
// 触发table数据退场
virtual std::future<int32_t> Shrink(uint32_t table_id,
const std::string threshold) = 0;
// 全量table进行数据load
virtual std::future<int32_t> Load(const std::string &epoch,
const std::string &mode) = 0;
// 指定table数据load
virtual std::future<int32_t> Load(uint32_t table_id,
const std::string &epoch,
const std::string &mode) = 0;
// 全量table数据save value_accessor根据mode,可能有不同的save条件
virtual std::future<int32_t> Save(const std::string &epoch,
const std::string &mode) = 0;
// 指定table数据save value_accessor根据mode,可能有不同的save条件
virtual std::future<int32_t> Save(uint32_t table_id,
const std::string &epoch,
const std::string &mode) = 0;
// 清空table数据
virtual std::future<int32_t> Clear() = 0;
virtual std::future<int32_t> Clear(uint32_t table_id) = 0;
// pull dense的参数部分,并分块填充到本地网络参数中
// start和num用于拉取部分参数
// future结束前keys和values缓冲区不能再次使用
// client将values按照区块拆包后送交多个sender
// sender聚集同一区块的请求,累计多个填充buffer
// server将参数区块中配置的某一维提取返回
// 返回数据解包后填充到累计的多个buffer中
virtual std::future<int32_t> PullDense(Region *regions,
size_t region_num,
size_t table_id) = 0; // 保留
// firstly push dense param for parameter server
// this is necessary because dense weight initialized in trainer on cold
// start
virtual std::future<int32_t> PushDenseParam(const Region *regions,
size_t region_num,
size_t table_id) = 0;
virtual std::future<int32_t> PushDense(const Region *regions,
size_t region_num,
size_t table_id) = 0;
// 使用keys进行pull请求,结果填充values
// keys和values的个数均为num个,每个value占用select_size空间
// future结束前keys和values缓冲区不能再次使用
// 整合多个线程请求的keys,聚集并分散发送到server
// 返回结果后,遍历buffer并对values赋值
// is_training 用于区分请求是训练/预测,server端对于特征和准入会有不同的处理.
virtual std::future<int32_t> PullSparse(float **select_values,
size_t table_id,
const uint64_t *keys,
size_t num,
bool is_training) = 0;
virtual std::future<int32_t> PullSparseParam(float **select_values,
size_t table_id,
const uint64_t *keys,
size_t num,
bool is_training) {
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
virtual ::std::future<int32_t> PullSparsePtr(char **select_values,
size_t table_id,
const uint64_t *keys,
size_t num) {
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
virtual std::future<int32_t> PrintTableStat(uint32_t table_id) = 0;
// 确保所有积攒中的请求都发起发送
virtual std::future<int32_t> Flush() = 0;
// server优雅退出
virtual std::future<int32_t> StopServer() = 0;
// server profilera
virtual std::future<int32_t> StartProfiler() = 0;
virtual std::future<int32_t> StopProfiler() = 0;
virtual std::future<int32_t> Barrier(size_t table_id,
uint32_t barrier_type) = 0;
virtual std::future<int32_t> PullGeoParam(size_t table_id,
std::vector<float> *values,
std::vector<uint64_t> *keys,
int pserver_idx) = 0;
virtual std::future<int32_t> PushGlobalStep(int table_id,
int64_t *total_send_data,
void *done) = 0;
// recv table from server and save it in LodTensor
virtual int32_t RecvAndSaveTable(const uint64_t table_id,
const std::string &path) = 0;
virtual void FinalizeWorker() = 0;
// client to client, 消息发送
virtual std::future<int32_t> SendClient2ClientMsg(int msg_type,
int to_client_id,
const std::string &msg) {
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
// client2client消息处理,std::function<int32_t (int, int, const std::string&)
// -> ret (msg_type, from_client_id, msg)
typedef std::function<int32_t(int, int, const std::string &)> MsgHandlerFunc;
virtual int RegisteClient2ClientMsgHandler(int msg_type,
MsgHandlerFunc handler) {
_msg_handler_map[msg_type] = handler;
return 0;
}
virtual int HandleClient2ClientMsg(int msg_type,
int from_client_id,
const std::string &msg) {
auto itr = _msg_handler_map.find(msg_type);
if (itr == _msg_handler_map.end()) {
LOG(WARNING) << "unknown client2client_msg type:" << msg_type;
return -1;
}
return itr->second(msg_type, from_client_id, msg);
}
virtual ValueAccessor *GetTableAccessor(size_t table_id) {
auto itr = _table_accessors.find(table_id);
if (itr == _table_accessors.end()) {
return NULL;
}
return itr->second.get();
}
virtual size_t GetServerNums() = 0;
virtual std::future<int32_t> PushDenseRawGradient(int table_id,
float *total_send_data,
size_t total_send_data_size,
void *done) = 0;
virtual std::future<int32_t> PushSparseRawGradient(
size_t table_id,
const uint64_t *keys,
const float **update_values,
size_t num,
void *done) = 0;
virtual std::future<int32_t> PushSparseRawGradientPartial(
size_t table_id,
const uint64_t *keys,
const float **update_values,
uint32_t num,
void *done,
int pserver_idx) = 0;
virtual std::future<int32_t> PushSparseParam(size_t table_id,
const uint64_t *keys,
const float **update_values,
size_t num,
void *done) = 0;
virtual std::future<int32_t> PushSparse(size_t table_id,
const uint64_t *keys,
const float **update_values,
size_t num) = 0;
// for save cache
virtual std::future<int32_t> CacheShuffle(
uint32_t table_id,
const std::string &path,
const std::string &mode,
const std::string &cache_threshold) {
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
virtual std::future<int32_t> CacheShuffleMultiTable(
std::vector<int> tables,
const std::string &path,
const std::string &mode,
const std::string &cache_threshold) {
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
virtual std::future<int32_t> SaveCache(uint32_t table_id,
const std::string &path,
const std::string &mode) {
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
virtual std::future<int32_t> GetCacheThreshold(uint32_t table_id,
double &cache_threshold) {
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
protected:
virtual int32_t Initialize() = 0;
size_t _client_id;
PSParameter _config;
std::map<uint64_t, std::vector<paddle::distributed::Region>>
_dense_pull_regions;
PSEnvironment *_env;
std::unordered_map<uint32_t, std::shared_ptr<ValueAccessor>> _table_accessors;
std::unordered_map<int32_t, MsgHandlerFunc>
_msg_handler_map; // 处理client2client消息
};
template <class T>
class AsyncRequestTask {
public:
AsyncRequestTask() : _promise(std::make_shared<std::promise<int32_t>>()) {}
AsyncRequestTask(T &data, size_t table_id, std::shared_ptr<CostTimer> &timer)
: _table_id(table_id),
_timer(timer),
_promise(std::make_shared<std::promise<int32_t>>()) {
_data = std::move(data);
}
AsyncRequestTask(AsyncRequestTask &data) // NOLINT
: _table_id(data.table_id()),
_timer(data.timer()),
_promise(data.promise()) {
_data = std::move(data.data());
}
~AsyncRequestTask() {}
inline T &data() { return _data; }
inline size_t table_id() { return _table_id; }
inline std::shared_ptr<CostTimer> &timer() { return _timer; }
inline std::future<int32_t> get_future() { return _promise->get_future(); }
inline std::shared_ptr<std::promise<int32_t>> &promise() { return _promise; }
private:
T _data;
size_t _table_id;
std::shared_ptr<CostTimer> _timer;
std::shared_ptr<std::promise<int32_t>> _promise;
};
REGISTER_PSCORE_REGISTERER(PSClient);
class PSClientFactory {
public:
static PSClient *Create(const PSParameter &config);
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/ps_local_client.h"
#include "paddle/fluid/distributed/ps/table/table.h"
//#define pslib_debug_dense_compress
namespace paddle {
namespace distributed {
int32_t PsLocalClient::Initialize() {
const auto& downpour_param = _config.server_param().downpour_server_param();
TableManager::Instance().Initialize();
for (int i = 0; i < downpour_param.downpour_table_param_size(); ++i) {
auto* table = CREATE_PSCORE_CLASS(
Table, downpour_param.downpour_table_param(i).table_class());
table->SetShard(0, 1);
table->Initialize(downpour_param.downpour_table_param(i),
_config.fs_client_param());
_table_map[downpour_param.downpour_table_param(i).table_id()].reset(table);
}
return 0;
}
::std::future<int32_t> PsLocalClient::Shrink(uint32_t table_id,
const std::string threshold) {
// TODO
return done();
}
::std::future<int32_t> PsLocalClient::Load(const std::string& epoch,
const std::string& mode) {
// TODO
for (auto& it : _table_map) {
Load(it.first, epoch, mode);
}
return done();
}
::std::future<int32_t> PsLocalClient::Load(uint32_t table_id,
const std::string& epoch,
const std::string& mode) {
// TODO
auto* table_ptr = GetTable(table_id);
table_ptr->Load(epoch, mode);
return done();
}
::std::future<int32_t> PsLocalClient::Save(const std::string& epoch,
const std::string& mode) {
// TODO
for (auto& it : _table_map) {
Save(it.first, epoch, mode);
}
return done();
}
::std::future<int32_t> PsLocalClient::Save(uint32_t table_id,
const std::string& epoch,
const std::string& mode) {
// TODO
auto* table_ptr = GetTable(table_id);
table_ptr->Flush();
table_ptr->Save(epoch, mode);
return done();
}
::std::future<int32_t> PsLocalClient::Clear() {
// TODO
return done();
}
::std::future<int32_t> PsLocalClient::Clear(uint32_t table_id) {
// TODO
return done();
}
::std::future<int32_t> PsLocalClient::Flush() {
// no need
return done();
}
::std::future<int32_t> PsLocalClient::StopServer() {
// no need
return done();
}
::std::future<int32_t> PsLocalClient::PullDense(Region* regions,
size_t region_num,
size_t table_id) {
auto* accessor = GetTableAccessor(table_id);
auto* table_ptr = GetTable(table_id);
uint32_t num_per_shard =
DenseDimPerShard(accessor->GetAccessorInfo().fea_dim, 1);
std::vector<float> region_buffer;
region_buffer.resize(num_per_shard);
TableContext table_context;
table_context.value_type = Dense;
table_context.pull_context.values = region_buffer.data();
table_context.num = region_buffer.size();
table_ptr->Pull(table_context);
// table_ptr->PullDense(region_buffer.data(), region_buffer.size());
size_t region_idx = 0;
size_t region_data_idx = 0;
size_t shard_data_size = num_per_shard;
size_t shard_buffer_remain = shard_data_size * sizeof(float);
PADDLE_ENFORCE_EQ(
shard_buffer_remain,
region_buffer.size() * sizeof(float),
platform::errors::PreconditionNotMet("pull dense size error."));
size_t index = 0;
while (shard_buffer_remain > 0 && region_idx < region_num) {
auto& region = regions[region_idx];
if (region.size - region_data_idx >= shard_buffer_remain) {
memcpy((void*)(region.data + region_data_idx),
(uint8_t*)(void*)(region_buffer.data()) + index,
shard_buffer_remain);
region_data_idx += shard_buffer_remain;
shard_buffer_remain = 0;
} else if (region.size - region_data_idx == 0) {
++region_idx;
region_data_idx = 0;
} else {
memcpy((void*)(region.data + region_data_idx),
(uint8_t*)(void*)(region_buffer.data()) + index,
region.size - region_data_idx);
shard_buffer_remain -= (region.size - region_data_idx);
index += (region.size - region_data_idx);
++region_idx;
region_data_idx = 0;
}
}
return done();
}
::std::future<int32_t> PsLocalClient::PushDenseParam(const Region* regions,
size_t region_num,
size_t table_id) {
auto* accessor = GetTableAccessor(table_id);
auto* table_ptr = GetTable(table_id);
std::vector<float> region_buffer;
region_buffer.resize(DenseDimPerShard(accessor->GetAccessorInfo().fea_dim, 1),
0);
for (size_t i = 0, offset = 0; i < region_num; ++i) {
uint32_t data_num = regions[i].size / sizeof(float);
memcpy(region_buffer.data() + offset, regions[i].data, regions[i].size);
offset += data_num;
}
TableContext table_context;
table_context.value_type = Dense;
table_context.push_context.values = region_buffer.data();
table_context.push_context.is_param = true;
table_context.num = region_buffer.size();
table_ptr->Push(table_context);
// table_ptr->PushDenseParam(region_buffer.data(), region_buffer.size());
return done();
}
::std::future<int32_t> PsLocalClient::PushDenseRawGradient(
int table_id,
float* total_send_data,
size_t total_send_data_size,
void* callback) {
VLOG(1) << "wxx push_dense_raw_gradient";
PSClientClosure* closure = reinterpret_cast<PSClientClosure*>(callback);
auto* table_ptr = GetTable(table_id);
TableContext table_context;
table_context.value_type = Dense;
table_context.push_context.values = total_send_data;
table_context.num = total_send_data_size;
// table_ptr->PushDense(total_send_data, total_send_data_size);
table_ptr->Push(table_context);
delete closure;
return done();
}
::std::future<int32_t> PsLocalClient::PushDense(const Region* regions,
size_t region_num,
size_t table_id) {
auto* accessor = GetTableAccessor(table_id);
auto* table_ptr = GetTable(table_id);
std::vector<float> region_buffer;
region_buffer.resize(
DenseDimPerShard(accessor->GetAccessorInfo().fea_dim, 1));
size_t data_size = region_buffer.size();
for (size_t i = 0, offset = 0; i < region_num; ++i) {
uint32_t data_num = regions[i].size / sizeof(float);
PADDLE_ENFORCE_LE(
offset + data_num,
data_size,
platform::errors::PreconditionNotMet(
"invalid dense size, cur pos[%d] data_num[%d] size[%d]",
offset,
data_num,
data_size));
memcpy(region_buffer.data() + offset, regions[i].data, regions[i].size);
offset += data_num;
}
TableContext table_context;
table_context.value_type = Dense;
table_context.push_context.values = region_buffer.data();
table_context.num = region_buffer.size();
// table_ptr->PushDense(total_send_data, total_send_data_size);
table_ptr->Push(table_context);
return done();
}
//::std::future<int32_t> PsLocalClient::PullSparse(float** select_values,
// size_t table_id,
// const uint64_t* keys,
// size_t num) {
// // FIXME
// // auto timer =
// // std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse");
// // auto local_timer =
// // std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse_local");
// //将key拆分到各shard请求,并记录原始对应value指针
// auto* accessor = GetTableAccessor(table_id);
// auto* table_ptr = GetTable(table_id);
// size_t value_size = accessor->select_size();
//
// // table_ptr->PullSparse(keys, num);
// std::vector<float> res_data;
// res_data.resize(num * value_size / sizeof(float));
// table_ptr->PullSparse(res_data.data(), keys, num);
// // memcpy(select_values[0], res_data->data(), res_data->size() *
// // sizeof(float));
// size_t offset = 0;
// for (int i = 0; i < num; ++i) {
// memcpy(select_values[i], (char*)res_data.data() + offset, value_size);
// offset += value_size;
// }
//
// // return fut;
// return done();
//}
::std::future<int32_t> PsLocalClient::PullSparsePtr(char** select_values,
size_t table_id,
const uint64_t* keys,
size_t num) {
// FIXME
// auto timer =
// std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse");
// auto local_timer =
// std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse_local");
//将key拆分到各shard请求,并记录原始对应value指针
auto* table_ptr = GetTable(table_id);
TableContext table_context;
table_context.value_type = Sparse;
table_context.pull_context.keys = keys;
table_context.pull_context.ptr_values = select_values;
table_context.use_ptr = true;
table_context.num = num;
// table_ptr->PullSparsePtr(select_values, keys, num);
table_ptr->Pull(table_context);
return done();
}
::std::future<int32_t> PsLocalClient::PushSparseRawGradient(
size_t table_id,
const uint64_t* keys,
const float** update_values,
size_t num,
void* callback) {
PSClientClosure* closure = reinterpret_cast<PSClientClosure*>(callback);
auto* table_ptr = GetTable(table_id);
TableContext table_context;
table_context.value_type = Sparse;
table_context.push_context.keys = keys;
table_context.push_context.ptr_values = update_values;
table_context.num = num;
table_context.use_ptr = true;
// table_ptr->PushSparse(keys, update_values, num);
table_ptr->Push(table_context);
delete closure;
return done();
}
::std::future<int32_t> PsLocalClient::PushSparse(size_t table_id,
const uint64_t* keys,
const float** update_values,
size_t num) {
auto* table_ptr = GetTable(table_id);
TableContext table_context;
table_context.value_type = Sparse;
table_context.push_context.keys = keys;
table_context.push_context.ptr_values = update_values;
table_context.num = num;
table_context.use_ptr = true;
// table_ptr->PushSparse(keys, update_values, num);
table_ptr->Push(table_context);
return done();
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License 0//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/ps/service/ps_client.h"
namespace paddle {
namespace distributed {
class Table;
class PsLocalClient : public PSClient {
public:
PsLocalClient() {}
virtual ~PsLocalClient() { _running = false; }
virtual int32_t CreateClient2ClientConnection(int pslib_timeout_ms,
int pslib_connect_timeout_ms,
int max_retry) {
return 0;
}
virtual ::std::future<int32_t> Shrink(uint32_t table_id,
const std::string threshold) override;
virtual ::std::future<int32_t> Load(const std::string& epoch,
const std::string& mode) override;
virtual ::std::future<int32_t> Load(uint32_t table_id,
const std::string& epoch,
const std::string& mode) override;
virtual ::std::future<int32_t> Save(const std::string& epoch,
const std::string& mode) override;
virtual ::std::future<int32_t> Save(uint32_t table_id,
const std::string& epoch,
const std::string& mode) override;
virtual ::std::future<int32_t> Clear() override;
virtual ::std::future<int32_t> Clear(uint32_t table_id) override;
virtual ::std::future<int32_t> StopServer() override;
virtual void FinalizeWorker() override {}
virtual ::std::future<int32_t> PullDense(Region* regions,
size_t region_num,
size_t table_id);
virtual ::std::future<int32_t> PushDense(const Region* regions,
size_t region_num,
size_t table_id);
virtual ::std::future<int32_t> PushDenseParam(const Region* regions,
size_t region_num,
size_t table_id);
virtual ::std::future<int32_t> PullSparse(float** select_values,
size_t table_id,
const uint64_t* keys,
size_t num,
bool is_training) {
std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future();
prom.set_value(0);
return fut;
}
virtual ::std::future<int32_t> PullSparsePtr(char** select_values,
size_t table_id,
const uint64_t* keys,
size_t num);
virtual ::std::future<int32_t> PrintTableStat(uint32_t table_id) {
std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future();
prom.set_value(0);
return fut;
}
virtual ::std::future<int32_t> PushSparse(size_t table_id,
const uint64_t* keys,
const float** update_values,
size_t num);
virtual ::std::future<int32_t> Flush();
// server profilera
virtual std::future<int32_t> StartProfiler() {
std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future();
prom.set_value(0);
return fut;
};
virtual std::future<int32_t> StopProfiler() {
std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future();
prom.set_value(0);
return fut;
}
virtual std::future<int32_t> Barrier(size_t table_id, uint32_t barrier_type) {
std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future();
prom.set_value(0);
return fut;
}
virtual std::future<int32_t> PullGeoParam(size_t table_id,
std::vector<float>* values,
std::vector<uint64_t>* keys,
int pserver_idx) {
std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future();
prom.set_value(0);
return fut;
}
virtual std::future<int32_t> PushGlobalStep(int table_id,
int64_t* total_send_data,
void* done) {
std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future();
prom.set_value(0);
return fut;
}
// recv table from server and save it in LodTensor
virtual int32_t RecvAndSaveTable(const uint64_t table_id,
const std::string& path) {
return 0;
}
virtual ::std::future<int32_t> SendClient2ClientMsg(
int msg_type, int to_client_id, const std::string& msg) override {
std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future();
prom.set_value(0);
return fut;
}
virtual size_t GetServerNums() { return 1; }
virtual std::future<int32_t> PushDenseRawGradient(int table_id,
float* total_send_data,
size_t total_send_data_size,
void* callback) override;
virtual std::future<int32_t> PushSparseRawGradient(
size_t table_id,
const uint64_t* keys,
const float** update_values,
size_t num,
void* callback) override;
virtual std::future<int32_t> PushSparseRawGradientPartial(
size_t table_id,
const uint64_t* keys,
const float** update_values,
uint32_t num,
void* done,
int pserver_idx) override {
std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future();
prom.set_value(0);
return fut;
}
virtual std::future<int32_t> PushSparseParam(size_t table_id,
const uint64_t* keys,
const float** update_values,
size_t num,
void* done) override {
std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future();
prom.set_value(0);
return fut;
}
private:
virtual int32_t Initialize() override;
std::future<int32_t> done() {
std::shared_ptr<std::promise<int32_t>> prom =
std::make_shared<std::promise<int32_t>>();
std::future<int32_t> fut = prom->get_future();
prom->set_value(0);
return fut;
}
inline uint32_t DenseDimPerShard(uint32_t dense_dim_total,
uint32_t shard_num) {
return dense_dim_total / shard_num + 1;
}
inline std::unordered_map<uint32_t, std::shared_ptr<Table>>* GetTable() {
return &_table_map;
}
inline Table* GetTable(size_t table_id) {
auto itr = _table_map.find(table_id);
if (itr != _table_map.end()) {
return itr->second.get();
}
LOG(ERROR) << "table not found " << table_id;
return NULL;
}
std::unordered_map<uint32_t, std::shared_ptr<Table>> _table_map;
bool _running = false;
bool _flushing = false;
private:
float _mae = 0;
float _mse = 0;
uint16_t _push_times = 0;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <vector>
#include "paddle/fluid/distributed/ps/service/server.h"
namespace paddle {
namespace distributed {
class PsLocalServer : public PSServer {
public:
PsLocalServer() {}
virtual ~PsLocalServer() {}
virtual uint64_t Start() { return 0; }
virtual uint64_t Start(const std::string &ip, uint32_t port) { return 0; }
virtual int32_t Stop() { return 0; }
virtual int32_t Configure(
const PSParameter &config,
PSEnvironment &env,
size_t server_rank,
const std::vector<framework::ProgramDesc> &server_sub_program = {}) {
return 0;
}
private:
virtual int32_t Initialize() { return 0; }
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/ps_service/graph_py_service.h"
#include <thread> // NOLINT
#include "butil/endpoint.h"
#include "iomanip"
#include "paddle/fluid/distributed/ps/table/table.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
namespace paddle {
namespace distributed {
std::vector<std::string> GraphPyService::split(std::string& str,
const char pattern) {
std::vector<std::string> res;
std::stringstream input(str);
std::string temp;
while (std::getline(input, temp, pattern)) {
res.push_back(temp);
}
return res;
}
void GraphPyService::add_table_feat_conf(std::string table_name,
std::string feat_name,
std::string feat_dtype,
int feat_shape) {
if (feature_to_id.find(table_name) != feature_to_id.end()) {
int idx = feature_to_id[table_name];
VLOG(0) << "for table name" << table_name << " idx = " << idx;
if (table_feat_mapping[idx].find(feat_name) ==
table_feat_mapping[idx].end()) {
VLOG(0) << "for table name not found,make a new one";
int res = (int)table_feat_mapping[idx].size();
table_feat_mapping[idx][feat_name] = res;
VLOG(0) << "seq id = " << table_feat_mapping[idx][feat_name];
}
int feat_idx = table_feat_mapping[idx][feat_name];
VLOG(0) << "table_name " << table_name << " mapping id " << idx;
VLOG(0) << " feat name " << feat_name << " feat id" << feat_idx;
if (static_cast<size_t>(feat_idx) < table_feat_conf_feat_name[idx].size()) {
// overide
table_feat_conf_feat_name[idx][feat_idx] = feat_name;
table_feat_conf_feat_dtype[idx][feat_idx] = feat_dtype;
table_feat_conf_feat_shape[idx][feat_idx] = feat_shape;
} else {
// new
table_feat_conf_feat_name[idx].push_back(feat_name);
table_feat_conf_feat_dtype[idx].push_back(feat_dtype);
table_feat_conf_feat_shape[idx].push_back(feat_shape);
}
}
VLOG(0) << "add conf over";
}
void add_graph_node(std::string name,
std::vector<int64_t> node_ids,
std::vector<bool> weight_list) {}
void remove_graph_node(std::string name, std::vector<int64_t> node_ids) {}
void GraphPyService::set_up(std::string ips_str,
int shard_num,
std::vector<std::string> node_types,
std::vector<std::string> edge_types) {
set_shard_num(shard_num);
set_num_node_types(node_types.size());
/*
int num_node_types;
std::unordered_map<std::string, uint32_t> edge_idx, feature_idx;
std::vector<std::unordered_map<std::string,uint32_t>> table_feat_mapping;
std::vector<std::vector<std::string>> table_feat_conf_feat_name;
std::vector<std::vector<std::string>> table_feat_conf_feat_dtype;
std::vector<std::vector<int32_t>> table_feat_conf_feat_shape;
*/
id_to_edge = edge_types;
for (size_t table_id = 0; table_id < edge_types.size(); table_id++) {
int res = (int)edge_to_id.size();
edge_to_id[edge_types[table_id]] = res;
}
id_to_feature = node_types;
for (size_t table_id = 0; table_id < node_types.size(); table_id++) {
int res = (int)feature_to_id.size();
feature_to_id[node_types[table_id]] = res;
}
table_feat_mapping.resize(node_types.size());
this->table_feat_conf_feat_name.resize(node_types.size());
this->table_feat_conf_feat_dtype.resize(node_types.size());
this->table_feat_conf_feat_shape.resize(node_types.size());
std::istringstream stream(ips_str);
std::string ip;
server_size = 0;
std::vector<std::string> ips_list = split(ips_str, ';');
int index = 0;
VLOG(0) << "start to build server";
for (auto ips : ips_list) {
auto ip_and_port = split(ips, ':');
server_list.push_back(ip_and_port[0]);
port_list.push_back(ip_and_port[1]);
uint32_t port = stoul(ip_and_port[1]);
auto ph_host = paddle::distributed::PSHost(ip_and_port[0], port, index);
host_sign_list.push_back(ph_host.SerializeToString());
index++;
}
VLOG(0) << "build server done";
}
void GraphPyClient::start_client() {
std::map<uint64_t, std::vector<paddle::distributed::Region>> dense_regions;
dense_regions.insert(
std::pair<uint64_t, std::vector<paddle::distributed::Region>>(0, {}));
auto regions = dense_regions[0];
::paddle::distributed::PSParameter worker_proto = GetWorkerProto();
paddle::distributed::PaddlePSEnvironment _ps_env;
auto servers_ = host_sign_list.size();
_ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.SetPsServers(&host_sign_list, servers_);
worker_ptr = std::shared_ptr<paddle::distributed::GraphBrpcClient>(
(paddle::distributed::GraphBrpcClient*)
paddle::distributed::PSClientFactory::Create(worker_proto));
worker_ptr->Configure(worker_proto, dense_regions, _ps_env, client_id);
worker_ptr->set_shard_num(get_shard_num());
}
void GraphPyServer::start_server(bool block) {
std::string ip = server_list[rank];
uint32_t port = std::stoul(port_list[rank]);
::paddle::distributed::PSParameter server_proto = this->GetServerProto();
auto _ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.SetPsServers(&this->host_sign_list,
this->host_sign_list.size()); // test
pserver_ptr = std::shared_ptr<paddle::distributed::GraphBrpcServer>(
(paddle::distributed::GraphBrpcServer*)
paddle::distributed::PSServerFactory::Create(server_proto));
VLOG(0) << "pserver-ptr created ";
std::vector<framework::ProgramDesc> empty_vec;
framework::ProgramDesc empty_prog;
empty_vec.push_back(empty_prog);
pserver_ptr->Configure(server_proto, _ps_env, rank, empty_vec);
pserver_ptr->Start(ip, port);
pserver_ptr->build_peer2peer_connection(rank);
std::condition_variable* cv_ = pserver_ptr->export_cv();
if (block) {
std::mutex mutex_;
std::unique_lock<std::mutex> lock(mutex_);
cv_->wait(lock);
}
}
::paddle::distributed::PSParameter GraphPyServer::GetServerProto() {
// Generate server proto desc
::paddle::distributed::PSParameter server_fleet_desc;
::paddle::distributed::ServerParameter* server_proto =
server_fleet_desc.mutable_server_param();
::paddle::distributed::DownpourServerParameter* downpour_server_proto =
server_proto->mutable_downpour_server_param();
::paddle::distributed::ServerServiceParameter* server_service_proto =
downpour_server_proto->mutable_service_param();
server_service_proto->set_service_class("GraphBrpcService");
server_service_proto->set_server_class("GraphBrpcServer");
server_service_proto->set_client_class("GraphBrpcClient");
server_service_proto->set_start_server_port(0);
server_service_proto->set_server_thread_num(12);
// for (auto& tuple : this->table_id_map) {
// VLOG(0) << " make a new table " << tuple.second;
::paddle::distributed::TableParameter* sparse_table_proto =
downpour_server_proto->add_downpour_table_param();
// std::vector<std::string> feat_name;
// std::vector<std::string> feat_dtype;
// std::vector<int32_t> feat_shape;
// for (size_t i = 0; i < this->table_feat_conf_table_name.size(); i++) {
// if (tuple.first == table_feat_conf_table_name[i]) {
// feat_name.push_back(table_feat_conf_feat_name[i]);
// feat_dtype.push_back(table_feat_conf_feat_dtype[i]);
// feat_shape.push_back(table_feat_conf_feat_shape[i]);
// }
// }
// std::string table_type;
// if (tuple.second < this->num_node_types) {
// table_type = "node";
// } else {
// table_type = "edge";
// }
GetDownpourSparseTableProto(sparse_table_proto);
//}
return server_fleet_desc;
}
::paddle::distributed::PSParameter GraphPyClient::GetWorkerProto() {
::paddle::distributed::PSParameter worker_fleet_desc;
::paddle::distributed::WorkerParameter* worker_proto =
worker_fleet_desc.mutable_worker_param();
::paddle::distributed::DownpourWorkerParameter* downpour_worker_proto =
worker_proto->mutable_downpour_worker_param();
// for (auto& tuple : this->table_id_map) {
// VLOG(0) << " make a new table " << tuple.second;
::paddle::distributed::TableParameter* worker_sparse_table_proto =
downpour_worker_proto->add_downpour_table_param();
// std::vector<std::string> feat_name;
// std::vector<std::string> feat_dtype;
// std::vector<int32_t> feat_shape;
// for (size_t i = 0; i < this->table_feat_conf_table_name.size(); i++) {
// if (tuple.first == table_feat_conf_table_name[i]) {
// feat_name.push_back(table_feat_conf_feat_name[i]);
// feat_dtype.push_back(table_feat_conf_feat_dtype[i]);
// feat_shape.push_back(table_feat_conf_feat_shape[i]);
// }
// }
// std::string table_type;
// if (tuple.second < this->num_node_types) {
// table_type = "node";
// } else {
// table_type = "edge";
// }
GetDownpourSparseTableProto(worker_sparse_table_proto);
//}
::paddle::distributed::ServerParameter* server_proto =
worker_fleet_desc.mutable_server_param();
::paddle::distributed::DownpourServerParameter* downpour_server_proto =
server_proto->mutable_downpour_server_param();
::paddle::distributed::ServerServiceParameter* server_service_proto =
downpour_server_proto->mutable_service_param();
server_service_proto->set_service_class("GraphBrpcService");
server_service_proto->set_server_class("GraphBrpcServer");
server_service_proto->set_client_class("GraphBrpcClient");
server_service_proto->set_start_server_port(0);
server_service_proto->set_server_thread_num(12);
// for (auto& tuple : this->table_id_map) {
// VLOG(0) << " make a new table " << tuple.second;
::paddle::distributed::TableParameter* sparse_table_proto =
downpour_server_proto->add_downpour_table_param();
// std::vector<std::string> feat_name;
// std::vector<std::string> feat_dtype;
// std::vector<int32_t> feat_shape;
// for (size_t i = 0; i < this->table_feat_conf_table_name.size(); i++) {
// if (tuple.first == table_feat_conf_table_name[i]) {
// feat_name.push_back(table_feat_conf_feat_name[i]);
// feat_dtype.push_back(table_feat_conf_feat_dtype[i]);
// feat_shape.push_back(table_feat_conf_feat_shape[i]);
// }
// }
// std::string table_type;
// if (tuple.second < this->num_node_types) {
// table_type = "node";
// } else {
// table_type = "edge";
// }
GetDownpourSparseTableProto(sparse_table_proto);
//}
return worker_fleet_desc;
}
void GraphPyClient::load_edge_file(std::string name,
std::string filepath,
bool reverse) {
// 'e' means load edge
std::string params = "e";
if (reverse) {
// 'e<' means load edges from $2 to $1
params += "<" + name;
} else {
// 'e>' means load edges from $1 to $2
params += ">" + name;
}
if (edge_to_id.find(name) != edge_to_id.end()) {
auto status = get_ps_client()->Load(0, std::string(filepath), params);
status.wait();
}
// if (this->table_id_map.count(name)) {
// VLOG(0) << "loadding data with type " << name << " from " << filepath;
// uint32_t table_id = this->table_id_map[name];
// auto status =
// get_ps_client()->Load(table_id, std::string(filepath), params);
// status.wait();
// }
}
void GraphPyClient::clear_nodes(std::string name) {
if (edge_to_id.find(name) != edge_to_id.end()) {
int idx = edge_to_id[name];
auto status = get_ps_client()->clear_nodes(0, 0, idx);
status.wait();
} else if (feature_to_id.find(name) != feature_to_id.end()) {
int idx = feature_to_id[name];
auto status = get_ps_client()->clear_nodes(0, 1, idx);
status.wait();
}
// if (this->table_id_map.count(name)) {
// uint32_t table_id = this->table_id_map[name];
// auto status = get_ps_client()->clear_nodes(table_id);
// status.wait();
// }
}
void GraphPyClient::add_graph_node(std::string name,
std::vector<int64_t>& node_ids,
std::vector<bool>& weight_list) {
// if (this->table_id_map.count(name)) {
// uint32_t table_id = this->table_id_map[name];
// auto status =
// get_ps_client()->add_graph_node(table_id, node_ids, weight_list);
// status.wait();
// }
if (edge_to_id.find(name) != edge_to_id.end()) {
int idx = edge_to_id[name];
auto status =
get_ps_client()->add_graph_node(0, idx, node_ids, weight_list);
status.wait();
}
}
void GraphPyClient::remove_graph_node(std::string name,
std::vector<int64_t>& node_ids) {
if (edge_to_id.find(name) != edge_to_id.end()) {
int idx = edge_to_id[name];
auto status = get_ps_client()->remove_graph_node(0, idx, node_ids);
status.wait();
}
// if (this->table_id_map.count(name)) {
// uint32_t table_id = this->table_id_map[name];
// auto status = get_ps_client()->remove_graph_node(table_id, node_ids);
// status.wait();
// }
}
void GraphPyClient::load_node_file(std::string name, std::string filepath) {
// 'n' means load nodes and 'node_type' follows
std::string params = "n" + name;
if (feature_to_id.find(name) != feature_to_id.end()) {
auto status = get_ps_client()->Load(0, std::string(filepath), params);
status.wait();
}
// if (this->table_id_map.count(name)) {
// uint32_t table_id = this->table_id_map[name];
// auto status =
// get_ps_client()->Load(table_id, std::string(filepath), params);
// status.wait();
// }
}
std::pair<std::vector<std::vector<int64_t>>, std::vector<float>>
GraphPyClient::batch_sample_neighbors(std::string name,
std::vector<int64_t> node_ids,
int sample_size,
bool return_weight,
bool return_edges) {
std::vector<std::vector<int64_t>> v;
std::vector<std::vector<float>> v1;
if (edge_to_id.find(name) != edge_to_id.end()) {
int idx = edge_to_id[name];
auto status = get_ps_client()->batch_sample_neighbors(
0, idx, node_ids, sample_size, v, v1, return_weight);
status.wait();
}
// if (this->table_id_map.count(name)) {
// uint32_t table_id = this->table_id_map[name];
// auto status = worker_ptr->batch_sample_neighbors(
// table_id, node_ids, sample_size, v, v1, return_weight);
// status.wait();
// }
// res.first[0]: neighbors (nodes)
// res.first[1]: slice index
// res.first[2]: src nodes
// res.second: edges weight
std::pair<std::vector<std::vector<int64_t>>, std::vector<float>> res;
res.first.push_back({});
res.first.push_back({});
if (return_edges) res.first.push_back({});
for (size_t i = 0; i < v.size(); i++) {
for (size_t j = 0; j < v[i].size(); j++) {
// res.first[0].push_back(v[i][j].first);
res.first[0].push_back(v[i][j]);
if (return_edges) res.first[2].push_back(node_ids[i]);
if (return_weight) res.second.push_back(v1[i][j]);
}
if (i == v.size() - 1) break;
if (i == 0) {
res.first[1].push_back(v[i].size());
} else {
res.first[1].push_back(v[i].size() + res.first[1].back());
}
}
return res;
}
std::vector<int64_t> GraphPyClient::random_sample_nodes(std::string name,
int server_index,
int sample_size) {
std::vector<int64_t> v;
if (feature_to_id.find(name) != feature_to_id.end()) {
int idx = feature_to_id[name];
auto status = get_ps_client()->random_sample_nodes(
0, 1, idx, server_index, sample_size, v);
status.wait();
} else if (edge_to_id.find(name) != edge_to_id.end()) {
int idx = edge_to_id[name];
auto status = get_ps_client()->random_sample_nodes(
0, 0, idx, server_index, sample_size, v);
status.wait();
}
// if (this->table_id_map.count(name)) {
// uint32_t table_id = this->table_id_map[name];
// auto status =
// worker_ptr->random_sample_nodes(table_id, server_index, sample_size,
// v);
// status.wait();
// }
return v;
}
// (name, dtype, ndarray)
std::vector<std::vector<std::string>> GraphPyClient::get_node_feat(
std::string name,
std::vector<int64_t> node_ids,
std::vector<std::string> feature_names) {
std::vector<std::vector<std::string>> v(
feature_names.size(), std::vector<std::string>(node_ids.size()));
if (feature_to_id.find(name) != feature_to_id.end()) {
int idx = feature_to_id[name];
auto status =
get_ps_client()->get_node_feat(0, idx, node_ids, feature_names, v);
status.wait();
}
// if (this->table_id_map.count(node_type)) {
// uint32_t table_id = this->table_id_map[node_type];
// auto status =
// worker_ptr->get_node_feat(table_id, node_ids, feature_names, v);
// status.wait();
// }
return v;
}
void GraphPyClient::set_node_feat(
std::string name,
std::vector<int64_t> node_ids,
std::vector<std::string> feature_names,
const std::vector<std::vector<std::string>> features) {
if (feature_to_id.find(name) != feature_to_id.end()) {
int idx = feature_to_id[name];
auto status = get_ps_client()->set_node_feat(
0, idx, node_ids, feature_names, features);
status.wait();
}
// if (this->table_id_map.count(node_type)) {
// uint32_t table_id = this->table_id_map[node_type];
// auto status =
// worker_ptr->set_node_feat(table_id, node_ids, feature_names,
// features);
// status.wait();
// }
return;
}
std::vector<FeatureNode> GraphPyClient::pull_graph_list(
std::string name, int server_index, int start, int size, int step) {
std::vector<FeatureNode> res;
// if (this->table_id_map.count(name)) {
// uint32_t table_id = this->table_id_map[name];
// auto status = worker_ptr->pull_graph_list(table_id, server_index, start,
// size, step, res);
// status.wait();
// }
if (feature_to_id.find(name) != feature_to_id.end()) {
int idx = feature_to_id[name];
auto status = get_ps_client()->pull_graph_list(
0, 1, idx, server_index, start, size, step, res);
status.wait();
} else if (edge_to_id.find(name) != edge_to_id.end()) {
int idx = edge_to_id[name];
auto status = get_ps_client()->pull_graph_list(
0, 0, idx, server_index, start, size, step, res);
status.wait();
}
return res;
}
void GraphPyClient::StopServer() {
VLOG(0) << "going to stop server";
std::unique_lock<std::mutex> lock(mutex_);
if (stoped_) return;
auto status = this->worker_ptr->StopServer();
if (status.get() == 0) stoped_ = true;
}
void GraphPyClient::FinalizeWorker() { this->worker_ptr->FinalizeWorker(); }
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <unistd.h>
#include <condition_variable> // NOLINT
#include <fstream>
#include <iomanip>
#include <iostream>
#include <sstream>
#include <string>
#include <thread> // NOLINT
#include <unordered_map>
#include <vector>
#include "google/protobuf/text_format.h"
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/ps/service/env.h"
#include "paddle/fluid/distributed/ps/service/graph_brpc_client.h"
#include "paddle/fluid/distributed/ps/service/graph_brpc_server.h"
#include "paddle/fluid/distributed/ps/service/ps_service/service.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace distributed {
class GraphPyService {
protected:
std::vector<std::string> server_list, port_list, host_sign_list;
int server_size, shard_num;
int num_node_types;
std::unordered_map<std::string, int> edge_to_id, feature_to_id;
std::vector<std::string> id_to_feature, id_to_edge;
std::vector<std::unordered_map<std::string, int>> table_feat_mapping;
std::vector<std::vector<std::string>> table_feat_conf_feat_name;
std::vector<std::vector<std::string>> table_feat_conf_feat_dtype;
std::vector<std::vector<int>> table_feat_conf_feat_shape;
public:
int get_shard_num() { return shard_num; }
void set_shard_num(int shard_num) { this->shard_num = shard_num; }
void GetDownpourSparseTableProto(
::paddle::distributed::TableParameter* sparse_table_proto) {
sparse_table_proto->set_table_id(0);
sparse_table_proto->set_table_class("GraphTable");
sparse_table_proto->set_shard_num(shard_num);
sparse_table_proto->set_type(::paddle::distributed::PS_SPARSE_TABLE);
::paddle::distributed::TableAccessorParameter* accessor_proto =
sparse_table_proto->mutable_accessor();
// ::paddle::distributed::CommonAccessorParameter* common_proto =
// sparse_table_proto->mutable_common();
::paddle::distributed::GraphParameter* graph_proto =
sparse_table_proto->mutable_graph_parameter();
// ::paddle::distributed::GraphFeature* graph_feature =
// graph_proto->mutable_graph_feature();
graph_proto->set_task_pool_size(24);
graph_proto->set_table_name("cpu_graph_table");
graph_proto->set_use_cache(false);
for (size_t i = 0; i < id_to_edge.size(); i++)
graph_proto->add_edge_types(id_to_edge[i]);
for (size_t i = 0; i < id_to_feature.size(); i++) {
graph_proto->add_node_types(id_to_feature[i]);
auto feat_node = id_to_feature[i];
::paddle::distributed::GraphFeature* g_f =
graph_proto->add_graph_feature();
for (size_t x = 0; x < table_feat_conf_feat_name[i].size(); x++) {
g_f->add_name(table_feat_conf_feat_name[i][x]);
g_f->add_dtype(table_feat_conf_feat_dtype[i][x]);
g_f->add_shape(table_feat_conf_feat_shape[i][x]);
}
}
// Set GraphTable Parameter
// common_proto->set_table_name(table_name);
// common_proto->set_name(table_type);
// for (size_t i = 0; i < feat_name.size(); i++) {
// common_proto->add_params(feat_dtype[i]);
// common_proto->add_dims(feat_shape[i]);
// common_proto->add_attributes(feat_name[i]);
// }
// for (size_t i = 0; i < feat_name.size(); i++) {
// graph_feature->add_dtype(feat_dtype[i]);
// graph_feature->add_shape(feat_shape[i]);
// graph_feature->add_name(feat_name[i]);
// }
accessor_proto->set_accessor_class("CommMergeAccessor");
}
void set_server_size(int server_size) { this->server_size = server_size; }
void set_num_node_types(int num_node_types) {
this->num_node_types = num_node_types;
}
int get_server_size(int server_size) { return server_size; }
std::vector<std::string> split(std::string& str, const char pattern);
void set_up(std::string ips_str,
int shard_num,
std::vector<std::string> node_types,
std::vector<std::string> edge_types);
void add_table_feat_conf(std::string node_type,
std::string feat_name,
std::string feat_dtype,
int32_t feat_shape);
};
class GraphPyServer : public GraphPyService {
public:
GraphPyServer() {}
void set_up(std::string ips_str,
int shard_num,
std::vector<std::string> node_types,
std::vector<std::string> edge_types,
int rank) {
set_rank(rank);
GraphPyService::set_up(ips_str, shard_num, node_types, edge_types);
}
int GetRank() { return rank; }
void set_rank(int rank) { this->rank = rank; }
void start_server(bool block = true);
::paddle::distributed::PSParameter GetServerProto();
std::shared_ptr<paddle::distributed::GraphBrpcServer> get_ps_server() {
return pserver_ptr;
}
protected:
int rank;
std::shared_ptr<paddle::distributed::GraphBrpcServer> pserver_ptr;
std::thread* server_thread;
};
class GraphPyClient : public GraphPyService {
public:
void set_up(std::string ips_str,
int shard_num,
std::vector<std::string> node_types,
std::vector<std::string> edge_types,
int client_id) {
set_client_id(client_id);
GraphPyService::set_up(ips_str, shard_num, node_types, edge_types);
}
std::shared_ptr<paddle::distributed::GraphBrpcClient> get_ps_client() {
return worker_ptr;
}
void bind_local_server(int local_channel_index, GraphPyServer& server) {
worker_ptr->set_local_channel(local_channel_index);
worker_ptr->set_local_graph_service(
(paddle::distributed::GraphBrpcService*)server.get_ps_server()
->get_service());
}
void StopServer();
void FinalizeWorker();
void load_edge_file(std::string name, std::string filepath, bool reverse);
void load_node_file(std::string name, std::string filepath);
void clear_nodes(std::string name);
void add_graph_node(std::string name,
std::vector<int64_t>& node_ids,
std::vector<bool>& weight_list);
void remove_graph_node(std::string name, std::vector<int64_t>& node_ids);
int get_client_id() { return client_id; }
void set_client_id(int client_id) { this->client_id = client_id; }
void start_client();
std::pair<std::vector<std::vector<int64_t>>, std::vector<float>>
batch_sample_neighbors(std::string name,
std::vector<int64_t> node_ids,
int sample_size,
bool return_weight,
bool return_edges);
std::vector<int64_t> random_sample_nodes(std::string name,
int server_index,
int sample_size);
std::vector<std::vector<std::string>> get_node_feat(
std::string name,
std::vector<int64_t> node_ids,
std::vector<std::string> feature_names);
void set_node_feat(std::string node_type,
std::vector<int64_t> node_ids,
std::vector<std::string> feature_names,
const std::vector<std::vector<std::string>> features);
std::vector<FeatureNode> pull_graph_list(
std::string name, int server_index, int start, int size, int step = 1);
::paddle::distributed::PSParameter GetWorkerProto();
protected:
mutable std::mutex mutex_;
int client_id;
std::shared_ptr<paddle::distributed::GraphBrpcClient> worker_ptr;
std::thread* client_thread;
bool stoped_ = false;
};
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/ps/service/ps_service/service.h"
#include <fcntl.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <google/protobuf/text_format.h>
#include <iostream>
#include "paddle/fluid/distributed/ps/service/communicator/communicator.h"
#include "paddle/fluid/string/string_helper.h"
using namespace std; // NOLINT
namespace paddle {
namespace distributed {
paddle::distributed::PSParameter load_from_prototxt(
const std::string& filename) {
paddle::distributed::PSParameter param;
int file_descriptor = open(filename.c_str(), O_RDONLY);
if (file_descriptor == -1) {
VLOG(3) << "FATAL: fail to parse " << filename;
exit(-1);
}
google::protobuf::io::FileInputStream fileInput(file_descriptor);
if (!google::protobuf::TextFormat::Parse(&fileInput, &param)) {
VLOG(3) << "FATAL: fail to parse " << filename;
exit(-1);
}
close(file_descriptor);
return param;
}
void PSCore::InitGFlag(const std::string& gflags) {
VLOG(3) << "Init With Gflags:" << gflags;
std::vector<std::string> flags = paddle::string::split_string(gflags);
if (flags.size() < 1) {
flags.push_back("-max_body_size=314217728");
flags.push_back("-socket_max_unwritten_bytes=2048000000");
flags.push_back("-max_connection_pool_size=1950");
}
auto it = flags.begin();
flags.insert(it, "exe default");
char* flags_ptr[flags.size()];
for (size_t i = 0; i < flags.size(); ++i) {
flags_ptr[i] = (char*)(flags[i].c_str()); // NOLINT
}
int params_cnt = flags.size();
char** params_ptr = &(flags_ptr[0]);
::GFLAGS_NAMESPACE::ParseCommandLineFlags(&params_cnt, &params_ptr, true);
}
int PSCore::InitServer(
const std::string& dist_desc,
const std::vector<std::string>* host_sign_list,
int node_num,
int index,
int trainers,
const std::vector<framework::ProgramDesc>& server_sub_program) {
google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param);
InitGFlag(_ps_param.init_gflags());
_ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.SetPsServers(host_sign_list, node_num);
_ps_env.SetTrainers(trainers);
int ret = 0;
_server_ptr = std::shared_ptr<paddle::distributed::PSServer>(
paddle::distributed::PSServerFactory::Create(_ps_param));
ret = _server_ptr->Configure(_ps_param, _ps_env, index, server_sub_program);
CHECK(ret == 0) << "failed to configure server";
return ret;
}
int PSCore::InitWorker(
const std::string& dist_desc,
const std::map<uint64_t, std::vector<paddle::distributed::Region>>& regions,
const std::vector<std::string>* host_sign_list,
int node_num,
int index) {
google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param);
InitGFlag(_ps_param.init_gflags());
_ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.SetPsServers(host_sign_list, node_num);
int ret = 0;
VLOG(1) << "PSCore::InitWorker";
auto* communicator = Communicator::GetInstance();
ret = communicator->GetPsClient()->Configure(
_ps_param, regions, _ps_env, index);
communicator->Start();
return ret;
}
std::vector<uint64_t> PSCore::GetClientInfo() {
return _ps_env.GetClientInfo();
}
int PSCore::CreateClient2ClientConnection(int pserver_timeout_ms,
int pserver_connect_timeout_ms,
int max_retry) {
int ret = _worker_ptr->CreateClient2ClientConnection(
pserver_timeout_ms, pserver_connect_timeout_ms, max_retry);
return ret;
}
uint64_t PSCore::RunServer(const std::string& ip, uint32_t port) {
return _server_ptr->Start(ip, port);
}
int PSCore::FinalizeWorker() {
_worker_ptr->FinalizeWorker();
return 0;
}
int PSCore::StopServer() {
auto stop_status = _worker_ptr->StopServer();
stop_status.wait();
return 0;
}
paddle::distributed::PSParameter* PSCore::GetParam() { return &_ps_param; }
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/ps/service/ps_client.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/ps/service/server.h"
namespace paddle {
namespace distributed {
class PSClient;
class PSServer;
class PsRequestMessage;
class PsResponseMessage;
class PsService;
using paddle::distributed::PsRequestMessage;
using paddle::distributed::PsResponseMessage;
using paddle::distributed::PsService;
class PSCore {
public:
explicit PSCore() {}
virtual ~PSCore() {}
virtual int InitServer(
const std::string& dist_desc,
const std::vector<std::string>* host_sign_list,
int node_num,
int index,
int trainers,
const std::vector<framework::ProgramDesc>& server_sub_program = {});
virtual int InitWorker(
const std::string& dist_desc,
const std::map<uint64_t, std::vector<paddle::distributed::Region>>&
regions,
const std::vector<std::string>* host_sign_list,
int node_num,
int index);
virtual uint64_t RunServer(const std::string& ip, uint32_t port);
virtual int StopServer();
virtual int FinalizeWorker();
virtual std::vector<uint64_t> GetClientInfo();
virtual int CreateClient2ClientConnection(int pserver_timeout_ms,
int pserver_connect_timeout_ms,
int max_retry);
std::shared_ptr<paddle::distributed::PSServer>
_server_ptr; // pointer to server
std::shared_ptr<paddle::distributed::PSClient>
_worker_ptr; // pointer to worker
virtual paddle::distributed::PSParameter* GetParam();
private:
void InitGFlag(const std::string& gflags);
paddle::distributed::PSParameter _ps_param;
paddle::distributed::PaddlePSEnvironment _ps_env;
};
} // namespace distributed
} // namespace paddle
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment