"docker/git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "94fbe5bb9fbb5f6067821f0be9cd145f0b7d7d94"
Commit 159e9a1e authored by Guolin Ke's avatar Guolin Ke
Browse files

clean code for network functions

parent 0a7a4080
...@@ -757,7 +757,7 @@ LIGHTGBM_C_EXPORT int LGBM_NetworkInit(const char* machines, ...@@ -757,7 +757,7 @@ LIGHTGBM_C_EXPORT int LGBM_NetworkInit(const char* machines,
*/ */
LIGHTGBM_C_EXPORT int LGBM_NetworkFree(); LIGHTGBM_C_EXPORT int LGBM_NetworkFree();
LIGHTGBM_C_EXPORT int LGBM_GetFuncions(void* AllreduceFuncPtr, LIGHTGBM_C_EXPORT int LGBM_NetworkInitWithFunctions(void* AllreduceFuncPtr,
void* ReduceScatterFuncPtr, void* ReduceScatterFuncPtr,
void* AllgatherFuncPtr, void* AllgatherFuncPtr,
int num_machines, int num_machines,
......
...@@ -191,9 +191,9 @@ public: ...@@ -191,9 +191,9 @@ public:
/*! \brief set variables and function ptrs */ /*! \brief set variables and function ptrs */
static void SetRank(int rank) { rank_ = rank;} static void SetRank(int rank) { rank_ = rank;}
static void SetNumMachines(int num_machines) { num_machines_ = num_machines; } static void SetNumMachines(int num_machines) { num_machines_ = num_machines; }
static void SetAllReduce(AllreduceFunction AllreduceFuncPtr) { AllreduceFuncPtr_ = AllreduceFuncPtr;} static void SetAllReduceFunction(AllreduceFunction AllreduceFuncPtr) { AllreduceFuncPtr_ = AllreduceFuncPtr;}
static void SetReduceScatter(ReduceScatterFunction ReduceScatterFuncPtr) { ReduceScatterFuncPtr_ = ReduceScatterFuncPtr; } static void SetReduceScatterFunction(ReduceScatterFunction ReduceScatterFuncPtr) { ReduceScatterFuncPtr_ = ReduceScatterFuncPtr; }
static void SetAllgather(AllgatherFunction AllgatherFuncPtr) { AllgatherFuncPtr_ = AllgatherFuncPtr; } static void SetAllgatherFunction(AllgatherFunction AllgatherFuncPtr) { AllgatherFuncPtr_ = AllgatherFuncPtr; }
private: private:
/*! \brief Number of all machines */ /*! \brief Number of all machines */
......
...@@ -1220,31 +1220,30 @@ int LGBM_NetworkFree() { ...@@ -1220,31 +1220,30 @@ int LGBM_NetworkFree() {
API_END(); API_END();
} }
int LGBM_GetFuncions(void* AllreduceFuncPtr, int LGBM_NetworkInitWithFunctions(void* AllreduceFuncPtr,
void* ReduceScatterFuncPtr, void* ReduceScatterFuncPtr,
void* AllgatherFuncPtr, void* AllgatherFuncPtr,
int num_machines, int num_machines,
int rank) { int rank) {
API_BEGIN(); API_BEGIN();
if(num_machines > 1) { if (num_machines > 1) {
auto func1 = [AllreduceFuncPtr](char* arg1, int arg2, int arg3, char* arg4, const ReduceFunction& func) { auto allreduce_fun = [AllreduceFuncPtr](char* arg1, int arg2, int arg3, char* arg4, const ReduceFunction& func) {
auto ptr = *func.target<ReduceFunctionInC>(); auto ptr = *func.target<ReduceFunctionInC>();
auto tmp = (void(*)(char*, int, int, char*, const ReduceFunctionInC&))AllreduceFuncPtr; auto tmp = (void(*)(char*, int, int, char*, const ReduceFunctionInC&))AllreduceFuncPtr;
return tmp(arg1, arg2, arg3, arg4, ptr); return tmp(arg1, arg2, arg3, arg4, ptr);
}; };
Network::SetAllReduce(func1); Network::SetAllReduceFunction(allreduce_fun);
auto func2 = [ReduceScatterFuncPtr](char* arg1, int arg2, const int* arg3, const int* arg4, char* arg5, const ReduceFunction& func) { auto reduce_scatter_fun = [ReduceScatterFuncPtr](char* arg1, int arg2, const int* arg3, const int* arg4, char* arg5, const ReduceFunction& func) {
auto ptr = *func.target<ReduceFunctionInC>(); auto ptr = *func.target<ReduceFunctionInC>();
auto tmp = (void(*)(char*, int, const int*, const int*, char*, const ReduceFunctionInC&))ReduceScatterFuncPtr; auto tmp = (void(*)(char*, int, const int*, const int*, char*, const ReduceFunctionInC&))ReduceScatterFuncPtr;
return tmp(arg1, arg2, arg3, arg4, arg5, ptr); return tmp(arg1, arg2, arg3, arg4, arg5, ptr);
}; };
Network::SetReduceScatter(func2); Network::SetReduceScatterFunction(reduce_scatter_fun);
Network::SetAllgather((void(*)(char*, int, char*))AllgatherFuncPtr); Network::SetAllgatherFunction((void(*)(char*, int, char*))AllgatherFuncPtr);
Network::SetNumMachines(num_machines); Network::SetNumMachines(num_machines);
Network::SetRank(rank); Network::SetRank(rank);
} }
API_END(); API_END();
} }
// ---- start of some help functions // ---- start of some help functions
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment