"git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "1c8355c975c1e02388202c748308a1952d3fd571"
Unverified Commit 1c35c3b9 authored by Joan Fontanals's avatar Joan Fontanals Committed by GitHub
Browse files

Change locking strategy of Booster, allow for share and unique locks (#2760)



* Add capability to get possible max and min values for a model

* Change implementation to have return value in tree.cpp, change naming to upper and lower bound, move implementation to gdbt.cpp

* Update include/LightGBM/c_api.h
Co-Authored-By: default avatarNikita Titov <nekit94-08@mail.ru>

* Change iteration to avoid potential overflow, add bindings to R and Python and a basic test

* Adjust test values

* Consider const correctness and multithreading protection

* Put everything possible as const

* Include shared_mutex, for now as unique_lock

* Update test values

* Put everything possible as const

* Include shared_mutex, for now as unique_lock

* Make PredictSingleRow const and share the lock with other reading threads

* Update test values

* Add test to check that model is exactly the same in all platforms

* Try to parse the model to get the expected values

* Try to parse the model to get the expected values

* Fix implementation, num_leaves can be lower than the leaf_value_ size

* Do not check for num_leaves to be smaller than actual size and get back to test with hardcoded value

* Change test order

* Add gpu_use_dp option in test

* Remove helper test method

* Remove TODO

* Add preprocessing option to compile with c++17

* Update python-package/setup.py
Co-Authored-By: default avatarNikita Titov <nekit94-08@mail.ru>

* Remove unwanted changes

* Move option

* Fix problems introduced by conflict fix

* Avoid switching to c++17 and use yamc mutex library to access shared lock functionality

* Add extra yamc include

* Change header order

* some lint fix

* change include order and remove some extra blank lines

* Further fix lint issues

* Update c_api.cpp

* Further fix lint issues

* Move yamc include files to a new yamc folder

* Use standard unique_lock

* Update windows/LightGBM.vcxproj
Co-authored-by: default avatarGuolin Ke <guolin.ke@outlook.com>

* Update windows/LightGBM.vcxproj.filters
Co-authored-by: default avatarGuolin Ke <guolin.ke@outlook.com>

* Update windows/LightGBM.vcxproj.filters
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* Update windows/LightGBM.vcxproj.filters
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* Update windows/LightGBM.vcxproj.filters
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* Fix problems coming from merge conflict resolution
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
Co-authored-by: default avatarjoanfontanals <jfontanals@ntent.com>
Co-authored-by: default avatarGuolin Ke <guolin.ke@outlook.com>
parent f5f27ca8
/*
* alternate_shared_mutex.hpp
*
* MIT License
*
* Copyright (c) 2017 yohhoy
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#ifndef YAMC_ALTERNATE_SHARED_MUTEX_HPP_
#define YAMC_ALTERNATE_SHARED_MUTEX_HPP_
#include <cassert>
#include <chrono>
#include <condition_variable>
#include <mutex>
#include "yamc_rwlock_sched.hpp"
namespace yamc {
/*
* alternate implementation of shared mutex variants
*
* - yamc::alternate::shared_mutex
* - yamc::alternate::shared_timed_mutex
* - yamc::alternate::basic_shared_mutex<RwLockPolicy>
* - yamc::alternate::basic_shared_timed_mutex<RwLockPolicy>
*/
namespace alternate {
namespace detail {
template <typename RwLockPolicy>
class shared_mutex_base {
protected:
typename RwLockPolicy::state state_;
std::condition_variable cv_;
std::mutex mtx_;
void lock() {
std::unique_lock<decltype(mtx_)> lk(mtx_);
RwLockPolicy::before_wait_wlock(state_);
while (RwLockPolicy::wait_wlock(state_)) {
cv_.wait(lk);
}
RwLockPolicy::after_wait_wlock(state_);
RwLockPolicy::acquire_wlock(&state_);
}
bool try_lock() {
std::lock_guard<decltype(mtx_)> lk(mtx_);
if (RwLockPolicy::wait_wlock(state_)) return false;
RwLockPolicy::acquire_wlock(state_);
return true;
}
void unlock() {
std::lock_guard<decltype(mtx_)> lk(mtx_);
RwLockPolicy::release_wlock(&state_);
cv_.notify_all();
}
void lock_shared() {
std::unique_lock<decltype(mtx_)> lk(mtx_);
while (RwLockPolicy::wait_rlock(state_)) {
cv_.wait(lk);
}
RwLockPolicy::acquire_rlock(&state_);
}
bool try_lock_shared() {
std::lock_guard<decltype(mtx_)> lk(mtx_);
if (RwLockPolicy::wait_rlock(state_)) return false;
RwLockPolicy::acquire_rlock(state_);
return true;
}
void unlock_shared() {
std::lock_guard<decltype(mtx_)> lk(mtx_);
if (RwLockPolicy::release_rlock(&state_)) {
cv_.notify_all();
}
}
};
} // namespace detail
template <typename RwLockPolicy>
class basic_shared_mutex : private detail::shared_mutex_base<RwLockPolicy> {
using base = detail::shared_mutex_base<RwLockPolicy>;
public:
basic_shared_mutex() = default;
~basic_shared_mutex() = default;
basic_shared_mutex(const basic_shared_mutex&) = delete;
basic_shared_mutex& operator=(const basic_shared_mutex&) = delete;
using base::lock;
using base::try_lock;
using base::unlock;
using base::lock_shared;
using base::try_lock_shared;
using base::unlock_shared;
};
using shared_mutex = basic_shared_mutex<YAMC_RWLOCK_SCHED_DEFAULT>;
template <typename RwLockPolicy>
class basic_shared_timed_mutex
: private detail::shared_mutex_base<RwLockPolicy> {
using base = detail::shared_mutex_base<RwLockPolicy>;
using base::cv_;
using base::mtx_;
using base::state_;
template <typename Clock, typename Duration>
bool do_try_lockwait(const std::chrono::time_point<Clock, Duration>& tp) {
std::unique_lock<decltype(mtx_)> lk(mtx_);
RwLockPolicy::before_wait_wlock(state_);
while (RwLockPolicy::wait_wlock(state_)) {
if (cv_.wait_until(lk, tp) == std::cv_status::timeout) {
if (!RwLockPolicy::wait_wlock(state_)) // re-check predicate
break;
RwLockPolicy::after_wait_wlock(state_);
return false;
}
}
RwLockPolicy::after_wait_wlock(state_);
RwLockPolicy::acquire_wlock(state_);
return true;
}
template <typename Clock, typename Duration>
bool do_try_lock_sharedwait(
const std::chrono::time_point<Clock, Duration>& tp) {
std::unique_lock<decltype(mtx_)> lk(mtx_);
while (RwLockPolicy::wait_rlock(state_)) {
if (cv_.wait_until(lk, tp) == std::cv_status::timeout) {
if (!RwLockPolicy::wait_rlock(state_)) // re-check predicate
break;
return false;
}
}
RwLockPolicy::acquire_rlock(state_);
return true;
}
public:
basic_shared_timed_mutex() = default;
~basic_shared_timed_mutex() = default;
basic_shared_timed_mutex(const basic_shared_timed_mutex&) = delete;
basic_shared_timed_mutex& operator=(const basic_shared_timed_mutex&) = delete;
using base::lock;
using base::try_lock;
using base::unlock;
template <typename Rep, typename Period>
bool try_lock_for(const std::chrono::duration<Rep, Period>& duration) {
const auto tp = std::chrono::steady_clock::now() + duration;
return do_try_lockwait(tp);
}
template <typename Clock, typename Duration>
bool try_lock_until(const std::chrono::time_point<Clock, Duration>& tp) {
return do_try_lockwait(tp);
}
using base::lock_shared;
using base::try_lock_shared;
using base::unlock_shared;
template <typename Rep, typename Period>
bool try_lock_shared_for(const std::chrono::duration<Rep, Period>& duration) {
const auto tp = std::chrono::steady_clock::now() + duration;
return do_try_lock_sharedwait(tp);
}
template <typename Clock, typename Duration>
bool try_lock_shared_until(
const std::chrono::time_point<Clock, Duration>& tp) {
return do_try_lock_sharedwait(tp);
}
};
using shared_timed_mutex = basic_shared_timed_mutex<YAMC_RWLOCK_SCHED_DEFAULT>;
} // namespace alternate
} // namespace yamc
#endif
/*
* yamc_rwlock_sched.hpp
*
* MIT License
*
* Copyright (c) 2017 yohhoy
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#ifndef YAMC_RWLOCK_SCHED_HPP_
#define YAMC_RWLOCK_SCHED_HPP_
#include <cassert>
#include <cstddef>
/// default shared_mutex rwlock policy
#ifndef YAMC_RWLOCK_SCHED_DEFAULT
#define YAMC_RWLOCK_SCHED_DEFAULT yamc::rwlock::ReaderPrefer
#endif
namespace yamc {
/*
* readers-writer locking policy for basic_shared_(timed)_mutex<RwLockPolicy>
*
* - yamc::rwlock::ReaderPrefer
* - yamc::rwlock::WriterPrefer
*/
namespace rwlock {
/// Reader prefer scheduling
///
/// NOTE:
// This policy might introduce "Writer Starvation" if readers continuously
// hold shared lock. PThreads rwlock implementation in Linux use this
// scheduling policy as default. (see also PTHREAD_RWLOCK_PREFER_READER_NP)
//
struct ReaderPrefer {
static const std::size_t writer_mask = ~(~std::size_t(0u) >> 1); // MSB 1bit
static const std::size_t reader_mask = ~std::size_t(0u) >> 1;
struct state {
std::size_t rwcount = 0;
};
static void before_wait_wlock(const state&) {}
static void after_wait_wlock(const state&) {}
static bool wait_wlock(const state& s) { return (s.rwcount != 0); }
static void acquire_wlock(state* s) {
assert(!(s->rwcount & writer_mask));
s->rwcount |= writer_mask;
}
static void release_wlock(state* s) {
assert(s->rwcount & writer_mask);
s->rwcount &= ~writer_mask;
}
static bool wait_rlock(const state& s) { return (s.rwcount & writer_mask) != 0; }
static void acquire_rlock(state* s) {
assert((s->rwcount & reader_mask) < reader_mask);
++(s->rwcount);
}
static bool release_rlock(state* s) {
assert(0 < (s->rwcount & reader_mask));
return (--(s->rwcount) == 0);
}
};
/// Writer prefer scheduling
///
/// NOTE:
/// If there are waiting writer, new readers are blocked until all shared lock
/// are released,
// and the writer thread can get exclusive lock in preference to blocked
// reader threads. This policy might introduce "Reader Starvation" if writers
// continuously request exclusive lock.
/// (see also PTHREAD_RWLOCK_PREFER_WRITER_NONRECURSIVE_NP)
///
struct WriterPrefer {
static const std::size_t locked = ~(~std::size_t(0u) >> 1); // MSB 1bit
static const std::size_t wait_mask = ~std::size_t(0u) >> 1;
struct state {
std::size_t nwriter = 0;
std::size_t nreader = 0;
};
static void before_wait_wlock(state* s) {
assert((s->nwriter & wait_mask) < wait_mask);
++(s->nwriter);
}
static bool wait_wlock(const state& s) {
return ((s.nwriter & locked) || 0 < s.nreader);
}
static void after_wait_wlock(state* s) {
assert(0 < (s->nwriter & wait_mask));
--(s->nwriter);
}
static void acquire_wlock(state* s) {
assert(!(s->nwriter & locked));
s->nwriter |= locked;
}
static void release_wlock(state* s) {
assert(s->nwriter & locked);
s->nwriter &= ~locked;
}
static bool wait_rlock(const state& s) { return (s.nwriter != 0); }
static void acquire_rlock(state* s) {
assert(!(s->nwriter & locked));
++(s->nreader);
}
static bool release_rlock(state* s) {
assert(0 < s->nreader);
return (--(s->nreader) == 0);
}
};
} // namespace rwlock
} // namespace yamc
#endif
/*
* yamc_shared_lock.hpp
*
* MIT License
*
* Copyright (c) 2017 yohhoy
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#ifndef YAMC_SHARED_LOCK_HPP_
#define YAMC_SHARED_LOCK_HPP_
#include <cassert>
#include <chrono>
#include <mutex>
#include <system_error>
#include <utility> // std::swap
/*
* std::shared_lock in C++14 Standard Library
*
* - yamc::shared_lock<Mutex>
*/
namespace yamc {
template <typename Mutex>
class shared_lock {
void locking_precondition(const char* emsg) {
if (pm_ == nullptr) {
throw std::system_error(
std::make_error_code(std::errc::operation_not_permitted), emsg);
}
if (owns_) {
throw std::system_error(
std::make_error_code(std::errc::resource_deadlock_would_occur), emsg);
}
}
public:
using mutex_type = Mutex;
shared_lock() noexcept = default;
explicit shared_lock(mutex_type* m) {
m->lock_shared();
pm_ = m;
owns_ = true;
}
shared_lock(const mutex_type& m, std::defer_lock_t) noexcept {
pm_ = &m;
owns_ = false;
}
shared_lock(const mutex_type& m, std::try_to_lock_t) {
pm_ = &m;
owns_ = m.try_lock_shared();
}
shared_lock(const mutex_type& m, std::adopt_lock_t) {
pm_ = &m;
owns_ = true;
}
template <typename Clock, typename Duration>
shared_lock(const mutex_type& m,
const std::chrono::time_point<Clock, Duration>& abs_time) {
pm_ = &m;
owns_ = m.try_lock_shared_until(abs_time);
}
template <typename Rep, typename Period>
shared_lock(const mutex_type& m,
const std::chrono::duration<Rep, Period>& rel_time) {
pm_ = &m;
owns_ = m.try_lock_shared_for(rel_time);
}
~shared_lock() {
if (owns_) {
assert(pm_ != nullptr);
pm_->unlock_shared();
}
}
shared_lock(const shared_lock&) = delete;
shared_lock& operator=(const shared_lock&) = delete;
shared_lock(shared_lock&& rhs) noexcept {
if (pm_ && owns_) {
pm_->unlock_shared();
}
pm_ = rhs.pm_;
owns_ = rhs.owns_;
rhs.pm_ = nullptr;
rhs.owns_ = false;
}
shared_lock& operator=(shared_lock&& rhs) noexcept {
if (pm_ && owns_) {
pm_->unlock_shared();
}
pm_ = rhs.pm_;
owns_ = rhs.owns_;
rhs.pm_ = nullptr;
rhs.owns_ = false;
return *this;
}
void lock() {
locking_precondition("shared_lock::lock");
pm_->lock_shared();
owns_ = true;
}
bool try_lock() {
locking_precondition("shared_lock::try_lock");
return (owns_ = pm_->try_lock_shared());
}
template <typename Rep, typename Period>
bool try_lock_for(const std::chrono::duration<Rep, Period>& rel_time) {
locking_precondition("shared_lock::try_lock_for");
return (owns_ = pm_->try_lock_shared_for(rel_time));
}
template <typename Clock, typename Duration>
bool try_lock_until(
const std::chrono::time_point<Clock, Duration>& abs_time) {
locking_precondition("shared_lock::try_lock_until");
return (owns_ = pm_->try_lock_shared_until(abs_time));
}
void unlock() {
assert(pm_ != nullptr);
if (!owns_) {
throw std::system_error(
std::make_error_code(std::errc::operation_not_permitted),
"shared_lock::unlock");
}
pm_->unlock_shared();
owns_ = false;
}
void swap(shared_lock& sl) noexcept {
std::swap(pm_, sl.pm_);
std::swap(owns_, sl.owns_);
}
mutex_type* release() noexcept {
mutex_type* result = pm_;
pm_ = nullptr;
owns_ = false;
return result;
}
bool owns_lock() const noexcept { return owns_; }
explicit operator bool() const noexcept { return owns_; }
mutex_type* mutex() const noexcept { return pm_; }
private:
mutex_type* pm_ = nullptr;
bool owns_ = false;
};
} // namespace yamc
namespace std {
/// std::swap() specialization for yamc::shared_lock<Mutex> type
template <typename Mutex>
void swap(yamc::shared_lock<Mutex>& lhs,
yamc::shared_lock<Mutex>& rhs) noexcept {
lhs.swap(rhs);
}
} // namespace std
#endif
......@@ -27,6 +27,8 @@
#include <vector>
#include "application/predictor.hpp"
#include <LightGBM/utils/yamc/alternate_shared_mutex.hpp>
#include <LightGBM/utils/yamc/yamc_shared_lock.hpp>
namespace LightGBM {
......@@ -46,6 +48,12 @@ catch(std::string& ex) { return LGBM_APIHandleException(ex); } \
catch(...) { return LGBM_APIHandleException("unknown exception"); } \
return 0;
#define UNIQUE_LOCK(mtx) \
std::unique_lock<yamc::alternate::shared_mutex> lock(mtx);
#define SHARED_LOCK(mtx) \
yamc::shared_lock<yamc::alternate::shared_mutex> lock(&mtx);
const int PREDICTOR_TYPES = 4;
// Single row predictor to abstract away caching logic
......@@ -133,7 +141,7 @@ class Booster {
}
void MergeFrom(const Booster* other) {
std::lock_guard<std::mutex> lock(mutex_);
UNIQUE_LOCK(mutex_)
boosting_->MergeFrom(other->boosting_.get());
}
......@@ -166,7 +174,7 @@ class Booster {
void ResetTrainingData(const Dataset* train_data) {
if (train_data != train_data_) {
std::lock_guard<std::mutex> lock(mutex_);
UNIQUE_LOCK(mutex_)
train_data_ = train_data;
CreateObjectiveAndMetrics();
// reset the boosting
......@@ -284,7 +292,7 @@ class Booster {
}
void ResetConfig(const char* parameters) {
std::lock_guard<std::mutex> lock(mutex_);
UNIQUE_LOCK(mutex_)
auto param = Config::Str2Map(parameters);
if (param.count("num_class")) {
Log::Fatal("Cannot change num_class during training");
......@@ -322,7 +330,7 @@ class Booster {
}
void AddValidData(const Dataset* valid_data) {
std::lock_guard<std::mutex> lock(mutex_);
UNIQUE_LOCK(mutex_)
valid_metrics_.emplace_back();
for (auto metric_type : config_.metric) {
auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_));
......@@ -336,12 +344,12 @@ class Booster {
}
bool TrainOneIter() {
std::lock_guard<std::mutex> lock(mutex_);
UNIQUE_LOCK(mutex_)
return boosting_->TrainOneIter(nullptr, nullptr);
}
void Refit(const int32_t* leaf_preds, int32_t nrow, int32_t ncol) {
std::lock_guard<std::mutex> lock(mutex_);
UNIQUE_LOCK(mutex_)
std::vector<std::vector<int32_t>> v_leaf_preds(nrow, std::vector<int32_t>(ncol, 0));
for (int i = 0; i < nrow; ++i) {
for (int j = 0; j < ncol; ++j) {
......@@ -352,37 +360,42 @@ class Booster {
}
bool TrainOneIter(const score_t* gradients, const score_t* hessians) {
std::lock_guard<std::mutex> lock(mutex_);
UNIQUE_LOCK(mutex_)
return boosting_->TrainOneIter(gradients, hessians);
}
void RollbackOneIter() {
std::lock_guard<std::mutex> lock(mutex_);
UNIQUE_LOCK(mutex_)
boosting_->RollbackOneIter();
}
void PredictSingleRow(int num_iteration, int predict_type, int ncol,
void SetSingleRowPredictor(int num_iteration, int predict_type, const Config& config) {
UNIQUE_LOCK(mutex_)
if (single_row_predictor_[predict_type].get() == nullptr ||
!single_row_predictor_[predict_type]->IsPredictorEqual(config, num_iteration, boosting_.get())) {
single_row_predictor_[predict_type].reset(new SingleRowPredictor(predict_type, boosting_.get(),
config, num_iteration));
}
}
void PredictSingleRow(int predict_type, int ncol,
std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
const Config& config,
double* out_result, int64_t* out_len) {
double* out_result, int64_t* out_len) const {
if (!config.predict_disable_shape_check && ncol != boosting_->MaxFeatureIdx() + 1) {
Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).\n"\
"You can set ``predict_disable_shape_check=true`` to discard this error, but please be aware what you are doing.", ncol, boosting_->MaxFeatureIdx() + 1);
}
std::lock_guard<std::mutex> lock(mutex_);
if (single_row_predictor_[predict_type].get() == nullptr ||
!single_row_predictor_[predict_type]->IsPredictorEqual(config, num_iteration, boosting_.get())) {
single_row_predictor_[predict_type].reset(new SingleRowPredictor(predict_type, boosting_.get(),
config, num_iteration));
}
SHARED_LOCK(mutex_)
const auto& single_row_predictor = single_row_predictor_[predict_type];
auto one_row = get_row_fun(0);
auto pred_wrt_ptr = out_result;
single_row_predictor_[predict_type]->predict_function(one_row, pred_wrt_ptr);
single_row_predictor->predict_function(one_row, pred_wrt_ptr);
*out_len = single_row_predictor_[predict_type]->num_pred_in_one_row;
*out_len = single_row_predictor->num_pred_in_one_row;
}
Predictor CreatePredictor(int num_iteration, int predict_type, int ncol, const Config& config) {
Predictor CreatePredictor(int num_iteration, int predict_type, int ncol, const Config& config) const {
if (!config.predict_disable_shape_check && ncol != boosting_->MaxFeatureIdx() + 1) {
Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).\n" \
"You can set ``predict_disable_shape_check=true`` to discard this error, but please be aware what you are doing.", ncol, boosting_->MaxFeatureIdx() + 1);
......@@ -408,8 +421,8 @@ class Booster {
void Predict(int num_iteration, int predict_type, int nrow, int ncol,
std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
const Config& config,
double* out_result, int64_t* out_len) {
std::lock_guard<std::mutex> lock(mutex_);
double* out_result, int64_t* out_len) const {
SHARED_LOCK(mutex_);
auto predictor = CreatePredictor(num_iteration, predict_type, ncol, config);
bool is_predict_leaf = false;
bool predict_contrib = false;
......@@ -438,7 +451,7 @@ class Booster {
const Config& config, int64_t* out_elements_size,
std::vector<std::vector<std::unordered_map<int, double>>>* agg_ptr,
int32_t** out_indices, void** out_data, int data_type,
bool* is_data_float32_ptr, int num_matrices) {
bool* is_data_float32_ptr, int num_matrices) const {
auto predictor = CreatePredictor(num_iteration, predict_type, ncol, config);
auto pred_sparse_fun = predictor.GetPredictSparseFunction();
std::vector<std::vector<std::unordered_map<int, double>>>& agg = *agg_ptr;
......@@ -479,8 +492,8 @@ class Booster {
std::function<std::vector<std::pair<int, double>>(int64_t row_idx)> get_row_fun,
const Config& config,
int64_t* out_len, void** out_indptr, int indptr_type,
int32_t** out_indices, void** out_data, int data_type) {
std::lock_guard<std::mutex> lock(mutex_);
int32_t** out_indices, void** out_data, int data_type) const {
SHARED_LOCK(mutex_);
// Get the number of trees per iteration (for multiclass scenario we output multiple sparse matrices)
int num_matrices = boosting_->NumModelPerIteration();
bool is_indptr_int32 = false;
......@@ -563,8 +576,8 @@ class Booster {
std::function<std::vector<std::pair<int, double>>(int64_t row_idx)> get_row_fun,
const Config& config,
int64_t* out_len, void** out_col_ptr, int col_ptr_type,
int32_t** out_indices, void** out_data, int data_type) {
std::lock_guard<std::mutex> lock(mutex_);
int32_t** out_indices, void** out_data, int data_type) const {
SHARED_LOCK(mutex_);
// Get the number of trees per iteration (for multiclass scenario we output multiple sparse matrices)
int num_matrices = boosting_->NumModelPerIteration();
auto predictor = CreatePredictor(num_iteration, predict_type, ncol, config);
......@@ -665,8 +678,8 @@ class Booster {
void Predict(int num_iteration, int predict_type, const char* data_filename,
int data_has_header, const Config& config,
const char* result_filename) {
std::lock_guard<std::mutex> lock(mutex_);
const char* result_filename) const {
SHARED_LOCK(mutex_)
bool is_predict_leaf = false;
bool is_raw_score = false;
bool predict_contrib = false;
......@@ -685,11 +698,11 @@ class Booster {
predictor.Predict(data_filename, result_filename, bool_data_has_header, config.predict_disable_shape_check);
}
void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) const {
boosting_->GetPredictAt(data_idx, out_result, out_len);
}
void SaveModelToFile(int start_iteration, int num_iteration, int feature_importance_type, const char* filename) {
void SaveModelToFile(int start_iteration, int num_iteration, int feature_importance_type, const char* filename) const {
boosting_->SaveModelToFile(start_iteration, num_iteration, feature_importance_type, filename);
}
......@@ -699,46 +712,48 @@ class Booster {
}
std::string SaveModelToString(int start_iteration, int num_iteration,
int feature_importance_type) {
int feature_importance_type) const {
return boosting_->SaveModelToString(start_iteration,
num_iteration, feature_importance_type);
}
std::string DumpModel(int start_iteration, int num_iteration,
int feature_importance_type) {
int feature_importance_type) const {
return boosting_->DumpModel(start_iteration, num_iteration,
feature_importance_type);
}
std::vector<double> FeatureImportance(int num_iteration, int importance_type) {
std::vector<double> FeatureImportance(int num_iteration, int importance_type) const {
return boosting_->FeatureImportance(num_iteration, importance_type);
}
double UpperBoundValue() const {
std::lock_guard<std::mutex> lock(mutex_);
SHARED_LOCK(mutex_)
return boosting_->GetUpperBoundValue();
}
double LowerBoundValue() const {
std::lock_guard<std::mutex> lock(mutex_);
SHARED_LOCK(mutex_)
return boosting_->GetLowerBoundValue();
}
double GetLeafValue(int tree_idx, int leaf_idx) const {
SHARED_LOCK(mutex_)
return dynamic_cast<GBDTBase*>(boosting_.get())->GetLeafValue(tree_idx, leaf_idx);
}
void SetLeafValue(int tree_idx, int leaf_idx, double val) {
std::lock_guard<std::mutex> lock(mutex_);
UNIQUE_LOCK(mutex_)
dynamic_cast<GBDTBase*>(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val);
}
void ShuffleModels(int start_iter, int end_iter) {
std::lock_guard<std::mutex> lock(mutex_);
UNIQUE_LOCK(mutex_)
boosting_->ShuffleModels(start_iter, end_iter);
}
int GetEvalCounts() const {
SHARED_LOCK(mutex_)
int ret = 0;
for (const auto& metric : train_metric_) {
ret += static_cast<int>(metric->GetName().size());
......@@ -747,6 +762,7 @@ class Booster {
}
int GetEvalNames(char** out_strs, const int len, const size_t buffer_len, size_t *out_buffer_len) const {
SHARED_LOCK(mutex_)
*out_buffer_len = 0;
int idx = 0;
for (const auto& metric : train_metric_) {
......@@ -763,6 +779,7 @@ class Booster {
}
int GetFeatureNames(char** out_strs, const int len, const size_t buffer_len, size_t *out_buffer_len) const {
SHARED_LOCK(mutex_)
*out_buffer_len = 0;
int idx = 0;
for (const auto& name : boosting_->FeatureNames()) {
......@@ -792,7 +809,7 @@ class Booster {
/*! \brief Training objective function */
std::unique_ptr<ObjectiveFunction> objective_fun_;
/*! \brief mutex for threading safe call */
mutable std::mutex mutex_;
mutable yamc::alternate::shared_mutex mutex_;
};
} // namespace LightGBM
......@@ -1916,7 +1933,8 @@ int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle,
}
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto get_row_fun = RowFunctionFromCSR<int>(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
ref_booster->PredictSingleRow(num_iteration, predict_type, static_cast<int32_t>(num_col), get_row_fun, config, out_result, out_len);
ref_booster->SetSingleRowPredictor(num_iteration, predict_type, config);
ref_booster->PredictSingleRow(predict_type, static_cast<int32_t>(num_col), get_row_fun, config, out_result, out_len);
API_END();
}
......@@ -1960,7 +1978,7 @@ int LGBM_BoosterPredictForCSRSingleRowFast(FastConfigHandle fastConfig_handle,
API_BEGIN();
FastConfig *fastConfig = reinterpret_cast<FastConfig*>(fastConfig_handle);
auto get_row_fun = RowFunctionFromCSR<int>(indptr, indptr_type, indices, data, fastConfig->data_type, nindptr, nelem);
fastConfig->booster->PredictSingleRow(num_iteration, predict_type, fastConfig->ncol,
fastConfig->booster->PredictSingleRow(predict_type, fastConfig->ncol,
get_row_fun, fastConfig->config, out_result, out_len);
API_END();
}
......@@ -2058,7 +2076,8 @@ int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle,
}
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto get_row_fun = RowPairFunctionFromDenseMatric(data, 1, ncol, data_type, is_row_major);
ref_booster->PredictSingleRow(num_iteration, predict_type, ncol, get_row_fun, config, out_result, out_len);
ref_booster->SetSingleRowPredictor(num_iteration, predict_type, config);
ref_booster->PredictSingleRow(predict_type, ncol, get_row_fun, config, out_result, out_len);
API_END();
}
......@@ -2092,7 +2111,7 @@ int LGBM_BoosterPredictForMatSingleRowFast(FastConfigHandle fastConfig_handle,
FastConfig *fastConfig = reinterpret_cast<FastConfig*>(fastConfig_handle);
// Single row in row-major format:
auto get_row_fun = RowPairFunctionFromDenseMatric(data, 1, fastConfig->ncol, fastConfig->data_type, 1);
fastConfig->booster->PredictSingleRow(num_iteration, predict_type, fastConfig->ncol,
fastConfig->booster->PredictSingleRow(predict_type, fastConfig->ncol,
get_row_fun, fastConfig->config,
out_result, out_len);
API_END();
......
......@@ -244,6 +244,7 @@
<ClInclude Include="..\include\LightGBM\prediction_early_stop.h" />
<ClInclude Include="..\include\LightGBM\tree.h" />
<ClInclude Include="..\include\LightGBM\tree_learner.h" />
<ClInclude Include="..\include\LightGBM\utils\yamc\alternate_shared_mutex.hpp" />
<ClInclude Include="..\include\LightGBM\utils\array_args.h" />
<ClInclude Include="..\include\LightGBM\utils\common.h" />
<ClInclude Include="..\include\LightGBM\utils\file_io.h" />
......@@ -255,6 +256,8 @@
<ClInclude Include="..\include\LightGBM\utils\random.h" />
<ClInclude Include="..\include\LightGBM\utils\text_reader.h" />
<ClInclude Include="..\include\LightGBM\utils\threading.h" />
<ClInclude Include="..\include\LightGBM\utils\yamc\yamc_rwlock_sched.hpp" />
<ClInclude Include="..\include\LightGBM\utils\yamc\yamc_shared_lock.hpp" />
<ClInclude Include="..\src\application\predictor.hpp" />
<ClInclude Include="..\src\boosting\gbdt.h" />
<ClInclude Include="..\src\boosting\dart.hpp" />
......
......@@ -210,6 +210,15 @@
<ClInclude Include="..\src\treelearner\col_sampler.hpp">
<Filter>src\treelearner</Filter>
</ClInclude>
<ClInclude Include="..\include\LightGBM\utils\yamc\alternate_shared_mutex.hpp">
<Filter>include\LightGBM\utils\yamc</Filter>
</ClInclude>
<ClInclude Include="..\include\LightGBM\utils\yamc\yamc_rwlock_sched.hpp">
<Filter>include\LightGBM\utils\yamc</Filter>
</ClInclude>
<ClInclude Include="..\include\LightGBM\utils\yamc\yamc_shared_lock.hpp">
<Filter>include\LightGBM\utils\yamc</Filter>
</ClInclude>
</ItemGroup>
<ItemGroup>
<ClCompile Include="..\src\application\application.cpp">
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment