Unverified Commit 1c35c3b9 authored by Joan Fontanals's avatar Joan Fontanals Committed by GitHub
Browse files

Change locking strategy of Booster, allow for share and unique locks (#2760)



* Add capability to get possible max and min values for a model

* Change implementation to have return value in tree.cpp, change naming to upper and lower bound, move implementation to gdbt.cpp

* Update include/LightGBM/c_api.h
Co-Authored-By: default avatarNikita Titov <nekit94-08@mail.ru>

* Change iteration to avoid potential overflow, add bindings to R and Python and a basic test

* Adjust test values

* Consider const correctness and multithreading protection

* Put everything possible as const

* Include shared_mutex, for now as unique_lock

* Update test values

* Put everything possible as const

* Include shared_mutex, for now as unique_lock

* Make PredictSingleRow const and share the lock with other reading threads

* Update test values

* Add test to check that model is exactly the same in all platforms

* Try to parse the model to get the expected values

* Try to parse the model to get the expected values

* Fix implementation, num_leaves can be lower than the leaf_value_ size

* Do not check for num_leaves to be smaller than actual size and get back to test with hardcoded value

* Change test order

* Add gpu_use_dp option in test

* Remove helper test method

* Remove TODO

* Add preprocessing option to compile with c++17

* Update python-package/setup.py
Co-Authored-By: default avatarNikita Titov <nekit94-08@mail.ru>

* Remove unwanted changes

* Move option

* Fix problems introduced by conflict fix

* Avoid switching to c++17 and use yamc mutex library to access shared lock functionality

* Add extra yamc include

* Change header order

* some lint fix

* change include order and remove some extra blank lines

* Further fix lint issues

* Update c_api.cpp

* Further fix lint issues

* Move yamc include files to a new yamc folder

* Use standard unique_lock

* Update windows/LightGBM.vcxproj
Co-authored-by: default avatarGuolin Ke <guolin.ke@outlook.com>

* Update windows/LightGBM.vcxproj.filters
Co-authored-by: default avatarGuolin Ke <guolin.ke@outlook.com>

* Update windows/LightGBM.vcxproj.filters
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* Update windows/LightGBM.vcxproj.filters
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* Update windows/LightGBM.vcxproj.filters
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* Fix problems coming from merge conflict resolution
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
Co-authored-by: default avatarjoanfontanals <jfontanals@ntent.com>
Co-authored-by: default avatarGuolin Ke <guolin.ke@outlook.com>
parent f5f27ca8
/*
* alternate_shared_mutex.hpp
*
* MIT License
*
* Copyright (c) 2017 yohhoy
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#ifndef YAMC_ALTERNATE_SHARED_MUTEX_HPP_
#define YAMC_ALTERNATE_SHARED_MUTEX_HPP_
#include <cassert>
#include <chrono>
#include <condition_variable>
#include <mutex>
#include "yamc_rwlock_sched.hpp"
namespace yamc {
/*
* alternate implementation of shared mutex variants
*
* - yamc::alternate::shared_mutex
* - yamc::alternate::shared_timed_mutex
* - yamc::alternate::basic_shared_mutex<RwLockPolicy>
* - yamc::alternate::basic_shared_timed_mutex<RwLockPolicy>
*/
namespace alternate {
namespace detail {
template <typename RwLockPolicy>
class shared_mutex_base {
protected:
typename RwLockPolicy::state state_;
std::condition_variable cv_;
std::mutex mtx_;
void lock() {
std::unique_lock<decltype(mtx_)> lk(mtx_);
RwLockPolicy::before_wait_wlock(state_);
while (RwLockPolicy::wait_wlock(state_)) {
cv_.wait(lk);
}
RwLockPolicy::after_wait_wlock(state_);
RwLockPolicy::acquire_wlock(&state_);
}
bool try_lock() {
std::lock_guard<decltype(mtx_)> lk(mtx_);
if (RwLockPolicy::wait_wlock(state_)) return false;
RwLockPolicy::acquire_wlock(state_);
return true;
}
void unlock() {
std::lock_guard<decltype(mtx_)> lk(mtx_);
RwLockPolicy::release_wlock(&state_);
cv_.notify_all();
}
void lock_shared() {
std::unique_lock<decltype(mtx_)> lk(mtx_);
while (RwLockPolicy::wait_rlock(state_)) {
cv_.wait(lk);
}
RwLockPolicy::acquire_rlock(&state_);
}
bool try_lock_shared() {
std::lock_guard<decltype(mtx_)> lk(mtx_);
if (RwLockPolicy::wait_rlock(state_)) return false;
RwLockPolicy::acquire_rlock(state_);
return true;
}
void unlock_shared() {
std::lock_guard<decltype(mtx_)> lk(mtx_);
if (RwLockPolicy::release_rlock(&state_)) {
cv_.notify_all();
}
}
};
} // namespace detail
template <typename RwLockPolicy>
class basic_shared_mutex : private detail::shared_mutex_base<RwLockPolicy> {
using base = detail::shared_mutex_base<RwLockPolicy>;
public:
basic_shared_mutex() = default;
~basic_shared_mutex() = default;
basic_shared_mutex(const basic_shared_mutex&) = delete;
basic_shared_mutex& operator=(const basic_shared_mutex&) = delete;
using base::lock;
using base::try_lock;
using base::unlock;
using base::lock_shared;
using base::try_lock_shared;
using base::unlock_shared;
};
using shared_mutex = basic_shared_mutex<YAMC_RWLOCK_SCHED_DEFAULT>;
template <typename RwLockPolicy>
class basic_shared_timed_mutex
: private detail::shared_mutex_base<RwLockPolicy> {
using base = detail::shared_mutex_base<RwLockPolicy>;
using base::cv_;
using base::mtx_;
using base::state_;
template <typename Clock, typename Duration>
bool do_try_lockwait(const std::chrono::time_point<Clock, Duration>& tp) {
std::unique_lock<decltype(mtx_)> lk(mtx_);
RwLockPolicy::before_wait_wlock(state_);
while (RwLockPolicy::wait_wlock(state_)) {
if (cv_.wait_until(lk, tp) == std::cv_status::timeout) {
if (!RwLockPolicy::wait_wlock(state_)) // re-check predicate
break;
RwLockPolicy::after_wait_wlock(state_);
return false;
}
}
RwLockPolicy::after_wait_wlock(state_);
RwLockPolicy::acquire_wlock(state_);
return true;
}
template <typename Clock, typename Duration>
bool do_try_lock_sharedwait(
const std::chrono::time_point<Clock, Duration>& tp) {
std::unique_lock<decltype(mtx_)> lk(mtx_);
while (RwLockPolicy::wait_rlock(state_)) {
if (cv_.wait_until(lk, tp) == std::cv_status::timeout) {
if (!RwLockPolicy::wait_rlock(state_)) // re-check predicate
break;
return false;
}
}
RwLockPolicy::acquire_rlock(state_);
return true;
}
public:
basic_shared_timed_mutex() = default;
~basic_shared_timed_mutex() = default;
basic_shared_timed_mutex(const basic_shared_timed_mutex&) = delete;
basic_shared_timed_mutex& operator=(const basic_shared_timed_mutex&) = delete;
using base::lock;
using base::try_lock;
using base::unlock;
template <typename Rep, typename Period>
bool try_lock_for(const std::chrono::duration<Rep, Period>& duration) {
const auto tp = std::chrono::steady_clock::now() + duration;
return do_try_lockwait(tp);
}
template <typename Clock, typename Duration>
bool try_lock_until(const std::chrono::time_point<Clock, Duration>& tp) {
return do_try_lockwait(tp);
}
using base::lock_shared;
using base::try_lock_shared;
using base::unlock_shared;
template <typename Rep, typename Period>
bool try_lock_shared_for(const std::chrono::duration<Rep, Period>& duration) {
const auto tp = std::chrono::steady_clock::now() + duration;
return do_try_lock_sharedwait(tp);
}
template <typename Clock, typename Duration>
bool try_lock_shared_until(
const std::chrono::time_point<Clock, Duration>& tp) {
return do_try_lock_sharedwait(tp);
}
};
using shared_timed_mutex = basic_shared_timed_mutex<YAMC_RWLOCK_SCHED_DEFAULT>;
} // namespace alternate
} // namespace yamc
#endif
/*
* yamc_rwlock_sched.hpp
*
* MIT License
*
* Copyright (c) 2017 yohhoy
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#ifndef YAMC_RWLOCK_SCHED_HPP_
#define YAMC_RWLOCK_SCHED_HPP_
#include <cassert>
#include <cstddef>
/// default shared_mutex rwlock policy
#ifndef YAMC_RWLOCK_SCHED_DEFAULT
#define YAMC_RWLOCK_SCHED_DEFAULT yamc::rwlock::ReaderPrefer
#endif
namespace yamc {
/*
* readers-writer locking policy for basic_shared_(timed)_mutex<RwLockPolicy>
*
* - yamc::rwlock::ReaderPrefer
* - yamc::rwlock::WriterPrefer
*/
namespace rwlock {
/// Reader prefer scheduling
///
/// NOTE:
// This policy might introduce "Writer Starvation" if readers continuously
// hold shared lock. PThreads rwlock implementation in Linux use this
// scheduling policy as default. (see also PTHREAD_RWLOCK_PREFER_READER_NP)
//
struct ReaderPrefer {
static const std::size_t writer_mask = ~(~std::size_t(0u) >> 1); // MSB 1bit
static const std::size_t reader_mask = ~std::size_t(0u) >> 1;
struct state {
std::size_t rwcount = 0;
};
static void before_wait_wlock(const state&) {}
static void after_wait_wlock(const state&) {}
static bool wait_wlock(const state& s) { return (s.rwcount != 0); }
static void acquire_wlock(state* s) {
assert(!(s->rwcount & writer_mask));
s->rwcount |= writer_mask;
}
static void release_wlock(state* s) {
assert(s->rwcount & writer_mask);
s->rwcount &= ~writer_mask;
}
static bool wait_rlock(const state& s) { return (s.rwcount & writer_mask) != 0; }
static void acquire_rlock(state* s) {
assert((s->rwcount & reader_mask) < reader_mask);
++(s->rwcount);
}
static bool release_rlock(state* s) {
assert(0 < (s->rwcount & reader_mask));
return (--(s->rwcount) == 0);
}
};
/// Writer prefer scheduling
///
/// NOTE:
/// If there are waiting writer, new readers are blocked until all shared lock
/// are released,
// and the writer thread can get exclusive lock in preference to blocked
// reader threads. This policy might introduce "Reader Starvation" if writers
// continuously request exclusive lock.
/// (see also PTHREAD_RWLOCK_PREFER_WRITER_NONRECURSIVE_NP)
///
struct WriterPrefer {
static const std::size_t locked = ~(~std::size_t(0u) >> 1); // MSB 1bit
static const std::size_t wait_mask = ~std::size_t(0u) >> 1;
struct state {
std::size_t nwriter = 0;
std::size_t nreader = 0;
};
static void before_wait_wlock(state* s) {
assert((s->nwriter & wait_mask) < wait_mask);
++(s->nwriter);
}
static bool wait_wlock(const state& s) {
return ((s.nwriter & locked) || 0 < s.nreader);
}
static void after_wait_wlock(state* s) {
assert(0 < (s->nwriter & wait_mask));
--(s->nwriter);
}
static void acquire_wlock(state* s) {
assert(!(s->nwriter & locked));
s->nwriter |= locked;
}
static void release_wlock(state* s) {
assert(s->nwriter & locked);
s->nwriter &= ~locked;
}
static bool wait_rlock(const state& s) { return (s.nwriter != 0); }
static void acquire_rlock(state* s) {
assert(!(s->nwriter & locked));
++(s->nreader);
}
static bool release_rlock(state* s) {
assert(0 < s->nreader);
return (--(s->nreader) == 0);
}
};
} // namespace rwlock
} // namespace yamc
#endif
/*
* yamc_shared_lock.hpp
*
* MIT License
*
* Copyright (c) 2017 yohhoy
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#ifndef YAMC_SHARED_LOCK_HPP_
#define YAMC_SHARED_LOCK_HPP_
#include <cassert>
#include <chrono>
#include <mutex>
#include <system_error>
#include <utility> // std::swap
/*
* std::shared_lock in C++14 Standard Library
*
* - yamc::shared_lock<Mutex>
*/
namespace yamc {
template <typename Mutex>
class shared_lock {
void locking_precondition(const char* emsg) {
if (pm_ == nullptr) {
throw std::system_error(
std::make_error_code(std::errc::operation_not_permitted), emsg);
}
if (owns_) {
throw std::system_error(
std::make_error_code(std::errc::resource_deadlock_would_occur), emsg);
}
}
public:
using mutex_type = Mutex;
shared_lock() noexcept = default;
explicit shared_lock(mutex_type* m) {
m->lock_shared();
pm_ = m;
owns_ = true;
}
shared_lock(const mutex_type& m, std::defer_lock_t) noexcept {
pm_ = &m;
owns_ = false;
}
shared_lock(const mutex_type& m, std::try_to_lock_t) {
pm_ = &m;
owns_ = m.try_lock_shared();
}
shared_lock(const mutex_type& m, std::adopt_lock_t) {
pm_ = &m;
owns_ = true;
}
template <typename Clock, typename Duration>
shared_lock(const mutex_type& m,
const std::chrono::time_point<Clock, Duration>& abs_time) {
pm_ = &m;
owns_ = m.try_lock_shared_until(abs_time);
}
template <typename Rep, typename Period>
shared_lock(const mutex_type& m,
const std::chrono::duration<Rep, Period>& rel_time) {
pm_ = &m;
owns_ = m.try_lock_shared_for(rel_time);
}
~shared_lock() {
if (owns_) {
assert(pm_ != nullptr);
pm_->unlock_shared();
}
}
shared_lock(const shared_lock&) = delete;
shared_lock& operator=(const shared_lock&) = delete;
shared_lock(shared_lock&& rhs) noexcept {
if (pm_ && owns_) {
pm_->unlock_shared();
}
pm_ = rhs.pm_;
owns_ = rhs.owns_;
rhs.pm_ = nullptr;
rhs.owns_ = false;
}
shared_lock& operator=(shared_lock&& rhs) noexcept {
if (pm_ && owns_) {
pm_->unlock_shared();
}
pm_ = rhs.pm_;
owns_ = rhs.owns_;
rhs.pm_ = nullptr;
rhs.owns_ = false;
return *this;
}
void lock() {
locking_precondition("shared_lock::lock");
pm_->lock_shared();
owns_ = true;
}
bool try_lock() {
locking_precondition("shared_lock::try_lock");
return (owns_ = pm_->try_lock_shared());
}
template <typename Rep, typename Period>
bool try_lock_for(const std::chrono::duration<Rep, Period>& rel_time) {
locking_precondition("shared_lock::try_lock_for");
return (owns_ = pm_->try_lock_shared_for(rel_time));
}
template <typename Clock, typename Duration>
bool try_lock_until(
const std::chrono::time_point<Clock, Duration>& abs_time) {
locking_precondition("shared_lock::try_lock_until");
return (owns_ = pm_->try_lock_shared_until(abs_time));
}
void unlock() {
assert(pm_ != nullptr);
if (!owns_) {
throw std::system_error(
std::make_error_code(std::errc::operation_not_permitted),
"shared_lock::unlock");
}
pm_->unlock_shared();
owns_ = false;
}
void swap(shared_lock& sl) noexcept {
std::swap(pm_, sl.pm_);
std::swap(owns_, sl.owns_);
}
mutex_type* release() noexcept {
mutex_type* result = pm_;
pm_ = nullptr;
owns_ = false;
return result;
}
bool owns_lock() const noexcept { return owns_; }
explicit operator bool() const noexcept { return owns_; }
mutex_type* mutex() const noexcept { return pm_; }
private:
mutex_type* pm_ = nullptr;
bool owns_ = false;
};
} // namespace yamc
namespace std {
/// std::swap() specialization for yamc::shared_lock<Mutex> type
template <typename Mutex>
void swap(yamc::shared_lock<Mutex>& lhs,
yamc::shared_lock<Mutex>& rhs) noexcept {
lhs.swap(rhs);
}
} // namespace std
#endif
......@@ -27,6 +27,8 @@
#include <vector>
#include "application/predictor.hpp"
#include <LightGBM/utils/yamc/alternate_shared_mutex.hpp>
#include <LightGBM/utils/yamc/yamc_shared_lock.hpp>
namespace LightGBM {
......@@ -46,6 +48,12 @@ catch(std::string& ex) { return LGBM_APIHandleException(ex); } \
catch(...) { return LGBM_APIHandleException("unknown exception"); } \
return 0;
#define UNIQUE_LOCK(mtx) \
std::unique_lock<yamc::alternate::shared_mutex> lock(mtx);
#define SHARED_LOCK(mtx) \
yamc::shared_lock<yamc::alternate::shared_mutex> lock(&mtx);
const int PREDICTOR_TYPES = 4;
// Single row predictor to abstract away caching logic
......@@ -133,7 +141,7 @@ class Booster {
}
void MergeFrom(const Booster* other) {
std::lock_guard<std::mutex> lock(mutex_);
UNIQUE_LOCK(mutex_)
boosting_->MergeFrom(other->boosting_.get());
}
......@@ -166,7 +174,7 @@ class Booster {
void ResetTrainingData(const Dataset* train_data) {
if (train_data != train_data_) {
std::lock_guard<std::mutex> lock(mutex_);
UNIQUE_LOCK(mutex_)
train_data_ = train_data;
CreateObjectiveAndMetrics();
// reset the boosting
......@@ -284,7 +292,7 @@ class Booster {
}
void ResetConfig(const char* parameters) {
std::lock_guard<std::mutex> lock(mutex_);
UNIQUE_LOCK(mutex_)
auto param = Config::Str2Map(parameters);
if (param.count("num_class")) {
Log::Fatal("Cannot change num_class during training");
......@@ -322,7 +330,7 @@ class Booster {
}
void AddValidData(const Dataset* valid_data) {
std::lock_guard<std::mutex> lock(mutex_);
UNIQUE_LOCK(mutex_)
valid_metrics_.emplace_back();
for (auto metric_type : config_.metric) {
auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_));
......@@ -336,12 +344,12 @@ class Booster {
}
bool TrainOneIter() {
std::lock_guard<std::mutex> lock(mutex_);
UNIQUE_LOCK(mutex_)
return boosting_->TrainOneIter(nullptr, nullptr);
}
void Refit(const int32_t* leaf_preds, int32_t nrow, int32_t ncol) {
std::lock_guard<std::mutex> lock(mutex_);
UNIQUE_LOCK(mutex_)
std::vector<std::vector<int32_t>> v_leaf_preds(nrow, std::vector<int32_t>(ncol, 0));
for (int i = 0; i < nrow; ++i) {
for (int j = 0; j < ncol; ++j) {
......@@ -352,37 +360,42 @@ class Booster {
}
bool TrainOneIter(const score_t* gradients, const score_t* hessians) {
std::lock_guard<std::mutex> lock(mutex_);
UNIQUE_LOCK(mutex_)
return boosting_->TrainOneIter(gradients, hessians);
}
void RollbackOneIter() {
std::lock_guard<std::mutex> lock(mutex_);
UNIQUE_LOCK(mutex_)
boosting_->RollbackOneIter();
}
void PredictSingleRow(int num_iteration, int predict_type, int ncol,
void SetSingleRowPredictor(int num_iteration, int predict_type, const Config& config) {
UNIQUE_LOCK(mutex_)
if (single_row_predictor_[predict_type].get() == nullptr ||
!single_row_predictor_[predict_type]->IsPredictorEqual(config, num_iteration, boosting_.get())) {
single_row_predictor_[predict_type].reset(new SingleRowPredictor(predict_type, boosting_.get(),
config, num_iteration));
}
}
void PredictSingleRow(int predict_type, int ncol,
std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
const Config& config,
double* out_result, int64_t* out_len) {
double* out_result, int64_t* out_len) const {
if (!config.predict_disable_shape_check && ncol != boosting_->MaxFeatureIdx() + 1) {
Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).\n"\
"You can set ``predict_disable_shape_check=true`` to discard this error, but please be aware what you are doing.", ncol, boosting_->MaxFeatureIdx() + 1);
}
std::lock_guard<std::mutex> lock(mutex_);
if (single_row_predictor_[predict_type].get() == nullptr ||
!single_row_predictor_[predict_type]->IsPredictorEqual(config, num_iteration, boosting_.get())) {
single_row_predictor_[predict_type].reset(new SingleRowPredictor(predict_type, boosting_.get(),
config, num_iteration));
}
SHARED_LOCK(mutex_)
const auto& single_row_predictor = single_row_predictor_[predict_type];
auto one_row = get_row_fun(0);
auto pred_wrt_ptr = out_result;
single_row_predictor_[predict_type]->predict_function(one_row, pred_wrt_ptr);
single_row_predictor->predict_function(one_row, pred_wrt_ptr);
*out_len = single_row_predictor_[predict_type]->num_pred_in_one_row;
*out_len = single_row_predictor->num_pred_in_one_row;
}
Predictor CreatePredictor(int num_iteration, int predict_type, int ncol, const Config& config) {
Predictor CreatePredictor(int num_iteration, int predict_type, int ncol, const Config& config) const {
if (!config.predict_disable_shape_check && ncol != boosting_->MaxFeatureIdx() + 1) {
Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).\n" \
"You can set ``predict_disable_shape_check=true`` to discard this error, but please be aware what you are doing.", ncol, boosting_->MaxFeatureIdx() + 1);
......@@ -408,8 +421,8 @@ class Booster {
void Predict(int num_iteration, int predict_type, int nrow, int ncol,
std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
const Config& config,
double* out_result, int64_t* out_len) {
std::lock_guard<std::mutex> lock(mutex_);
double* out_result, int64_t* out_len) const {
SHARED_LOCK(mutex_);
auto predictor = CreatePredictor(num_iteration, predict_type, ncol, config);
bool is_predict_leaf = false;
bool predict_contrib = false;
......@@ -438,7 +451,7 @@ class Booster {
const Config& config, int64_t* out_elements_size,
std::vector<std::vector<std::unordered_map<int, double>>>* agg_ptr,
int32_t** out_indices, void** out_data, int data_type,
bool* is_data_float32_ptr, int num_matrices) {
bool* is_data_float32_ptr, int num_matrices) const {
auto predictor = CreatePredictor(num_iteration, predict_type, ncol, config);
auto pred_sparse_fun = predictor.GetPredictSparseFunction();
std::vector<std::vector<std::unordered_map<int, double>>>& agg = *agg_ptr;
......@@ -479,8 +492,8 @@ class Booster {
std::function<std::vector<std::pair<int, double>>(int64_t row_idx)> get_row_fun,
const Config& config,
int64_t* out_len, void** out_indptr, int indptr_type,
int32_t** out_indices, void** out_data, int data_type) {
std::lock_guard<std::mutex> lock(mutex_);
int32_t** out_indices, void** out_data, int data_type) const {
SHARED_LOCK(mutex_);
// Get the number of trees per iteration (for multiclass scenario we output multiple sparse matrices)
int num_matrices = boosting_->NumModelPerIteration();
bool is_indptr_int32 = false;
......@@ -563,8 +576,8 @@ class Booster {
std::function<std::vector<std::pair<int, double>>(int64_t row_idx)> get_row_fun,
const Config& config,
int64_t* out_len, void** out_col_ptr, int col_ptr_type,
int32_t** out_indices, void** out_data, int data_type) {
std::lock_guard<std::mutex> lock(mutex_);
int32_t** out_indices, void** out_data, int data_type) const {
SHARED_LOCK(mutex_);
// Get the number of trees per iteration (for multiclass scenario we output multiple sparse matrices)
int num_matrices = boosting_->NumModelPerIteration();
auto predictor = CreatePredictor(num_iteration, predict_type, ncol, config);
......@@ -665,8 +678,8 @@ class Booster {
void Predict(int num_iteration, int predict_type, const char* data_filename,
int data_has_header, const Config& config,
const char* result_filename) {
std::lock_guard<std::mutex> lock(mutex_);
const char* result_filename) const {
SHARED_LOCK(mutex_)
bool is_predict_leaf = false;
bool is_raw_score = false;
bool predict_contrib = false;
......@@ -685,11 +698,11 @@ class Booster {
predictor.Predict(data_filename, result_filename, bool_data_has_header, config.predict_disable_shape_check);
}
void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) const {
boosting_->GetPredictAt(data_idx, out_result, out_len);
}
void SaveModelToFile(int start_iteration, int num_iteration, int feature_importance_type, const char* filename) {
void SaveModelToFile(int start_iteration, int num_iteration, int feature_importance_type, const char* filename) const {
boosting_->SaveModelToFile(start_iteration, num_iteration, feature_importance_type, filename);
}
......@@ -699,46 +712,48 @@ class Booster {
}
std::string SaveModelToString(int start_iteration, int num_iteration,
int feature_importance_type) {
int feature_importance_type) const {
return boosting_->SaveModelToString(start_iteration,
num_iteration, feature_importance_type);
}
std::string DumpModel(int start_iteration, int num_iteration,
int feature_importance_type) {
int feature_importance_type) const {
return boosting_->DumpModel(start_iteration, num_iteration,
feature_importance_type);
}
std::vector<double> FeatureImportance(int num_iteration, int importance_type) {
std::vector<double> FeatureImportance(int num_iteration, int importance_type) const {
return boosting_->FeatureImportance(num_iteration, importance_type);
}
double UpperBoundValue() const {
std::lock_guard<std::mutex> lock(mutex_);
SHARED_LOCK(mutex_)
return boosting_->GetUpperBoundValue();
}
double LowerBoundValue() const {
std::lock_guard<std::mutex> lock(mutex_);
SHARED_LOCK(mutex_)
return boosting_->GetLowerBoundValue();
}
double GetLeafValue(int tree_idx, int leaf_idx) const {
SHARED_LOCK(mutex_)
return dynamic_cast<GBDTBase*>(boosting_.get())->GetLeafValue(tree_idx, leaf_idx);
}
void SetLeafValue(int tree_idx, int leaf_idx, double val) {
std::lock_guard<std::mutex> lock(mutex_);
UNIQUE_LOCK(mutex_)
dynamic_cast<GBDTBase*>(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val);
}
void ShuffleModels(int start_iter, int end_iter) {
std::lock_guard<std::mutex> lock(mutex_);
UNIQUE_LOCK(mutex_)
boosting_->ShuffleModels(start_iter, end_iter);
}
int GetEvalCounts() const {
SHARED_LOCK(mutex_)
int ret = 0;
for (const auto& metric : train_metric_) {
ret += static_cast<int>(metric->GetName().size());
......@@ -747,6 +762,7 @@ class Booster {
}
int GetEvalNames(char** out_strs, const int len, const size_t buffer_len, size_t *out_buffer_len) const {
SHARED_LOCK(mutex_)
*out_buffer_len = 0;
int idx = 0;
for (const auto& metric : train_metric_) {
......@@ -763,6 +779,7 @@ class Booster {
}
int GetFeatureNames(char** out_strs, const int len, const size_t buffer_len, size_t *out_buffer_len) const {
SHARED_LOCK(mutex_)
*out_buffer_len = 0;
int idx = 0;
for (const auto& name : boosting_->FeatureNames()) {
......@@ -792,7 +809,7 @@ class Booster {
/*! \brief Training objective function */
std::unique_ptr<ObjectiveFunction> objective_fun_;
/*! \brief mutex for threading safe call */
mutable std::mutex mutex_;
mutable yamc::alternate::shared_mutex mutex_;
};
} // namespace LightGBM
......@@ -1916,7 +1933,8 @@ int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle,
}
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto get_row_fun = RowFunctionFromCSR<int>(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
ref_booster->PredictSingleRow(num_iteration, predict_type, static_cast<int32_t>(num_col), get_row_fun, config, out_result, out_len);
ref_booster->SetSingleRowPredictor(num_iteration, predict_type, config);
ref_booster->PredictSingleRow(predict_type, static_cast<int32_t>(num_col), get_row_fun, config, out_result, out_len);
API_END();
}
......@@ -1960,7 +1978,7 @@ int LGBM_BoosterPredictForCSRSingleRowFast(FastConfigHandle fastConfig_handle,
API_BEGIN();
FastConfig *fastConfig = reinterpret_cast<FastConfig*>(fastConfig_handle);
auto get_row_fun = RowFunctionFromCSR<int>(indptr, indptr_type, indices, data, fastConfig->data_type, nindptr, nelem);
fastConfig->booster->PredictSingleRow(num_iteration, predict_type, fastConfig->ncol,
fastConfig->booster->PredictSingleRow(predict_type, fastConfig->ncol,
get_row_fun, fastConfig->config, out_result, out_len);
API_END();
}
......@@ -2058,7 +2076,8 @@ int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle,
}
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto get_row_fun = RowPairFunctionFromDenseMatric(data, 1, ncol, data_type, is_row_major);
ref_booster->PredictSingleRow(num_iteration, predict_type, ncol, get_row_fun, config, out_result, out_len);
ref_booster->SetSingleRowPredictor(num_iteration, predict_type, config);
ref_booster->PredictSingleRow(predict_type, ncol, get_row_fun, config, out_result, out_len);
API_END();
}
......@@ -2092,7 +2111,7 @@ int LGBM_BoosterPredictForMatSingleRowFast(FastConfigHandle fastConfig_handle,
FastConfig *fastConfig = reinterpret_cast<FastConfig*>(fastConfig_handle);
// Single row in row-major format:
auto get_row_fun = RowPairFunctionFromDenseMatric(data, 1, fastConfig->ncol, fastConfig->data_type, 1);
fastConfig->booster->PredictSingleRow(num_iteration, predict_type, fastConfig->ncol,
fastConfig->booster->PredictSingleRow(predict_type, fastConfig->ncol,
get_row_fun, fastConfig->config,
out_result, out_len);
API_END();
......
......@@ -244,6 +244,7 @@
<ClInclude Include="..\include\LightGBM\prediction_early_stop.h" />
<ClInclude Include="..\include\LightGBM\tree.h" />
<ClInclude Include="..\include\LightGBM\tree_learner.h" />
<ClInclude Include="..\include\LightGBM\utils\yamc\alternate_shared_mutex.hpp" />
<ClInclude Include="..\include\LightGBM\utils\array_args.h" />
<ClInclude Include="..\include\LightGBM\utils\common.h" />
<ClInclude Include="..\include\LightGBM\utils\file_io.h" />
......@@ -255,6 +256,8 @@
<ClInclude Include="..\include\LightGBM\utils\random.h" />
<ClInclude Include="..\include\LightGBM\utils\text_reader.h" />
<ClInclude Include="..\include\LightGBM\utils\threading.h" />
<ClInclude Include="..\include\LightGBM\utils\yamc\yamc_rwlock_sched.hpp" />
<ClInclude Include="..\include\LightGBM\utils\yamc\yamc_shared_lock.hpp" />
<ClInclude Include="..\src\application\predictor.hpp" />
<ClInclude Include="..\src\boosting\gbdt.h" />
<ClInclude Include="..\src\boosting\dart.hpp" />
......
......@@ -210,6 +210,15 @@
<ClInclude Include="..\src\treelearner\col_sampler.hpp">
<Filter>src\treelearner</Filter>
</ClInclude>
<ClInclude Include="..\include\LightGBM\utils\yamc\alternate_shared_mutex.hpp">
<Filter>include\LightGBM\utils\yamc</Filter>
</ClInclude>
<ClInclude Include="..\include\LightGBM\utils\yamc\yamc_rwlock_sched.hpp">
<Filter>include\LightGBM\utils\yamc</Filter>
</ClInclude>
<ClInclude Include="..\include\LightGBM\utils\yamc\yamc_shared_lock.hpp">
<Filter>include\LightGBM\utils\yamc</Filter>
</ClInclude>
</ItemGroup>
<ItemGroup>
<ClCompile Include="..\src\application\application.cpp">
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment