"src/vscode:/vscode.git/clone" did not exist on "b7de71df5315dee2e10bffc2977c629a22aee08f"
Commit e951a3d7 authored by ww's avatar ww Committed by Guolin Ke
Browse files

Network interface with c_api (#1067)

parent 38b65e5f
#ifndef LIGHTGBM_C_API_H_
#define LIGHTGBM_C_API_H_
#include <LightGBM/meta.h>
#include <cstdint>
#include <exception>
#include <stdexcept>
......@@ -754,6 +757,11 @@ LIGHTGBM_C_EXPORT int LGBM_NetworkInit(const char* machines,
*/
LIGHTGBM_C_EXPORT int LGBM_NetworkFree();
LIGHTGBM_C_EXPORT int LGBM_GetFuncions(void* AllreduceFuncPtr,
void* ReduceScatterFuncPtr,
void* AllgatherFuncPtr,
int num_machines,
int rank);
// exception handle and error msg
static char* LastErrorMsg() { static THREAD_LOCAL char err_msg[512] = "Everything is fine"; return err_msg; }
......
......@@ -23,9 +23,17 @@ const double kZeroAsMissingValueRange = 1e-20f;
using ReduceFunction = std::function<void(const char*, char*, int)>;
typedef void(*ReduceFunctionInC)(const char*, char*, int);
using PredictFunction =
std::function<void(const std::vector<std::pair<int, double>>&, double* output)>;
using AllreduceFunction = std::function<void(char*, int, int, char*, const ReduceFunction&)>;
using ReduceScatterFunction = std::function<void(char*, int, const int*, const int*, char*, const ReduceFunction&)>;
using AllgatherFunction = std::function<void(char*, int, char*)>;
#define NO_SPECIFIC (-1)
#if (_MSC_VER <= 1800)
......
......@@ -188,6 +188,12 @@ public:
});
return global;
}
/*! \brief set variables and function ptrs */
static void SetRank(int rank) { rank_ = rank;}
static void SetNumMachines(int num_machines) { num_machines_ = num_machines; }
static void SetAllReduce(AllreduceFunction AllreduceFuncPtr) { AllreduceFuncPtr_ = AllreduceFuncPtr;}
static void SetReduceScatter(ReduceScatterFunction ReduceScatterFuncPtr) { ReduceScatterFuncPtr_ = ReduceScatterFuncPtr; }
static void SetAllgather(AllgatherFunction AllgatherFuncPtr) { AllgatherFuncPtr_ = AllgatherFuncPtr; }
private:
/*! \brief Number of all machines */
......@@ -208,7 +214,10 @@ private:
static THREAD_LOCAL std::vector<char> buffer_;
/*! \brief Size of buffer_ */
static THREAD_LOCAL int buffer_size_;
/*! \brief Funcs*/
static THREAD_LOCAL AllreduceFunction AllreduceFuncPtr_;
static THREAD_LOCAL ReduceScatterFunction ReduceScatterFuncPtr_;
static THREAD_LOCAL AllgatherFunction AllgatherFuncPtr_;
};
inline int Network::rank() {
......
......@@ -1220,6 +1220,32 @@ int LGBM_NetworkFree() {
API_END();
}
int LGBM_GetFuncions(void* AllreduceFuncPtr,
void* ReduceScatterFuncPtr,
void* AllgatherFuncPtr,
int num_machines,
int rank) {
API_BEGIN();
if(num_machines > 1) {
auto func1 = [AllreduceFuncPtr](char* arg1, int arg2, int arg3, char* arg4, const ReduceFunction& func) {
auto ptr = *func.target<ReduceFunctionInC>();
auto tmp = (void(*)(char*, int, int, char*, const ReduceFunctionInC&))AllreduceFuncPtr;
return tmp(arg1, arg2, arg3, arg4, ptr);
};
Network::SetAllReduce(func1);
auto func2 = [ReduceScatterFuncPtr](char* arg1, int arg2, const int* arg3, const int* arg4, char* arg5, const ReduceFunction& func) {
auto ptr = *func.target<ReduceFunctionInC>();
auto tmp = (void(*)(char*, int, const int*, const int*, char*, const ReduceFunctionInC&))ReduceScatterFuncPtr;
return tmp(arg1, arg2, arg3, arg4, arg5, ptr);
};
Network::SetReduceScatter(func2);
Network::SetAllgather((void(*)(char*, int, char*))AllgatherFuncPtr);
Network::SetNumMachines(num_machines);
Network::SetRank(rank);
}
API_END();
}
// ---- start of some help functions
std::function<std::vector<double>(int row_idx)>
......
......@@ -19,6 +19,10 @@ THREAD_LOCAL std::vector<int> Network::block_start_;
THREAD_LOCAL std::vector<int> Network::block_len_;
THREAD_LOCAL int Network::buffer_size_;
THREAD_LOCAL std::vector<char> Network::buffer_;
THREAD_LOCAL AllreduceFunction Network::AllreduceFuncPtr_ = NULL;
THREAD_LOCAL ReduceScatterFunction Network::ReduceScatterFuncPtr_ = NULL;
THREAD_LOCAL AllgatherFunction Network::AllgatherFuncPtr_ = NULL;
void Network::Init(NetworkConfig config) {
if (config.num_machines > 1) {
......@@ -45,6 +49,9 @@ void Network::Allreduce(char* input, int input_size, int type_size, char* output
if (num_machines_ <= 1) {
Log::Fatal("Please initilize the network interface first");
}
if (AllreduceFuncPtr_ != NULL) {
return AllreduceFuncPtr_(input, input_size, type_size, output, reducer);
}
int count = input_size / type_size;
// if small package or small count , do it by all gather.(reduce the communication times.)
if (count < num_machines_ || input_size < 4096) {
......@@ -99,6 +106,9 @@ void Network::Allgather(char* input, int send_size, char* output) {
Log::Fatal("Please initilize the network interface first");
}
if (num_machines_ <= 1) { return; }
if (AllgatherFuncPtr_ != NULL) {
return AllgatherFuncPtr_(input, send_size, output);
}
// assign blocks
block_start_[0] = 0;
block_len_[0] = send_size;
......@@ -145,10 +155,13 @@ void Network::Allgather(char* input, int all_size, const int* block_start, const
std::reverse<char*>(output + block_start[rank_], output + all_size);
}
void Network::ReduceScatter(char* input, int, const int* block_start, const int* block_len, char* output, const ReduceFunction& reducer) {
void Network::ReduceScatter(char* input, int input_size, const int* block_start, const int* block_len, char* output, const ReduceFunction& reducer) {
if (num_machines_ <= 1) {
Log::Fatal("Please initilize the network interface first");
}
if (ReduceScatterFuncPtr_ != NULL) {
return ReduceScatterFuncPtr_(input, input_size, block_start, block_len, output, reducer);
}
if (recursive_halving_map_.need_pairwise) {
for (int i = 1; i < num_machines_; ++i) {
int out_rank = (rank_ + i) % num_machines_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment