Unverified Commit 4c5d0fbb authored by Scott Votaw's avatar Scott Votaw Committed by GitHub
Browse files

Fix OpenMP thread allocation in Linux (#5551)

parent 51efd901
...@@ -261,8 +261,9 @@ class Bin { ...@@ -261,8 +261,9 @@ class Bin {
/*! /*!
* \brief Initialize for pushing. By default, no action needed. * \brief Initialize for pushing. By default, no action needed.
* \param num_thread The number of external threads that will be calling the push APIs * \param num_thread The number of external threads that will be calling the push APIs
* \param omp_max_threads The maximum number of OpenMP threads to allocate for
*/ */
virtual void InitStreaming(uint32_t /*num_thread*/) { } virtual void InitStreaming(uint32_t /*num_thread*/, int32_t /*omp_max_threads*/) { }
/*! /*!
* \brief Push one record * \brief Push one record
* \param tid Thread id * \param tid Thread id
......
...@@ -153,6 +153,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateByReference(const DatasetHandle referenc ...@@ -153,6 +153,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateByReference(const DatasetHandle referenc
* \param has_queries Whether the dataset has Metadata queries/groups * \param has_queries Whether the dataset has Metadata queries/groups
* \param nclasses Number of initial score classes * \param nclasses Number of initial score classes
* \param nthreads Number of external threads that will use the PushRows APIs * \param nthreads Number of external threads that will use the PushRows APIs
* \param omp_max_threads Maximum number of OpenMP threads (-1 for default)
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
LIGHTGBM_C_EXPORT int LGBM_DatasetInitStreaming(DatasetHandle dataset, LIGHTGBM_C_EXPORT int LGBM_DatasetInitStreaming(DatasetHandle dataset,
...@@ -160,7 +161,8 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetInitStreaming(DatasetHandle dataset, ...@@ -160,7 +161,8 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetInitStreaming(DatasetHandle dataset,
int32_t has_init_scores, int32_t has_init_scores,
int32_t has_queries, int32_t has_queries,
int32_t nclasses, int32_t nclasses,
int32_t nthreads); int32_t nthreads,
int32_t omp_max_threads);
/*! /*!
* \brief Push data to existing dataset, if ``nrow + start_row == num_total_row``, will call ``dataset->FinishLoad``. * \brief Push data to existing dataset, if ``nrow + start_row == num_total_row``, will call ``dataset->FinishLoad``.
......
...@@ -458,10 +458,18 @@ class Dataset { ...@@ -458,10 +458,18 @@ class Dataset {
int32_t has_init_scores, int32_t has_init_scores,
int32_t has_queries, int32_t has_queries,
int32_t nclasses, int32_t nclasses,
int32_t nthreads) { int32_t nthreads,
int32_t omp_max_threads) {
// Initialize optional max thread count with either parameter or OMP setting
if (omp_max_threads > 0) {
omp_max_threads_ = omp_max_threads;
} else if (omp_max_threads_ <= 0) {
omp_max_threads_ = OMP_NUM_THREADS();
}
metadata_.Init(num_data, has_weights, has_init_scores, has_queries, nclasses); metadata_.Init(num_data, has_weights, has_init_scores, has_queries, nclasses);
for (int i = 0; i < num_groups_; ++i) { for (int i = 0; i < num_groups_; ++i) {
feature_groups_[i]->InitStreaming(nthreads); feature_groups_[i]->InitStreaming(nthreads, omp_max_threads_);
} }
} }
...@@ -846,6 +854,9 @@ class Dataset { ...@@ -846,6 +854,9 @@ class Dataset {
/*! \brief Get whether FinishLoad is automatically called when pushing last row. */ /*! \brief Get whether FinishLoad is automatically called when pushing last row. */
inline bool wait_for_manual_finish() const { return wait_for_manual_finish_; } inline bool wait_for_manual_finish() const { return wait_for_manual_finish_; }
/*! \brief Get the maximum number of OpenMP threads to allocate for. */
inline int omp_max_threads() const { return omp_max_threads_; }
/*! \brief Set whether the Dataset is finished automatically when last row is pushed or with a manual /*! \brief Set whether the Dataset is finished automatically when last row is pushed or with a manual
* MarkFinished API call. Set to true for thread-safe streaming and/or if will be coalesced later. * MarkFinished API call. Set to true for thread-safe streaming and/or if will be coalesced later.
* FinishLoad should not be called on any Dataset that will be coalesced. * FinishLoad should not be called on any Dataset that will be coalesced.
...@@ -947,6 +958,7 @@ class Dataset { ...@@ -947,6 +958,7 @@ class Dataset {
std::vector<int> feature_need_push_zeros_; std::vector<int> feature_need_push_zeros_;
std::vector<std::vector<float>> raw_data_; std::vector<std::vector<float>> raw_data_;
bool wait_for_manual_finish_; bool wait_for_manual_finish_;
int omp_max_threads_ = -1;
bool has_raw_; bool has_raw_;
/*! map feature (inner index) to its index in the list of numeric (non-categorical) features */ /*! map feature (inner index) to its index in the list of numeric (non-categorical) features */
std::vector<int> numeric_feature_map_; std::vector<int> numeric_feature_map_;
......
...@@ -192,14 +192,15 @@ class FeatureGroup { ...@@ -192,14 +192,15 @@ class FeatureGroup {
/*! /*!
* \brief Initialize for pushing in a streaming fashion. By default, no action needed. * \brief Initialize for pushing in a streaming fashion. By default, no action needed.
* \param num_thread The number of external threads that will be calling the push APIs * \param num_thread The number of external threads that will be calling the push APIs
* \param omp_max_threads The maximum number of OpenMP threads to allocate for
*/ */
void InitStreaming(int32_t num_thread) { void InitStreaming(int32_t num_thread, int32_t omp_max_threads) {
if (is_multi_val_) { if (is_multi_val_) {
for (int i = 0; i < num_feature_; ++i) { for (int i = 0; i < num_feature_; ++i) {
multi_bin_data_[i]->InitStreaming(num_thread); multi_bin_data_[i]->InitStreaming(num_thread, omp_max_threads);
} }
} else { } else {
bin_data_->InitStreaming(num_thread); bin_data_->InitStreaming(num_thread, omp_max_threads);
} }
} }
......
...@@ -1018,11 +1018,12 @@ int LGBM_DatasetInitStreaming(DatasetHandle dataset, ...@@ -1018,11 +1018,12 @@ int LGBM_DatasetInitStreaming(DatasetHandle dataset,
int32_t has_init_scores, int32_t has_init_scores,
int32_t has_queries, int32_t has_queries,
int32_t nclasses, int32_t nclasses,
int32_t nthreads) { int32_t nthreads,
int32_t omp_max_threads) {
API_BEGIN(); API_BEGIN();
auto p_dataset = reinterpret_cast<Dataset*>(dataset); auto p_dataset = reinterpret_cast<Dataset*>(dataset);
auto num_data = p_dataset->num_data(); auto num_data = p_dataset->num_data();
p_dataset->InitStreaming(num_data, has_weights, has_init_scores, has_queries, nclasses, nthreads); p_dataset->InitStreaming(num_data, has_weights, has_init_scores, has_queries, nclasses, nthreads, omp_max_threads);
p_dataset->set_wait_for_manual_finish(true); p_dataset->set_wait_for_manual_finish(true);
API_END(); API_END();
} }
...@@ -1073,19 +1074,20 @@ int LGBM_DatasetPushRowsWithMetadata(DatasetHandle dataset, ...@@ -1073,19 +1074,20 @@ int LGBM_DatasetPushRowsWithMetadata(DatasetHandle dataset,
if (!data) { if (!data) {
Log::Fatal("data cannot be null."); Log::Fatal("data cannot be null.");
} }
const int num_omp_threads = OMP_NUM_THREADS();
auto p_dataset = reinterpret_cast<Dataset*>(dataset); auto p_dataset = reinterpret_cast<Dataset*>(dataset);
auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, 1); auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, 1);
if (p_dataset->has_raw()) { if (p_dataset->has_raw()) {
p_dataset->ResizeRaw(p_dataset->num_numeric_features() + nrow); p_dataset->ResizeRaw(p_dataset->num_numeric_features() + nrow);
} }
const int max_omp_threads = p_dataset->omp_max_threads() > 0 ? p_dataset->omp_max_threads() : OMP_NUM_THREADS();
OMP_INIT_EX(); OMP_INIT_EX();
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (int i = 0; i < nrow; ++i) { for (int i = 0; i < nrow; ++i) {
OMP_LOOP_EX_BEGIN(); OMP_LOOP_EX_BEGIN();
// convert internal thread id to be unique based on external thread id // convert internal thread id to be unique based on external thread id
const int internal_tid = omp_get_thread_num() + (num_omp_threads * tid); const int internal_tid = omp_get_thread_num() + (max_omp_threads * tid);
auto one_row = get_row_fun(i); auto one_row = get_row_fun(i);
p_dataset->PushOneRow(internal_tid, start_row + i, one_row); p_dataset->PushOneRow(internal_tid, start_row + i, one_row);
OMP_LOOP_EX_END(); OMP_LOOP_EX_END();
...@@ -1154,19 +1156,21 @@ int LGBM_DatasetPushRowsByCSRWithMetadata(DatasetHandle dataset, ...@@ -1154,19 +1156,21 @@ int LGBM_DatasetPushRowsByCSRWithMetadata(DatasetHandle dataset,
if (!data) { if (!data) {
Log::Fatal("data cannot be null."); Log::Fatal("data cannot be null.");
} }
const int num_omp_threads = OMP_NUM_THREADS();
auto p_dataset = reinterpret_cast<Dataset*>(dataset); auto p_dataset = reinterpret_cast<Dataset*>(dataset);
auto get_row_fun = RowFunctionFromCSR<int>(indptr, indptr_type, indices, data, data_type, nindptr, nelem); auto get_row_fun = RowFunctionFromCSR<int>(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
int32_t nrow = static_cast<int32_t>(nindptr - 1); int32_t nrow = static_cast<int32_t>(nindptr - 1);
if (p_dataset->has_raw()) { if (p_dataset->has_raw()) {
p_dataset->ResizeRaw(p_dataset->num_numeric_features() + nrow); p_dataset->ResizeRaw(p_dataset->num_numeric_features() + nrow);
} }
const int max_omp_threads = p_dataset->omp_max_threads() > 0 ? p_dataset->omp_max_threads() : OMP_NUM_THREADS();
OMP_INIT_EX(); OMP_INIT_EX();
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (int i = 0; i < nrow; ++i) { for (int i = 0; i < nrow; ++i) {
OMP_LOOP_EX_BEGIN(); OMP_LOOP_EX_BEGIN();
// convert internal thread id to be unique based on external thread id // convert internal thread id to be unique based on external thread id
const int internal_tid = omp_get_thread_num() + (num_omp_threads * tid); const int internal_tid = omp_get_thread_num() + (max_omp_threads * tid);
auto one_row = get_row_fun(i); auto one_row = get_row_fun(i);
p_dataset->PushOneRow(internal_tid, static_cast<data_size_t>(start_row + i), one_row); p_dataset->PushOneRow(internal_tid, static_cast<data_size_t>(start_row + i), one_row);
OMP_LOOP_EX_END(); OMP_LOOP_EX_END();
......
...@@ -81,10 +81,10 @@ class SparseBin : public Bin { ...@@ -81,10 +81,10 @@ class SparseBin : public Bin {
~SparseBin() {} ~SparseBin() {}
void InitStreaming(uint32_t num_thread) override { void InitStreaming(uint32_t num_thread, int32_t omp_max_threads) override {
// Each thread needs its own push buffer, so allocate external num_thread times the number of OMP threads // Each external thread needs its own set of OpenMP push buffers,
int num_omp_threads = OMP_NUM_THREADS(); // so allocate num_thread times the maximum number of OMP threads per external thread
push_buffers_.resize(num_omp_threads * num_thread); push_buffers_.resize(omp_max_threads * num_thread);
}; };
void ReSize(data_size_t num_data) override { num_data_ = num_data; } void ReSize(data_size_t num_data) override { num_data_ = num_data; }
......
...@@ -79,7 +79,7 @@ void test_stream_dense( ...@@ -79,7 +79,7 @@ void test_stream_dense(
&dataset_handle); &dataset_handle);
EXPECT_EQ(0, result) << "LGBM_DatasetCreateFromSampledColumn result code: " << result; EXPECT_EQ(0, result) << "LGBM_DatasetCreateFromSampledColumn result code: " << result;
result = LGBM_DatasetInitStreaming(dataset_handle, has_weights, has_init_scores, has_queries, nclasses, 1); result = LGBM_DatasetInitStreaming(dataset_handle, has_weights, has_init_scores, has_queries, nclasses, 1, -1);
EXPECT_EQ(0, result) << "LGBM_DatasetInitStreaming result code: " << result; EXPECT_EQ(0, result) << "LGBM_DatasetInitStreaming result code: " << result;
break; break;
} }
...@@ -197,7 +197,7 @@ void test_stream_sparse( ...@@ -197,7 +197,7 @@ void test_stream_sparse(
EXPECT_EQ(0, result) << "LGBM_DatasetCreateFromSampledColumn result code: " << result; EXPECT_EQ(0, result) << "LGBM_DatasetCreateFromSampledColumn result code: " << result;
dataset = static_cast<Dataset*>(dataset_handle); dataset = static_cast<Dataset*>(dataset_handle);
dataset->InitStreaming(nrows, has_weights, has_init_scores, has_queries, nclasses, 2); dataset->InitStreaming(nrows, has_weights, has_init_scores, has_queries, nclasses, 2, -1);
break; break;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment