Commit 491dd019 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

thread_local for network interface (#982)

* thread local for the network interface.

* fix bug.
parent 5c0afab2
...@@ -735,12 +735,8 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterFeatureImportance(BoosterHandle handle, ...@@ -735,12 +735,8 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterFeatureImportance(BoosterHandle handle,
int importance_type, int importance_type,
double* out_results); double* out_results);
#if defined(_MSC_VER)
// exception handle and error msg // exception handle and error msg
static char* LastErrorMsg() { static __declspec(thread) char err_msg[512] = "Everything is fine"; return err_msg; } static char* LastErrorMsg() { static THREAD_LOCAL char err_msg[512] = "Everything is fine"; return err_msg; }
#else
static char* LastErrorMsg() { static thread_local char err_msg[512] = "Everything is fine"; return err_msg; }
#endif
#pragma warning(disable : 4996) #pragma warning(disable : 4996)
inline void LGBM_SetLastError(const char* msg) { inline void LGBM_SetLastError(const char* msg) {
......
...@@ -191,23 +191,24 @@ public: ...@@ -191,23 +191,24 @@ public:
private: private:
/*! \brief Number of all machines */ /*! \brief Number of all machines */
static int num_machines_; static THREAD_LOCAL int num_machines_;
/*! \brief Rank of local machine */ /*! \brief Rank of local machine */
static int rank_; static THREAD_LOCAL int rank_;
/*! \brief The network interface, provide send/recv functions */ /*! \brief The network interface, provide send/recv functions */
static std::unique_ptr<Linkers> linkers_; static THREAD_LOCAL std::unique_ptr<Linkers> linkers_;
/*! \brief Bruck map for all gather algorithm*/ /*! \brief Bruck map for all gather algorithm*/
static BruckMap bruck_map_; static THREAD_LOCAL BruckMap bruck_map_;
/*! \brief Recursive halving map for reduce scatter */ /*! \brief Recursive halving map for reduce scatter */
static RecursiveHalvingMap recursive_halving_map_; static THREAD_LOCAL RecursiveHalvingMap recursive_halving_map_;
/*! \brief Buffer to store block start index */ /*! \brief Buffer to store block start index */
static std::vector<int> block_start_; static THREAD_LOCAL std::vector<int> block_start_;
/*! \brief Buffer to store block size */ /*! \brief Buffer to store block size */
static std::vector<int> block_len_; static THREAD_LOCAL std::vector<int> block_len_;
/*! \brief Buffer */ /*! \brief Buffer */
static std::vector<char> buffer_; static THREAD_LOCAL std::vector<char> buffer_;
/*! \brief Size of buffer_ */ /*! \brief Size of buffer_ */
static int buffer_size_; static THREAD_LOCAL int buffer_size_;
}; };
inline int Network::rank() { inline int Network::rank() {
......
...@@ -11,6 +11,11 @@ ...@@ -11,6 +11,11 @@
namespace LightGBM { namespace LightGBM {
#if defined(_MSC_VER)
#define THREAD_LOCAL __declspec(thread)
#else
#define THREAD_LOCAL thread_local
#endif
#ifndef CHECK #ifndef CHECK
#define CHECK(condition) \ #define CHECK(condition) \
...@@ -92,11 +97,7 @@ private: ...@@ -92,11 +97,7 @@ private:
// a trick to use static variable in header file. // a trick to use static variable in header file.
// May be not good, but avoid to use an additional cpp file // May be not good, but avoid to use an additional cpp file
#if defined(_MSC_VER) static LogLevel& GetLevel() { static THREAD_LOCAL LogLevel level = LogLevel::Info; return level; }
static LogLevel& GetLevel() { static __declspec(thread) LogLevel level = LogLevel::Info; return level; }
#else
static LogLevel& GetLevel() { static thread_local LogLevel level = LogLevel::Info; return level; }
#endif
}; };
......
...@@ -10,15 +10,15 @@ ...@@ -10,15 +10,15 @@
namespace LightGBM { namespace LightGBM {
// static member definition // static member definition
int Network::num_machines_; THREAD_LOCAL int Network::num_machines_;
int Network::rank_; THREAD_LOCAL int Network::rank_;
std::unique_ptr<Linkers> Network::linkers_; THREAD_LOCAL std::unique_ptr<Linkers> Network::linkers_;
BruckMap Network::bruck_map_; THREAD_LOCAL BruckMap Network::bruck_map_;
RecursiveHalvingMap Network::recursive_halving_map_; THREAD_LOCAL RecursiveHalvingMap Network::recursive_halving_map_;
std::vector<int> Network::block_start_; THREAD_LOCAL std::vector<int> Network::block_start_;
std::vector<int> Network::block_len_; THREAD_LOCAL std::vector<int> Network::block_len_;
int Network::buffer_size_; THREAD_LOCAL int Network::buffer_size_;
std::vector<char> Network::buffer_; THREAD_LOCAL std::vector<char> Network::buffer_;
void Network::Init(NetworkConfig config) { void Network::Init(NetworkConfig config) {
linkers_.reset(new Linkers(config)); linkers_.reset(new Linkers(config));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment