Unverified Commit 81e2485a authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

add indices in shuffle model. (#1710)

* add indexs in shuffle model.

* fix pep

* fix bug
parent 172caee1
...@@ -47,7 +47,7 @@ public: ...@@ -47,7 +47,7 @@ public:
/*! /*!
* \brief Shuffle Existing Models * \brief Shuffle Existing Models
*/ */
virtual void ShuffleModels() = 0; virtual void ShuffleModels(int start_iter, int end_iter) = 0;
virtual void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* objective_function, virtual void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) = 0; const std::vector<const Metric*>& training_metrics) = 0;
......
...@@ -374,7 +374,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterFree(BoosterHandle handle); ...@@ -374,7 +374,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterFree(BoosterHandle handle);
/*! /*!
* \brief Shuffle Models * \brief Shuffle Models
*/ */
LIGHTGBM_C_EXPORT int LGBM_BoosterShuffleModels(BoosterHandle handle); LIGHTGBM_C_EXPORT int LGBM_BoosterShuffleModels(BoosterHandle handle, int start_iter, int end_iter);
/*! /*!
* \brief Merge model in two booster to first handle * \brief Merge model in two booster to first handle
......
...@@ -1945,15 +1945,26 @@ class Booster(object): ...@@ -1945,15 +1945,26 @@ class Booster(object):
_save_pandas_categorical(filename, self.pandas_categorical) _save_pandas_categorical(filename, self.pandas_categorical)
return self return self
def shuffle_models(self): def shuffle_models(self, start_iteration=0, end_iteration=-1):
"""Shuffle models. """Shuffle models.
Parameters
----------
start_iteration : int, optional (default=0)
Index of the iteration that will start to shuffle.
end_iteration : int, optional (default=-1)
The last iteration that will be shuffled.
If <= 0, means the last iteration.
Returns Returns
------- -------
self : Booster self : Booster
Booster with shuffled models. Booster with shuffled models.
""" """
_safe_call(_LIB.LGBM_BoosterShuffleModels(self.handle)) _safe_call(_LIB.LGBM_BoosterShuffleModels(
self.handle,
ctypes.c_int(start_iter),
ctypes.c_int(end_iter)))
return self return self
def model_from_string(self, model_str, verbose=True): def model_from_string(self, model_str, verbose=True):
......
...@@ -70,16 +70,21 @@ public: ...@@ -70,16 +70,21 @@ public:
num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_tree_per_iteration_; num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_tree_per_iteration_;
} }
void ShuffleModels() override { void ShuffleModels(int start_iter, int end_iter) override {
int total_iter = static_cast<int>(models_.size()) / num_tree_per_iteration_; int total_iter = static_cast<int>(models_.size()) / num_tree_per_iteration_;
start_iter = std::max(0, start_iter);
if (end_iter <= 0) {
end_iter = total_iter;
}
end_iter = std::min(total_iter, end_iter);
auto original_models = std::move(models_); auto original_models = std::move(models_);
std::vector<int> indices(total_iter); std::vector<int> indices(total_iter);
for (int i = 0; i < total_iter; ++i) { for (int i = 0; i < total_iter; ++i) {
indices[i] = i; indices[i] = i;
} }
Random tmp_rand(17); Random tmp_rand(17);
for (int i = 0; i < total_iter - 1; ++i) { for (int i = start_iter; i < end_iter - 1; ++i) {
int j = tmp_rand.NextShort(i + 1, total_iter); int j = tmp_rand.NextShort(i + 1, end_iter);
std::swap(indices[i], indices[j]); std::swap(indices[i], indices[j]);
} }
models_ = std::vector<std::unique_ptr<Tree>>(); models_ = std::vector<std::unique_ptr<Tree>>();
......
...@@ -294,9 +294,9 @@ public: ...@@ -294,9 +294,9 @@ public:
dynamic_cast<GBDTBase*>(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val); dynamic_cast<GBDTBase*>(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val);
} }
void ShuffleModels() { void ShuffleModels(int start_iter, int end_iter) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
boosting_->ShuffleModels(); boosting_->ShuffleModels(start_iter, end_iter);
} }
int GetEvalCounts() const { int GetEvalCounts() const {
...@@ -919,10 +919,10 @@ int LGBM_BoosterFree(BoosterHandle handle) { ...@@ -919,10 +919,10 @@ int LGBM_BoosterFree(BoosterHandle handle) {
API_END(); API_END();
} }
int LGBM_BoosterShuffleModels(BoosterHandle handle) { int LGBM_BoosterShuffleModels(BoosterHandle handle, int start_iter, int end_iter) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->ShuffleModels(); ref_booster->ShuffleModels(start_iter, end_iter);
API_END(); API_END();
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment