"vscode:/vscode.git/clone" did not exist on "499dfb3d22bb53a518eea71be5c34333b3720ecf"
Commit 51287f07 authored by Guolin Ke's avatar Guolin Ke
Browse files

refine network interface

parent 72b54956
...@@ -35,9 +35,8 @@ public: ...@@ -35,9 +35,8 @@ public:
/*! /*!
* \brief Sum up (reducers) functions for histogram bin * \brief Sum up (reducers) functions for histogram bin
*/ */
inline static void SumReducer(const char *src, char *dst, int len) { inline static void SumReducer(const char *src, char *dst, int type_size, comm_size_t len) {
const int type_size = sizeof(HistogramBinEntry); comm_size_t used_size = 0;
int used_size = 0;
const HistogramBinEntry* p1; const HistogramBinEntry* p1;
HistogramBinEntry* p2; HistogramBinEntry* p2;
while (used_size < len) { while (used_size < len) {
......
...@@ -757,8 +757,7 @@ LIGHTGBM_C_EXPORT int LGBM_NetworkInit(const char* machines, ...@@ -757,8 +757,7 @@ LIGHTGBM_C_EXPORT int LGBM_NetworkInit(const char* machines,
*/ */
LIGHTGBM_C_EXPORT int LGBM_NetworkFree(); LIGHTGBM_C_EXPORT int LGBM_NetworkFree();
LIGHTGBM_C_EXPORT int LGBM_NetworkInitWithFunctions(void* allreduce_fun_ptr, LIGHTGBM_C_EXPORT int LGBM_NetworkInitWithFunctions(void* reduce_scatter_fun_ptr,
void* reduce_scatter_fun_ptr,
void* allgather_fun_ptr, void* allgather_fun_ptr,
int num_machines, int num_machines,
int rank); int rank);
......
...@@ -21,16 +21,22 @@ const score_t kEpsilon = 1e-15f; ...@@ -21,16 +21,22 @@ const score_t kEpsilon = 1e-15f;
const double kZeroThreshold = 1e-35f; const double kZeroThreshold = 1e-35f;
using ReduceFunction = std::function<void(const char*, char*, int)>;
typedef int32_t comm_size_t;
using PredictFunction = using PredictFunction =
std::function<void(const std::vector<std::pair<int, double>>&, double* output)>; std::function<void(const std::vector<std::pair<int, double>>&, double* output)>;
using AllreduceFunction = std::function<void(char*, int, int, char*, const ReduceFunction&)>; typedef void(*ReduceFunction)(const char* input, char* output, int type_size, comm_size_t array_size);
typedef void(*ReduceScatterFunction)(char* input, comm_size_t input_size, int type_size,
const comm_size_t* block_start, const comm_size_t* block_len, int num_block, char* output, comm_size_t output_size,
const ReduceFunction& reducer);
using ReduceScatterFunction = std::function<void(char*, int, const int*, const int*, char*, const ReduceFunction&)>; typedef void(*AllgatherFunction)(char* input, comm_size_t input_size, const comm_size_t* block_start,
const comm_size_t* block_len, int num_block, char* output, comm_size_t output_size);
using AllgatherFunction = std::function<void(char*, int, const int*, const int*, char*)>;
#define NO_SPECIFIC (-1) #define NO_SPECIFIC (-1)
......
...@@ -89,18 +89,19 @@ public: ...@@ -89,18 +89,19 @@ public:
* \param output Output result * \param output Output result
* \param reducer Reduce function * \param reducer Reduce function
*/ */
static void Allreduce(char* input, int input_size, int type_size, static void Allreduce(char* input, comm_size_t input_size, int type_size,
char* output, const ReduceFunction& reducer); char* output, const ReduceFunction& reducer);
/*! /*!
* \brief Perform all_reduce by using all_gather. it can be use to reduce communication time when data is small * \brief Perform all_reduce by using all_gather. it can be use to reduce communication time when data is small
* \param input Input data * \param input Input data
* \param input_size The size of input data * \param input_size The size of input data
* \param type_size The size of one object in the reduce function
* \param output Output result * \param output Output result
* \param reducer Reduce function * \param reducer Reduce function
*/ */
static void AllreduceByAllGather(char* input, int input_size, char* output, static void AllreduceByAllGather(char* input, comm_size_t input_size, int type_size, char* output,
const ReduceFunction& reducer); const ReduceFunction& reducer);
/*! /*!
* \brief Performing all_gather by using bruck algorithm. * \brief Performing all_gather by using bruck algorithm.
...@@ -110,34 +111,35 @@ public: ...@@ -110,34 +111,35 @@ public:
* \param send_size The size of input data * \param send_size The size of input data
* \param output Output result * \param output Output result
*/ */
static void Allgather(char* input, int send_size, char* output); static void Allgather(char* input, comm_size_t send_size, char* output);
/*! /*!
* \brief Performing all_gather by using bruck algorithm. * \brief Performing all_gather by using bruck algorithm.
Communication times is O(log(n)), and communication cost is O(all_size) Communication times is O(log(n)), and communication cost is O(all_size)
* It can be used when nodes have different input size. * It can be used when nodes have different input size.
* \param input Input data * \param input Input data
* \param all_size The size of input data
* \param block_start The block start for different machines * \param block_start The block start for different machines
* \param block_len The block size for different machines * \param block_len The block size for different machines
* \param output Output result * \param output Output result
* \param all_size The size of output data
*/ */
static void Allgather(char* input, int all_size, const int* block_start, static void Allgather(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size);
const int* block_len, char* output);
/*! /*!
* \brief Perform reduce scatter by using recursive halving algorithm. * \brief Perform reduce scatter by using recursive halving algorithm.
Communication times is O(log(n)), and communication cost is O(input_size) Communication times is O(log(n)), and communication cost is O(input_size)
* \param input Input data * \param input Input data
* \param input_size The size of input data * \param input_size The size of input data
* \param type_size The size of one object in the reduce function
* \param block_start The block start for different machines * \param block_start The block start for different machines
* \param block_len The block size for different machines * \param block_len The block size for different machines
* \param output Output result * \param output Output result
* \param output_size size of output data
* \param reducer Reduce function * \param reducer Reduce function
*/ */
static void ReduceScatter(char* input, int input_size, static void ReduceScatter(char* input, comm_size_t input_size, int type_size,
const int* block_start, const int* block_len, char* output, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t output_size,
const ReduceFunction& reducer); const ReduceFunction& reducer);
template<class T> template<class T>
static T GlobalSyncUpByMin(T& local) { static T GlobalSyncUpByMin(T& local) {
...@@ -145,9 +147,8 @@ public: ...@@ -145,9 +147,8 @@ public:
Allreduce(reinterpret_cast<char*>(&local), Allreduce(reinterpret_cast<char*>(&local),
sizeof(local), sizeof(local), sizeof(local), sizeof(local),
reinterpret_cast<char*>(&global), reinterpret_cast<char*>(&global),
[] (const char* src, char* dst, int len) { [] (const char* src, char* dst, int type_size, comm_size_t len) {
int used_size = 0; comm_size_t used_size = 0;
const int type_size = sizeof(T);
const T *p1; const T *p1;
T *p2; T *p2;
while (used_size < len) { while (used_size < len) {
...@@ -170,9 +171,8 @@ public: ...@@ -170,9 +171,8 @@ public:
Allreduce(reinterpret_cast<char*>(&local), Allreduce(reinterpret_cast<char*>(&local),
sizeof(local), sizeof(local), sizeof(local), sizeof(local),
reinterpret_cast<char*>(&global), reinterpret_cast<char*>(&global),
[] (const char* src, char* dst, int len) { [] (const char* src, char* dst, int type_size, comm_size_t len) {
int used_size = 0; comm_size_t used_size = 0;
const int type_size = sizeof(T);
const T *p1; const T *p1;
T *p2; T *p2;
while (used_size < len) { while (used_size < len) {
...@@ -191,7 +191,6 @@ public: ...@@ -191,7 +191,6 @@ public:
/*! \brief set variables and function ptrs */ /*! \brief set variables and function ptrs */
static void SetRank(int rank) { rank_ = rank;} static void SetRank(int rank) { rank_ = rank;}
static void SetNumMachines(int num_machines) { num_machines_ = num_machines; } static void SetNumMachines(int num_machines) { num_machines_ = num_machines; }
static void SetAllReduceFunction(AllreduceFunction allreduce_ext_fun) { allreduce_ext_fun_ = allreduce_ext_fun;}
static void SetReduceScatterFunction(ReduceScatterFunction reduce_scatter_ext_fun) { reduce_scatter_ext_fun_ = reduce_scatter_ext_fun; } static void SetReduceScatterFunction(ReduceScatterFunction reduce_scatter_ext_fun) { reduce_scatter_ext_fun_ = reduce_scatter_ext_fun; }
static void SetAllgatherFunction(AllgatherFunction allgather_ext_fun) { allgather_ext_fun_ = allgather_ext_fun; } static void SetAllgatherFunction(AllgatherFunction allgather_ext_fun) { allgather_ext_fun_ = allgather_ext_fun; }
...@@ -207,15 +206,14 @@ private: ...@@ -207,15 +206,14 @@ private:
/*! \brief Recursive halving map for reduce scatter */ /*! \brief Recursive halving map for reduce scatter */
static THREAD_LOCAL RecursiveHalvingMap recursive_halving_map_; static THREAD_LOCAL RecursiveHalvingMap recursive_halving_map_;
/*! \brief Buffer to store block start index */ /*! \brief Buffer to store block start index */
static THREAD_LOCAL std::vector<int> block_start_; static THREAD_LOCAL std::vector<comm_size_t> block_start_;
/*! \brief Buffer to store block size */ /*! \brief Buffer to store block size */
static THREAD_LOCAL std::vector<int> block_len_; static THREAD_LOCAL std::vector<comm_size_t> block_len_;
/*! \brief Buffer */ /*! \brief Buffer */
static THREAD_LOCAL std::vector<char> buffer_; static THREAD_LOCAL std::vector<char> buffer_;
/*! \brief Size of buffer_ */ /*! \brief Size of buffer_ */
static THREAD_LOCAL int buffer_size_; static THREAD_LOCAL comm_size_t buffer_size_;
/*! \brief Funcs*/ /*! \brief Funcs*/
static THREAD_LOCAL AllreduceFunction allreduce_ext_fun_;
static THREAD_LOCAL ReduceScatterFunction reduce_scatter_ext_fun_; static THREAD_LOCAL ReduceScatterFunction reduce_scatter_ext_fun_;
static THREAD_LOCAL AllgatherFunction allgather_ext_fun_; static THREAD_LOCAL AllgatherFunction allgather_ext_fun_;
}; };
......
...@@ -314,9 +314,8 @@ double ObtainAutomaticInitialScore(const ObjectiveFunction* fobj, const float* l ...@@ -314,9 +314,8 @@ double ObtainAutomaticInitialScore(const ObjectiveFunction* fobj, const float* l
Network::Allreduce(reinterpret_cast<char*>(&init_score), Network::Allreduce(reinterpret_cast<char*>(&init_score),
sizeof(init_score), sizeof(init_score), sizeof(init_score), sizeof(init_score),
reinterpret_cast<char*>(&global_init_score), reinterpret_cast<char*>(&global_init_score),
[](const char* src, char* dst, int len) { [](const char* src, char* dst, int type_size, comm_size_t len) {
int used_size = 0; comm_size_t used_size = 0;
const int type_size = sizeof(double);
const double *p1; const double *p1;
double *p2; double *p2;
while (used_size < len) { while (used_size < len) {
......
...@@ -1220,32 +1220,20 @@ int LGBM_NetworkFree() { ...@@ -1220,32 +1220,20 @@ int LGBM_NetworkFree() {
API_END(); API_END();
} }
int LGBM_NetworkInitWithFunctions(void* allreduce_fun_ptr, int LGBM_NetworkInitWithFunctions(void* reduce_scatter_fun_ptr,
void* reduce_scatter_fun_ptr,
void* allgather_fun_ptr, void* allgather_fun_ptr,
int num_machines, int num_machines,
int rank) { int rank) {
API_BEGIN(); API_BEGIN();
typedef void(*ReduceFunctionPtr)(const char* input, char* output, int array_size);
if (num_machines > 1) { if (num_machines > 1) {
auto allreduce_fun = [allreduce_fun_ptr](char* arg1, int arg2, int arg3, char* arg4, const ReduceFunction& reduce_fun) { Network::SetReduceScatterFunction((ReduceScatterFunction)reduce_scatter_fun_ptr);
auto reduce_fun_ptr = *reduce_fun.target<ReduceFunctionPtr>(); Network::SetAllgatherFunction((AllgatherFunction)allgather_fun_ptr);
auto tmp = (void(*)(char*, int, int, char*, const ReduceFunctionPtr&))allreduce_fun_ptr;
return tmp(arg1, arg2, arg3, arg4, reduce_fun_ptr);
};
Network::SetAllReduceFunction(allreduce_fun);
auto reduce_scatter_fun = [reduce_scatter_fun_ptr](char* arg1, int arg2, const int* arg3, const int* arg4, char* arg5, const ReduceFunction& reduce_fun) {
auto reduce_fun_ptr = *reduce_fun.target<ReduceFunctionPtr>();
auto tmp = (void(*)(char*, int, const int*, const int*, char*, const ReduceFunctionPtr&))reduce_scatter_fun_ptr;
return tmp(arg1, arg2, arg3, arg4, arg5, reduce_fun_ptr);
};
Network::SetReduceScatterFunction(reduce_scatter_fun);
Network::SetAllgatherFunction((void(*)(char*, int, const int*, const int*, char*))allgather_fun_ptr);
Network::SetNumMachines(num_machines); Network::SetNumMachines(num_machines);
Network::SetRank(rank); Network::SetRank(rank);
} }
API_END(); API_END();
} }
// ---- start of some help functions // ---- start of some help functions
std::function<std::vector<double>(int row_idx)> std::function<std::vector<double>(int row_idx)>
......
...@@ -561,7 +561,7 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values, ...@@ -561,7 +561,7 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values,
// get size of bin mapper with max_bin size // get size of bin mapper with max_bin size
int type_size = BinMapper::SizeForSpecificBin(max_bin); int type_size = BinMapper::SizeForSpecificBin(max_bin);
// since sizes of different feature may not be same, we expand all bin mapper to type_size // since sizes of different feature may not be same, we expand all bin mapper to type_size
int buffer_size = type_size * total_num_feature; comm_size_t buffer_size = type_size * total_num_feature;
auto input_buffer = std::vector<char>(buffer_size); auto input_buffer = std::vector<char>(buffer_size);
auto output_buffer = std::vector<char>(buffer_size); auto output_buffer = std::vector<char>(buffer_size);
...@@ -578,13 +578,15 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values, ...@@ -578,13 +578,15 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values,
OMP_LOOP_EX_END(); OMP_LOOP_EX_END();
} }
OMP_THROW_EX(); OMP_THROW_EX();
std::vector<comm_size_t> size_start(num_machines);
std::vector<comm_size_t> size_len(num_machines);
// convert to binary size // convert to binary size
for (int i = 0; i < num_machines; ++i) { for (int i = 0; i < num_machines; ++i) {
start[i] *= type_size; size_start[i] = start[i] * static_cast<comm_size_t>(type_size);
len[i] *= type_size; size_len[i] = len[i] * static_cast<comm_size_t>(type_size);
} }
// gather global feature bin mappers // gather global feature bin mappers
Network::Allgather(input_buffer.data(), buffer_size, start.data(), len.data(), output_buffer.data()); Network::Allgather(input_buffer.data(), size_start.data(), size_len.data(), output_buffer.data(), buffer_size);
// restore features bins from buffer // restore features bins from buffer
for (int i = 0; i < total_num_feature; ++i) { for (int i = 0; i < total_num_feature; ++i) {
if (ignore_features_.count(i) > 0) { if (ignore_features_.count(i) > 0) {
...@@ -863,7 +865,7 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines, ...@@ -863,7 +865,7 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
// get size of bin mapper with max_bin size // get size of bin mapper with max_bin size
int type_size = BinMapper::SizeForSpecificBin(max_bin); int type_size = BinMapper::SizeForSpecificBin(max_bin);
// since sizes of different feature may not be same, we expand all bin mapper to type_size // since sizes of different feature may not be same, we expand all bin mapper to type_size
int buffer_size = type_size * total_num_feature; comm_size_t buffer_size = type_size * total_num_feature;
auto input_buffer = std::vector<char>(buffer_size); auto input_buffer = std::vector<char>(buffer_size);
auto output_buffer = std::vector<char>(buffer_size); auto output_buffer = std::vector<char>(buffer_size);
...@@ -880,13 +882,15 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines, ...@@ -880,13 +882,15 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
OMP_LOOP_EX_END(); OMP_LOOP_EX_END();
} }
OMP_THROW_EX(); OMP_THROW_EX();
std::vector<comm_size_t> size_start(num_machines);
std::vector<comm_size_t> size_len(num_machines);
// convert to binary size // convert to binary size
for (int i = 0; i < num_machines; ++i) { for (int i = 0; i < num_machines; ++i) {
start[i] *= type_size; size_start[i] = start[i] * static_cast<comm_size_t>(type_size);
len[i] *= type_size; size_len[i] = len[i] * static_cast<comm_size_t>(type_size);
} }
// gather global feature bin mappers // gather global feature bin mappers
Network::Allgather(input_buffer.data(), buffer_size, start.data(), len.data(), output_buffer.data()); Network::Allgather(input_buffer.data(), size_start.data(), size_len.data(), output_buffer.data(), buffer_size);
// restore features bins from buffer // restore features bins from buffer
for (int i = 0; i < total_num_feature; ++i) { for (int i = 0; i < total_num_feature; ++i) {
if (ignore_features_.count(i) > 0) { if (ignore_features_.count(i) > 0) {
......
...@@ -9,13 +9,14 @@ ...@@ -9,13 +9,14 @@
#include <algorithm> #include <algorithm>
#include <chrono> #include <chrono>
#include <ctime> #include <ctime>
#ifdef USE_SOCKET
#include "socket_wrapper.hpp"
#include <LightGBM/utils/common.h>
#include <thread> #include <thread>
#include <vector> #include <vector>
#include <string> #include <string>
#include <memory> #include <memory>
#ifdef USE_SOCKET
#include "socket_wrapper.hpp"
#include <LightGBM/utils/common.h>
#endif #endif
#ifdef USE_MPI #ifdef USE_MPI
...@@ -51,6 +52,9 @@ public: ...@@ -51,6 +52,9 @@ public:
* \prama len Recv size, will block until recive len size of data * \prama len Recv size, will block until recive len size of data
*/ */
inline void Recv(int rank, char* data, int len) const; inline void Recv(int rank, char* data, int len) const;
inline void Recv(int rank, char* data, int64_t len) const;
/*! /*!
* \brief Send data, blocking * \brief Send data, blocking
* \param rank Which rank local machine will send to * \param rank Which rank local machine will send to
...@@ -58,6 +62,8 @@ public: ...@@ -58,6 +62,8 @@ public:
* \prama len Send size * \prama len Send size
*/ */
inline void Send(int rank, char* data, int len) const; inline void Send(int rank, char* data, int len) const;
inline void Send(int rank, char* data, int64_t len) const;
/*! /*!
* \brief Send and Recv at same time, blocking * \brief Send and Recv at same time, blocking
* \param send_rank * \param send_rank
...@@ -68,7 +74,10 @@ public: ...@@ -68,7 +74,10 @@ public:
* \prama recv_len * \prama recv_len
*/ */
inline void SendRecv(int send_rank, char* send_data, int send_len, inline void SendRecv(int send_rank, char* send_data, int send_len,
int recv_rank, char* recv_data, int recv_len); int recv_rank, char* recv_data, int recv_len);
inline void SendRecv(int send_rank, char* send_data, int64_t send_len,
int recv_rank, char* recv_data, int64_t recv_len);
/*! /*!
* \brief Get rank of local machine * \brief Get rank of local machine
*/ */
...@@ -174,6 +183,39 @@ inline const RecursiveHalvingMap& Linkers::recursive_halving_map() { ...@@ -174,6 +183,39 @@ inline const RecursiveHalvingMap& Linkers::recursive_halving_map() {
return recursive_halving_map_; return recursive_halving_map_;
} }
inline void Linkers::Recv(int rank, char* data, int64_t len) const {
int64_t used = 0;
do {
int cur_size = static_cast<int>(std::min<int64_t>(len - used, INT32_MAX));
Recv(rank, data + used, cur_size);
used += cur_size;
} while (used < len);
}
inline void Linkers::Send(int rank, char* data, int64_t len) const {
int64_t used = 0;
do {
int cur_size = static_cast<int>(std::min<int64_t>(len - used, INT32_MAX));
Send(rank, data + used, cur_size);
used += cur_size;
} while (used < len);
}
inline void Linkers::SendRecv(int send_rank, char* send_data, int64_t send_len,
int recv_rank, char* recv_data, int64_t recv_len) {
auto start_time = std::chrono::high_resolution_clock::now();
std::thread send_worker(
[this, send_rank, send_data, send_len]() {
Send(send_rank, send_data, send_len);
});
Recv(recv_rank, recv_data, recv_len);
send_worker.join();
// wait for send complete
auto end_time = std::chrono::high_resolution_clock::now();
// output used time on each iteration
network_time_ += std::chrono::duration<double, std::milli>(end_time - start_time);
}
#ifdef USE_SOCKET #ifdef USE_SOCKET
inline void Linkers::Recv(int rank, char* data, int len) const { inline void Linkers::Recv(int rank, char* data, int len) const {
...@@ -197,7 +239,7 @@ inline void Linkers::Send(int rank, char* data, int len) const { ...@@ -197,7 +239,7 @@ inline void Linkers::Send(int rank, char* data, int len) const {
} }
inline void Linkers::SendRecv(int send_rank, char* send_data, int send_len, inline void Linkers::SendRecv(int send_rank, char* send_data, int send_len,
int recv_rank, char* recv_data, int recv_len) { int recv_rank, char* recv_data, int recv_len) {
auto start_time = std::chrono::high_resolution_clock::now(); auto start_time = std::chrono::high_resolution_clock::now();
if (send_len < SocketConfig::kSocketBufferSize) { if (send_len < SocketConfig::kSocketBufferSize) {
// if buffer is enough, send will non-blocking // if buffer is enough, send will non-blocking
...@@ -244,7 +286,7 @@ inline void Linkers::Send(int rank, char* data, int len) const { ...@@ -244,7 +286,7 @@ inline void Linkers::Send(int rank, char* data, int len) const {
} }
inline void Linkers::SendRecv(int send_rank, char* send_data, int send_len, inline void Linkers::SendRecv(int send_rank, char* send_data, int send_len,
int recv_rank, char* recv_data, int recv_len) { int recv_rank, char* recv_data, int recv_len) {
MPI_Request send_request; MPI_Request send_request;
// send first, non-blocking // send first, non-blocking
MPI_SAFE_CALL(MPI_Isend(send_data, send_len, MPI_BYTE, send_rank, 0, MPI_COMM_WORLD, &send_request)); MPI_SAFE_CALL(MPI_Isend(send_data, send_len, MPI_BYTE, send_rank, 0, MPI_COMM_WORLD, &send_request));
......
...@@ -15,11 +15,10 @@ THREAD_LOCAL int Network::rank_ = 0; ...@@ -15,11 +15,10 @@ THREAD_LOCAL int Network::rank_ = 0;
THREAD_LOCAL std::unique_ptr<Linkers> Network::linkers_; THREAD_LOCAL std::unique_ptr<Linkers> Network::linkers_;
THREAD_LOCAL BruckMap Network::bruck_map_; THREAD_LOCAL BruckMap Network::bruck_map_;
THREAD_LOCAL RecursiveHalvingMap Network::recursive_halving_map_; THREAD_LOCAL RecursiveHalvingMap Network::recursive_halving_map_;
THREAD_LOCAL std::vector<int> Network::block_start_; THREAD_LOCAL std::vector<comm_size_t> Network::block_start_;
THREAD_LOCAL std::vector<int> Network::block_len_; THREAD_LOCAL std::vector<comm_size_t> Network::block_len_;
THREAD_LOCAL int Network::buffer_size_; THREAD_LOCAL comm_size_t Network::buffer_size_;
THREAD_LOCAL std::vector<char> Network::buffer_; THREAD_LOCAL std::vector<char> Network::buffer_;
THREAD_LOCAL AllreduceFunction Network::allreduce_ext_fun_ = NULL;
THREAD_LOCAL ReduceScatterFunction Network::reduce_scatter_ext_fun_ = NULL; THREAD_LOCAL ReduceScatterFunction Network::reduce_scatter_ext_fun_ = NULL;
THREAD_LOCAL AllgatherFunction Network::allgather_ext_fun_ = NULL; THREAD_LOCAL AllgatherFunction Network::allgather_ext_fun_ = NULL;
...@@ -31,8 +30,8 @@ void Network::Init(NetworkConfig config) { ...@@ -31,8 +30,8 @@ void Network::Init(NetworkConfig config) {
num_machines_ = linkers_->num_machines(); num_machines_ = linkers_->num_machines();
bruck_map_ = linkers_->bruck_map(); bruck_map_ = linkers_->bruck_map();
recursive_halving_map_ = linkers_->recursive_halving_map(); recursive_halving_map_ = linkers_->recursive_halving_map();
block_start_ = std::vector<int>(num_machines_); block_start_ = std::vector<comm_size_t>(num_machines_);
block_len_ = std::vector<int>(num_machines_); block_len_ = std::vector<comm_size_t>(num_machines_);
buffer_size_ = 1024 * 1024; buffer_size_ = 1024 * 1024;
buffer_.resize(buffer_size_); buffer_.resize(buffer_size_);
Log::Info("Local rank: %d, total number of machines: %d", rank_, num_machines_); Log::Info("Local rank: %d, total number of machines: %d", rank_, num_machines_);
...@@ -45,42 +44,39 @@ void Network::Dispose() { ...@@ -45,42 +44,39 @@ void Network::Dispose() {
linkers_.reset(new Linkers()); linkers_.reset(new Linkers());
} }
void Network::Allreduce(char* input, int input_size, int type_size, char* output, const ReduceFunction& reducer) { void Network::Allreduce(char* input, comm_size_t input_size, int type_size, char* output, const ReduceFunction& reducer) {
if (num_machines_ <= 1) { if (num_machines_ <= 1) {
Log::Fatal("Please initilize the network interface first"); Log::Fatal("Please initilize the network interface first");
} }
if (allreduce_ext_fun_ != NULL) { comm_size_t count = input_size / type_size;
return allreduce_ext_fun_(input, input_size, type_size, output, reducer);
}
int count = input_size / type_size;
// if small package or small count , do it by all gather.(reduce the communication times.) // if small package or small count , do it by all gather.(reduce the communication times.)
if (count < num_machines_ || input_size < 4096) { if (count < num_machines_ || input_size < 4096) {
AllreduceByAllGather(input, input_size, output, reducer); AllreduceByAllGather(input, input_size, type_size, output, reducer);
return; return;
} }
// assign the blocks to every rank. // assign the blocks to every rank.
int step = (count + num_machines_ - 1) / num_machines_; comm_size_t step = (count + num_machines_ - 1) / num_machines_;
if (step < 1) { if (step < 1) {
step = 1; step = 1;
} }
block_start_[0] = 0; block_start_[0] = 0;
for (int i = 0; i < num_machines_ - 1; ++i) { for (int i = 0; i < num_machines_ - 1; ++i) {
block_len_[i] = std::min(step * type_size, input_size - block_start_[i]); block_len_[i] = std::min<comm_size_t>(step * type_size, input_size - block_start_[i]);
block_start_[i + 1] = block_start_[i] + block_len_[i]; block_start_[i + 1] = block_start_[i] + block_len_[i];
} }
block_len_[num_machines_ - 1] = input_size - block_start_[num_machines_ - 1]; block_len_[num_machines_ - 1] = input_size - block_start_[num_machines_ - 1];
// do reduce scatter // do reduce scatter
ReduceScatter(input, input_size, block_start_.data(), block_len_.data(), output, reducer); ReduceScatter(input, input_size, type_size, block_start_.data(), block_len_.data(), output, input_size, reducer);
// do all gather // do all gather
Allgather(output, input_size, block_start_.data(), block_len_.data(), output); Allgather(output, block_start_.data(), block_len_.data(), output, input_size);
} }
void Network::AllreduceByAllGather(char* input, int input_size, char* output, const ReduceFunction& reducer) { void Network::AllreduceByAllGather(char* input, comm_size_t input_size, int type_size, char* output, const ReduceFunction& reducer) {
if (num_machines_ <= 1) { if (num_machines_ <= 1) {
Log::Fatal("Please initilize the network interface first"); Log::Fatal("Please initilize the network interface first");
} }
// assign blocks // assign blocks
int all_size = input_size * num_machines_; comm_size_t all_size = input_size * num_machines_;
block_start_[0] = 0; block_start_[0] = 0;
block_len_[0] = input_size; block_len_[0] = input_size;
for (int i = 1; i < num_machines_; ++i) { for (int i = 1; i < num_machines_; ++i) {
...@@ -93,15 +89,15 @@ void Network::AllreduceByAllGather(char* input, int input_size, char* output, co ...@@ -93,15 +89,15 @@ void Network::AllreduceByAllGather(char* input, int input_size, char* output, co
buffer_.resize(buffer_size_); buffer_.resize(buffer_size_);
} }
Allgather(input, all_size, block_start_.data(), block_len_.data(), buffer_.data()); Allgather(input, block_start_.data(), block_len_.data(), buffer_.data(), all_size);
for (int i = 1; i < num_machines_; ++i) { for (int i = 1; i < num_machines_; ++i) {
reducer(buffer_.data() + block_start_[i], buffer_.data() + block_start_[0], input_size); reducer(buffer_.data() + block_start_[i], buffer_.data() + block_start_[0], type_size, input_size);
} }
// copy back // copy back
std::memcpy(output, buffer_.data(), input_size); std::memcpy(output, buffer_.data(), input_size);
} }
void Network::Allgather(char* input, int send_size, char* output) { void Network::Allgather(char* input, comm_size_t send_size, char* output) {
if (num_machines_ <= 1) { if (num_machines_ <= 1) {
Log::Fatal("Please initilize the network interface first"); Log::Fatal("Please initilize the network interface first");
} }
...@@ -114,17 +110,17 @@ void Network::Allgather(char* input, int send_size, char* output) { ...@@ -114,17 +110,17 @@ void Network::Allgather(char* input, int send_size, char* output) {
block_len_[i] = send_size; block_len_[i] = send_size;
} }
// start all gather // start all gather
Allgather(input, send_size * num_machines_, block_start_.data(), block_len_.data(), output); Allgather(input, block_start_.data(), block_len_.data(), output, send_size * num_machines_);
} }
void Network::Allgather(char* input, int all_size, const int* block_start, const int* block_len, char* output) { void Network::Allgather(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size) {
if (num_machines_ <= 1) { if (num_machines_ <= 1) {
Log::Fatal("Please initilize the network interface first"); Log::Fatal("Please initilize the network interface first");
} }
if (allgather_ext_fun_ != NULL) { if (allgather_ext_fun_ != NULL) {
return allgather_ext_fun_(input, all_size, block_start, block_len, output); return allgather_ext_fun_(input, block_len[rank_], block_start, block_len, num_machines_, output, all_size);
} }
int write_pos = 0; comm_size_t write_pos = 0;
// use output as receive buffer // use output as receive buffer
std::memcpy(output, input, block_len[rank_]); std::memcpy(output, input, block_len[rank_]);
write_pos += block_len[rank_]; write_pos += block_len[rank_];
...@@ -137,9 +133,9 @@ void Network::Allgather(char* input, int all_size, const int* block_start, const ...@@ -137,9 +133,9 @@ void Network::Allgather(char* input, int all_size, const int* block_start, const
// get in rank // get in rank
int in_rank = bruck_map_.in_ranks[i]; int in_rank = bruck_map_.in_ranks[i];
// get send information // get send information
int need_send_len = 0; comm_size_t need_send_len = 0;
// get recv information // get recv information
int need_recv_len = 0; comm_size_t need_recv_len = 0;
for (int j = 0; j < cur_block_size; ++j) { for (int j = 0; j < cur_block_size; ++j) {
need_send_len += block_len[(rank_ + j) % num_machines_]; need_send_len += block_len[(rank_ + j) % num_machines_];
need_recv_len += block_len[(rank_ + accumulated_block + j) % num_machines_]; need_recv_len += block_len[(rank_ + accumulated_block + j) % num_machines_];
...@@ -155,40 +151,40 @@ void Network::Allgather(char* input, int all_size, const int* block_start, const ...@@ -155,40 +151,40 @@ void Network::Allgather(char* input, int all_size, const int* block_start, const
std::reverse<char*>(output + block_start[rank_], output + all_size); std::reverse<char*>(output + block_start[rank_], output + all_size);
} }
void Network::ReduceScatter(char* input, int input_size, const int* block_start, const int* block_len, char* output, const ReduceFunction& reducer) { void Network::ReduceScatter(char* input, comm_size_t input_size, int type_size, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t output_size, const ReduceFunction& reducer) {
if (num_machines_ <= 1) { if (num_machines_ <= 1) {
Log::Fatal("Please initilize the network interface first"); Log::Fatal("Please initilize the network interface first");
} }
if (reduce_scatter_ext_fun_ != NULL) { if (reduce_scatter_ext_fun_ != NULL) {
return reduce_scatter_ext_fun_(input, input_size, block_start, block_len, output, reducer); return reduce_scatter_ext_fun_(input, input_size, type_size, block_start, block_len, num_machines_, output, output_size, reducer);
} }
if (recursive_halving_map_.need_pairwise) { if (recursive_halving_map_.need_pairwise) {
for (int i = 1; i < num_machines_; ++i) { for (int i = 1; i < num_machines_; ++i) {
int out_rank = (rank_ + i) % num_machines_; int out_rank = (rank_ + i) % num_machines_;
int in_rank = (rank_ - i + num_machines_) % num_machines_; int in_rank = (rank_ - i + num_machines_) % num_machines_;
linkers_->SendRecv(out_rank, input + block_start[out_rank], block_len[out_rank], in_rank, output, block_len[rank_]); linkers_->SendRecv(out_rank, input + block_start[out_rank], block_len[out_rank], in_rank, output, block_len[rank_]);
reducer(output, input + block_start[rank_], block_len[rank_]); reducer(output, input + block_start[rank_], type_size, block_len[rank_]);
} }
} else { } else {
for (int i = 0; i < recursive_halving_map_.k; ++i) { for (int i = 0; i < recursive_halving_map_.k; ++i) {
// get target // get target
int target = recursive_halving_map_.ranks[i]; int target = recursive_halving_map_.ranks[i];
int send_block_start = recursive_halving_map_.send_block_start[i]; comm_size_t send_block_start = recursive_halving_map_.send_block_start[i];
int recv_block_start = recursive_halving_map_.recv_block_start[i]; comm_size_t recv_block_start = recursive_halving_map_.recv_block_start[i];
// get send information // get send information
int send_size = 0; comm_size_t send_size = 0;
for (int j = 0; j < recursive_halving_map_.send_block_len[i]; ++j) { for (int j = 0; j < recursive_halving_map_.send_block_len[i]; ++j) {
send_size += block_len[send_block_start + j]; send_size += block_len[send_block_start + j];
} }
// get recv information // get recv information
int need_recv_cnt = 0; comm_size_t need_recv_cnt = 0;
for (int j = 0; j < recursive_halving_map_.recv_block_len[i]; ++j) { for (int j = 0; j < recursive_halving_map_.recv_block_len[i]; ++j) {
need_recv_cnt += block_len[recv_block_start + j]; need_recv_cnt += block_len[recv_block_start + j];
} }
// send and recv at same time // send and recv at same time
linkers_->SendRecv(target, input + block_start[send_block_start], send_size, target, output, need_recv_cnt); linkers_->SendRecv(target, input + block_start[send_block_start], send_size, target, output, need_recv_cnt);
// reduce // reduce
reducer(output, input + block_start[recv_block_start], need_recv_cnt); reducer(output, input + block_start[recv_block_start], type_size, need_recv_cnt);
} }
} }
// copy result // copy result
......
...@@ -120,9 +120,8 @@ void DataParallelTreeLearner<TREELEARNER_T>::BeforeTrain() { ...@@ -120,9 +120,8 @@ void DataParallelTreeLearner<TREELEARNER_T>::BeforeTrain() {
int size = sizeof(data); int size = sizeof(data);
std::memcpy(input_buffer_.data(), &data, size); std::memcpy(input_buffer_.data(), &data, size);
// global sumup reduce // global sumup reduce
Network::Allreduce(input_buffer_.data(), size, size, output_buffer_.data(), [](const char *src, char *dst, int len) { Network::Allreduce(input_buffer_.data(), size, sizeof(std::tuple<data_size_t, double, double>), output_buffer_.data(), [](const char *src, char *dst, int type_size, comm_size_t len) {
int used_size = 0; comm_size_t used_size = 0;
int type_size = sizeof(std::tuple<data_size_t, double, double>);
const std::tuple<data_size_t, double, double> *p1; const std::tuple<data_size_t, double, double> *p1;
std::tuple<data_size_t, double, double> *p2; std::tuple<data_size_t, double, double> *p2;
while (used_size < len) { while (used_size < len) {
...@@ -157,8 +156,8 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplits() { ...@@ -157,8 +156,8 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplits() {
this->smaller_leaf_histogram_array_[feature_index].SizeOfHistgram()); this->smaller_leaf_histogram_array_[feature_index].SizeOfHistgram());
} }
// Reduce scatter for histogram // Reduce scatter for histogram
Network::ReduceScatter(input_buffer_.data(), reduce_scatter_size_, block_start_.data(), Network::ReduceScatter(input_buffer_.data(), reduce_scatter_size_, sizeof(HistogramBinEntry), block_start_.data(),
block_len_.data(), output_buffer_.data(), &HistogramBinEntry::SumReducer); block_len_.data(), output_buffer_.data(), static_cast<comm_size_t>(output_buffer_.size()), &HistogramBinEntry::SumReducer);
this->FindBestSplitsFromHistograms(this->is_feature_used_, true); this->FindBestSplitsFromHistograms(this->is_feature_used_, true);
} }
......
...@@ -79,15 +79,15 @@ private: ...@@ -79,15 +79,15 @@ private:
use this to mark local aggregate features*/ use this to mark local aggregate features*/
std::vector<bool> is_feature_aggregated_; std::vector<bool> is_feature_aggregated_;
/*! \brief Block start index for reduce scatter */ /*! \brief Block start index for reduce scatter */
std::vector<int> block_start_; std::vector<comm_size_t> block_start_;
/*! \brief Block size for reduce scatter */ /*! \brief Block size for reduce scatter */
std::vector<int> block_len_; std::vector<comm_size_t> block_len_;
/*! \brief Write positions for feature histograms */ /*! \brief Write positions for feature histograms */
std::vector<int> buffer_write_start_pos_; std::vector<comm_size_t> buffer_write_start_pos_;
/*! \brief Read positions for local feature histograms */ /*! \brief Read positions for local feature histograms */
std::vector<int> buffer_read_start_pos_; std::vector<comm_size_t> buffer_read_start_pos_;
/*! \brief Size for reduce scatter */ /*! \brief Size for reduce scatter */
int reduce_scatter_size_; comm_size_t reduce_scatter_size_;
/*! \brief Store global number of data in leaves */ /*! \brief Store global number of data in leaves */
std::vector<data_size_t> global_data_count_in_leaf_; std::vector<data_size_t> global_data_count_in_leaf_;
}; };
...@@ -155,15 +155,15 @@ private: ...@@ -155,15 +155,15 @@ private:
use this to mark local aggregate features*/ use this to mark local aggregate features*/
std::vector<bool> larger_is_feature_aggregated_; std::vector<bool> larger_is_feature_aggregated_;
/*! \brief Block start index for reduce scatter */ /*! \brief Block start index for reduce scatter */
std::vector<int> block_start_; std::vector<comm_size_t> block_start_;
/*! \brief Block size for reduce scatter */ /*! \brief Block size for reduce scatter */
std::vector<int> block_len_; std::vector<comm_size_t> block_len_;
/*! \brief Read positions for feature histgrams at smaller leaf */ /*! \brief Read positions for feature histgrams at smaller leaf */
std::vector<int> smaller_buffer_read_start_pos_; std::vector<comm_size_t> smaller_buffer_read_start_pos_;
/*! \brief Read positions for feature histgrams at larger leaf */ /*! \brief Read positions for feature histgrams at larger leaf */
std::vector<int> larger_buffer_read_start_pos_; std::vector<comm_size_t> larger_buffer_read_start_pos_;
/*! \brief Size for reduce scatter */ /*! \brief Size for reduce scatter */
int reduce_scatter_size_; comm_size_t reduce_scatter_size_;
/*! \brief Store global number of data in leaves */ /*! \brief Store global number of data in leaves */
std::vector<data_size_t> global_data_count_in_leaf_; std::vector<data_size_t> global_data_count_in_leaf_;
/*! \brief Store global split information for smaller leaf */ /*! \brief Store global split information for smaller leaf */
...@@ -187,8 +187,8 @@ inline void SyncUpGlobalBestSplit(char* input_buffer_, char* output_buffer_, Spl ...@@ -187,8 +187,8 @@ inline void SyncUpGlobalBestSplit(char* input_buffer_, char* output_buffer_, Spl
smaller_best_split->CopyTo(input_buffer_); smaller_best_split->CopyTo(input_buffer_);
larger_best_split->CopyTo(input_buffer_ + size); larger_best_split->CopyTo(input_buffer_ + size);
Network::Allreduce(input_buffer_, size * 2, size, output_buffer_, Network::Allreduce(input_buffer_, size * 2, size, output_buffer_,
[&size] (const char* src, char* dst, int len) { [] (const char* src, char* dst, int size, comm_size_t len) {
int used_size = 0; comm_size_t used_size = 0;
LightSplitInfo p1, p2; LightSplitInfo p1, p2;
while (used_size < len) { while (used_size < len) {
p1.CopyFrom(src); p1.CopyFrom(src);
......
...@@ -113,9 +113,8 @@ void VotingParallelTreeLearner<TREELEARNER_T>::BeforeTrain() { ...@@ -113,9 +113,8 @@ void VotingParallelTreeLearner<TREELEARNER_T>::BeforeTrain() {
int size = sizeof(std::tuple<data_size_t, double, double>); int size = sizeof(std::tuple<data_size_t, double, double>);
std::memcpy(input_buffer_.data(), &data, size); std::memcpy(input_buffer_.data(), &data, size);
Network::Allreduce(input_buffer_.data(), size, size, output_buffer_.data(), [](const char *src, char *dst, int len) { Network::Allreduce(input_buffer_.data(), size, sizeof(std::tuple<data_size_t, double, double>), output_buffer_.data(), [](const char *src, char *dst, int type_size, comm_size_t len) {
int used_size = 0; comm_size_t used_size = 0;
int type_size = sizeof(std::tuple<data_size_t, double, double>);
const std::tuple<data_size_t, double, double> *p1; const std::tuple<data_size_t, double, double> *p1;
std::tuple<data_size_t, double, double> *p2; std::tuple<data_size_t, double, double> *p2;
while (used_size < len) { while (used_size < len) {
...@@ -357,8 +356,8 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplits() { ...@@ -357,8 +356,8 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplits() {
CopyLocalHistogram(smaller_top_features, larger_top_features); CopyLocalHistogram(smaller_top_features, larger_top_features);
// Reduce scatter for histogram // Reduce scatter for histogram
Network::ReduceScatter(input_buffer_.data(), reduce_scatter_size_, block_start_.data(), block_len_.data(), Network::ReduceScatter(input_buffer_.data(), reduce_scatter_size_, sizeof(HistogramBinEntry), block_start_.data(), block_len_.data(),
output_buffer_.data(), &HistogramBinEntry::SumReducer); output_buffer_.data(), static_cast<comm_size_t>(output_buffer_.size()), &HistogramBinEntry::SumReducer);
this->FindBestSplitsFromHistograms(is_feature_used, false); this->FindBestSplitsFromHistograms(is_feature_used, false);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment