Commit 72b54956 authored by Guolin Ke's avatar Guolin Ke
Browse files

fix bug in LGBM_NetworkInitWithFunctions

parent 159e9a1e
...@@ -757,9 +757,9 @@ LIGHTGBM_C_EXPORT int LGBM_NetworkInit(const char* machines, ...@@ -757,9 +757,9 @@ LIGHTGBM_C_EXPORT int LGBM_NetworkInit(const char* machines,
*/ */
LIGHTGBM_C_EXPORT int LGBM_NetworkFree(); LIGHTGBM_C_EXPORT int LGBM_NetworkFree();
LIGHTGBM_C_EXPORT int LGBM_NetworkInitWithFunctions(void* AllreduceFuncPtr, LIGHTGBM_C_EXPORT int LGBM_NetworkInitWithFunctions(void* allreduce_fun_ptr,
void* ReduceScatterFuncPtr, void* reduce_scatter_fun_ptr,
void* AllgatherFuncPtr, void* allgather_fun_ptr,
int num_machines, int num_machines,
int rank); int rank);
......
...@@ -23,8 +23,6 @@ const double kZeroThreshold = 1e-35f; ...@@ -23,8 +23,6 @@ const double kZeroThreshold = 1e-35f;
using ReduceFunction = std::function<void(const char*, char*, int)>; using ReduceFunction = std::function<void(const char*, char*, int)>;
typedef void(*ReduceFunctionInC)(const char*, char*, int);
using PredictFunction = using PredictFunction =
std::function<void(const std::vector<std::pair<int, double>>&, double* output)>; std::function<void(const std::vector<std::pair<int, double>>&, double* output)>;
...@@ -32,7 +30,7 @@ using AllreduceFunction = std::function<void(char*, int, int, char*, const Reduc ...@@ -32,7 +30,7 @@ using AllreduceFunction = std::function<void(char*, int, int, char*, const Reduc
using ReduceScatterFunction = std::function<void(char*, int, const int*, const int*, char*, const ReduceFunction&)>; using ReduceScatterFunction = std::function<void(char*, int, const int*, const int*, char*, const ReduceFunction&)>;
using AllgatherFunction = std::function<void(char*, int, char*)>; using AllgatherFunction = std::function<void(char*, int, const int*, const int*, char*)>;
#define NO_SPECIFIC (-1) #define NO_SPECIFIC (-1)
......
...@@ -191,9 +191,9 @@ public: ...@@ -191,9 +191,9 @@ public:
/*! \brief set variables and function ptrs */ /*! \brief set variables and function ptrs */
static void SetRank(int rank) { rank_ = rank;} static void SetRank(int rank) { rank_ = rank;}
static void SetNumMachines(int num_machines) { num_machines_ = num_machines; } static void SetNumMachines(int num_machines) { num_machines_ = num_machines; }
static void SetAllReduceFunction(AllreduceFunction AllreduceFuncPtr) { AllreduceFuncPtr_ = AllreduceFuncPtr;} static void SetAllReduceFunction(AllreduceFunction allreduce_ext_fun) { allreduce_ext_fun_ = allreduce_ext_fun;}
static void SetReduceScatterFunction(ReduceScatterFunction ReduceScatterFuncPtr) { ReduceScatterFuncPtr_ = ReduceScatterFuncPtr; } static void SetReduceScatterFunction(ReduceScatterFunction reduce_scatter_ext_fun) { reduce_scatter_ext_fun_ = reduce_scatter_ext_fun; }
static void SetAllgatherFunction(AllgatherFunction AllgatherFuncPtr) { AllgatherFuncPtr_ = AllgatherFuncPtr; } static void SetAllgatherFunction(AllgatherFunction allgather_ext_fun) { allgather_ext_fun_ = allgather_ext_fun; }
private: private:
/*! \brief Number of all machines */ /*! \brief Number of all machines */
...@@ -215,9 +215,9 @@ private: ...@@ -215,9 +215,9 @@ private:
/*! \brief Size of buffer_ */ /*! \brief Size of buffer_ */
static THREAD_LOCAL int buffer_size_; static THREAD_LOCAL int buffer_size_;
/*! \brief Funcs*/ /*! \brief Funcs*/
static THREAD_LOCAL AllreduceFunction AllreduceFuncPtr_; static THREAD_LOCAL AllreduceFunction allreduce_ext_fun_;
static THREAD_LOCAL ReduceScatterFunction ReduceScatterFuncPtr_; static THREAD_LOCAL ReduceScatterFunction reduce_scatter_ext_fun_;
static THREAD_LOCAL AllgatherFunction AllgatherFuncPtr_; static THREAD_LOCAL AllgatherFunction allgather_ext_fun_;
}; };
inline int Network::rank() { inline int Network::rank() {
......
...@@ -1220,26 +1220,27 @@ int LGBM_NetworkFree() { ...@@ -1220,26 +1220,27 @@ int LGBM_NetworkFree() {
API_END(); API_END();
} }
int LGBM_NetworkInitWithFunctions(void* AllreduceFuncPtr, int LGBM_NetworkInitWithFunctions(void* allreduce_fun_ptr,
void* ReduceScatterFuncPtr, void* reduce_scatter_fun_ptr,
void* AllgatherFuncPtr, void* allgather_fun_ptr,
int num_machines, int num_machines,
int rank) { int rank) {
API_BEGIN(); API_BEGIN();
typedef void(*ReduceFunctionPtr)(const char* input, char* output, int array_size);
if (num_machines > 1) { if (num_machines > 1) {
auto allreduce_fun = [AllreduceFuncPtr](char* arg1, int arg2, int arg3, char* arg4, const ReduceFunction& func) { auto allreduce_fun = [allreduce_fun_ptr](char* arg1, int arg2, int arg3, char* arg4, const ReduceFunction& reduce_fun) {
auto ptr = *func.target<ReduceFunctionInC>(); auto reduce_fun_ptr = *reduce_fun.target<ReduceFunctionPtr>();
auto tmp = (void(*)(char*, int, int, char*, const ReduceFunctionInC&))AllreduceFuncPtr; auto tmp = (void(*)(char*, int, int, char*, const ReduceFunctionPtr&))allreduce_fun_ptr;
return tmp(arg1, arg2, arg3, arg4, ptr); return tmp(arg1, arg2, arg3, arg4, reduce_fun_ptr);
}; };
Network::SetAllReduceFunction(allreduce_fun); Network::SetAllReduceFunction(allreduce_fun);
auto reduce_scatter_fun = [ReduceScatterFuncPtr](char* arg1, int arg2, const int* arg3, const int* arg4, char* arg5, const ReduceFunction& func) { auto reduce_scatter_fun = [reduce_scatter_fun_ptr](char* arg1, int arg2, const int* arg3, const int* arg4, char* arg5, const ReduceFunction& reduce_fun) {
auto ptr = *func.target<ReduceFunctionInC>(); auto reduce_fun_ptr = *reduce_fun.target<ReduceFunctionPtr>();
auto tmp = (void(*)(char*, int, const int*, const int*, char*, const ReduceFunctionInC&))ReduceScatterFuncPtr; auto tmp = (void(*)(char*, int, const int*, const int*, char*, const ReduceFunctionPtr&))reduce_scatter_fun_ptr;
return tmp(arg1, arg2, arg3, arg4, arg5, ptr); return tmp(arg1, arg2, arg3, arg4, arg5, reduce_fun_ptr);
}; };
Network::SetReduceScatterFunction(reduce_scatter_fun); Network::SetReduceScatterFunction(reduce_scatter_fun);
Network::SetAllgatherFunction((void(*)(char*, int, char*))AllgatherFuncPtr); Network::SetAllgatherFunction((void(*)(char*, int, const int*, const int*, char*))allgather_fun_ptr);
Network::SetNumMachines(num_machines); Network::SetNumMachines(num_machines);
Network::SetRank(rank); Network::SetRank(rank);
} }
......
...@@ -19,9 +19,9 @@ THREAD_LOCAL std::vector<int> Network::block_start_; ...@@ -19,9 +19,9 @@ THREAD_LOCAL std::vector<int> Network::block_start_;
THREAD_LOCAL std::vector<int> Network::block_len_; THREAD_LOCAL std::vector<int> Network::block_len_;
THREAD_LOCAL int Network::buffer_size_; THREAD_LOCAL int Network::buffer_size_;
THREAD_LOCAL std::vector<char> Network::buffer_; THREAD_LOCAL std::vector<char> Network::buffer_;
THREAD_LOCAL AllreduceFunction Network::AllreduceFuncPtr_ = NULL; THREAD_LOCAL AllreduceFunction Network::allreduce_ext_fun_ = NULL;
THREAD_LOCAL ReduceScatterFunction Network::ReduceScatterFuncPtr_ = NULL; THREAD_LOCAL ReduceScatterFunction Network::reduce_scatter_ext_fun_ = NULL;
THREAD_LOCAL AllgatherFunction Network::AllgatherFuncPtr_ = NULL; THREAD_LOCAL AllgatherFunction Network::allgather_ext_fun_ = NULL;
void Network::Init(NetworkConfig config) { void Network::Init(NetworkConfig config) {
...@@ -49,8 +49,8 @@ void Network::Allreduce(char* input, int input_size, int type_size, char* output ...@@ -49,8 +49,8 @@ void Network::Allreduce(char* input, int input_size, int type_size, char* output
if (num_machines_ <= 1) { if (num_machines_ <= 1) {
Log::Fatal("Please initilize the network interface first"); Log::Fatal("Please initilize the network interface first");
} }
if (AllreduceFuncPtr_ != NULL) { if (allreduce_ext_fun_ != NULL) {
return AllreduceFuncPtr_(input, input_size, type_size, output, reducer); return allreduce_ext_fun_(input, input_size, type_size, output, reducer);
} }
int count = input_size / type_size; int count = input_size / type_size;
// if small package or small count , do it by all gather.(reduce the communication times.) // if small package or small count , do it by all gather.(reduce the communication times.)
...@@ -106,9 +106,6 @@ void Network::Allgather(char* input, int send_size, char* output) { ...@@ -106,9 +106,6 @@ void Network::Allgather(char* input, int send_size, char* output) {
Log::Fatal("Please initilize the network interface first"); Log::Fatal("Please initilize the network interface first");
} }
if (num_machines_ <= 1) { return; } if (num_machines_ <= 1) { return; }
if (AllgatherFuncPtr_ != NULL) {
return AllgatherFuncPtr_(input, send_size, output);
}
// assign blocks // assign blocks
block_start_[0] = 0; block_start_[0] = 0;
block_len_[0] = send_size; block_len_[0] = send_size;
...@@ -124,6 +121,9 @@ void Network::Allgather(char* input, int all_size, const int* block_start, const ...@@ -124,6 +121,9 @@ void Network::Allgather(char* input, int all_size, const int* block_start, const
if (num_machines_ <= 1) { if (num_machines_ <= 1) {
Log::Fatal("Please initilize the network interface first"); Log::Fatal("Please initilize the network interface first");
} }
if (allgather_ext_fun_ != NULL) {
return allgather_ext_fun_(input, all_size, block_start, block_len, output);
}
int write_pos = 0; int write_pos = 0;
// use output as receive buffer // use output as receive buffer
std::memcpy(output, input, block_len[rank_]); std::memcpy(output, input, block_len[rank_]);
...@@ -159,8 +159,8 @@ void Network::ReduceScatter(char* input, int input_size, const int* block_start, ...@@ -159,8 +159,8 @@ void Network::ReduceScatter(char* input, int input_size, const int* block_start,
if (num_machines_ <= 1) { if (num_machines_ <= 1) {
Log::Fatal("Please initilize the network interface first"); Log::Fatal("Please initilize the network interface first");
} }
if (ReduceScatterFuncPtr_ != NULL) { if (reduce_scatter_ext_fun_ != NULL) {
return ReduceScatterFuncPtr_(input, input_size, block_start, block_len, output, reducer); return reduce_scatter_ext_fun_(input, input_size, block_start, block_len, output, reducer);
} }
if (recursive_halving_map_.need_pairwise) { if (recursive_halving_map_.need_pairwise) {
for (int i = 1; i < num_machines_; ++i) { for (int i = 1; i < num_machines_; ++i) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment