"tests/vscode:/vscode.git/clone" did not exist on "2e962c779f0d4f44da46c5be9851604b9acb5708"
Unverified Commit 81e2485a authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

add indices in shuffle model. (#1710)

* add indexs in shuffle model.

* fix pep

* fix bug
parent 172caee1
......@@ -47,7 +47,7 @@ public:
/*!
* \brief Shuffle Existing Models
*/
virtual void ShuffleModels() = 0;
virtual void ShuffleModels(int start_iter, int end_iter) = 0;
virtual void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) = 0;
......
......@@ -374,7 +374,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterFree(BoosterHandle handle);
/*!
* \brief Shuffle Models
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterShuffleModels(BoosterHandle handle);
LIGHTGBM_C_EXPORT int LGBM_BoosterShuffleModels(BoosterHandle handle, int start_iter, int end_iter);
/*!
* \brief Merge model in two booster to first handle
......
......@@ -1945,15 +1945,26 @@ class Booster(object):
_save_pandas_categorical(filename, self.pandas_categorical)
return self
def shuffle_models(self):
def shuffle_models(self, start_iteration=0, end_iteration=-1):
"""Shuffle models.
Parameters
----------
start_iteration : int, optional (default=0)
Index of the iteration that will start to shuffle.
end_iteration : int, optional (default=-1)
The last iteration that will be shuffled.
If <= 0, means the last iteration.
Returns
-------
self : Booster
Booster with shuffled models.
"""
_safe_call(_LIB.LGBM_BoosterShuffleModels(self.handle))
_safe_call(_LIB.LGBM_BoosterShuffleModels(
self.handle,
ctypes.c_int(start_iter),
ctypes.c_int(end_iter)))
return self
def model_from_string(self, model_str, verbose=True):
......
......@@ -70,16 +70,21 @@ public:
num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_tree_per_iteration_;
}
void ShuffleModels() override {
void ShuffleModels(int start_iter, int end_iter) override {
int total_iter = static_cast<int>(models_.size()) / num_tree_per_iteration_;
start_iter = std::max(0, start_iter);
if (end_iter <= 0) {
end_iter = total_iter;
}
end_iter = std::min(total_iter, end_iter);
auto original_models = std::move(models_);
std::vector<int> indices(total_iter);
for (int i = 0; i < total_iter; ++i) {
indices[i] = i;
}
Random tmp_rand(17);
for (int i = 0; i < total_iter - 1; ++i) {
int j = tmp_rand.NextShort(i + 1, total_iter);
for (int i = start_iter; i < end_iter - 1; ++i) {
int j = tmp_rand.NextShort(i + 1, end_iter);
std::swap(indices[i], indices[j]);
}
models_ = std::vector<std::unique_ptr<Tree>>();
......
......@@ -294,9 +294,9 @@ public:
dynamic_cast<GBDTBase*>(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val);
}
void ShuffleModels() {
void ShuffleModels(int start_iter, int end_iter) {
std::lock_guard<std::mutex> lock(mutex_);
boosting_->ShuffleModels();
boosting_->ShuffleModels(start_iter, end_iter);
}
int GetEvalCounts() const {
......@@ -919,10 +919,10 @@ int LGBM_BoosterFree(BoosterHandle handle) {
API_END();
}
int LGBM_BoosterShuffleModels(BoosterHandle handle) {
int LGBM_BoosterShuffleModels(BoosterHandle handle, int start_iter, int end_iter) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->ShuffleModels();
ref_booster->ShuffleModels(start_iter, end_iter);
API_END();
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment