Commit f0ef3442 authored by yuguo960516yuguo's avatar yuguo960516yuguo
Browse files

2.3.2-dtk-22.10.1

parent ad08b8ce
Pipeline #227 failed with stages
in 0 seconds
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <atomic>
#include <ctime>
#include <map>
#include <memory>
#include <random>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/ps/service/brpc_utils.h"
#include "paddle/fluid/distributed/ps/service/heter_client.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/ps/table/depends/feature_value.h"
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
#include "paddle/fluid/platform/profiler.h"
namespace google {
namespace protobuf {
class Closure;
class RpcController;
} // namespace protobuf
} // namespace google
namespace paddle {
namespace framework {
class Executor;
class ProgramDesc;
class Scope;
} // namespace framework
} // namespace paddle
DECLARE_double(eager_delete_tensor_gb);
namespace paddle {
namespace distributed {
DECLARE_int32(pserver_timeout_ms);
DECLARE_int32(heter_world_size);
DECLARE_int32(switch_send_recv_timeout_s);
using MultiVarMsg = MultiVariableMessage;
using VarMsg = VariableMessage;
using serviceHandler =
std::function<int32_t(const PsRequestMessage& request,
PsResponseMessage& response, // NOLINT
brpc::Controller* cntl)>;
using HeterServiceHandler =
std::function<int32_t(const MultiVarMsg*, MultiVarMsg*, brpc::Controller*)>;
using HeterRpcCallbackFunc = std::function<void(void*)>;
class ServiceHandlerBase {
public:
ServiceHandlerBase() : dev_ctx_(nullptr), scope_(nullptr) {}
virtual ~ServiceHandlerBase() {}
void SetScope(const framework::Scope* scope) { scope_ = scope; }
void SetDevCtx(const platform::DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; }
virtual int Handle(const MultiVarMsg* request,
MultiVarMsg* response,
brpc::Controller* cntl) = 0;
protected:
const platform::DeviceContext* dev_ctx_;
const framework::Scope* scope_;
};
using SharedMiniScope =
std::shared_ptr<std::unordered_map<int, ::paddle::framework::Scope*>>;
using SharedMicroScope = std::shared_ptr<std::unordered_map<
int,
std::shared_ptr<std::vector<::paddle::framework::Scope*>>>>;
using SharedTaskQueue = std::shared_ptr<
std::unordered_map<int,
std::shared_ptr<::paddle::framework::BlockingQueue<
std::pair<std::string, int>>>>>;
class ValueInSwitch {
public:
ValueInSwitch() {}
~ValueInSwitch() {}
char* data() { return _data.data(); }
size_t size() { return _data.size(); }
void resize(size_t size) { _data.resize(size); }
void shrink_to_fit() { _data.shrink_to_fit(); }
private:
std::vector<char> _data;
};
class SendAndRecvVariableHandler final : public ServiceHandlerBase {
public:
SendAndRecvVariableHandler() {
this->num_microbatch_ = 0;
this->num_minibatch_ = 0;
_local_shards.reset(new shard_type[FLAGS_heter_world_size]);
}
virtual ~SendAndRecvVariableHandler() {}
void SetMiniScopes(SharedMiniScope mini_scopes) {
mini_scopes_ = mini_scopes;
num_minibatch_ = mini_scopes_->size();
}
void SetMicroScopes(SharedMicroScope micro_scopes) {
micro_scopes_ = micro_scopes;
for (auto& scope_pair : (*micro_scopes_)) {
// auto mini_idx = scope_pair.first;
auto& micro_scopes = scope_pair.second;
num_microbatch_ = micro_scopes->size();
break;
}
}
int GetThreadNum() {
std::unique_lock<std::mutex> lk(scope_mutex_);
return (*task_queue_).size();
}
int SaveInSwitchWithScope(const MultiVarMsg* request,
PsResponseMessage* response,
brpc::Controller* cntl);
void WaitForVarsConsumed(int32_t group_id, const std::string& var_name) {
// timeline_.Start();
while (true) {
{
std::lock_guard<std::mutex> lock(scope_mutex_);
if (vars_ready_flag[group_id][var_name] == 0) {
break;
}
}
/*
timeline_.Pause();
if (timeline_.ElapsedSec() > FLAGS_switch_send_recv_timeout_s) {
VLOG(0) << "vars not consumed exceed 10 miniutes";
break;
}
*/
}
return;
}
void WaitForVarsProduced(int32_t group_id, const std::string& var_name) {
// timeline_.Start();
while (true) {
{
std::lock_guard<std::mutex> lock(scope_mutex_);
if (vars_ready_flag[group_id][var_name] == 1) {
break;
}
}
/*
timeline_.Pause();
if (timeline_.ElapsedSec() > FLAGS_switch_send_recv_timeout_s) {
VLOG(0) << "vars not produced exceed 10 miniutes";
break;
}
*/
}
return;
}
int SaveInSwitchWithShard(const MultiVarMsg* request,
PsResponseMessage* response,
brpc::Controller* cntl);
int QueryInSwitchWithShard(const MultiVarMsg* request,
MultiVarMsg* response,
brpc::Controller* cntl);
int QueryInSwitchWithScope(const MultiVarMsg* request,
MultiVarMsg* response,
brpc::Controller* cntl);
void SetTaskQueue(SharedTaskQueue task_queue) { task_queue_ = task_queue; }
int Handle(const MultiVarMsg* request,
MultiVarMsg* response,
brpc::Controller* cntl) override {
LOG(INFO) << "entered Handle";
platform::RecordEvent record_event("SendAndRecvVariableHandler->Handle",
platform::TracerEventType::Communication,
1);
FLAGS_eager_delete_tensor_gb = -1;
// get microID from request
// deserialize variable to micro scope
// Push to heter worker's task_queue
std::unique_ptr<paddle::framework::Scope> local_scope_ptr(
new paddle::framework::Scope());
auto& local_scope = *(local_scope_ptr.get());
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::CPUPlace cpu_place;
auto& cpu_dev_ctx = *pool.Get(cpu_place);
auto message_name = request->message_name();
auto& request_io_buffer = cntl->request_attachment();
distributed::DeserializeFromMultiVarMsgAndIOBuf(
*request, &request_io_buffer, cpu_dev_ctx, &local_scope);
auto* var = local_scope.FindVar("microbatch_id");
PADDLE_ENFORCE_NE(var,
nullptr,
platform::errors::InvalidArgument(
"Not find variable microbatch_id in scope."));
auto* tensor = var->GetMutable<phi::DenseTensor>();
auto data = reinterpret_cast<const float*>(tensor->data());
auto micro_id = static_cast<int>(data[0]);
VLOG(4) << "micro_id in heter server: " << micro_id;
int minibatch_index = micro_id / 10;
int microbatch_index = micro_id % 10;
// check minibatch_index is in mini_scopes_
std::unique_lock<std::mutex> lk(scope_mutex_);
if ((*mini_scopes_).find(minibatch_index) != (*mini_scopes_).end()) {
lk.unlock();
PADDLE_ENFORCE_EQ(
(*micro_scopes_).find(minibatch_index) != (*micro_scopes_).end(),
1,
platform::errors::InvalidArgument(
"minibatch index should in current trainer"));
} else {
// create mini scope & micro scopes
auto* minibatch_scope = &(scope_->NewScope());
(*mini_scopes_)[minibatch_index] = minibatch_scope;
(*micro_scopes_)[minibatch_index].reset(
new std::vector<paddle::framework::Scope*>{});
for (int i = 0; i < num_microbatch_; i++) {
auto* micro_scope = &(minibatch_scope->NewScope());
(*((*micro_scopes_)[minibatch_index])).push_back(micro_scope);
}
(*task_queue_)[minibatch_index].reset(
new ::paddle::framework::BlockingQueue<
std::pair<std::string, int>>());
lk.unlock();
}
auto* micro_scope =
(*((*micro_scopes_)[minibatch_index]))[microbatch_index];
distributed::DeserializeFromMultiVarMsgAndIOBuf(
*request, &request_io_buffer, *dev_ctx_, micro_scope);
// blocking queue handles multi thread
VLOG(4) << "Handle in HeterServer: " << message_name << ", "
<< microbatch_index;
VLOG(4) << "task_queue_ size: " << task_queue_->size();
(*task_queue_)[minibatch_index]->Push(
std::make_pair(message_name, microbatch_index));
auto response_var_nums = request->recv_var_names_size();
std::vector<std::string> response_var_names(response_var_nums),
empty_var_names{};
for (int var_idx = 0; var_idx < response_var_nums; ++var_idx) {
response_var_names[var_idx] = request->recv_var_names(var_idx);
}
auto& response_io_buffer = cntl->response_attachment();
distributed::SerializeToMultiVarMsgAndIOBuf(message_name,
response_var_names,
empty_var_names,
*dev_ctx_,
&local_scope,
response,
&response_io_buffer);
VLOG(4) << "Handle over";
return 0;
}
public:
using shard_type = SparseTableShard<std::string, ValueInSwitch>;
std::shared_ptr<paddle::framework::Scope> local_scope_ptr; // for switch
std::unordered_map<uint32_t, std::unordered_map<std::string, uint32_t>>
vars_ready_flag;
std::unique_ptr<shard_type[]> _local_shards;
platform::Timer timeline_;
private:
// share with HeterPipelineTrainer
SharedMiniScope mini_scopes_{nullptr};
SharedMicroScope micro_scopes_{nullptr};
int num_microbatch_;
int num_minibatch_;
std::mutex scope_mutex_;
bool is_first_stage_ = false;
bool is_last_stage_ = false;
SharedTaskQueue task_queue_;
};
class HeterService : public PsService {
public:
HeterService() {
_service_handler_map[PS_STOP_SERVER] =
std::bind(&HeterService::stop_heter_worker,
this,
std::placeholders::_1,
std::placeholders::_2,
std::placeholders::_3);
_service_handler_map[PS_START_PROFILER] =
std::bind(&HeterService::start_profiler,
this,
std::placeholders::_1,
std::placeholders::_2,
std::placeholders::_3);
_service_handler_map[PS_STOP_PROFILER] =
std::bind(&HeterService::stop_profiler,
this,
std::placeholders::_1,
std::placeholders::_2,
std::placeholders::_3);
service_handler_.local_scope_ptr =
std::make_shared<paddle::framework::Scope>();
}
virtual ~HeterService() {}
virtual void service(::google::protobuf::RpcController* controller,
const PsRequestMessage* request,
PsResponseMessage* response,
::google::protobuf::Closure* done) {
brpc::ClosureGuard done_guard(done);
response->set_err_code(0);
response->set_err_msg("");
brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
auto itr = _service_handler_map.find(request->cmd_id());
if (itr == _service_handler_map.end()) {
std::string err_msg(
"undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:");
err_msg.append(std::to_string(request->cmd_id()));
return;
}
serviceHandler handler = itr->second;
int service_ret = handler(*request, *response, cntl);
VLOG(4) << "handler in service ret: " << service_ret;
if (service_ret != 0) {
response->set_err_code(service_ret);
response->set_err_msg("server internal error");
}
}
virtual void SendAndRecvVariable(
::google::protobuf::RpcController* controller,
const MultiVarMsg* request,
MultiVarMsg* response,
::google::protobuf::Closure* done) {
// This object helps you to call done->Run() in RAII style. If you need
// to process the request asynchronously, pass done_guard.release().
brpc::ClosureGuard done_guard(done);
std::string message_name = request->message_name();
VLOG(0) << "SendAndRecvVariable message_name: " << message_name;
auto itr = handler_map_.find(message_name);
brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
LOG(INFO) << "SendAndRecvVariable(client addr) =" << cntl->remote_side();
PADDLE_ENFORCE_NE(
itr,
handler_map_.end(),
platform::errors::InvalidArgument(
"HeterService::SendAndRecvVariable Get illegal message_name: %s "
"which is not in HeterService::handler_map_",
message_name));
itr->second(request, response, cntl);
// We don't want to call done->Run() here, release the guard.
// done_guard.release();
}
virtual void RecvFromSwitch(::google::protobuf::RpcController* controller,
const MultiVarMsg* request,
MultiVarMsg* response,
::google::protobuf::Closure* done) {
brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
// int ret = service_handler_.QueryInSwitchWithScope(request, response,
// cntl);
int ret = service_handler_.QueryInSwitchWithShard(request, response, cntl);
// std::string message_name = request->message_name();
// auto itr = handler_map_.find(message_name);
// int ret = itr->second(request, response, cntl);
if (ret != 0) {
LOG(ERROR) << "QueryInSwitchWithScope failed!";
}
// response->set_message_name(message_name);
}
virtual void SendToSwitch(::google::protobuf::RpcController* controller,
const MultiVarMsg* request,
PsResponseMessage* response,
::google::protobuf::Closure* done) {
VLOG(4) << "entering SendToSwitch";
brpc::ClosureGuard done_guard(done);
std::shared_ptr<HeterClient> switch_client_ptr_ =
HeterClient::GetSwitchInstance(peer_endpoints_, PEER_ROLE_IS_SWITCH);
if (switch_client_ptr_->peer_switch_channels_.empty()) {
LOG(ERROR) << "switch_client_ptr_->peer_switch_channels_ null";
}
brpc::Channel* channel = switch_client_ptr_->peer_switch_channels_[0].get();
brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
// proxy: 定义新的 OnHeterRpcDone 对象(或者在类 OnHeterRpcDone 中 reset)
OnHeterRpcDone* closure2 = new OnHeterRpcDone([](void* done) {
auto* closure = reinterpret_cast<OnHeterRpcDone*>(done);
int ret = closure->CheckResponse();
closure->set_promise_value(ret);
if (closure->cntl.Failed()) {
PADDLE_ENFORCE_NE(
closure->cntl.Failed(),
true,
platform::errors::Unimplemented(
"HeterClient::SendS2S meets brpc error, error message is %s",
closure->cntl.ErrorText()));
}
});
auto& std_cntl = closure2->cntl;
std_cntl.set_timeout_ms(FLAGS_pserver_timeout_ms);
std_cntl.request_attachment().append(cntl->request_attachment().movable());
auto promise = std::make_shared<std::promise<int32_t>>();
closure2->add_promise(promise);
std::future<int> fut = promise->get_future();
// brpc::Controller std_cntl;
// std_cntl.request_attachment().append(cntl->request_attachment().movable());
PsService_Stub stub(channel);
stub.SendS2S(&std_cntl, request, response, closure2);
cntl->response_attachment().append(
std_cntl.response_attachment().movable());
fut.wait();
VLOG(4) << "SendToSwitch done";
delete closure2;
}
void SendS2S(::google::protobuf::RpcController* controller,
const MultiVarMsg* request,
PsResponseMessage* response,
::google::protobuf::Closure* done) {
VLOG(4) << "entering SendS2S";
brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
// int ret = service_handler_.SaveInSwitchWithScope(request, response,
// cntl);
int ret = service_handler_.SaveInSwitchWithShard(request, response, cntl);
// std::string message_name = request->message_name();
// auto itr = handler_map_.find(message_name);
// if (itr == handler_map_.end()) {
// LOG(ERROR) << "can not find func handler";
//}
// int ret = itr->second(request, response, cntl);
if (ret != 0) {
LOG(ERROR) << "SaveInSwitchWithScope failed";
}
std::string err_msg = "ok";
response->set_err_msg(err_msg.c_str());
response->set_err_code(ret);
VLOG(4) << "heter server SendS2S done";
}
void SendToWorker(::google::protobuf::RpcController* controller,
const MultiVarMsg* request,
PsResponseMessage* response,
::google::protobuf::Closure* done) {
brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
VLOG(4) << "SendToWorker(client addr) =" << cntl->remote_side();
std::shared_ptr<distributed::HeterClient> switch_client_ptr_ =
HeterClient::GetSwitchInstance(peer_endpoints_, PEER_ROLE_IS_WORKER);
VLOG(4) << "in switch client, peer worker 0: "
<< switch_client_ptr_->peer_worker_list_[0];
brpc::Channel* channel = switch_client_ptr_->peer_worker_channels_[0].get();
auto* closure = reinterpret_cast<OnHeterRpcDone*>(done);
PsService_Stub stub(channel);
stub.SendAndRecvVariable(controller, request, &closure->response, done);
// fill response content
std::string err_msg("pass to worker");
response->set_err_msg(err_msg.c_str());
response->set_err_code(0);
}
void RegisterServiceHandler(std::string message_name,
HeterServiceHandler func) {
handler_map_[message_name] = func;
}
void SetEndpoint(const std::string& end_point) { endpoint_ = end_point; }
void SetInterEndpoint(const std::string& end_point) {
endpoint_inter_ = end_point;
}
void SetPeerEndPoints(const std::vector<std::string>& peer_endpoints) {
peer_endpoints_ = peer_endpoints;
}
void SetFanin(const int& fan_in) { fan_in_ = fan_in; }
void ForceExit() {
VLOG(3) << "heter service force exit";
is_exit_ = true;
return;
}
bool IsExit() { return is_exit_; }
private:
int32_t stop_profiler(const PsRequestMessage& request,
PsResponseMessage& response, // NOLINT
brpc::Controller* cntl) {
platform::DisableProfiler(
platform::EventSortingKey::kDefault,
string::Sprintf("heter_worker_%s_profile", endpoint_));
return 0;
}
int32_t start_profiler(const PsRequestMessage& request,
PsResponseMessage& response, // NOLINT
brpc::Controller* cntl) {
platform::EnableProfiler(platform::ProfilerState::kAll);
return 0;
}
int32_t stop_heter_worker(const PsRequestMessage& request,
PsResponseMessage& response, // NOLINT
brpc::Controller* cntl) {
auto client_id = request.client_id();
stop_cpu_worker_set_.insert(client_id);
if (stop_cpu_worker_set_.size() == fan_in_) {
is_exit_ = true;
}
return 0;
}
private:
SendAndRecvVariableHandler service_handler_;
std::string endpoint_;
std::string endpoint_inter_;
// for switch
std::vector<std::string> peer_endpoints_;
std::unordered_map<int32_t, serviceHandler> _service_handler_map;
std::unordered_map<std::string, HeterServiceHandler> handler_map_;
std::unordered_set<int> stop_cpu_worker_set_;
uint32_t fan_in_;
bool is_exit_ = false;
};
class HeterServer {
public:
HeterServer() : ready_(0) {}
virtual ~HeterServer() {}
void Stop() {
std::unique_lock<std::mutex> lock(mutex_);
if (stoped_ == true) return;
if (!IsExit()) {
service_.ForceExit();
}
stoped_ = true;
cv_.notify_all();
server_.Stop(1000);
server_.Join();
}
bool IsStop() {
std::unique_lock<std::mutex> lock(mutex_);
return stoped_;
}
bool IsExit() { return service_.IsExit(); }
void RegisterServiceHandler(std::string message_name,
HeterServiceHandler func);
void StartHeterService(bool need_encrypt = false);
void StartHeterInterService(bool need_encrypt = false);
void SetEndPoint(const std::string& endpoint) {
this->endpoint_ = endpoint;
service_.SetEndpoint(endpoint);
}
void SetLocalScope() {
request_handler_->local_scope_ptr =
std::make_shared<paddle::framework::Scope>();
}
void SetInterEndpoint(const std::string& endpoint) {
this->endpoint_inter_ = endpoint;
service_.SetInterEndpoint(endpoint);
}
void SetPeerEndPoints(const std::vector<std::string>& peer_endpoints) {
this->peer_endpoints_ = peer_endpoints;
service_.SetPeerEndPoints(peer_endpoints);
}
void SetFanin(const int& fan_in);
void SetServiceHandler(
std::shared_ptr<SendAndRecvVariableHandler> request_handler) {
request_handler_ = request_handler;
}
void SetMiniBatchScopes(SharedMiniScope mini_scopes) {
request_handler_->SetMiniScopes(mini_scopes);
}
void SetMicroBatchScopes(SharedMicroScope micro_scopes) {
request_handler_->SetMicroScopes(micro_scopes);
}
int GetThreadNum() { return request_handler_->GetThreadNum(); }
void SetTaskQueue(SharedTaskQueue task_queue) {
request_handler_->SetTaskQueue(task_queue);
}
// HeterWrapper singleton
static std::shared_ptr<HeterServer> GetInstance() {
std::unique_lock<std::mutex> lock(mtx_);
if (s_instance_ == nullptr) {
s_instance_.reset(new HeterServer());
}
return s_instance_;
}
void WaitServerReady();
private:
static std::shared_ptr<HeterServer> s_instance_;
mutable std::mutex mutex_;
static std::mutex mtx_;
std::condition_variable cv_;
std::condition_variable condition_ready_;
bool stoped_ = true;
std::string endpoint_;
std::string endpoint_inter_;
// for switch
std::vector<std::string> peer_endpoints_;
protected:
brpc::Server server_;
brpc::Server server_inter_;
HeterService service_;
std::shared_ptr<SendAndRecvVariableHandler> request_handler_;
DISABLE_COPY_AND_ASSIGN(HeterServer);
std::mutex mutex_ready_;
int ready_;
};
} // end namespace distributed
} // end namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/ps_client.h"
#include "glog/logging.h"
#include "paddle/fluid/distributed/ps/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/ps/service/coordinator_client.h"
#include "paddle/fluid/distributed/ps/service/graph_brpc_client.h"
#include "paddle/fluid/distributed/ps/service/ps_local_client.h"
#include "paddle/fluid/distributed/ps/table/table.h"
namespace paddle {
namespace distributed {
REGISTER_PSCORE_CLASS(PSClient, BrpcPsClient);
REGISTER_PSCORE_CLASS(PSClient, PsLocalClient);
REGISTER_PSCORE_CLASS(PSClient, GraphBrpcClient);
REGISTER_PSCORE_CLASS(PSClient, CoordinatorClient);
int32_t PSClient::Configure( // called in FleetWrapper::InitWorker
const PSParameter &config,
const std::map<uint64_t, std::vector<paddle::distributed::Region>> &regions,
PSEnvironment &env,
size_t client_id) {
_env = &env;
_config = config;
_dense_pull_regions = regions;
_client_id = client_id;
_config.mutable_worker_param()
->mutable_downpour_worker_param()
->mutable_downpour_table_param()
->CopyFrom(_config.server_param()
.downpour_server_param()
.downpour_table_param());
const auto &work_param = _config.worker_param().downpour_worker_param();
for (int i = 0; i < work_param.downpour_table_param_size(); ++i) {
auto *accessor = CREATE_PSCORE_CLASS(
ValueAccessor,
work_param.downpour_table_param(i).accessor().accessor_class());
accessor->Configure(work_param.downpour_table_param(i).accessor());
accessor->Initialize();
_table_accessors[work_param.downpour_table_param(i).table_id()].reset(
accessor);
}
return Initialize();
}
PSClient *PSClientFactory::Create(const PSParameter &ps_config) {
const auto &config = ps_config.server_param();
if (!config.has_downpour_server_param()) {
LOG(ERROR) << "miss downpour_server_param in ServerParameter";
return NULL;
}
if (!config.downpour_server_param().has_service_param()) {
LOG(ERROR) << "miss service_param in ServerParameter.downpour_server_param";
return NULL;
}
if (!config.downpour_server_param().service_param().has_client_class()) {
LOG(ERROR) << "miss client_class in "
"ServerParameter.downpour_server_param.service_param";
return NULL;
}
const auto &service_param = config.downpour_server_param().service_param();
PSClient *client =
CREATE_PSCORE_CLASS(PSClient, service_param.client_class());
if (client == NULL) {
LOG(ERROR) << "client is not registered, server_name:"
<< service_param.client_class();
return NULL;
}
TableManager::Instance().Initialize();
VLOG(3) << "Create PSClient[" << service_param.client_class() << "] success";
return client;
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <future>
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/common/cost_timer.h"
#include "paddle/fluid/distributed/ps/service/env.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/distributed/the_one_ps.pb.h"
#include "paddle/fluid/platform/timer.h"
namespace paddle {
namespace distributed {
using paddle::distributed::PsRequestMessage;
using paddle::distributed::PsResponseMessage;
typedef std::function<void(void *)> PSClientCallBack;
class PSClientClosure : public google::protobuf::Closure {
public:
explicit PSClientClosure(PSClientCallBack callback) : _callback(callback) {}
virtual ~PSClientClosure() {}
virtual void set_promise_value(int value) {
for (auto &promise : _promises) {
promise->set_value(value);
}
}
void add_promise(std::shared_ptr<std::promise<int32_t>> &promise) { // NOLINT
_promises.push_back(promise);
}
void add_timer(std::shared_ptr<CostTimer> &timer) { // NOLINT
_timers.push_back(timer);
}
protected:
PSClientCallBack _callback;
std::vector<std::shared_ptr<CostTimer>> _timers;
std::vector<std::shared_ptr<std::promise<int32_t>>> _promises;
};
class PSClient {
public:
PSClient() {}
virtual ~PSClient() {}
PSClient(PSClient &&) = delete;
PSClient(const PSClient &) = delete;
virtual int32_t Configure(
const PSParameter &config,
const std::map<uint64_t, std::vector<paddle::distributed::Region>>
&regions,
PSEnvironment &_env, // NOLINT
size_t client_id) final;
virtual int32_t CreateClient2ClientConnection(int pserver_timeout_ms,
int pserver_connect_timeout_ms,
int max_retry) = 0;
// 触发table数据退场
virtual std::future<int32_t> Shrink(uint32_t table_id,
const std::string threshold) = 0;
// 全量table进行数据load
virtual std::future<int32_t> Load(const std::string &epoch,
const std::string &mode) = 0;
// 指定table数据load
virtual std::future<int32_t> Load(uint32_t table_id,
const std::string &epoch,
const std::string &mode) = 0;
// 全量table数据save value_accessor根据mode,可能有不同的save条件
virtual std::future<int32_t> Save(const std::string &epoch,
const std::string &mode) = 0;
// 指定table数据save value_accessor根据mode,可能有不同的save条件
virtual std::future<int32_t> Save(uint32_t table_id,
const std::string &epoch,
const std::string &mode) = 0;
// 清空table数据
virtual std::future<int32_t> Clear() = 0;
virtual std::future<int32_t> Clear(uint32_t table_id) = 0;
// pull dense的参数部分,并分块填充到本地网络参数中
// start和num用于拉取部分参数
// future结束前keys和values缓冲区不能再次使用
// client将values按照区块拆包后送交多个sender
// sender聚集同一区块的请求,累计多个填充buffer
// server将参数区块中配置的某一维提取返回
// 返回数据解包后填充到累计的多个buffer中
virtual std::future<int32_t> PullDense(Region *regions,
size_t region_num,
size_t table_id) = 0; // 保留
// firstly push dense param for parameter server
// this is necessary because dense weight initialized in trainer on cold
// start
virtual std::future<int32_t> PushDenseParam(const Region *regions,
size_t region_num,
size_t table_id) = 0;
virtual std::future<int32_t> PushDense(const Region *regions,
size_t region_num,
size_t table_id) = 0;
// 使用keys进行pull请求,结果填充values
// keys和values的个数均为num个,每个value占用select_size空间
// future结束前keys和values缓冲区不能再次使用
// 整合多个线程请求的keys,聚集并分散发送到server
// 返回结果后,遍历buffer并对values赋值
// is_training 用于区分请求是训练/预测,server端对于特征和准入会有不同的处理.
virtual std::future<int32_t> PullSparse(float **select_values,
size_t table_id,
const uint64_t *keys,
size_t num,
bool is_training) = 0;
virtual std::future<int32_t> PullSparseParam(float **select_values,
size_t table_id,
const uint64_t *keys,
size_t num,
bool is_training) {
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
virtual ::std::future<int32_t> PullSparsePtr(char **select_values,
size_t table_id,
const uint64_t *keys,
size_t num) {
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
virtual std::future<int32_t> PrintTableStat(uint32_t table_id) = 0;
// 确保所有积攒中的请求都发起发送
virtual std::future<int32_t> Flush() = 0;
// server优雅退出
virtual std::future<int32_t> StopServer() = 0;
// server profilera
virtual std::future<int32_t> StartProfiler() = 0;
virtual std::future<int32_t> StopProfiler() = 0;
virtual std::future<int32_t> Barrier(size_t table_id,
uint32_t barrier_type) = 0;
virtual std::future<int32_t> PullGeoParam(size_t table_id,
std::vector<float> *values,
std::vector<uint64_t> *keys,
int pserver_idx) = 0;
virtual std::future<int32_t> PushGlobalStep(int table_id,
int64_t *total_send_data,
void *done) = 0;
// recv table from server and save it in LodTensor
virtual int32_t RecvAndSaveTable(const uint64_t table_id,
const std::string &path) = 0;
virtual void FinalizeWorker() = 0;
// client to client, 消息发送
virtual std::future<int32_t> SendClient2ClientMsg(int msg_type,
int to_client_id,
const std::string &msg) {
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
// client2client消息处理,std::function<int32_t (int, int, const std::string&)
// -> ret (msg_type, from_client_id, msg)
typedef std::function<int32_t(int, int, const std::string &)> MsgHandlerFunc;
virtual int RegisteClient2ClientMsgHandler(int msg_type,
MsgHandlerFunc handler) {
_msg_handler_map[msg_type] = handler;
return 0;
}
virtual int HandleClient2ClientMsg(int msg_type,
int from_client_id,
const std::string &msg) {
auto itr = _msg_handler_map.find(msg_type);
if (itr == _msg_handler_map.end()) {
LOG(WARNING) << "unknown client2client_msg type:" << msg_type;
return -1;
}
return itr->second(msg_type, from_client_id, msg);
}
virtual ValueAccessor *GetTableAccessor(size_t table_id) {
auto itr = _table_accessors.find(table_id);
if (itr == _table_accessors.end()) {
return NULL;
}
return itr->second.get();
}
virtual size_t GetServerNums() = 0;
virtual std::future<int32_t> PushDenseRawGradient(int table_id,
float *total_send_data,
size_t total_send_data_size,
void *done) = 0;
virtual std::future<int32_t> PushSparseRawGradient(
size_t table_id,
const uint64_t *keys,
const float **update_values,
size_t num,
void *done) = 0;
virtual std::future<int32_t> PushSparseRawGradientPartial(
size_t table_id,
const uint64_t *keys,
const float **update_values,
uint32_t num,
void *done,
int pserver_idx) = 0;
virtual std::future<int32_t> PushSparseParam(size_t table_id,
const uint64_t *keys,
const float **update_values,
size_t num,
void *done) = 0;
virtual std::future<int32_t> PushSparse(size_t table_id,
const uint64_t *keys,
const float **update_values,
size_t num) = 0;
// for save cache
virtual std::future<int32_t> CacheShuffle(
uint32_t table_id,
const std::string &path,
const std::string &mode,
const std::string &cache_threshold) {
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
virtual std::future<int32_t> CacheShuffleMultiTable(
std::vector<int> tables,
const std::string &path,
const std::string &mode,
const std::string &cache_threshold) {
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
virtual std::future<int32_t> SaveCache(uint32_t table_id,
const std::string &path,
const std::string &mode) {
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
virtual std::future<int32_t> GetCacheThreshold(
uint32_t table_id,
double &cache_threshold) { // NOLINT
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
virtual std::future<int32_t> Revert() {
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
virtual std::future<int32_t> CheckSavePrePatchDone() {
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
protected:
virtual int32_t Initialize() = 0;
PSParameter _config;
std::map<uint64_t, std::vector<paddle::distributed::Region>>
_dense_pull_regions;
std::unordered_map<uint32_t, std::shared_ptr<ValueAccessor>> _table_accessors;
std::unordered_map<int32_t, MsgHandlerFunc>
_msg_handler_map; // 处理client2client消息
public:
size_t _client_id;
PSEnvironment *_env;
};
template <class T>
class AsyncRequestTask {
public:
AsyncRequestTask() : _promise(std::make_shared<std::promise<int32_t>>()) {}
AsyncRequestTask(T &data, size_t table_id, std::shared_ptr<CostTimer> &timer)
: _table_id(table_id),
_timer(timer),
_promise(std::make_shared<std::promise<int32_t>>()) {
_data = std::move(data);
}
AsyncRequestTask(AsyncRequestTask &data) // NOLINT
: _table_id(data.table_id()),
_timer(data.timer()),
_promise(data.promise()) {
_data = std::move(data.data());
}
~AsyncRequestTask() {}
inline T &data() { return _data; }
inline size_t table_id() { return _table_id; }
inline std::shared_ptr<CostTimer> &timer() { return _timer; }
inline std::future<int32_t> get_future() { return _promise->get_future(); }
inline std::shared_ptr<std::promise<int32_t>> &promise() { return _promise; }
private:
T _data;
size_t _table_id;
std::shared_ptr<CostTimer> _timer;
std::shared_ptr<std::promise<int32_t>> _promise;
};
REGISTER_PSCORE_REGISTERER(PSClient);
class PSClientFactory {
public:
static PSClient *Create(const PSParameter &config);
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/ps_local_client.h"
#include "paddle/fluid/distributed/ps/table/table.h"
//#define pslib_debug_dense_compress
namespace paddle {
namespace distributed {
int32_t PsLocalClient::Initialize() {
const auto& downpour_param = _config.server_param().downpour_server_param();
TableManager::Instance().Initialize();
for (int i = 0; i < downpour_param.downpour_table_param_size(); ++i) {
auto* table = CREATE_PSCORE_CLASS(
Table, downpour_param.downpour_table_param(i).table_class());
table->SetShard(0, 1);
table->Initialize(downpour_param.downpour_table_param(i),
_config.fs_client_param());
_table_map[downpour_param.downpour_table_param(i).table_id()].reset(table);
}
return 0;
}
::std::future<int32_t> PsLocalClient::Shrink(uint32_t table_id,
const std::string threshold) {
// TODO
return done();
}
::std::future<int32_t> PsLocalClient::Load(const std::string& epoch,
const std::string& mode) {
// TODO
for (auto& it : _table_map) {
Load(it.first, epoch, mode);
}
return done();
}
::std::future<int32_t> PsLocalClient::Load(uint32_t table_id,
const std::string& epoch,
const std::string& mode) {
// TODO
auto* table_ptr = GetTable(table_id);
table_ptr->Load(epoch, mode);
return done();
}
::std::future<int32_t> PsLocalClient::Save(const std::string& epoch,
const std::string& mode) {
// TODO
for (auto& it : _table_map) {
Save(it.first, epoch, mode);
}
return done();
}
::std::future<int32_t> PsLocalClient::Save(uint32_t table_id,
const std::string& epoch,
const std::string& mode) {
// TODO
auto* table_ptr = GetTable(table_id);
table_ptr->Flush();
table_ptr->Save(epoch, mode);
return done();
}
::std::future<int32_t> PsLocalClient::Clear() {
// TODO
return done();
}
::std::future<int32_t> PsLocalClient::Clear(uint32_t table_id) {
// TODO
return done();
}
::std::future<int32_t> PsLocalClient::Flush() {
// no need
return done();
}
::std::future<int32_t> PsLocalClient::StopServer() {
// no need
return done();
}
::std::future<int32_t> PsLocalClient::PullDense(Region* regions,
size_t region_num,
size_t table_id) {
auto* accessor = GetTableAccessor(table_id);
auto* table_ptr = GetTable(table_id);
uint32_t num_per_shard =
DenseDimPerShard(accessor->GetAccessorInfo().fea_dim, 1);
std::vector<float> region_buffer;
region_buffer.resize(num_per_shard);
TableContext table_context;
table_context.value_type = Dense;
table_context.pull_context.values = region_buffer.data();
table_context.num = region_buffer.size();
table_ptr->Pull(table_context);
// table_ptr->PullDense(region_buffer.data(), region_buffer.size());
size_t region_idx = 0;
size_t region_data_idx = 0;
size_t shard_data_size = num_per_shard;
size_t shard_buffer_remain = shard_data_size * sizeof(float);
PADDLE_ENFORCE_EQ(
shard_buffer_remain,
region_buffer.size() * sizeof(float),
platform::errors::PreconditionNotMet("pull dense size error."));
size_t index = 0;
while (shard_buffer_remain > 0 && region_idx < region_num) {
auto& region = regions[region_idx];
if (region.size - region_data_idx >= shard_buffer_remain) {
memcpy((void*)(region.data + region_data_idx),
(uint8_t*)(void*)(region_buffer.data()) + index,
shard_buffer_remain);
region_data_idx += shard_buffer_remain;
shard_buffer_remain = 0;
} else if (region.size - region_data_idx == 0) {
++region_idx;
region_data_idx = 0;
} else {
memcpy((void*)(region.data + region_data_idx),
(uint8_t*)(void*)(region_buffer.data()) + index,
region.size - region_data_idx);
shard_buffer_remain -= (region.size - region_data_idx);
index += (region.size - region_data_idx);
++region_idx;
region_data_idx = 0;
}
}
return done();
}
::std::future<int32_t> PsLocalClient::PushDenseParam(const Region* regions,
size_t region_num,
size_t table_id) {
auto* accessor = GetTableAccessor(table_id);
auto* table_ptr = GetTable(table_id);
std::vector<float> region_buffer;
region_buffer.resize(DenseDimPerShard(accessor->GetAccessorInfo().fea_dim, 1),
0);
for (size_t i = 0, offset = 0; i < region_num; ++i) {
uint32_t data_num = regions[i].size / sizeof(float);
memcpy(region_buffer.data() + offset, regions[i].data, regions[i].size);
offset += data_num;
}
TableContext table_context;
table_context.value_type = Dense;
table_context.push_context.values = region_buffer.data();
table_context.push_context.is_param = true;
table_context.num = region_buffer.size();
table_ptr->Push(table_context);
// table_ptr->PushDenseParam(region_buffer.data(), region_buffer.size());
return done();
}
::std::future<int32_t> PsLocalClient::PushDenseRawGradient(
int table_id,
float* total_send_data,
size_t total_send_data_size,
void* callback) {
VLOG(1) << "wxx push_dense_raw_gradient";
PSClientClosure* closure = reinterpret_cast<PSClientClosure*>(callback);
auto* table_ptr = GetTable(table_id);
TableContext table_context;
table_context.value_type = Dense;
table_context.push_context.values = total_send_data;
table_context.num = total_send_data_size;
// table_ptr->PushDense(total_send_data, total_send_data_size);
table_ptr->Push(table_context);
delete closure;
return done();
}
::std::future<int32_t> PsLocalClient::PushDense(const Region* regions,
size_t region_num,
size_t table_id) {
auto* accessor = GetTableAccessor(table_id);
auto* table_ptr = GetTable(table_id);
std::vector<float> region_buffer;
region_buffer.resize(
DenseDimPerShard(accessor->GetAccessorInfo().fea_dim, 1));
size_t data_size = region_buffer.size();
for (size_t i = 0, offset = 0; i < region_num; ++i) {
uint32_t data_num = regions[i].size / sizeof(float);
PADDLE_ENFORCE_LE(
offset + data_num,
data_size,
platform::errors::PreconditionNotMet(
"invalid dense size, cur pos[%d] data_num[%d] size[%d]",
offset,
data_num,
data_size));
memcpy(region_buffer.data() + offset, regions[i].data, regions[i].size);
offset += data_num;
}
TableContext table_context;
table_context.value_type = Dense;
table_context.push_context.values = region_buffer.data();
table_context.num = region_buffer.size();
// table_ptr->PushDense(total_send_data, total_send_data_size);
table_ptr->Push(table_context);
return done();
}
//::std::future<int32_t> PsLocalClient::PullSparse(float** select_values,
// size_t table_id,
// const uint64_t* keys,
// size_t num) {
// // FIXME
// // auto timer =
// // std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse");
// // auto local_timer =
// // std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse_local");
// //将key拆分到各shard请求,并记录原始对应value指针
// auto* accessor = GetTableAccessor(table_id);
// auto* table_ptr = GetTable(table_id);
// size_t value_size = accessor->select_size();
//
// // table_ptr->PullSparse(keys, num);
// std::vector<float> res_data;
// res_data.resize(num * value_size / sizeof(float));
// table_ptr->PullSparse(res_data.data(), keys, num);
// // memcpy(select_values[0], res_data->data(), res_data->size() *
// // sizeof(float));
// size_t offset = 0;
// for (int i = 0; i < num; ++i) {
// memcpy(select_values[i], (char*)res_data.data() + offset, value_size);
// offset += value_size;
// }
//
// // return fut;
// return done();
//}
::std::future<int32_t> PsLocalClient::PullSparsePtr(char** select_values,
size_t table_id,
const uint64_t* keys,
size_t num) {
// FIXME
// auto timer =
// std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse");
// auto local_timer =
// std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse_local");
//将key拆分到各shard请求,并记录原始对应value指针
auto* table_ptr = GetTable(table_id);
TableContext table_context;
table_context.value_type = Sparse;
table_context.pull_context.keys = keys;
table_context.pull_context.ptr_values = select_values;
table_context.use_ptr = true;
table_context.num = num;
// table_ptr->PullSparsePtr(select_values, keys, num);
table_ptr->Pull(table_context);
return done();
}
::std::future<int32_t> PsLocalClient::PushSparseRawGradient(
size_t table_id,
const uint64_t* keys,
const float** update_values,
size_t num,
void* callback) {
PSClientClosure* closure = reinterpret_cast<PSClientClosure*>(callback);
auto* table_ptr = GetTable(table_id);
TableContext table_context;
table_context.value_type = Sparse;
table_context.push_context.keys = keys;
table_context.push_context.ptr_values = update_values;
table_context.num = num;
table_context.use_ptr = true;
// table_ptr->PushSparse(keys, update_values, num);
table_ptr->Push(table_context);
delete closure;
return done();
}
::std::future<int32_t> PsLocalClient::PushSparse(size_t table_id,
const uint64_t* keys,
const float** update_values,
size_t num) {
auto* table_ptr = GetTable(table_id);
TableContext table_context;
table_context.value_type = Sparse;
table_context.push_context.keys = keys;
table_context.push_context.ptr_values = update_values;
table_context.num = num;
table_context.use_ptr = true;
// table_ptr->PushSparse(keys, update_values, num);
table_ptr->Push(table_context);
return done();
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License 0//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/ps/service/ps_client.h"
namespace paddle {
namespace distributed {
class Table;
class PsLocalClient : public PSClient {
public:
PsLocalClient() {}
virtual ~PsLocalClient() { _running = false; }
virtual int32_t CreateClient2ClientConnection(int pslib_timeout_ms,
int pslib_connect_timeout_ms,
int max_retry) {
return 0;
}
virtual ::std::future<int32_t> Shrink(uint32_t table_id,
const std::string threshold) override;
virtual ::std::future<int32_t> Load(const std::string& epoch,
const std::string& mode) override;
virtual ::std::future<int32_t> Load(uint32_t table_id,
const std::string& epoch,
const std::string& mode) override;
virtual ::std::future<int32_t> Save(const std::string& epoch,
const std::string& mode) override;
virtual ::std::future<int32_t> Save(uint32_t table_id,
const std::string& epoch,
const std::string& mode) override;
virtual ::std::future<int32_t> Clear() override;
virtual ::std::future<int32_t> Clear(uint32_t table_id) override;
virtual ::std::future<int32_t> StopServer() override;
virtual void FinalizeWorker() override {}
virtual ::std::future<int32_t> PullDense(Region* regions,
size_t region_num,
size_t table_id);
virtual ::std::future<int32_t> PushDense(const Region* regions,
size_t region_num,
size_t table_id);
virtual ::std::future<int32_t> PushDenseParam(const Region* regions,
size_t region_num,
size_t table_id);
virtual ::std::future<int32_t> PullSparse(float** select_values,
size_t table_id,
const uint64_t* keys,
size_t num,
bool is_training) {
std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future();
prom.set_value(0);
return fut;
}
virtual ::std::future<int32_t> PullSparsePtr(char** select_values,
size_t table_id,
const uint64_t* keys,
size_t num);
virtual ::std::future<int32_t> PrintTableStat(uint32_t table_id) {
std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future();
prom.set_value(0);
return fut;
}
virtual ::std::future<int32_t> PushSparse(size_t table_id,
const uint64_t* keys,
const float** update_values,
size_t num);
virtual ::std::future<int32_t> Flush();
// server profilera
virtual std::future<int32_t> StartProfiler() {
std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future();
prom.set_value(0);
return fut;
};
virtual std::future<int32_t> StopProfiler() {
std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future();
prom.set_value(0);
return fut;
}
virtual std::future<int32_t> Barrier(size_t table_id, uint32_t barrier_type) {
std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future();
prom.set_value(0);
return fut;
}
virtual std::future<int32_t> PullGeoParam(size_t table_id,
std::vector<float>* values,
std::vector<uint64_t>* keys,
int pserver_idx) {
std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future();
prom.set_value(0);
return fut;
}
virtual std::future<int32_t> PushGlobalStep(int table_id,
int64_t* total_send_data,
void* done) {
std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future();
prom.set_value(0);
return fut;
}
// recv table from server and save it in LodTensor
virtual int32_t RecvAndSaveTable(const uint64_t table_id,
const std::string& path) {
return 0;
}
virtual ::std::future<int32_t> SendClient2ClientMsg(
int msg_type, int to_client_id, const std::string& msg) override {
std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future();
prom.set_value(0);
return fut;
}
virtual size_t GetServerNums() { return 1; }
virtual std::future<int32_t> PushDenseRawGradient(int table_id,
float* total_send_data,
size_t total_send_data_size,
void* callback) override;
virtual std::future<int32_t> PushSparseRawGradient(
size_t table_id,
const uint64_t* keys,
const float** update_values,
size_t num,
void* callback) override;
virtual std::future<int32_t> PushSparseRawGradientPartial(
size_t table_id,
const uint64_t* keys,
const float** update_values,
uint32_t num,
void* done,
int pserver_idx) override {
std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future();
prom.set_value(0);
return fut;
}
virtual std::future<int32_t> PushSparseParam(size_t table_id,
const uint64_t* keys,
const float** update_values,
size_t num,
void* done) override {
std::promise<int32_t> prom;
std::future<int32_t> fut = prom.get_future();
prom.set_value(0);
return fut;
}
private:
virtual int32_t Initialize() override;
std::future<int32_t> done() {
std::shared_ptr<std::promise<int32_t>> prom =
std::make_shared<std::promise<int32_t>>();
std::future<int32_t> fut = prom->get_future();
prom->set_value(0);
return fut;
}
inline uint32_t DenseDimPerShard(uint32_t dense_dim_total,
uint32_t shard_num) {
return dense_dim_total / shard_num + 1;
}
inline std::unordered_map<uint32_t, std::shared_ptr<Table>>* GetTable() {
return &_table_map;
}
inline Table* GetTable(size_t table_id) {
auto itr = _table_map.find(table_id);
if (itr != _table_map.end()) {
return itr->second.get();
}
LOG(ERROR) << "table not found " << table_id;
return NULL;
}
std::unordered_map<uint32_t, std::shared_ptr<Table>> _table_map;
bool _running = false;
bool _flushing = false;
private:
float _mae = 0;
float _mse = 0;
uint16_t _push_times = 0;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <vector>
#include "paddle/fluid/distributed/ps/service/server.h"
namespace paddle {
namespace distributed {
class PsLocalServer : public PSServer {
public:
PsLocalServer() {}
virtual ~PsLocalServer() {}
virtual uint64_t Start() { return 0; }
virtual uint64_t Start(const std::string &ip, uint32_t port) { return 0; }
virtual int32_t Stop() { return 0; }
virtual int32_t Configure(
const PSParameter &config,
PSEnvironment &env,
size_t server_rank,
const std::vector<framework::ProgramDesc> &server_sub_program = {}) {
return 0;
}
private:
virtual int32_t Initialize() { return 0; }
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/ps_service/graph_py_service.h"
#include <thread> // NOLINT
#include "butil/endpoint.h"
#include "iomanip"
#include "paddle/fluid/distributed/ps/table/table.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
namespace paddle {
namespace distributed {
std::vector<std::string> GraphPyService::split(std::string& str,
const char pattern) {
std::vector<std::string> res;
std::stringstream input(str);
std::string temp;
while (std::getline(input, temp, pattern)) {
res.push_back(temp);
}
return res;
}
void GraphPyService::add_table_feat_conf(std::string table_name,
std::string feat_name,
std::string feat_dtype,
int feat_shape) {
if (feature_to_id.find(table_name) != feature_to_id.end()) {
int idx = feature_to_id[table_name];
VLOG(0) << "for table name" << table_name << " idx = " << idx;
if (table_feat_mapping[idx].find(feat_name) ==
table_feat_mapping[idx].end()) {
VLOG(0) << "for table name not found,make a new one";
int res = (int)table_feat_mapping[idx].size();
table_feat_mapping[idx][feat_name] = res;
VLOG(0) << "seq id = " << table_feat_mapping[idx][feat_name];
}
int feat_idx = table_feat_mapping[idx][feat_name];
VLOG(0) << "table_name " << table_name << " mapping id " << idx;
VLOG(0) << " feat name " << feat_name << " feat id" << feat_idx;
if (static_cast<size_t>(feat_idx) < table_feat_conf_feat_name[idx].size()) {
// overide
table_feat_conf_feat_name[idx][feat_idx] = feat_name;
table_feat_conf_feat_dtype[idx][feat_idx] = feat_dtype;
table_feat_conf_feat_shape[idx][feat_idx] = feat_shape;
} else {
// new
table_feat_conf_feat_name[idx].push_back(feat_name);
table_feat_conf_feat_dtype[idx].push_back(feat_dtype);
table_feat_conf_feat_shape[idx].push_back(feat_shape);
}
}
VLOG(0) << "add conf over";
}
void add_graph_node(std::string name,
std::vector<int64_t> node_ids,
std::vector<bool> weight_list) {}
void remove_graph_node(std::string name, std::vector<int64_t> node_ids) {}
void GraphPyService::set_up(std::string ips_str,
int shard_num,
std::vector<std::string> node_types,
std::vector<std::string> edge_types) {
set_shard_num(shard_num);
set_num_node_types(node_types.size());
/*
int num_node_types;
std::unordered_map<std::string, uint32_t> edge_idx, feature_idx;
std::vector<std::unordered_map<std::string,uint32_t>> table_feat_mapping;
std::vector<std::vector<std::string>> table_feat_conf_feat_name;
std::vector<std::vector<std::string>> table_feat_conf_feat_dtype;
std::vector<std::vector<int32_t>> table_feat_conf_feat_shape;
*/
id_to_edge = edge_types;
for (size_t table_id = 0; table_id < edge_types.size(); table_id++) {
int res = (int)edge_to_id.size();
edge_to_id[edge_types[table_id]] = res;
}
id_to_feature = node_types;
for (size_t table_id = 0; table_id < node_types.size(); table_id++) {
int res = (int)feature_to_id.size();
feature_to_id[node_types[table_id]] = res;
}
table_feat_mapping.resize(node_types.size());
this->table_feat_conf_feat_name.resize(node_types.size());
this->table_feat_conf_feat_dtype.resize(node_types.size());
this->table_feat_conf_feat_shape.resize(node_types.size());
std::istringstream stream(ips_str);
std::string ip;
server_size = 0;
std::vector<std::string> ips_list = split(ips_str, ';');
int index = 0;
VLOG(0) << "start to build server";
for (auto ips : ips_list) {
auto ip_and_port = split(ips, ':');
server_list.push_back(ip_and_port[0]);
port_list.push_back(ip_and_port[1]);
uint32_t port = stoul(ip_and_port[1]);
auto ph_host = paddle::distributed::PSHost(ip_and_port[0], port, index);
host_sign_list.push_back(ph_host.SerializeToString());
index++;
}
VLOG(0) << "build server done";
}
void GraphPyClient::start_client() {
std::map<uint64_t, std::vector<paddle::distributed::Region>> dense_regions;
dense_regions.insert(
std::pair<uint64_t, std::vector<paddle::distributed::Region>>(0, {}));
auto regions = dense_regions[0];
::paddle::distributed::PSParameter worker_proto = GetWorkerProto();
paddle::distributed::PaddlePSEnvironment _ps_env;
auto servers_ = host_sign_list.size();
_ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.SetPsServers(&host_sign_list, servers_);
worker_ptr = std::shared_ptr<paddle::distributed::GraphBrpcClient>(
(paddle::distributed::GraphBrpcClient*)
paddle::distributed::PSClientFactory::Create(worker_proto));
worker_ptr->Configure(worker_proto, dense_regions, _ps_env, client_id);
worker_ptr->set_shard_num(get_shard_num());
}
void GraphPyServer::start_server(bool block) {
std::string ip = server_list[rank];
uint32_t port = std::stoul(port_list[rank]);
::paddle::distributed::PSParameter server_proto = this->GetServerProto();
auto _ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.SetPsServers(&this->host_sign_list,
this->host_sign_list.size()); // test
pserver_ptr = std::shared_ptr<paddle::distributed::GraphBrpcServer>(
(paddle::distributed::GraphBrpcServer*)
paddle::distributed::PSServerFactory::Create(server_proto));
VLOG(0) << "pserver-ptr created ";
std::vector<framework::ProgramDesc> empty_vec;
framework::ProgramDesc empty_prog;
empty_vec.push_back(empty_prog);
pserver_ptr->Configure(server_proto, _ps_env, rank, empty_vec);
pserver_ptr->Start(ip, port);
pserver_ptr->build_peer2peer_connection(rank);
std::condition_variable* cv_ = pserver_ptr->export_cv();
if (block) {
std::mutex mutex_;
std::unique_lock<std::mutex> lock(mutex_);
cv_->wait(lock);
}
}
::paddle::distributed::PSParameter GraphPyServer::GetServerProto() {
// Generate server proto desc
::paddle::distributed::PSParameter server_fleet_desc;
::paddle::distributed::ServerParameter* server_proto =
server_fleet_desc.mutable_server_param();
::paddle::distributed::DownpourServerParameter* downpour_server_proto =
server_proto->mutable_downpour_server_param();
::paddle::distributed::ServerServiceParameter* server_service_proto =
downpour_server_proto->mutable_service_param();
server_service_proto->set_service_class("GraphBrpcService");
server_service_proto->set_server_class("GraphBrpcServer");
server_service_proto->set_client_class("GraphBrpcClient");
server_service_proto->set_start_server_port(0);
server_service_proto->set_server_thread_num(12);
// for (auto& tuple : this->table_id_map) {
// VLOG(0) << " make a new table " << tuple.second;
::paddle::distributed::TableParameter* sparse_table_proto =
downpour_server_proto->add_downpour_table_param();
// std::vector<std::string> feat_name;
// std::vector<std::string> feat_dtype;
// std::vector<int32_t> feat_shape;
// for (size_t i = 0; i < this->table_feat_conf_table_name.size(); i++) {
// if (tuple.first == table_feat_conf_table_name[i]) {
// feat_name.push_back(table_feat_conf_feat_name[i]);
// feat_dtype.push_back(table_feat_conf_feat_dtype[i]);
// feat_shape.push_back(table_feat_conf_feat_shape[i]);
// }
// }
// std::string table_type;
// if (tuple.second < this->num_node_types) {
// table_type = "node";
// } else {
// table_type = "edge";
// }
GetDownpourSparseTableProto(sparse_table_proto);
//}
return server_fleet_desc;
}
::paddle::distributed::PSParameter GraphPyClient::GetWorkerProto() {
::paddle::distributed::PSParameter worker_fleet_desc;
::paddle::distributed::WorkerParameter* worker_proto =
worker_fleet_desc.mutable_worker_param();
::paddle::distributed::DownpourWorkerParameter* downpour_worker_proto =
worker_proto->mutable_downpour_worker_param();
// for (auto& tuple : this->table_id_map) {
// VLOG(0) << " make a new table " << tuple.second;
::paddle::distributed::TableParameter* worker_sparse_table_proto =
downpour_worker_proto->add_downpour_table_param();
// std::vector<std::string> feat_name;
// std::vector<std::string> feat_dtype;
// std::vector<int32_t> feat_shape;
// for (size_t i = 0; i < this->table_feat_conf_table_name.size(); i++) {
// if (tuple.first == table_feat_conf_table_name[i]) {
// feat_name.push_back(table_feat_conf_feat_name[i]);
// feat_dtype.push_back(table_feat_conf_feat_dtype[i]);
// feat_shape.push_back(table_feat_conf_feat_shape[i]);
// }
// }
// std::string table_type;
// if (tuple.second < this->num_node_types) {
// table_type = "node";
// } else {
// table_type = "edge";
// }
GetDownpourSparseTableProto(worker_sparse_table_proto);
//}
::paddle::distributed::ServerParameter* server_proto =
worker_fleet_desc.mutable_server_param();
::paddle::distributed::DownpourServerParameter* downpour_server_proto =
server_proto->mutable_downpour_server_param();
::paddle::distributed::ServerServiceParameter* server_service_proto =
downpour_server_proto->mutable_service_param();
server_service_proto->set_service_class("GraphBrpcService");
server_service_proto->set_server_class("GraphBrpcServer");
server_service_proto->set_client_class("GraphBrpcClient");
server_service_proto->set_start_server_port(0);
server_service_proto->set_server_thread_num(12);
// for (auto& tuple : this->table_id_map) {
// VLOG(0) << " make a new table " << tuple.second;
::paddle::distributed::TableParameter* sparse_table_proto =
downpour_server_proto->add_downpour_table_param();
// std::vector<std::string> feat_name;
// std::vector<std::string> feat_dtype;
// std::vector<int32_t> feat_shape;
// for (size_t i = 0; i < this->table_feat_conf_table_name.size(); i++) {
// if (tuple.first == table_feat_conf_table_name[i]) {
// feat_name.push_back(table_feat_conf_feat_name[i]);
// feat_dtype.push_back(table_feat_conf_feat_dtype[i]);
// feat_shape.push_back(table_feat_conf_feat_shape[i]);
// }
// }
// std::string table_type;
// if (tuple.second < this->num_node_types) {
// table_type = "node";
// } else {
// table_type = "edge";
// }
GetDownpourSparseTableProto(sparse_table_proto);
//}
return worker_fleet_desc;
}
void GraphPyClient::load_edge_file(std::string name,
std::string filepath,
bool reverse) {
// 'e' means load edge
std::string params = "e";
if (reverse) {
// 'e<' means load edges from $2 to $1
params += "<" + name;
} else {
// 'e>' means load edges from $1 to $2
params += ">" + name;
}
if (edge_to_id.find(name) != edge_to_id.end()) {
auto status = get_ps_client()->Load(0, std::string(filepath), params);
status.wait();
}
// if (this->table_id_map.count(name)) {
// VLOG(0) << "loadding data with type " << name << " from " << filepath;
// uint32_t table_id = this->table_id_map[name];
// auto status =
// get_ps_client()->Load(table_id, std::string(filepath), params);
// status.wait();
// }
}
void GraphPyClient::clear_nodes(std::string name) {
if (edge_to_id.find(name) != edge_to_id.end()) {
int idx = edge_to_id[name];
auto status = get_ps_client()->clear_nodes(0, 0, idx);
status.wait();
} else if (feature_to_id.find(name) != feature_to_id.end()) {
int idx = feature_to_id[name];
auto status = get_ps_client()->clear_nodes(0, 1, idx);
status.wait();
}
// if (this->table_id_map.count(name)) {
// uint32_t table_id = this->table_id_map[name];
// auto status = get_ps_client()->clear_nodes(table_id);
// status.wait();
// }
}
void GraphPyClient::add_graph_node(std::string name,
std::vector<int64_t>& node_ids,
std::vector<bool>& weight_list) {
// if (this->table_id_map.count(name)) {
// uint32_t table_id = this->table_id_map[name];
// auto status =
// get_ps_client()->add_graph_node(table_id, node_ids, weight_list);
// status.wait();
// }
if (edge_to_id.find(name) != edge_to_id.end()) {
int idx = edge_to_id[name];
auto status =
get_ps_client()->add_graph_node(0, idx, node_ids, weight_list);
status.wait();
}
}
void GraphPyClient::remove_graph_node(std::string name,
std::vector<int64_t>& node_ids) {
if (edge_to_id.find(name) != edge_to_id.end()) {
int idx = edge_to_id[name];
auto status = get_ps_client()->remove_graph_node(0, idx, node_ids);
status.wait();
}
// if (this->table_id_map.count(name)) {
// uint32_t table_id = this->table_id_map[name];
// auto status = get_ps_client()->remove_graph_node(table_id, node_ids);
// status.wait();
// }
}
void GraphPyClient::load_node_file(std::string name, std::string filepath) {
// 'n' means load nodes and 'node_type' follows
std::string params = "n" + name;
if (feature_to_id.find(name) != feature_to_id.end()) {
auto status = get_ps_client()->Load(0, std::string(filepath), params);
status.wait();
}
// if (this->table_id_map.count(name)) {
// uint32_t table_id = this->table_id_map[name];
// auto status =
// get_ps_client()->Load(table_id, std::string(filepath), params);
// status.wait();
// }
}
std::pair<std::vector<std::vector<int64_t>>, std::vector<float>>
GraphPyClient::batch_sample_neighbors(std::string name,
std::vector<int64_t> node_ids,
int sample_size,
bool return_weight,
bool return_edges) {
std::vector<std::vector<int64_t>> v;
std::vector<std::vector<float>> v1;
if (edge_to_id.find(name) != edge_to_id.end()) {
int idx = edge_to_id[name];
auto status = get_ps_client()->batch_sample_neighbors(
0, idx, node_ids, sample_size, v, v1, return_weight);
status.wait();
}
// if (this->table_id_map.count(name)) {
// uint32_t table_id = this->table_id_map[name];
// auto status = worker_ptr->batch_sample_neighbors(
// table_id, node_ids, sample_size, v, v1, return_weight);
// status.wait();
// }
// res.first[0]: neighbors (nodes)
// res.first[1]: slice index
// res.first[2]: src nodes
// res.second: edges weight
std::pair<std::vector<std::vector<int64_t>>, std::vector<float>> res;
res.first.push_back({});
res.first.push_back({});
if (return_edges) res.first.push_back({});
for (size_t i = 0; i < v.size(); i++) {
for (size_t j = 0; j < v[i].size(); j++) {
// res.first[0].push_back(v[i][j].first);
res.first[0].push_back(v[i][j]);
if (return_edges) res.first[2].push_back(node_ids[i]);
if (return_weight) res.second.push_back(v1[i][j]);
}
if (i == v.size() - 1) break;
if (i == 0) {
res.first[1].push_back(v[i].size());
} else {
res.first[1].push_back(v[i].size() + res.first[1].back());
}
}
return res;
}
std::vector<int64_t> GraphPyClient::random_sample_nodes(std::string name,
int server_index,
int sample_size) {
std::vector<int64_t> v;
if (feature_to_id.find(name) != feature_to_id.end()) {
int idx = feature_to_id[name];
auto status = get_ps_client()->random_sample_nodes(
0, 1, idx, server_index, sample_size, v);
status.wait();
} else if (edge_to_id.find(name) != edge_to_id.end()) {
int idx = edge_to_id[name];
auto status = get_ps_client()->random_sample_nodes(
0, 0, idx, server_index, sample_size, v);
status.wait();
}
// if (this->table_id_map.count(name)) {
// uint32_t table_id = this->table_id_map[name];
// auto status =
// worker_ptr->random_sample_nodes(table_id, server_index, sample_size,
// v);
// status.wait();
// }
return v;
}
// (name, dtype, ndarray)
std::vector<std::vector<std::string>> GraphPyClient::get_node_feat(
std::string name,
std::vector<int64_t> node_ids,
std::vector<std::string> feature_names) {
std::vector<std::vector<std::string>> v(
feature_names.size(), std::vector<std::string>(node_ids.size()));
if (feature_to_id.find(name) != feature_to_id.end()) {
int idx = feature_to_id[name];
auto status =
get_ps_client()->get_node_feat(0, idx, node_ids, feature_names, v);
status.wait();
}
// if (this->table_id_map.count(node_type)) {
// uint32_t table_id = this->table_id_map[node_type];
// auto status =
// worker_ptr->get_node_feat(table_id, node_ids, feature_names, v);
// status.wait();
// }
return v;
}
void GraphPyClient::set_node_feat(
std::string name,
std::vector<int64_t> node_ids,
std::vector<std::string> feature_names,
const std::vector<std::vector<std::string>> features) {
if (feature_to_id.find(name) != feature_to_id.end()) {
int idx = feature_to_id[name];
auto status = get_ps_client()->set_node_feat(
0, idx, node_ids, feature_names, features);
status.wait();
}
// if (this->table_id_map.count(node_type)) {
// uint32_t table_id = this->table_id_map[node_type];
// auto status =
// worker_ptr->set_node_feat(table_id, node_ids, feature_names,
// features);
// status.wait();
// }
return;
}
std::vector<FeatureNode> GraphPyClient::pull_graph_list(
std::string name, int server_index, int start, int size, int step) {
std::vector<FeatureNode> res;
// if (this->table_id_map.count(name)) {
// uint32_t table_id = this->table_id_map[name];
// auto status = worker_ptr->pull_graph_list(table_id, server_index, start,
// size, step, res);
// status.wait();
// }
if (feature_to_id.find(name) != feature_to_id.end()) {
int idx = feature_to_id[name];
auto status = get_ps_client()->pull_graph_list(
0, 1, idx, server_index, start, size, step, res);
status.wait();
} else if (edge_to_id.find(name) != edge_to_id.end()) {
int idx = edge_to_id[name];
auto status = get_ps_client()->pull_graph_list(
0, 0, idx, server_index, start, size, step, res);
status.wait();
}
return res;
}
void GraphPyClient::StopServer() {
VLOG(0) << "going to stop server";
std::unique_lock<std::mutex> lock(mutex_);
if (stoped_) return;
auto status = this->worker_ptr->StopServer();
if (status.get() == 0) stoped_ = true;
}
void GraphPyClient::FinalizeWorker() { this->worker_ptr->FinalizeWorker(); }
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <unistd.h>
#include <condition_variable> // NOLINT
#include <fstream>
#include <iomanip>
#include <iostream>
#include <sstream>
#include <string>
#include <thread> // NOLINT
#include <unordered_map>
#include <vector>
#include "google/protobuf/text_format.h"
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/ps/service/env.h"
#include "paddle/fluid/distributed/ps/service/graph_brpc_client.h"
#include "paddle/fluid/distributed/ps/service/graph_brpc_server.h"
#include "paddle/fluid/distributed/ps/service/ps_service/service.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/the_one_ps.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace distributed {
class GraphPyService {
protected:
std::vector<std::string> server_list, port_list, host_sign_list;
int server_size, shard_num;
int num_node_types;
std::unordered_map<std::string, int> edge_to_id, feature_to_id;
std::vector<std::string> id_to_feature, id_to_edge;
std::vector<std::unordered_map<std::string, int>> table_feat_mapping;
std::vector<std::vector<std::string>> table_feat_conf_feat_name;
std::vector<std::vector<std::string>> table_feat_conf_feat_dtype;
std::vector<std::vector<int>> table_feat_conf_feat_shape;
public:
int get_shard_num() { return shard_num; }
void set_shard_num(int shard_num) { this->shard_num = shard_num; }
void GetDownpourSparseTableProto(
::paddle::distributed::TableParameter* sparse_table_proto) {
sparse_table_proto->set_table_id(0);
sparse_table_proto->set_table_class("GraphTable");
sparse_table_proto->set_shard_num(shard_num);
sparse_table_proto->set_type(::paddle::distributed::PS_SPARSE_TABLE);
::paddle::distributed::TableAccessorParameter* accessor_proto =
sparse_table_proto->mutable_accessor();
// ::paddle::distributed::CommonAccessorParameter* common_proto =
// sparse_table_proto->mutable_common();
::paddle::distributed::GraphParameter* graph_proto =
sparse_table_proto->mutable_graph_parameter();
// ::paddle::distributed::GraphFeature* graph_feature =
// graph_proto->mutable_graph_feature();
graph_proto->set_task_pool_size(24);
graph_proto->set_table_name("cpu_graph_table");
graph_proto->set_use_cache(false);
for (size_t i = 0; i < id_to_edge.size(); i++)
graph_proto->add_edge_types(id_to_edge[i]);
for (size_t i = 0; i < id_to_feature.size(); i++) {
graph_proto->add_node_types(id_to_feature[i]);
auto feat_node = id_to_feature[i];
::paddle::distributed::GraphFeature* g_f =
graph_proto->add_graph_feature();
for (size_t x = 0; x < table_feat_conf_feat_name[i].size(); x++) {
g_f->add_name(table_feat_conf_feat_name[i][x]);
g_f->add_dtype(table_feat_conf_feat_dtype[i][x]);
g_f->add_shape(table_feat_conf_feat_shape[i][x]);
}
}
// Set GraphTable Parameter
// common_proto->set_table_name(table_name);
// common_proto->set_name(table_type);
// for (size_t i = 0; i < feat_name.size(); i++) {
// common_proto->add_params(feat_dtype[i]);
// common_proto->add_dims(feat_shape[i]);
// common_proto->add_attributes(feat_name[i]);
// }
// for (size_t i = 0; i < feat_name.size(); i++) {
// graph_feature->add_dtype(feat_dtype[i]);
// graph_feature->add_shape(feat_shape[i]);
// graph_feature->add_name(feat_name[i]);
// }
accessor_proto->set_accessor_class("CommMergeAccessor");
}
void set_server_size(int server_size) { this->server_size = server_size; }
void set_num_node_types(int num_node_types) {
this->num_node_types = num_node_types;
}
int get_server_size(int server_size) { return server_size; }
std::vector<std::string> split(std::string& str, const char pattern);
void set_up(std::string ips_str,
int shard_num,
std::vector<std::string> node_types,
std::vector<std::string> edge_types);
void add_table_feat_conf(std::string node_type,
std::string feat_name,
std::string feat_dtype,
int32_t feat_shape);
};
class GraphPyServer : public GraphPyService {
public:
GraphPyServer() {}
void set_up(std::string ips_str,
int shard_num,
std::vector<std::string> node_types,
std::vector<std::string> edge_types,
int rank) {
set_rank(rank);
GraphPyService::set_up(ips_str, shard_num, node_types, edge_types);
}
int GetRank() { return rank; }
void set_rank(int rank) { this->rank = rank; }
void start_server(bool block = true);
::paddle::distributed::PSParameter GetServerProto();
std::shared_ptr<paddle::distributed::GraphBrpcServer> get_ps_server() {
return pserver_ptr;
}
protected:
int rank;
std::shared_ptr<paddle::distributed::GraphBrpcServer> pserver_ptr;
std::thread* server_thread;
};
class GraphPyClient : public GraphPyService {
public:
void set_up(std::string ips_str,
int shard_num,
std::vector<std::string> node_types,
std::vector<std::string> edge_types,
int client_id) {
set_client_id(client_id);
GraphPyService::set_up(ips_str, shard_num, node_types, edge_types);
}
std::shared_ptr<paddle::distributed::GraphBrpcClient> get_ps_client() {
return worker_ptr;
}
void bind_local_server(int local_channel_index, GraphPyServer& server) {
worker_ptr->set_local_channel(local_channel_index);
worker_ptr->set_local_graph_service(
(paddle::distributed::GraphBrpcService*)server.get_ps_server()
->get_service());
}
void StopServer();
void FinalizeWorker();
void load_edge_file(std::string name, std::string filepath, bool reverse);
void load_node_file(std::string name, std::string filepath);
void clear_nodes(std::string name);
void add_graph_node(std::string name,
std::vector<int64_t>& node_ids,
std::vector<bool>& weight_list);
void remove_graph_node(std::string name, std::vector<int64_t>& node_ids);
int get_client_id() { return client_id; }
void set_client_id(int client_id) { this->client_id = client_id; }
void start_client();
std::pair<std::vector<std::vector<int64_t>>, std::vector<float>>
batch_sample_neighbors(std::string name,
std::vector<int64_t> node_ids,
int sample_size,
bool return_weight,
bool return_edges);
std::vector<int64_t> random_sample_nodes(std::string name,
int server_index,
int sample_size);
std::vector<std::vector<std::string>> get_node_feat(
std::string name,
std::vector<int64_t> node_ids,
std::vector<std::string> feature_names);
void set_node_feat(std::string node_type,
std::vector<int64_t> node_ids,
std::vector<std::string> feature_names,
const std::vector<std::vector<std::string>> features);
std::vector<FeatureNode> pull_graph_list(
std::string name, int server_index, int start, int size, int step = 1);
::paddle::distributed::PSParameter GetWorkerProto();
protected:
mutable std::mutex mutex_;
int client_id;
std::shared_ptr<paddle::distributed::GraphBrpcClient> worker_ptr;
std::thread* client_thread;
bool stoped_ = false;
};
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/ps/service/ps_service/service.h"
#include <fcntl.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <google/protobuf/text_format.h>
#include <iostream>
#include "paddle/fluid/distributed/ps/service/communicator/communicator.h"
#include "paddle/fluid/string/string_helper.h"
using namespace std; // NOLINT
namespace paddle {
namespace distributed {
paddle::distributed::PSParameter load_from_prototxt(
const std::string& filename) {
paddle::distributed::PSParameter param;
int file_descriptor = open(filename.c_str(), O_RDONLY);
if (file_descriptor == -1) {
VLOG(3) << "FATAL: fail to parse " << filename;
exit(-1);
}
google::protobuf::io::FileInputStream fileInput(file_descriptor);
if (!google::protobuf::TextFormat::Parse(&fileInput, &param)) {
VLOG(3) << "FATAL: fail to parse " << filename;
exit(-1);
}
close(file_descriptor);
return param;
}
void PSCore::InitGFlag(const std::string& gflags) {
VLOG(3) << "Init With Gflags:" << gflags;
std::vector<std::string> flags = paddle::string::split_string(gflags);
if (flags.size() < 1) {
flags.push_back("-max_body_size=314217728");
flags.push_back("-socket_max_unwritten_bytes=2048000000");
flags.push_back("-max_connection_pool_size=1950");
}
auto it = flags.begin();
flags.insert(it, "exe default");
char* flags_ptr[flags.size()];
for (size_t i = 0; i < flags.size(); ++i) {
flags_ptr[i] = (char*)(flags[i].c_str()); // NOLINT
}
int params_cnt = flags.size();
char** params_ptr = &(flags_ptr[0]);
::GFLAGS_NAMESPACE::ParseCommandLineFlags(&params_cnt, &params_ptr, true);
}
int PSCore::InitServer(
const std::string& dist_desc,
const std::vector<std::string>* host_sign_list,
int node_num,
int index,
int trainers,
const std::vector<framework::ProgramDesc>& server_sub_program) {
google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param);
InitGFlag(_ps_param.init_gflags());
_ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.SetPsServers(host_sign_list, node_num);
_ps_env.SetTrainers(trainers);
int ret = 0;
_server_ptr = std::shared_ptr<paddle::distributed::PSServer>(
paddle::distributed::PSServerFactory::Create(_ps_param));
ret = _server_ptr->Configure(_ps_param, _ps_env, index, server_sub_program);
CHECK(ret == 0) << "failed to configure server";
return ret;
}
int PSCore::InitWorker(
const std::string& dist_desc,
const std::map<uint64_t, std::vector<paddle::distributed::Region>>& regions,
const std::vector<std::string>* host_sign_list,
int node_num,
int index) {
google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param);
InitGFlag(_ps_param.init_gflags());
_ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.SetPsServers(host_sign_list, node_num);
int ret = 0;
VLOG(1) << "PSCore::InitWorker";
auto* communicator = Communicator::GetInstance();
ret = communicator->GetPsClient()->Configure(
_ps_param, regions, _ps_env, index);
communicator->Start();
return ret;
}
std::vector<uint64_t> PSCore::GetClientInfo() {
return _ps_env.GetClientInfo();
}
int PSCore::CreateClient2ClientConnection(int pserver_timeout_ms,
int pserver_connect_timeout_ms,
int max_retry) {
int ret = _worker_ptr->CreateClient2ClientConnection(
pserver_timeout_ms, pserver_connect_timeout_ms, max_retry);
return ret;
}
uint64_t PSCore::RunServer(const std::string& ip, uint32_t port) {
return _server_ptr->Start(ip, port);
}
int PSCore::FinalizeWorker() {
_worker_ptr->FinalizeWorker();
return 0;
}
int PSCore::StopServer() {
auto stop_status = _worker_ptr->StopServer();
stop_status.wait();
return 0;
}
paddle::distributed::PSParameter* PSCore::GetParam() { return &_ps_param; }
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/distributed/ps/service/ps_client.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/ps/service/server.h"
#include "paddle/fluid/distributed/the_one_ps.pb.h"
namespace paddle {
namespace distributed {
class PSClient;
class PSServer;
class PsRequestMessage;
class PsResponseMessage;
class PsService;
using paddle::distributed::PsRequestMessage;
using paddle::distributed::PsResponseMessage;
using paddle::distributed::PsService;
class PSCore {
public:
explicit PSCore() {}
virtual ~PSCore() {}
virtual int InitServer(
const std::string& dist_desc,
const std::vector<std::string>* host_sign_list,
int node_num,
int index,
int trainers,
const std::vector<framework::ProgramDesc>& server_sub_program = {});
virtual int InitWorker(
const std::string& dist_desc,
const std::map<uint64_t, std::vector<paddle::distributed::Region>>&
regions,
const std::vector<std::string>* host_sign_list,
int node_num,
int index);
virtual uint64_t RunServer(const std::string& ip, uint32_t port);
virtual int StopServer();
virtual int FinalizeWorker();
virtual std::vector<uint64_t> GetClientInfo();
virtual int CreateClient2ClientConnection(int pserver_timeout_ms,
int pserver_connect_timeout_ms,
int max_retry);
std::shared_ptr<paddle::distributed::PSServer>
_server_ptr; // pointer to server
std::shared_ptr<paddle::distributed::PSClient>
_worker_ptr; // pointer to worker
virtual paddle::distributed::PSParameter* GetParam();
private:
void InitGFlag(const std::string& gflags);
paddle::distributed::PSParameter _ps_param;
paddle::distributed::PaddlePSEnvironment _ps_env;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package paddle.distributed;
option cc_generic_services = true;
option cc_enable_arenas = true;
enum PsCmdID {
PS_PULL_DENSE_TABLE = 0;
PS_PUSH_DENSE_TABLE = 1;
PS_PULL_SPARSE_TABLE = 2;
PS_PUSH_SPARSE_TABLE = 3;
PS_SHRINK_TABLE = 4;
PS_SAVE_ONE_TABLE = 5;
PS_SAVE_ALL_TABLE = 6;
PS_LOAD_ONE_TABLE = 7;
PS_LOAD_ALL_TABLE = 8;
PS_CLEAR_ONE_TABLE = 9;
PS_CLEAR_ALL_TABLE = 10;
PS_PUSH_DENSE_PARAM = 11;
PS_STOP_SERVER = 12;
PS_SAVE_ONE_CACHE_TABLE = 13;
PS_GET_CACHE_THRESHOLD = 14;
PS_CACHE_SHUFFLE = 15;
PS_COPY_TABLE = 16;
PS_COPY_TABLE_BY_FEASIGN = 17;
PS_PULL_SPARSE_TABLE_WITH_DEPENDENCY = 18;
PS_PUSH_SPARSE_TABLE_WITH_DEPENDENCY = 19;
PS_PRINT_TABLE_STAT = 20;
PS_SAVE_ONE_TABLE_PREFIX = 21;
PS_SAVE_ONE_TABLE_WITH_WHITELIST = 22;
PS_LOAD_ONE_TABLE_WITH_WHITELIST = 23;
PS_PULL_GEO_PARAM = 24;
PS_BARRIER = 25;
PS_PUSH_SPARSE_PARAM = 26;
PS_START_PROFILER = 27;
PS_STOP_PROFILER = 28;
PS_PUSH_GLOBAL_STEP = 29;
PS_PULL_GRAPH_LIST = 30;
PS_GRAPH_SAMPLE_NEIGHBORS = 31;
PS_GRAPH_SAMPLE_NODES = 32;
PS_GRAPH_GET_NODE_FEAT = 33;
PS_GRAPH_CLEAR = 34;
PS_GRAPH_ADD_GRAPH_NODE = 35;
PS_GRAPH_REMOVE_GRAPH_NODE = 36;
PS_GRAPH_SET_NODE_FEAT = 37;
PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER = 38;
PS_GRAPH_USE_NEIGHBORS_SAMPLE_CACHE = 39;
PS_GRAPH_LOAD_GRAPH_SPLIT_CONFIG = 40;
PEER_ROLE_IS_WORKER = 41;
PEER_ROLE_IS_SWITCH = 42;
PS_SAVE_WITH_SCOPE = 43;
PS_SAVE_WITH_SHARD = 44;
PS_QUERY_WITH_SCOPE = 45;
PS_QUERY_WITH_SHARD = 46;
PS_REVERT = 47;
PS_CHECK_SAVE_PRE_PATCH_DONE = 48;
// pserver2pserver cmd start from 100
PS_S2S_MSG = 101;
PUSH_FL_CLIENT_INFO_SYNC = 200;
PUSH_FL_STRATEGY = 201;
}
message PsRequestMessage {
required uint32 cmd_id = 1;
optional uint32 table_id = 2;
repeated bytes params = 3;
optional int32 client_id = 4;
optional bytes data = 5;
};
message PsResponseMessage {
required int32 err_code = 1 [ default = 0 ];
required string err_msg = 2 [ default = "" ];
optional bytes data = 3;
};
message CoordinatorReqMessage {
required uint32 cmd_id = 1;
optional int32 client_id = 2;
optional string str_params = 3;
};
message CoordinatorResMessage {
required int32 err_code = 1 [ default = 0 ];
required string err_msg = 2 [ default = "" ];
optional string str_params = 3;
};
enum VarType {
LOD_TENSOR = 0;
SELECTED_ROWS = 1;
}
message VariableMessage {
enum Type {
// Pod Types
BOOL = 0;
INT16 = 1;
INT32 = 2;
INT64 = 3;
FP16 = 4;
FP32 = 5;
FP64 = 6;
}
message LodData { repeated int64 lod_data = 1; }
optional string varname = 1;
// TODO(Yancey1989): reference framework::proto::VarDesc::VarType
optional VarType type = 2;
// bool persistable is not needed for sending.
// tensor info:
optional Type data_type = 3;
repeated int64 dims = 4;
// lod details:
optional int64 lod_level = 5;
repeated LodData lod = 6;
// selected_rows height, aka. original dim0
optional int64 slr_height = 7;
// tensor data
optional bytes data = 8;
}
// for SendAndRecv RPC method
message MultiVariableMessage {
// message flags
required string message_name = 1;
repeated string send_var_names = 2;
repeated string recv_var_names = 3;
repeated VariableMessage var_messages = 4;
optional bytes data = 5;
repeated int64 vars_len = 6;
optional int32 group_id = 7;
};
service PsService {
rpc service(PsRequestMessage) returns (PsResponseMessage);
rpc FLService(CoordinatorReqMessage) returns (CoordinatorResMessage);
rpc SendAndRecvVariable(MultiVariableMessage) returns (MultiVariableMessage);
rpc SendToWorker(MultiVariableMessage) returns (PsResponseMessage);
rpc SendToSwitch(MultiVariableMessage) returns (PsResponseMessage);
rpc SendS2S(MultiVariableMessage) returns (PsResponseMessage);
rpc RecvFromSwitch(MultiVariableMessage) returns (MultiVariableMessage);
};
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/server.h"
#include "glog/logging.h"
#include "paddle/fluid/distributed/ps/service/brpc_ps_server.h"
#include "paddle/fluid/distributed/ps/service/graph_brpc_server.h"
#include "paddle/fluid/distributed/ps/service/ps_local_server.h"
#include "paddle/fluid/distributed/ps/table/table.h"
namespace paddle {
namespace distributed {
REGISTER_PSCORE_CLASS(PSServer, BrpcPsServer);
REGISTER_PSCORE_CLASS(PSServer, PsLocalServer);
REGISTER_PSCORE_CLASS(PsBaseService, BrpcPsService);
REGISTER_PSCORE_CLASS(PSServer, GraphBrpcServer);
REGISTER_PSCORE_CLASS(PsBaseService, GraphBrpcService);
PSServer *PSServerFactory::Create(const PSParameter &ps_config) {
const auto &config = ps_config.server_param();
if (!config.has_downpour_server_param()) {
LOG(ERROR) << "miss downpour_server_param in ServerParameter";
return NULL;
}
if (!config.downpour_server_param().has_service_param()) {
LOG(ERROR) << "miss service_param in ServerParameter.downpour_server_param";
return NULL;
}
if (!config.downpour_server_param().service_param().has_server_class()) {
LOG(ERROR) << "miss server_class in "
"ServerParameter.downpour_server_param.service_param";
return NULL;
}
const auto &service_param = config.downpour_server_param().service_param();
PSServer *server =
CREATE_PSCORE_CLASS(PSServer, service_param.server_class());
if (server == NULL) {
LOG(ERROR) << "server is not registered, server_name:"
<< service_param.server_class();
return NULL;
}
TableManager::Instance().Initialize();
return server;
}
int32_t PSServer::Configure(
const PSParameter &config,
PSEnvironment &env,
size_t server_rank,
const std::vector<framework::ProgramDesc> &server_sub_program) {
scope_.reset(new framework::Scope());
_config = config.server_param();
_rank = server_rank;
_environment = &env;
_shuffled_ins =
paddle::framework::MakeChannel<std::pair<uint64_t, std::string>>();
size_t shard_num = env.GetPsServers().size();
const auto &downpour_param = _config.downpour_server_param();
uint32_t barrier_table = UINT32_MAX;
uint32_t global_step_table = UINT32_MAX;
for (int i = 0; i < downpour_param.downpour_table_param_size(); ++i) {
auto *table = CREATE_PSCORE_CLASS(
Table, downpour_param.downpour_table_param(i).table_class());
if (downpour_param.downpour_table_param(i).table_class() ==
"BarrierTable") {
barrier_table = downpour_param.downpour_table_param(i).table_id();
}
if (downpour_param.downpour_table_param(i).table_class() ==
"GlobalStepTable") {
global_step_table = downpour_param.downpour_table_param(i).table_id();
}
table->SetProgramEnv(scope_.get(), place_, &server_sub_program);
table->SetShard(_rank, shard_num);
table->Initialize(downpour_param.downpour_table_param(i),
config.fs_client_param());
_table_map[downpour_param.downpour_table_param(i).table_id()].reset(table);
}
if (barrier_table != UINT32_MAX) {
_table_map[barrier_table]->SetTableMap(&_table_map);
}
if (global_step_table != UINT32_MAX) {
_table_map[global_step_table]->SetTableMap(&_table_map);
}
return Initialize();
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <future>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "butil/endpoint.h"
#include "google/protobuf/service.h"
#include "paddle/fluid/distributed/common/registerer.h"
#include "paddle/fluid/distributed/ps/service/env.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/the_one_ps.pb.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
namespace google {
namespace protobuf {
class RpcController;
} // namespace protobuf
} // namespace google
namespace paddle {
namespace distributed {
class PSEnvironment;
} // namespace distributed
} // namespace paddle
namespace paddle {
namespace framework {
class Executor;
class ProgramDesc;
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace distributed {
class Table;
using paddle::distributed::PsRequestMessage;
using paddle::distributed::PsResponseMessage;
class PSServer {
public:
PSServer() {}
virtual ~PSServer() {}
PSServer(PSServer &&) = delete;
PSServer(const PSServer &) = delete;
virtual int32_t Configure(
const PSParameter &config,
PSEnvironment &env,
size_t server_rank,
const std::vector<framework::ProgramDesc> &server_sub_program = {});
virtual uint64_t Start(const std::string &ip, uint32_t port) = 0;
virtual int32_t Stop() = 0;
inline size_t Rank() const { return _rank; }
inline PSEnvironment *Environment() { return _environment; }
inline const ServerParameter *Config() const { return &_config; }
inline Table *GetTable(size_t table_id) {
auto itr = _table_map.find(table_id);
if (itr != _table_map.end()) {
return itr->second.get();
}
return NULL;
}
inline std::unordered_map<uint32_t, std::shared_ptr<Table>> *GetTable() {
return &_table_map;
}
// for cache
virtual int32_t StartS2S() { return 0; }
virtual ::std::future<int32_t> SendPServer2PServerMsg(
int msg_type, int to_pserver_id, const std::string &msg) {
LOG(FATAL) << "NotImplementError: PSServer::send_pserver2pserver_msg";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
typedef std::function<int32_t(int, int, const std::string &)> MsgHandlerFunc;
virtual int RegistePServer2PServerMsgHandler(int msg_type,
MsgHandlerFunc handler) {
_msg_handler_map[msg_type] = handler;
return 0;
}
virtual int HandlePServer2PServerMsg(int msg_type,
int from_pserver_id,
const std::string &msg) {
auto itr = _msg_handler_map.find(msg_type);
if (itr == _msg_handler_map.end()) {
if (msg_type == 101) {
return ReceiveFromPServer(msg_type, from_pserver_id, msg);
} else {
LOG(WARNING) << "unknown pserver2pserver_msg type:" << msg_type;
return -1;
}
}
return itr->second(msg_type, from_pserver_id, msg);
}
virtual int32_t ReceiveFromPServer(int msg_type,
int pserver_id,
const std::string &msg) {
LOG(FATAL) << "NotImplementError::PSServer::ReceiveFromPServer";
return -1;
}
paddle::framework::Channel<std::pair<uint64_t, std::string>> _shuffled_ins;
protected:
virtual int32_t Initialize() = 0;
protected:
size_t _rank;
ServerParameter _config;
PSEnvironment *_environment;
std::unordered_map<uint32_t, std::shared_ptr<Table>> _table_map;
std::unordered_map<int32_t, MsgHandlerFunc> _msg_handler_map;
protected:
std::shared_ptr<framework::Scope> scope_;
platform::Place place_ = platform::CPUPlace();
};
REGISTER_PSCORE_REGISTERER(PSServer);
typedef std::function<void(void *)> PServerCallBack;
class PServerClosure : public google::protobuf::Closure {
public:
PServerClosure(PServerCallBack callback) : _callback(callback) {}
virtual ~PServerClosure() {}
virtual void set_promise_value(int value) {
for (auto &promise : _promises) {
promise->set_value(value);
}
}
void add_promise(std::shared_ptr<std::promise<int32_t>> &promise) {
_promises.push_back(promise);
}
protected:
PServerCallBack _callback;
std::vector<std::shared_ptr<std::promise<int32_t>>> _promises;
};
class PsBaseService : public PsService {
public:
PsBaseService() : _rank(0), _server(NULL), _config(NULL) {}
virtual ~PsBaseService() {}
virtual size_t GetRank() { return _rank; }
virtual int32_t Configure(PSServer *server) {
_server = server;
_rank = _server->Rank();
_config = _server->Config();
return 0;
}
virtual void service(::google::protobuf::RpcController *controller,
const PsRequestMessage *request,
PsResponseMessage *response,
::google::protobuf::Closure *done) override = 0;
virtual void set_response_code(PsResponseMessage &response,
int err_code,
const char *err_msg) {
response.set_err_msg(err_msg);
response.set_err_code(err_code);
LOG(WARNING) << "Resonse err_code:" << err_code << " msg:" << err_msg;
}
virtual int32_t Initialize() = 0;
PSServer *GetServer() { return _server; }
protected:
size_t _rank;
PSServer *_server;
const ServerParameter *_config;
};
REGISTER_PSCORE_REGISTERER(PsBaseService);
class PSServerFactory {
public:
static PSServer *Create(const PSParameter &config);
};
} // namespace distributed
} // namespace paddle
set_property(GLOBAL PROPERTY TABLE_DEPS string_helper)
set(graphDir graph)
get_property(TABLE_DEPS GLOBAL PROPERTY TABLE_DEPS)
set_source_files_properties(
${graphDir}/graph_edge.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS})
cc_library(graph_edge SRCS ${graphDir}/graph_edge.cc)
set_source_files_properties(
${graphDir}/graph_weighted_sampler.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS})
cc_library(
WeightedSampler
SRCS ${graphDir}/graph_weighted_sampler.cc
DEPS graph_edge)
set_source_files_properties(
${graphDir}/graph_node.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS})
cc_library(
graph_node
SRCS ${graphDir}/graph_node.cc
DEPS WeightedSampler enforce)
set_source_files_properties(
memory_dense_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
barrier_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
common_graph_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS)
set(PADDLE_LIB_THIRD_PARTY_PATH "${PADDLE_LIB}/third_party/")
include_directories(
${PADDLE_LIB_THIRD_PARTY_PATH}libmct/src/extern_libmct/libmct/include)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp")
set(TABLE_SRC memory_dense_table.cc barrier_table.cc common_graph_table.cc)
#set(EXTERN_DEP rocksdb)
cc_library(
common_table
SRCS ${TABLE_SRC}
DEPS ${TABLE_DEPS}
${RPC_DEPS}
graph_edge
graph_node
device_context
string_helper
simple_threadpool
xxhash
generator)
set_source_files_properties(
tensor_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(
tensor_table
SRCS
DEPS eigen3
ps_framework_proto
executor
scope
device_context
tensor
${TABLE_DEPS})
set_source_files_properties(table.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
sparse_sgd_rule.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
ctr_double_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
ctr_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
sparse_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
ctr_dymf_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
memory_sparse_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
ssd_sparse_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
memory_sparse_geo_table.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS})
cc_library(
table
SRCS sparse_sgd_rule.cc
ctr_accessor.cc
ctr_double_accessor.cc
sparse_accessor.cc
ctr_dymf_accessor.cc
tensor_accessor.cc
memory_sparse_table.cc
ssd_sparse_table.cc
memory_sparse_geo_table.cc
table.cc
DEPS ${TABLE_DEPS}
common_table
tensor_table
ps_framework_proto
string_helper
device_context
gflags
glog
fs
afs_wrapper
rocksdb
eigen3)
target_link_libraries(table -fopenmp)
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdint.h>
#include <stdio.h>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/common/afs_warpper.h"
#include "paddle/fluid/distributed/common/registerer.h"
#include "paddle/fluid/distributed/the_one_ps.pb.h"
namespace paddle {
namespace distributed {
struct Region {
Region() : data(NULL), size(0) {}
Region(char* data, size_t data_num) : data(data), size(data_num) {}
Region(float* data, size_t data_num)
: data(reinterpret_cast<char*>(data)), size(data_num << 2) {}
Region(int16_t* data, size_t data_num)
: data(reinterpret_cast<char*>(data)), size(data_num << 1) {}
Region(int32_t* data, size_t data_num)
: data(reinterpret_cast<char*>(data)), size(data_num << 2) {}
Region(int64_t* data, size_t data_num)
: data(reinterpret_cast<char*>(data)), size(data_num << 3) {}
char* data;
size_t size;
};
struct DataConverter {
int param;
std::string converter;
std::string deconverter;
};
struct AccessorInfo {
// value维度
size_t dim;
// value各个维度的size
size_t size;
// pull value维度
size_t select_dim;
// pull value各维度相加总size
size_t select_size;
// push value维度
size_t update_dim;
// push value各个维度的size
size_t update_size;
// value中mf动态长度部分总size大小, sparse下生效
size_t mf_size;
// value总维度,dense下生效
size_t fea_dim;
};
class ValueAccessor {
public:
ValueAccessor() {}
virtual ~ValueAccessor() {}
virtual int Configure(const TableAccessorParameter& parameter) {
_config = parameter;
// data_convert结构体初始化
if (_config.table_accessor_save_param_size() != 0) {
for (int i = 0; i < _config.table_accessor_save_param_size(); ++i) {
int param = _config.table_accessor_save_param(i).param();
std::string converter =
_config.table_accessor_save_param(i).converter();
std::string deconverter =
_config.table_accessor_save_param(i).deconverter();
_data_coverter_map[param] = std::make_shared<DataConverter>();
*(_data_coverter_map[param]) = {param, converter, deconverter};
}
}
return 0;
}
virtual int Initialize() = 0;
virtual AccessorInfo GetAccessorInfo() { return _accessor_info; }
virtual bool NeedExtendMF(float* value) { return false; }
virtual bool HasMF(size_t size) { return false; }
// converter for save
virtual std::string GetConverter(int param) {
auto itr = _data_coverter_map.find(param);
if (itr == _data_coverter_map.end()) {
return "";
} else {
return (*itr).second->converter;
}
}
// deconverter for load
virtual std::string GetDeconverter(int param) {
auto itr = _data_coverter_map.find(param);
if (itr == _data_coverter_map.end()) {
return "";
} else {
return (*itr).second->deconverter;
}
}
// 判断该value是否进行shrink
virtual bool Shrink(float* value) = 0;
// 判断该value是否在save阶段dump,
// param作为参数用于标识save阶段,如downpour的xbox与batch_model
virtual bool Save(float* value, int param) = 0;
// update delta_score and unseen_days after save
virtual void UpdateStatAfterSave(float* value, int param) {}
// 判断该value是否保存到ssd
virtual bool SaveSSD(float* value) = 0;
//
virtual bool SaveCache(float* value,
int param,
double global_cache_threshold) = 0;
// keys不存在时,为values生成随机值
virtual int32_t Create(float** value, size_t num) = 0;
virtual bool CreateValue(int type, const float* value) { return true; }
// 从values中选取到select_values中
virtual int32_t Select(float** select_values,
const float** values,
size_t num) = 0;
// 将update_values聚合到一起
virtual int32_t Merge(float** update_values,
const float** other_update_values,
size_t num) = 0;
// 将update_values聚合到一起,通过it.next判定是否进入下一个key
// virtual int32_t Merge(float** update_values, iterator it);
// 将update_values更新应用到values中
virtual int32_t Update(float** values,
const float** update_values,
size_t num) = 0;
// used to save model, will filter feature
virtual std::string ParseToString(const float* value, int param) = 0;
// parse value from string, used to load model
virtual int32_t ParseFromString(const std::string& data, float* value) = 0;
virtual FsDataConverter Converter(int param) {
FsDataConverter data_convert;
data_convert.converter = this->GetConverter(param);
data_convert.deconverter = this->GetDeconverter(param);
return data_convert;
}
virtual int SetWeight(float** values,
const float** update_values,
size_t num) {
return 0;
}
virtual float GetField(float* value, const std::string& name) { return 0.0; }
#define DEFINE_GET_INDEX(class, field) \
virtual int get_##field##_index() override { return class ::field##_index(); }
protected:
size_t _value_size;
size_t _select_value_size;
size_t _update_value_size;
TableAccessorParameter _config;
std::unordered_map<int, std::shared_ptr<struct DataConverter>>
_data_coverter_map;
AccessorInfo _accessor_info;
};
REGISTER_PSCORE_REGISTERER(ValueAccessor);
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/table/common_table.h"
namespace paddle {
namespace distributed {
int32_t BarrierTable::Initialize() {
auto trainers = _config.common().trainer_num();
trigger_.store(trainers);
for (int x = 0; x < trainers; ++x) {
trainer_all_.insert(x);
}
VLOG(1) << "BarrierTable init trigger: " << trigger_.load();
return 0;
}
// 0: send_barrier 1: recv_barrier 2: complete
int32_t BarrierTable::Barrier(const uint32_t trainer_id,
const std::string barrier_type) {
std::unique_lock<std::mutex> lock(mutex_);
if (barrier_type == "2") {
trigger_.fetch_sub(1, std::memory_order::memory_order_relaxed);
VLOG(1) << "trigger sub to : " << trigger_.load();
} else {
trainer_ids_.insert(trainer_id);
VLOG(1) << "barrier type: " << barrier_type
<< " add trainer id: " << trainer_id;
}
if (static_cast<int>(trainer_ids_.size()) < trigger_.load()) {
std::vector<uint32_t> diffs(trainer_all_.size());
auto iter = std::set_difference(trainer_all_.begin(),
trainer_all_.end(),
trainer_ids_.begin(),
trainer_ids_.end(),
diffs.begin());
diffs.resize(iter - diffs.begin());
auto diff = to_string<uint32_t>(diffs);
VLOG(1) << "still need trainers: " << diff;
trainer_wait_.wait(lock, [&] { return trainer_ids_.size() == 0; });
} else {
VLOG(1) << "barrier table optimize begin";
for (auto& x : *table_map_) {
auto table = x.second;
table->Pour();
}
VLOG(1) << "barrier table optimize done";
trainer_ids_.clear();
trainer_wait_.notify_all();
}
return 0;
}
int32_t BarrierTable::SetTableMap(
std::unordered_map<uint32_t, std::shared_ptr<Table>>* table_map) {
table_map_ = table_map;
return 0;
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/table/common_graph_table.h"
#include <time.h>
#include <algorithm>
#include <chrono>
#include <set>
#include <sstream>
#include "gflags/gflags.h"
#include "paddle/fluid/distributed/common/utils.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/platform/timer.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/string_helper.h"
DECLARE_bool(graph_load_in_parallel);
namespace paddle {
namespace distributed {
#ifdef PADDLE_WITH_HETERPS
int32_t GraphTable::Load_to_ssd(const std::string &path,
const std::string &param) {
bool load_edge = (param[0] == 'e');
bool load_node = (param[0] == 'n');
if (load_edge) {
bool reverse_edge = (param[1] == '<');
std::string edge_type = param.substr(2);
return this->load_edges_to_ssd(path, reverse_edge, edge_type);
}
if (load_node) {
std::string node_type = param.substr(1);
return this->load_nodes(path, node_type);
}
return 0;
}
paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea(
std::vector<uint64_t> &node_ids, int slot_num) {
std::vector<std::vector<uint64_t>> bags(task_pool_size_);
for (int i = 0; i < task_pool_size_; i++) {
auto predsize = node_ids.size() / task_pool_size_;
bags[i].reserve(predsize * 1.2);
}
for (auto x : node_ids) {
int location = x % shard_num % task_pool_size_;
bags[location].push_back(x);
}
std::vector<std::future<int>> tasks;
std::vector<uint64_t> feature_array[task_pool_size_];
std::vector<uint8_t> slot_id_array[task_pool_size_];
std::vector<uint64_t> node_id_array[task_pool_size_];
std::vector<paddle::framework::GpuPsFeaInfo>
node_fea_info_array[task_pool_size_];
for (size_t i = 0; i < bags.size(); i++) {
if (bags[i].size() > 0) {
tasks.push_back(_shards_task_pool[i]->enqueue([&, i, this]() -> int {
uint64_t node_id;
paddle::framework::GpuPsFeaInfo x;
std::vector<uint64_t> feature_ids;
for (size_t j = 0; j < bags[i].size(); j++) {
// TODO use FEATURE_TABLE instead
Node *v = find_node(1, bags[i][j]);
node_id = bags[i][j];
if (v == NULL) {
x.feature_size = 0;
x.feature_offset = 0;
node_fea_info_array[i].push_back(x);
} else {
// x <- v
x.feature_offset = feature_array[i].size();
int total_feature_size = 0;
for (int k = 0; k < slot_num; ++k) {
v->get_feature_ids(k, &feature_ids);
total_feature_size += feature_ids.size();
if (!feature_ids.empty()) {
feature_array[i].insert(feature_array[i].end(),
feature_ids.begin(),
feature_ids.end());
slot_id_array[i].insert(
slot_id_array[i].end(), feature_ids.size(), k);
}
}
x.feature_size = total_feature_size;
node_fea_info_array[i].push_back(x);
}
node_id_array[i].push_back(node_id);
}
return 0;
}));
}
}
for (int i = 0; i < (int)tasks.size(); i++) tasks[i].get();
paddle::framework::GpuPsCommGraphFea res;
uint64_t tot_len = 0;
for (int i = 0; i < task_pool_size_; i++) {
tot_len += feature_array[i].size();
}
VLOG(0) << "Loaded feature table on cpu, feature_list_size[" << tot_len
<< "] node_ids_size[" << node_ids.size() << "]";
res.init_on_cpu(tot_len, (unsigned int)node_ids.size(), slot_num);
unsigned int offset = 0, ind = 0;
for (int i = 0; i < task_pool_size_; i++) {
for (int j = 0; j < (int)node_id_array[i].size(); j++) {
res.node_list[ind] = node_id_array[i][j];
res.fea_info_list[ind] = node_fea_info_array[i][j];
res.fea_info_list[ind++].feature_offset += offset;
}
for (size_t j = 0; j < feature_array[i].size(); j++) {
res.feature_list[offset + j] = feature_array[i][j];
res.slot_id_list[offset + j] = slot_id_array[i][j];
}
offset += feature_array[i].size();
}
return res;
}
paddle::framework::GpuPsCommGraph GraphTable::make_gpu_ps_graph(
int idx, std::vector<uint64_t> ids) {
std::vector<std::vector<uint64_t>> bags(task_pool_size_);
for (int i = 0; i < task_pool_size_; i++) {
auto predsize = ids.size() / task_pool_size_;
bags[i].reserve(predsize * 1.2);
}
for (auto x : ids) {
int location = x % shard_num % task_pool_size_;
bags[location].push_back(x);
}
std::vector<std::future<int>> tasks;
std::vector<uint64_t> node_array[task_pool_size_]; // node id list
std::vector<paddle::framework::GpuPsNodeInfo> info_array[task_pool_size_];
std::vector<uint64_t> edge_array[task_pool_size_]; // edge id list
for (size_t i = 0; i < bags.size(); i++) {
if (bags[i].size() > 0) {
tasks.push_back(_shards_task_pool[i]->enqueue([&, i, this]() -> int {
node_array[i].resize(bags[i].size());
info_array[i].resize(bags[i].size());
edge_array[i].reserve(bags[i].size());
for (size_t j = 0; j < bags[i].size(); j++) {
auto node_id = bags[i][j];
node_array[i][j] = node_id;
Node *v = find_node(0, idx, node_id);
if (v != nullptr) {
info_array[i][j].neighbor_offset = edge_array[i].size();
info_array[i][j].neighbor_size = v->get_neighbor_size();
for (size_t k = 0; k < v->get_neighbor_size(); k++) {
edge_array[i].push_back(v->get_neighbor_id(k));
}
} else {
info_array[i][j].neighbor_offset = 0;
info_array[i][j].neighbor_size = 0;
}
}
return 0;
}));
}
}
for (int i = 0; i < (int)tasks.size(); i++) tasks[i].get();
int64_t tot_len = 0;
for (int i = 0; i < task_pool_size_; i++) {
tot_len += edge_array[i].size();
}
paddle::framework::GpuPsCommGraph res;
res.init_on_cpu(tot_len, ids.size());
int64_t offset = 0, ind = 0;
for (int i = 0; i < task_pool_size_; i++) {
for (int j = 0; j < (int)node_array[i].size(); j++) {
res.node_list[ind] = node_array[i][j];
res.node_info_list[ind] = info_array[i][j];
res.node_info_list[ind++].neighbor_offset += offset;
}
for (size_t j = 0; j < edge_array[i].size(); j++) {
res.neighbor_list[offset + j] = edge_array[i][j];
}
offset += edge_array[i].size();
}
return res;
}
int32_t GraphTable::add_node_to_ssd(
int type_id, int idx, uint64_t src_id, char *data, int len) {
if (_db != NULL) {
char ch[sizeof(int) * 2 + sizeof(uint64_t)];
memcpy(ch, &type_id, sizeof(int));
memcpy(ch + sizeof(int), &idx, sizeof(int));
memcpy(ch + sizeof(int) * 2, &src_id, sizeof(uint64_t));
std::string str;
if (_db->get(src_id % shard_num % task_pool_size_,
ch,
sizeof(int) * 2 + sizeof(uint64_t),
str) == 0) {
uint64_t *stored_data = ((uint64_t *)str.c_str());
int n = str.size() / sizeof(uint64_t);
char *new_data = new char[n * sizeof(uint64_t) + len];
memcpy(new_data, stored_data, n * sizeof(uint64_t));
memcpy(new_data + n * sizeof(uint64_t), data, len);
_db->put(src_id % shard_num % task_pool_size_,
ch,
sizeof(int) * 2 + sizeof(uint64_t),
(char *)new_data,
n * sizeof(uint64_t) + len);
delete[] new_data;
} else {
_db->put(src_id % shard_num % task_pool_size_,
ch,
sizeof(int) * 2 + sizeof(uint64_t),
(char *)data,
len);
}
}
return 0;
}
char *GraphTable::random_sample_neighbor_from_ssd(
int idx,
uint64_t id,
int sample_size,
const std::shared_ptr<std::mt19937_64> rng,
int &actual_size) {
if (_db == NULL) {
actual_size = 0;
return NULL;
}
std::string str;
VLOG(2) << "sample ssd for key " << id;
char ch[sizeof(int) * 2 + sizeof(uint64_t)];
memset(ch, 0, sizeof(int));
memcpy(ch + sizeof(int), &idx, sizeof(int));
memcpy(ch + sizeof(int) * 2, &id, sizeof(uint64_t));
if (_db->get(id % shard_num % task_pool_size_,
ch,
sizeof(int) * 2 + sizeof(uint64_t),
str) == 0) {
uint64_t *data = ((uint64_t *)str.c_str());
int n = str.size() / sizeof(uint64_t);
std::unordered_map<int, int> m;
// std::vector<uint64_t> res;
int sm_size = std::min(n, sample_size);
actual_size = sm_size * Node::id_size;
char *buff = new char[actual_size];
for (int i = 0; i < sm_size; i++) {
std::uniform_int_distribution<int> distrib(0, n - i - 1);
int t = distrib(*rng);
// int t = rand() % (n-i);
int pos = 0;
auto iter = m.find(t);
if (iter != m.end()) {
pos = iter->second;
} else {
pos = t;
}
auto iter2 = m.find(n - i - 1);
int key2 = iter2 == m.end() ? n - i - 1 : iter2->second;
m[t] = key2;
m.erase(n - i - 1);
memcpy(buff + i * Node::id_size, &data[pos], Node::id_size);
// res.push_back(data[pos]);
}
for (int i = 0; i < actual_size; i += 8) {
VLOG(2) << "sampled an neighbor " << *(uint64_t *)&buff[i];
}
return buff;
}
actual_size = 0;
return NULL;
}
int64_t GraphTable::load_graph_to_memory_from_ssd(int idx,
std::vector<uint64_t> &ids) {
std::vector<std::vector<uint64_t>> bags(task_pool_size_);
for (auto x : ids) {
int location = x % shard_num % task_pool_size_;
bags[location].push_back(x);
}
std::vector<std::future<int>> tasks;
std::vector<int64_t> count(task_pool_size_, 0);
for (size_t i = 0; i < bags.size(); i++) {
if (bags[i].size() > 0) {
tasks.push_back(_shards_task_pool[i]->enqueue([&, i, idx, this]() -> int {
char ch[sizeof(int) * 2 + sizeof(uint64_t)];
memset(ch, 0, sizeof(int));
memcpy(ch + sizeof(int), &idx, sizeof(int));
for (size_t k = 0; k < bags[i].size(); k++) {
auto v = bags[i][k];
memcpy(ch + sizeof(int) * 2, &v, sizeof(uint64_t));
std::string str;
if (_db->get(i, ch, sizeof(int) * 2 + sizeof(uint64_t), str) == 0) {
count[i] += (int64_t)str.size();
for (size_t j = 0; j < (int)str.size(); j += sizeof(uint64_t)) {
uint64_t id = *(uint64_t *)(str.c_str() + j);
add_comm_edge(idx, v, id);
}
}
}
return 0;
}));
}
}
for (int i = 0; i < (int)tasks.size(); i++) tasks[i].get();
int64_t tot = 0;
for (auto x : count) tot += x;
return tot;
}
void GraphTable::make_partitions(int idx, int64_t byte_size, int device_len) {
VLOG(2) << "start to make graph partitions , byte_size = " << byte_size
<< " total memory cost = " << total_memory_cost;
if (total_memory_cost == 0) {
VLOG(0) << "no edges are detected,make partitions exits";
return;
}
auto &weight_map = node_weight[0][idx];
const double a = 2.0, y = 1.25, weight_param = 1.0;
int64_t gb_size_by_discount = byte_size * 0.8 * device_len;
if (gb_size_by_discount <= 0) gb_size_by_discount = 1;
int part_len = total_memory_cost / gb_size_by_discount;
if (part_len == 0) part_len = 1;
VLOG(2) << "part_len = " << part_len
<< " byte size = " << gb_size_by_discount;
partitions[idx].clear();
partitions[idx].resize(part_len);
std::vector<double> weight_cost(part_len, 0);
std::vector<int64_t> memory_remaining(part_len, gb_size_by_discount);
std::vector<double> score(part_len, 0);
std::unordered_map<uint64_t, int> id_map;
std::vector<rocksdb::Iterator *> iters;
for (int i = 0; i < task_pool_size_; i++) {
iters.push_back(_db->get_iterator(i));
iters[i]->SeekToFirst();
}
int next = 0;
while (iters.size()) {
if (next >= (int)iters.size()) {
next = 0;
}
if (!iters[next]->Valid()) {
iters.erase(iters.begin() + next);
continue;
}
std::string key = iters[next]->key().ToString();
int type_idx = *(int *)key.c_str();
int temp_idx = *(int *)(key.c_str() + sizeof(int));
if (type_idx != 0 || temp_idx != idx) {
iters[next]->Next();
next++;
continue;
}
std::string value = iters[next]->value().ToString();
std::uint64_t i_key = *(uint64_t *)(key.c_str() + sizeof(int) * 2);
for (int i = 0; i < part_len; i++) {
if (memory_remaining[i] < (int64_t)value.size()) {
score[i] = -100000.0;
} else {
score[i] = 0;
}
}
for (size_t j = 0; j < (int)value.size(); j += sizeof(uint64_t)) {
uint64_t v = *((uint64_t *)(value.c_str() + j));
int index = -1;
if (id_map.find(v) != id_map.end()) {
index = id_map[v];
score[index]++;
}
}
double base, weight_base = 0;
double w = 0;
bool has_weight = false;
if (weight_map.find(i_key) != weight_map.end()) {
w = weight_map[i_key];
has_weight = true;
}
int index = 0;
for (int i = 0; i < part_len; i++) {
base = gb_size_by_discount - memory_remaining[i] + value.size();
if (has_weight)
weight_base = weight_cost[i] + w * weight_param;
else {
weight_base = 0;
}
score[i] -= a * y * std::pow(1.0 * base, y - 1) + weight_base;
if (score[i] > score[index]) index = i;
VLOG(2) << "score" << i << " = " << score[i] << " memory left "
<< memory_remaining[i];
}
id_map[i_key] = index;
partitions[idx][index].push_back(i_key);
memory_remaining[index] -= (int64_t)value.size();
if (has_weight) weight_cost[index] += w;
iters[next]->Next();
next++;
}
for (int i = 0; i < part_len; i++) {
if (partitions[idx][i].size() == 0) {
partitions[idx].erase(partitions[idx].begin() + i);
i--;
part_len--;
continue;
}
VLOG(2) << " partition " << i << " size = " << partitions[idx][i].size();
for (auto x : partitions[idx][i]) {
VLOG(2) << "find a id " << x;
}
}
next_partition = 0;
}
void GraphTable::export_partition_files(int idx, std::string file_path) {
int part_len = partitions[idx].size();
if (part_len == 0) return;
if (file_path == "") file_path = ".";
if (file_path[(int)file_path.size() - 1] != '/') {
file_path += "/";
}
std::vector<std::future<int>> tasks;
for (int i = 0; i < part_len; i++) {
tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue(
[&, i, idx, this]() -> int {
std::string output_path =
file_path + "partition_" + std::to_string(i);
std::ofstream ofs(output_path);
if (ofs.fail()) {
VLOG(0) << "creating " << output_path << " failed";
return 0;
}
for (auto x : partitions[idx][i]) {
auto str = std::to_string(x);
ofs.write(str.c_str(), str.size());
ofs.write("\n", 1);
}
ofs.close();
return 0;
}));
}
for (int i = 0; i < (int)tasks.size(); i++) tasks[i].get();
}
void GraphTable::clear_graph(int idx) {
for (auto p : edge_shards[idx]) {
delete p;
}
edge_shards[idx].clear();
for (size_t i = 0; i < shard_num_per_server; i++) {
edge_shards[idx].push_back(new GraphShard());
}
}
int32_t GraphTable::load_next_partition(int idx) {
if (next_partition >= (int)partitions[idx].size()) {
VLOG(0) << "partition iteration is done";
return -1;
}
clear_graph(idx);
load_graph_to_memory_from_ssd(idx, partitions[idx][next_partition]);
next_partition++;
return 0;
}
int32_t GraphTable::load_edges_to_ssd(const std::string &path,
bool reverse_edge,
const std::string &edge_type) {
int idx = 0;
if (edge_type == "") {
VLOG(0) << "edge_type not specified, loading edges to " << id_to_edge[0]
<< " part";
} else {
if (edge_to_id.find(edge_type) == edge_to_id.end()) {
VLOG(0) << "edge_type " << edge_type
<< " is not defined, nothing will be loaded";
return 0;
}
idx = edge_to_id[edge_type];
}
total_memory_cost = 0;
auto paths = paddle::string::split_string<std::string>(path, ";");
int64_t count = 0;
std::string sample_type = "random";
for (auto path : paths) {
std::ifstream file(path);
std::string line;
while (std::getline(file, line)) {
VLOG(0) << "get a line from file " << line;
auto values = paddle::string::split_string<std::string>(line, "\t");
count++;
if (values.size() < 2) continue;
auto src_id = std::stoll(values[0]);
auto dist_ids = paddle::string::split_string<std::string>(values[1], ";");
std::vector<uint64_t> dist_data;
for (auto x : dist_ids) {
dist_data.push_back(std::stoll(x));
total_memory_cost += sizeof(uint64_t);
}
add_node_to_ssd(0,
idx,
src_id,
(char *)dist_data.data(),
(int)(dist_data.size() * sizeof(uint64_t)));
}
}
VLOG(0) << "total memory cost = " << total_memory_cost << " bytes";
return 0;
}
int32_t GraphTable::dump_edges_to_ssd(int idx) {
VLOG(2) << "calling dump edges to ssd";
std::vector<std::future<int64_t>> tasks;
auto &shards = edge_shards[idx];
for (size_t i = 0; i < shards.size(); ++i) {
tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue(
[&, i, this]() -> int64_t {
int64_t cost = 0;
std::vector<Node *> &v = shards[i]->get_bucket();
for (size_t j = 0; j < v.size(); j++) {
std::vector<uint64_t> s;
for (size_t k = 0; k < (int)v[j]->get_neighbor_size(); k++) {
s.push_back(v[j]->get_neighbor_id(k));
}
cost += v[j]->get_neighbor_size() * sizeof(uint64_t);
add_node_to_ssd(0,
idx,
v[j]->get_id(),
(char *)s.data(),
s.size() * sizeof(uint64_t));
}
return cost;
}));
}
for (size_t i = 0; i < tasks.size(); i++) total_memory_cost += tasks[i].get();
return 0;
}
int32_t GraphTable::make_complementary_graph(int idx, int64_t byte_size) {
VLOG(0) << "make_complementary_graph";
const int64_t fixed_size = byte_size / 8;
// std::vector<int64_t> edge_array[task_pool_size_];
std::vector<std::unordered_map<uint64_t, int>> count(task_pool_size_);
std::vector<std::future<int>> tasks;
auto &shards = edge_shards[idx];
for (size_t i = 0; i < shards.size(); ++i) {
tasks.push_back(
_shards_task_pool[i % task_pool_size_]->enqueue([&, i, this]() -> int {
std::vector<Node *> &v = shards[i]->get_bucket();
size_t ind = i % this->task_pool_size_;
for (size_t j = 0; j < v.size(); j++) {
// size_t location = v[j]->get_id();
for (size_t k = 0; k < v[j]->get_neighbor_size(); k++) {
count[ind][v[j]->get_neighbor_id(k)]++;
}
}
return 0;
}));
}
for (size_t i = 0; i < tasks.size(); i++) tasks[i].get();
std::unordered_map<uint64_t, int> final_count;
std::map<int, std::vector<uint64_t>> count_to_id;
std::vector<uint64_t> buffer;
clear_graph(idx);
for (int i = 0; i < task_pool_size_; i++) {
for (auto &p : count[i]) {
final_count[p.first] = final_count[p.first] + p.second;
}
count[i].clear();
}
for (auto &p : final_count) {
count_to_id[p.second].push_back(p.first);
VLOG(2) << p.first << " appear " << p.second << " times";
}
auto iter = count_to_id.rbegin();
while (iter != count_to_id.rend() && byte_size > 0) {
for (auto x : iter->second) {
buffer.push_back(x);
if (buffer.size() >= fixed_size) {
int64_t res = load_graph_to_memory_from_ssd(idx, buffer);
buffer.clear();
byte_size -= res;
}
if (byte_size <= 0) break;
}
iter++;
}
if (byte_size > 0 && buffer.size() > 0) {
int64_t res = load_graph_to_memory_from_ssd(idx, buffer);
byte_size -= res;
}
std::string sample_type = "random";
for (auto &shard : edge_shards[idx]) {
auto bucket = shard->get_bucket();
for (size_t i = 0; i < bucket.size(); i++) {
bucket[i]->build_sampler(sample_type);
}
}
return 0;
}
#endif
/*
int CompleteGraphSampler::run_graph_sampling() {
pthread_rwlock_t *rw_lock = graph_table->rw_lock.get();
pthread_rwlock_rdlock(rw_lock);
std::cout << "in graph sampling" << std::endl;
sample_nodes.clear();
sample_neighbors.clear();
sample_res.clear();
sample_nodes.resize(gpu_num);
sample_neighbors.resize(gpu_num);
sample_res.resize(gpu_num);
std::vector<std::vector<std::vector<paddle::framework::GpuPsGraphNode>>>
sample_nodes_ex(graph_table->task_pool_size_);
std::vector<std::vector<std::vector<int64_t>>> sample_neighbors_ex(
graph_table->task_pool_size_);
for (int i = 0; i < graph_table->task_pool_size_; i++) {
sample_nodes_ex[i].resize(gpu_num);
sample_neighbors_ex[i].resize(gpu_num);
}
std::vector<std::future<int>> tasks;
for (size_t i = 0; i < graph_table->shards.size(); ++i) {
tasks.push_back(
graph_table->_shards_task_pool[i % graph_table->task_pool_size_]
->enqueue([&, i, this]() -> int {
if (this->status == GraphSamplerStatus::terminating) return 0;
paddle::framework::GpuPsGraphNode node;
std::vector<Node *> &v =
this->graph_table->shards[i]->get_bucket();
size_t ind = i % this->graph_table->task_pool_size_;
for (size_t j = 0; j < v.size(); j++) {
size_t location = v[j]->get_id() % this->gpu_num;
node.node_id = v[j]->get_id();
node.neighbor_size = v[j]->get_neighbor_size();
node.neighbor_offset =
(int)sample_neighbors_ex[ind][location].size();
sample_nodes_ex[ind][location].emplace_back(node);
for (int k = 0; k < node.neighbor_size; k++)
sample_neighbors_ex[ind][location].push_back(
v[j]->get_neighbor_id(k));
}
return 0;
}));
}
for (size_t i = 0; i < tasks.size(); i++) tasks[i].get();
tasks.clear();
for (int i = 0; i < gpu_num; i++) {
tasks.push_back(
graph_table->_shards_task_pool[i % graph_table->task_pool_size_]
->enqueue([&, i, this]() -> int {
if (this->status == GraphSamplerStatus::terminating) return 0;
int total_offset = 0;
size_t ind = i % this->graph_table->task_pool_size_;
for (int j = 0; j < this->graph_table->task_pool_size_; j++) {
for (size_t k = 0; k < sample_nodes_ex[j][ind].size(); k++) {
sample_nodes[ind].push_back(sample_nodes_ex[j][ind][k]);
sample_nodes[ind].back().neighbor_offset += total_offset;
}
size_t neighbor_size = sample_neighbors_ex[j][ind].size();
total_offset += neighbor_size;
for (size_t k = 0; k < neighbor_size; k++) {
sample_neighbors[ind].push_back(
sample_neighbors_ex[j][ind][k]);
}
}
return 0;
}));
}
for (size_t i = 0; i < tasks.size(); i++) tasks[i].get();
if (this->status == GraphSamplerStatus::terminating) {
pthread_rwlock_unlock(rw_lock);
return 0;
}
for (int i = 0; i < gpu_num; i++) {
sample_res[i].node_list = sample_nodes[i].data();
sample_res[i].neighbor_list = sample_neighbors[i].data();
sample_res[i].node_size = sample_nodes[i].size();
sample_res[i].neighbor_size = sample_neighbors[i].size();
}
pthread_rwlock_unlock(rw_lock);
if (this->status == GraphSamplerStatus::terminating) {
return 0;
}
callback(sample_res);
return 0;
}
void CompleteGraphSampler::init(size_t gpu_num, GraphTable *graph_table,
std::vector<std::string> args) {
this->gpu_num = gpu_num;
this->graph_table = graph_table;
}
int BasicBfsGraphSampler::run_graph_sampling() {
pthread_rwlock_t *rw_lock = graph_table->rw_lock.get();
pthread_rwlock_rdlock(rw_lock);
while (rounds > 0 && status == GraphSamplerStatus::running) {
for (size_t i = 0; i < sample_neighbors_map.size(); i++) {
sample_neighbors_map[i].clear();
}
sample_neighbors_map.clear();
std::vector<int> nodes_left(graph_table->shards.size(),
node_num_for_each_shard);
std::promise<int> prom;
std::future<int> fut = prom.get_future();
sample_neighbors_map.resize(graph_table->task_pool_size_);
int task_size = 0;
std::vector<std::future<int>> tasks;
int init_size = 0;
//__sync_fetch_and_add
std::function<int(int, int64_t)> bfs = [&, this](int i, int id) -> int {
if (this->status == GraphSamplerStatus::terminating) {
int task_left = __sync_sub_and_fetch(&task_size, 1);
if (task_left == 0) {
prom.set_value(0);
}
return 0;
}
size_t ind = i % this->graph_table->task_pool_size_;
if (nodes_left[i] > 0) {
auto iter = sample_neighbors_map[ind].find(id);
if (iter == sample_neighbors_map[ind].end()) {
Node *node = graph_table->shards[i]->find_node(id);
if (node != NULL) {
nodes_left[i]--;
sample_neighbors_map[ind][id] = std::vector<int64_t>();
iter = sample_neighbors_map[ind].find(id);
size_t edge_fetch_size =
std::min((size_t) this->edge_num_for_each_node,
node->get_neighbor_size());
for (size_t k = 0; k < edge_fetch_size; k++) {
int64_t neighbor_id = node->get_neighbor_id(k);
int node_location = neighbor_id % this->graph_table->shard_num %
this->graph_table->task_pool_size_;
__sync_add_and_fetch(&task_size, 1);
graph_table->_shards_task_pool[node_location]->enqueue(
bfs, neighbor_id % this->graph_table->shard_num, neighbor_id);
iter->second.push_back(neighbor_id);
}
}
}
}
int task_left = __sync_sub_and_fetch(&task_size, 1);
if (task_left == 0) {
prom.set_value(0);
}
return 0;
};
for (size_t i = 0; i < graph_table->shards.size(); ++i) {
std::vector<Node *> &v = graph_table->shards[i]->get_bucket();
if (v.size() > 0) {
int search_size = std::min(init_search_size, (int)v.size());
for (int k = 0; k < search_size; k++) {
init_size++;
__sync_add_and_fetch(&task_size, 1);
int64_t id = v[k]->get_id();
graph_table->_shards_task_pool[i % graph_table->task_pool_size_]
->enqueue(bfs, i, id);
}
} // if
}
if (init_size == 0) {
prom.set_value(0);
}
fut.get();
if (this->status == GraphSamplerStatus::terminating) {
pthread_rwlock_unlock(rw_lock);
return 0;
}
VLOG(0) << "BasicBfsGraphSampler finishes the graph searching task";
sample_nodes.clear();
sample_neighbors.clear();
sample_res.clear();
sample_nodes.resize(gpu_num);
sample_neighbors.resize(gpu_num);
sample_res.resize(gpu_num);
std::vector<std::vector<std::vector<paddle::framework::GpuPsGraphNode>>>
sample_nodes_ex(graph_table->task_pool_size_);
std::vector<std::vector<std::vector<int64_t>>> sample_neighbors_ex(
graph_table->task_pool_size_);
for (int i = 0; i < graph_table->task_pool_size_; i++) {
sample_nodes_ex[i].resize(gpu_num);
sample_neighbors_ex[i].resize(gpu_num);
}
tasks.clear();
for (size_t i = 0; i < (size_t)graph_table->task_pool_size_; ++i) {
tasks.push_back(
graph_table->_shards_task_pool[i]->enqueue([&, i, this]() -> int {
if (this->status == GraphSamplerStatus::terminating) {
return 0;
}
paddle::framework::GpuPsGraphNode node;
auto iter = sample_neighbors_map[i].begin();
size_t ind = i;
for (; iter != sample_neighbors_map[i].end(); iter++) {
size_t location = iter->first % this->gpu_num;
node.node_id = iter->first;
node.neighbor_size = iter->second.size();
node.neighbor_offset =
(int)sample_neighbors_ex[ind][location].size();
sample_nodes_ex[ind][location].emplace_back(node);
for (auto k : iter->second)
sample_neighbors_ex[ind][location].push_back(k);
}
return 0;
}));
}
for (size_t i = 0; i < tasks.size(); i++) {
tasks[i].get();
sample_neighbors_map[i].clear();
}
tasks.clear();
if (this->status == GraphSamplerStatus::terminating) {
pthread_rwlock_unlock(rw_lock);
return 0;
}
for (size_t i = 0; i < (size_t)gpu_num; i++) {
tasks.push_back(
graph_table->_shards_task_pool[i % graph_table->task_pool_size_]
->enqueue([&, i, this]() -> int {
if (this->status == GraphSamplerStatus::terminating) {
pthread_rwlock_unlock(rw_lock);
return 0;
}
int total_offset = 0;
for (int j = 0; j < this->graph_table->task_pool_size_; j++) {
for (size_t k = 0; k < sample_nodes_ex[j][i].size(); k++) {
sample_nodes[i].push_back(sample_nodes_ex[j][i][k]);
sample_nodes[i].back().neighbor_offset += total_offset;
}
size_t neighbor_size = sample_neighbors_ex[j][i].size();
total_offset += neighbor_size;
for (size_t k = 0; k < neighbor_size; k++) {
sample_neighbors[i].push_back(sample_neighbors_ex[j][i][k]);
}
}
return 0;
}));
}
for (size_t i = 0; i < tasks.size(); i++) tasks[i].get();
if (this->status == GraphSamplerStatus::terminating) {
pthread_rwlock_unlock(rw_lock);
return 0;
}
for (int i = 0; i < gpu_num; i++) {
sample_res[i].node_list = sample_nodes[i].data();
sample_res[i].neighbor_list = sample_neighbors[i].data();
sample_res[i].node_size = sample_nodes[i].size();
sample_res[i].neighbor_size = sample_neighbors[i].size();
}
pthread_rwlock_unlock(rw_lock);
if (this->status == GraphSamplerStatus::terminating) {
return 0;
}
callback(sample_res);
rounds--;
if (rounds > 0) {
for (int i = 0;
i < interval && this->status == GraphSamplerStatus::running; i++) {
std::this_thread::sleep_for(std::chrono::seconds(1));
}
}
VLOG(0)<<"bfs returning";
}
return 0;
}
void BasicBfsGraphSampler::init(size_t gpu_num, GraphTable *graph_table,
std::vector<std::string> args) {
this->gpu_num = gpu_num;
this->graph_table = graph_table;
init_search_size = args.size() > 0 ? std::stoi(args[0]) : 10;
node_num_for_each_shard = args.size() > 1 ? std::stoi(args[1]) : 10;
edge_num_for_each_node = args.size() > 2 ? std::stoi(args[2]) : 10;
rounds = args.size() > 3 ? std::stoi(args[3]) : 1;
interval = args.size() > 4 ? std::stoi(args[4]) : 60;
}
#endif
*/
std::vector<Node *> GraphShard::get_batch(int start, int end, int step) {
if (start < 0) start = 0;
std::vector<Node *> res;
for (int pos = start; pos < std::min(end, (int)bucket.size()); pos += step) {
res.push_back(bucket[pos]);
}
return res;
}
size_t GraphShard::get_size() { return bucket.size(); }
int32_t GraphTable::add_comm_edge(int idx, uint64_t src_id, uint64_t dst_id) {
size_t src_shard_id = src_id % shard_num;
if (src_shard_id >= shard_end || src_shard_id < shard_start) {
return -1;
}
size_t index = src_shard_id - shard_start;
edge_shards[idx][index]->add_graph_node(src_id)->build_edges(false);
edge_shards[idx][index]->add_neighbor(src_id, dst_id, 1.0);
return 0;
}
int32_t GraphTable::add_graph_node(int idx,
std::vector<uint64_t> &id_list,
std::vector<bool> &is_weight_list) {
auto &shards = edge_shards[idx];
size_t node_size = id_list.size();
std::vector<std::vector<std::pair<uint64_t, bool>>> batch(task_pool_size_);
for (size_t i = 0; i < node_size; i++) {
size_t shard_id = id_list[i] % shard_num;
if (shard_id >= shard_end || shard_id < shard_start) {
continue;
}
batch[get_thread_pool_index(id_list[i])].push_back(
{id_list[i], i < is_weight_list.size() ? is_weight_list[i] : false});
}
std::vector<std::future<int>> tasks;
for (size_t i = 0; i < batch.size(); ++i) {
if (!batch[i].size()) continue;
tasks.push_back(
_shards_task_pool[i]->enqueue([&shards, &batch, i, this]() -> int {
for (auto &p : batch[i]) {
size_t index = p.first % this->shard_num - this->shard_start;
shards[index]->add_graph_node(p.first)->build_edges(p.second);
}
return 0;
}));
}
for (size_t i = 0; i < tasks.size(); i++) tasks[i].get();
return 0;
}
int32_t GraphTable::remove_graph_node(int idx, std::vector<uint64_t> &id_list) {
size_t node_size = id_list.size();
std::vector<std::vector<uint64_t>> batch(task_pool_size_);
for (size_t i = 0; i < node_size; i++) {
size_t shard_id = id_list[i] % shard_num;
if (shard_id >= shard_end || shard_id < shard_start) continue;
batch[get_thread_pool_index(id_list[i])].push_back(id_list[i]);
}
auto &shards = edge_shards[idx];
std::vector<std::future<int>> tasks;
for (size_t i = 0; i < batch.size(); ++i) {
if (!batch[i].size()) continue;
tasks.push_back(
_shards_task_pool[i]->enqueue([&shards, &batch, i, this]() -> int {
for (auto &p : batch[i]) {
size_t index = p % this->shard_num - this->shard_start;
shards[index]->delete_node(p);
}
return 0;
}));
}
for (size_t i = 0; i < tasks.size(); i++) tasks[i].get();
return 0;
}
void GraphShard::clear() {
for (size_t i = 0; i < bucket.size(); i++) {
delete bucket[i];
}
bucket.clear();
node_location.clear();
}
GraphShard::~GraphShard() { clear(); }
void GraphShard::delete_node(uint64_t id) {
auto iter = node_location.find(id);
if (iter == node_location.end()) return;
int pos = iter->second;
delete bucket[pos];
if (pos != (int)bucket.size() - 1) {
bucket[pos] = bucket.back();
node_location[bucket.back()->get_id()] = pos;
}
node_location.erase(id);
bucket.pop_back();
}
GraphNode *GraphShard::add_graph_node(uint64_t id) {
if (node_location.find(id) == node_location.end()) {
node_location[id] = bucket.size();
bucket.push_back(new GraphNode(id));
}
return (GraphNode *)bucket[node_location[id]];
}
GraphNode *GraphShard::add_graph_node(Node *node) {
auto id = node->get_id();
if (node_location.find(id) == node_location.end()) {
node_location[id] = bucket.size();
bucket.push_back(node);
}
return (GraphNode *)bucket[node_location[id]];
}
FeatureNode *GraphShard::add_feature_node(uint64_t id, bool is_overlap) {
if (node_location.find(id) == node_location.end()) {
node_location[id] = bucket.size();
bucket.push_back(new FeatureNode(id));
return (FeatureNode *)bucket[node_location[id]];
}
if (is_overlap) {
return (FeatureNode *)bucket[node_location[id]];
}
return NULL;
}
void GraphShard::add_neighbor(uint64_t id, uint64_t dst_id, float weight) {
find_node(id)->add_edge(dst_id, weight);
}
Node *GraphShard::find_node(uint64_t id) {
auto iter = node_location.find(id);
return iter == node_location.end() ? nullptr : bucket[iter->second];
}
GraphTable::~GraphTable() {
for (int i = 0; i < (int)edge_shards.size(); i++) {
for (auto p : edge_shards[i]) {
delete p;
}
edge_shards[i].clear();
}
for (int i = 0; i < (int)feature_shards.size(); i++) {
for (auto p : feature_shards[i]) {
delete p;
}
feature_shards[i].clear();
}
}
int32_t GraphTable::Load(const std::string &path, const std::string &param) {
bool load_edge = (param[0] == 'e');
bool load_node = (param[0] == 'n');
if (load_edge) {
bool reverse_edge = (param[1] == '<');
std::string edge_type = param.substr(2);
return this->load_edges(path, reverse_edge, edge_type);
}
if (load_node) {
std::string node_type = param.substr(1);
return this->load_nodes(path, node_type);
}
return 0;
}
std::string GraphTable::get_inverse_etype(std::string &etype) {
auto etype_split = paddle::string::split_string<std::string>(etype, "2");
std::string res;
if ((int)etype_split.size() == 3) {
res = etype_split[2] + "2" + etype_split[1] + "2" + etype_split[0];
} else {
res = etype_split[1] + "2" + etype_split[0];
}
return res;
}
int32_t GraphTable::load_node_and_edge_file(std::string etype,
std::string ntype,
std::string epath,
std::string npath,
int part_num,
bool reverse) {
auto etypes = paddle::string::split_string<std::string>(etype, ",");
auto ntypes = paddle::string::split_string<std::string>(ntype, ",");
VLOG(0) << "etypes size: " << etypes.size();
VLOG(0) << "whether reverse: " << reverse;
std::string delim = ";";
size_t total_len = etypes.size() + 1; // 1 is for node
std::vector<std::future<int>> tasks;
for (size_t i = 0; i < total_len; i++) {
tasks.push_back(
_shards_task_pool[i % task_pool_size_]->enqueue([&, i, this]() -> int {
if (i < etypes.size()) {
std::string etype_path = epath + "/" + etypes[i];
auto etype_path_list = paddle::framework::localfs_list(etype_path);
std::string etype_path_str;
if (part_num > 0 && part_num < (int)etype_path_list.size()) {
std::vector<std::string> sub_etype_path_list(
etype_path_list.begin(), etype_path_list.begin() + part_num);
etype_path_str =
paddle::string::join_strings(sub_etype_path_list, delim);
} else {
etype_path_str =
paddle::string::join_strings(etype_path_list, delim);
}
this->load_edges(etype_path_str, false, etypes[i]);
if (reverse) {
std::string r_etype = get_inverse_etype(etypes[i]);
this->load_edges(etype_path_str, true, r_etype);
}
} else {
auto npath_list = paddle::framework::localfs_list(npath);
std::string npath_str;
if (part_num > 0 && part_num < (int)npath_list.size()) {
std::vector<std::string> sub_npath_list(
npath_list.begin(), npath_list.begin() + part_num);
npath_str = paddle::string::join_strings(sub_npath_list, delim);
} else {
npath_str = paddle::string::join_strings(npath_list, delim);
}
if (ntypes.size() == 0) {
VLOG(0) << "node_type not specified, nothing will be loaded ";
return 0;
}
if (FLAGS_graph_load_in_parallel) {
this->load_nodes(npath_str, "");
} else {
for (size_t j = 0; j < ntypes.size(); j++) {
this->load_nodes(npath_str, ntypes[j]);
}
}
}
return 0;
}));
}
for (int i = 0; i < (int)tasks.size(); i++) tasks[i].get();
return 0;
}
int32_t GraphTable::get_nodes_ids_by_ranges(
int type_id,
int idx,
std::vector<std::pair<int, int>> ranges,
std::vector<uint64_t> &res) {
std::mutex mutex;
int start = 0, end, index = 0, total_size = 0;
res.clear();
auto &shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx];
std::vector<std::future<size_t>> tasks;
for (size_t i = 0; i < shards.size() && index < (int)ranges.size(); i++) {
end = total_size + shards[i]->get_size();
start = total_size;
while (start < end && index < (int)ranges.size()) {
if (ranges[index].second <= start)
index++;
else if (ranges[index].first >= end) {
break;
} else {
int first = std::max(ranges[index].first, start);
int second = std::min(ranges[index].second, end);
start = second;
first -= total_size;
second -= total_size;
tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue(
[&shards, this, first, second, i, &res, &mutex]() -> size_t {
std::vector<uint64_t> keys;
shards[i]->get_ids_by_range(first, second, &keys);
size_t num = keys.size();
mutex.lock();
res.reserve(res.size() + num);
for (auto &id : keys) {
res.push_back(id);
std::swap(res[rand() % res.size()], res[(int)res.size() - 1]);
}
mutex.unlock();
return num;
}));
}
}
total_size += shards[i]->get_size();
}
for (size_t i = 0; i < tasks.size(); i++) {
tasks[i].get();
}
return 0;
}
std::pair<uint64_t, uint64_t> GraphTable::parse_node_file(
const std::string &path, const std::string &node_type, int idx) {
std::ifstream file(path);
std::string line;
uint64_t local_count = 0;
uint64_t local_valid_count = 0;
int num = 0;
std::vector<paddle::string::str_ptr> vals;
size_t n = node_type.length();
while (std::getline(file, line)) {
if (strncmp(line.c_str(), node_type.c_str(), n) != 0) {
continue;
}
vals.clear();
num = paddle::string::split_string_ptr(
line.c_str() + n + 1, line.length() - n - 1, '\t', &vals);
if (num == 0) {
continue;
}
uint64_t id = std::strtoul(vals[0].ptr, NULL, 10);
size_t shard_id = id % shard_num;
if (shard_id >= shard_end || shard_id < shard_start) {
VLOG(4) << "will not load " << id << " from " << path
<< ", please check id distribution";
continue;
}
local_count++;
size_t index = shard_id - shard_start;
auto node = feature_shards[idx][index]->add_feature_node(id, false);
if (node != NULL) {
node->set_feature_size(feat_name[idx].size());
for (int i = 1; i < num; ++i) {
auto &v = vals[i];
parse_feature(idx, v.ptr, v.len, node);
}
}
local_valid_count++;
}
VLOG(2) << "node_type[" << node_type << "] loads " << local_count
<< " nodes from filepath->" << path;
return {local_count, local_valid_count};
}
std::pair<uint64_t, uint64_t> GraphTable::parse_node_file(
const std::string &path) {
std::ifstream file(path);
std::string line;
uint64_t local_count = 0;
uint64_t local_valid_count = 0;
int idx = 0;
auto path_split = paddle::string::split_string<std::string>(path, "/");
auto path_name = path_split[path_split.size() - 1];
int num = 0;
std::vector<paddle::string::str_ptr> vals;
while (std::getline(file, line)) {
vals.clear();
num = paddle::string::split_string_ptr(
line.c_str(), line.length(), '\t', &vals);
if (vals.empty()) {
continue;
}
std::string parse_node_type = vals[0].to_string();
auto it = feature_to_id.find(parse_node_type);
if (it == feature_to_id.end()) {
VLOG(0) << parse_node_type << "type error, please check";
continue;
}
idx = it->second;
uint64_t id = std::strtoul(vals[1].ptr, NULL, 10);
size_t shard_id = id % shard_num;
if (shard_id >= shard_end || shard_id < shard_start) {
VLOG(4) << "will not load " << id << " from " << path
<< ", please check id distribution";
continue;
}
local_count++;
size_t index = shard_id - shard_start;
auto node = feature_shards[idx][index]->add_feature_node(id, false);
if (node != NULL) {
for (int i = 2; i < num; ++i) {
auto &v = vals[i];
parse_feature(idx, v.ptr, v.len, node);
}
}
local_valid_count++;
}
VLOG(2) << local_valid_count << "/" << local_count << " nodes from filepath->"
<< path;
return {local_count, local_valid_count};
}
// TODO opt load all node_types in once reading
int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) {
auto paths = paddle::string::split_string<std::string>(path, ";");
uint64_t count = 0;
uint64_t valid_count = 0;
int idx = 0;
if (FLAGS_graph_load_in_parallel) {
if (node_type == "") {
VLOG(0) << "Begin GraphTable::load_nodes(), will load all node_type once";
}
std::vector<std::future<std::pair<uint64_t, uint64_t>>> tasks;
for (size_t i = 0; i < paths.size(); i++) {
tasks.push_back(load_node_edge_task_pool->enqueue(
[&, i, this]() -> std::pair<uint64_t, uint64_t> {
return parse_node_file(paths[i]);
}));
}
for (int i = 0; i < (int)tasks.size(); i++) {
auto res = tasks[i].get();
count += res.first;
valid_count += res.second;
}
} else {
VLOG(0) << "Begin GraphTable::load_nodes() node_type[" << node_type << "]";
if (node_type == "") {
VLOG(0) << "node_type not specified, loading edges to "
<< id_to_feature[0] << " part";
} else {
if (feature_to_id.find(node_type) == feature_to_id.end()) {
VLOG(0) << "node_type " << node_type
<< " is not defined, nothing will be loaded";
return 0;
}
idx = feature_to_id[node_type];
}
for (auto path : paths) {
VLOG(2) << "Begin GraphTable::load_nodes(), path[" << path << "]";
auto res = parse_node_file(path, node_type, idx);
count += res.first;
valid_count += res.second;
}
}
VLOG(0) << valid_count << "/" << count << " nodes in node_type[ " << node_type
<< "] are loaded successfully!";
return 0;
}
int32_t GraphTable::build_sampler(int idx, std::string sample_type) {
for (auto &shard : edge_shards[idx]) {
auto bucket = shard->get_bucket();
for (size_t i = 0; i < bucket.size(); i++) {
bucket[i]->build_sampler(sample_type);
}
}
return 0;
}
std::pair<uint64_t, uint64_t> GraphTable::parse_edge_file(
const std::string &path, int idx, bool reverse) {
std::string sample_type = "random";
bool is_weighted = false;
std::ifstream file(path);
std::string line;
uint64_t local_count = 0;
uint64_t local_valid_count = 0;
uint64_t part_num = 0;
if (FLAGS_graph_load_in_parallel) {
auto path_split = paddle::string::split_string<std::string>(path, "/");
auto part_name_split = paddle::string::split_string<std::string>(
path_split[path_split.size() - 1], "-");
part_num = std::stoull(part_name_split[part_name_split.size() - 1]);
}
while (std::getline(file, line)) {
size_t start = line.find_first_of('\t');
if (start == std::string::npos) continue;
local_count++;
uint64_t src_id = std::stoull(&line[0]);
uint64_t dst_id = std::stoull(&line[start + 1]);
if (reverse) {
std::swap(src_id, dst_id);
}
size_t src_shard_id = src_id % shard_num;
if (FLAGS_graph_load_in_parallel) {
if (src_shard_id != (part_num % shard_num)) {
continue;
}
}
float weight = 1;
size_t last = line.find_last_of('\t');
if (start != last) {
weight = std::stof(&line[last + 1]);
sample_type = "weighted";
is_weighted = true;
}
if (src_shard_id >= shard_end || src_shard_id < shard_start) {
VLOG(4) << "will not load " << src_id << " from " << path
<< ", please check id distribution";
continue;
}
size_t index = src_shard_id - shard_start;
auto node = edge_shards[idx][index]->add_graph_node(src_id);
if (node != NULL) {
node->build_edges(is_weighted);
node->add_edge(dst_id, weight);
}
local_valid_count++;
}
VLOG(2) << local_count << " edges are loaded from filepath->" << path;
return {local_count, local_valid_count};
}
int32_t GraphTable::load_edges(const std::string &path,
bool reverse_edge,
const std::string &edge_type) {
#ifdef PADDLE_WITH_HETERPS
if (search_level == 2) total_memory_cost = 0;
const uint64_t fixed_load_edges = 1000000;
#endif
int idx = 0;
if (edge_type == "") {
VLOG(0) << "edge_type not specified, loading edges to " << id_to_edge[0]
<< " part";
} else {
if (edge_to_id.find(edge_type) == edge_to_id.end()) {
VLOG(0) << "edge_type " << edge_type
<< " is not defined, nothing will be loaded";
return 0;
}
idx = edge_to_id[edge_type];
}
auto paths = paddle::string::split_string<std::string>(path, ";");
uint64_t count = 0;
uint64_t valid_count = 0;
VLOG(0) << "Begin GraphTable::load_edges() edge_type[" << edge_type << "]";
if (FLAGS_graph_load_in_parallel) {
std::vector<std::future<std::pair<uint64_t, uint64_t>>> tasks;
for (int i = 0; i < paths.size(); i++) {
tasks.push_back(load_node_edge_task_pool->enqueue(
[&, i, idx, this]() -> std::pair<uint64_t, uint64_t> {
return parse_edge_file(paths[i], idx, reverse_edge);
}));
}
for (int j = 0; j < (int)tasks.size(); j++) {
auto res = tasks[j].get();
count += res.first;
valid_count += res.second;
}
} else {
for (auto path : paths) {
auto res = parse_edge_file(path, idx, reverse_edge);
count += res.first;
valid_count += res.second;
}
}
VLOG(0) << valid_count << "/" << count << " edge_type[" << edge_type
<< "] edges are loaded successfully";
#ifdef PADDLE_WITH_HETERPS
if (search_level == 2) {
if (count > 0) {
dump_edges_to_ssd(idx);
VLOG(0) << "dumping edges to ssd, edge count is reset to 0";
clear_graph(idx);
count = 0;
}
return 0;
}
#endif
if (!build_sampler_on_cpu) {
// To reduce memory overhead, CPU samplers won't be created in gpugraph.
// In order not to affect the sampler function of other scenario,
// this optimization is only performed in load_edges function.
VLOG(0) << "run in gpugraph mode!";
} else {
std::string sample_type = "random";
VLOG(0) << "build sampler ... ";
for (auto &shard : edge_shards[idx]) {
auto bucket = shard->get_bucket();
for (size_t i = 0; i < bucket.size(); i++) {
bucket[i]->build_sampler(sample_type);
}
}
}
return 0;
}
Node *GraphTable::find_node(int type_id, uint64_t id) {
size_t shard_id = id % shard_num;
if (shard_id >= shard_end || shard_id < shard_start) {
return nullptr;
}
Node *node = nullptr;
size_t index = shard_id - shard_start;
auto &search_shards = type_id == 0 ? edge_shards : feature_shards;
for (auto &search_shard : search_shards) {
PADDLE_ENFORCE_NOT_NULL(search_shard[index],
paddle::platform::errors::InvalidArgument(
"search_shard[%d] should not be null.", index));
node = search_shard[index]->find_node(id);
if (node != nullptr) {
break;
}
}
return node;
}
Node *GraphTable::find_node(int type_id, int idx, uint64_t id) {
size_t shard_id = id % shard_num;
if (shard_id >= shard_end || shard_id < shard_start) {
return nullptr;
}
size_t index = shard_id - shard_start;
auto &search_shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx];
PADDLE_ENFORCE_NOT_NULL(search_shards[index],
paddle::platform::errors::InvalidArgument(
"search_shard[%d] should not be null.", index));
Node *node = search_shards[index]->find_node(id);
return node;
}
uint32_t GraphTable::get_thread_pool_index(uint64_t node_id) {
return node_id % shard_num % shard_num_per_server % task_pool_size_;
}
uint32_t GraphTable::get_thread_pool_index_by_shard_index(
uint64_t shard_index) {
return shard_index % shard_num_per_server % task_pool_size_;
}
int32_t GraphTable::clear_nodes(int type_id, int idx) {
auto &search_shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx];
for (size_t i = 0; i < search_shards.size(); i++) {
search_shards[i]->clear();
}
return 0;
}
int32_t GraphTable::random_sample_nodes(int type_id,
int idx,
int sample_size,
std::unique_ptr<char[]> &buffer,
int &actual_size) {
int total_size = 0;
auto &shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx];
for (int i = 0; i < (int)shards.size(); i++) {
total_size += shards[i]->get_size();
}
if (sample_size > total_size) sample_size = total_size;
int range_num = random_sample_nodes_ranges;
if (range_num > sample_size) range_num = sample_size;
if (sample_size == 0 || range_num == 0) return 0;
std::vector<int> ranges_len, ranges_pos;
int remain = sample_size, last_pos = -1, num;
std::set<int> separator_set;
for (int i = 0; i < range_num - 1; i++) {
while (separator_set.find(num = rand() % (sample_size - 1)) !=
separator_set.end())
;
separator_set.insert(num);
}
for (auto p : separator_set) {
ranges_len.push_back(p - last_pos);
last_pos = p;
}
ranges_len.push_back(sample_size - 1 - last_pos);
remain = total_size - sample_size + range_num;
separator_set.clear();
for (int i = 0; i < range_num; i++) {
while (separator_set.find(num = rand() % remain) != separator_set.end())
;
separator_set.insert(num);
}
int used = 0, index = 0;
last_pos = -1;
for (auto p : separator_set) {
used += p - last_pos - 1;
last_pos = p;
ranges_pos.push_back(used);
used += ranges_len[index++];
}
std::vector<std::pair<int, int>> first_half, second_half;
int start_index = rand() % total_size;
for (size_t i = 0; i < ranges_len.size() && i < ranges_pos.size(); i++) {
if (ranges_pos[i] + ranges_len[i] - 1 + start_index < total_size)
first_half.push_back({ranges_pos[i] + start_index,
ranges_pos[i] + ranges_len[i] + start_index});
else if (ranges_pos[i] + start_index >= total_size) {
second_half.push_back(
{ranges_pos[i] + start_index - total_size,
ranges_pos[i] + ranges_len[i] + start_index - total_size});
} else {
first_half.push_back({ranges_pos[i] + start_index, total_size});
second_half.push_back(
{0, ranges_pos[i] + ranges_len[i] + start_index - total_size});
}
}
for (auto &pair : first_half) second_half.push_back(pair);
std::vector<uint64_t> res;
get_nodes_ids_by_ranges(type_id, idx, second_half, res);
actual_size = res.size() * sizeof(uint64_t);
buffer.reset(new char[actual_size]);
char *pointer = buffer.get();
memcpy(pointer, res.data(), actual_size);
return 0;
}
int32_t GraphTable::random_sample_neighbors(
int idx,
uint64_t *node_ids,
int sample_size,
std::vector<std::shared_ptr<char>> &buffers,
std::vector<int> &actual_sizes,
bool need_weight) {
size_t node_num = buffers.size();
std::function<void(char *)> char_del = [](char *c) { delete[] c; };
std::vector<std::future<int>> tasks;
std::vector<std::vector<uint32_t>> seq_id(task_pool_size_);
std::vector<std::vector<SampleKey>> id_list(task_pool_size_);
size_t index;
for (size_t idy = 0; idy < node_num; ++idy) {
index = get_thread_pool_index(node_ids[idy]);
seq_id[index].emplace_back(idy);
id_list[index].emplace_back(idx, node_ids[idy], sample_size, need_weight);
}
for (int i = 0; i < (int)seq_id.size(); i++) {
if (seq_id[i].size() == 0) continue;
tasks.push_back(_shards_task_pool[i]->enqueue([&, i, this]() -> int {
uint64_t node_id;
std::vector<std::pair<SampleKey, SampleResult>> r;
LRUResponse response = LRUResponse::blocked;
if (use_cache) {
response =
scaled_lru->query(i, id_list[i].data(), id_list[i].size(), r);
}
int index = 0;
std::vector<SampleResult> sample_res;
std::vector<SampleKey> sample_keys;
auto &rng = _shards_task_rng_pool[i];
for (size_t k = 0; k < id_list[i].size(); k++) {
if (index < (int)r.size() &&
r[index].first.node_key == id_list[i][k].node_key) {
int idy = seq_id[i][k];
actual_sizes[idy] = r[index].second.actual_size;
buffers[idy] = r[index].second.buffer;
index++;
} else {
node_id = id_list[i][k].node_key;
Node *node = find_node(0, idx, node_id);
int idy = seq_id[i][k];
int &actual_size = actual_sizes[idy];
if (node == nullptr) {
#ifdef PADDLE_WITH_HETERPS
if (search_level == 2) {
VLOG(2) << "enter sample from ssd for node_id " << node_id;
char *buffer_addr = random_sample_neighbor_from_ssd(
idx, node_id, sample_size, rng, actual_size);
if (actual_size != 0) {
std::shared_ptr<char> &buffer = buffers[idy];
buffer.reset(buffer_addr, char_del);
}
VLOG(2) << "actual sampled size from ssd = " << actual_sizes[idy];
continue;
}
#endif
actual_size = 0;
continue;
}
std::shared_ptr<char> &buffer = buffers[idy];
std::vector<int> res = node->sample_k(sample_size, rng);
actual_size =
res.size() * (need_weight ? (Node::id_size + Node::weight_size)
: Node::id_size);
int offset = 0;
uint64_t id;
float weight;
char *buffer_addr = new char[actual_size];
if (response == LRUResponse::ok) {
sample_keys.emplace_back(idx, node_id, sample_size, need_weight);
sample_res.emplace_back(actual_size, buffer_addr);
buffer = sample_res.back().buffer;
} else {
buffer.reset(buffer_addr, char_del);
}
for (int &x : res) {
id = node->get_neighbor_id(x);
memcpy(buffer_addr + offset, &id, Node::id_size);
offset += Node::id_size;
if (need_weight) {
weight = node->get_neighbor_weight(x);
memcpy(buffer_addr + offset, &weight, Node::weight_size);
offset += Node::weight_size;
}
}
}
}
if (sample_res.size()) {
scaled_lru->insert(
i, sample_keys.data(), sample_res.data(), sample_keys.size());
}
return 0;
}));
}
for (auto &t : tasks) {
t.get();
}
return 0;
}
int32_t GraphTable::get_node_feat(int idx,
const std::vector<uint64_t> &node_ids,
const std::vector<std::string> &feature_names,
std::vector<std::vector<std::string>> &res) {
size_t node_num = node_ids.size();
std::vector<std::future<int>> tasks;
for (size_t idy = 0; idy < node_num; ++idy) {
uint64_t node_id = node_ids[idy];
tasks.push_back(_shards_task_pool[get_thread_pool_index(node_id)]->enqueue(
[&, idx, idy, node_id]() -> int {
Node *node = find_node(1, idx, node_id);
if (node == nullptr) {
return 0;
}
for (int feat_idx = 0; feat_idx < (int)feature_names.size();
++feat_idx) {
const std::string &feature_name = feature_names[feat_idx];
if (feat_id_map[idx].find(feature_name) != feat_id_map[idx].end()) {
// res[feat_idx][idx] =
// node->get_feature(feat_id_map[feature_name]);
auto feat = node->get_feature(feat_id_map[idx][feature_name]);
res[feat_idx][idy] = feat;
}
}
return 0;
}));
}
for (size_t idy = 0; idy < node_num; ++idy) {
tasks[idy].get();
}
return 0;
}
int32_t GraphTable::set_node_feat(
int idx,
const std::vector<uint64_t> &node_ids,
const std::vector<std::string> &feature_names,
const std::vector<std::vector<std::string>> &res) {
size_t node_num = node_ids.size();
std::vector<std::future<int>> tasks;
for (size_t idy = 0; idy < node_num; ++idy) {
uint64_t node_id = node_ids[idy];
tasks.push_back(_shards_task_pool[get_thread_pool_index(node_id)]->enqueue(
[&, idx, idy, node_id]() -> int {
size_t index = node_id % this->shard_num - this->shard_start;
auto node = feature_shards[idx][index]->add_feature_node(node_id);
node->set_feature_size(this->feat_name[idx].size());
for (int feat_idx = 0; feat_idx < (int)feature_names.size();
++feat_idx) {
const std::string &feature_name = feature_names[feat_idx];
if (feat_id_map[idx].find(feature_name) != feat_id_map[idx].end()) {
node->set_feature(feat_id_map[idx][feature_name],
res[feat_idx][idy]);
}
}
return 0;
}));
}
for (size_t idy = 0; idy < node_num; ++idy) {
tasks[idy].get();
}
return 0;
}
void string_vector_2_string(std::vector<std::string>::iterator strs_begin,
std::vector<std::string>::iterator strs_end,
char delim,
std::string *output) {
size_t i = 0;
for (std::vector<std::string>::iterator iter = strs_begin; iter != strs_end;
++iter) {
if (i > 0) {
*output += delim;
}
*output += *iter;
++i;
}
}
void string_vector_2_string(
std::vector<paddle::string::str_ptr>::iterator strs_begin,
std::vector<paddle::string::str_ptr>::iterator strs_end,
char delim,
std::string *output) {
size_t i = 0;
for (auto iter = strs_begin; iter != strs_end; ++iter) {
if (i > 0) {
output->append(&delim, 1);
}
output->append((*iter).ptr, (*iter).len);
++i;
}
}
int GraphTable::parse_feature(int idx,
const char *feat_str,
size_t len,
FeatureNode *node) {
// Return (feat_id, btyes) if name are in this->feat_name, else return (-1,
// "")
thread_local std::vector<paddle::string::str_ptr> fields;
fields.clear();
const char c = feature_separator_.at(0);
paddle::string::split_string_ptr(feat_str, len, c, &fields);
std::string name = fields[0].to_string();
auto it = feat_id_map[idx].find(name);
if (it != feat_id_map[idx].end()) {
int32_t id = it->second;
std::string *fea_ptr = node->mutable_feature(id);
std::string dtype = this->feat_dtype[idx][id];
if (dtype == "feasign") {
// string_vector_2_string(fields.begin() + 1, fields.end(), ' ',
// fea_ptr);
FeatureNode::parse_value_to_bytes<uint64_t>(
fields.begin() + 1, fields.end(), fea_ptr);
return 0;
} else if (dtype == "string") {
string_vector_2_string(fields.begin() + 1, fields.end(), ' ', fea_ptr);
return 0;
} else if (dtype == "float32") {
FeatureNode::parse_value_to_bytes<float>(
fields.begin() + 1, fields.end(), fea_ptr);
return 0;
} else if (dtype == "float64") {
FeatureNode::parse_value_to_bytes<double>(
fields.begin() + 1, fields.end(), fea_ptr);
return 0;
} else if (dtype == "int32") {
FeatureNode::parse_value_to_bytes<int32_t>(
fields.begin() + 1, fields.end(), fea_ptr);
return 0;
} else if (dtype == "int64") {
FeatureNode::parse_value_to_bytes<uint64_t>(
fields.begin() + 1, fields.end(), fea_ptr);
return 0;
}
} else {
VLOG(2) << "feature_name[" << name << "] is not in feat_id_map, ntype_id["
<< idx << "] feat_id_map_size[" << feat_id_map.size() << "]";
}
return -1;
}
// thread safe shard vector merge
class MergeShardVector {
public:
MergeShardVector(std::vector<std::vector<uint64_t>> *output, int slice_num) {
_slice_num = slice_num;
_shard_keys = output;
_shard_keys->resize(slice_num);
_mutexs = new std::mutex[slice_num];
}
~MergeShardVector() {
if (_mutexs != nullptr) {
delete[] _mutexs;
_mutexs = nullptr;
}
}
// merge shard keys
void merge(const std::vector<std::vector<uint64_t>> &shard_keys) {
// add to shard
for (int shard_id = 0; shard_id < _slice_num; ++shard_id) {
auto &dest = (*_shard_keys)[shard_id];
auto &src = shard_keys[shard_id];
_mutexs[shard_id].lock();
dest.insert(dest.end(), src.begin(), src.end());
_mutexs[shard_id].unlock();
}
}
private:
int _slice_num = 0;
std::mutex *_mutexs = nullptr;
std::vector<std::vector<uint64_t>> *_shard_keys;
};
int GraphTable::get_all_id(int type_id,
int slice_num,
std::vector<std::vector<uint64_t>> *output) {
MergeShardVector shard_merge(output, slice_num);
auto &search_shards = type_id == 0 ? edge_shards : feature_shards;
std::vector<std::future<size_t>> tasks;
for (int idx = 0; idx < search_shards.size(); idx++) {
for (int j = 0; j < search_shards[idx].size(); j++) {
tasks.push_back(_shards_task_pool[j % task_pool_size_]->enqueue(
[&search_shards, idx, j, slice_num, &shard_merge]() -> size_t {
std::vector<std::vector<uint64_t>> shard_keys;
size_t num =
search_shards[idx][j]->get_all_id(&shard_keys, slice_num);
// add to shard
shard_merge.merge(shard_keys);
return num;
}));
}
}
for (size_t i = 0; i < tasks.size(); ++i) {
tasks[i].wait();
}
return 0;
}
int GraphTable::get_all_neighbor_id(
int type_id, int slice_num, std::vector<std::vector<uint64_t>> *output) {
MergeShardVector shard_merge(output, slice_num);
auto &search_shards = type_id == 0 ? edge_shards : feature_shards;
std::vector<std::future<size_t>> tasks;
for (int idx = 0; idx < search_shards.size(); idx++) {
for (int j = 0; j < search_shards[idx].size(); j++) {
tasks.push_back(_shards_task_pool[j % task_pool_size_]->enqueue(
[&search_shards, idx, j, slice_num, &shard_merge]() -> size_t {
std::vector<std::vector<uint64_t>> shard_keys;
size_t num = search_shards[idx][j]->get_all_neighbor_id(&shard_keys,
slice_num);
// add to shard
shard_merge.merge(shard_keys);
return num;
}));
}
}
for (size_t i = 0; i < tasks.size(); ++i) {
tasks[i].wait();
}
return 0;
}
int GraphTable::get_all_id(int type_id,
int idx,
int slice_num,
std::vector<std::vector<uint64_t>> *output) {
MergeShardVector shard_merge(output, slice_num);
auto &search_shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx];
std::vector<std::future<size_t>> tasks;
VLOG(3) << "begin task, task_pool_size_[" << task_pool_size_ << "]";
for (size_t i = 0; i < search_shards.size(); i++) {
tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue(
[&search_shards, i, slice_num, &shard_merge]() -> size_t {
std::vector<std::vector<uint64_t>> shard_keys;
size_t num = search_shards[i]->get_all_id(&shard_keys, slice_num);
// add to shard
shard_merge.merge(shard_keys);
return num;
}));
}
for (size_t i = 0; i < tasks.size(); ++i) {
tasks[i].wait();
}
VLOG(3) << "end task, task_pool_size_[" << task_pool_size_ << "]";
return 0;
}
int GraphTable::get_all_neighbor_id(
int type_id,
int idx,
int slice_num,
std::vector<std::vector<uint64_t>> *output) {
MergeShardVector shard_merge(output, slice_num);
auto &search_shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx];
std::vector<std::future<size_t>> tasks;
VLOG(3) << "begin task, task_pool_size_[" << task_pool_size_ << "]";
for (int i = 0; i < search_shards.size(); i++) {
tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue(
[&search_shards, i, slice_num, &shard_merge]() -> size_t {
std::vector<std::vector<uint64_t>> shard_keys;
size_t num =
search_shards[i]->get_all_neighbor_id(&shard_keys, slice_num);
// add to shard
shard_merge.merge(shard_keys);
return num;
}));
}
for (size_t i = 0; i < tasks.size(); ++i) {
tasks[i].wait();
}
VLOG(3) << "end task, task_pool_size_[" << task_pool_size_ << "]";
return 0;
}
int GraphTable::get_all_feature_ids(
int type_id,
int idx,
int slice_num,
std::vector<std::vector<uint64_t>> *output) {
MergeShardVector shard_merge(output, slice_num);
auto &search_shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx];
std::vector<std::future<size_t>> tasks;
for (int i = 0; i < search_shards.size(); i++) {
tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue(
[&search_shards, i, slice_num, &shard_merge]() -> size_t {
std::vector<std::vector<uint64_t>> shard_keys;
size_t num =
search_shards[i]->get_all_feature_ids(&shard_keys, slice_num);
// add to shard
shard_merge.merge(shard_keys);
return num;
}));
}
for (size_t i = 0; i < tasks.size(); ++i) {
tasks[i].wait();
}
return 0;
}
int32_t GraphTable::pull_graph_list(int type_id,
int idx,
int start,
int total_size,
std::unique_ptr<char[]> &buffer,
int &actual_size,
bool need_feature,
int step) {
if (start < 0) start = 0;
int size = 0, cur_size;
auto &search_shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx];
std::vector<std::future<std::vector<Node *>>> tasks;
for (size_t i = 0; i < search_shards.size() && total_size > 0; i++) {
cur_size = search_shards[i]->get_size();
if (size + cur_size <= start) {
size += cur_size;
continue;
}
int count = std::min(1 + (size + cur_size - start - 1) / step, total_size);
int end = start + (count - 1) * step + 1;
tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue(
[&search_shards, this, i, start, end, step, size]()
-> std::vector<Node *> {
return search_shards[i]->get_batch(start - size, end - size, step);
}));
start += count * step;
total_size -= count;
size += cur_size;
}
for (size_t i = 0; i < tasks.size(); ++i) {
tasks[i].wait();
}
size = 0;
std::vector<std::vector<Node *>> res;
for (size_t i = 0; i < tasks.size(); i++) {
res.push_back(tasks[i].get());
for (size_t j = 0; j < res.back().size(); j++) {
size += res.back()[j]->get_size(need_feature);
}
}
char *buffer_addr = new char[size];
buffer.reset(buffer_addr);
int index = 0;
for (size_t i = 0; i < res.size(); i++) {
for (size_t j = 0; j < res[i].size(); j++) {
res[i][j]->to_buffer(buffer_addr + index, need_feature);
index += res[i][j]->get_size(need_feature);
}
}
actual_size = size;
return 0;
}
void GraphTable::set_feature_separator(const std::string &ch) {
feature_separator_ = ch;
}
int32_t GraphTable::get_server_index_by_id(uint64_t id) {
return id % shard_num / shard_num_per_server;
}
int32_t GraphTable::Initialize(const TableParameter &config,
const FsClientParameter &fs_config) {
LOG(INFO) << "in graphTable initialize";
_config = config;
if (InitializeAccessor() != 0) {
LOG(WARNING) << "Table accessor initialize failed";
return -1;
}
if (_afs_client.initialize(fs_config) != 0) {
LOG(WARNING) << "Table fs_client initialize failed";
// return -1;
}
auto graph = config.graph_parameter();
shard_num = _config.shard_num();
LOG(INFO) << "in graphTable initialize over";
return Initialize(graph);
}
void GraphTable::load_node_weight(int type_id, int idx, std::string path) {
auto paths = paddle::string::split_string<std::string>(path, ";");
int64_t count = 0;
auto &weight_map = node_weight[type_id][idx];
for (auto path : paths) {
std::ifstream file(path);
std::string line;
while (std::getline(file, line)) {
auto values = paddle::string::split_string<std::string>(line, "\t");
count++;
if (values.size() < 2) continue;
auto src_id = std::stoull(values[0]);
double weight = std::stod(values[1]);
weight_map[src_id] = weight;
}
}
}
int32_t GraphTable::Initialize(const GraphParameter &graph) {
task_pool_size_ = graph.task_pool_size();
build_sampler_on_cpu = graph.build_sampler_on_cpu();
#ifdef PADDLE_WITH_HETERPS
_db = NULL;
search_level = graph.search_level();
if (search_level >= 2) {
_db = paddle::distributed::RocksDBHandler::GetInstance();
_db->initialize("./temp_gpups_db", task_pool_size_);
}
// gpups_mode = true;
// auto *sampler =
// CREATE_PSCORE_CLASS(GraphSampler, graph.gpups_graph_sample_class());
// auto slices =
// string::split_string<std::string>(graph.gpups_graph_sample_args(), ",");
// std::cout << "slices" << std::endl;
// for (auto x : slices) std::cout << x << std::endl;
// sampler->init(graph.gpu_num(), this, slices);
// graph_sampler.reset(sampler);
#endif
if (shard_num == 0) {
server_num = 1;
_shard_idx = 0;
shard_num = graph.shard_num();
}
use_cache = graph.use_cache();
if (use_cache) {
cache_size_limit = graph.cache_size_limit();
cache_ttl = graph.cache_ttl();
make_neighbor_sample_cache((size_t)cache_size_limit, (size_t)cache_ttl);
}
_shards_task_pool.resize(task_pool_size_);
for (size_t i = 0; i < _shards_task_pool.size(); ++i) {
_shards_task_pool[i].reset(new ::ThreadPool(1));
_shards_task_rng_pool.push_back(paddle::framework::GetCPURandomEngine(0));
}
load_node_edge_task_pool.reset(new ::ThreadPool(load_thread_num));
auto graph_feature = graph.graph_feature();
auto node_types = graph.node_types();
auto edge_types = graph.edge_types();
VLOG(0) << "got " << edge_types.size() << "edge types in total";
feat_id_map.resize(node_types.size());
for (int k = 0; k < edge_types.size(); k++) {
VLOG(0) << "in initialize: get a edge_type " << edge_types[k];
edge_to_id[edge_types[k]] = k;
id_to_edge.push_back(edge_types[k]);
}
feat_name.resize(node_types.size());
feat_shape.resize(node_types.size());
feat_dtype.resize(node_types.size());
VLOG(0) << "got " << node_types.size() << "node types in total";
for (int k = 0; k < node_types.size(); k++) {
feature_to_id[node_types[k]] = k;
auto node_type = node_types[k];
auto feature = graph_feature[k];
id_to_feature.push_back(node_type);
int feat_conf_size = static_cast<int>(feature.name().size());
for (int i = 0; i < feat_conf_size; i++) {
// auto &f_name = common.attributes()[i];
// auto &f_shape = common.dims()[i];
// auto &f_dtype = common.params()[i];
auto &f_name = feature.name()[i];
auto &f_shape = feature.shape()[i];
auto &f_dtype = feature.dtype()[i];
feat_name[k].push_back(f_name);
feat_shape[k].push_back(f_shape);
feat_dtype[k].push_back(f_dtype);
feat_id_map[k][f_name] = i;
VLOG(0) << "init graph table feat conf name:" << f_name
<< " shape:" << f_shape << " dtype:" << f_dtype;
}
}
// this->table_name = common.table_name();
// this->table_type = common.name();
this->table_name = graph.table_name();
this->table_type = graph.table_type();
VLOG(0) << " init graph table type " << this->table_type << " table name "
<< this->table_name;
// int feat_conf_size = static_cast<int>(common.attributes().size());
// int feat_conf_size = static_cast<int>(graph_feature.name().size());
VLOG(0) << "in init graph table shard num = " << shard_num << " shard_idx"
<< _shard_idx;
shard_num_per_server = sparse_local_shard_num(shard_num, server_num);
shard_start = _shard_idx * shard_num_per_server;
shard_end = shard_start + shard_num_per_server;
VLOG(0) << "in init graph table shard idx = " << _shard_idx << " shard_start "
<< shard_start << " shard_end " << shard_end;
edge_shards.resize(id_to_edge.size());
node_weight.resize(2);
node_weight[0].resize(id_to_edge.size());
#ifdef PADDLE_WITH_HETERPS
partitions.resize(id_to_edge.size());
#endif
for (int k = 0; k < (int)edge_shards.size(); k++) {
for (size_t i = 0; i < shard_num_per_server; i++) {
edge_shards[k].push_back(new GraphShard());
}
}
node_weight[1].resize(id_to_feature.size());
feature_shards.resize(id_to_feature.size());
for (int k = 0; k < (int)feature_shards.size(); k++) {
for (size_t i = 0; i < shard_num_per_server; i++) {
feature_shards[k].push_back(new GraphShard());
}
}
return 0;
}
} // namespace distributed
}; // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ThreadPool.h>
#include <assert.h>
#include <pthread.h>
#include <algorithm>
#include <cassert>
#include <cstdio>
#include <ctime>
#include <functional>
#include <iostream>
#include <list>
#include <map>
#include <memory>
#include <mutex> // NOLINT
#include <numeric>
#include <queue>
#include <set>
#include <string>
#include <thread>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/common_table.h"
#include "paddle/fluid/distributed/ps/table/graph/class_macro.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/phi/core/utils/rw_lock.h"
#ifdef PADDLE_WITH_HETERPS
#include "paddle/fluid/distributed/ps/table/depends/rocksdb_warpper.h"
#include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h"
#endif
namespace paddle {
namespace distributed {
class GraphShard {
public:
size_t get_size();
GraphShard() {}
~GraphShard();
std::vector<Node *> &get_bucket() { return bucket; }
std::vector<Node *> get_batch(int start, int end, int step);
void get_ids_by_range(int start, int end, std::vector<uint64_t> *res) {
res->reserve(res->size() + end - start);
for (int i = start; i < end && i < (int)bucket.size(); i++) {
res->emplace_back(bucket[i]->get_id());
}
}
size_t get_all_id(std::vector<std::vector<uint64_t>> *shard_keys,
int slice_num) {
int bucket_num = bucket.size();
shard_keys->resize(slice_num);
for (int i = 0; i < slice_num; ++i) {
(*shard_keys)[i].reserve(bucket_num / slice_num);
}
for (int i = 0; i < bucket_num; i++) {
uint64_t k = bucket[i]->get_id();
(*shard_keys)[k % slice_num].emplace_back(k);
}
return bucket_num;
}
size_t get_all_neighbor_id(std::vector<std::vector<uint64_t>> *total_res,
int slice_num) {
std::vector<uint64_t> keys;
for (size_t i = 0; i < bucket.size(); i++) {
size_t neighbor_size = bucket[i]->get_neighbor_size();
size_t n = keys.size();
keys.resize(n + neighbor_size);
for (size_t j = 0; j < neighbor_size; j++) {
keys[n + j] = bucket[i]->get_neighbor_id(j);
}
}
return dedup2shard_keys(&keys, total_res, slice_num);
}
size_t get_all_feature_ids(std::vector<std::vector<uint64_t>> *total_res,
int slice_num) {
std::vector<uint64_t> keys;
for (int i = 0; i < (int)bucket.size(); i++) {
bucket[i]->get_feature_ids(&keys);
}
return dedup2shard_keys(&keys, total_res, slice_num);
}
size_t dedup2shard_keys(std::vector<uint64_t> *keys,
std::vector<std::vector<uint64_t>> *total_res,
int slice_num) {
size_t num = keys->size();
uint64_t last_key = 0;
// sort key insert to vector
std::sort(keys->begin(), keys->end());
total_res->resize(slice_num);
for (int shard_id = 0; shard_id < slice_num; ++shard_id) {
(*total_res)[shard_id].reserve(num / slice_num);
}
for (size_t i = 0; i < num; ++i) {
const uint64_t &k = (*keys)[i];
if (i > 0 && last_key == k) {
continue;
}
last_key = k;
(*total_res)[k % slice_num].push_back(k);
}
return num;
}
GraphNode *add_graph_node(uint64_t id);
GraphNode *add_graph_node(Node *node);
FeatureNode *add_feature_node(uint64_t id, bool is_overlap = true);
Node *find_node(uint64_t id);
void delete_node(uint64_t id);
void clear();
void add_neighbor(uint64_t id, uint64_t dst_id, float weight);
std::unordered_map<uint64_t, int> &get_node_location() {
return node_location;
}
private:
std::unordered_map<uint64_t, int> node_location;
std::vector<Node *> bucket;
};
enum LRUResponse { ok = 0, blocked = 1, err = 2 };
struct SampleKey {
int idx;
uint64_t node_key;
size_t sample_size;
bool is_weighted;
SampleKey(int _idx,
uint64_t _node_key,
size_t _sample_size,
bool _is_weighted) {
idx = _idx;
node_key = _node_key;
sample_size = _sample_size;
is_weighted = _is_weighted;
}
bool operator==(const SampleKey &s) const {
return idx == s.idx && node_key == s.node_key &&
sample_size == s.sample_size && is_weighted == s.is_weighted;
}
};
class SampleResult {
public:
size_t actual_size;
std::shared_ptr<char> buffer;
SampleResult(size_t _actual_size, std::shared_ptr<char> &_buffer)
: actual_size(_actual_size), buffer(_buffer) {}
SampleResult(size_t _actual_size, char *_buffer)
: actual_size(_actual_size),
buffer(_buffer, [](char *p) { delete[] p; }) {}
~SampleResult() {}
};
template <typename K, typename V>
class LRUNode {
public:
LRUNode(K _key, V _data, size_t _ttl) : key(_key), data(_data), ttl(_ttl) {
next = pre = NULL;
}
K key;
V data;
size_t ttl;
// time to live
LRUNode<K, V> *pre, *next;
};
template <typename K, typename V>
class ScaledLRU;
template <typename K, typename V>
class RandomSampleLRU {
public:
RandomSampleLRU(ScaledLRU<K, V> *_father) {
father = _father;
remove_count = 0;
node_size = 0;
node_head = node_end = NULL;
global_ttl = father->ttl;
total_diff = 0;
}
~RandomSampleLRU() {
LRUNode<K, V> *p;
while (node_head != NULL) {
p = node_head->next;
delete node_head;
node_head = p;
}
}
LRUResponse query(K *keys, size_t length, std::vector<std::pair<K, V>> &res) {
if (pthread_rwlock_tryrdlock(&father->rwlock) != 0)
return LRUResponse::blocked;
// pthread_rwlock_rdlock(&father->rwlock);
int init_size = node_size - remove_count;
process_redundant(length * 3);
for (size_t i = 0; i < length; i++) {
auto iter = key_map.find(keys[i]);
if (iter != key_map.end()) {
res.emplace_back(keys[i], iter->second->data);
iter->second->ttl--;
if (iter->second->ttl == 0) {
remove(iter->second);
if (remove_count != 0) remove_count--;
} else {
move_to_tail(iter->second);
}
}
}
total_diff += node_size - remove_count - init_size;
if (total_diff >= 500 || total_diff < -500) {
father->handle_size_diff(total_diff);
total_diff = 0;
}
pthread_rwlock_unlock(&father->rwlock);
return LRUResponse::ok;
}
LRUResponse insert(K *keys, V *data, size_t length) {
if (pthread_rwlock_tryrdlock(&father->rwlock) != 0)
return LRUResponse::blocked;
// pthread_rwlock_rdlock(&father->rwlock);
int init_size = node_size - remove_count;
process_redundant(length * 3);
for (size_t i = 0; i < length; i++) {
auto iter = key_map.find(keys[i]);
if (iter != key_map.end()) {
move_to_tail(iter->second);
iter->second->ttl = global_ttl;
iter->second->data = data[i];
} else {
LRUNode<K, V> *temp = new LRUNode<K, V>(keys[i], data[i], global_ttl);
add_new(temp);
}
}
total_diff += node_size - remove_count - init_size;
if (total_diff >= 500 || total_diff < -500) {
father->handle_size_diff(total_diff);
total_diff = 0;
}
pthread_rwlock_unlock(&father->rwlock);
return LRUResponse::ok;
}
void remove(LRUNode<K, V> *node) {
fetch(node);
node_size--;
key_map.erase(node->key);
delete node;
}
void process_redundant(int process_size) {
int length = std::min(remove_count, process_size);
while (length--) {
remove(node_head);
remove_count--;
}
// std::cerr<<"after remove_count = "<<remove_count<<std::endl;
}
void move_to_tail(LRUNode<K, V> *node) {
fetch(node);
place_at_tail(node);
}
void add_new(LRUNode<K, V> *node) {
node->ttl = global_ttl;
place_at_tail(node);
node_size++;
key_map[node->key] = node;
}
void place_at_tail(LRUNode<K, V> *node) {
if (node_end == NULL) {
node_head = node_end = node;
node->next = node->pre = NULL;
} else {
node_end->next = node;
node->pre = node_end;
node->next = NULL;
node_end = node;
}
}
void fetch(LRUNode<K, V> *node) {
if (node->pre) {
node->pre->next = node->next;
} else {
node_head = node->next;
}
if (node->next) {
node->next->pre = node->pre;
} else {
node_end = node->pre;
}
}
private:
std::unordered_map<K, LRUNode<K, V> *> key_map;
ScaledLRU<K, V> *father;
size_t global_ttl, size_limit;
int node_size, total_diff;
LRUNode<K, V> *node_head, *node_end;
friend class ScaledLRU<K, V>;
int remove_count;
};
template <typename K, typename V>
class ScaledLRU {
public:
ScaledLRU(size_t _shard_num, size_t size_limit, size_t _ttl)
: size_limit(size_limit), ttl(_ttl) {
shard_num = _shard_num;
pthread_rwlock_init(&rwlock, NULL);
stop = false;
thread_pool.reset(new ::ThreadPool(1));
global_count = 0;
lru_pool = std::vector<RandomSampleLRU<K, V>>(shard_num,
RandomSampleLRU<K, V>(this));
shrink_job = std::thread([this]() -> void {
while (true) {
{
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait_for(lock, std::chrono::milliseconds(20000));
if (stop) {
return;
}
}
auto status =
thread_pool->enqueue([this]() -> int { return Shrink(); });
status.wait();
}
});
shrink_job.detach();
}
~ScaledLRU() {
std::unique_lock<std::mutex> lock(mutex_);
stop = true;
cv_.notify_one();
}
LRUResponse query(size_t index,
K *keys,
size_t length,
std::vector<std::pair<K, V>> &res) {
return lru_pool[index].query(keys, length, res);
}
LRUResponse insert(size_t index, K *keys, V *data, size_t length) {
return lru_pool[index].insert(keys, data, length);
}
int Shrink() {
int node_size = 0;
for (size_t i = 0; i < lru_pool.size(); i++) {
node_size += lru_pool[i].node_size - lru_pool[i].remove_count;
}
if ((size_t)node_size <= size_t(1.1 * size_limit) + 1) return 0;
if (pthread_rwlock_wrlock(&rwlock) == 0) {
global_count = 0;
for (size_t i = 0; i < lru_pool.size(); i++) {
global_count += lru_pool[i].node_size - lru_pool[i].remove_count;
}
if ((size_t)global_count > size_limit) {
size_t remove = global_count - size_limit;
for (size_t i = 0; i < lru_pool.size(); i++) {
lru_pool[i].total_diff = 0;
lru_pool[i].remove_count +=
1.0 * (lru_pool[i].node_size - lru_pool[i].remove_count) /
global_count * remove;
}
}
pthread_rwlock_unlock(&rwlock);
return 0;
}
return 0;
}
void handle_size_diff(int diff) {
if (diff != 0) {
__sync_fetch_and_add(&global_count, diff);
if (global_count > int(1.25 * size_limit)) {
thread_pool->enqueue([this]() -> int { return Shrink(); });
}
}
}
size_t get_ttl() { return ttl; }
private:
pthread_rwlock_t rwlock;
size_t shard_num;
int global_count;
size_t size_limit, total, hit;
size_t ttl;
bool stop;
std::thread shrink_job;
std::vector<RandomSampleLRU<K, V>> lru_pool;
mutable std::mutex mutex_;
std::condition_variable cv_;
std::shared_ptr<::ThreadPool> thread_pool;
friend class RandomSampleLRU<K, V>;
};
/*
#ifdef PADDLE_WITH_HETERPS
enum GraphSamplerStatus { waiting = 0, running = 1, terminating = 2 };
class GraphTable;
class GraphSampler {
public:
GraphSampler() {
status = GraphSamplerStatus::waiting;
thread_pool.reset(new ::ThreadPool(1));
callback = [](std::vector<paddle::framework::GpuPsCommGraph> &res) {
return;
};
}
virtual int loadData(const std::string &path){
return 0;
}
virtual int run_graph_sampling() = 0;
virtual int start_graph_sampling() {
if (status != GraphSamplerStatus::waiting) {
return -1;
}
std::promise<int> prom;
std::future<int> fut = prom.get_future();
graph_sample_task_over = thread_pool->enqueue([&prom, this]() {
prom.set_value(0);
status = GraphSamplerStatus::running;
return run_graph_sampling();
});
return fut.get();
}
virtual void init(size_t gpu_num, GraphTable *graph_table,
std::vector<std::string> args) = 0;
virtual void set_graph_sample_callback(
std::function<void(std::vector<paddle::framework::GpuPsCommGraph> &)>
callback) {
this->callback = callback;
}
virtual int end_graph_sampling() {
if (status == GraphSamplerStatus::running) {
status = GraphSamplerStatus::terminating;
return graph_sample_task_over.get();
}
return -1;
}
virtual GraphSamplerStatus get_graph_sampler_status() { return status; }
protected:
std::function<void(std::vector<paddle::framework::GpuPsCommGraph> &)>
callback;
std::shared_ptr<::ThreadPool> thread_pool;
GraphSamplerStatus status;
std::future<int> graph_sample_task_over;
std::vector<paddle::framework::GpuPsCommGraph> sample_res;
};
#endif
*/
class GraphTable : public Table {
public:
GraphTable() {
use_cache = false;
shard_num = 0;
rw_lock.reset(new pthread_rwlock_t());
#ifdef PADDLE_WITH_HETERPS
next_partition = 0;
total_memory_cost = 0;
#endif
}
virtual ~GraphTable();
virtual void *GetShard(size_t shard_idx) { return 0; }
static int32_t sparse_local_shard_num(uint32_t shard_num,
uint32_t server_num) {
if (shard_num % server_num == 0) {
return shard_num / server_num;
}
size_t local_shard_num = shard_num / server_num + 1;
return local_shard_num;
}
static size_t get_sparse_shard(uint32_t shard_num,
uint32_t server_num,
uint64_t key) {
return (key % shard_num) / sparse_local_shard_num(shard_num, server_num);
}
virtual int32_t pull_graph_list(int type_id,
int idx,
int start,
int size,
std::unique_ptr<char[]> &buffer,
int &actual_size,
bool need_feature,
int step);
virtual int32_t random_sample_neighbors(
int idx,
uint64_t *node_ids,
int sample_size,
std::vector<std::shared_ptr<char>> &buffers,
std::vector<int> &actual_sizes,
bool need_weight);
int32_t random_sample_nodes(int type_id,
int idx,
int sample_size,
std::unique_ptr<char[]> &buffers,
int &actual_sizes);
virtual int32_t get_nodes_ids_by_ranges(
int type_id,
int idx,
std::vector<std::pair<int, int>> ranges,
std::vector<uint64_t> &res);
virtual int32_t Initialize() { return 0; }
virtual int32_t Initialize(const TableParameter &config,
const FsClientParameter &fs_config);
virtual int32_t Initialize(const GraphParameter &config);
int32_t Load(const std::string &path, const std::string &param);
int32_t load_node_and_edge_file(std::string etype,
std::string ntype,
std::string epath,
std::string npath,
int part_num,
bool reverse);
std::string get_inverse_etype(std::string &etype);
int32_t load_edges(const std::string &path,
bool reverse,
const std::string &edge_type);
int get_all_id(int type,
int slice_num,
std::vector<std::vector<uint64_t>> *output);
int get_all_neighbor_id(int type,
int slice_num,
std::vector<std::vector<uint64_t>> *output);
int get_all_id(int type,
int idx,
int slice_num,
std::vector<std::vector<uint64_t>> *output);
int get_all_neighbor_id(int type_id,
int id,
int slice_num,
std::vector<std::vector<uint64_t>> *output);
int get_all_feature_ids(int type,
int idx,
int slice_num,
std::vector<std::vector<uint64_t>> *output);
int32_t load_nodes(const std::string &path,
std::string node_type = std::string());
std::pair<uint64_t, uint64_t> parse_edge_file(const std::string &path,
int idx,
bool reverse);
std::pair<uint64_t, uint64_t> parse_node_file(const std::string &path,
const std::string &node_type,
int idx);
std::pair<uint64_t, uint64_t> parse_node_file(const std::string &path);
int32_t add_graph_node(int idx,
std::vector<uint64_t> &id_list,
std::vector<bool> &is_weight_list);
int32_t remove_graph_node(int idx, std::vector<uint64_t> &id_list);
int32_t get_server_index_by_id(uint64_t id);
Node *find_node(int type_id, int idx, uint64_t id);
Node *find_node(int type_id, uint64_t id);
virtual int32_t Pull(TableContext &context) { return 0; }
virtual int32_t Push(TableContext &context) { return 0; }
virtual int32_t clear_nodes(int type, int idx);
virtual void Clear() {}
virtual int32_t Flush() { return 0; }
virtual int32_t Shrink(const std::string &param) { return 0; }
//指定保存路径
virtual int32_t Save(const std::string &path, const std::string &converter) {
return 0;
}
virtual int32_t InitializeShard() { return 0; }
virtual int32_t SetShard(size_t shard_idx, size_t server_num) {
_shard_idx = shard_idx;
/*
_shard_num is not used in graph_table, this following operation is for the
purpose of
being compatible with base class table.
*/
_shard_num = server_num;
this->server_num = server_num;
return 0;
}
virtual uint32_t get_thread_pool_index_by_shard_index(uint64_t shard_index);
virtual uint32_t get_thread_pool_index(uint64_t node_id);
virtual int parse_feature(int idx,
const char *feat_str,
size_t len,
FeatureNode *node);
virtual int32_t get_node_feat(int idx,
const std::vector<uint64_t> &node_ids,
const std::vector<std::string> &feature_names,
std::vector<std::vector<std::string>> &res);
virtual int32_t set_node_feat(
int idx,
const std::vector<uint64_t> &node_ids,
const std::vector<std::string> &feature_names,
const std::vector<std::vector<std::string>> &res);
size_t get_server_num() { return server_num; }
void clear_graph(int idx);
virtual int32_t make_neighbor_sample_cache(size_t size_limit, size_t ttl) {
{
std::unique_lock<std::mutex> lock(mutex_);
if (use_cache == false) {
scaled_lru.reset(new ScaledLRU<SampleKey, SampleResult>(
task_pool_size_, size_limit, ttl));
use_cache = true;
}
}
return 0;
}
virtual void load_node_weight(int type_id, int idx, std::string path);
#ifdef PADDLE_WITH_HETERPS
// virtual int32_t start_graph_sampling() {
// return this->graph_sampler->start_graph_sampling();
// }
// virtual int32_t end_graph_sampling() {
// return this->graph_sampler->end_graph_sampling();
// }
// virtual int32_t set_graph_sample_callback(
// std::function<void(std::vector<paddle::framework::GpuPsCommGraph> &)>
// callback) {
// graph_sampler->set_graph_sample_callback(callback);
// return 0;
// }
virtual void make_partitions(int idx, int64_t gb_size, int device_len);
virtual void export_partition_files(int idx, std::string file_path);
virtual char *random_sample_neighbor_from_ssd(
int idx,
uint64_t id,
int sample_size,
const std::shared_ptr<std::mt19937_64> rng,
int &actual_size);
virtual int32_t add_node_to_ssd(
int type_id, int idx, uint64_t src_id, char *data, int len);
virtual paddle::framework::GpuPsCommGraph make_gpu_ps_graph(
int idx, std::vector<uint64_t> ids);
virtual paddle::framework::GpuPsCommGraphFea make_gpu_ps_graph_fea(
std::vector<uint64_t> &node_ids, int slot_num);
int32_t Load_to_ssd(const std::string &path, const std::string &param);
int64_t load_graph_to_memory_from_ssd(int idx, std::vector<uint64_t> &ids);
int32_t make_complementary_graph(int idx, int64_t byte_size);
int32_t dump_edges_to_ssd(int idx);
int32_t get_partition_num(int idx) { return partitions[idx].size(); }
std::vector<uint64_t> get_partition(int idx, int index) {
if (idx >= (int)partitions.size() || index >= (int)partitions[idx].size())
return std::vector<uint64_t>();
return partitions[idx][index];
}
int32_t load_edges_to_ssd(const std::string &path,
bool reverse_edge,
const std::string &edge_type);
int32_t load_next_partition(int idx);
void set_search_level(int search_level) { this->search_level = search_level; }
int search_level;
int64_t total_memory_cost;
std::vector<std::vector<std::vector<uint64_t>>> partitions;
int next_partition;
#endif
virtual int32_t add_comm_edge(int idx, uint64_t src_id, uint64_t dst_id);
virtual int32_t build_sampler(int idx, std::string sample_type = "random");
void set_feature_separator(const std::string &ch);
std::vector<std::vector<GraphShard *>> edge_shards, feature_shards;
size_t shard_start, shard_end, server_num, shard_num_per_server, shard_num;
int task_pool_size_ = 24;
int load_thread_num = 160;
const int random_sample_nodes_ranges = 3;
std::vector<std::vector<std::unordered_map<uint64_t, double>>> node_weight;
std::vector<std::vector<std::string>> feat_name;
std::vector<std::vector<std::string>> feat_dtype;
std::vector<std::vector<int32_t>> feat_shape;
std::vector<std::unordered_map<std::string, int32_t>> feat_id_map;
std::unordered_map<std::string, int> feature_to_id, edge_to_id;
std::vector<std::string> id_to_feature, id_to_edge;
std::string table_name;
std::string table_type;
std::vector<std::shared_ptr<::ThreadPool>> _shards_task_pool;
std::vector<std::shared_ptr<std::mt19937_64>> _shards_task_rng_pool;
std::shared_ptr<::ThreadPool> load_node_edge_task_pool;
std::shared_ptr<ScaledLRU<SampleKey, SampleResult>> scaled_lru;
std::unordered_set<uint64_t> extra_nodes;
std::unordered_map<uint64_t, size_t> extra_nodes_to_thread_index;
bool use_cache, use_duplicate_nodes;
int cache_size_limit;
int cache_ttl;
mutable std::mutex mutex_;
bool build_sampler_on_cpu;
std::shared_ptr<pthread_rwlock_t> rw_lock;
#ifdef PADDLE_WITH_HETERPS
// paddle::framework::GpuPsGraphTable gpu_graph_table;
paddle::distributed::RocksDBHandler *_db;
// std::shared_ptr<::ThreadPool> graph_sample_pool;
// std::shared_ptr<GraphSampler> graph_sampler;
// REGISTER_GRAPH_FRIEND_CLASS(2, CompleteGraphSampler, BasicBfsGraphSampler)
#endif
std::string feature_separator_ = std::string(" ");
};
/*
#ifdef PADDLE_WITH_HETERPS
REGISTER_PSCORE_REGISTERER(GraphSampler);
class CompleteGraphSampler : public GraphSampler {
public:
CompleteGraphSampler() {}
~CompleteGraphSampler() {}
// virtual pthread_rwlock_t *export_rw_lock();
virtual int run_graph_sampling();
virtual void init(size_t gpu_num, GraphTable *graph_table,
std::vector<std::string> args_);
protected:
GraphTable *graph_table;
std::vector<std::vector<paddle::framework::GpuPsGraphNode>> sample_nodes;
std::vector<std::vector<uint64_t>> sample_neighbors;
// std::vector<GpuPsCommGraph> sample_res;
// std::shared_ptr<std::mt19937_64> random;
int gpu_num;
};
class BasicBfsGraphSampler : public GraphSampler {
public:
BasicBfsGraphSampler() {}
~BasicBfsGraphSampler() {}
// virtual pthread_rwlock_t *export_rw_lock();
virtual int run_graph_sampling();
virtual void init(size_t gpu_num, GraphTable *graph_table,
std::vector<std::string> args_);
protected:
GraphTable *graph_table;
// std::vector<std::vector<GpuPsGraphNode>> sample_nodes;
std::vector<std::vector<paddle::framework::GpuPsGraphNode>> sample_nodes;
std::vector<std::vector<uint64_t>> sample_neighbors;
size_t gpu_num;
int init_search_size, node_num_for_each_shard, edge_num_for_each_node;
int rounds, interval;
std::vector<std::unordered_map<uint64_t, std::vector<uint64_t>>>
sample_neighbors_map;
};
#endif
*/
} // namespace distributed
}; // namespace paddle
namespace std {
template <>
struct hash<paddle::distributed::SampleKey> {
size_t operator()(const paddle::distributed::SampleKey &s) const {
return s.idx ^ s.node_key ^ s.sample_size;
}
};
} // namespace std
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <condition_variable> // NOLINT
#include <mutex> // NOLINT
#include <set>
#include "paddle/fluid/distributed/common/utils.h"
#include "paddle/fluid/distributed/ps/table/table.h"
namespace paddle {
namespace distributed {
template <typename T>
struct ReservoirValue {
std::vector<T> values;
uint32_t counter;
uint32_t dim;
ReservoirValue() {
dim = 0;
values.resize(dim);
counter = 0;
}
ReservoirValue(uint32_t dim) {
this->dim = dim;
values.resize(dim);
counter = 0;
}
void add(const T *value, int numel) {
GetBlas<T>().VADD(numel, values.data(), value, values.data());
counter++;
}
void add(T *value, int numel) {
GetBlas<T>().VADD(numel, values.data(), value, values.data());
counter++;
}
void avg() {
if (counter == 0) return;
auto scale = 1 / static_cast<T>(counter);
GetBlas<T>().SCAL(values.size(), scale, values.data());
}
void reset() {
std::fill(values.begin(), values.end(), 0);
counter = 0;
}
};
class BarrierTable : public Table {
public:
BarrierTable() {}
virtual ~BarrierTable() {}
virtual void *GetShard(size_t shard_idx) { return 0; }
virtual int32_t Pull(TableContext &context) { return 0; }
virtual int32_t Push(TableContext &context) { return 0; }
int32_t Shrink(const std::string &param) override { return 0; }
virtual void Clear() {}
virtual int32_t Flush() { return 0; }
virtual int32_t Load(const std::string &path, const std::string &param) {
return 0;
}
virtual int32_t Save(const std::string &path, const std::string &param) {
return 0;
}
virtual int32_t InitializeShard() { return 0; }
virtual int32_t Initialize() override;
// only for barrier
// 0: send_barrier 1: recv_barrier 2: complete
virtual int32_t Barrier(const uint32_t trainer_id,
const std::string barrier_type) override;
virtual int32_t SetTableMap(
std::unordered_map<uint32_t, std::shared_ptr<Table>> *table_map) override;
private:
std::mutex mutex_;
std::condition_variable trainer_wait_;
std::set<uint64_t> trainer_ids_;
std::set<uint64_t> trainer_all_;
std::atomic<int> trigger_;
std::atomic<bool> exit_;
std::unordered_map<uint32_t, std::shared_ptr<Table>> *table_map_;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/table/ctr_accessor.h"
#include <gflags/gflags.h>
#include "glog/logging.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace distributed {
int CtrCommonAccessor::Initialize() {
auto name = _config.embed_sgd_param().name();
_embed_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name);
_embed_sgd_rule->LoadConfig(_config.embed_sgd_param(), 1);
name = _config.embedx_sgd_param().name();
_embedx_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name);
_embedx_sgd_rule->LoadConfig(_config.embedx_sgd_param(),
_config.embedx_dim());
common_feature_value.embed_sgd_dim = _embed_sgd_rule->Dim();
common_feature_value.embedx_dim = _config.embedx_dim();
common_feature_value.embedx_sgd_dim = _embedx_sgd_rule->Dim();
_show_click_decay_rate = _config.ctr_accessor_param().show_click_decay_rate();
_ssd_unseenday_threshold =
_config.ctr_accessor_param().ssd_unseenday_threshold();
if (_config.ctr_accessor_param().show_scale()) {
_show_scale = true;
}
InitAccessorInfo();
return 0;
}
void CtrCommonAccessor::InitAccessorInfo() {
_accessor_info.dim = common_feature_value.Dim();
_accessor_info.size = common_feature_value.Size();
auto embedx_dim = _config.embedx_dim();
_accessor_info.select_dim = 3 + embedx_dim;
_accessor_info.select_size = _accessor_info.select_dim * sizeof(float);
_accessor_info.update_dim = 4 + embedx_dim;
_accessor_info.update_size = _accessor_info.update_dim * sizeof(float);
_accessor_info.mf_size =
(embedx_dim + common_feature_value.embedx_sgd_dim) * sizeof(float);
}
bool CtrCommonAccessor::Shrink(float* value) {
auto delete_after_unseen_days =
_config.ctr_accessor_param().delete_after_unseen_days();
auto delete_threshold = _config.ctr_accessor_param().delete_threshold();
// time_decay first
common_feature_value.Show(value) *= _show_click_decay_rate;
common_feature_value.Click(value) *= _show_click_decay_rate;
// shrink after
auto score = ShowClickScore(common_feature_value.Show(value),
common_feature_value.Click(value));
auto unseen_days = common_feature_value.UnseenDays(value);
if (score < delete_threshold || unseen_days > delete_after_unseen_days) {
return true;
}
return false;
}
bool CtrCommonAccessor::SaveCache(float* value,
int param,
double global_cache_threshold) {
auto base_threshold = _config.ctr_accessor_param().base_threshold();
auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
if (ShowClickScore(common_feature_value.Show(value),
common_feature_value.Click(value)) >= base_threshold &&
common_feature_value.UnseenDays(value) <= delta_keep_days) {
return common_feature_value.Show(value) > global_cache_threshold;
}
return false;
}
bool CtrCommonAccessor::SaveSSD(float* value) {
if (common_feature_value.UnseenDays(value) > _ssd_unseenday_threshold) {
return true;
}
return false;
}
bool CtrCommonAccessor::Save(float* value, int param) {
auto base_threshold = _config.ctr_accessor_param().base_threshold();
auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
if (param == 2) {
delta_threshold = 0;
}
switch (param) {
// save all
case 0: {
return true;
}
// save xbox delta
case 1:
// save xbox base
case 2: {
if (ShowClickScore(common_feature_value.Show(value),
common_feature_value.Click(value)) >= base_threshold &&
common_feature_value.DeltaScore(value) >= delta_threshold &&
common_feature_value.UnseenDays(value) <= delta_keep_days) {
// do this after save, because it must not be modified when retry
if (param == 2) {
common_feature_value.DeltaScore(value) = 0;
}
return true;
} else {
return false;
}
}
// already decayed in shrink
case 3: {
// do this after save, because it must not be modified when retry
// common_feature_value.UnseenDays(value)++;
return true;
}
// save revert batch_model
case 5: {
return true;
}
default:
return true;
}
}
void CtrCommonAccessor::UpdateStatAfterSave(float* value, int param) {
auto base_threshold = _config.ctr_accessor_param().base_threshold();
auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
if (param == 2) {
delta_threshold = 0;
}
switch (param) {
case 1: {
if (ShowClickScore(common_feature_value.Show(value),
common_feature_value.Click(value)) >= base_threshold &&
common_feature_value.DeltaScore(value) >= delta_threshold &&
common_feature_value.UnseenDays(value) <= delta_keep_days) {
common_feature_value.DeltaScore(value) = 0;
}
}
return;
case 3: {
common_feature_value.UnseenDays(value)++;
}
return;
default:
return;
}
}
int32_t CtrCommonAccessor::Create(float** values, size_t num) {
for (size_t value_item = 0; value_item < num; ++value_item) {
float* value = values[value_item];
value[common_feature_value.UnseenDaysIndex()] = 0;
value[common_feature_value.DeltaScoreIndex()] = 0;
value[common_feature_value.ShowIndex()] = 0;
value[common_feature_value.ClickIndex()] = 0;
value[common_feature_value.SlotIndex()] = -1;
bool zero_init = _config.ctr_accessor_param().zero_init();
_embed_sgd_rule->InitValue(value + common_feature_value.EmbedWIndex(),
value + common_feature_value.EmbedG2SumIndex(),
zero_init);
_embedx_sgd_rule->InitValue(value + common_feature_value.EmbedxWIndex(),
value + common_feature_value.EmbedxG2SumIndex(),
false);
}
return 0;
}
bool CtrCommonAccessor::NeedExtendMF(float* value) {
float show = value[common_feature_value.ShowIndex()];
float click = value[common_feature_value.ClickIndex()];
float score = (show - click) * _config.ctr_accessor_param().nonclk_coeff() +
click * _config.ctr_accessor_param().click_coeff();
return score >= _config.embedx_threshold();
}
bool CtrCommonAccessor::HasMF(int size) {
return size > common_feature_value.EmbedxG2SumIndex();
}
// from CommonFeatureValue to CtrCommonPullValue
int32_t CtrCommonAccessor::Select(float** select_values,
const float** values,
size_t num) {
auto embedx_dim = _config.embedx_dim();
for (size_t value_item = 0; value_item < num; ++value_item) {
float* select_value = select_values[value_item];
const float* value = values[value_item];
select_value[CtrCommonPullValue::ShowIndex()] =
value[common_feature_value.ShowIndex()];
select_value[CtrCommonPullValue::ClickIndex()] =
value[common_feature_value.ClickIndex()];
select_value[CtrCommonPullValue::EmbedWIndex()] =
value[common_feature_value.EmbedWIndex()];
memcpy(select_value + CtrCommonPullValue::EmbedxWIndex(),
value + common_feature_value.EmbedxWIndex(),
embedx_dim * sizeof(float));
}
return 0;
}
// from CtrCommonPushValue to CtrCommonPushValue
// first dim: item
// second dim: field num
int32_t CtrCommonAccessor::Merge(float** update_values,
const float** other_update_values,
size_t num) {
auto embedx_dim = _config.embedx_dim();
int total_dim = CtrCommonPushValue::Dim(embedx_dim);
for (size_t value_item = 0; value_item < num; ++value_item) {
float* update_value = update_values[value_item];
const float* other_update_value = other_update_values[value_item];
for (int i = 0; i < total_dim; ++i) {
if (i != CtrCommonPushValue::SlotIndex()) {
update_value[i] += other_update_value[i];
}
}
}
return 0;
}
// from CtrCommonPushValue to CommonFeatureValue
// first dim: item
// second dim: field num
int32_t CtrCommonAccessor::Update(float** update_values,
const float** push_values,
size_t num) {
for (size_t value_item = 0; value_item < num; ++value_item) {
float* update_value = update_values[value_item];
const float* push_value = push_values[value_item];
float push_show = push_value[CtrCommonPushValue::ShowIndex()];
float push_click = push_value[CtrCommonPushValue::ClickIndex()];
float slot = push_value[CtrCommonPushValue::SlotIndex()];
update_value[common_feature_value.ShowIndex()] += push_show;
update_value[common_feature_value.ClickIndex()] += push_click;
update_value[common_feature_value.SlotIndex()] = slot;
update_value[common_feature_value.DeltaScoreIndex()] +=
(push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() +
push_click * _config.ctr_accessor_param().click_coeff();
update_value[common_feature_value.UnseenDaysIndex()] = 0;
// TODO(zhaocaibei123): add configure show_scale
if (!_show_scale) {
push_show = 1;
}
VLOG(3) << "accessor show scale:" << _show_scale
<< ", push_show:" << push_show;
_embed_sgd_rule->UpdateValue(
update_value + common_feature_value.EmbedWIndex(),
update_value + common_feature_value.EmbedG2SumIndex(),
push_value + CtrCommonPushValue::EmbedGIndex(),
push_show);
_embedx_sgd_rule->UpdateValue(
update_value + common_feature_value.EmbedxWIndex(),
update_value + common_feature_value.EmbedxG2SumIndex(),
push_value + CtrCommonPushValue::EmbedxGIndex(),
push_show);
}
return 0;
}
bool CtrCommonAccessor::CreateValue(int stage, const float* value) {
// stage == 0, pull
// stage == 1, push
if (stage == 0) {
return true;
} else if (stage == 1) {
// operation
auto show = CtrCommonPushValue::Show(const_cast<float*>(value));
auto click = CtrCommonPushValue::Click(const_cast<float*>(value));
auto score = ShowClickScore(show, click);
if (score <= 0) {
return false;
}
if (score >= 1) {
return true;
}
return local_uniform_real_distribution<float>()(local_random_engine()) <
score;
} else {
return true;
}
}
float CtrCommonAccessor::ShowClickScore(float show, float click) {
auto nonclk_coeff = _config.ctr_accessor_param().nonclk_coeff();
auto click_coeff = _config.ctr_accessor_param().click_coeff();
return (show - click) * nonclk_coeff + click * click_coeff;
}
std::string CtrCommonAccessor::ParseToString(const float* v, int param) {
thread_local std::ostringstream os;
os.clear();
os.str("");
os << v[0] << " " << v[1] << " " << v[2] << " " << v[3] << " " << v[4] << " "
<< v[5];
for (int i = common_feature_value.EmbedG2SumIndex();
i < common_feature_value.EmbedxWIndex();
i++) {
os << " " << v[i];
}
auto show = common_feature_value.Show(const_cast<float*>(v));
auto click = common_feature_value.Click(const_cast<float*>(v));
auto score = ShowClickScore(show, click);
if (score >= _config.embedx_threshold() &&
param > common_feature_value.EmbedxWIndex()) {
for (auto i = common_feature_value.EmbedxWIndex();
i < common_feature_value.Dim();
++i) {
os << " " << v[i];
}
}
return os.str();
}
int CtrCommonAccessor::ParseFromString(const std::string& str, float* value) {
_embedx_sgd_rule->InitValue(value + common_feature_value.EmbedxWIndex(),
value + common_feature_value.EmbedxG2SumIndex());
auto ret = paddle::string::str_to_float(str.data(), value);
CHECK(ret >= 6) << "expect more than 6 real:" << ret;
return ret;
}
} // namespace distributed
} // namespace paddle
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment