Commit 51287f07 authored by Guolin Ke's avatar Guolin Ke
Browse files

refine network interface

parent 72b54956
......@@ -35,9 +35,8 @@ public:
/*!
* \brief Sum up (reducers) functions for histogram bin
*/
inline static void SumReducer(const char *src, char *dst, int len) {
const int type_size = sizeof(HistogramBinEntry);
int used_size = 0;
inline static void SumReducer(const char *src, char *dst, int type_size, comm_size_t len) {
comm_size_t used_size = 0;
const HistogramBinEntry* p1;
HistogramBinEntry* p2;
while (used_size < len) {
......
......@@ -757,8 +757,7 @@ LIGHTGBM_C_EXPORT int LGBM_NetworkInit(const char* machines,
*/
LIGHTGBM_C_EXPORT int LGBM_NetworkFree();
LIGHTGBM_C_EXPORT int LGBM_NetworkInitWithFunctions(void* allreduce_fun_ptr,
void* reduce_scatter_fun_ptr,
LIGHTGBM_C_EXPORT int LGBM_NetworkInitWithFunctions(void* reduce_scatter_fun_ptr,
void* allgather_fun_ptr,
int num_machines,
int rank);
......
......@@ -21,16 +21,22 @@ const score_t kEpsilon = 1e-15f;
const double kZeroThreshold = 1e-35f;
using ReduceFunction = std::function<void(const char*, char*, int)>;
typedef int32_t comm_size_t;
using PredictFunction =
std::function<void(const std::vector<std::pair<int, double>>&, double* output)>;
using AllreduceFunction = std::function<void(char*, int, int, char*, const ReduceFunction&)>;
typedef void(*ReduceFunction)(const char* input, char* output, int type_size, comm_size_t array_size);
typedef void(*ReduceScatterFunction)(char* input, comm_size_t input_size, int type_size,
const comm_size_t* block_start, const comm_size_t* block_len, int num_block, char* output, comm_size_t output_size,
const ReduceFunction& reducer);
using ReduceScatterFunction = std::function<void(char*, int, const int*, const int*, char*, const ReduceFunction&)>;
typedef void(*AllgatherFunction)(char* input, comm_size_t input_size, const comm_size_t* block_start,
const comm_size_t* block_len, int num_block, char* output, comm_size_t output_size);
using AllgatherFunction = std::function<void(char*, int, const int*, const int*, char*)>;
#define NO_SPECIFIC (-1)
......
......@@ -89,17 +89,18 @@ public:
* \param output Output result
* \param reducer Reduce function
*/
static void Allreduce(char* input, int input_size, int type_size,
static void Allreduce(char* input, comm_size_t input_size, int type_size,
char* output, const ReduceFunction& reducer);
/*!
* \brief Perform all_reduce by using all_gather. it can be use to reduce communication time when data is small
* \param input Input data
* \param input_size The size of input data
* \param type_size The size of one object in the reduce function
* \param output Output result
* \param reducer Reduce function
*/
static void AllreduceByAllGather(char* input, int input_size, char* output,
static void AllreduceByAllGather(char* input, comm_size_t input_size, int type_size, char* output,
const ReduceFunction& reducer);
/*!
......@@ -110,33 +111,34 @@ public:
* \param send_size The size of input data
* \param output Output result
*/
static void Allgather(char* input, int send_size, char* output);
static void Allgather(char* input, comm_size_t send_size, char* output);
/*!
* \brief Performing all_gather by using bruck algorithm.
Communication times is O(log(n)), and communication cost is O(all_size)
* It can be used when nodes have different input size.
* \param input Input data
* \param all_size The size of input data
* \param block_start The block start for different machines
* \param block_len The block size for different machines
* \param output Output result
* \param all_size The size of output data
*/
static void Allgather(char* input, int all_size, const int* block_start,
const int* block_len, char* output);
static void Allgather(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size);
/*!
* \brief Perform reduce scatter by using recursive halving algorithm.
Communication times is O(log(n)), and communication cost is O(input_size)
* \param input Input data
* \param input_size The size of input data
* \param type_size The size of one object in the reduce function
* \param block_start The block start for different machines
* \param block_len The block size for different machines
* \param output Output result
* \param output_size size of output data
* \param reducer Reduce function
*/
static void ReduceScatter(char* input, int input_size,
const int* block_start, const int* block_len, char* output,
static void ReduceScatter(char* input, comm_size_t input_size, int type_size,
const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t output_size,
const ReduceFunction& reducer);
template<class T>
......@@ -145,9 +147,8 @@ public:
Allreduce(reinterpret_cast<char*>(&local),
sizeof(local), sizeof(local),
reinterpret_cast<char*>(&global),
[] (const char* src, char* dst, int len) {
int used_size = 0;
const int type_size = sizeof(T);
[] (const char* src, char* dst, int type_size, comm_size_t len) {
comm_size_t used_size = 0;
const T *p1;
T *p2;
while (used_size < len) {
......@@ -170,9 +171,8 @@ public:
Allreduce(reinterpret_cast<char*>(&local),
sizeof(local), sizeof(local),
reinterpret_cast<char*>(&global),
[] (const char* src, char* dst, int len) {
int used_size = 0;
const int type_size = sizeof(T);
[] (const char* src, char* dst, int type_size, comm_size_t len) {
comm_size_t used_size = 0;
const T *p1;
T *p2;
while (used_size < len) {
......@@ -191,7 +191,6 @@ public:
/*! \brief set variables and function ptrs */
static void SetRank(int rank) { rank_ = rank;}
static void SetNumMachines(int num_machines) { num_machines_ = num_machines; }
static void SetAllReduceFunction(AllreduceFunction allreduce_ext_fun) { allreduce_ext_fun_ = allreduce_ext_fun;}
static void SetReduceScatterFunction(ReduceScatterFunction reduce_scatter_ext_fun) { reduce_scatter_ext_fun_ = reduce_scatter_ext_fun; }
static void SetAllgatherFunction(AllgatherFunction allgather_ext_fun) { allgather_ext_fun_ = allgather_ext_fun; }
......@@ -207,15 +206,14 @@ private:
/*! \brief Recursive halving map for reduce scatter */
static THREAD_LOCAL RecursiveHalvingMap recursive_halving_map_;
/*! \brief Buffer to store block start index */
static THREAD_LOCAL std::vector<int> block_start_;
static THREAD_LOCAL std::vector<comm_size_t> block_start_;
/*! \brief Buffer to store block size */
static THREAD_LOCAL std::vector<int> block_len_;
static THREAD_LOCAL std::vector<comm_size_t> block_len_;
/*! \brief Buffer */
static THREAD_LOCAL std::vector<char> buffer_;
/*! \brief Size of buffer_ */
static THREAD_LOCAL int buffer_size_;
static THREAD_LOCAL comm_size_t buffer_size_;
/*! \brief Funcs*/
static THREAD_LOCAL AllreduceFunction allreduce_ext_fun_;
static THREAD_LOCAL ReduceScatterFunction reduce_scatter_ext_fun_;
static THREAD_LOCAL AllgatherFunction allgather_ext_fun_;
};
......
......@@ -314,9 +314,8 @@ double ObtainAutomaticInitialScore(const ObjectiveFunction* fobj, const float* l
Network::Allreduce(reinterpret_cast<char*>(&init_score),
sizeof(init_score), sizeof(init_score),
reinterpret_cast<char*>(&global_init_score),
[](const char* src, char* dst, int len) {
int used_size = 0;
const int type_size = sizeof(double);
[](const char* src, char* dst, int type_size, comm_size_t len) {
comm_size_t used_size = 0;
const double *p1;
double *p2;
while (used_size < len) {
......
......@@ -1220,32 +1220,20 @@ int LGBM_NetworkFree() {
API_END();
}
int LGBM_NetworkInitWithFunctions(void* allreduce_fun_ptr,
void* reduce_scatter_fun_ptr,
int LGBM_NetworkInitWithFunctions(void* reduce_scatter_fun_ptr,
void* allgather_fun_ptr,
int num_machines,
int rank) {
API_BEGIN();
typedef void(*ReduceFunctionPtr)(const char* input, char* output, int array_size);
if (num_machines > 1) {
auto allreduce_fun = [allreduce_fun_ptr](char* arg1, int arg2, int arg3, char* arg4, const ReduceFunction& reduce_fun) {
auto reduce_fun_ptr = *reduce_fun.target<ReduceFunctionPtr>();
auto tmp = (void(*)(char*, int, int, char*, const ReduceFunctionPtr&))allreduce_fun_ptr;
return tmp(arg1, arg2, arg3, arg4, reduce_fun_ptr);
};
Network::SetAllReduceFunction(allreduce_fun);
auto reduce_scatter_fun = [reduce_scatter_fun_ptr](char* arg1, int arg2, const int* arg3, const int* arg4, char* arg5, const ReduceFunction& reduce_fun) {
auto reduce_fun_ptr = *reduce_fun.target<ReduceFunctionPtr>();
auto tmp = (void(*)(char*, int, const int*, const int*, char*, const ReduceFunctionPtr&))reduce_scatter_fun_ptr;
return tmp(arg1, arg2, arg3, arg4, arg5, reduce_fun_ptr);
};
Network::SetReduceScatterFunction(reduce_scatter_fun);
Network::SetAllgatherFunction((void(*)(char*, int, const int*, const int*, char*))allgather_fun_ptr);
Network::SetReduceScatterFunction((ReduceScatterFunction)reduce_scatter_fun_ptr);
Network::SetAllgatherFunction((AllgatherFunction)allgather_fun_ptr);
Network::SetNumMachines(num_machines);
Network::SetRank(rank);
}
API_END();
}
// ---- start of some help functions
std::function<std::vector<double>(int row_idx)>
......
......@@ -561,7 +561,7 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values,
// get size of bin mapper with max_bin size
int type_size = BinMapper::SizeForSpecificBin(max_bin);
// since sizes of different feature may not be same, we expand all bin mapper to type_size
int buffer_size = type_size * total_num_feature;
comm_size_t buffer_size = type_size * total_num_feature;
auto input_buffer = std::vector<char>(buffer_size);
auto output_buffer = std::vector<char>(buffer_size);
......@@ -578,13 +578,15 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values,
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
std::vector<comm_size_t> size_start(num_machines);
std::vector<comm_size_t> size_len(num_machines);
// convert to binary size
for (int i = 0; i < num_machines; ++i) {
start[i] *= type_size;
len[i] *= type_size;
size_start[i] = start[i] * static_cast<comm_size_t>(type_size);
size_len[i] = len[i] * static_cast<comm_size_t>(type_size);
}
// gather global feature bin mappers
Network::Allgather(input_buffer.data(), buffer_size, start.data(), len.data(), output_buffer.data());
Network::Allgather(input_buffer.data(), size_start.data(), size_len.data(), output_buffer.data(), buffer_size);
// restore features bins from buffer
for (int i = 0; i < total_num_feature; ++i) {
if (ignore_features_.count(i) > 0) {
......@@ -863,7 +865,7 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
// get size of bin mapper with max_bin size
int type_size = BinMapper::SizeForSpecificBin(max_bin);
// since sizes of different feature may not be same, we expand all bin mapper to type_size
int buffer_size = type_size * total_num_feature;
comm_size_t buffer_size = type_size * total_num_feature;
auto input_buffer = std::vector<char>(buffer_size);
auto output_buffer = std::vector<char>(buffer_size);
......@@ -880,13 +882,15 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
std::vector<comm_size_t> size_start(num_machines);
std::vector<comm_size_t> size_len(num_machines);
// convert to binary size
for (int i = 0; i < num_machines; ++i) {
start[i] *= type_size;
len[i] *= type_size;
size_start[i] = start[i] * static_cast<comm_size_t>(type_size);
size_len[i] = len[i] * static_cast<comm_size_t>(type_size);
}
// gather global feature bin mappers
Network::Allgather(input_buffer.data(), buffer_size, start.data(), len.data(), output_buffer.data());
Network::Allgather(input_buffer.data(), size_start.data(), size_len.data(), output_buffer.data(), buffer_size);
// restore features bins from buffer
for (int i = 0; i < total_num_feature; ++i) {
if (ignore_features_.count(i) > 0) {
......
......@@ -9,13 +9,14 @@
#include <algorithm>
#include <chrono>
#include <ctime>
#ifdef USE_SOCKET
#include "socket_wrapper.hpp"
#include <LightGBM/utils/common.h>
#include <thread>
#include <vector>
#include <string>
#include <memory>
#ifdef USE_SOCKET
#include "socket_wrapper.hpp"
#include <LightGBM/utils/common.h>
#endif
#ifdef USE_MPI
......@@ -51,6 +52,9 @@ public:
* \prama len Recv size, will block until recive len size of data
*/
inline void Recv(int rank, char* data, int len) const;
inline void Recv(int rank, char* data, int64_t len) const;
/*!
* \brief Send data, blocking
* \param rank Which rank local machine will send to
......@@ -58,6 +62,8 @@ public:
* \prama len Send size
*/
inline void Send(int rank, char* data, int len) const;
inline void Send(int rank, char* data, int64_t len) const;
/*!
* \brief Send and Recv at same time, blocking
* \param send_rank
......@@ -69,6 +75,9 @@ public:
*/
inline void SendRecv(int send_rank, char* send_data, int send_len,
int recv_rank, char* recv_data, int recv_len);
inline void SendRecv(int send_rank, char* send_data, int64_t send_len,
int recv_rank, char* recv_data, int64_t recv_len);
/*!
* \brief Get rank of local machine
*/
......@@ -174,6 +183,39 @@ inline const RecursiveHalvingMap& Linkers::recursive_halving_map() {
return recursive_halving_map_;
}
inline void Linkers::Recv(int rank, char* data, int64_t len) const {
int64_t used = 0;
do {
int cur_size = static_cast<int>(std::min<int64_t>(len - used, INT32_MAX));
Recv(rank, data + used, cur_size);
used += cur_size;
} while (used < len);
}
inline void Linkers::Send(int rank, char* data, int64_t len) const {
int64_t used = 0;
do {
int cur_size = static_cast<int>(std::min<int64_t>(len - used, INT32_MAX));
Send(rank, data + used, cur_size);
used += cur_size;
} while (used < len);
}
inline void Linkers::SendRecv(int send_rank, char* send_data, int64_t send_len,
int recv_rank, char* recv_data, int64_t recv_len) {
auto start_time = std::chrono::high_resolution_clock::now();
std::thread send_worker(
[this, send_rank, send_data, send_len]() {
Send(send_rank, send_data, send_len);
});
Recv(recv_rank, recv_data, recv_len);
send_worker.join();
// wait for send complete
auto end_time = std::chrono::high_resolution_clock::now();
// output used time on each iteration
network_time_ += std::chrono::duration<double, std::milli>(end_time - start_time);
}
#ifdef USE_SOCKET
inline void Linkers::Recv(int rank, char* data, int len) const {
......
......@@ -15,11 +15,10 @@ THREAD_LOCAL int Network::rank_ = 0;
THREAD_LOCAL std::unique_ptr<Linkers> Network::linkers_;
THREAD_LOCAL BruckMap Network::bruck_map_;
THREAD_LOCAL RecursiveHalvingMap Network::recursive_halving_map_;
THREAD_LOCAL std::vector<int> Network::block_start_;
THREAD_LOCAL std::vector<int> Network::block_len_;
THREAD_LOCAL int Network::buffer_size_;
THREAD_LOCAL std::vector<comm_size_t> Network::block_start_;
THREAD_LOCAL std::vector<comm_size_t> Network::block_len_;
THREAD_LOCAL comm_size_t Network::buffer_size_;
THREAD_LOCAL std::vector<char> Network::buffer_;
THREAD_LOCAL AllreduceFunction Network::allreduce_ext_fun_ = NULL;
THREAD_LOCAL ReduceScatterFunction Network::reduce_scatter_ext_fun_ = NULL;
THREAD_LOCAL AllgatherFunction Network::allgather_ext_fun_ = NULL;
......@@ -31,8 +30,8 @@ void Network::Init(NetworkConfig config) {
num_machines_ = linkers_->num_machines();
bruck_map_ = linkers_->bruck_map();
recursive_halving_map_ = linkers_->recursive_halving_map();
block_start_ = std::vector<int>(num_machines_);
block_len_ = std::vector<int>(num_machines_);
block_start_ = std::vector<comm_size_t>(num_machines_);
block_len_ = std::vector<comm_size_t>(num_machines_);
buffer_size_ = 1024 * 1024;
buffer_.resize(buffer_size_);
Log::Info("Local rank: %d, total number of machines: %d", rank_, num_machines_);
......@@ -45,42 +44,39 @@ void Network::Dispose() {
linkers_.reset(new Linkers());
}
void Network::Allreduce(char* input, int input_size, int type_size, char* output, const ReduceFunction& reducer) {
void Network::Allreduce(char* input, comm_size_t input_size, int type_size, char* output, const ReduceFunction& reducer) {
if (num_machines_ <= 1) {
Log::Fatal("Please initilize the network interface first");
}
if (allreduce_ext_fun_ != NULL) {
return allreduce_ext_fun_(input, input_size, type_size, output, reducer);
}
int count = input_size / type_size;
comm_size_t count = input_size / type_size;
// if small package or small count , do it by all gather.(reduce the communication times.)
if (count < num_machines_ || input_size < 4096) {
AllreduceByAllGather(input, input_size, output, reducer);
AllreduceByAllGather(input, input_size, type_size, output, reducer);
return;
}
// assign the blocks to every rank.
int step = (count + num_machines_ - 1) / num_machines_;
comm_size_t step = (count + num_machines_ - 1) / num_machines_;
if (step < 1) {
step = 1;
}
block_start_[0] = 0;
for (int i = 0; i < num_machines_ - 1; ++i) {
block_len_[i] = std::min(step * type_size, input_size - block_start_[i]);
block_len_[i] = std::min<comm_size_t>(step * type_size, input_size - block_start_[i]);
block_start_[i + 1] = block_start_[i] + block_len_[i];
}
block_len_[num_machines_ - 1] = input_size - block_start_[num_machines_ - 1];
// do reduce scatter
ReduceScatter(input, input_size, block_start_.data(), block_len_.data(), output, reducer);
ReduceScatter(input, input_size, type_size, block_start_.data(), block_len_.data(), output, input_size, reducer);
// do all gather
Allgather(output, input_size, block_start_.data(), block_len_.data(), output);
Allgather(output, block_start_.data(), block_len_.data(), output, input_size);
}
void Network::AllreduceByAllGather(char* input, int input_size, char* output, const ReduceFunction& reducer) {
void Network::AllreduceByAllGather(char* input, comm_size_t input_size, int type_size, char* output, const ReduceFunction& reducer) {
if (num_machines_ <= 1) {
Log::Fatal("Please initilize the network interface first");
}
// assign blocks
int all_size = input_size * num_machines_;
comm_size_t all_size = input_size * num_machines_;
block_start_[0] = 0;
block_len_[0] = input_size;
for (int i = 1; i < num_machines_; ++i) {
......@@ -93,15 +89,15 @@ void Network::AllreduceByAllGather(char* input, int input_size, char* output, co
buffer_.resize(buffer_size_);
}
Allgather(input, all_size, block_start_.data(), block_len_.data(), buffer_.data());
Allgather(input, block_start_.data(), block_len_.data(), buffer_.data(), all_size);
for (int i = 1; i < num_machines_; ++i) {
reducer(buffer_.data() + block_start_[i], buffer_.data() + block_start_[0], input_size);
reducer(buffer_.data() + block_start_[i], buffer_.data() + block_start_[0], type_size, input_size);
}
// copy back
std::memcpy(output, buffer_.data(), input_size);
}
void Network::Allgather(char* input, int send_size, char* output) {
void Network::Allgather(char* input, comm_size_t send_size, char* output) {
if (num_machines_ <= 1) {
Log::Fatal("Please initilize the network interface first");
}
......@@ -114,17 +110,17 @@ void Network::Allgather(char* input, int send_size, char* output) {
block_len_[i] = send_size;
}
// start all gather
Allgather(input, send_size * num_machines_, block_start_.data(), block_len_.data(), output);
Allgather(input, block_start_.data(), block_len_.data(), output, send_size * num_machines_);
}
void Network::Allgather(char* input, int all_size, const int* block_start, const int* block_len, char* output) {
void Network::Allgather(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size) {
if (num_machines_ <= 1) {
Log::Fatal("Please initilize the network interface first");
}
if (allgather_ext_fun_ != NULL) {
return allgather_ext_fun_(input, all_size, block_start, block_len, output);
return allgather_ext_fun_(input, block_len[rank_], block_start, block_len, num_machines_, output, all_size);
}
int write_pos = 0;
comm_size_t write_pos = 0;
// use output as receive buffer
std::memcpy(output, input, block_len[rank_]);
write_pos += block_len[rank_];
......@@ -137,9 +133,9 @@ void Network::Allgather(char* input, int all_size, const int* block_start, const
// get in rank
int in_rank = bruck_map_.in_ranks[i];
// get send information
int need_send_len = 0;
comm_size_t need_send_len = 0;
// get recv information
int need_recv_len = 0;
comm_size_t need_recv_len = 0;
for (int j = 0; j < cur_block_size; ++j) {
need_send_len += block_len[(rank_ + j) % num_machines_];
need_recv_len += block_len[(rank_ + accumulated_block + j) % num_machines_];
......@@ -155,40 +151,40 @@ void Network::Allgather(char* input, int all_size, const int* block_start, const
std::reverse<char*>(output + block_start[rank_], output + all_size);
}
void Network::ReduceScatter(char* input, int input_size, const int* block_start, const int* block_len, char* output, const ReduceFunction& reducer) {
void Network::ReduceScatter(char* input, comm_size_t input_size, int type_size, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t output_size, const ReduceFunction& reducer) {
if (num_machines_ <= 1) {
Log::Fatal("Please initilize the network interface first");
}
if (reduce_scatter_ext_fun_ != NULL) {
return reduce_scatter_ext_fun_(input, input_size, block_start, block_len, output, reducer);
return reduce_scatter_ext_fun_(input, input_size, type_size, block_start, block_len, num_machines_, output, output_size, reducer);
}
if (recursive_halving_map_.need_pairwise) {
for (int i = 1; i < num_machines_; ++i) {
int out_rank = (rank_ + i) % num_machines_;
int in_rank = (rank_ - i + num_machines_) % num_machines_;
linkers_->SendRecv(out_rank, input + block_start[out_rank], block_len[out_rank], in_rank, output, block_len[rank_]);
reducer(output, input + block_start[rank_], block_len[rank_]);
reducer(output, input + block_start[rank_], type_size, block_len[rank_]);
}
} else {
for (int i = 0; i < recursive_halving_map_.k; ++i) {
// get target
int target = recursive_halving_map_.ranks[i];
int send_block_start = recursive_halving_map_.send_block_start[i];
int recv_block_start = recursive_halving_map_.recv_block_start[i];
comm_size_t send_block_start = recursive_halving_map_.send_block_start[i];
comm_size_t recv_block_start = recursive_halving_map_.recv_block_start[i];
// get send information
int send_size = 0;
comm_size_t send_size = 0;
for (int j = 0; j < recursive_halving_map_.send_block_len[i]; ++j) {
send_size += block_len[send_block_start + j];
}
// get recv information
int need_recv_cnt = 0;
comm_size_t need_recv_cnt = 0;
for (int j = 0; j < recursive_halving_map_.recv_block_len[i]; ++j) {
need_recv_cnt += block_len[recv_block_start + j];
}
// send and recv at same time
linkers_->SendRecv(target, input + block_start[send_block_start], send_size, target, output, need_recv_cnt);
// reduce
reducer(output, input + block_start[recv_block_start], need_recv_cnt);
reducer(output, input + block_start[recv_block_start], type_size, need_recv_cnt);
}
}
// copy result
......
......@@ -120,9 +120,8 @@ void DataParallelTreeLearner<TREELEARNER_T>::BeforeTrain() {
int size = sizeof(data);
std::memcpy(input_buffer_.data(), &data, size);
// global sumup reduce
Network::Allreduce(input_buffer_.data(), size, size, output_buffer_.data(), [](const char *src, char *dst, int len) {
int used_size = 0;
int type_size = sizeof(std::tuple<data_size_t, double, double>);
Network::Allreduce(input_buffer_.data(), size, sizeof(std::tuple<data_size_t, double, double>), output_buffer_.data(), [](const char *src, char *dst, int type_size, comm_size_t len) {
comm_size_t used_size = 0;
const std::tuple<data_size_t, double, double> *p1;
std::tuple<data_size_t, double, double> *p2;
while (used_size < len) {
......@@ -157,8 +156,8 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplits() {
this->smaller_leaf_histogram_array_[feature_index].SizeOfHistgram());
}
// Reduce scatter for histogram
Network::ReduceScatter(input_buffer_.data(), reduce_scatter_size_, block_start_.data(),
block_len_.data(), output_buffer_.data(), &HistogramBinEntry::SumReducer);
Network::ReduceScatter(input_buffer_.data(), reduce_scatter_size_, sizeof(HistogramBinEntry), block_start_.data(),
block_len_.data(), output_buffer_.data(), static_cast<comm_size_t>(output_buffer_.size()), &HistogramBinEntry::SumReducer);
this->FindBestSplitsFromHistograms(this->is_feature_used_, true);
}
......
......@@ -79,15 +79,15 @@ private:
use this to mark local aggregate features*/
std::vector<bool> is_feature_aggregated_;
/*! \brief Block start index for reduce scatter */
std::vector<int> block_start_;
std::vector<comm_size_t> block_start_;
/*! \brief Block size for reduce scatter */
std::vector<int> block_len_;
std::vector<comm_size_t> block_len_;
/*! \brief Write positions for feature histograms */
std::vector<int> buffer_write_start_pos_;
std::vector<comm_size_t> buffer_write_start_pos_;
/*! \brief Read positions for local feature histograms */
std::vector<int> buffer_read_start_pos_;
std::vector<comm_size_t> buffer_read_start_pos_;
/*! \brief Size for reduce scatter */
int reduce_scatter_size_;
comm_size_t reduce_scatter_size_;
/*! \brief Store global number of data in leaves */
std::vector<data_size_t> global_data_count_in_leaf_;
};
......@@ -155,15 +155,15 @@ private:
use this to mark local aggregate features*/
std::vector<bool> larger_is_feature_aggregated_;
/*! \brief Block start index for reduce scatter */
std::vector<int> block_start_;
std::vector<comm_size_t> block_start_;
/*! \brief Block size for reduce scatter */
std::vector<int> block_len_;
std::vector<comm_size_t> block_len_;
/*! \brief Read positions for feature histgrams at smaller leaf */
std::vector<int> smaller_buffer_read_start_pos_;
std::vector<comm_size_t> smaller_buffer_read_start_pos_;
/*! \brief Read positions for feature histgrams at larger leaf */
std::vector<int> larger_buffer_read_start_pos_;
std::vector<comm_size_t> larger_buffer_read_start_pos_;
/*! \brief Size for reduce scatter */
int reduce_scatter_size_;
comm_size_t reduce_scatter_size_;
/*! \brief Store global number of data in leaves */
std::vector<data_size_t> global_data_count_in_leaf_;
/*! \brief Store global split information for smaller leaf */
......@@ -187,8 +187,8 @@ inline void SyncUpGlobalBestSplit(char* input_buffer_, char* output_buffer_, Spl
smaller_best_split->CopyTo(input_buffer_);
larger_best_split->CopyTo(input_buffer_ + size);
Network::Allreduce(input_buffer_, size * 2, size, output_buffer_,
[&size] (const char* src, char* dst, int len) {
int used_size = 0;
[] (const char* src, char* dst, int size, comm_size_t len) {
comm_size_t used_size = 0;
LightSplitInfo p1, p2;
while (used_size < len) {
p1.CopyFrom(src);
......
......@@ -113,9 +113,8 @@ void VotingParallelTreeLearner<TREELEARNER_T>::BeforeTrain() {
int size = sizeof(std::tuple<data_size_t, double, double>);
std::memcpy(input_buffer_.data(), &data, size);
Network::Allreduce(input_buffer_.data(), size, size, output_buffer_.data(), [](const char *src, char *dst, int len) {
int used_size = 0;
int type_size = sizeof(std::tuple<data_size_t, double, double>);
Network::Allreduce(input_buffer_.data(), size, sizeof(std::tuple<data_size_t, double, double>), output_buffer_.data(), [](const char *src, char *dst, int type_size, comm_size_t len) {
comm_size_t used_size = 0;
const std::tuple<data_size_t, double, double> *p1;
std::tuple<data_size_t, double, double> *p2;
while (used_size < len) {
......@@ -357,8 +356,8 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplits() {
CopyLocalHistogram(smaller_top_features, larger_top_features);
// Reduce scatter for histogram
Network::ReduceScatter(input_buffer_.data(), reduce_scatter_size_, block_start_.data(), block_len_.data(),
output_buffer_.data(), &HistogramBinEntry::SumReducer);
Network::ReduceScatter(input_buffer_.data(), reduce_scatter_size_, sizeof(HistogramBinEntry), block_start_.data(), block_len_.data(),
output_buffer_.data(), static_cast<comm_size_t>(output_buffer_.size()), &HistogramBinEntry::SumReducer);
this->FindBestSplitsFromHistograms(is_feature_used, false);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment