Commit 12257feb authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

enable network interface into c_api (#986)

* add network apis.

* support parallel loading dataset in c api.

* fix bug

* fix bug
parent 491dd019
...@@ -735,6 +735,26 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterFeatureImportance(BoosterHandle handle, ...@@ -735,6 +735,26 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterFeatureImportance(BoosterHandle handle,
int importance_type, int importance_type,
double* out_results); double* out_results);
/*!
* \brief Initilize the network
* \param machines represent the nodes, format: ip1:port1,ip2:port2
* \param local_listen_port
* \param listen_time_out
* \param num_machines
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_NetworkInit(const char* machines,
int local_listen_port,
int listen_time_out,
int num_machines);
/*!
* \brief Finalize the network
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_NetworkFree();
// exception handle and error msg // exception handle and error msg
static char* LastErrorMsg() { static THREAD_LOCAL char err_msg[512] = "Everything is fine"; return err_msg; } static char* LastErrorMsg() { static THREAD_LOCAL char err_msg[512] = "Everything is fine"; return err_msg; }
......
...@@ -271,6 +271,7 @@ public: ...@@ -271,6 +271,7 @@ public:
int local_listen_port = 12400; int local_listen_port = 12400;
int time_out = 120; // in minutes int time_out = 120; // in minutes
std::string machine_list_filename = ""; std::string machine_list_filename = "";
std::string machines = "";
LIGHTGBM_EXPORT void Set(const std::unordered_map<std::string, std::string>& params) override; LIGHTGBM_EXPORT void Set(const std::unordered_map<std::string, std::string>& params) override;
}; };
...@@ -438,7 +439,9 @@ struct ParameterAlias { ...@@ -438,7 +439,9 @@ struct ParameterAlias {
{ "num_classes", "num_class" }, { "num_classes", "num_class" },
{ "unbalanced_sets", "is_unbalance" }, { "unbalanced_sets", "is_unbalance" },
{ "bagging_fraction_seed", "bagging_seed" }, { "bagging_fraction_seed", "bagging_seed" },
{ "num_boost_round", "num_iterations" } { "num_boost_round", "num_iterations" },
{ "workers", "machines" },
{ "nodes", "machines" },
}); });
const std::unordered_set<std::string> parameter_set({ const std::unordered_set<std::string> parameter_set({
"config", "config_file", "task", "device", "config", "config_file", "task", "device",
...@@ -468,7 +471,7 @@ struct ParameterAlias { ...@@ -468,7 +471,7 @@ struct ParameterAlias {
"feature_fraction_seed", "enable_bundle", "data_filename", "valid_data_filenames", "feature_fraction_seed", "enable_bundle", "data_filename", "valid_data_filenames",
"snapshot_freq", "verbosity", "sparse_threshold", "enable_load_from_binary_file", "snapshot_freq", "verbosity", "sparse_threshold", "enable_load_from_binary_file",
"max_conflict_rate", "poisson_max_delta_step", "gaussian_eta", "max_conflict_rate", "poisson_max_delta_step", "gaussian_eta",
"histogram_pool_size", "output_freq", "is_provide_training_metric", "machine_list_filename", "histogram_pool_size", "output_freq", "is_provide_training_metric", "machine_list_filename", "machines",
"zero_as_missing", "init_score_file", "valid_init_score_file", "is_predict_contrib", "zero_as_missing", "init_score_file", "valid_init_score_file", "is_predict_contrib",
"max_cat_threshold", "max_cat_group", "cat_smooth_ratio", "min_cat_smooth", "max_cat_smooth", "min_data_per_group" "max_cat_threshold", "max_cat_group", "cat_smooth_ratio", "min_cat_smooth", "max_cat_smooth", "min_data_per_group"
}); });
......
...@@ -27,7 +27,7 @@ inline char tolower(char in) { ...@@ -27,7 +27,7 @@ inline char tolower(char in) {
return in; return in;
} }
inline static std::string& Trim(std::string& str) { inline static std::string Trim(std::string str) {
if (str.empty()) { if (str.empty()) {
return str; return str;
} }
...@@ -36,7 +36,7 @@ inline static std::string& Trim(std::string& str) { ...@@ -36,7 +36,7 @@ inline static std::string& Trim(std::string& str) {
return str; return str;
} }
inline static std::string& RemoveQuotationSymbol(std::string& str) { inline static std::string RemoveQuotationSymbol(std::string str) {
if (str.empty()) { if (str.empty()) {
return str; return str;
} }
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <LightGBM/metric.h> #include <LightGBM/metric.h>
#include <LightGBM/config.h> #include <LightGBM/config.h>
#include <LightGBM/prediction_early_stop.h> #include <LightGBM/prediction_early_stop.h>
#include <LightGBM/network.h>
#include <cstdio> #include <cstdio>
#include <vector> #include <vector>
...@@ -54,6 +55,13 @@ public: ...@@ -54,6 +55,13 @@ public:
train_data_ = train_data; train_data_ = train_data;
CreateObjectiveAndMetrics(); CreateObjectiveAndMetrics();
// initialize the boosting // initialize the boosting
if (config_.boosting_config.tree_learner_type == std::string("feature")) {
Log::Fatal("Do not support feature parallel in c api.");
}
if (Network::num_machines() == 1) {
Log::Warning("Only find one worker, will switch to serial tree learner.");
config_.boosting_config.tree_learner_type = "serial";
}
boosting_->Init(&config_.boosting_config, train_data_, objective_fun_.get(), boosting_->Init(&config_.boosting_config, train_data_, objective_fun_.get(),
Common::ConstPtrInVectorWrapper<Metric>(train_metric_)); Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
...@@ -361,7 +369,11 @@ int LGBM_DatasetCreateFromFile(const char* filename, ...@@ -361,7 +369,11 @@ int LGBM_DatasetCreateFromFile(const char* filename,
} }
DatasetLoader loader(config.io_config,nullptr, 1, filename); DatasetLoader loader(config.io_config,nullptr, 1, filename);
if (reference == nullptr) { if (reference == nullptr) {
*out = loader.LoadFromFile(filename, ""); if (Network::num_machines() == 1) {
*out = loader.LoadFromFile(filename, "");
} else {
*out = loader.LoadFromFile(filename, "", Network::rank(), Network::num_machines());
}
} else { } else {
*out = loader.LoadFromFileAlignWithOtherDataset(filename, "", *out = loader.LoadFromFileAlignWithOtherDataset(filename, "",
reinterpret_cast<const Dataset*>(reference)); reinterpret_cast<const Dataset*>(reference));
...@@ -1194,6 +1206,28 @@ int LGBM_BoosterFeatureImportance(BoosterHandle handle, ...@@ -1194,6 +1206,28 @@ int LGBM_BoosterFeatureImportance(BoosterHandle handle,
API_END(); API_END();
} }
int LGBM_NetworkInit(const char* machines,
int local_listen_port,
int listen_time_out,
int num_machines) {
API_BEGIN();
NetworkConfig config;
config.machines = Common::RemoveQuotationSymbol(std::string(machines));
config.local_listen_port = local_listen_port;
config.num_machines = num_machines;
config.time_out = listen_time_out;
if (num_machines > 1) {
Network::Init(config);
}
API_END();
}
int LGBM_NetworkFree() {
API_BEGIN();
Network::Dispose();
API_END();
}
// ---- start of some help functions // ---- start of some help functions
std::function<std::vector<double>(int row_idx)> std::function<std::vector<double>(int row_idx)>
......
...@@ -435,6 +435,7 @@ void NetworkConfig::Set(const std::unordered_map<std::string, std::string>& para ...@@ -435,6 +435,7 @@ void NetworkConfig::Set(const std::unordered_map<std::string, std::string>& para
GetInt(params, "time_out", &time_out); GetInt(params, "time_out", &time_out);
CHECK(time_out > 0); CHECK(time_out > 0);
GetString(params, "machine_list_file", &machine_list_filename); GetString(params, "machine_list_file", &machine_list_filename);
GetString(params, "machines", &machines);
} }
} // namespace LightGBM } // namespace LightGBM
...@@ -494,24 +494,107 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values, ...@@ -494,24 +494,107 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values,
const data_size_t filter_cnt = static_cast<data_size_t>( const data_size_t filter_cnt = static_cast<data_size_t>(
static_cast<double>(io_config_.min_data_in_leaf * total_sample_size) / num_data); static_cast<double>(io_config_.min_data_in_leaf * total_sample_size) / num_data);
OMP_INIT_EX(); if (Network::num_machines() == 1) {
#pragma omp parallel for schedule(guided) // if only one machine, find bin locally
for (int i = 0; i < num_col; ++i) { OMP_INIT_EX();
OMP_LOOP_EX_BEGIN(); #pragma omp parallel for schedule(guided)
if (ignore_features_.count(i) > 0) { for (int i = 0; i < num_col; ++i) {
bin_mappers[i] = nullptr; OMP_LOOP_EX_BEGIN();
continue; if (ignore_features_.count(i) > 0) {
} bin_mappers[i] = nullptr;
BinType bin_type = BinType::NumericalBin; continue;
if (categorical_features_.count(i)) { }
bin_type = BinType::CategoricalBin; BinType bin_type = BinType::NumericalBin;
} if (categorical_features_.count(i)) {
bin_mappers[i].reset(new BinMapper()); bin_type = BinType::CategoricalBin;
bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size, }
io_config_.max_bin, io_config_.min_data_in_bin, filter_cnt, bin_type, io_config_.use_missing, io_config_.zero_as_missing); bin_mappers[i].reset(new BinMapper());
OMP_LOOP_EX_END(); bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
} io_config_.max_bin, io_config_.min_data_in_bin, filter_cnt, bin_type, io_config_.use_missing, io_config_.zero_as_missing);
OMP_THROW_EX(); OMP_LOOP_EX_END();
}
OMP_THROW_EX();
} else {
// if have multi-machines, need to find bin distributed
// different machines will find bin for different features
int num_machines = Network::num_machines();
int rank = Network::rank();
int total_num_feature = num_col;
total_num_feature = Network::GlobalSyncUpByMin(total_num_feature);
// start and len will store the process feature indices for different machines
// machine i will find bins for features in [ start[i], start[i] + len[i] )
std::vector<int> start(num_machines);
std::vector<int> len(num_machines);
int step = (total_num_feature + num_machines - 1) / num_machines;
if (step < 1) { step = 1; }
start[0] = 0;
for (int i = 0; i < num_machines - 1; ++i) {
len[i] = std::min(step, total_num_feature - start[i]);
start[i + 1] = start[i] + len[i];
}
len[num_machines - 1] = total_num_feature - start[num_machines - 1];
OMP_INIT_EX();
#pragma omp parallel for schedule(guided)
for (int i = 0; i < len[rank]; ++i) {
OMP_LOOP_EX_BEGIN();
if (ignore_features_.count(start[rank] + i) > 0) {
continue;
}
BinType bin_type = BinType::NumericalBin;
if (categorical_features_.count(start[rank] + i)) {
bin_type = BinType::CategoricalBin;
}
bin_mappers[i].reset(new BinMapper());
bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i], total_sample_size,
io_config_.max_bin, io_config_.min_data_in_bin, filter_cnt, bin_type, io_config_.use_missing, io_config_.zero_as_missing);
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
int max_bin = 0;
for (int i = 0; i < len[rank]; ++i) {
if (bin_mappers[i] != nullptr) {
max_bin = std::max(max_bin, bin_mappers[i]->num_bin());
}
}
max_bin = Network::GlobalSyncUpByMax(max_bin);
// get size of bin mapper with max_bin size
int type_size = BinMapper::SizeForSpecificBin(max_bin);
// since sizes of different feature may not be same, we expand all bin mapper to type_size
int buffer_size = type_size * total_num_feature;
auto input_buffer = std::vector<char>(buffer_size);
auto output_buffer = std::vector<char>(buffer_size);
// find local feature bins and copy to buffer
#pragma omp parallel for schedule(guided)
for (int i = 0; i < len[rank]; ++i) {
OMP_LOOP_EX_BEGIN();
if (ignore_features_.count(start[rank] + i) > 0) {
continue;
}
bin_mappers[i]->CopyTo(input_buffer.data() + i * type_size);
// free
bin_mappers[i].reset(nullptr);
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
// convert to binary size
for (int i = 0; i < num_machines; ++i) {
start[i] *= type_size;
len[i] *= type_size;
}
// gather global feature bin mappers
Network::Allgather(input_buffer.data(), buffer_size, start.data(), len.data(), output_buffer.data());
// restore features bins from buffer
for (int i = 0; i < total_num_feature; ++i) {
if (ignore_features_.count(i) > 0) {
bin_mappers[i] = nullptr;
continue;
}
bin_mappers[i].reset(new BinMapper());
bin_mappers[i]->CopyFrom(output_buffer.data() + i * type_size);
}
}
auto dataset = std::unique_ptr<Dataset>(new Dataset(num_data)); auto dataset = std::unique_ptr<Dataset>(new Dataset(num_data));
dataset->Construct(bin_mappers, sample_indices, num_per_col, total_sample_size, io_config_); dataset->Construct(bin_mappers, sample_indices, num_per_col, total_sample_size, io_config_);
dataset->set_feature_names(feature_names_); dataset->set_feature_names(feature_names_);
...@@ -715,8 +798,8 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines, ...@@ -715,8 +798,8 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
// start find bins // start find bins
if (num_machines == 1) { if (num_machines == 1) {
OMP_INIT_EX();
// if only one machine, find bin locally // if only one machine, find bin locally
OMP_INIT_EX();
#pragma omp parallel for schedule(guided) #pragma omp parallel for schedule(guided)
for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) { for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
OMP_LOOP_EX_BEGIN(); OMP_LOOP_EX_BEGIN();
......
...@@ -32,6 +32,9 @@ namespace LightGBM { ...@@ -32,6 +32,9 @@ namespace LightGBM {
*/ */
class Linkers { class Linkers {
public: public:
Linkers() {
is_init_ = false;
}
/*! /*!
* \brief Constructor * \brief Constructor
* \param config Config of network settings * \param config Config of network settings
...@@ -106,9 +109,10 @@ public: ...@@ -106,9 +109,10 @@ public:
void Construct(); void Construct();
/*! /*!
* \brief Parser machines information from file * \brief Parser machines information from file
* \param machines
* \param filename * \param filename
*/ */
void ParseMachineList(const char * filename); void ParseMachineList(const std::string& machines, const std::string& filename);
/*! /*!
* \brief Check one linker is connected or not * \brief Check one linker is connected or not
* \param rank * \param rank
...@@ -135,6 +139,8 @@ private: ...@@ -135,6 +139,8 @@ private:
std::chrono::duration<double, std::milli> network_time_; std::chrono::duration<double, std::milli> network_time_;
bool is_init_;
#ifdef USE_SOCKET #ifdef USE_SOCKET
/*! \brief use to store client ips */ /*! \brief use to store client ips */
std::vector<std::string> client_ips_; std::vector<std::string> client_ips_;
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
namespace LightGBM { namespace LightGBM {
Linkers::Linkers(NetworkConfig) { Linkers::Linkers(NetworkConfig) {
is_init_ = false;
int argc = 0; int argc = 0;
char**argv = nullptr; char**argv = nullptr;
int flag = 0; int flag = 0;
...@@ -17,10 +18,13 @@ Linkers::Linkers(NetworkConfig) { ...@@ -17,10 +18,13 @@ Linkers::Linkers(NetworkConfig) {
MPI_SAFE_CALL(MPI_Barrier(MPI_COMM_WORLD)); MPI_SAFE_CALL(MPI_Barrier(MPI_COMM_WORLD));
bruck_map_ = BruckMap::Construct(rank_, num_machines_); bruck_map_ = BruckMap::Construct(rank_, num_machines_);
recursive_halving_map_ = RecursiveHalvingMap::Construct(rank_, num_machines_); recursive_halving_map_ = RecursiveHalvingMap::Construct(rank_, num_machines_);
is_init_ = true;
} }
Linkers::~Linkers() { Linkers::~Linkers() {
MPI_SAFE_CALL(MPI_Finalize()); if (is_init_) {
MPI_SAFE_CALL(MPI_Finalize());
}
} }
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
namespace LightGBM { namespace LightGBM {
Linkers::Linkers(NetworkConfig config) { Linkers::Linkers(NetworkConfig config) {
is_init_ = false;
// start up socket // start up socket
TcpSocket::Startup(); TcpSocket::Startup();
network_time_ = std::chrono::duration<double, std::milli>(0); network_time_ = std::chrono::duration<double, std::milli>(0);
...@@ -26,7 +27,7 @@ Linkers::Linkers(NetworkConfig config) { ...@@ -26,7 +27,7 @@ Linkers::Linkers(NetworkConfig config) {
socket_timeout_ = config.time_out; socket_timeout_ = config.time_out;
rank_ = -1; rank_ = -1;
// parse clients from file // parse clients from file
ParseMachineList(config.machine_list_filename.c_str()); ParseMachineList(config.machines, config.machine_list_filename);
if (rank_ == -1) { if (rank_ == -1) {
// get ip list of local machine // get ip list of local machine
...@@ -58,35 +59,46 @@ Linkers::Linkers(NetworkConfig config) { ...@@ -58,35 +59,46 @@ Linkers::Linkers(NetworkConfig config) {
Construct(); Construct();
// free listener // free listener
listener_->Close(); listener_->Close();
is_init_ = true;
} }
Linkers::~Linkers() { Linkers::~Linkers() {
for (size_t i = 0; i < linkers_.size(); ++i) { if (is_init_) {
if (linkers_[i] != nullptr) { for (size_t i = 0; i < linkers_.size(); ++i) {
linkers_[i]->Close(); if (linkers_[i] != nullptr) {
linkers_[i]->Close();
}
} }
TcpSocket::Finalize();
Log::Info("Finished linking network in %f seconds", network_time_ * 1e-3);
} }
TcpSocket::Finalize();
Log::Info("Finished linking network in %f seconds", network_time_ * 1e-3);
} }
void Linkers::ParseMachineList(const char * filename) { void Linkers::ParseMachineList(const std::string& machines, const std::string& filename) {
TextReader<size_t> machine_list_reader(filename, false); std::vector<std::string> lines;
machine_list_reader.ReadAllLines(); if (machines.empty()) {
if (machine_list_reader.Lines().empty()) { TextReader<size_t> machine_list_reader(filename.c_str(), false);
Log::Fatal("Machine list file %s doesn't exist", filename); machine_list_reader.ReadAllLines();
if (machine_list_reader.Lines().empty()) {
Log::Fatal("Machine list file %s doesn't exist", filename.c_str());
}
lines = machine_list_reader.Lines();
} else {
lines = Common::Split(machines.c_str(), ',');
} }
for (auto& line : lines) {
for (auto& line : machine_list_reader.Lines()) {
line = Common::Trim(line); line = Common::Trim(line);
if (line.find_first_of("rank=") != std::string::npos) { if (line.find_first_of("rank=") != std::string::npos) {
std::vector<std::string> str_after_split = Common::Split(line.c_str(), '='); std::vector<std::string> str_after_split = Common::Split(line.c_str(), '=');
Common::Atoi(str_after_split[1].c_str(), &rank_); Common::Atoi(str_after_split[1].c_str(), &rank_);
continue; continue;
} }
std::vector<std::string> str_after_split = Common::Split(line.c_str() , ' '); std::vector<std::string> str_after_split = Common::Split(line.c_str(), ' ');
if (str_after_split.size() != 2) { if (str_after_split.size() != 2) {
continue; str_after_split = Common::Split(line.c_str(), ':');
if (str_after_split.size() != 2) {
continue;
}
} }
if (client_ips_.size() >= static_cast<size_t>(num_machines_)) { if (client_ips_.size() >= static_cast<size_t>(num_machines_)) {
Log::Warning("machine_list size is larger than the parameter num_machines, ignoring redundant entries"); Log::Warning("machine_list size is larger than the parameter num_machines, ignoring redundant entries");
...@@ -98,8 +110,8 @@ void Linkers::ParseMachineList(const char * filename) { ...@@ -98,8 +110,8 @@ void Linkers::ParseMachineList(const char * filename) {
client_ports_.push_back(atoi(str_after_split[1].c_str())); client_ports_.push_back(atoi(str_after_split[1].c_str()));
} }
if (client_ips_.empty()) { if (client_ips_.empty()) {
Log::Fatal("Machine list file doesn't contain any ip and port. \ Log::Fatal("Cannot find any ip and port. \
Please check it again"); Please check machine_list_filename or machines parameter.");
} }
if (client_ips_.size() != static_cast<size_t>(num_machines_)) { if (client_ips_.size() != static_cast<size_t>(num_machines_)) {
Log::Warning("World size is larger than the machine_list size, change world size to %d", client_ips_.size()); Log::Warning("World size is larger than the machine_list size, change world size to %d", client_ips_.size());
......
...@@ -10,8 +10,8 @@ ...@@ -10,8 +10,8 @@
namespace LightGBM { namespace LightGBM {
// static member definition // static member definition
THREAD_LOCAL int Network::num_machines_; THREAD_LOCAL int Network::num_machines_ = 1;
THREAD_LOCAL int Network::rank_; THREAD_LOCAL int Network::rank_ = 0;
THREAD_LOCAL std::unique_ptr<Linkers> Network::linkers_; THREAD_LOCAL std::unique_ptr<Linkers> Network::linkers_;
THREAD_LOCAL BruckMap Network::bruck_map_; THREAD_LOCAL BruckMap Network::bruck_map_;
THREAD_LOCAL RecursiveHalvingMap Network::recursive_halving_map_; THREAD_LOCAL RecursiveHalvingMap Network::recursive_halving_map_;
...@@ -21,23 +21,30 @@ THREAD_LOCAL int Network::buffer_size_; ...@@ -21,23 +21,30 @@ THREAD_LOCAL int Network::buffer_size_;
THREAD_LOCAL std::vector<char> Network::buffer_; THREAD_LOCAL std::vector<char> Network::buffer_;
void Network::Init(NetworkConfig config) { void Network::Init(NetworkConfig config) {
linkers_.reset(new Linkers(config)); if (config.num_machines > 1) {
rank_ = linkers_->rank(); linkers_.reset(new Linkers(config));
num_machines_ = linkers_->num_machines(); rank_ = linkers_->rank();
bruck_map_ = linkers_->bruck_map(); num_machines_ = linkers_->num_machines();
recursive_halving_map_ = linkers_->recursive_halving_map(); bruck_map_ = linkers_->bruck_map();
block_start_ = std::vector<int>(num_machines_); recursive_halving_map_ = linkers_->recursive_halving_map();
block_len_ = std::vector<int>(num_machines_); block_start_ = std::vector<int>(num_machines_);
buffer_size_ = 1024 * 1024; block_len_ = std::vector<int>(num_machines_);
buffer_.resize(buffer_size_); buffer_size_ = 1024 * 1024;
Log::Info("Local rank: %d, total number of machines: %d", rank_, num_machines_); buffer_.resize(buffer_size_);
Log::Info("Local rank: %d, total number of machines: %d", rank_, num_machines_);
}
} }
void Network::Dispose() { void Network::Dispose() {
num_machines_ = 1;
rank_ = 0;
linkers_.reset(new Linkers());
} }
void Network::Allreduce(char* input, int input_size, int type_size, char* output, const ReduceFunction& reducer) { void Network::Allreduce(char* input, int input_size, int type_size, char* output, const ReduceFunction& reducer) {
if (num_machines_ <= 1) {
Log::Fatal("Please initilize the network interface first");
}
int count = input_size / type_size; int count = input_size / type_size;
// if small package or small count , do it by all gather.(reduce the communication times.) // if small package or small count , do it by all gather.(reduce the communication times.)
if (count < num_machines_ || input_size < 4096) { if (count < num_machines_ || input_size < 4096) {
...@@ -62,6 +69,9 @@ void Network::Allreduce(char* input, int input_size, int type_size, char* output ...@@ -62,6 +69,9 @@ void Network::Allreduce(char* input, int input_size, int type_size, char* output
} }
void Network::AllreduceByAllGather(char* input, int input_size, char* output, const ReduceFunction& reducer) { void Network::AllreduceByAllGather(char* input, int input_size, char* output, const ReduceFunction& reducer) {
if (num_machines_ <= 1) {
Log::Fatal("Please initilize the network interface first");
}
// assign blocks // assign blocks
int all_size = input_size * num_machines_; int all_size = input_size * num_machines_;
block_start_[0] = 0; block_start_[0] = 0;
...@@ -85,6 +95,10 @@ void Network::AllreduceByAllGather(char* input, int input_size, char* output, co ...@@ -85,6 +95,10 @@ void Network::AllreduceByAllGather(char* input, int input_size, char* output, co
} }
void Network::Allgather(char* input, int send_size, char* output) { void Network::Allgather(char* input, int send_size, char* output) {
if (num_machines_ <= 1) {
Log::Fatal("Please initilize the network interface first");
}
if (num_machines_ <= 1) { return; }
// assign blocks // assign blocks
block_start_[0] = 0; block_start_[0] = 0;
block_len_[0] = send_size; block_len_[0] = send_size;
...@@ -97,6 +111,9 @@ void Network::Allgather(char* input, int send_size, char* output) { ...@@ -97,6 +111,9 @@ void Network::Allgather(char* input, int send_size, char* output) {
} }
void Network::Allgather(char* input, int all_size, const int* block_start, const int* block_len, char* output) { void Network::Allgather(char* input, int all_size, const int* block_start, const int* block_len, char* output) {
if (num_machines_ <= 1) {
Log::Fatal("Please initilize the network interface first");
}
int write_pos = 0; int write_pos = 0;
// use output as receive buffer // use output as receive buffer
std::memcpy(output, input, block_len[rank_]); std::memcpy(output, input, block_len[rank_]);
...@@ -129,6 +146,9 @@ void Network::Allgather(char* input, int all_size, const int* block_start, const ...@@ -129,6 +146,9 @@ void Network::Allgather(char* input, int all_size, const int* block_start, const
} }
void Network::ReduceScatter(char* input, int, const int* block_start, const int* block_len, char* output, const ReduceFunction& reducer) { void Network::ReduceScatter(char* input, int, const int* block_start, const int* block_len, char* output, const ReduceFunction& reducer) {
if (num_machines_ <= 1) {
Log::Fatal("Please initilize the network interface first");
}
if (recursive_halving_map_.need_pairwise) { if (recursive_halving_map_.need_pairwise) {
for (int i = 1; i < num_machines_; ++i) { for (int i = 1; i < num_machines_; ++i) {
int out_rank = (rank_ + i) % num_machines_; int out_rank = (rank_ + i) % num_machines_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment