Commit d2d32668 authored by yuguo960516yuguo's avatar yuguo960516yuguo
Browse files

2.3.0-dtk-22.04.2

parent ad08b8ce
Pipeline #226 failed with stages
in 0 seconds
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package paddle.distributed;
option cc_generic_services = true;
option cc_enable_arenas = true;
enum PsCmdID {
PS_PULL_DENSE_TABLE = 0;
PS_PUSH_DENSE_TABLE = 1;
PS_PULL_SPARSE_TABLE = 2;
PS_PUSH_SPARSE_TABLE = 3;
PS_SHRINK_TABLE = 4;
PS_SAVE_ONE_TABLE = 5;
PS_SAVE_ALL_TABLE = 6;
PS_LOAD_ONE_TABLE = 7;
PS_LOAD_ALL_TABLE = 8;
PS_CLEAR_ONE_TABLE = 9;
PS_CLEAR_ALL_TABLE = 10;
PS_PUSH_DENSE_PARAM = 11;
PS_STOP_SERVER = 12;
PS_SAVE_ONE_CACHE_TABLE = 13;
PS_GET_CACHE_THRESHOLD = 14;
PS_CACHE_SHUFFLE = 15;
PS_COPY_TABLE = 16;
PS_COPY_TABLE_BY_FEASIGN = 17;
PS_PULL_SPARSE_TABLE_WITH_DEPENDENCY = 18;
PS_PUSH_SPARSE_TABLE_WITH_DEPENDENCY = 19;
PS_PRINT_TABLE_STAT = 20;
PS_SAVE_ONE_TABLE_PREFIX = 21;
PS_SAVE_ONE_TABLE_WITH_WHITELIST = 22;
PS_LOAD_ONE_TABLE_WITH_WHITELIST = 23;
PS_PULL_GEO_PARAM = 24;
PS_BARRIER = 25;
PS_PUSH_SPARSE_PARAM = 26;
PS_START_PROFILER = 27;
PS_STOP_PROFILER = 28;
PS_PUSH_GLOBAL_STEP = 29;
PS_PULL_GRAPH_LIST = 30;
PS_GRAPH_SAMPLE_NEIGHBORS = 31;
PS_GRAPH_SAMPLE_NODES = 32;
PS_GRAPH_GET_NODE_FEAT = 33;
PS_GRAPH_CLEAR = 34;
PS_GRAPH_ADD_GRAPH_NODE = 35;
PS_GRAPH_REMOVE_GRAPH_NODE = 36;
PS_GRAPH_SET_NODE_FEAT = 37;
PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER = 38;
PS_GRAPH_USE_NEIGHBORS_SAMPLE_CACHE = 39;
PS_GRAPH_LOAD_GRAPH_SPLIT_CONFIG = 40;
PEER_ROLE_IS_WORKER = 41;
PEER_ROLE_IS_SWITCH = 42;
PS_SAVE_WITH_SCOPE = 43;
PS_SAVE_WITH_SHARD = 44;
PS_QUERY_WITH_SCOPE = 45;
PS_QUERY_WITH_SHARD = 46;
// pserver2pserver cmd start from 100
PS_S2S_MSG = 101;
}
message PsRequestMessage {
required uint32 cmd_id = 1;
optional uint32 table_id = 2;
repeated bytes params = 3;
optional int32 client_id = 4;
optional bytes data = 5;
};
message PsResponseMessage {
required int32 err_code = 1 [ default = 0 ];
required string err_msg = 2 [ default = "" ];
optional bytes data = 3;
};
enum VarType {
LOD_TENSOR = 0;
SELECTED_ROWS = 1;
}
message VariableMessage {
enum Type {
// Pod Types
BOOL = 0;
INT16 = 1;
INT32 = 2;
INT64 = 3;
FP16 = 4;
FP32 = 5;
FP64 = 6;
}
message LodData { repeated int64 lod_data = 1; }
optional string varname = 1;
// TODO(Yancey1989): reference framework::proto::VarDesc::VarType
optional VarType type = 2;
// bool persistable is not needed for sending.
// tensor info:
optional Type data_type = 3;
repeated int64 dims = 4;
// lod details:
optional int64 lod_level = 5;
repeated LodData lod = 6;
// selected_rows height, aka. original dim0
optional int64 slr_height = 7;
// tensor data
optional bytes data = 8;
}
// for SendAndRecv RPC method
message MultiVariableMessage {
// message flags
required string message_name = 1;
repeated string send_var_names = 2;
repeated string recv_var_names = 3;
repeated VariableMessage var_messages = 4;
optional bytes data = 5;
repeated int64 vars_len = 6;
optional int32 group_id = 7;
};
service PsService {
rpc service(PsRequestMessage) returns (PsResponseMessage);
rpc SendAndRecvVariable(MultiVariableMessage) returns (MultiVariableMessage);
rpc SendToWorker(MultiVariableMessage) returns (PsResponseMessage);
rpc SendToSwitch(MultiVariableMessage) returns (PsResponseMessage);
rpc SendS2S(MultiVariableMessage) returns (PsResponseMessage);
rpc RecvFromSwitch(MultiVariableMessage) returns (MultiVariableMessage);
};
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/service/server.h"
#include "glog/logging.h"
#include "paddle/fluid/distributed/ps/service/brpc_ps_server.h"
#include "paddle/fluid/distributed/ps/service/graph_brpc_server.h"
#include "paddle/fluid/distributed/ps/service/ps_local_server.h"
#include "paddle/fluid/distributed/ps/table/table.h"
namespace paddle {
namespace distributed {
REGISTER_PSCORE_CLASS(PSServer, BrpcPsServer);
REGISTER_PSCORE_CLASS(PSServer, PsLocalServer);
REGISTER_PSCORE_CLASS(PsBaseService, BrpcPsService);
REGISTER_PSCORE_CLASS(PSServer, GraphBrpcServer);
REGISTER_PSCORE_CLASS(PsBaseService, GraphBrpcService);
PSServer *PSServerFactory::Create(const PSParameter &ps_config) {
const auto &config = ps_config.server_param();
if (!config.has_downpour_server_param()) {
LOG(ERROR) << "miss downpour_server_param in ServerParameter";
return NULL;
}
if (!config.downpour_server_param().has_service_param()) {
LOG(ERROR) << "miss service_param in ServerParameter.downpour_server_param";
return NULL;
}
if (!config.downpour_server_param().service_param().has_server_class()) {
LOG(ERROR) << "miss server_class in "
"ServerParameter.downpour_server_param.service_param";
return NULL;
}
const auto &service_param = config.downpour_server_param().service_param();
PSServer *server =
CREATE_PSCORE_CLASS(PSServer, service_param.server_class());
if (server == NULL) {
LOG(ERROR) << "server is not registered, server_name:"
<< service_param.server_class();
return NULL;
}
TableManager::Instance().Initialize();
return server;
}
int32_t PSServer::Configure(
const PSParameter &config,
PSEnvironment &env,
size_t server_rank,
const std::vector<framework::ProgramDesc> &server_sub_program) {
scope_.reset(new framework::Scope());
_config = config.server_param();
_rank = server_rank;
_environment = &env;
_shuffled_ins =
paddle::framework::MakeChannel<std::pair<uint64_t, std::string>>();
size_t shard_num = env.GetPsServers().size();
const auto &downpour_param = _config.downpour_server_param();
uint32_t barrier_table = UINT32_MAX;
uint32_t global_step_table = UINT32_MAX;
for (int i = 0; i < downpour_param.downpour_table_param_size(); ++i) {
auto *table = CREATE_PSCORE_CLASS(
Table, downpour_param.downpour_table_param(i).table_class());
if (downpour_param.downpour_table_param(i).table_class() ==
"BarrierTable") {
barrier_table = downpour_param.downpour_table_param(i).table_id();
}
if (downpour_param.downpour_table_param(i).table_class() ==
"GlobalStepTable") {
global_step_table = downpour_param.downpour_table_param(i).table_id();
}
table->SetProgramEnv(scope_.get(), place_, &server_sub_program);
table->SetShard(_rank, shard_num);
table->Initialize(downpour_param.downpour_table_param(i),
config.fs_client_param());
_table_map[downpour_param.downpour_table_param(i).table_id()].reset(table);
}
if (barrier_table != UINT32_MAX) {
_table_map[barrier_table]->SetTableMap(&_table_map);
}
if (global_step_table != UINT32_MAX) {
_table_map[global_step_table]->SetTableMap(&_table_map);
}
return Initialize();
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <future>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "butil/endpoint.h"
#include "google/protobuf/service.h"
#include "paddle/fluid/distributed/common/registerer.h"
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/ps/service/env.h"
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
namespace google {
namespace protobuf {
class RpcController;
} // namespace protobuf
} // namespace google
namespace paddle {
namespace distributed {
class PSEnvironment;
} // namespace distributed
} // namespace paddle
namespace paddle {
namespace framework {
class Executor;
class ProgramDesc;
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace distributed {
class Table;
using paddle::distributed::PsRequestMessage;
using paddle::distributed::PsResponseMessage;
class PSServer {
public:
PSServer() {}
virtual ~PSServer() {}
PSServer(PSServer &&) = delete;
PSServer(const PSServer &) = delete;
virtual int32_t Configure(
const PSParameter &config,
PSEnvironment &env,
size_t server_rank,
const std::vector<framework::ProgramDesc> &server_sub_program = {});
virtual uint64_t Start(const std::string &ip, uint32_t port) = 0;
virtual int32_t Stop() = 0;
inline size_t Rank() const { return _rank; }
inline PSEnvironment *Environment() { return _environment; }
inline const ServerParameter *Config() const { return &_config; }
inline Table *GetTable(size_t table_id) {
auto itr = _table_map.find(table_id);
if (itr != _table_map.end()) {
return itr->second.get();
}
return NULL;
}
inline std::unordered_map<uint32_t, std::shared_ptr<Table>> *GetTable() {
return &_table_map;
}
// for cache
virtual int32_t StartS2S() { return 0; }
virtual ::std::future<int32_t> SendPServer2PServerMsg(
int msg_type, int to_pserver_id, const std::string &msg) {
LOG(FATAL) << "NotImplementError: PSServer::send_pserver2pserver_msg";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
typedef std::function<int32_t(int, int, const std::string &)> MsgHandlerFunc;
virtual int RegistePServer2PServerMsgHandler(int msg_type,
MsgHandlerFunc handler) {
_msg_handler_map[msg_type] = handler;
return 0;
}
virtual int HandlePServer2PServerMsg(int msg_type,
int from_pserver_id,
const std::string &msg) {
auto itr = _msg_handler_map.find(msg_type);
if (itr == _msg_handler_map.end()) {
if (msg_type == 101) {
return ReceiveFromPServer(msg_type, from_pserver_id, msg);
} else {
LOG(WARNING) << "unknown pserver2pserver_msg type:" << msg_type;
return -1;
}
}
return itr->second(msg_type, from_pserver_id, msg);
}
virtual int32_t ReceiveFromPServer(int msg_type,
int pserver_id,
const std::string &msg) {
LOG(FATAL) << "NotImplementError::PSServer::ReceiveFromPServer";
return -1;
}
paddle::framework::Channel<std::pair<uint64_t, std::string>> _shuffled_ins;
protected:
virtual int32_t Initialize() = 0;
protected:
size_t _rank;
ServerParameter _config;
PSEnvironment *_environment;
std::unordered_map<uint32_t, std::shared_ptr<Table>> _table_map;
std::unordered_map<int32_t, MsgHandlerFunc> _msg_handler_map;
protected:
std::shared_ptr<framework::Scope> scope_;
platform::Place place_ = platform::CPUPlace();
};
REGISTER_PSCORE_REGISTERER(PSServer);
typedef std::function<void(void *)> PServerCallBack;
class PServerClosure : public google::protobuf::Closure {
public:
PServerClosure(PServerCallBack callback) : _callback(callback) {}
virtual ~PServerClosure() {}
virtual void set_promise_value(int value) {
for (auto &promise : _promises) {
promise->set_value(value);
}
}
void add_promise(std::shared_ptr<std::promise<int32_t>> &promise) {
_promises.push_back(promise);
}
protected:
PServerCallBack _callback;
std::vector<std::shared_ptr<std::promise<int32_t>>> _promises;
};
class PsBaseService : public PsService {
public:
PsBaseService() : _rank(0), _server(NULL), _config(NULL) {}
virtual ~PsBaseService() {}
virtual size_t GetRank() { return _rank; }
virtual int32_t Configure(PSServer *server) {
_server = server;
_rank = _server->Rank();
_config = _server->Config();
return 0;
}
virtual void service(::google::protobuf::RpcController *controller,
const PsRequestMessage *request,
PsResponseMessage *response,
::google::protobuf::Closure *done) override = 0;
virtual void set_response_code(PsResponseMessage &response,
int err_code,
const char *err_msg) {
response.set_err_msg(err_msg);
response.set_err_code(err_code);
LOG(WARNING) << "Resonse err_code:" << err_code << " msg:" << err_msg;
}
virtual int32_t Initialize() = 0;
PSServer *GetServer() { return _server; }
protected:
size_t _rank;
PSServer *_server;
const ServerParameter *_config;
};
REGISTER_PSCORE_REGISTERER(PsBaseService);
class PSServerFactory {
public:
static PSServer *Create(const PSParameter &config);
};
} // namespace distributed
} // namespace paddle
set_property(GLOBAL PROPERTY TABLE_DEPS string_helper)
set(graphDir graph)
get_property(TABLE_DEPS GLOBAL PROPERTY TABLE_DEPS)
set_source_files_properties(
${graphDir}/graph_edge.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS})
cc_library(graph_edge SRCS ${graphDir}/graph_edge.cc)
set_source_files_properties(
${graphDir}/graph_weighted_sampler.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS})
cc_library(
WeightedSampler
SRCS ${graphDir}/graph_weighted_sampler.cc
DEPS graph_edge)
set_source_files_properties(
${graphDir}/graph_node.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS})
cc_library(
graph_node
SRCS ${graphDir}/graph_node.cc
DEPS WeightedSampler)
set_source_files_properties(
memory_dense_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
barrier_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
common_graph_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS)
set(PADDLE_LIB_THIRD_PARTY_PATH "${PADDLE_LIB}/third_party/")
include_directories(
${PADDLE_LIB_THIRD_PARTY_PATH}libmct/src/extern_libmct/libmct/include)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp")
set(TABLE_SRC memory_dense_table.cc barrier_table.cc common_graph_table.cc)
#set(EXTERN_DEP rocksdb)
cc_library(
common_table
SRCS ${TABLE_SRC}
DEPS ${TABLE_DEPS}
${RPC_DEPS}
graph_edge
graph_node
device_context
string_helper
simple_threadpool
xxhash
generator)
set_source_files_properties(
tensor_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(
tensor_accessor
SRCS tensor_accessor.cc
DEPS ${TABLE_DEPS} eigen3 ps_framework_proto device_context)
cc_library(
tensor_table
SRCS
DEPS eigen3
ps_framework_proto
executor
scope
device_context
tensor
${TABLE_DEPS})
set_source_files_properties(table.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
sparse_sgd_rule.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
ctr_double_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
ctr_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
sparse_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
ctr_dymf_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
memory_sparse_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
ssd_sparse_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
memory_sparse_geo_table.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS})
cc_library(
sparse_sgd_rule
SRCS sparse_sgd_rule.cc
DEPS ${TABLE_DEPS} ps_framework_proto)
cc_library(
ctr_accessor
SRCS ctr_accessor.cc ctr_double_accessor.cc sparse_accessor.cc
ctr_dymf_accessor.cc
DEPS ${TABLE_DEPS} ps_framework_proto sparse_sgd_rule)
cc_library(
sparse_table
SRCS memory_sparse_table.cc ssd_sparse_table.cc memory_sparse_geo_table.cc
DEPS ps_framework_proto
${TABLE_DEPS}
fs
afs_wrapper
ctr_accessor
common_table
rocksdb)
cc_library(
table
SRCS table.cc
DEPS sparse_table
common_table
tensor_accessor
tensor_table
ps_framework_proto
string_helper
device_context
gflags
glog
boost)
target_link_libraries(table -fopenmp)
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdint.h>
#include <stdio.h>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/common/afs_warpper.h"
#include "paddle/fluid/distributed/common/registerer.h"
#include "paddle/fluid/distributed/ps.pb.h"
namespace paddle {
namespace distributed {
struct Region {
Region() : data(NULL), size(0) {}
Region(char* data, size_t data_num) : data(data), size(data_num) {}
Region(float* data, size_t data_num)
: data(reinterpret_cast<char*>(data)), size(data_num << 2) {}
Region(int16_t* data, size_t data_num)
: data(reinterpret_cast<char*>(data)), size(data_num << 1) {}
Region(int32_t* data, size_t data_num)
: data(reinterpret_cast<char*>(data)), size(data_num << 2) {}
Region(int64_t* data, size_t data_num)
: data(reinterpret_cast<char*>(data)), size(data_num << 3) {}
char* data;
size_t size;
};
struct DataConverter {
int param;
std::string converter;
std::string deconverter;
};
struct AccessorInfo {
// value维度
size_t dim;
// value各个维度的size
size_t size;
// pull value维度
size_t select_dim;
// pull value各维度相加总size
size_t select_size;
// push value维度
size_t update_dim;
// push value各个维度的size
size_t update_size;
// value中mf动态长度部分总size大小, sparse下生效
size_t mf_size;
// value总维度,dense下生效
size_t fea_dim;
};
class ValueAccessor {
public:
ValueAccessor() {}
virtual ~ValueAccessor() {}
virtual int Configure(const TableAccessorParameter& parameter) {
_config = parameter;
// data_convert结构体初始化
if (_config.table_accessor_save_param_size() != 0) {
for (int i = 0; i < _config.table_accessor_save_param_size(); ++i) {
int param = _config.table_accessor_save_param(i).param();
std::string converter =
_config.table_accessor_save_param(i).converter();
std::string deconverter =
_config.table_accessor_save_param(i).deconverter();
_data_coverter_map[param] = std::make_shared<DataConverter>();
*(_data_coverter_map[param]) = {param, converter, deconverter};
}
}
return 0;
}
virtual int Initialize() = 0;
virtual AccessorInfo GetAccessorInfo() { return _accessor_info; }
virtual bool NeedExtendMF(float* value) { return false; }
virtual bool HasMF(size_t size) { return false; }
// converter for save
virtual std::string GetConverter(int param) {
auto itr = _data_coverter_map.find(param);
if (itr == _data_coverter_map.end()) {
return "";
} else {
return (*itr).second->converter;
}
}
// deconverter for load
virtual std::string GetDeconverter(int param) {
auto itr = _data_coverter_map.find(param);
if (itr == _data_coverter_map.end()) {
return "";
} else {
return (*itr).second->deconverter;
}
}
// 判断该value是否进行shrink
virtual bool Shrink(float* value) = 0;
// 判断该value是否在save阶段dump,
// param作为参数用于标识save阶段,如downpour的xbox与batch_model
virtual bool Save(float* value, int param) = 0;
// update delta_score and unseen_days after save
virtual void UpdateStatAfterSave(float* value, int param) {}
// 判断该value是否保存到ssd
virtual bool SaveSSD(float* value) = 0;
//
virtual bool SaveCache(float* value,
int param,
double global_cache_threshold) = 0;
// keys不存在时,为values生成随机值
virtual int32_t Create(float** value, size_t num) = 0;
virtual bool CreateValue(int type, const float* value) { return true; }
// 从values中选取到select_values中
virtual int32_t Select(float** select_values,
const float** values,
size_t num) = 0;
// 将update_values聚合到一起
virtual int32_t Merge(float** update_values,
const float** other_update_values,
size_t num) = 0;
// 将update_values聚合到一起,通过it.next判定是否进入下一个key
// virtual int32_t Merge(float** update_values, iterator it);
// 将update_values更新应用到values中
virtual int32_t Update(float** values,
const float** update_values,
size_t num) = 0;
// used to save model, will filter feature
virtual std::string ParseToString(const float* value, int param) = 0;
// parse value from string, used to load model
virtual int32_t ParseFromString(const std::string& data, float* value) = 0;
virtual FsDataConverter Converter(int param) {
FsDataConverter data_convert;
data_convert.converter = this->GetConverter(param);
data_convert.deconverter = this->GetDeconverter(param);
return data_convert;
}
virtual int SetWeight(float** values,
const float** update_values,
size_t num) {
return 0;
}
virtual float GetField(float* value, const std::string& name) { return 0.0; }
#define DEFINE_GET_INDEX(class, field) \
virtual int get_##field##_index() override { return class ::field##_index(); }
protected:
size_t _value_size;
size_t _select_value_size;
size_t _update_value_size;
TableAccessorParameter _config;
std::unordered_map<int, std::shared_ptr<struct DataConverter>>
_data_coverter_map;
AccessorInfo _accessor_info;
};
REGISTER_PSCORE_REGISTERER(ValueAccessor);
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/table/common_table.h"
namespace paddle {
namespace distributed {
int32_t BarrierTable::Initialize() {
auto trainers = _config.common().trainer_num();
trigger_.store(trainers);
for (int x = 0; x < trainers; ++x) {
trainer_all_.insert(x);
}
VLOG(1) << "BarrierTable init trigger: " << trigger_.load();
return 0;
}
// 0: send_barrier 1: recv_barrier 2: complete
int32_t BarrierTable::Barrier(const uint32_t trainer_id,
const std::string barrier_type) {
std::unique_lock<std::mutex> lock(mutex_);
if (barrier_type == "2") {
trigger_.fetch_sub(1, std::memory_order::memory_order_relaxed);
VLOG(1) << "trigger sub to : " << trigger_.load();
} else {
trainer_ids_.insert(trainer_id);
VLOG(1) << "barrier type: " << barrier_type
<< " add trainer id: " << trainer_id;
}
if (static_cast<int>(trainer_ids_.size()) < trigger_.load()) {
std::vector<uint32_t> diffs(trainer_all_.size());
auto iter = std::set_difference(trainer_all_.begin(),
trainer_all_.end(),
trainer_ids_.begin(),
trainer_ids_.end(),
diffs.begin());
diffs.resize(iter - diffs.begin());
auto diff = to_string<uint32_t>(diffs);
VLOG(1) << "still need trainers: " << diff;
trainer_wait_.wait(lock, [&] { return trainer_ids_.size() == 0; });
} else {
VLOG(1) << "barrier table optimize begin";
for (auto& x : *table_map_) {
auto table = x.second;
table->Pour();
}
VLOG(1) << "barrier table optimize done";
trainer_ids_.clear();
trainer_wait_.notify_all();
}
return 0;
}
int32_t BarrierTable::SetTableMap(
std::unordered_map<uint32_t, std::shared_ptr<Table>>* table_map) {
table_map_ = table_map;
return 0;
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/table/common_graph_table.h"
#include <time.h>
#include <algorithm>
#include <chrono>
#include <set>
#include <sstream>
#include "paddle/fluid/distributed/common/utils.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace distributed {
#ifdef PADDLE_WITH_HETERPS
int32_t GraphTable::Load_to_ssd(const std::string &path,
const std::string &param) {
bool load_edge = (param[0] == 'e');
bool load_node = (param[0] == 'n');
if (load_edge) {
bool reverse_edge = (param[1] == '<');
std::string edge_type = param.substr(2);
return this->load_edges_to_ssd(path, reverse_edge, edge_type);
}
if (load_node) {
std::string node_type = param.substr(1);
return this->load_nodes(path, node_type);
}
return 0;
}
paddle::framework::GpuPsCommGraph GraphTable::make_gpu_ps_graph(
int idx, std::vector<int64_t> ids) {
std::vector<std::vector<int64_t>> bags(task_pool_size_);
for (auto x : ids) {
int location = x % shard_num % task_pool_size_;
bags[location].push_back(x);
}
std::vector<std::future<int>> tasks;
std::vector<int64_t> edge_array[task_pool_size_];
std::vector<paddle::framework::GpuPsGraphNode> node_array[task_pool_size_];
for (size_t i = 0; i < bags.size(); i++) {
if (bags[i].size() > 0) {
tasks.push_back(_shards_task_pool[i]->enqueue([&, i, this]() -> int {
paddle::framework::GpuPsGraphNode x;
for (size_t j = 0; j < bags[i].size(); j++) {
Node *v = find_node(0, idx, bags[i][j]);
x.node_id = bags[i][j];
if (v == NULL) {
x.neighbor_size = 0;
x.neighbor_offset = 0;
node_array[i].push_back(x);
} else {
x.neighbor_size = v->get_neighbor_size();
x.neighbor_offset = edge_array[i].size();
node_array[i].push_back(x);
for (size_t k = 0; k < x.neighbor_size; k++) {
edge_array[i].push_back(v->get_neighbor_id(k));
}
}
}
return 0;
}));
}
}
for (int i = 0; i < (int)tasks.size(); i++) tasks[i].get();
paddle::framework::GpuPsCommGraph res;
int64_t tot_len = 0;
for (int i = 0; i < task_pool_size_; i++) {
tot_len += edge_array[i].size();
}
// res.neighbor_size = tot_len;
// res.node_size = ids.size();
// res.neighbor_list = new int64_t[tot_len];
// res.node_list = new paddle::framework::GpuPsGraphNode[ids.size()];
res.init_on_cpu(tot_len, ids.size());
int64_t offset = 0, ind = 0;
for (int i = 0; i < task_pool_size_; i++) {
for (int j = 0; j < (int)node_array[i].size(); j++) {
res.node_list[ind] = node_array[i][j];
res.node_list[ind++].neighbor_offset += offset;
}
for (size_t j = 0; j < edge_array[i].size(); j++) {
res.neighbor_list[offset + j] = edge_array[i][j];
}
offset += edge_array[i].size();
}
return res;
}
int32_t GraphTable::add_node_to_ssd(
int type_id, int idx, int64_t src_id, char *data, int len) {
if (_db != NULL) {
char ch[sizeof(int) * 2 + sizeof(int64_t)];
memcpy(ch, &type_id, sizeof(int));
memcpy(ch + sizeof(int), &idx, sizeof(int));
memcpy(ch + sizeof(int) * 2, &src_id, sizeof(int64_t));
std::string str;
if (_db->get(src_id % shard_num % task_pool_size_,
ch,
sizeof(int) * 2 + sizeof(int64_t),
str) == 0) {
int64_t *stored_data = ((int64_t *)str.c_str());
int n = str.size() / sizeof(int64_t);
char *new_data = new char[n * sizeof(int64_t) + len];
memcpy(new_data, stored_data, n * sizeof(int64_t));
memcpy(new_data + n * sizeof(int64_t), data, len);
_db->put(src_id % shard_num % task_pool_size_,
ch,
sizeof(int) * 2 + sizeof(int64_t),
(char *)new_data,
n * sizeof(int64_t) + len);
delete[] new_data;
} else {
_db->put(src_id % shard_num % task_pool_size_,
ch,
sizeof(int) * 2 + sizeof(int64_t),
(char *)data,
len);
}
// _db->flush(src_id % shard_num % task_pool_size_);
// std::string x;
// if (_db->get(src_id % shard_num % task_pool_size_, ch, sizeof(int64_t) +
// 2 * sizeof(int), x) ==0){
// VLOG(0)<<"put result";
// for(int i = 0;i < x.size();i+=8){
// VLOG(0)<<"get an id "<<*((int64_t *)(x.c_str() + i));
// }
//}
// if(src_id == 429){
// str = "";
// _db->get(src_id % shard_num % task_pool_size_, ch,
// sizeof(int) * 2 + sizeof(int64_t), str);
// int64_t *stored_data = ((int64_t *)str.c_str());
// int n = str.size() / sizeof(int64_t);
// VLOG(0)<<"429 has "<<n<<"neighbors";
// for(int i =0;i< n;i++){
// VLOG(0)<<"get an id "<<*((int64_t *)(str.c_str() +
// i*sizeof(int64_t)));
// }
// }
}
return 0;
}
char *GraphTable::random_sample_neighbor_from_ssd(
int idx,
int64_t id,
int sample_size,
const std::shared_ptr<std::mt19937_64> rng,
int &actual_size) {
if (_db == NULL) {
actual_size = 0;
return NULL;
}
std::string str;
VLOG(2) << "sample ssd for key " << id;
char ch[sizeof(int) * 2 + sizeof(int64_t)];
memset(ch, 0, sizeof(int));
memcpy(ch + sizeof(int), &idx, sizeof(int));
memcpy(ch + sizeof(int) * 2, &id, sizeof(int64_t));
if (_db->get(id % shard_num % task_pool_size_,
ch,
sizeof(int) * 2 + sizeof(int64_t),
str) == 0) {
int64_t *data = ((int64_t *)str.c_str());
int n = str.size() / sizeof(int64_t);
std::unordered_map<int, int> m;
// std::vector<int64_t> res;
int sm_size = std::min(n, sample_size);
actual_size = sm_size * Node::id_size;
char *buff = new char[actual_size];
for (int i = 0; i < sm_size; i++) {
std::uniform_int_distribution<int> distrib(0, n - i - 1);
int t = distrib(*rng);
// int t = rand() % (n-i);
int pos = 0;
auto iter = m.find(t);
if (iter != m.end()) {
pos = iter->second;
} else {
pos = t;
}
auto iter2 = m.find(n - i - 1);
int key2 = iter2 == m.end() ? n - i - 1 : iter2->second;
m[t] = key2;
m.erase(n - i - 1);
memcpy(buff + i * Node::id_size, &data[pos], Node::id_size);
// res.push_back(data[pos]);
}
for (int i = 0; i < actual_size; i += 8) {
VLOG(2) << "sampled an neighbor " << *(int64_t *)&buff[i];
}
return buff;
}
actual_size = 0;
return NULL;
}
int64_t GraphTable::load_graph_to_memory_from_ssd(int idx,
std::vector<int64_t> &ids) {
std::vector<std::vector<int64_t>> bags(task_pool_size_);
for (auto x : ids) {
int location = x % shard_num % task_pool_size_;
bags[location].push_back(x);
}
std::vector<std::future<int>> tasks;
std::vector<int64_t> count(task_pool_size_, 0);
for (size_t i = 0; i < bags.size(); i++) {
if (bags[i].size() > 0) {
tasks.push_back(_shards_task_pool[i]->enqueue([&, i, idx, this]() -> int {
char ch[sizeof(int) * 2 + sizeof(int64_t)];
memset(ch, 0, sizeof(int));
memcpy(ch + sizeof(int), &idx, sizeof(int));
for (size_t k = 0; k < bags[i].size(); k++) {
auto v = bags[i][k];
memcpy(ch + sizeof(int) * 2, &v, sizeof(int64_t));
std::string str;
if (_db->get(i, ch, sizeof(int) * 2 + sizeof(int64_t), str) == 0) {
count[i] += (int64_t)str.size();
for (int j = 0; j < str.size(); j += sizeof(int64_t)) {
int64_t id = *(int64_t *)(str.c_str() + j);
add_comm_edge(idx, v, id);
}
}
}
return 0;
}));
}
}
for (int i = 0; i < (int)tasks.size(); i++) tasks[i].get();
int64_t tot = 0;
for (auto x : count) tot += x;
return tot;
}
void GraphTable::make_partitions(int idx, int64_t byte_size, int device_len) {
VLOG(2) << "start to make graph partitions , byte_size = " << byte_size
<< " total memory cost = " << total_memory_cost;
if (total_memory_cost == 0) {
VLOG(0) << "no edges are detected,make partitions exits";
return;
}
auto &weight_map = node_weight[0][idx];
const double a = 2.0, y = 1.25, weight_param = 1.0;
int64_t gb_size_by_discount = byte_size * 0.8 * device_len;
if (gb_size_by_discount <= 0) gb_size_by_discount = 1;
int part_len = total_memory_cost / gb_size_by_discount;
if (part_len == 0) part_len = 1;
VLOG(2) << "part_len = " << part_len
<< " byte size = " << gb_size_by_discount;
partitions[idx].clear();
partitions[idx].resize(part_len);
std::vector<double> weight_cost(part_len, 0);
std::vector<int64_t> memory_remaining(part_len, gb_size_by_discount);
std::vector<double> score(part_len, 0);
std::unordered_map<int64_t, int> id_map;
std::vector<rocksdb::Iterator *> iters;
for (int i = 0; i < task_pool_size_; i++) {
iters.push_back(_db->get_iterator(i));
iters[i]->SeekToFirst();
}
int next = 0;
while (iters.size()) {
if (next >= iters.size()) {
next = 0;
}
if (!iters[next]->Valid()) {
iters.erase(iters.begin() + next);
continue;
}
std::string key = iters[next]->key().ToString();
int type_idx = *(int *)key.c_str();
int temp_idx = *(int *)(key.c_str() + sizeof(int));
if (type_idx != 0 || temp_idx != idx) {
iters[next]->Next();
next++;
continue;
}
std::string value = iters[next]->value().ToString();
std::int64_t i_key = *(int64_t *)(key.c_str() + sizeof(int) * 2);
for (int i = 0; i < part_len; i++) {
if (memory_remaining[i] < (int64_t)value.size()) {
score[i] = -100000.0;
} else {
score[i] = 0;
}
}
for (int j = 0; j < value.size(); j += sizeof(int64_t)) {
int64_t v = *((int64_t *)(value.c_str() + j));
int index = -1;
if (id_map.find(v) != id_map.end()) {
index = id_map[v];
score[index]++;
}
}
double base, weight_base = 0;
double w = 0;
bool has_weight = false;
if (weight_map.find(i_key) != weight_map.end()) {
w = weight_map[i_key];
has_weight = true;
}
int index = 0;
for (int i = 0; i < part_len; i++) {
base = gb_size_by_discount - memory_remaining[i] + value.size();
if (has_weight)
weight_base = weight_cost[i] + w * weight_param;
else {
weight_base = 0;
}
score[i] -= a * y * std::pow(1.0 * base, y - 1) + weight_base;
if (score[i] > score[index]) index = i;
VLOG(2) << "score" << i << " = " << score[i] << " memory left "
<< memory_remaining[i];
}
id_map[i_key] = index;
partitions[idx][index].push_back(i_key);
memory_remaining[index] -= (int64_t)value.size();
if (has_weight) weight_cost[index] += w;
iters[next]->Next();
next++;
}
for (int i = 0; i < part_len; i++) {
if (partitions[idx][i].size() == 0) {
partitions[idx].erase(partitions[idx].begin() + i);
i--;
part_len--;
continue;
}
VLOG(2) << " partition " << i << " size = " << partitions[idx][i].size();
for (auto x : partitions[idx][i]) {
VLOG(2) << "find a id " << x;
}
}
next_partition = 0;
}
void GraphTable::export_partition_files(int idx, std::string file_path) {
int part_len = partitions[idx].size();
if (part_len == 0) return;
if (file_path == "") file_path = ".";
if (file_path[(int)file_path.size() - 1] != '/') {
file_path += "/";
}
std::vector<std::future<int>> tasks;
for (int i = 0; i < part_len; i++) {
tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue(
[&, i, idx, this]() -> int {
std::string output_path =
file_path + "partition_" + std::to_string(i);
std::ofstream ofs(output_path);
if (ofs.fail()) {
VLOG(0) << "creating " << output_path << " failed";
return 0;
}
for (auto x : partitions[idx][i]) {
auto str = std::to_string(x);
ofs.write(str.c_str(), str.size());
ofs.write("\n", 1);
}
ofs.close();
return 0;
}));
}
for (int i = 0; i < (int)tasks.size(); i++) tasks[i].get();
}
void GraphTable::clear_graph(int idx) {
for (auto p : edge_shards[idx]) {
delete p;
}
edge_shards[idx].clear();
for (size_t i = 0; i < shard_num_per_server; i++) {
edge_shards[idx].push_back(new GraphShard());
}
}
int32_t GraphTable::load_next_partition(int idx) {
if (next_partition >= partitions[idx].size()) {
VLOG(0) << "partition iteration is done";
return -1;
}
clear_graph(idx);
load_graph_to_memory_from_ssd(idx, partitions[idx][next_partition]);
next_partition++;
return 0;
}
int32_t GraphTable::load_edges_to_ssd(const std::string &path,
bool reverse_edge,
const std::string &edge_type) {
int idx = 0;
if (edge_type == "") {
VLOG(0) << "edge_type not specified, loading edges to " << id_to_edge[0]
<< " part";
} else {
if (edge_to_id.find(edge_type) == edge_to_id.end()) {
VLOG(0) << "edge_type " << edge_type
<< " is not defined, nothing will be loaded";
return 0;
}
idx = edge_to_id[edge_type];
}
total_memory_cost = 0;
auto paths = paddle::string::split_string<std::string>(path, ";");
int64_t count = 0;
std::string sample_type = "random";
bool is_weighted = false;
int valid_count = 0;
for (auto path : paths) {
std::ifstream file(path);
std::string line;
while (std::getline(file, line)) {
VLOG(0) << "get a line from file " << line;
auto values = paddle::string::split_string<std::string>(line, "\t");
count++;
if (values.size() < 2) continue;
auto src_id = std::stoll(values[0]);
auto dist_ids = paddle::string::split_string<std::string>(values[1], ";");
std::vector<int64_t> dist_data;
for (auto x : dist_ids) {
dist_data.push_back(std::stoll(x));
total_memory_cost += sizeof(int64_t);
}
add_node_to_ssd(0,
idx,
src_id,
(char *)dist_data.data(),
(int)(dist_data.size() * sizeof(int64_t)));
}
}
VLOG(0) << "total memory cost = " << total_memory_cost << " bytes";
return 0;
}
int32_t GraphTable::dump_edges_to_ssd(int idx) {
VLOG(2) << "calling dump edges to ssd";
const int64_t fixed_size = 10000;
// std::vector<int64_t> edge_array[task_pool_size_];
std::vector<std::unordered_map<int64_t, int>> count(task_pool_size_);
std::vector<std::future<int64_t>> tasks;
auto &shards = edge_shards[idx];
for (size_t i = 0; i < shards.size(); ++i) {
tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue(
[&, i, this]() -> int64_t {
int64_t cost = 0;
std::vector<Node *> &v = shards[i]->get_bucket();
size_t ind = i % this->task_pool_size_;
for (size_t j = 0; j < v.size(); j++) {
std::vector<int64_t> s;
for (int k = 0; k < v[j]->get_neighbor_size(); k++) {
s.push_back(v[j]->get_neighbor_id(k));
}
cost += v[j]->get_neighbor_size() * sizeof(int64_t);
add_node_to_ssd(0,
idx,
v[j]->get_id(),
(char *)s.data(),
s.size() * sizeof(int64_t));
}
return cost;
}));
}
for (size_t i = 0; i < tasks.size(); i++) total_memory_cost += tasks[i].get();
return 0;
}
int32_t GraphTable::make_complementary_graph(int idx, int64_t byte_size) {
VLOG(0) << "make_complementary_graph";
const int64_t fixed_size = byte_size / 8;
// std::vector<int64_t> edge_array[task_pool_size_];
std::vector<std::unordered_map<int64_t, int>> count(task_pool_size_);
std::vector<std::future<int>> tasks;
auto &shards = edge_shards[idx];
for (size_t i = 0; i < shards.size(); ++i) {
tasks.push_back(
_shards_task_pool[i % task_pool_size_]->enqueue([&, i, this]() -> int {
std::vector<Node *> &v = shards[i]->get_bucket();
size_t ind = i % this->task_pool_size_;
for (size_t j = 0; j < v.size(); j++) {
// size_t location = v[j]->get_id();
for (int k = 0; k < v[j]->get_neighbor_size(); k++) {
count[ind][v[j]->get_neighbor_id(k)]++;
}
}
return 0;
}));
}
for (size_t i = 0; i < tasks.size(); i++) tasks[i].get();
std::unordered_map<int64_t, int> final_count;
std::map<int, std::vector<int64_t>> count_to_id;
std::vector<int64_t> buffer;
clear_graph(idx);
for (int i = 0; i < task_pool_size_; i++) {
for (auto &p : count[i]) {
final_count[p.first] = final_count[p.first] + p.second;
}
count[i].clear();
}
for (auto &p : final_count) {
count_to_id[p.second].push_back(p.first);
VLOG(2) << p.first << " appear " << p.second << " times";
}
auto iter = count_to_id.rbegin();
while (iter != count_to_id.rend() && byte_size > 0) {
for (auto x : iter->second) {
buffer.push_back(x);
if (buffer.size() >= fixed_size) {
int64_t res = load_graph_to_memory_from_ssd(idx, buffer);
buffer.clear();
byte_size -= res;
}
if (byte_size <= 0) break;
}
iter++;
}
if (byte_size > 0 && buffer.size() > 0) {
int64_t res = load_graph_to_memory_from_ssd(idx, buffer);
byte_size -= res;
}
std::string sample_type = "random";
for (auto &shard : edge_shards[idx]) {
auto bucket = shard->get_bucket();
for (size_t i = 0; i < bucket.size(); i++) {
bucket[i]->build_sampler(sample_type);
}
}
return 0;
}
#endif
/*
int CompleteGraphSampler::run_graph_sampling() {
pthread_rwlock_t *rw_lock = graph_table->rw_lock.get();
pthread_rwlock_rdlock(rw_lock);
std::cout << "in graph sampling" << std::endl;
sample_nodes.clear();
sample_neighbors.clear();
sample_res.clear();
sample_nodes.resize(gpu_num);
sample_neighbors.resize(gpu_num);
sample_res.resize(gpu_num);
std::vector<std::vector<std::vector<paddle::framework::GpuPsGraphNode>>>
sample_nodes_ex(graph_table->task_pool_size_);
std::vector<std::vector<std::vector<int64_t>>> sample_neighbors_ex(
graph_table->task_pool_size_);
for (int i = 0; i < graph_table->task_pool_size_; i++) {
sample_nodes_ex[i].resize(gpu_num);
sample_neighbors_ex[i].resize(gpu_num);
}
std::vector<std::future<int>> tasks;
for (size_t i = 0; i < graph_table->shards.size(); ++i) {
tasks.push_back(
graph_table->_shards_task_pool[i % graph_table->task_pool_size_]
->enqueue([&, i, this]() -> int {
if (this->status == GraphSamplerStatus::terminating) return 0;
paddle::framework::GpuPsGraphNode node;
std::vector<Node *> &v =
this->graph_table->shards[i]->get_bucket();
size_t ind = i % this->graph_table->task_pool_size_;
for (size_t j = 0; j < v.size(); j++) {
size_t location = v[j]->get_id() % this->gpu_num;
node.node_id = v[j]->get_id();
node.neighbor_size = v[j]->get_neighbor_size();
node.neighbor_offset =
(int)sample_neighbors_ex[ind][location].size();
sample_nodes_ex[ind][location].emplace_back(node);
for (int k = 0; k < node.neighbor_size; k++)
sample_neighbors_ex[ind][location].push_back(
v[j]->get_neighbor_id(k));
}
return 0;
}));
}
for (size_t i = 0; i < tasks.size(); i++) tasks[i].get();
tasks.clear();
for (int i = 0; i < gpu_num; i++) {
tasks.push_back(
graph_table->_shards_task_pool[i % graph_table->task_pool_size_]
->enqueue([&, i, this]() -> int {
if (this->status == GraphSamplerStatus::terminating) return 0;
int total_offset = 0;
size_t ind = i % this->graph_table->task_pool_size_;
for (int j = 0; j < this->graph_table->task_pool_size_; j++) {
for (size_t k = 0; k < sample_nodes_ex[j][ind].size(); k++) {
sample_nodes[ind].push_back(sample_nodes_ex[j][ind][k]);
sample_nodes[ind].back().neighbor_offset += total_offset;
}
size_t neighbor_size = sample_neighbors_ex[j][ind].size();
total_offset += neighbor_size;
for (size_t k = 0; k < neighbor_size; k++) {
sample_neighbors[ind].push_back(
sample_neighbors_ex[j][ind][k]);
}
}
return 0;
}));
}
for (size_t i = 0; i < tasks.size(); i++) tasks[i].get();
if (this->status == GraphSamplerStatus::terminating) {
pthread_rwlock_unlock(rw_lock);
return 0;
}
for (int i = 0; i < gpu_num; i++) {
sample_res[i].node_list = sample_nodes[i].data();
sample_res[i].neighbor_list = sample_neighbors[i].data();
sample_res[i].node_size = sample_nodes[i].size();
sample_res[i].neighbor_size = sample_neighbors[i].size();
}
pthread_rwlock_unlock(rw_lock);
if (this->status == GraphSamplerStatus::terminating) {
return 0;
}
callback(sample_res);
return 0;
}
void CompleteGraphSampler::init(size_t gpu_num, GraphTable *graph_table,
std::vector<std::string> args) {
this->gpu_num = gpu_num;
this->graph_table = graph_table;
}
int BasicBfsGraphSampler::run_graph_sampling() {
pthread_rwlock_t *rw_lock = graph_table->rw_lock.get();
pthread_rwlock_rdlock(rw_lock);
while (rounds > 0 && status == GraphSamplerStatus::running) {
for (size_t i = 0; i < sample_neighbors_map.size(); i++) {
sample_neighbors_map[i].clear();
}
sample_neighbors_map.clear();
std::vector<int> nodes_left(graph_table->shards.size(),
node_num_for_each_shard);
std::promise<int> prom;
std::future<int> fut = prom.get_future();
sample_neighbors_map.resize(graph_table->task_pool_size_);
int task_size = 0;
std::vector<std::future<int>> tasks;
int init_size = 0;
//__sync_fetch_and_add
std::function<int(int, int64_t)> bfs = [&, this](int i, int id) -> int {
if (this->status == GraphSamplerStatus::terminating) {
int task_left = __sync_sub_and_fetch(&task_size, 1);
if (task_left == 0) {
prom.set_value(0);
}
return 0;
}
size_t ind = i % this->graph_table->task_pool_size_;
if (nodes_left[i] > 0) {
auto iter = sample_neighbors_map[ind].find(id);
if (iter == sample_neighbors_map[ind].end()) {
Node *node = graph_table->shards[i]->find_node(id);
if (node != NULL) {
nodes_left[i]--;
sample_neighbors_map[ind][id] = std::vector<int64_t>();
iter = sample_neighbors_map[ind].find(id);
size_t edge_fetch_size =
std::min((size_t) this->edge_num_for_each_node,
node->get_neighbor_size());
for (size_t k = 0; k < edge_fetch_size; k++) {
int64_t neighbor_id = node->get_neighbor_id(k);
int node_location = neighbor_id % this->graph_table->shard_num %
this->graph_table->task_pool_size_;
__sync_add_and_fetch(&task_size, 1);
graph_table->_shards_task_pool[node_location]->enqueue(
bfs, neighbor_id % this->graph_table->shard_num, neighbor_id);
iter->second.push_back(neighbor_id);
}
}
}
}
int task_left = __sync_sub_and_fetch(&task_size, 1);
if (task_left == 0) {
prom.set_value(0);
}
return 0;
};
for (size_t i = 0; i < graph_table->shards.size(); ++i) {
std::vector<Node *> &v = graph_table->shards[i]->get_bucket();
if (v.size() > 0) {
int search_size = std::min(init_search_size, (int)v.size());
for (int k = 0; k < search_size; k++) {
init_size++;
__sync_add_and_fetch(&task_size, 1);
int64_t id = v[k]->get_id();
graph_table->_shards_task_pool[i % graph_table->task_pool_size_]
->enqueue(bfs, i, id);
}
} // if
}
if (init_size == 0) {
prom.set_value(0);
}
fut.get();
if (this->status == GraphSamplerStatus::terminating) {
pthread_rwlock_unlock(rw_lock);
return 0;
}
VLOG(0) << "BasicBfsGraphSampler finishes the graph searching task";
sample_nodes.clear();
sample_neighbors.clear();
sample_res.clear();
sample_nodes.resize(gpu_num);
sample_neighbors.resize(gpu_num);
sample_res.resize(gpu_num);
std::vector<std::vector<std::vector<paddle::framework::GpuPsGraphNode>>>
sample_nodes_ex(graph_table->task_pool_size_);
std::vector<std::vector<std::vector<int64_t>>> sample_neighbors_ex(
graph_table->task_pool_size_);
for (int i = 0; i < graph_table->task_pool_size_; i++) {
sample_nodes_ex[i].resize(gpu_num);
sample_neighbors_ex[i].resize(gpu_num);
}
tasks.clear();
for (size_t i = 0; i < (size_t)graph_table->task_pool_size_; ++i) {
tasks.push_back(
graph_table->_shards_task_pool[i]->enqueue([&, i, this]() -> int {
if (this->status == GraphSamplerStatus::terminating) {
return 0;
}
paddle::framework::GpuPsGraphNode node;
auto iter = sample_neighbors_map[i].begin();
size_t ind = i;
for (; iter != sample_neighbors_map[i].end(); iter++) {
size_t location = iter->first % this->gpu_num;
node.node_id = iter->first;
node.neighbor_size = iter->second.size();
node.neighbor_offset =
(int)sample_neighbors_ex[ind][location].size();
sample_nodes_ex[ind][location].emplace_back(node);
for (auto k : iter->second)
sample_neighbors_ex[ind][location].push_back(k);
}
return 0;
}));
}
for (size_t i = 0; i < tasks.size(); i++) {
tasks[i].get();
sample_neighbors_map[i].clear();
}
tasks.clear();
if (this->status == GraphSamplerStatus::terminating) {
pthread_rwlock_unlock(rw_lock);
return 0;
}
for (size_t i = 0; i < (size_t)gpu_num; i++) {
tasks.push_back(
graph_table->_shards_task_pool[i % graph_table->task_pool_size_]
->enqueue([&, i, this]() -> int {
if (this->status == GraphSamplerStatus::terminating) {
pthread_rwlock_unlock(rw_lock);
return 0;
}
int total_offset = 0;
for (int j = 0; j < this->graph_table->task_pool_size_; j++) {
for (size_t k = 0; k < sample_nodes_ex[j][i].size(); k++) {
sample_nodes[i].push_back(sample_nodes_ex[j][i][k]);
sample_nodes[i].back().neighbor_offset += total_offset;
}
size_t neighbor_size = sample_neighbors_ex[j][i].size();
total_offset += neighbor_size;
for (size_t k = 0; k < neighbor_size; k++) {
sample_neighbors[i].push_back(sample_neighbors_ex[j][i][k]);
}
}
return 0;
}));
}
for (size_t i = 0; i < tasks.size(); i++) tasks[i].get();
if (this->status == GraphSamplerStatus::terminating) {
pthread_rwlock_unlock(rw_lock);
return 0;
}
for (int i = 0; i < gpu_num; i++) {
sample_res[i].node_list = sample_nodes[i].data();
sample_res[i].neighbor_list = sample_neighbors[i].data();
sample_res[i].node_size = sample_nodes[i].size();
sample_res[i].neighbor_size = sample_neighbors[i].size();
}
pthread_rwlock_unlock(rw_lock);
if (this->status == GraphSamplerStatus::terminating) {
return 0;
}
callback(sample_res);
rounds--;
if (rounds > 0) {
for (int i = 0;
i < interval && this->status == GraphSamplerStatus::running; i++) {
std::this_thread::sleep_for(std::chrono::seconds(1));
}
}
VLOG(0)<<"bfs returning";
}
return 0;
}
void BasicBfsGraphSampler::init(size_t gpu_num, GraphTable *graph_table,
std::vector<std::string> args) {
this->gpu_num = gpu_num;
this->graph_table = graph_table;
init_search_size = args.size() > 0 ? std::stoi(args[0]) : 10;
node_num_for_each_shard = args.size() > 1 ? std::stoi(args[1]) : 10;
edge_num_for_each_node = args.size() > 2 ? std::stoi(args[2]) : 10;
rounds = args.size() > 3 ? std::stoi(args[3]) : 1;
interval = args.size() > 4 ? std::stoi(args[4]) : 60;
}
#endif
*/
std::vector<Node *> GraphShard::get_batch(int start, int end, int step) {
if (start < 0) start = 0;
std::vector<Node *> res;
for (int pos = start; pos < std::min(end, (int)bucket.size()); pos += step) {
res.push_back(bucket[pos]);
}
return res;
}
size_t GraphShard::get_size() { return bucket.size(); }
int32_t GraphTable::add_comm_edge(int idx, int64_t src_id, int64_t dst_id) {
size_t src_shard_id = src_id % shard_num;
if (src_shard_id >= shard_end || src_shard_id < shard_start) {
return -1;
}
size_t index = src_shard_id - shard_start;
edge_shards[idx][index]->add_graph_node(src_id)->build_edges(false);
edge_shards[idx][index]->add_neighbor(src_id, dst_id, 1.0);
return 0;
}
int32_t GraphTable::add_graph_node(int idx,
std::vector<int64_t> &id_list,
std::vector<bool> &is_weight_list) {
auto &shards = edge_shards[idx];
size_t node_size = id_list.size();
std::vector<std::vector<std::pair<int64_t, bool>>> batch(task_pool_size_);
for (size_t i = 0; i < node_size; i++) {
size_t shard_id = id_list[i] % shard_num;
if (shard_id >= shard_end || shard_id < shard_start) {
continue;
}
batch[get_thread_pool_index(id_list[i])].push_back(
{id_list[i], i < is_weight_list.size() ? is_weight_list[i] : false});
}
std::vector<std::future<int>> tasks;
for (size_t i = 0; i < batch.size(); ++i) {
if (!batch[i].size()) continue;
tasks.push_back(
_shards_task_pool[i]->enqueue([&shards, &batch, i, this]() -> int {
for (auto &p : batch[i]) {
size_t index = p.first % this->shard_num - this->shard_start;
shards[index]->add_graph_node(p.first)->build_edges(p.second);
}
return 0;
}));
}
for (size_t i = 0; i < tasks.size(); i++) tasks[i].get();
return 0;
}
int32_t GraphTable::remove_graph_node(int idx, std::vector<int64_t> &id_list) {
size_t node_size = id_list.size();
std::vector<std::vector<int64_t>> batch(task_pool_size_);
for (size_t i = 0; i < node_size; i++) {
size_t shard_id = id_list[i] % shard_num;
if (shard_id >= shard_end || shard_id < shard_start) continue;
batch[get_thread_pool_index(id_list[i])].push_back(id_list[i]);
}
auto &shards = edge_shards[idx];
std::vector<std::future<int>> tasks;
for (size_t i = 0; i < batch.size(); ++i) {
if (!batch[i].size()) continue;
tasks.push_back(
_shards_task_pool[i]->enqueue([&shards, &batch, i, this]() -> int {
for (auto &p : batch[i]) {
size_t index = p % this->shard_num - this->shard_start;
shards[index]->delete_node(p);
}
return 0;
}));
}
for (size_t i = 0; i < tasks.size(); i++) tasks[i].get();
return 0;
}
void GraphShard::clear() {
for (size_t i = 0; i < bucket.size(); i++) {
delete bucket[i];
}
bucket.clear();
node_location.clear();
}
GraphShard::~GraphShard() { clear(); }
void GraphShard::delete_node(int64_t id) {
auto iter = node_location.find(id);
if (iter == node_location.end()) return;
int pos = iter->second;
delete bucket[pos];
if (pos != (int)bucket.size() - 1) {
bucket[pos] = bucket.back();
node_location[bucket.back()->get_id()] = pos;
}
node_location.erase(id);
bucket.pop_back();
}
GraphNode *GraphShard::add_graph_node(int64_t id) {
if (node_location.find(id) == node_location.end()) {
node_location[id] = bucket.size();
bucket.push_back(new GraphNode(id));
}
return (GraphNode *)bucket[node_location[id]];
}
GraphNode *GraphShard::add_graph_node(Node *node) {
auto id = node->get_id();
if (node_location.find(id) == node_location.end()) {
node_location[id] = bucket.size();
bucket.push_back(node);
}
return (GraphNode *)bucket[node_location[id]];
}
FeatureNode *GraphShard::add_feature_node(int64_t id) {
if (node_location.find(id) == node_location.end()) {
node_location[id] = bucket.size();
bucket.push_back(new FeatureNode(id));
}
return (FeatureNode *)bucket[node_location[id]];
}
void GraphShard::add_neighbor(int64_t id, int64_t dst_id, float weight) {
find_node(id)->add_edge(dst_id, weight);
}
Node *GraphShard::find_node(int64_t id) {
auto iter = node_location.find(id);
return iter == node_location.end() ? nullptr : bucket[iter->second];
}
GraphTable::~GraphTable() {
for (int i = 0; i < (int)edge_shards.size(); i++) {
for (auto p : edge_shards[i]) {
delete p;
}
edge_shards[i].clear();
}
for (int i = 0; i < (int)feature_shards.size(); i++) {
for (auto p : feature_shards[i]) {
delete p;
}
feature_shards[i].clear();
}
}
int32_t GraphTable::Load(const std::string &path, const std::string &param) {
bool load_edge = (param[0] == 'e');
bool load_node = (param[0] == 'n');
if (load_edge) {
bool reverse_edge = (param[1] == '<');
std::string edge_type = param.substr(2);
return this->load_edges(path, reverse_edge, edge_type);
}
if (load_node) {
std::string node_type = param.substr(1);
return this->load_nodes(path, node_type);
}
return 0;
}
int32_t GraphTable::get_nodes_ids_by_ranges(
int type_id,
int idx,
std::vector<std::pair<int, int>> ranges,
std::vector<int64_t> &res) {
int start = 0, end, index = 0, total_size = 0;
res.clear();
auto &shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx];
std::vector<std::future<std::vector<int64_t>>> tasks;
for (size_t i = 0; i < shards.size() && index < (int)ranges.size(); i++) {
end = total_size + shards[i]->get_size();
start = total_size;
while (start < end && index < (int)ranges.size()) {
if (ranges[index].second <= start)
index++;
else if (ranges[index].first >= end) {
break;
} else {
int first = std::max(ranges[index].first, start);
int second = std::min(ranges[index].second, end);
start = second;
first -= total_size;
second -= total_size;
tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue(
[&shards, this, first, second, i]() -> std::vector<int64_t> {
return shards[i]->get_ids_by_range(first, second);
}));
}
}
total_size += shards[i]->get_size();
}
for (size_t i = 0; i < tasks.size(); i++) {
auto vec = tasks[i].get();
for (auto &id : vec) {
res.push_back(id);
std::swap(res[rand() % res.size()], res[(int)res.size() - 1]);
}
}
return 0;
}
int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) {
auto paths = paddle::string::split_string<std::string>(path, ";");
int64_t count = 0;
int64_t valid_count = 0;
int idx = 0;
if (node_type == "") {
VLOG(0) << "node_type not specified, loading edges to " << id_to_feature[0]
<< " part";
} else {
if (feature_to_id.find(node_type) == feature_to_id.end()) {
VLOG(0) << "node_type " << node_type
<< " is not defined, nothing will be loaded";
return 0;
}
idx = feature_to_id[node_type];
}
for (auto path : paths) {
std::ifstream file(path);
std::string line;
while (std::getline(file, line)) {
auto values = paddle::string::split_string<std::string>(line, "\t");
if (values.size() < 2) continue;
auto id = std::stoull(values[1]);
size_t shard_id = id % shard_num;
if (shard_id >= shard_end || shard_id < shard_start) {
VLOG(4) << "will not load " << id << " from " << path
<< ", please check id distribution";
continue;
}
if (count % 1000000 == 0) {
VLOG(0) << count << " nodes are loaded from filepath";
VLOG(0) << line;
}
count++;
std::string nt = values[0];
if (nt != node_type) {
continue;
}
size_t index = shard_id - shard_start;
// auto node = shards[index]->add_feature_node(id);
auto node = feature_shards[idx][index]->add_feature_node(id);
node->set_feature_size(feat_name[idx].size());
for (size_t slice = 2; slice < values.size(); slice++) {
auto feat = this->parse_feature(idx, values[slice]);
if (feat.first >= 0) {
node->set_feature(feat.first, feat.second);
} else {
VLOG(4) << "Node feature: " << values[slice]
<< " not in feature_map.";
}
}
valid_count++;
}
}
VLOG(0) << valid_count << "/" << count << " nodes in type " << node_type
<< " are loaded successfully in " << path;
return 0;
}
int32_t GraphTable::build_sampler(int idx, std::string sample_type) {
for (auto &shard : edge_shards[idx]) {
auto bucket = shard->get_bucket();
for (size_t i = 0; i < bucket.size(); i++) {
bucket[i]->build_sampler(sample_type);
}
}
return 0;
}
int32_t GraphTable::load_edges(const std::string &path,
bool reverse_edge,
const std::string &edge_type) {
#ifdef PADDLE_WITH_HETERPS
// if (gpups_mode) pthread_rwlock_rdlock(rw_lock.get());
if (search_level == 2) total_memory_cost = 0;
const int64_t fixed_load_edges = 1000000;
#endif
int idx = 0;
if (edge_type == "") {
VLOG(0) << "edge_type not specified, loading edges to " << id_to_edge[0]
<< " part";
} else {
if (edge_to_id.find(edge_type) == edge_to_id.end()) {
VLOG(0) << "edge_type " << edge_type
<< " is not defined, nothing will be loaded";
return 0;
}
idx = edge_to_id[edge_type];
}
auto paths = paddle::string::split_string<std::string>(path, ";");
int64_t count = 0;
std::string sample_type = "random";
bool is_weighted = false;
int valid_count = 0;
for (auto path : paths) {
std::ifstream file(path);
std::string line;
while (std::getline(file, line)) {
auto values = paddle::string::split_string<std::string>(line, "\t");
count++;
if (values.size() < 2) continue;
auto src_id = std::stoull(values[0]);
auto dst_id = std::stoull(values[1]);
if (reverse_edge) {
std::swap(src_id, dst_id);
}
float weight = 1;
if (values.size() == 3) {
weight = std::stof(values[2]);
sample_type = "weighted";
is_weighted = true;
}
size_t src_shard_id = src_id % shard_num;
if (src_shard_id >= shard_end || src_shard_id < shard_start) {
VLOG(4) << "will not load " << src_id << " from " << path
<< ", please check id distribution";
continue;
}
if (count % 1000000 == 0) {
VLOG(0) << count << " edges are loaded from filepath";
VLOG(0) << line;
}
size_t index = src_shard_id - shard_start;
edge_shards[idx][index]->add_graph_node(src_id)->build_edges(is_weighted);
edge_shards[idx][index]->add_neighbor(src_id, dst_id, weight);
valid_count++;
#ifdef PADDLE_WITH_HETERPS
// if (gpups_mode) pthread_rwlock_rdlock(rw_lock.get());
if (count > fixed_load_edges && search_level == 2) {
dump_edges_to_ssd(idx);
VLOG(0) << "dumping edges to ssd, edge count is reset to 0";
clear_graph(idx);
count = 0;
}
#endif
}
}
VLOG(0) << valid_count << "/" << count << " edges are loaded successfully in "
<< path;
// Build Sampler j
#ifdef PADDLE_WITH_HETERPS
// if (gpups_mode) pthread_rwlock_rdlock(rw_lock.get());
if (search_level == 2) {
if (count > 0) {
dump_edges_to_ssd(idx);
VLOG(0) << "dumping edges to ssd, edge count is reset to 0";
clear_graph(idx);
count = 0;
}
return 0;
}
#endif
for (auto &shard : edge_shards[idx]) {
auto bucket = shard->get_bucket();
for (size_t i = 0; i < bucket.size(); i++) {
bucket[i]->build_sampler(sample_type);
}
}
return 0;
}
Node *GraphTable::find_node(int type_id, int idx, int64_t id) {
size_t shard_id = id % shard_num;
if (shard_id >= shard_end || shard_id < shard_start) {
return nullptr;
}
size_t index = shard_id - shard_start;
auto &search_shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx];
Node *node = search_shards[index]->find_node(id);
return node;
}
uint32_t GraphTable::get_thread_pool_index(int64_t node_id) {
return node_id % shard_num % shard_num_per_server % task_pool_size_;
}
uint32_t GraphTable::get_thread_pool_index_by_shard_index(int64_t shard_index) {
return shard_index % shard_num_per_server % task_pool_size_;
}
int32_t GraphTable::clear_nodes(int type_id, int idx) {
auto &search_shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx];
for (size_t i = 0; i < search_shards.size(); i++) {
search_shards[i]->clear();
}
return 0;
}
int32_t GraphTable::random_sample_nodes(int type_id,
int idx,
int sample_size,
std::unique_ptr<char[]> &buffer,
int &actual_size) {
int total_size = 0;
auto &shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx];
for (int i = 0; i < (int)shards.size(); i++) {
total_size += shards[i]->get_size();
}
if (sample_size > total_size) sample_size = total_size;
int range_num = random_sample_nodes_ranges;
if (range_num > sample_size) range_num = sample_size;
if (sample_size == 0 || range_num == 0) return 0;
std::vector<int> ranges_len, ranges_pos;
int remain = sample_size, last_pos = -1, num;
std::set<int> separator_set;
for (int i = 0; i < range_num - 1; i++) {
while (separator_set.find(num = rand() % (sample_size - 1)) !=
separator_set.end())
;
separator_set.insert(num);
}
for (auto p : separator_set) {
ranges_len.push_back(p - last_pos);
last_pos = p;
}
ranges_len.push_back(sample_size - 1 - last_pos);
remain = total_size - sample_size + range_num;
separator_set.clear();
for (int i = 0; i < range_num; i++) {
while (separator_set.find(num = rand() % remain) != separator_set.end())
;
separator_set.insert(num);
}
int used = 0, index = 0;
last_pos = -1;
for (auto p : separator_set) {
used += p - last_pos - 1;
last_pos = p;
ranges_pos.push_back(used);
used += ranges_len[index++];
}
std::vector<std::pair<int, int>> first_half, second_half;
int start_index = rand() % total_size;
for (size_t i = 0; i < ranges_len.size() && i < ranges_pos.size(); i++) {
if (ranges_pos[i] + ranges_len[i] - 1 + start_index < total_size)
first_half.push_back({ranges_pos[i] + start_index,
ranges_pos[i] + ranges_len[i] + start_index});
else if (ranges_pos[i] + start_index >= total_size) {
second_half.push_back(
{ranges_pos[i] + start_index - total_size,
ranges_pos[i] + ranges_len[i] + start_index - total_size});
} else {
first_half.push_back({ranges_pos[i] + start_index, total_size});
second_half.push_back(
{0, ranges_pos[i] + ranges_len[i] + start_index - total_size});
}
}
for (auto &pair : first_half) second_half.push_back(pair);
std::vector<int64_t> res;
get_nodes_ids_by_ranges(type_id, idx, second_half, res);
actual_size = res.size() * sizeof(int64_t);
buffer.reset(new char[actual_size]);
char *pointer = buffer.get();
memcpy(pointer, res.data(), actual_size);
return 0;
}
int32_t GraphTable::random_sample_neighbors(
int idx,
int64_t *node_ids,
int sample_size,
std::vector<std::shared_ptr<char>> &buffers,
std::vector<int> &actual_sizes,
bool need_weight) {
size_t node_num = buffers.size();
std::function<void(char *)> char_del = [](char *c) { delete[] c; };
std::vector<std::future<int>> tasks;
std::vector<std::vector<uint32_t>> seq_id(task_pool_size_);
std::vector<std::vector<SampleKey>> id_list(task_pool_size_);
size_t index;
for (size_t idy = 0; idy < node_num; ++idy) {
index = get_thread_pool_index(node_ids[idy]);
seq_id[index].emplace_back(idy);
id_list[index].emplace_back(idx, node_ids[idy], sample_size, need_weight);
}
for (int i = 0; i < (int)seq_id.size(); i++) {
if (seq_id[i].size() == 0) continue;
tasks.push_back(_shards_task_pool[i]->enqueue([&, i, this]() -> int {
int64_t node_id;
std::vector<std::pair<SampleKey, SampleResult>> r;
LRUResponse response = LRUResponse::blocked;
if (use_cache) {
response =
scaled_lru->query(i, id_list[i].data(), id_list[i].size(), r);
}
int index = 0;
std::vector<SampleResult> sample_res;
std::vector<SampleKey> sample_keys;
auto &rng = _shards_task_rng_pool[i];
for (size_t k = 0; k < id_list[i].size(); k++) {
if (index < (int)r.size() &&
r[index].first.node_key == id_list[i][k].node_key) {
int idy = seq_id[i][k];
actual_sizes[idy] = r[index].second.actual_size;
buffers[idy] = r[index].second.buffer;
index++;
} else {
node_id = id_list[i][k].node_key;
Node *node = find_node(0, idx, node_id);
int idy = seq_id[i][k];
int &actual_size = actual_sizes[idy];
if (node == nullptr) {
#ifdef PADDLE_WITH_HETERPS
if (search_level == 2) {
VLOG(2) << "enter sample from ssd for node_id " << node_id;
char *buffer_addr = random_sample_neighbor_from_ssd(
idx, node_id, sample_size, rng, actual_size);
if (actual_size != 0) {
std::shared_ptr<char> &buffer = buffers[idy];
buffer.reset(buffer_addr, char_del);
}
VLOG(2) << "actual sampled size from ssd = " << actual_sizes[idy];
continue;
}
#endif
actual_size = 0;
continue;
}
std::shared_ptr<char> &buffer = buffers[idy];
std::vector<int> res = node->sample_k(sample_size, rng);
actual_size =
res.size() * (need_weight ? (Node::id_size + Node::weight_size)
: Node::id_size);
int offset = 0;
int64_t id;
float weight;
char *buffer_addr = new char[actual_size];
if (response == LRUResponse::ok) {
sample_keys.emplace_back(idx, node_id, sample_size, need_weight);
sample_res.emplace_back(actual_size, buffer_addr);
buffer = sample_res.back().buffer;
} else {
buffer.reset(buffer_addr, char_del);
}
for (int &x : res) {
id = node->get_neighbor_id(x);
memcpy(buffer_addr + offset, &id, Node::id_size);
offset += Node::id_size;
if (need_weight) {
weight = node->get_neighbor_weight(x);
memcpy(buffer_addr + offset, &weight, Node::weight_size);
offset += Node::weight_size;
}
}
}
}
if (sample_res.size()) {
scaled_lru->insert(
i, sample_keys.data(), sample_res.data(), sample_keys.size());
}
return 0;
}));
}
for (auto &t : tasks) {
t.get();
}
return 0;
}
int32_t GraphTable::get_node_feat(int idx,
const std::vector<int64_t> &node_ids,
const std::vector<std::string> &feature_names,
std::vector<std::vector<std::string>> &res) {
size_t node_num = node_ids.size();
std::vector<std::future<int>> tasks;
for (size_t idy = 0; idy < node_num; ++idy) {
int64_t node_id = node_ids[idy];
tasks.push_back(_shards_task_pool[get_thread_pool_index(node_id)]->enqueue(
[&, idx, idy, node_id]() -> int {
Node *node = find_node(1, idx, node_id);
if (node == nullptr) {
return 0;
}
for (int feat_idx = 0; feat_idx < (int)feature_names.size();
++feat_idx) {
const std::string &feature_name = feature_names[feat_idx];
if (feat_id_map[idx].find(feature_name) != feat_id_map[idx].end()) {
// res[feat_idx][idx] =
// node->get_feature(feat_id_map[feature_name]);
auto feat = node->get_feature(feat_id_map[idx][feature_name]);
res[feat_idx][idy] = feat;
}
}
return 0;
}));
}
for (size_t idy = 0; idy < node_num; ++idy) {
tasks[idy].get();
}
return 0;
}
int32_t GraphTable::set_node_feat(
int idx,
const std::vector<int64_t> &node_ids,
const std::vector<std::string> &feature_names,
const std::vector<std::vector<std::string>> &res) {
size_t node_num = node_ids.size();
std::vector<std::future<int>> tasks;
for (size_t idy = 0; idy < node_num; ++idy) {
int64_t node_id = node_ids[idy];
tasks.push_back(_shards_task_pool[get_thread_pool_index(node_id)]->enqueue(
[&, idx, idy, node_id]() -> int {
size_t index = node_id % this->shard_num - this->shard_start;
auto node = feature_shards[idx][index]->add_feature_node(node_id);
node->set_feature_size(this->feat_name[idx].size());
for (int feat_idx = 0; feat_idx < (int)feature_names.size();
++feat_idx) {
const std::string &feature_name = feature_names[feat_idx];
if (feat_id_map[idx].find(feature_name) != feat_id_map[idx].end()) {
node->set_feature(feat_id_map[idx][feature_name],
res[feat_idx][idy]);
}
}
return 0;
}));
}
for (size_t idy = 0; idy < node_num; ++idy) {
tasks[idy].get();
}
return 0;
}
std::pair<int32_t, std::string> GraphTable::parse_feature(
int idx, std::string feat_str) {
// Return (feat_id, btyes) if name are in this->feat_name, else return (-1,
// "")
auto fields = paddle::string::split_string<std::string>(feat_str, " ");
if (feat_id_map[idx].count(fields[0])) {
// if (this->feat_id_map.count(fields[0])) {
int32_t id = this->feat_id_map[idx][fields[0]];
std::string dtype = this->feat_dtype[idx][id];
std::vector<std::string> values(fields.begin() + 1, fields.end());
if (dtype == "feasign") {
return std::make_pair<int32_t, std::string>(
int32_t(id), paddle::string::join_strings(values, ' '));
} else if (dtype == "string") {
return std::make_pair<int32_t, std::string>(
int32_t(id), paddle::string::join_strings(values, ' '));
} else if (dtype == "float32") {
return std::make_pair<int32_t, std::string>(
int32_t(id), FeatureNode::parse_value_to_bytes<float>(values));
} else if (dtype == "float64") {
return std::make_pair<int32_t, std::string>(
int32_t(id), FeatureNode::parse_value_to_bytes<double>(values));
} else if (dtype == "int32") {
return std::make_pair<int32_t, std::string>(
int32_t(id), FeatureNode::parse_value_to_bytes<int32_t>(values));
} else if (dtype == "int64") {
return std::make_pair<int32_t, std::string>(
int32_t(id), FeatureNode::parse_value_to_bytes<int64_t>(values));
}
}
return std::make_pair<int32_t, std::string>(-1, "");
}
std::vector<std::vector<int64_t>> GraphTable::get_all_id(int type_id,
int idx,
int slice_num) {
std::vector<std::vector<int64_t>> res(slice_num);
auto &search_shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx];
std::vector<std::future<std::vector<int64_t>>> tasks;
for (size_t i = 0; i < search_shards.size(); i++) {
tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue(
[&search_shards, i]() -> std::vector<int64_t> {
return search_shards[i]->get_all_id();
}));
}
for (size_t i = 0; i < tasks.size(); ++i) {
tasks[i].wait();
}
for (size_t i = 0; i < tasks.size(); i++) {
auto ids = tasks[i].get();
for (auto &id : ids) res[(uint64_t)(id) % slice_num].push_back(id);
}
return res;
}
int32_t GraphTable::pull_graph_list(int type_id,
int idx,
int start,
int total_size,
std::unique_ptr<char[]> &buffer,
int &actual_size,
bool need_feature,
int step) {
if (start < 0) start = 0;
int size = 0, cur_size;
auto &search_shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx];
std::vector<std::future<std::vector<Node *>>> tasks;
for (size_t i = 0; i < search_shards.size() && total_size > 0; i++) {
cur_size = search_shards[i]->get_size();
if (size + cur_size <= start) {
size += cur_size;
continue;
}
int count = std::min(1 + (size + cur_size - start - 1) / step, total_size);
int end = start + (count - 1) * step + 1;
tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue(
[&search_shards, this, i, start, end, step, size]()
-> std::vector<Node *> {
return search_shards[i]->get_batch(start - size, end - size, step);
}));
start += count * step;
total_size -= count;
size += cur_size;
}
for (size_t i = 0; i < tasks.size(); ++i) {
tasks[i].wait();
}
size = 0;
std::vector<std::vector<Node *>> res;
for (size_t i = 0; i < tasks.size(); i++) {
res.push_back(tasks[i].get());
for (size_t j = 0; j < res.back().size(); j++) {
size += res.back()[j]->get_size(need_feature);
}
}
char *buffer_addr = new char[size];
buffer.reset(buffer_addr);
int index = 0;
for (size_t i = 0; i < res.size(); i++) {
for (size_t j = 0; j < res[i].size(); j++) {
res[i][j]->to_buffer(buffer_addr + index, need_feature);
index += res[i][j]->get_size(need_feature);
}
}
actual_size = size;
return 0;
}
int32_t GraphTable::get_server_index_by_id(int64_t id) {
return id % shard_num / shard_num_per_server;
}
int32_t GraphTable::Initialize(const TableParameter &config,
const FsClientParameter &fs_config) {
LOG(INFO) << "in graphTable initialize";
_config = config;
if (InitializeAccessor() != 0) {
LOG(WARNING) << "Table accessor initialize failed";
return -1;
}
if (_afs_client.initialize(fs_config) != 0) {
LOG(WARNING) << "Table fs_client initialize failed";
// return -1;
}
auto graph = config.graph_parameter();
shard_num = _config.shard_num();
LOG(INFO) << "in graphTable initialize over";
return Initialize(graph);
}
void GraphTable::load_node_weight(int type_id, int idx, std::string path) {
auto paths = paddle::string::split_string<std::string>(path, ";");
int64_t count = 0;
auto &weight_map = node_weight[type_id][idx];
for (auto path : paths) {
std::ifstream file(path);
std::string line;
while (std::getline(file, line)) {
auto values = paddle::string::split_string<std::string>(line, "\t");
count++;
if (values.size() < 2) continue;
auto src_id = std::stoull(values[0]);
double weight = std::stod(values[1]);
weight_map[src_id] = weight;
}
}
}
int32_t GraphTable::Initialize(const GraphParameter &graph) {
task_pool_size_ = graph.task_pool_size();
#ifdef PADDLE_WITH_HETERPS
_db = NULL;
search_level = graph.search_level();
if (search_level >= 2) {
_db = paddle::distributed::RocksDBHandler::GetInstance();
_db->initialize("./temp_gpups_db", task_pool_size_);
}
// gpups_mode = true;
// auto *sampler =
// CREATE_PSCORE_CLASS(GraphSampler, graph.gpups_graph_sample_class());
// auto slices =
// string::split_string<std::string>(graph.gpups_graph_sample_args(), ",");
// std::cout << "slices" << std::endl;
// for (auto x : slices) std::cout << x << std::endl;
// sampler->init(graph.gpu_num(), this, slices);
// graph_sampler.reset(sampler);
#endif
if (shard_num == 0) {
server_num = 1;
_shard_idx = 0;
shard_num = graph.shard_num();
}
use_cache = graph.use_cache();
if (use_cache) {
cache_size_limit = graph.cache_size_limit();
cache_ttl = graph.cache_ttl();
make_neighbor_sample_cache((size_t)cache_size_limit, (size_t)cache_ttl);
}
_shards_task_pool.resize(task_pool_size_);
for (size_t i = 0; i < _shards_task_pool.size(); ++i) {
_shards_task_pool[i].reset(new ::ThreadPool(1));
_shards_task_rng_pool.push_back(paddle::framework::GetCPURandomEngine(0));
}
auto graph_feature = graph.graph_feature();
auto node_types = graph.node_types();
auto edge_types = graph.edge_types();
VLOG(0) << "got " << edge_types.size() << "edge types in total";
feat_id_map.resize(node_types.size());
for (int k = 0; k < edge_types.size(); k++) {
VLOG(0) << "in initialize: get a edge_type " << edge_types[k];
edge_to_id[edge_types[k]] = k;
id_to_edge.push_back(edge_types[k]);
}
feat_name.resize(node_types.size());
feat_shape.resize(node_types.size());
feat_dtype.resize(node_types.size());
VLOG(0) << "got " << node_types.size() << "node types in total";
for (int k = 0; k < node_types.size(); k++) {
feature_to_id[node_types[k]] = k;
auto node_type = node_types[k];
auto feature = graph_feature[k];
id_to_feature.push_back(node_type);
int feat_conf_size = static_cast<int>(feature.name().size());
for (int i = 0; i < feat_conf_size; i++) {
// auto &f_name = common.attributes()[i];
// auto &f_shape = common.dims()[i];
// auto &f_dtype = common.params()[i];
auto &f_name = feature.name()[i];
auto &f_shape = feature.shape()[i];
auto &f_dtype = feature.dtype()[i];
feat_name[k].push_back(f_name);
feat_shape[k].push_back(f_shape);
feat_dtype[k].push_back(f_dtype);
feat_id_map[k][f_name] = i;
VLOG(0) << "init graph table feat conf name:" << f_name
<< " shape:" << f_shape << " dtype:" << f_dtype;
}
}
// this->table_name = common.table_name();
// this->table_type = common.name();
this->table_name = graph.table_name();
this->table_type = graph.table_type();
VLOG(0) << " init graph table type " << this->table_type << " table name "
<< this->table_name;
// int feat_conf_size = static_cast<int>(common.attributes().size());
// int feat_conf_size = static_cast<int>(graph_feature.name().size());
VLOG(0) << "in init graph table shard num = " << shard_num << " shard_idx"
<< _shard_idx;
shard_num_per_server = sparse_local_shard_num(shard_num, server_num);
shard_start = _shard_idx * shard_num_per_server;
shard_end = shard_start + shard_num_per_server;
VLOG(0) << "in init graph table shard idx = " << _shard_idx << " shard_start "
<< shard_start << " shard_end " << shard_end;
edge_shards.resize(id_to_edge.size());
node_weight.resize(2);
node_weight[0].resize(id_to_edge.size());
#ifdef PADDLE_WITH_HETERPS
partitions.resize(id_to_edge.size());
#endif
for (int k = 0; k < (int)edge_shards.size(); k++) {
for (size_t i = 0; i < shard_num_per_server; i++) {
edge_shards[k].push_back(new GraphShard());
}
}
node_weight[1].resize(id_to_feature.size());
feature_shards.resize(id_to_feature.size());
for (int k = 0; k < (int)feature_shards.size(); k++) {
for (size_t i = 0; i < shard_num_per_server; i++) {
feature_shards[k].push_back(new GraphShard());
}
}
return 0;
}
} // namespace distributed
}; // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ThreadPool.h>
#include <assert.h>
#include <pthread.h>
#include <algorithm>
#include <cassert>
#include <cstdio>
#include <ctime>
#include <functional>
#include <iostream>
#include <list>
#include <map>
#include <memory>
#include <mutex> // NOLINT
#include <numeric>
#include <queue>
#include <set>
#include <string>
#include <thread>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/common_table.h"
#include "paddle/fluid/distributed/ps/table/graph/class_macro.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/phi/core/utils/rw_lock.h"
#ifdef PADDLE_WITH_HETERPS
#include "paddle/fluid/distributed/ps/table/depends/rocksdb_warpper.h"
#include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h"
#endif
namespace paddle {
namespace distributed {
class GraphShard {
public:
size_t get_size();
GraphShard() {}
~GraphShard();
std::vector<Node *> &get_bucket() { return bucket; }
std::vector<Node *> get_batch(int start, int end, int step);
std::vector<int64_t> get_ids_by_range(int start, int end) {
std::vector<int64_t> res;
for (int i = start; i < end && i < (int)bucket.size(); i++) {
res.push_back(bucket[i]->get_id());
}
return res;
}
std::vector<int64_t> get_all_id() {
std::vector<int64_t> res;
for (int i = 0; i < (int)bucket.size(); i++) {
res.push_back(bucket[i]->get_id());
}
return res;
}
GraphNode *add_graph_node(int64_t id);
GraphNode *add_graph_node(Node *node);
FeatureNode *add_feature_node(int64_t id);
Node *find_node(int64_t id);
void delete_node(int64_t id);
void clear();
void add_neighbor(int64_t id, int64_t dst_id, float weight);
std::unordered_map<int64_t, int> &get_node_location() {
return node_location;
}
private:
std::unordered_map<int64_t, int> node_location;
std::vector<Node *> bucket;
};
enum LRUResponse { ok = 0, blocked = 1, err = 2 };
struct SampleKey {
int idx;
int64_t node_key;
size_t sample_size;
bool is_weighted;
SampleKey(int _idx,
int64_t _node_key,
size_t _sample_size,
bool _is_weighted) {
idx = _idx;
node_key = _node_key;
sample_size = _sample_size;
is_weighted = _is_weighted;
}
bool operator==(const SampleKey &s) const {
return idx == s.idx && node_key == s.node_key &&
sample_size == s.sample_size && is_weighted == s.is_weighted;
}
};
class SampleResult {
public:
size_t actual_size;
std::shared_ptr<char> buffer;
SampleResult(size_t _actual_size, std::shared_ptr<char> &_buffer)
: actual_size(_actual_size), buffer(_buffer) {}
SampleResult(size_t _actual_size, char *_buffer)
: actual_size(_actual_size),
buffer(_buffer, [](char *p) { delete[] p; }) {}
~SampleResult() {}
};
template <typename K, typename V>
class LRUNode {
public:
LRUNode(K _key, V _data, size_t _ttl) : key(_key), data(_data), ttl(_ttl) {
next = pre = NULL;
}
K key;
V data;
size_t ttl;
// time to live
LRUNode<K, V> *pre, *next;
};
template <typename K, typename V>
class ScaledLRU;
template <typename K, typename V>
class RandomSampleLRU {
public:
RandomSampleLRU(ScaledLRU<K, V> *_father) {
father = _father;
remove_count = 0;
node_size = 0;
node_head = node_end = NULL;
global_ttl = father->ttl;
total_diff = 0;
}
~RandomSampleLRU() {
LRUNode<K, V> *p;
while (node_head != NULL) {
p = node_head->next;
delete node_head;
node_head = p;
}
}
LRUResponse query(K *keys, size_t length, std::vector<std::pair<K, V>> &res) {
if (pthread_rwlock_tryrdlock(&father->rwlock) != 0)
return LRUResponse::blocked;
// pthread_rwlock_rdlock(&father->rwlock);
int init_size = node_size - remove_count;
process_redundant(length * 3);
for (size_t i = 0; i < length; i++) {
auto iter = key_map.find(keys[i]);
if (iter != key_map.end()) {
res.emplace_back(keys[i], iter->second->data);
iter->second->ttl--;
if (iter->second->ttl == 0) {
remove(iter->second);
if (remove_count != 0) remove_count--;
} else {
move_to_tail(iter->second);
}
}
}
total_diff += node_size - remove_count - init_size;
if (total_diff >= 500 || total_diff < -500) {
father->handle_size_diff(total_diff);
total_diff = 0;
}
pthread_rwlock_unlock(&father->rwlock);
return LRUResponse::ok;
}
LRUResponse insert(K *keys, V *data, size_t length) {
if (pthread_rwlock_tryrdlock(&father->rwlock) != 0)
return LRUResponse::blocked;
// pthread_rwlock_rdlock(&father->rwlock);
int init_size = node_size - remove_count;
process_redundant(length * 3);
for (size_t i = 0; i < length; i++) {
auto iter = key_map.find(keys[i]);
if (iter != key_map.end()) {
move_to_tail(iter->second);
iter->second->ttl = global_ttl;
iter->second->data = data[i];
} else {
LRUNode<K, V> *temp = new LRUNode<K, V>(keys[i], data[i], global_ttl);
add_new(temp);
}
}
total_diff += node_size - remove_count - init_size;
if (total_diff >= 500 || total_diff < -500) {
father->handle_size_diff(total_diff);
total_diff = 0;
}
pthread_rwlock_unlock(&father->rwlock);
return LRUResponse::ok;
}
void remove(LRUNode<K, V> *node) {
fetch(node);
node_size--;
key_map.erase(node->key);
delete node;
}
void process_redundant(int process_size) {
int length = std::min(remove_count, process_size);
while (length--) {
remove(node_head);
remove_count--;
}
// std::cerr<<"after remove_count = "<<remove_count<<std::endl;
}
void move_to_tail(LRUNode<K, V> *node) {
fetch(node);
place_at_tail(node);
}
void add_new(LRUNode<K, V> *node) {
node->ttl = global_ttl;
place_at_tail(node);
node_size++;
key_map[node->key] = node;
}
void place_at_tail(LRUNode<K, V> *node) {
if (node_end == NULL) {
node_head = node_end = node;
node->next = node->pre = NULL;
} else {
node_end->next = node;
node->pre = node_end;
node->next = NULL;
node_end = node;
}
}
void fetch(LRUNode<K, V> *node) {
if (node->pre) {
node->pre->next = node->next;
} else {
node_head = node->next;
}
if (node->next) {
node->next->pre = node->pre;
} else {
node_end = node->pre;
}
}
private:
std::unordered_map<K, LRUNode<K, V> *> key_map;
ScaledLRU<K, V> *father;
size_t global_ttl, size_limit;
int node_size, total_diff;
LRUNode<K, V> *node_head, *node_end;
friend class ScaledLRU<K, V>;
int remove_count;
};
template <typename K, typename V>
class ScaledLRU {
public:
ScaledLRU(size_t _shard_num, size_t size_limit, size_t _ttl)
: size_limit(size_limit), ttl(_ttl) {
shard_num = _shard_num;
pthread_rwlock_init(&rwlock, NULL);
stop = false;
thread_pool.reset(new ::ThreadPool(1));
global_count = 0;
lru_pool = std::vector<RandomSampleLRU<K, V>>(shard_num,
RandomSampleLRU<K, V>(this));
shrink_job = std::thread([this]() -> void {
while (true) {
{
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait_for(lock, std::chrono::milliseconds(20000));
if (stop) {
return;
}
}
auto status =
thread_pool->enqueue([this]() -> int { return Shrink(); });
status.wait();
}
});
shrink_job.detach();
}
~ScaledLRU() {
std::unique_lock<std::mutex> lock(mutex_);
stop = true;
cv_.notify_one();
}
LRUResponse query(size_t index,
K *keys,
size_t length,
std::vector<std::pair<K, V>> &res) {
return lru_pool[index].query(keys, length, res);
}
LRUResponse insert(size_t index, K *keys, V *data, size_t length) {
return lru_pool[index].insert(keys, data, length);
}
int Shrink() {
int node_size = 0;
for (size_t i = 0; i < lru_pool.size(); i++) {
node_size += lru_pool[i].node_size - lru_pool[i].remove_count;
}
if ((size_t)node_size <= size_t(1.1 * size_limit) + 1) return 0;
if (pthread_rwlock_wrlock(&rwlock) == 0) {
global_count = 0;
for (size_t i = 0; i < lru_pool.size(); i++) {
global_count += lru_pool[i].node_size - lru_pool[i].remove_count;
}
if ((size_t)global_count > size_limit) {
size_t remove = global_count - size_limit;
for (size_t i = 0; i < lru_pool.size(); i++) {
lru_pool[i].total_diff = 0;
lru_pool[i].remove_count +=
1.0 * (lru_pool[i].node_size - lru_pool[i].remove_count) /
global_count * remove;
}
}
pthread_rwlock_unlock(&rwlock);
return 0;
}
return 0;
}
void handle_size_diff(int diff) {
if (diff != 0) {
__sync_fetch_and_add(&global_count, diff);
if (global_count > int(1.25 * size_limit)) {
thread_pool->enqueue([this]() -> int { return Shrink(); });
}
}
}
size_t get_ttl() { return ttl; }
private:
pthread_rwlock_t rwlock;
size_t shard_num;
int global_count;
size_t size_limit, total, hit;
size_t ttl;
bool stop;
std::thread shrink_job;
std::vector<RandomSampleLRU<K, V>> lru_pool;
mutable std::mutex mutex_;
std::condition_variable cv_;
std::shared_ptr<::ThreadPool> thread_pool;
friend class RandomSampleLRU<K, V>;
};
/*
#ifdef PADDLE_WITH_HETERPS
enum GraphSamplerStatus { waiting = 0, running = 1, terminating = 2 };
class GraphTable;
class GraphSampler {
public:
GraphSampler() {
status = GraphSamplerStatus::waiting;
thread_pool.reset(new ::ThreadPool(1));
callback = [](std::vector<paddle::framework::GpuPsCommGraph> &res) {
return;
};
}
virtual int loadData(const std::string &path){
return 0;
}
virtual int run_graph_sampling() = 0;
virtual int start_graph_sampling() {
if (status != GraphSamplerStatus::waiting) {
return -1;
}
std::promise<int> prom;
std::future<int> fut = prom.get_future();
graph_sample_task_over = thread_pool->enqueue([&prom, this]() {
prom.set_value(0);
status = GraphSamplerStatus::running;
return run_graph_sampling();
});
return fut.get();
}
virtual void init(size_t gpu_num, GraphTable *graph_table,
std::vector<std::string> args) = 0;
virtual void set_graph_sample_callback(
std::function<void(std::vector<paddle::framework::GpuPsCommGraph> &)>
callback) {
this->callback = callback;
}
virtual int end_graph_sampling() {
if (status == GraphSamplerStatus::running) {
status = GraphSamplerStatus::terminating;
return graph_sample_task_over.get();
}
return -1;
}
virtual GraphSamplerStatus get_graph_sampler_status() { return status; }
protected:
std::function<void(std::vector<paddle::framework::GpuPsCommGraph> &)>
callback;
std::shared_ptr<::ThreadPool> thread_pool;
GraphSamplerStatus status;
std::future<int> graph_sample_task_over;
std::vector<paddle::framework::GpuPsCommGraph> sample_res;
};
#endif
*/
class GraphTable : public Table {
public:
GraphTable() {
use_cache = false;
shard_num = 0;
rw_lock.reset(new pthread_rwlock_t());
#ifdef PADDLE_WITH_HETERPS
next_partition = 0;
total_memory_cost = 0;
#endif
}
virtual ~GraphTable();
virtual void *GetShard(size_t shard_idx) { return 0; }
static int32_t sparse_local_shard_num(uint32_t shard_num,
uint32_t server_num) {
if (shard_num % server_num == 0) {
return shard_num / server_num;
}
size_t local_shard_num = shard_num / server_num + 1;
return local_shard_num;
}
static size_t get_sparse_shard(uint32_t shard_num,
uint32_t server_num,
uint64_t key) {
return (key % shard_num) / sparse_local_shard_num(shard_num, server_num);
}
virtual int32_t pull_graph_list(int type_id,
int idx,
int start,
int size,
std::unique_ptr<char[]> &buffer,
int &actual_size,
bool need_feature,
int step);
virtual int32_t random_sample_neighbors(
int idx,
int64_t *node_ids,
int sample_size,
std::vector<std::shared_ptr<char>> &buffers,
std::vector<int> &actual_sizes,
bool need_weight);
int32_t random_sample_nodes(int type_id,
int idx,
int sample_size,
std::unique_ptr<char[]> &buffers,
int &actual_sizes);
virtual int32_t get_nodes_ids_by_ranges(
int type_id,
int idx,
std::vector<std::pair<int, int>> ranges,
std::vector<int64_t> &res);
virtual int32_t Initialize() { return 0; }
virtual int32_t Initialize(const TableParameter &config,
const FsClientParameter &fs_config);
virtual int32_t Initialize(const GraphParameter &config);
int32_t Load(const std::string &path, const std::string &param);
int32_t load_edges(const std::string &path,
bool reverse,
const std::string &edge_type);
std::vector<std::vector<int64_t>> get_all_id(int type,
int idx,
int slice_num);
int32_t load_nodes(const std::string &path, std::string node_type);
int32_t add_graph_node(int idx,
std::vector<int64_t> &id_list,
std::vector<bool> &is_weight_list);
int32_t remove_graph_node(int idx, std::vector<int64_t> &id_list);
int32_t get_server_index_by_id(int64_t id);
Node *find_node(int type_id, int idx, int64_t id);
virtual int32_t Pull(TableContext &context) { return 0; }
virtual int32_t Push(TableContext &context) { return 0; }
virtual int32_t clear_nodes(int type, int idx);
virtual void Clear() {}
virtual int32_t Flush() { return 0; }
virtual int32_t Shrink(const std::string &param) { return 0; }
//指定保存路径
virtual int32_t Save(const std::string &path, const std::string &converter) {
return 0;
}
virtual int32_t InitializeShard() { return 0; }
virtual int32_t SetShard(size_t shard_idx, size_t server_num) {
_shard_idx = shard_idx;
/*
_shard_num is not used in graph_table, this following operation is for the
purpose of
being compatible with base class table.
*/
_shard_num = server_num;
this->server_num = server_num;
return 0;
}
virtual uint32_t get_thread_pool_index_by_shard_index(int64_t shard_index);
virtual uint32_t get_thread_pool_index(int64_t node_id);
virtual std::pair<int32_t, std::string> parse_feature(int idx,
std::string feat_str);
virtual int32_t get_node_feat(int idx,
const std::vector<int64_t> &node_ids,
const std::vector<std::string> &feature_names,
std::vector<std::vector<std::string>> &res);
virtual int32_t set_node_feat(
int idx,
const std::vector<int64_t> &node_ids,
const std::vector<std::string> &feature_names,
const std::vector<std::vector<std::string>> &res);
size_t get_server_num() { return server_num; }
void clear_graph(int idx);
virtual int32_t make_neighbor_sample_cache(size_t size_limit, size_t ttl) {
{
std::unique_lock<std::mutex> lock(mutex_);
if (use_cache == false) {
scaled_lru.reset(new ScaledLRU<SampleKey, SampleResult>(
task_pool_size_, size_limit, ttl));
use_cache = true;
}
}
return 0;
}
virtual void load_node_weight(int type_id, int idx, std::string path);
#ifdef PADDLE_WITH_HETERPS
// virtual int32_t start_graph_sampling() {
// return this->graph_sampler->start_graph_sampling();
// }
// virtual int32_t end_graph_sampling() {
// return this->graph_sampler->end_graph_sampling();
// }
// virtual int32_t set_graph_sample_callback(
// std::function<void(std::vector<paddle::framework::GpuPsCommGraph> &)>
// callback) {
// graph_sampler->set_graph_sample_callback(callback);
// return 0;
// }
virtual void make_partitions(int idx, int64_t gb_size, int device_len);
virtual void export_partition_files(int idx, std::string file_path);
virtual char *random_sample_neighbor_from_ssd(
int idx,
int64_t id,
int sample_size,
const std::shared_ptr<std::mt19937_64> rng,
int &actual_size);
virtual int32_t add_node_to_ssd(
int type_id, int idx, int64_t src_id, char *data, int len);
virtual paddle::framework::GpuPsCommGraph make_gpu_ps_graph(
int idx, std::vector<int64_t> ids);
int32_t Load_to_ssd(const std::string &path, const std::string &param);
int64_t load_graph_to_memory_from_ssd(int idx, std::vector<int64_t> &ids);
int32_t make_complementary_graph(int idx, int64_t byte_size);
int32_t dump_edges_to_ssd(int idx);
int32_t get_partition_num(int idx) { return partitions[idx].size(); }
std::vector<int64_t> get_partition(int idx, int index) {
if (idx >= partitions.size() || index >= partitions[idx].size())
return std::vector<int64_t>();
return partitions[idx][index];
}
int32_t load_edges_to_ssd(const std::string &path,
bool reverse_edge,
const std::string &edge_type);
int32_t load_next_partition(int idx);
void set_search_level(int search_level) { this->search_level = search_level; }
int search_level;
int64_t total_memory_cost;
std::vector<std::vector<std::vector<int64_t>>> partitions;
int next_partition;
#endif
virtual int32_t add_comm_edge(int idx, int64_t src_id, int64_t dst_id);
virtual int32_t build_sampler(int idx, std::string sample_type = "random");
std::vector<std::vector<GraphShard *>> edge_shards, feature_shards;
size_t shard_start, shard_end, server_num, shard_num_per_server, shard_num;
int task_pool_size_ = 24;
const int random_sample_nodes_ranges = 3;
std::vector<std::vector<std::unordered_map<int64_t, double>>> node_weight;
std::vector<std::vector<std::string>> feat_name;
std::vector<std::vector<std::string>> feat_dtype;
std::vector<std::vector<int32_t>> feat_shape;
std::vector<std::unordered_map<std::string, int32_t>> feat_id_map;
std::unordered_map<std::string, int> feature_to_id, edge_to_id;
std::vector<std::string> id_to_feature, id_to_edge;
std::string table_name;
std::string table_type;
std::vector<std::shared_ptr<::ThreadPool>> _shards_task_pool;
std::vector<std::shared_ptr<std::mt19937_64>> _shards_task_rng_pool;
std::shared_ptr<ScaledLRU<SampleKey, SampleResult>> scaled_lru;
std::unordered_set<int64_t> extra_nodes;
std::unordered_map<int64_t, size_t> extra_nodes_to_thread_index;
bool use_cache, use_duplicate_nodes;
int cache_size_limit;
int cache_ttl;
mutable std::mutex mutex_;
std::shared_ptr<pthread_rwlock_t> rw_lock;
#ifdef PADDLE_WITH_HETERPS
// paddle::framework::GpuPsGraphTable gpu_graph_table;
paddle::distributed::RocksDBHandler *_db;
// std::shared_ptr<::ThreadPool> graph_sample_pool;
// std::shared_ptr<GraphSampler> graph_sampler;
// REGISTER_GRAPH_FRIEND_CLASS(2, CompleteGraphSampler, BasicBfsGraphSampler)
#endif
};
/*
#ifdef PADDLE_WITH_HETERPS
REGISTER_PSCORE_REGISTERER(GraphSampler);
class CompleteGraphSampler : public GraphSampler {
public:
CompleteGraphSampler() {}
~CompleteGraphSampler() {}
// virtual pthread_rwlock_t *export_rw_lock();
virtual int run_graph_sampling();
virtual void init(size_t gpu_num, GraphTable *graph_table,
std::vector<std::string> args_);
protected:
GraphTable *graph_table;
std::vector<std::vector<paddle::framework::GpuPsGraphNode>> sample_nodes;
std::vector<std::vector<int64_t>> sample_neighbors;
// std::vector<GpuPsCommGraph> sample_res;
// std::shared_ptr<std::mt19937_64> random;
int gpu_num;
};
class BasicBfsGraphSampler : public GraphSampler {
public:
BasicBfsGraphSampler() {}
~BasicBfsGraphSampler() {}
// virtual pthread_rwlock_t *export_rw_lock();
virtual int run_graph_sampling();
virtual void init(size_t gpu_num, GraphTable *graph_table,
std::vector<std::string> args_);
protected:
GraphTable *graph_table;
// std::vector<std::vector<GpuPsGraphNode>> sample_nodes;
std::vector<std::vector<paddle::framework::GpuPsGraphNode>> sample_nodes;
std::vector<std::vector<int64_t>> sample_neighbors;
size_t gpu_num;
int init_search_size, node_num_for_each_shard, edge_num_for_each_node;
int rounds, interval;
std::vector<std::unordered_map<int64_t, std::vector<int64_t>>>
sample_neighbors_map;
};
#endif
*/
} // namespace distributed
}; // namespace paddle
namespace std {
template <>
struct hash<paddle::distributed::SampleKey> {
size_t operator()(const paddle::distributed::SampleKey &s) const {
return s.idx ^ s.node_key ^ s.sample_size;
}
};
} // namespace std
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <condition_variable> // NOLINT
#include <mutex> // NOLINT
#include <set>
#include "paddle/fluid/distributed/common/utils.h"
#include "paddle/fluid/distributed/ps/table/table.h"
namespace paddle {
namespace distributed {
template <typename T>
struct ReservoirValue {
std::vector<T> values;
uint32_t counter;
uint32_t dim;
ReservoirValue() {
dim = 0;
values.resize(dim);
counter = 0;
}
ReservoirValue(uint32_t dim) {
this->dim = dim;
values.resize(dim);
counter = 0;
}
void add(const T *value, int numel) {
GetBlas<T>().VADD(numel, values.data(), value, values.data());
counter++;
}
void add(T *value, int numel) {
GetBlas<T>().VADD(numel, values.data(), value, values.data());
counter++;
}
void avg() {
if (counter == 0) return;
auto scale = 1 / static_cast<T>(counter);
GetBlas<T>().SCAL(values.size(), scale, values.data());
}
void reset() {
std::fill(values.begin(), values.end(), 0);
counter = 0;
}
};
class BarrierTable : public Table {
public:
BarrierTable() {}
virtual ~BarrierTable() {}
virtual void *GetShard(size_t shard_idx) { return 0; }
virtual int32_t Pull(TableContext &context) { return 0; }
virtual int32_t Push(TableContext &context) { return 0; }
int32_t Shrink(const std::string &param) override { return 0; }
virtual void Clear() {}
virtual int32_t Flush() { return 0; }
virtual int32_t Load(const std::string &path, const std::string &param) {
return 0;
}
virtual int32_t Save(const std::string &path, const std::string &param) {
return 0;
}
virtual int32_t InitializeShard() { return 0; }
virtual int32_t Initialize() override;
// only for barrier
// 0: send_barrier 1: recv_barrier 2: complete
virtual int32_t Barrier(const uint32_t trainer_id,
const std::string barrier_type) override;
virtual int32_t SetTableMap(
std::unordered_map<uint32_t, std::shared_ptr<Table>> *table_map) override;
private:
std::mutex mutex_;
std::condition_variable trainer_wait_;
std::set<uint64_t> trainer_ids_;
std::set<uint64_t> trainer_all_;
std::atomic<int> trigger_;
std::atomic<bool> exit_;
std::unordered_map<uint32_t, std::shared_ptr<Table>> *table_map_;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/table/ctr_accessor.h"
#include <gflags/gflags.h>
#include "glog/logging.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace distributed {
int CtrCommonAccessor::Initialize() {
auto name = _config.embed_sgd_param().name();
_embed_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name);
_embed_sgd_rule->LoadConfig(_config.embed_sgd_param(), 1);
name = _config.embedx_sgd_param().name();
_embedx_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name);
_embedx_sgd_rule->LoadConfig(_config.embedx_sgd_param(),
_config.embedx_dim());
common_feature_value.embed_sgd_dim = _embed_sgd_rule->Dim();
common_feature_value.embedx_dim = _config.embedx_dim();
common_feature_value.embedx_sgd_dim = _embedx_sgd_rule->Dim();
_show_click_decay_rate = _config.ctr_accessor_param().show_click_decay_rate();
_ssd_unseenday_threshold =
_config.ctr_accessor_param().ssd_unseenday_threshold();
if (_config.ctr_accessor_param().show_scale()) {
_show_scale = true;
}
InitAccessorInfo();
return 0;
}
void CtrCommonAccessor::InitAccessorInfo() {
_accessor_info.dim = common_feature_value.Dim();
_accessor_info.size = common_feature_value.Size();
auto embedx_dim = _config.embedx_dim();
_accessor_info.select_dim = 3 + embedx_dim;
_accessor_info.select_size = _accessor_info.select_dim * sizeof(float);
_accessor_info.update_dim = 4 + embedx_dim;
_accessor_info.update_size = _accessor_info.update_dim * sizeof(float);
_accessor_info.mf_size =
(embedx_dim + common_feature_value.embedx_sgd_dim) * sizeof(float);
}
bool CtrCommonAccessor::Shrink(float* value) {
auto delete_after_unseen_days =
_config.ctr_accessor_param().delete_after_unseen_days();
auto delete_threshold = _config.ctr_accessor_param().delete_threshold();
// time_decay first
common_feature_value.Show(value) *= _show_click_decay_rate;
common_feature_value.Click(value) *= _show_click_decay_rate;
// shrink after
auto score = ShowClickScore(common_feature_value.Show(value),
common_feature_value.Click(value));
auto unseen_days = common_feature_value.UnseenDays(value);
if (score < delete_threshold || unseen_days > delete_after_unseen_days) {
return true;
}
return false;
}
bool CtrCommonAccessor::SaveCache(float* value,
int param,
double global_cache_threshold) {
auto base_threshold = _config.ctr_accessor_param().base_threshold();
auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
if (ShowClickScore(common_feature_value.Show(value),
common_feature_value.Click(value)) >= base_threshold &&
common_feature_value.UnseenDays(value) <= delta_keep_days) {
return common_feature_value.Show(value) > global_cache_threshold;
}
return false;
}
bool CtrCommonAccessor::SaveSSD(float* value) {
if (common_feature_value.UnseenDays(value) > _ssd_unseenday_threshold) {
return true;
}
return false;
}
bool CtrCommonAccessor::Save(float* value, int param) {
auto base_threshold = _config.ctr_accessor_param().base_threshold();
auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
if (param == 2) {
delta_threshold = 0;
}
switch (param) {
// save all
case 0: {
return true;
}
// save xbox delta
case 1:
// save xbox base
case 2: {
if (ShowClickScore(common_feature_value.Show(value),
common_feature_value.Click(value)) >= base_threshold &&
common_feature_value.DeltaScore(value) >= delta_threshold &&
common_feature_value.UnseenDays(value) <= delta_keep_days) {
// do this after save, because it must not be modified when retry
if (param == 2) {
common_feature_value.DeltaScore(value) = 0;
}
return true;
} else {
return false;
}
}
// already decayed in shrink
case 3: {
// do this after save, because it must not be modified when retry
// common_feature_value.UnseenDays(value)++;
return true;
}
// save revert batch_model
case 5: {
return true;
}
default:
return true;
}
}
void CtrCommonAccessor::UpdateStatAfterSave(float* value, int param) {
auto base_threshold = _config.ctr_accessor_param().base_threshold();
auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
if (param == 2) {
delta_threshold = 0;
}
switch (param) {
case 1: {
if (ShowClickScore(common_feature_value.Show(value),
common_feature_value.Click(value)) >= base_threshold &&
common_feature_value.DeltaScore(value) >= delta_threshold &&
common_feature_value.UnseenDays(value) <= delta_keep_days) {
common_feature_value.DeltaScore(value) = 0;
}
}
return;
case 3: {
common_feature_value.UnseenDays(value)++;
}
return;
default:
return;
}
}
int32_t CtrCommonAccessor::Create(float** values, size_t num) {
for (size_t value_item = 0; value_item < num; ++value_item) {
float* value = values[value_item];
value[common_feature_value.UnseenDaysIndex()] = 0;
value[common_feature_value.DeltaScoreIndex()] = 0;
value[common_feature_value.ShowIndex()] = 0;
value[common_feature_value.ClickIndex()] = 0;
value[common_feature_value.SlotIndex()] = -1;
_embed_sgd_rule->InitValue(value + common_feature_value.EmbedWIndex(),
value + common_feature_value.EmbedG2SumIndex());
_embedx_sgd_rule->InitValue(value + common_feature_value.EmbedxWIndex(),
value + common_feature_value.EmbedxG2SumIndex(),
false);
}
return 0;
}
bool CtrCommonAccessor::NeedExtendMF(float* value) {
float show = value[common_feature_value.ShowIndex()];
float click = value[common_feature_value.ClickIndex()];
float score = (show - click) * _config.ctr_accessor_param().nonclk_coeff() +
click * _config.ctr_accessor_param().click_coeff();
return score >= _config.embedx_threshold();
}
bool CtrCommonAccessor::HasMF(int size) {
return size > common_feature_value.EmbedxG2SumIndex();
}
// from CommonFeatureValue to CtrCommonPullValue
int32_t CtrCommonAccessor::Select(float** select_values,
const float** values,
size_t num) {
auto embedx_dim = _config.embedx_dim();
for (size_t value_item = 0; value_item < num; ++value_item) {
float* select_value = select_values[value_item];
const float* value = values[value_item];
select_value[CtrCommonPullValue::ShowIndex()] =
value[common_feature_value.ShowIndex()];
select_value[CtrCommonPullValue::ClickIndex()] =
value[common_feature_value.ClickIndex()];
select_value[CtrCommonPullValue::EmbedWIndex()] =
value[common_feature_value.EmbedWIndex()];
memcpy(select_value + CtrCommonPullValue::EmbedxWIndex(),
value + common_feature_value.EmbedxWIndex(),
embedx_dim * sizeof(float));
}
return 0;
}
// from CtrCommonPushValue to CtrCommonPushValue
// first dim: item
// second dim: field num
int32_t CtrCommonAccessor::Merge(float** update_values,
const float** other_update_values,
size_t num) {
auto embedx_dim = _config.embedx_dim();
int total_dim = CtrCommonPushValue::Dim(embedx_dim);
for (size_t value_item = 0; value_item < num; ++value_item) {
float* update_value = update_values[value_item];
const float* other_update_value = other_update_values[value_item];
for (int i = 0; i < total_dim; ++i) {
if (i != CtrCommonPushValue::SlotIndex()) {
update_value[i] += other_update_value[i];
}
}
}
return 0;
}
// from CtrCommonPushValue to CommonFeatureValue
// first dim: item
// second dim: field num
int32_t CtrCommonAccessor::Update(float** update_values,
const float** push_values,
size_t num) {
for (size_t value_item = 0; value_item < num; ++value_item) {
float* update_value = update_values[value_item];
const float* push_value = push_values[value_item];
float push_show = push_value[CtrCommonPushValue::ShowIndex()];
float push_click = push_value[CtrCommonPushValue::ClickIndex()];
float slot = push_value[CtrCommonPushValue::SlotIndex()];
update_value[common_feature_value.ShowIndex()] += push_show;
update_value[common_feature_value.ClickIndex()] += push_click;
update_value[common_feature_value.SlotIndex()] = slot;
update_value[common_feature_value.DeltaScoreIndex()] +=
(push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() +
push_click * _config.ctr_accessor_param().click_coeff();
update_value[common_feature_value.UnseenDaysIndex()] = 0;
// TODO(zhaocaibei123): add configure show_scale
if (!_show_scale) {
push_show = 1;
}
VLOG(3) << "accessor show scale:" << _show_scale
<< ", push_show:" << push_show;
_embed_sgd_rule->UpdateValue(
update_value + common_feature_value.EmbedWIndex(),
update_value + common_feature_value.EmbedG2SumIndex(),
push_value + CtrCommonPushValue::EmbedGIndex(),
push_show);
_embedx_sgd_rule->UpdateValue(
update_value + common_feature_value.EmbedxWIndex(),
update_value + common_feature_value.EmbedxG2SumIndex(),
push_value + CtrCommonPushValue::EmbedxGIndex(),
push_show);
}
return 0;
}
bool CtrCommonAccessor::CreateValue(int stage, const float* value) {
// stage == 0, pull
// stage == 1, push
if (stage == 0) {
return true;
} else if (stage == 1) {
// operation
auto show = CtrCommonPushValue::Show(const_cast<float*>(value));
auto click = CtrCommonPushValue::Click(const_cast<float*>(value));
auto score = ShowClickScore(show, click);
if (score <= 0) {
return false;
}
if (score >= 1) {
return true;
}
return local_uniform_real_distribution<float>()(local_random_engine()) <
score;
} else {
return true;
}
}
float CtrCommonAccessor::ShowClickScore(float show, float click) {
auto nonclk_coeff = _config.ctr_accessor_param().nonclk_coeff();
auto click_coeff = _config.ctr_accessor_param().click_coeff();
return (show - click) * nonclk_coeff + click * click_coeff;
}
std::string CtrCommonAccessor::ParseToString(const float* v, int param) {
thread_local std::ostringstream os;
os.clear();
os.str("");
os << v[0] << " " << v[1] << " " << v[2] << " " << v[3] << " " << v[4] << " "
<< v[5];
for (int i = common_feature_value.EmbedG2SumIndex();
i < common_feature_value.EmbedxWIndex();
i++) {
os << " " << v[i];
}
auto show = common_feature_value.Show(const_cast<float*>(v));
auto click = common_feature_value.Click(const_cast<float*>(v));
auto score = ShowClickScore(show, click);
if (score >= _config.embedx_threshold() &&
param > common_feature_value.EmbedxWIndex()) {
for (auto i = common_feature_value.EmbedxWIndex();
i < common_feature_value.Dim();
++i) {
os << " " << v[i];
}
}
return os.str();
}
int CtrCommonAccessor::ParseFromString(const std::string& str, float* value) {
_embedx_sgd_rule->InitValue(value + common_feature_value.EmbedxWIndex(),
value + common_feature_value.EmbedxG2SumIndex());
auto ret = paddle::string::str_to_float(str.data(), value);
CHECK(ret >= 6) << "expect more than 6 real:" << ret;
return ret;
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdint.h>
#include <stdio.h>
#include <vector>
#include "paddle/fluid/distributed/common/registerer.h"
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/sparse_sgd_rule.h"
namespace paddle {
namespace distributed {
// DownpourUnitAccessor
class CtrCommonAccessor : public ValueAccessor {
public:
struct CtrCommonFeatureValue {
/*
float slot;
float unseen_days;
float delta_score;
float show;
float click;
float embed_w;
std::vector<float> embed_g2sum;
std::vector<float> embedx_w;
std::<vector>float embedx_g2sum;
*/
int Dim() { return 6 + embed_sgd_dim + embedx_sgd_dim + embedx_dim; }
int DimSize(size_t dim, int embedx_dim) { return sizeof(float); }
int Size() { return Dim() * sizeof(float); }
int SlotIndex() { return 0; }
int UnseenDaysIndex() { return SlotIndex() + 1; }
int DeltaScoreIndex() { return UnseenDaysIndex() + 1; }
int ShowIndex() { return DeltaScoreIndex() + 1; }
int ClickIndex() { return ShowIndex() + 1; }
int EmbedWIndex() { return ClickIndex() + 1; }
int EmbedG2SumIndex() { return EmbedWIndex() + 1; }
int EmbedxWIndex() { return EmbedG2SumIndex() + embed_sgd_dim; }
int EmbedxG2SumIndex() { return EmbedxWIndex() + embedx_dim; }
float& UnseenDays(float* val) { return val[UnseenDaysIndex()]; }
float& DeltaScore(float* val) { return val[DeltaScoreIndex()]; }
float& Show(float* val) { return val[ShowIndex()]; }
float& Click(float* val) { return val[ClickIndex()]; }
float& Slot(float* val) { return val[SlotIndex()]; }
float& EmbedW(float* val) { return val[EmbedWIndex()]; }
float& EmbedG2Sum(float* val) { return val[EmbedG2SumIndex()]; }
float& EmbedxW(float* val) { return val[EmbedxWIndex()]; }
float& EmbedxG2Sum(float* val) { return val[EmbedxG2SumIndex()]; }
int embed_sgd_dim;
int embedx_dim;
int embedx_sgd_dim;
};
struct CtrCommonPushValue {
/*
float slot;
float show;
float click;
float embed_g;
std::vector<float> embedx_g;
*/
static int Dim(int embedx_dim) { return 4 + embedx_dim; }
static int DimSize(int dim, int embedx_dim) { return sizeof(float); }
static int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); }
static int SlotIndex() { return 0; }
static int ShowIndex() { return CtrCommonPushValue::SlotIndex() + 1; }
static int ClickIndex() { return CtrCommonPushValue::ShowIndex() + 1; }
static int EmbedGIndex() { return CtrCommonPushValue::ClickIndex() + 1; }
static int EmbedxGIndex() { return CtrCommonPushValue::EmbedGIndex() + 1; }
static float& Slot(float* val) {
return val[CtrCommonPushValue::SlotIndex()];
}
static float& Show(float* val) {
return val[CtrCommonPushValue::ShowIndex()];
}
static float& Click(float* val) {
return val[CtrCommonPushValue::ClickIndex()];
}
static float& EmbedG(float* val) {
return val[CtrCommonPushValue::EmbedGIndex()];
}
static float* EmbedxG(float* val) {
return val + CtrCommonPushValue::EmbedxGIndex();
}
};
struct CtrCommonPullValue {
/*
float show;
float click;
float embed_w;
std::vector<float> embedx_w;
*/
static int Dim(int embedx_dim) { return 3 + embedx_dim; }
static int DimSize(size_t dim) { return sizeof(float); }
static int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); }
static int ShowIndex() { return 0; }
static int ClickIndex() { return 1; }
static int EmbedWIndex() { return 2; }
static int EmbedxWIndex() { return 3; }
static float& Show(float* val) {
return val[CtrCommonPullValue::ShowIndex()];
}
static float& Click(float* val) {
return val[CtrCommonPullValue::ClickIndex()];
}
static float& EmbedW(float* val) {
return val[CtrCommonPullValue::EmbedWIndex()];
}
static float* EmbedxW(float* val) {
return val + CtrCommonPullValue::EmbedxWIndex();
}
};
CtrCommonAccessor() {}
virtual ~CtrCommonAccessor() {}
virtual int Initialize();
// 初始化AccessorInfo
virtual void InitAccessorInfo();
// 判断该value是否进行shrink
virtual bool Shrink(float* value);
// 判断该value是否保存到ssd
// virtual bool save_ssd(float* value);
virtual bool NeedExtendMF(float* value);
virtual bool HasMF(int size);
// 判断该value是否在save阶段dump,
// param作为参数用于标识save阶段,如downpour的xbox与batch_model
// param = 0, save all feature
// param = 1, save delta feature
// param = 2, save xbox base feature
bool Save(float* value, int param) override;
bool SaveCache(float* value,
int param,
double global_cache_threshold) override;
bool SaveSSD(float* value) override;
// update delta_score and unseen_days after save
void UpdateStatAfterSave(float* value, int param) override;
// keys不存在时,为values生成随机值
// 要求value的内存由外部调用者分配完毕
virtual int32_t Create(float** value, size_t num);
// 从values中选取到select_values中
virtual int32_t Select(float** select_values,
const float** values,
size_t num);
// 将update_values聚合到一起
virtual int32_t Merge(float** update_values,
const float** other_update_values,
size_t num);
// 将update_values聚合到一起,通过it.next判定是否进入下一个key
// virtual int32_t Merge(float** update_values, iterator it);
// 将update_values更新应用到values中
virtual int32_t Update(float** values,
const float** update_values,
size_t num);
std::string ParseToString(const float* value, int param) override;
int32_t ParseFromString(const std::string& str, float* v) override;
virtual bool CreateValue(int type, const float* value);
// 这个接口目前只用来取show
float GetField(float* value, const std::string& name) override {
// CHECK(name == "show");
if (name == "show") {
return common_feature_value.Show(value);
}
return 0.0;
}
private:
// float ShowClickScore(float show, float click);
// SparseValueSGDRule* _embed_sgd_rule;
// SparseValueSGDRule* _embedx_sgd_rule;
// CtrCommonFeatureValue common_feature_value;
float _show_click_decay_rate;
int32_t _ssd_unseenday_threshold;
bool _show_scale = false;
public: // TODO(zhaocaibei123): it should be private, but we make it public
// for unit test
CtrCommonFeatureValue common_feature_value;
float ShowClickScore(float show, float click);
SparseValueSGDRule* _embed_sgd_rule;
SparseValueSGDRule* _embedx_sgd_rule;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/table/ctr_double_accessor.h"
#include <gflags/gflags.h>
#include "glog/logging.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace distributed {
int CtrDoubleAccessor::Initialize() {
auto name = _config.embed_sgd_param().name();
_embed_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name);
_embed_sgd_rule->LoadConfig(_config.embed_sgd_param(), 1);
name = _config.embedx_sgd_param().name();
_embedx_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name);
_embedx_sgd_rule->LoadConfig(_config.embedx_sgd_param(),
_config.embedx_dim());
_show_click_decay_rate = _config.ctr_accessor_param().show_click_decay_rate();
_ssd_unseenday_threshold =
_config.ctr_accessor_param().ssd_unseenday_threshold();
if (_config.ctr_accessor_param().show_scale()) {
_show_scale = true;
}
InitAccessorInfo();
return 0;
}
void CtrDoubleAccessor::InitAccessorInfo() {
auto embedx_dim = _config.embedx_dim();
_accessor_info.dim = CtrDoubleFeatureValue::Dim(embedx_dim);
_accessor_info.size = CtrDoubleFeatureValue::Size(embedx_dim);
_accessor_info.select_dim = 3 + embedx_dim;
_accessor_info.select_size = _accessor_info.select_dim * sizeof(float);
_accessor_info.update_dim = 4 + embedx_dim;
_accessor_info.update_size = _accessor_info.update_dim * sizeof(float);
_accessor_info.mf_size = (embedx_dim + 1) * sizeof(float);
}
bool CtrDoubleAccessor::Shrink(float* value) {
// auto base_threshold = _config.ctr_accessor_param().base_threshold();
// auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
// auto delete_threshold = _config.ctr_accessor_param().delete_threshold();
auto delete_after_unseen_days =
_config.ctr_accessor_param().delete_after_unseen_days();
auto delete_threshold = _config.ctr_accessor_param().delete_threshold();
// time_decay first
CtrDoubleFeatureValue::Show(value) *= _show_click_decay_rate;
CtrDoubleFeatureValue::Click(value) *= _show_click_decay_rate;
// shrink after
auto score = ShowClickScore(CtrDoubleFeatureValue::Show(value),
CtrDoubleFeatureValue::Click(value));
auto unseen_days = CtrDoubleFeatureValue::UnseenDays(value);
if (score < delete_threshold || unseen_days > delete_after_unseen_days) {
return true;
}
return false;
}
bool CtrDoubleAccessor::SaveSSD(float* value) {
if (CtrDoubleFeatureValue::UnseenDays(value) > _ssd_unseenday_threshold) {
return true;
}
return false;
}
bool CtrDoubleAccessor::SaveCache(float* value,
int param,
double global_cache_threshold) {
auto base_threshold = _config.ctr_accessor_param().base_threshold();
auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
if (ShowClickScore(CtrDoubleFeatureValue::Show(value),
CtrDoubleFeatureValue::Click(value)) >= base_threshold &&
CtrDoubleFeatureValue::UnseenDays(value) <= delta_keep_days) {
return CtrDoubleFeatureValue::Show(value) > global_cache_threshold;
}
return false;
}
bool CtrDoubleAccessor::Save(float* value, int param) {
// auto base_threshold = _config.ctr_accessor_param().base_threshold();
// auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
// auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
auto base_threshold = _config.ctr_accessor_param().base_threshold();
auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
if (param == 2) {
delta_threshold = 0;
}
switch (param) {
// save all
case 0: {
return true;
}
// save xbox delta
case 1:
// save xbox base
case 2: {
if (ShowClickScore(CtrDoubleFeatureValue::Show(value),
CtrDoubleFeatureValue::Click(value)) >=
base_threshold &&
CtrDoubleFeatureValue::DeltaScore(value) >= delta_threshold &&
CtrDoubleFeatureValue::UnseenDays(value) <= delta_keep_days) {
// do this after save, because it must not be modified when retry
if (param == 2) {
CtrDoubleFeatureValue::DeltaScore(value) = 0;
}
return true;
} else {
return false;
}
}
// already decayed in shrink
case 3: {
// CtrDoubleFeatureValue::Show(value) *= _show_click_decay_rate;
// CtrDoubleFeatureValue::Click(value) *= _show_click_decay_rate;
// do this after save, because it must not be modified when retry
// CtrDoubleFeatureValue::UnseenDays(value)++;
return true;
}
default:
return true;
}
}
void CtrDoubleAccessor::UpdateStatAfterSave(float* value, int param) {
auto base_threshold = _config.ctr_accessor_param().base_threshold();
auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
if (param == 2) {
delta_threshold = 0;
}
switch (param) {
case 1: {
if (ShowClickScore(CtrDoubleFeatureValue::Show(value),
CtrDoubleFeatureValue::Click(value)) >=
base_threshold &&
CtrDoubleFeatureValue::DeltaScore(value) >= delta_threshold &&
CtrDoubleFeatureValue::UnseenDays(value) <= delta_keep_days) {
CtrDoubleFeatureValue::DeltaScore(value) = 0;
}
}
return;
case 3: {
CtrDoubleFeatureValue::UnseenDays(value)++;
}
return;
default:
return;
}
}
int32_t CtrDoubleAccessor::Create(float** values, size_t num) {
for (size_t value_item = 0; value_item < num; ++value_item) {
float* value = values[value_item];
value[CtrDoubleFeatureValue::UnseenDaysIndex()] = 0;
value[CtrDoubleFeatureValue::DeltaScoreIndex()] = 0;
*reinterpret_cast<double*>(value + CtrDoubleFeatureValue::ShowIndex()) = 0;
*(double*)(value + CtrDoubleFeatureValue::ClickIndex()) = 0;
value[CtrDoubleFeatureValue::SlotIndex()] = -1;
_embed_sgd_rule->InitValue(
value + CtrDoubleFeatureValue::EmbedWIndex(),
value + CtrDoubleFeatureValue::EmbedG2SumIndex());
_embedx_sgd_rule->InitValue(
value + CtrDoubleFeatureValue::EmbedxWIndex(),
value + CtrDoubleFeatureValue::EmbedxG2SumIndex(),
false);
}
return 0;
}
bool CtrDoubleAccessor::NeedExtendMF(float* value) {
auto show = ((double*)(value + CtrDoubleFeatureValue::ShowIndex()))[0];
auto click = ((double*)(value + CtrDoubleFeatureValue::ClickIndex()))[0];
// float score = (show - click) * _config.ctr_accessor_param().nonclk_coeff()
auto score = (show - click) * _config.ctr_accessor_param().nonclk_coeff() +
click * _config.ctr_accessor_param().click_coeff();
//+ click * _config.ctr_accessor_param().click_coeff();
return score >= _config.embedx_threshold();
}
// from CtrDoubleFeatureValue to CtrDoublePullValue
int32_t CtrDoubleAccessor::Select(float** select_values,
const float** values,
size_t num) {
auto embedx_dim = _config.embedx_dim();
for (size_t value_item = 0; value_item < num; ++value_item) {
float* select_value = select_values[value_item];
float* value = const_cast<float*>(values[value_item]);
select_value[CtrDoublePullValue::ShowIndex()] =
(float)*(double*)(value + CtrDoubleFeatureValue::ShowIndex());
select_value[CtrDoublePullValue::ClickIndex()] =
(float)*(double*)(value + CtrDoubleFeatureValue::ClickIndex());
select_value[CtrDoublePullValue::EmbedWIndex()] =
value[CtrDoubleFeatureValue::EmbedWIndex()];
memcpy(select_value + CtrDoublePullValue::EmbedxWIndex(),
value + CtrDoubleFeatureValue::EmbedxWIndex(),
embedx_dim * sizeof(float));
}
return 0;
}
// from CtrDoublePushValue to CtrDoublePushValue
// first dim: item
// second dim: field num
int32_t CtrDoubleAccessor::Merge(float** update_values,
const float** other_update_values,
size_t num) {
auto embedx_dim = _config.embedx_dim();
size_t total_dim = CtrDoublePushValue::Dim(embedx_dim);
for (size_t value_item = 0; value_item < num; ++value_item) {
float* update_value = update_values[value_item];
const float* other_update_value = other_update_values[value_item];
/**(double*)(update_value + CtrDoublePushValue::ShowIndex()) +=
*(double*)(other_update_value + CtrDoublePushValue::ShowIndex());
*(double*)(update_value + CtrDoublePushValue::ClickIndex()) +=
*(double*)(other_update_value + CtrDoublePushValue::ClickIndex());
for (auto i = 3u; i < total_dim; ++i) {
update_value[i] += other_update_value[i];
}*/
for (size_t i = 0; i < total_dim; ++i) {
if (static_cast<int>(i) != CtrDoublePushValue::SlotIndex()) {
update_value[i] += other_update_value[i];
}
}
}
return 0;
}
// from CtrDoublePushValue to CtrDoubleFeatureValue
// first dim: item
// second dim: field num
int32_t CtrDoubleAccessor::Update(float** update_values,
const float** push_values,
size_t num) {
for (size_t value_item = 0; value_item < num; ++value_item) {
float* update_value = update_values[value_item];
const float* push_value = push_values[value_item];
float push_show = push_value[CtrDoublePushValue::ShowIndex()];
float push_click = push_value[CtrDoublePushValue::ClickIndex()];
float slot = push_value[CtrDoublePushValue::SlotIndex()];
*(double*)(update_value + CtrDoubleFeatureValue::ShowIndex()) +=
(double)push_show;
*(double*)(update_value + CtrDoubleFeatureValue::ClickIndex()) +=
(double)push_click;
update_value[CtrDoubleFeatureValue::SlotIndex()] = slot;
update_value[CtrDoubleFeatureValue::DeltaScoreIndex()] +=
(push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() +
push_click * _config.ctr_accessor_param().click_coeff();
//(push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() +
// push_click * _config.ctr_accessor_param().click_coeff();
update_value[CtrDoubleFeatureValue::UnseenDaysIndex()] = 0;
if (!_show_scale) {
push_show = 1;
}
VLOG(3) << "accessor show scale:" << _show_scale
<< ", push_show:" << push_show;
_embed_sgd_rule->UpdateValue(
update_value + CtrDoubleFeatureValue::EmbedWIndex(),
update_value + CtrDoubleFeatureValue::EmbedG2SumIndex(),
push_value + CtrDoublePushValue::EmbedGIndex(),
push_show);
_embedx_sgd_rule->UpdateValue(
update_value + CtrDoubleFeatureValue::EmbedxWIndex(),
update_value + CtrDoubleFeatureValue::EmbedxG2SumIndex(),
push_value + CtrDoublePushValue::EmbedxGIndex(),
push_show);
}
return 0;
}
bool CtrDoubleAccessor::CreateValue(int stage, const float* value) {
// stage == 0, pull
// stage == 1, push
if (stage == 0) {
return true;
} else if (stage == 1) {
auto show = CtrDoublePushValue::Show(const_cast<float*>(value));
auto click = CtrDoublePushValue::Click(const_cast<float*>(value));
auto score = ShowClickScore(show, click);
if (score <= 0) {
return false;
}
if (score >= 1) {
return true;
}
return local_uniform_real_distribution<float>()(local_random_engine()) <
score;
} else {
return true;
}
}
double CtrDoubleAccessor::ShowClickScore(double show, double click) {
// auto nonclk_coeff = _config.ctr_accessor_param().nonclk_coeff();
// auto click_coeff = _config.ctr_accessor_param().click_coeff();
auto nonclk_coeff = _config.ctr_accessor_param().nonclk_coeff();
auto click_coeff = _config.ctr_accessor_param().click_coeff();
return (show - click) * nonclk_coeff + click * click_coeff;
}
std::string CtrDoubleAccessor::ParseToString(const float* v, int param_size) {
thread_local std::ostringstream os;
os.clear();
os.str("");
os << v[0] << " " << v[1] << " " << (float)((double*)(v + 2))[0] << " "
<< (float)((double*)(v + 4))[0] << " " << v[6] << " " << v[7] << " "
<< v[8];
auto show = CtrDoubleFeatureValue::Show(const_cast<float*>(v));
auto click = CtrDoubleFeatureValue::Click(const_cast<float*>(v));
auto score = ShowClickScore(show, click);
if (score >= _config.embedx_threshold() && param_size > 9) {
os << " " << v[9];
for (size_t i = 0; i < _config.embedx_dim(); ++i) {
os << " " << v[10 + i];
}
}
return os.str();
}
int CtrDoubleAccessor::ParseFromString(const std::string& str, float* value) {
int embedx_dim = _config.embedx_dim();
float data_buff[_accessor_info.dim + 2];
float* data_buff_ptr = data_buff;
_embedx_sgd_rule->InitValue(
data_buff_ptr + CtrDoubleFeatureValue::EmbedxWIndex(),
data_buff_ptr + CtrDoubleFeatureValue::EmbedxG2SumIndex());
auto str_len = paddle::string::str_to_float(str.data(), data_buff_ptr);
CHECK(str_len >= 6) << "expect more than 6 real:" << str_len;
int show_index = CtrDoubleFeatureValue::ShowIndex();
int click_index = CtrDoubleFeatureValue::ClickIndex();
int embed_w_index = CtrDoubleFeatureValue::EmbedWIndex();
// no slot, embedx
int value_dim = _accessor_info.dim;
int embedx_g2sum_index = CtrDoubleFeatureValue::EmbedxG2SumIndex();
value[CtrDoubleFeatureValue::SlotIndex()] = -1;
// other case
if (str_len == (value_dim - 1)) {
// copy unseen_days..delta_score
memcpy(value, data_buff_ptr, show_index * sizeof(float));
// copy show & click
*(double*)(value + show_index) = (double)data_buff_ptr[2];
*(double*)(value + click_index) = (double)data_buff_ptr[3];
// copy others
value[CtrDoubleFeatureValue::EmbedWIndex()] = data_buff_ptr[4];
value[CtrDoubleFeatureValue::EmbedG2SumIndex()] = data_buff_ptr[5];
memcpy(value + embedx_g2sum_index,
data_buff_ptr + 6,
(embedx_dim + 1) * sizeof(float));
} else {
// copy unseen_days..delta_score
memcpy(value, data_buff_ptr, show_index * sizeof(float));
// copy show & click
*(double*)(value + show_index) = (double)data_buff_ptr[2];
*(double*)(value + click_index) = (double)data_buff_ptr[3];
// copy embed_w..embedx_w
memcpy(value + embed_w_index,
data_buff_ptr + 4,
(str_len - 4) * sizeof(float));
}
if (str_len == (value_dim - 1) || str_len == 6) {
str_len += 1;
}
return str_len + 2;
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdint.h>
#include <stdio.h>
#include <vector>
#include "paddle/fluid/distributed/common/registerer.h"
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/sparse_sgd_rule.h"
namespace paddle {
namespace distributed {
class CtrDoubleAccessor : public ValueAccessor {
public:
struct CtrDoubleFeatureValue {
/*
float unseen_days;
float delta_score;
double show;
double click;
float embed_w;
float embed_g2sum;
float slot;
float embedx_g2sum;
std::vector<float> embedx_w;
*/
static int Dim(int embedx_dim) { return 8 + embedx_dim; }
static int DimSize(size_t dim, int embedx_dim) { return sizeof(float); }
static int Size(int embedx_dim) {
return (Dim(embedx_dim) + 2) * sizeof(float);
}
static int UnseenDaysIndex() { return 0; }
static int DeltaScoreIndex() {
return CtrDoubleFeatureValue::UnseenDaysIndex() + 1;
}
static int ShowIndex() {
return CtrDoubleFeatureValue::DeltaScoreIndex() + 1;
}
// show is double
static int ClickIndex() { return CtrDoubleFeatureValue::ShowIndex() + 2; }
// click is double
static int EmbedWIndex() { return CtrDoubleFeatureValue::ClickIndex() + 2; }
static int EmbedG2SumIndex() {
return CtrDoubleFeatureValue::EmbedWIndex() + 1;
}
static int SlotIndex() {
return CtrDoubleFeatureValue::EmbedG2SumIndex() + 1;
}
static int EmbedxG2SumIndex() {
return CtrDoubleFeatureValue::SlotIndex() + 1;
}
static int EmbedxWIndex() {
return CtrDoubleFeatureValue::EmbedxG2SumIndex() + 1;
}
static float& UnseenDays(float* val) {
return val[CtrDoubleFeatureValue::UnseenDaysIndex()];
}
static float& DeltaScore(float* val) {
return val[CtrDoubleFeatureValue::DeltaScoreIndex()];
}
static double& Show(float* val) {
return ((double*)(val + CtrDoubleFeatureValue::ShowIndex()))[0];
}
static double& Click(float* val) {
return ((double*)(val + CtrDoubleFeatureValue::ClickIndex()))[0];
}
static float& Slot(float* val) {
return val[CtrDoubleFeatureValue::SlotIndex()];
}
static float& EmbedW(float* val) {
return val[CtrDoubleFeatureValue::EmbedWIndex()];
}
static float& EmbedG2Sum(float* val) {
return val[CtrDoubleFeatureValue::EmbedG2SumIndex()];
}
static float& EmbedxG2Sum(float* val) {
return val[CtrDoubleFeatureValue::EmbedxG2SumIndex()];
}
static float* EmbedxW(float* val) {
return (val + CtrDoubleFeatureValue::EmbedxWIndex());
}
};
struct CtrDoublePushValue {
/*
float slot;
float show;
float click;
float embed_g;
std::vector<float> embedx_g;
*/
static int Dim(int embedx_dim) { return 4 + embedx_dim; }
static int DimSize(int dim, int embedx_dim) { return sizeof(float); }
static int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); }
static int SlotIndex() { return 0; }
static int ShowIndex() { return CtrDoublePushValue::SlotIndex() + 1; }
static int ClickIndex() { return CtrDoublePushValue::ShowIndex() + 1; }
static int EmbedGIndex() { return CtrDoublePushValue::ClickIndex() + 1; }
static int EmbedxGIndex() { return CtrDoublePushValue::EmbedGIndex() + 1; }
static float& Slot(float* val) {
return val[CtrDoublePushValue::SlotIndex()];
}
static float& Show(float* val) {
return val[CtrDoublePushValue::ShowIndex()];
}
static float& Click(float* val) {
return val[CtrDoublePushValue::ClickIndex()];
}
static float& EmbedG(float* val) {
return val[CtrDoublePushValue::EmbedGIndex()];
}
static float* EmbedxG(float* val) {
return val + CtrDoublePushValue::EmbedxGIndex();
}
};
struct CtrDoublePullValue {
/*
float show;
float click;
float embed_w;
std::vector<float> embedx_w;
*/
static int Dim(int embedx_dim) { return 3 + embedx_dim; }
static int DimSize(size_t dim) { return sizeof(float); }
static int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); }
static int ShowIndex() { return 0; }
static int ClickIndex() { return 1; }
static int EmbedWIndex() { return 2; }
static int EmbedxWIndex() { return 3; }
static float& Show(float* val) {
return val[CtrDoublePullValue::ShowIndex()];
}
static float& Click(float* val) {
return val[CtrDoublePullValue::ClickIndex()];
}
static float& EmbedW(float* val) {
return val[CtrDoublePullValue::EmbedWIndex()];
}
static float* EmbedxW(float* val) {
return val + CtrDoublePullValue::EmbedxWIndex();
}
};
CtrDoubleAccessor() {}
virtual ~CtrDoubleAccessor() {}
virtual int Initialize();
// 初始化AccessorInfo
virtual void InitAccessorInfo();
// 判断该value是否进行shrink
virtual bool Shrink(float* value);
virtual bool NeedExtendMF(float* value);
// 判断该value是否在save阶段dump,
// param作为参数用于标识save阶段,如downpour的xbox与batch_model
// param = 0, save all feature
// param = 1, save delta feature
// param = 3, save all feature with time decay
virtual bool Save(float* value, int param) override;
bool SaveCache(float* value,
int param,
double global_cache_threshold) override;
// update delta_score and unseen_days after save
virtual void UpdateStatAfterSave(float* value, int param) override;
// 判断该value是否保存到ssd
virtual bool SaveSSD(float* value);
// virtual bool save_cache(float* value, int param, double
// global_cache_threshold) override;
// keys不存在时,为values生成随机值
// 要求value的内存由外部调用者分配完毕
virtual int32_t Create(float** value, size_t num);
// 从values中选取到select_values中
virtual int32_t Select(float** select_values,
const float** values,
size_t num);
// 将update_values聚合到一起
virtual int32_t Merge(float** update_values,
const float** other_update_values,
size_t num);
// 将update_values聚合到一起,通过it.next判定是否进入下一个key
// virtual int32_t Merge(float** update_values, iterator it);
// 将update_values更新应用到values中
virtual int32_t Update(float** values,
const float** update_values,
size_t num);
virtual std::string ParseToString(const float* value, int param) override;
virtual int32_t ParseFromString(const std::string& str, float* v) override;
virtual bool CreateValue(int type, const float* value);
//这个接口目前只用来取show
virtual float GetField(float* value, const std::string& name) override {
CHECK(name == "show");
if (name == "show") {
return (float)CtrDoubleFeatureValue::Show(value);
}
return 0.0;
}
// DEFINE_GET_INDEX(CtrDoubleFeatureValue, show)
// DEFINE_GET_INDEX(CtrDoubleFeatureValue, click)
// DEFINE_GET_INDEX(CtrDoubleFeatureValue, embed_w)
// DEFINE_GET_INDEX(CtrDoubleFeatureValue, embedx_w)
private:
double ShowClickScore(double show, double click);
private:
SparseValueSGDRule* _embed_sgd_rule;
SparseValueSGDRule* _embedx_sgd_rule;
float _show_click_decay_rate;
int32_t _ssd_unseenday_threshold;
bool _show_scale = false;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/table/ctr_dymf_accessor.h"
#include <gflags/gflags.h>
#include "glog/logging.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace distributed {
int CtrDymfAccessor::Initialize() {
auto name = _config.embed_sgd_param().name();
_embed_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name);
_embed_sgd_rule->LoadConfig(_config.embed_sgd_param(), 1);
name = _config.embedx_sgd_param().name();
_embedx_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name);
_embedx_sgd_rule->LoadConfig(_config.embedx_sgd_param(),
_config.embedx_dim());
common_feature_value.embed_sgd_dim = _embed_sgd_rule->Dim();
common_feature_value.embedx_dim = _config.embedx_dim();
common_feature_value.embedx_sgd_dim = _embedx_sgd_rule->Dim();
_show_click_decay_rate = _config.ctr_accessor_param().show_click_decay_rate();
_ssd_unseenday_threshold =
_config.ctr_accessor_param().ssd_unseenday_threshold();
if (_config.ctr_accessor_param().show_scale()) {
_show_scale = true;
}
VLOG(0) << " INTO CtrDymfAccessor::Initialize()";
InitAccessorInfo();
return 0;
}
void CtrDymfAccessor::InitAccessorInfo() {
_accessor_info.dim = common_feature_value.Dim();
_accessor_info.size = common_feature_value.Size();
auto embedx_dim = _config.embedx_dim();
VLOG(0) << "InitAccessorInfo embedx_dim:" << embedx_dim;
_accessor_info.select_dim = 3 + embedx_dim;
_accessor_info.select_size = _accessor_info.select_dim * sizeof(float);
_accessor_info.update_dim = 4 + embedx_dim;
_accessor_info.update_size = _accessor_info.update_dim * sizeof(float);
_accessor_info.mf_size =
(embedx_dim + common_feature_value.embedx_sgd_dim) * sizeof(float);
}
bool CtrDymfAccessor::Shrink(float* value) {
auto delete_after_unseen_days =
_config.ctr_accessor_param().delete_after_unseen_days();
auto delete_threshold = _config.ctr_accessor_param().delete_threshold();
// time_decay first
common_feature_value.Show(value) *= _show_click_decay_rate;
common_feature_value.Click(value) *= _show_click_decay_rate;
// shrink after
auto score = ShowClickScore(common_feature_value.Show(value),
common_feature_value.Click(value));
auto unseen_days = common_feature_value.UnseenDays(value);
if (score < delete_threshold || unseen_days > delete_after_unseen_days) {
return true;
}
return false;
}
bool CtrDymfAccessor::SaveCache(float* value,
int param,
double global_cache_threshold) {
auto base_threshold = _config.ctr_accessor_param().base_threshold();
auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
if (ShowClickScore(common_feature_value.Show(value),
common_feature_value.Click(value)) >= base_threshold &&
common_feature_value.UnseenDays(value) <= delta_keep_days) {
return common_feature_value.Show(value) > global_cache_threshold;
}
return false;
}
bool CtrDymfAccessor::SaveSSD(float* value) {
if (common_feature_value.UnseenDays(value) > _ssd_unseenday_threshold) {
return true;
}
return false;
}
bool CtrDymfAccessor::Save(float* value, int param) {
auto base_threshold = _config.ctr_accessor_param().base_threshold();
auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
if (param == 2) {
delta_threshold = 0;
}
switch (param) {
// save all
case 0: {
return true;
}
// save xbox delta
case 1:
// save xbox base
case 2: {
if (ShowClickScore(common_feature_value.Show(value),
common_feature_value.Click(value)) >= base_threshold &&
common_feature_value.DeltaScore(value) >= delta_threshold &&
common_feature_value.UnseenDays(value) <= delta_keep_days) {
// do this after save, because it must not be modified when retry
if (param == 2) {
common_feature_value.DeltaScore(value) = 0;
}
return true;
} else {
return false;
}
}
// already decayed in shrink
case 3: {
// do this after save, because it must not be modified when retry
// common_feature_value.UnseenDays(value)++;
return true;
}
// save revert batch_model
case 5: {
return true;
}
default:
return true;
}
}
void CtrDymfAccessor::UpdateStatAfterSave(float* value, int param) {
auto base_threshold = _config.ctr_accessor_param().base_threshold();
auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
if (param == 2) {
delta_threshold = 0;
}
switch (param) {
case 1: {
if (ShowClickScore(common_feature_value.Show(value),
common_feature_value.Click(value)) >= base_threshold &&
common_feature_value.DeltaScore(value) >= delta_threshold &&
common_feature_value.UnseenDays(value) <= delta_keep_days) {
common_feature_value.DeltaScore(value) = 0;
}
}
return;
case 3: {
common_feature_value.UnseenDays(value)++;
}
return;
default:
return;
}
}
int32_t CtrDymfAccessor::Create(float** values, size_t num) {
for (size_t value_item = 0; value_item < num; ++value_item) {
float* value = values[value_item];
value[common_feature_value.UnseenDaysIndex()] = 0;
value[common_feature_value.DeltaScoreIndex()] = 0;
value[common_feature_value.ShowIndex()] = 0;
value[common_feature_value.ClickIndex()] = 0;
value[common_feature_value.SlotIndex()] = -1;
value[common_feature_value.MfDimIndex()] = -1;
_embed_sgd_rule->InitValue(value + common_feature_value.EmbedWIndex(),
value + common_feature_value.EmbedG2SumIndex());
_embedx_sgd_rule->InitValue(value + common_feature_value.EmbedxWIndex(),
value + common_feature_value.EmbedxG2SumIndex(),
false);
}
return 0;
}
bool CtrDymfAccessor::NeedExtendMF(float* value) {
float show = value[common_feature_value.ShowIndex()];
float click = value[common_feature_value.ClickIndex()];
float score = (show - click) * _config.ctr_accessor_param().nonclk_coeff() +
click * _config.ctr_accessor_param().click_coeff();
return score >= _config.embedx_threshold();
}
bool CtrDymfAccessor::HasMF(int size) {
return size > common_feature_value.EmbedxG2SumIndex();
}
// from CommonFeatureValue to CtrDymfPullValue
int32_t CtrDymfAccessor::Select(float** select_values,
const float** values,
size_t num) {
auto embedx_dim = _config.embedx_dim();
for (size_t value_item = 0; value_item < num; ++value_item) {
float* select_value = select_values[value_item];
const float* value = values[value_item];
select_value[CtrDymfPullValue::ShowIndex()] =
value[common_feature_value.ShowIndex()];
select_value[CtrDymfPullValue::ClickIndex()] =
value[common_feature_value.ClickIndex()];
select_value[CtrDymfPullValue::EmbedWIndex()] =
value[common_feature_value.EmbedWIndex()];
memcpy(select_value + CtrDymfPullValue::EmbedxWIndex(),
value + common_feature_value.EmbedxWIndex(),
embedx_dim * sizeof(float));
}
return 0;
}
// from CtrDymfPushValue to CtrDymfPushValue
// first dim: item
// second dim: field num
int32_t CtrDymfAccessor::Merge(float** update_values,
const float** other_update_values,
size_t num) {
// currently merge in cpu is not supported
return 0;
}
// from CtrDymfPushValue to CommonFeatureValue
// first dim: item
// second dim: field num
int32_t CtrDymfAccessor::Update(float** update_values,
const float** push_values,
size_t num) {
// currently update in cpu is not supported
return 0;
}
bool CtrDymfAccessor::CreateValue(int stage, const float* value) {
// stage == 0, pull
// stage == 1, push
if (stage == 0) {
return true;
} else if (stage == 1) {
// operation
auto show = CtrDymfPushValue::Show(const_cast<float*>(value));
auto click = CtrDymfPushValue::Click(const_cast<float*>(value));
auto score = ShowClickScore(show, click);
if (score <= 0) {
return false;
}
if (score >= 1) {
return true;
}
return local_uniform_real_distribution<float>()(local_random_engine()) <
score;
} else {
return true;
}
}
float CtrDymfAccessor::ShowClickScore(float show, float click) {
auto nonclk_coeff = _config.ctr_accessor_param().nonclk_coeff();
auto click_coeff = _config.ctr_accessor_param().click_coeff();
return (show - click) * nonclk_coeff + click * click_coeff;
}
std::string CtrDymfAccessor::ParseToString(const float* v, int param) {
/*
float unseen_days;
float delta_score;
float show;
float click;
float embed_w;
std::vector<float> embed_g2sum; // float embed_g2sum
float slot;
float mf_dim;
std::<vector>float embedx_g2sum; // float embedx_g2sum
std::vector<float> embedx_w;
*/
thread_local std::ostringstream os;
os.clear();
os.str("");
os << v[0] << " " << v[1] << " " << v[2] << " " << v[3] << " " << v[4];
// << v[5] << " " << v[6];
for (int i = common_feature_value.EmbedG2SumIndex();
i < common_feature_value.EmbedxG2SumIndex();
i++) {
os << " " << v[i];
}
// os << " " << common_feature_value.Slot(const_cast<float*>(v)) << " "
// << common_feature_value.MfDim(const_cast<float*>(v));
auto show = common_feature_value.Show(const_cast<float*>(v));
auto click = common_feature_value.Click(const_cast<float*>(v));
auto score = ShowClickScore(show, click);
if (score >= _config.embedx_threshold() &&
param > common_feature_value.EmbedxG2SumIndex()) {
// VLOG(1) << "common_feature_value.EmbedxG2SumIndex():"
// << common_feature_value.EmbedxG2SumIndex();
// VLOG(1) << "common_feature_value.EmbedxWIndex():"
// << common_feature_value.EmbedxWIndex();
// VLOG(1) << "common_feature_value.MfDim():"
// << common_feature_value.MfDim(const_cast<float*>(v));
for (auto i = common_feature_value.EmbedxG2SumIndex();
i < common_feature_value.EmbedxWIndex() +
common_feature_value.MfDim(const_cast<float*>(v));
++i) {
os << " " << v[i];
}
}
return os.str();
}
int CtrDymfAccessor::ParseFromString(const std::string& str, float* value) {
auto ret = paddle::string::str_to_float(str.data(), value);
CHECK(ret >= 7) << "expect more than 7 real:" << ret;
return ret;
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdint.h>
#include <stdio.h>
#include <vector>
#include "paddle/fluid/distributed/common/registerer.h"
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/sparse_sgd_rule.h"
namespace paddle {
namespace distributed {
// DownpourUnitAccessor
class CtrDymfAccessor : public ValueAccessor {
public:
struct CtrDymfFeatureValue {
/*
float unseen_days;
float delta_score;
float show;
float click;
float embed_w;
// float embed_g2sum;
std::vector<float> embed_g2sum;
float slot;
float mf_dim
std::<vector>float embedx_g2sum;
// float embedx_g2sum;
std::vector<float> embedx_w;
*/
int Dim() { return 7 + embed_sgd_dim + embedx_sgd_dim + embedx_dim; }
int DimSize(size_t dim, int embedx_dim) { return sizeof(float); }
int Size() { return Dim() * sizeof(float); }
int UnseenDaysIndex() { return 0; }
int DeltaScoreIndex() { return UnseenDaysIndex() + 1; }
int ShowIndex() { return DeltaScoreIndex() + 1; }
int ClickIndex() { return ShowIndex() + 1; }
int EmbedWIndex() { return ClickIndex() + 1; }
int EmbedG2SumIndex() { return EmbedWIndex() + 1; }
int SlotIndex() { return EmbedG2SumIndex() + 1; }
int MfDimIndex() { return SlotIndex() + 1; }
int EmbedxG2SumIndex() { return MfDimIndex() + 1; }
int EmbedxWIndex() { return EmbedxG2SumIndex() + 1; }
float& UnseenDays(float* val) { return val[UnseenDaysIndex()]; }
float& DeltaScore(float* val) { return val[DeltaScoreIndex()]; }
float& Show(float* val) { return val[ShowIndex()]; }
float& Click(float* val) { return val[ClickIndex()]; }
float& Slot(float* val) { return val[SlotIndex()]; }
float& MfDim(float* val) { return val[MfDimIndex()]; }
float& EmbedW(float* val) { return val[EmbedWIndex()]; }
float& EmbedG2Sum(float* val) { return val[EmbedG2SumIndex()]; }
float& EmbedxG2Sum(float* val) { return val[EmbedxG2SumIndex()]; }
float& EmbedxW(float* val) { return val[EmbedxWIndex()]; }
int embed_sgd_dim;
int embedx_dim;
int embedx_sgd_dim;
};
struct CtrDymfPushValue {
/*
float slot;
float show;
float click;
float mf_dim;
float embed_g;
std::vector<float> embedx_g;
*/
static int Dim(int embedx_dim) { return 5 + embedx_dim; }
static int DimSize(int dim, int embedx_dim) { return sizeof(float); }
static int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); }
static int SlotIndex() { return 0; }
static int ShowIndex() { return CtrDymfPushValue::SlotIndex() + 1; }
static int ClickIndex() { return CtrDymfPushValue::ShowIndex() + 1; }
static int MfDimIndex() { return CtrDymfPushValue::ClickIndex() + 1; }
static int EmbedGIndex() { return CtrDymfPushValue::MfDimIndex() + 1; }
static int EmbedxGIndex() { return CtrDymfPushValue::EmbedGIndex() + 1; }
static float& Slot(float* val) {
return val[CtrDymfPushValue::SlotIndex()];
}
static float& Show(float* val) {
return val[CtrDymfPushValue::ShowIndex()];
}
static float& Click(float* val) {
return val[CtrDymfPushValue::ClickIndex()];
}
static float& MfDim(float* val) {
return val[CtrDymfPushValue::MfDimIndex()];
}
static float& EmbedG(float* val) {
return val[CtrDymfPushValue::EmbedGIndex()];
}
static float* EmbedxG(float* val) {
return val + CtrDymfPushValue::EmbedxGIndex();
}
};
struct CtrDymfPullValue {
/*
float show;
float click;
float mf_dim;
float embed_w;
std::vector<float> embedx_w;
*/
static int Dim(int embedx_dim) { return 4 + embedx_dim; }
static int DimSize(size_t dim) { return sizeof(float); }
static int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); }
static int ShowIndex() { return 0; }
static int ClickIndex() { return 1; }
static int MfDimIndex() { return 2; }
static int EmbedWIndex() { return 3; }
static int EmbedxWIndex() { return 4; }
static float& Show(float* val) {
return val[CtrDymfPullValue::ShowIndex()];
}
static float& Click(float* val) {
return val[CtrDymfPullValue::ClickIndex()];
}
static float& MfDim(float* val) {
return val[CtrDymfPullValue::MfDimIndex()];
}
static float& EmbedW(float* val) {
return val[CtrDymfPullValue::EmbedWIndex()];
}
static float* EmbedxW(float* val) {
return val + CtrDymfPullValue::EmbedxWIndex();
}
};
CtrDymfAccessor() {}
virtual ~CtrDymfAccessor() {}
virtual int Initialize();
// 初始化AccessorInfo
virtual void InitAccessorInfo();
// 判断该value是否进行shrink
virtual bool Shrink(float* value);
// 判断该value是否保存到ssd
// virtual bool save_ssd(float* value);
virtual bool NeedExtendMF(float* value);
virtual bool HasMF(int size);
// 判断该value是否在save阶段dump,
// param作为参数用于标识save阶段,如downpour的xbox与batch_model
// param = 0, save all feature
// param = 1, save delta feature
// param = 2, save xbox base feature
bool Save(float* value, int param) override;
bool SaveCache(float* value,
int param,
double global_cache_threshold) override;
bool SaveSSD(float* value) override;
// update delta_score and unseen_days after save
void UpdateStatAfterSave(float* value, int param) override;
// keys不存在时,为values生成随机值
// 要求value的内存由外部调用者分配完毕
virtual int32_t Create(float** value, size_t num);
// 从values中选取到select_values中
virtual int32_t Select(float** select_values,
const float** values,
size_t num);
// 将update_values聚合到一起
virtual int32_t Merge(float** update_values,
const float** other_update_values,
size_t num);
// 将update_values聚合到一起,通过it.next判定是否进入下一个key
// virtual int32_t Merge(float** update_values, iterator it);
// 将update_values更新应用到values中
virtual int32_t Update(float** values,
const float** update_values,
size_t num);
std::string ParseToString(const float* value, int param) override;
int32_t ParseFromString(const std::string& str, float* v) override;
virtual bool CreateValue(int type, const float* value);
// 这个接口目前只用来取show
float GetField(float* value, const std::string& name) override {
// CHECK(name == "show");
if (name == "show") {
return common_feature_value.Show(value);
}
return 0.0;
}
private:
// float ShowClickScore(float show, float click);
// SparseValueSGDRule* _embed_sgd_rule;
// SparseValueSGDRule* _embedx_sgd_rule;
// CtrDymfFeatureValue common_feature_value;
float _show_click_decay_rate;
int32_t _ssd_unseenday_threshold;
bool _show_scale = false;
public: // TODO(zhaocaibei123): it should be private, but we make it public
// for unit test
CtrDymfFeatureValue common_feature_value;
float ShowClickScore(float show, float click);
SparseValueSGDRule* _embed_sgd_rule;
SparseValueSGDRule* _embedx_sgd_rule;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <math.h> // for sqrt in CPU and CUDA
#include <functional>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "gflags/gflags.h"
#include "paddle/fluid/distributed/common/utils.h"
namespace paddle {
namespace distributed {
// dense optimzier
// TODO(tangwei12) integrate with sparse optimzer later.
class DenseOptimizer {
public:
DenseOptimizer() {}
explicit DenseOptimizer(const CommonAccessorParameter& accessor,
std::vector<std::vector<float>>* values) {}
virtual void Update(const float* update_values,
size_t num,
int begin,
int end) = 0;
virtual void SetGlobalLR(float* lr) { global_learning_rate_ = lr; }
protected:
float* global_learning_rate_;
};
// sum calc for dense tensor
class DSUM : public DenseOptimizer {
public:
explicit DSUM(const CommonAccessorParameter& accessor,
std::vector<std::vector<float>>* values) {
auto& names = accessor.params();
for (int x = 0; x < static_cast<int>(names.size()); ++x) {
if (names[x] == "Param") {
param = (*values)[x].data();
}
}
}
void Update(const float* update_values,
size_t num,
int begin,
int end) override {
auto update_numel = end - begin;
GetBlas<float>().VADD(
update_numel, update_values + begin, param + begin, param + begin);
}
float* param;
};
// sgd optimizer for dense tensor
class DSGD : public DenseOptimizer {
public:
explicit DSGD(const CommonAccessorParameter& accessor,
std::vector<std::vector<float>>* values) {
auto& names = accessor.params();
for (int x = 0; x < static_cast<int>(names.size()); ++x) {
if (names[x] == "LearningRate") {
learning_rate = (*values)[x].data();
}
if (names[x] == "Param") {
param = (*values)[x].data();
}
}
}
void Update(const float* update_values,
size_t num,
int begin,
int end) override {
auto update_numel = end - begin;
std::vector<float> grads;
grads.resize(update_numel);
auto blas = GetBlas<float>();
float lr = *(global_learning_rate_) * (*learning_rate);
blas.VCOPY(update_numel, update_values + begin, grads.data());
blas.SCAL(update_numel, lr, grads.data());
blas.VSUB(update_numel, param + begin, grads.data(), param + begin);
}
float* learning_rate;
float* param;
};
// adam optimizer for dense tensor
// TODO(zhaocaibei123): add CHECK(memory_dense_table.task_pool_size_) == 1
class DAdam : public DenseOptimizer {
public:
explicit DAdam(const CommonAccessorParameter& accessor,
std::vector<std::vector<float>>* values) {
auto& names = accessor.params();
for (int x = 0; x < static_cast<int>(names.size()); ++x) {
if (names[x] == "LearningRate") {
learning_rate = (*values)[x].data();
}
if (names[x] == "Param") {
param = (*values)[x].data();
}
if (names[x] == "Moment1") {
moment1 = (*values)[x].data();
}
if (names[x] == "Moment2") {
moment2 = (*values)[x].data();
}
if (names[x] == "Beta1Pow") {
beta1_pow = (*values)[x].data();
}
if (names[x] == "Beta2Pow") {
beta2_pow = (*values)[x].data();
}
}
// add attr later
beta1 = 0.9;
beta2 = 0.999;
epsilon = 1.0e-8;
}
// make sure memory_dense_table.task_pool_size_ == 1;
// otherwise, task_pool_size_ times beta1_pow/beta2_pow multiplication
void Update(const float* update_values,
size_t num,
int begin,
int end) override {
auto update_numel = end - begin;
std::vector<float> grad, grad2, tmp;
grad.resize(update_numel);
grad2.resize(update_numel);
tmp.resize(update_numel);
auto blas = GetBlas<float>();
blas.VCOPY(update_numel, update_values + begin, grad.data());
blas.VCOPY(update_numel, update_values + begin, grad2.data());
blas.SCAL(update_numel, 1 - beta1, grad.data());
blas.VSQUARE(update_numel, grad2.data(), grad2.data());
blas.SCAL(update_numel, 1 - beta2, grad2.data());
blas.SCAL(update_numel, beta1, moment1 + begin);
blas.VADD(update_numel, moment1 + begin, grad.data(), moment1 + begin);
blas.SCAL(update_numel, beta2, moment2 + begin);
blas.VADD(update_numel, moment2 + begin, grad2.data(), moment2 + begin);
beta1_pow[0] = beta1_pow[0] * beta1;
beta2_pow[0] = beta2_pow[0] * beta2;
float lr_ = *(global_learning_rate_)*learning_rate[0];
lr_ *= sqrt(1 - beta2_pow[0]) / (1 - beta1_pow[0]);
float* tmp_ = tmp.data();
float eps_ = epsilon * sqrt(1 - beta2_pow[0]);
SQRT<float>(update_numel, moment2 + begin, tmp_);
ADD<float>(update_numel, tmp_, eps_, tmp_);
blas.VDIV(update_numel, moment1 + begin, tmp_, tmp_);
blas.SCAL(update_numel, lr_, tmp_);
blas.VSUB(update_numel, param + begin, tmp_, param + begin);
}
float* learning_rate;
float* param;
float* moment1;
float* moment2;
float* beta1_pow;
float* beta2_pow;
float beta1;
float beta2;
float epsilon;
};
// adam optimizer for dense tensor
class DAdamD2Sum : public DenseOptimizer {
public:
explicit DAdamD2Sum(const CommonAccessorParameter& accessor,
std::vector<std::vector<float>>* values) {
lr_hardcode = 5e-6;
auto& names = accessor.params();
for (int x = 0; x < static_cast<int>(names.size()); ++x) {
if (names[x] == "LearningRate") {
learning_rate = (*values)[x].data();
} else if (names[x] == "Param") {
param = (*values)[x].data();
} else if (names[x] == "Moment") {
mom_velocity = (*values)[x].data();
} else if (names[x] == "G2Sum") {
ada_g2sum = (*values)[x].data();
} else if (names[x] == "D2Sum") {
ada_d2sum = (*values)[x].data();
} else if (names[x] == "MomentDecayRate") {
mom_decay_rate = (*values)[x].data();
} else if (names[x] == "AdaDecayRate") {
ada_decay_rate = (*values)[x].data();
} else if (names[x] == "AdaEpsilon") {
ada_epsilon = (*values)[x].data();
}
}
}
void Update(const float* update_values,
size_t num,
int begin,
int end) override {
auto update_numel = end - begin;
Eigen::Map<Eigen::MatrixXf> mat_ada_g2sum(
ada_g2sum + begin, 1, update_numel);
Eigen::Map<Eigen::MatrixXf> mat_ada_d2sum(
ada_d2sum + begin, 1, update_numel);
Eigen::Map<Eigen::MatrixXf> mat_mom_velocity(
mom_velocity + begin, 1, update_numel);
Eigen::Map<Eigen::MatrixXf> mat_w(param + begin, 1, update_numel);
Eigen::Map<const Eigen::MatrixXf> mat_grad(
update_values + begin, 1, update_numel);
mat_ada_d2sum = (mat_ada_d2sum * ada_decay_rate[0]).array() + 1;
mat_ada_g2sum =
(mat_ada_g2sum * ada_decay_rate[0]) + mat_grad.cwiseProduct(mat_grad);
thread_local std::vector<float> scale_vec;
scale_vec.resize(update_numel);
Eigen::Map<Eigen::MatrixXf> scale(scale_vec.data(), 1, update_numel);
memcpy(
scale_vec.data(), mat_ada_d2sum.data(), sizeof(float) * update_numel);
scale = scale.array() * ada_epsilon[0];
scale = (mat_ada_d2sum + scale).cwiseQuotient(mat_ada_g2sum + scale);
scale = scale.cwiseSqrt();
mat_mom_velocity =
(mat_mom_velocity - mat_grad) * mom_decay_rate[0] + mat_grad;
mat_w -= learning_rate[0] * mat_mom_velocity.cwiseProduct(scale);
}
float* learning_rate;
float lr_hardcode;
float* param;
float* mom_velocity;
float* ada_g2sum;
float* ada_d2sum;
float* mom_decay_rate;
float* ada_decay_rate;
float* ada_epsilon;
};
// for data_norm
class DSummary : public DenseOptimizer {
public:
explicit DSummary(const CommonAccessorParameter& accessor,
std::vector<std::vector<float>>* values) {
auto& names = accessor.params();
for (int x = 0; x < static_cast<int>(names.size()); ++x) {
if (names[x] == "Param") {
param = (*values)[x].data();
} else if (names[x] == "SummaryDecayRate") {
summary_decay_rate = (*values)[x].data();
}
}
}
void Update(const float* update_values,
size_t num,
int begin,
int end) override {
auto update_numel = end - begin;
Eigen::Map<Eigen::MatrixXf> mat_w(param + begin, 1, update_numel);
Eigen::Map<const Eigen::MatrixXf> mat_grad(
update_values + begin, 1, update_numel);
mat_w = mat_w * summary_decay_rate_d + mat_grad;
}
float* summary_decay_rate;
double summary_decay_rate_d = 0.999999;
float* param;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <mct/hash-map.hpp>
#include <vector>
#include "gflags/gflags.h"
#include "paddle/fluid/distributed/common/chunk_allocator.h"
namespace paddle {
namespace distributed {
static const int CTR_SPARSE_SHARD_BUCKET_NUM_BITS = 6;
static const size_t CTR_SPARSE_SHARD_BUCKET_NUM =
static_cast<size_t>(1) << CTR_SPARSE_SHARD_BUCKET_NUM_BITS;
class FixedFeatureValue {
public:
FixedFeatureValue() {}
~FixedFeatureValue() {}
float* data() { return _data.data(); }
size_t size() { return _data.size(); }
void resize(size_t size) { _data.resize(size); }
void shrink_to_fit() { _data.shrink_to_fit(); }
private:
std::vector<float> _data;
};
template <class KEY, class VALUE>
struct alignas(64) SparseTableShard {
public:
typedef typename mct::closed_hash_map<KEY, mct::Pointer, std::hash<KEY>>
map_type;
struct iterator {
typename map_type::iterator it;
size_t bucket;
map_type* buckets;
friend bool operator==(const iterator& a, const iterator& b) {
return a.it == b.it;
}
friend bool operator!=(const iterator& a, const iterator& b) {
return a.it != b.it;
}
const KEY& key() const { return it->first; }
VALUE& value() const { return *(VALUE*)(void*)it->second; } // NOLINT
VALUE* value_ptr() const { return (VALUE*)(void*)it->second; } // NOLINT
iterator& operator++() {
++it;
while (it == buckets[bucket].end() &&
bucket + 1 < CTR_SPARSE_SHARD_BUCKET_NUM) {
it = buckets[++bucket].begin();
}
return *this;
}
iterator operator++(int) {
iterator ret = *this;
++*this;
return ret;
}
};
struct local_iterator {
typename map_type::iterator it;
friend bool operator==(const local_iterator& a, const local_iterator& b) {
return a.it == b.it;
}
friend bool operator!=(const local_iterator& a, const local_iterator& b) {
return a.it != b.it;
}
const KEY& key() const { return it->first; }
VALUE& value() const { return *(VALUE*)(void*)it->second; } // NOLINT
local_iterator& operator++() {
++it;
return *this;
}
local_iterator operator++(int) { return {it++}; }
};
~SparseTableShard() { clear(); }
bool empty() { return _alloc.size() == 0; }
size_t size() { return _alloc.size(); }
void set_max_load_factor(float x) {
for (size_t bucket = 0; bucket < CTR_SPARSE_SHARD_BUCKET_NUM; bucket++) {
_buckets[bucket].max_load_factor(x);
}
}
size_t bucket_count() { return CTR_SPARSE_SHARD_BUCKET_NUM; }
size_t bucket_size(size_t bucket) { return _buckets[bucket].size(); }
void clear() {
for (size_t bucket = 0; bucket < CTR_SPARSE_SHARD_BUCKET_NUM; bucket++) {
map_type& data = _buckets[bucket];
for (auto it = data.begin(); it != data.end(); ++it) {
_alloc.release((VALUE*)(void*)it->second); // NOLINT
}
data.clear();
}
}
iterator begin() {
auto it = _buckets[0].begin();
size_t bucket = 0;
while (it == _buckets[bucket].end() &&
bucket + 1 < CTR_SPARSE_SHARD_BUCKET_NUM) {
it = _buckets[++bucket].begin();
}
return {it, bucket, _buckets};
}
iterator end() {
return {_buckets[CTR_SPARSE_SHARD_BUCKET_NUM - 1].end(),
CTR_SPARSE_SHARD_BUCKET_NUM - 1,
_buckets};
}
local_iterator begin(size_t bucket) { return {_buckets[bucket].begin()}; }
local_iterator end(size_t bucket) { return {_buckets[bucket].end()}; }
iterator find(const KEY& key) {
size_t hash = _hasher(key);
size_t bucket = compute_bucket(hash);
auto it = _buckets[bucket].find_with_hash(key, hash);
if (it == _buckets[bucket].end()) {
return end();
}
return {it, bucket, _buckets};
}
VALUE& operator[](const KEY& key) { return emplace(key).first.value(); }
std::pair<iterator, bool> insert(const KEY& key, const VALUE& val) {
return emplace(key, val);
}
std::pair<iterator, bool> insert(const KEY& key, VALUE&& val) {
return emplace(key, std::move(val));
}
template <class... ARGS>
std::pair<iterator, bool> emplace(const KEY& key, ARGS&&... args) {
size_t hash = _hasher(key);
size_t bucket = compute_bucket(hash);
auto res = _buckets[bucket].insert_with_hash({key, NULL}, hash);
if (res.second) {
res.first->second = _alloc.acquire(std::forward<ARGS>(args)...);
}
return {{res.first, bucket, _buckets}, res.second};
}
iterator erase(iterator it) {
_alloc.release((VALUE*)(void*)it.it->second); // NOLINT
size_t bucket = it.bucket;
auto it2 = _buckets[bucket].erase(it.it);
while (it2 == _buckets[bucket].end() &&
bucket + 1 < CTR_SPARSE_SHARD_BUCKET_NUM) {
it2 = _buckets[++bucket].begin();
}
return {it2, bucket, _buckets};
}
void quick_erase(iterator it) {
_alloc.release((VALUE*)(void*)it.it->second); // NOLINT
_buckets[it.bucket].quick_erase(it.it);
}
local_iterator erase(size_t bucket, local_iterator it) {
_alloc.release((VALUE*)(void*)it.it->second); // NOLINT
return {_buckets[bucket].erase(it.it)};
}
void quick_erase(size_t bucket, local_iterator it) {
_alloc.release((VALUE*)(void*)it.it->second); // NOLINT
_buckets[bucket].quick_erase(it.it);
}
size_t erase(const KEY& key) {
auto it = find(key);
if (it == end()) {
return 0;
}
quick_erase(it);
return 1;
}
size_t compute_bucket(size_t hash) {
if (CTR_SPARSE_SHARD_BUCKET_NUM == 1) {
return 0;
} else {
return hash >> (sizeof(size_t) * 8 - CTR_SPARSE_SHARD_BUCKET_NUM_BITS);
}
}
private:
map_type _buckets[CTR_SPARSE_SHARD_BUCKET_NUM];
ChunkAllocator<VALUE> _alloc;
std::hash<KEY> _hasher;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ThreadPool.h>
#include <future> // NOLINT
#include <memory>
#include <unordered_set>
#include <vector>
namespace paddle {
namespace distributed {
class ConcurrentSet {
public:
ConcurrentSet() : pool_(new ::ThreadPool(1)) {}
~ConcurrentSet() {}
std::future<void> Update(const std::vector<uint64_t>& rows) {
auto task = [this, rows] {
for (auto row : rows) {
set_.insert(row);
}
};
return pool_->enqueue(std::move(task));
}
std::future<void> GetAndClear(std::vector<uint64_t>* result) {
auto task = [this, &result] {
result->clear();
for (auto& id : set_) {
result->push_back(id);
}
set_.clear();
};
return pool_->enqueue(std::move(task));
}
private:
std::unordered_set<uint64_t> set_;
std::unique_ptr<::ThreadPool> pool_{nullptr};
};
class GeoRecorder {
public:
explicit GeoRecorder(int trainer_num) : trainer_num_(trainer_num) {
trainer_rows_.reserve(trainer_num);
for (auto i = 0; i < trainer_num; ++i) {
trainer_rows_.emplace_back(new ConcurrentSet());
}
}
~GeoRecorder() = default;
void Update(const std::vector<uint64_t>& update_rows) {
VLOG(3) << " row size: " << update_rows.size();
std::vector<std::future<void>> fs;
for (auto& set : trainer_rows_) {
fs.push_back(set->Update(update_rows));
}
for (auto& f : fs) {
f.wait();
}
}
void GetAndClear(uint32_t trainer_id, std::vector<uint64_t>* result) {
VLOG(3) << "GetAndClear for trainer: " << trainer_id;
trainer_rows_.at(trainer_id)->GetAndClear(result).wait();
}
private:
const int trainer_num_;
std::vector<std::unique_ptr<ConcurrentSet>> trainer_rows_;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <memory>
#include <random>
#include <string>
#include <utility>
#include <vector>
#include "gflags/gflags.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/operators/truncated_gaussian_random_op.h"
namespace paddle {
namespace distributed {
class Initializer {
public:
Initializer() {}
explicit Initializer(const std::vector<std::string> &attrs) {}
virtual float GetValue() = 0;
virtual void GetValue(std::vector<float> *values, int numel) {
for (int x = 0; x < numel; ++x) {
values->push_back(GetValue());
}
}
virtual void GetValue(float *value, int numel) {
for (int x = 0; x < numel; ++x) {
value[x] = GetValue();
}
}
virtual ~Initializer() {}
protected:
std::string name_;
unsigned int seed_;
};
class UniformInitializer : public Initializer {
public:
explicit UniformInitializer(const std::vector<std::string> &attrs) {
name_ = attrs[0];
seed_ = static_cast<unsigned int>(std::stoi(attrs[1]));
min_ = std::stof(attrs[2]);
max_ = std::stof(attrs[3]);
dist_ = std::uniform_real_distribution<float>(min_, max_);
random_engine_ = framework::GetCPURandomEngine(seed_);
}
float GetValue() override { return dist_(*random_engine_); }
void GetValue(float *value, int numel) {
for (int x = 0; x < numel; ++x) {
value[x] = dist_(*random_engine_);
}
}
private:
float min_;
float max_;
std::shared_ptr<std::mt19937_64> random_engine_;
std::uniform_real_distribution<float> dist_;
};
class GaussianInitializer : public Initializer {
public:
explicit GaussianInitializer(const std::vector<std::string> &attrs) {
name_ = attrs[0];
seed_ = static_cast<unsigned int>(std::stoi(attrs[1]));
mean_ = std::stof(attrs[2]);
std_ = std::stof(attrs[3]);
random_engine_ = framework::GetCPURandomEngine(seed_);
dist_ = std::normal_distribution<float>(mean_, std_);
}
float GetValue() override { return dist_(*random_engine_); }
void GetValue(float *value, int numel) {
for (int x = 0; x < numel; ++x) {
value[x] = dist_(*random_engine_);
}
}
private:
float std_;
float mean_;
std::shared_ptr<std::mt19937_64> random_engine_;
std::normal_distribution<float> dist_;
};
class TruncatedGaussianInitializer : public Initializer {
public:
explicit TruncatedGaussianInitializer(const std::vector<std::string> &attrs) {
name_ = attrs[0];
seed_ = static_cast<unsigned int>(std::stoi(attrs[1]));
mean_ = std::stof(attrs[2]);
std_ = std::stof(attrs[3]);
std::uniform_real_distribution<float> dist_(
std::numeric_limits<float>::min(), 1.0);
random_engine_ = framework::GetCPURandomEngine(seed_);
}
float GetValue() override {
paddle::operators::TruncatedNormal<float> truncated_normal(mean_, std_);
float value = truncated_normal(dist_(*random_engine_));
return value;
}
void GetValue(float *value, int numel) {
paddle::operators::TruncatedNormal<float> truncated_normal(mean_, std_);
for (int x = 0; x < numel; ++x) {
value[x] = truncated_normal(dist_(*random_engine_));
}
}
private:
float std_;
float mean_;
std::shared_ptr<std::mt19937_64> random_engine_;
std::uniform_real_distribution<float> dist_;
};
class FillConstantInitializer : public Initializer {
public:
explicit FillConstantInitializer(const std::vector<std::string> &attrs) {
name_ = attrs[0];
value_ = std::stof(attrs[1]);
}
float GetValue() override { return value_; }
void GetValue(float *value, int numel) { std::fill_n(value, numel, value_); }
private:
float value_;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <glog/logging.h>
#include <rocksdb/db.h>
#include <rocksdb/filter_policy.h>
#include <rocksdb/options.h>
#include <rocksdb/slice.h>
#include <rocksdb/table.h>
#include <rocksdb/write_batch.h>
#include <iostream>
#include <string>
namespace paddle {
namespace distributed {
class RocksDBHandler {
public:
RocksDBHandler() {}
~RocksDBHandler() {}
static RocksDBHandler* GetInstance() {
static RocksDBHandler handler;
return &handler;
}
int initialize(const std::string& db_path, const int colnum) {
VLOG(3) << "db path: " << db_path << " colnum: " << colnum;
rocksdb::Options options;
rocksdb::BlockBasedTableOptions bbto;
bbto.block_size = 4 * 1024;
bbto.block_cache = rocksdb::NewLRUCache(64 * 1024 * 1024);
bbto.block_cache_compressed = rocksdb::NewLRUCache(64 * 1024 * 1024);
bbto.cache_index_and_filter_blocks = false;
bbto.filter_policy.reset(rocksdb::NewBloomFilterPolicy(20, false));
bbto.whole_key_filtering = true;
options.table_factory.reset(rocksdb::NewBlockBasedTableFactory(bbto));
options.keep_log_file_num = 100;
options.max_log_file_size = 50 * 1024 * 1024; // 50MB
options.create_if_missing = true;
options.use_direct_reads = true;
options.max_background_flushes = 5;
options.max_background_compactions = 5;
options.base_background_compactions = 10;
options.write_buffer_size = 256 * 1024 * 1024; // 256MB
options.max_write_buffer_number = 8;
options.max_bytes_for_level_base =
options.max_write_buffer_number * options.write_buffer_size;
options.min_write_buffer_number_to_merge = 1;
options.target_file_size_base = 1024 * 1024 * 1024; // 1024MB
options.memtable_prefix_bloom_size_ratio = 0.02;
options.num_levels = 4;
options.max_open_files = -1;
options.compression = rocksdb::kNoCompression;
options.level0_file_num_compaction_trigger = 8;
options.level0_slowdown_writes_trigger =
1.8 * options.level0_file_num_compaction_trigger;
options.level0_stop_writes_trigger =
3.6 * options.level0_file_num_compaction_trigger;
if (!db_path.empty()) {
std::string rm_cmd = "rm -rf " + db_path;
system(rm_cmd.c_str());
}
rocksdb::Status s = rocksdb::DB::Open(options, db_path, &_db);
assert(s.ok());
_handles.resize(colnum);
for (int i = 0; i < colnum; i++) {
s = _db->CreateColumnFamily(
options, "shard_" + std::to_string(i), &_handles[i]);
assert(s.ok());
}
LOG(INFO) << "DB initialize success, colnum:" << colnum;
return 0;
}
int put(
int id, const char* key, int key_len, const char* value, int value_len) {
rocksdb::WriteOptions options;
options.disableWAL = true;
rocksdb::Status s = _db->Put(options,
_handles[id],
rocksdb::Slice(key, key_len),
rocksdb::Slice(value, value_len));
assert(s.ok());
return 0;
}
int put_batch(int id,
std::vector<std::pair<char*, int>>& ssd_keys,
std::vector<std::pair<char*, int>>& ssd_values,
int n) {
rocksdb::WriteOptions options;
options.disableWAL = true;
rocksdb::WriteBatch batch(n * 128);
for (int i = 0; i < n; i++) {
batch.Put(_handles[id],
rocksdb::Slice(ssd_keys[i].first, ssd_keys[i].second),
rocksdb::Slice(ssd_values[i].first, ssd_values[i].second));
}
rocksdb::Status s = _db->Write(options, &batch);
assert(s.ok());
return 0;
}
int get(int id, const char* key, int key_len, std::string& value) {
rocksdb::Status s = _db->Get(rocksdb::ReadOptions(),
_handles[id],
rocksdb::Slice(key, key_len),
&value);
if (s.IsNotFound()) {
return 1;
}
assert(s.ok());
return 0;
}
int del_data(int id, const char* key, int key_len) {
rocksdb::WriteOptions options;
options.disableWAL = true;
rocksdb::Status s =
_db->Delete(options, _handles[id], rocksdb::Slice(key, key_len));
assert(s.ok());
return 0;
}
int flush(int id) {
rocksdb::Status s = _db->Flush(rocksdb::FlushOptions(), _handles[id]);
assert(s.ok());
return 0;
}
rocksdb::Iterator* get_iterator(int id) {
return _db->NewIterator(rocksdb::ReadOptions(), _handles[id]);
}
int get_estimate_key_num(uint64_t& num_keys) {
_db->GetAggregatedIntProperty("rocksdb.estimate-num-keys", &num_keys);
return 0;
}
private:
std::vector<rocksdb::ColumnFamilyHandle*> _handles;
rocksdb::DB* _db;
};
} // namespace distributed
} // namespace paddle
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment