Commit 162509ae authored by Guolin Ke's avatar Guolin Ke
Browse files

fix network init with extern functions.

parent b65f6e65
...@@ -757,10 +757,9 @@ LIGHTGBM_C_EXPORT int LGBM_NetworkInit(const char* machines, ...@@ -757,10 +757,9 @@ LIGHTGBM_C_EXPORT int LGBM_NetworkInit(const char* machines,
*/ */
LIGHTGBM_C_EXPORT int LGBM_NetworkFree(); LIGHTGBM_C_EXPORT int LGBM_NetworkFree();
LIGHTGBM_C_EXPORT int LGBM_NetworkInitWithFunctions(void* reduce_scatter_fun_ptr, LIGHTGBM_C_EXPORT int LGBM_NetworkInitWithFunctions(int num_machines, int rank,
void* allgather_fun_ptr, void* reduce_scatter_ext_fun,
int num_machines, void* allgather_ext_fun);
int rank);
// exception handle and error msg // exception handle and error msg
static char* LastErrorMsg() { static THREAD_LOCAL char err_msg[512] = "Everything is fine"; return err_msg; } static char* LastErrorMsg() { static THREAD_LOCAL char err_msg[512] = "Everything is fine"; return err_msg; }
......
...@@ -73,6 +73,10 @@ public: ...@@ -73,6 +73,10 @@ public:
* \param config Config of network setting * \param config Config of network setting
*/ */
static void Init(NetworkConfig config); static void Init(NetworkConfig config);
/*!
* \brief Initialize
*/
static void Init(int num_machines, int rank, ReduceScatterFunction reduce_scatter_ext_fun, AllgatherFunction allgather_ext_fun);
/*! \brief Free this static class */ /*! \brief Free this static class */
static void Dispose(); static void Dispose();
/*! \brief Get rank of this machine */ /*! \brief Get rank of this machine */
...@@ -188,11 +192,6 @@ public: ...@@ -188,11 +192,6 @@ public:
}); });
return global; return global;
} }
/*! \brief set variables and function ptrs */
static void SetRank(int rank) { rank_ = rank;}
static void SetNumMachines(int num_machines) { num_machines_ = num_machines; }
static void SetReduceScatterFunction(ReduceScatterFunction reduce_scatter_ext_fun) { reduce_scatter_ext_fun_ = reduce_scatter_ext_fun; }
static void SetAllgatherFunction(AllgatherFunction allgather_ext_fun) { allgather_ext_fun_ = allgather_ext_fun; }
private: private:
/*! \brief Number of all machines */ /*! \brief Number of all machines */
......
...@@ -1220,16 +1220,12 @@ int LGBM_NetworkFree() { ...@@ -1220,16 +1220,12 @@ int LGBM_NetworkFree() {
API_END(); API_END();
} }
int LGBM_NetworkInitWithFunctions(void* reduce_scatter_fun_ptr, int LGBM_NetworkInitWithFunctions(int num_machines, int rank,
void* allgather_fun_ptr, void* reduce_scatter_ext_fun,
int num_machines, void* allgather_ext_fun) {
int rank) {
API_BEGIN(); API_BEGIN();
if (num_machines > 1) { if (num_machines > 1) {
Network::SetReduceScatterFunction((ReduceScatterFunction)reduce_scatter_fun_ptr); Network::Init(num_machines, rank, (ReduceScatterFunction)reduce_scatter_ext_fun, (AllgatherFunction)allgather_ext_fun);
Network::SetAllgatherFunction((AllgatherFunction)allgather_fun_ptr);
Network::SetNumMachines(num_machines);
Network::SetRank(rank);
} }
API_END(); API_END();
} }
......
...@@ -17,10 +17,10 @@ THREAD_LOCAL BruckMap Network::bruck_map_; ...@@ -17,10 +17,10 @@ THREAD_LOCAL BruckMap Network::bruck_map_;
THREAD_LOCAL RecursiveHalvingMap Network::recursive_halving_map_; THREAD_LOCAL RecursiveHalvingMap Network::recursive_halving_map_;
THREAD_LOCAL std::vector<comm_size_t> Network::block_start_; THREAD_LOCAL std::vector<comm_size_t> Network::block_start_;
THREAD_LOCAL std::vector<comm_size_t> Network::block_len_; THREAD_LOCAL std::vector<comm_size_t> Network::block_len_;
THREAD_LOCAL comm_size_t Network::buffer_size_; THREAD_LOCAL comm_size_t Network::buffer_size_ = 0;
THREAD_LOCAL std::vector<char> Network::buffer_; THREAD_LOCAL std::vector<char> Network::buffer_;
THREAD_LOCAL ReduceScatterFunction Network::reduce_scatter_ext_fun_ = NULL; THREAD_LOCAL ReduceScatterFunction Network::reduce_scatter_ext_fun_ = nullptr;
THREAD_LOCAL AllgatherFunction Network::allgather_ext_fun_ = NULL; THREAD_LOCAL AllgatherFunction Network::allgather_ext_fun_ = nullptr;
void Network::Init(NetworkConfig config) { void Network::Init(NetworkConfig config) {
...@@ -38,10 +38,27 @@ void Network::Init(NetworkConfig config) { ...@@ -38,10 +38,27 @@ void Network::Init(NetworkConfig config) {
} }
} }
void Network::Init(int num_machines, int rank,
ReduceScatterFunction reduce_scatter_ext_fun, AllgatherFunction allgather_ext_fun) {
if (num_machines > 1) {
rank_ = rank;
num_machines_ = num_machines;
block_start_ = std::vector<comm_size_t>(num_machines_);
block_len_ = std::vector<comm_size_t>(num_machines_);
buffer_size_ = 1024 * 1024;
buffer_.resize(buffer_size_);
reduce_scatter_ext_fun_ = reduce_scatter_ext_fun;
allgather_ext_fun_ = allgather_ext_fun;
Log::Info("Local rank: %d, total number of machines: %d", rank_, num_machines_);
}
}
void Network::Dispose() { void Network::Dispose() {
num_machines_ = 1; num_machines_ = 1;
rank_ = 0; rank_ = 0;
linkers_.reset(new Linkers()); linkers_.reset(new Linkers());
reduce_scatter_ext_fun_ = nullptr;
allgather_ext_fun_ = nullptr;
} }
void Network::Allreduce(char* input, comm_size_t input_size, int type_size, char* output, const ReduceFunction& reducer) { void Network::Allreduce(char* input, comm_size_t input_size, int type_size, char* output, const ReduceFunction& reducer) {
...@@ -117,7 +134,7 @@ void Network::Allgather(char* input, const comm_size_t* block_start, const comm_ ...@@ -117,7 +134,7 @@ void Network::Allgather(char* input, const comm_size_t* block_start, const comm_
if (num_machines_ <= 1) { if (num_machines_ <= 1) {
Log::Fatal("Please initilize the network interface first"); Log::Fatal("Please initilize the network interface first");
} }
if (allgather_ext_fun_ != NULL) { if (allgather_ext_fun_ != nullptr) {
return allgather_ext_fun_(input, block_len[rank_], block_start, block_len, num_machines_, output, all_size); return allgather_ext_fun_(input, block_len[rank_], block_start, block_len, num_machines_, output, all_size);
} }
comm_size_t write_pos = 0; comm_size_t write_pos = 0;
...@@ -155,7 +172,7 @@ void Network::ReduceScatter(char* input, comm_size_t input_size, int type_size, ...@@ -155,7 +172,7 @@ void Network::ReduceScatter(char* input, comm_size_t input_size, int type_size,
if (num_machines_ <= 1) { if (num_machines_ <= 1) {
Log::Fatal("Please initilize the network interface first"); Log::Fatal("Please initilize the network interface first");
} }
if (reduce_scatter_ext_fun_ != NULL) { if (reduce_scatter_ext_fun_ != nullptr) {
return reduce_scatter_ext_fun_(input, input_size, type_size, block_start, block_len, num_machines_, output, output_size, reducer); return reduce_scatter_ext_fun_(input, input_size, type_size, block_start, block_len, num_machines_, output, output_size, reducer);
} }
if (recursive_halving_map_.need_pairwise) { if (recursive_halving_map_.need_pairwise) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment