Commit 491dd019 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

thread_local for network interface (#982)

* thread local for the network interface.

* fix bug.
parent 5c0afab2
......@@ -735,12 +735,8 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterFeatureImportance(BoosterHandle handle,
int importance_type,
double* out_results);
#if defined(_MSC_VER)
// exception handle and error msg
static char* LastErrorMsg() { static __declspec(thread) char err_msg[512] = "Everything is fine"; return err_msg; }
#else
static char* LastErrorMsg() { static thread_local char err_msg[512] = "Everything is fine"; return err_msg; }
#endif
static char* LastErrorMsg() { static THREAD_LOCAL char err_msg[512] = "Everything is fine"; return err_msg; }
#pragma warning(disable : 4996)
inline void LGBM_SetLastError(const char* msg) {
......
......@@ -191,23 +191,24 @@ public:
private:
/*! \brief Number of all machines */
static int num_machines_;
static THREAD_LOCAL int num_machines_;
/*! \brief Rank of local machine */
static int rank_;
static THREAD_LOCAL int rank_;
/*! \brief The network interface, provide send/recv functions */
static std::unique_ptr<Linkers> linkers_;
static THREAD_LOCAL std::unique_ptr<Linkers> linkers_;
/*! \brief Bruck map for all gather algorithm*/
static BruckMap bruck_map_;
static THREAD_LOCAL BruckMap bruck_map_;
/*! \brief Recursive halving map for reduce scatter */
static RecursiveHalvingMap recursive_halving_map_;
static THREAD_LOCAL RecursiveHalvingMap recursive_halving_map_;
/*! \brief Buffer to store block start index */
static std::vector<int> block_start_;
static THREAD_LOCAL std::vector<int> block_start_;
/*! \brief Buffer to store block size */
static std::vector<int> block_len_;
static THREAD_LOCAL std::vector<int> block_len_;
/*! \brief Buffer */
static std::vector<char> buffer_;
static THREAD_LOCAL std::vector<char> buffer_;
/*! \brief Size of buffer_ */
static int buffer_size_;
static THREAD_LOCAL int buffer_size_;
};
inline int Network::rank() {
......
......@@ -11,6 +11,11 @@
namespace LightGBM {
#if defined(_MSC_VER)
#define THREAD_LOCAL __declspec(thread)
#else
#define THREAD_LOCAL thread_local
#endif
#ifndef CHECK
#define CHECK(condition) \
......@@ -92,11 +97,7 @@ private:
// a trick to use static variable in header file.
// May be not good, but avoid to use an additional cpp file
#if defined(_MSC_VER)
static LogLevel& GetLevel() { static __declspec(thread) LogLevel level = LogLevel::Info; return level; }
#else
static LogLevel& GetLevel() { static thread_local LogLevel level = LogLevel::Info; return level; }
#endif
static LogLevel& GetLevel() { static THREAD_LOCAL LogLevel level = LogLevel::Info; return level; }
};
......
......@@ -10,15 +10,15 @@
namespace LightGBM {
// static member definition
int Network::num_machines_;
int Network::rank_;
std::unique_ptr<Linkers> Network::linkers_;
BruckMap Network::bruck_map_;
RecursiveHalvingMap Network::recursive_halving_map_;
std::vector<int> Network::block_start_;
std::vector<int> Network::block_len_;
int Network::buffer_size_;
std::vector<char> Network::buffer_;
THREAD_LOCAL int Network::num_machines_;
THREAD_LOCAL int Network::rank_;
THREAD_LOCAL std::unique_ptr<Linkers> Network::linkers_;
THREAD_LOCAL BruckMap Network::bruck_map_;
THREAD_LOCAL RecursiveHalvingMap Network::recursive_halving_map_;
THREAD_LOCAL std::vector<int> Network::block_start_;
THREAD_LOCAL std::vector<int> Network::block_len_;
THREAD_LOCAL int Network::buffer_size_;
THREAD_LOCAL std::vector<char> Network::buffer_;
void Network::Init(NetworkConfig config) {
linkers_.reset(new Linkers(config));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment