Commit 159e9a1e authored by Guolin Ke's avatar Guolin Ke
Browse files

clean code for network functions

parent 0a7a4080
......@@ -757,11 +757,11 @@ LIGHTGBM_C_EXPORT int LGBM_NetworkInit(const char* machines,
*/
LIGHTGBM_C_EXPORT int LGBM_NetworkFree();
LIGHTGBM_C_EXPORT int LGBM_GetFuncions(void* AllreduceFuncPtr,
void* ReduceScatterFuncPtr,
void* AllgatherFuncPtr,
int num_machines,
int rank);
LIGHTGBM_C_EXPORT int LGBM_NetworkInitWithFunctions(void* AllreduceFuncPtr,
void* ReduceScatterFuncPtr,
void* AllgatherFuncPtr,
int num_machines,
int rank);
// exception handle and error msg
static char* LastErrorMsg() { static THREAD_LOCAL char err_msg[512] = "Everything is fine"; return err_msg; }
......
......@@ -191,9 +191,9 @@ public:
/*! \brief set variables and function ptrs */
static void SetRank(int rank) { rank_ = rank;}
static void SetNumMachines(int num_machines) { num_machines_ = num_machines; }
static void SetAllReduce(AllreduceFunction AllreduceFuncPtr) { AllreduceFuncPtr_ = AllreduceFuncPtr;}
static void SetReduceScatter(ReduceScatterFunction ReduceScatterFuncPtr) { ReduceScatterFuncPtr_ = ReduceScatterFuncPtr; }
static void SetAllgather(AllgatherFunction AllgatherFuncPtr) { AllgatherFuncPtr_ = AllgatherFuncPtr; }
static void SetAllReduceFunction(AllreduceFunction AllreduceFuncPtr) { AllreduceFuncPtr_ = AllreduceFuncPtr;}
static void SetReduceScatterFunction(ReduceScatterFunction ReduceScatterFuncPtr) { ReduceScatterFuncPtr_ = ReduceScatterFuncPtr; }
static void SetAllgatherFunction(AllgatherFunction AllgatherFuncPtr) { AllgatherFuncPtr_ = AllgatherFuncPtr; }
private:
/*! \brief Number of all machines */
......
......@@ -1220,31 +1220,30 @@ int LGBM_NetworkFree() {
API_END();
}
int LGBM_GetFuncions(void* AllreduceFuncPtr,
void* ReduceScatterFuncPtr,
void* AllgatherFuncPtr,
int num_machines,
int rank) {
int LGBM_NetworkInitWithFunctions(void* AllreduceFuncPtr,
void* ReduceScatterFuncPtr,
void* AllgatherFuncPtr,
int num_machines,
int rank) {
API_BEGIN();
if(num_machines > 1) {
auto func1 = [AllreduceFuncPtr](char* arg1, int arg2, int arg3, char* arg4, const ReduceFunction& func) {
if (num_machines > 1) {
auto allreduce_fun = [AllreduceFuncPtr](char* arg1, int arg2, int arg3, char* arg4, const ReduceFunction& func) {
auto ptr = *func.target<ReduceFunctionInC>();
auto tmp = (void(*)(char*, int, int, char*, const ReduceFunctionInC&))AllreduceFuncPtr;
return tmp(arg1, arg2, arg3, arg4, ptr);
};
Network::SetAllReduce(func1);
auto func2 = [ReduceScatterFuncPtr](char* arg1, int arg2, const int* arg3, const int* arg4, char* arg5, const ReduceFunction& func) {
Network::SetAllReduceFunction(allreduce_fun);
auto reduce_scatter_fun = [ReduceScatterFuncPtr](char* arg1, int arg2, const int* arg3, const int* arg4, char* arg5, const ReduceFunction& func) {
auto ptr = *func.target<ReduceFunctionInC>();
auto tmp = (void(*)(char*, int, const int*, const int*, char*, const ReduceFunctionInC&))ReduceScatterFuncPtr;
return tmp(arg1, arg2, arg3, arg4, arg5, ptr);
};
Network::SetReduceScatter(func2);
Network::SetAllgather((void(*)(char*, int, char*))AllgatherFuncPtr);
Network::SetReduceScatterFunction(reduce_scatter_fun);
Network::SetAllgatherFunction((void(*)(char*, int, char*))AllgatherFuncPtr);
Network::SetNumMachines(num_machines);
Network::SetRank(rank);
}
API_END();
}
// ---- start of some help functions
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment