Commit 162509ae authored by Guolin Ke's avatar Guolin Ke
Browse files

fix network init with extern functions.

parent b65f6e65
......@@ -757,10 +757,9 @@ LIGHTGBM_C_EXPORT int LGBM_NetworkInit(const char* machines,
*/
LIGHTGBM_C_EXPORT int LGBM_NetworkFree();
LIGHTGBM_C_EXPORT int LGBM_NetworkInitWithFunctions(void* reduce_scatter_fun_ptr,
void* allgather_fun_ptr,
int num_machines,
int rank);
LIGHTGBM_C_EXPORT int LGBM_NetworkInitWithFunctions(int num_machines, int rank,
void* reduce_scatter_ext_fun,
void* allgather_ext_fun);
// exception handle and error msg
static char* LastErrorMsg() { static THREAD_LOCAL char err_msg[512] = "Everything is fine"; return err_msg; }
......
......@@ -73,6 +73,10 @@ public:
* \param config Config of network setting
*/
static void Init(NetworkConfig config);
/*!
* \brief Initialize
*/
static void Init(int num_machines, int rank, ReduceScatterFunction reduce_scatter_ext_fun, AllgatherFunction allgather_ext_fun);
/*! \brief Free this static class */
static void Dispose();
/*! \brief Get rank of this machine */
......@@ -188,11 +192,6 @@ public:
});
return global;
}
/*! \brief set variables and function ptrs */
static void SetRank(int rank) { rank_ = rank;}
static void SetNumMachines(int num_machines) { num_machines_ = num_machines; }
static void SetReduceScatterFunction(ReduceScatterFunction reduce_scatter_ext_fun) { reduce_scatter_ext_fun_ = reduce_scatter_ext_fun; }
static void SetAllgatherFunction(AllgatherFunction allgather_ext_fun) { allgather_ext_fun_ = allgather_ext_fun; }
private:
/*! \brief Number of all machines */
......
......@@ -1220,16 +1220,12 @@ int LGBM_NetworkFree() {
API_END();
}
int LGBM_NetworkInitWithFunctions(void* reduce_scatter_fun_ptr,
void* allgather_fun_ptr,
int num_machines,
int rank) {
int LGBM_NetworkInitWithFunctions(int num_machines, int rank,
void* reduce_scatter_ext_fun,
void* allgather_ext_fun) {
API_BEGIN();
if (num_machines > 1) {
Network::SetReduceScatterFunction((ReduceScatterFunction)reduce_scatter_fun_ptr);
Network::SetAllgatherFunction((AllgatherFunction)allgather_fun_ptr);
Network::SetNumMachines(num_machines);
Network::SetRank(rank);
Network::Init(num_machines, rank, (ReduceScatterFunction)reduce_scatter_ext_fun, (AllgatherFunction)allgather_ext_fun);
}
API_END();
}
......
......@@ -17,10 +17,10 @@ THREAD_LOCAL BruckMap Network::bruck_map_;
THREAD_LOCAL RecursiveHalvingMap Network::recursive_halving_map_;
THREAD_LOCAL std::vector<comm_size_t> Network::block_start_;
THREAD_LOCAL std::vector<comm_size_t> Network::block_len_;
THREAD_LOCAL comm_size_t Network::buffer_size_;
THREAD_LOCAL comm_size_t Network::buffer_size_ = 0;
THREAD_LOCAL std::vector<char> Network::buffer_;
THREAD_LOCAL ReduceScatterFunction Network::reduce_scatter_ext_fun_ = NULL;
THREAD_LOCAL AllgatherFunction Network::allgather_ext_fun_ = NULL;
THREAD_LOCAL ReduceScatterFunction Network::reduce_scatter_ext_fun_ = nullptr;
THREAD_LOCAL AllgatherFunction Network::allgather_ext_fun_ = nullptr;
void Network::Init(NetworkConfig config) {
......@@ -38,10 +38,27 @@ void Network::Init(NetworkConfig config) {
}
}
void Network::Init(int num_machines, int rank,
ReduceScatterFunction reduce_scatter_ext_fun, AllgatherFunction allgather_ext_fun) {
if (num_machines > 1) {
rank_ = rank;
num_machines_ = num_machines;
block_start_ = std::vector<comm_size_t>(num_machines_);
block_len_ = std::vector<comm_size_t>(num_machines_);
buffer_size_ = 1024 * 1024;
buffer_.resize(buffer_size_);
reduce_scatter_ext_fun_ = reduce_scatter_ext_fun;
allgather_ext_fun_ = allgather_ext_fun;
Log::Info("Local rank: %d, total number of machines: %d", rank_, num_machines_);
}
}
void Network::Dispose() {
num_machines_ = 1;
rank_ = 0;
linkers_.reset(new Linkers());
reduce_scatter_ext_fun_ = nullptr;
allgather_ext_fun_ = nullptr;
}
void Network::Allreduce(char* input, comm_size_t input_size, int type_size, char* output, const ReduceFunction& reducer) {
......@@ -117,7 +134,7 @@ void Network::Allgather(char* input, const comm_size_t* block_start, const comm_
if (num_machines_ <= 1) {
Log::Fatal("Please initilize the network interface first");
}
if (allgather_ext_fun_ != NULL) {
if (allgather_ext_fun_ != nullptr) {
return allgather_ext_fun_(input, block_len[rank_], block_start, block_len, num_machines_, output, all_size);
}
comm_size_t write_pos = 0;
......@@ -155,7 +172,7 @@ void Network::ReduceScatter(char* input, comm_size_t input_size, int type_size,
if (num_machines_ <= 1) {
Log::Fatal("Please initilize the network interface first");
}
if (reduce_scatter_ext_fun_ != NULL) {
if (reduce_scatter_ext_fun_ != nullptr) {
return reduce_scatter_ext_fun_(input, input_size, type_size, block_start, block_len, num_machines_, output, output_size, reducer);
}
if (recursive_halving_map_.need_pairwise) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment