Unverified Commit 941068ee authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

add start_iteration in model saving (#1565)

* add start_iteration in model saving

* fix test

* shuffle models ability

* fix bug

* update document

* refine

* Update engine.py

* Update basic.py

* fix comments

* fix comment
parent 3400e389
...@@ -44,6 +44,11 @@ public: ...@@ -44,6 +44,11 @@ public:
*/ */
virtual void MergeFrom(const Boosting* other) = 0; virtual void MergeFrom(const Boosting* other) = 0;
/*!
* \brief Shuffle Existing Models
*/
virtual void ShuffleModels() = 0;
virtual void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* objective_function, virtual void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) = 0; const std::vector<const Metric*>& training_metrics) = 0;
...@@ -163,10 +168,11 @@ public: ...@@ -163,10 +168,11 @@ public:
/*! /*!
* \brief Dump model to json format string * \brief Dump model to json format string
* \param start_iteration The model will be saved start from
* \param num_iteration Number of iterations that want to dump, -1 means dump all * \param num_iteration Number of iterations that want to dump, -1 means dump all
* \return Json format string of model * \return Json format string of model
*/ */
virtual std::string DumpModel(int num_iteration) const = 0; virtual std::string DumpModel(int start_iteration, int num_iteration) const = 0;
/*! /*!
* \brief Translate model to if-else statement * \brief Translate model to if-else statement
...@@ -185,19 +191,21 @@ public: ...@@ -185,19 +191,21 @@ public:
/*! /*!
* \brief Save model to file * \brief Save model to file
* \param start_iteration The model will be saved start from
* \param num_iterations Number of model that want to save, -1 means save all * \param num_iterations Number of model that want to save, -1 means save all
* \param is_finish Is training finished or not * \param is_finish Is training finished or not
* \param filename Filename that want to save to * \param filename Filename that want to save to
* \return true if succeeded * \return true if succeeded
*/ */
virtual bool SaveModelToFile(int num_iterations, const char* filename) const = 0; virtual bool SaveModelToFile(int start_iteration, int num_iterations, const char* filename) const = 0;
/*! /*!
* \brief Save model to string * \brief Save model to string
* \param start_iteration The model will be saved start from
* \param num_iterations Number of model that want to save, -1 means save all * \param num_iterations Number of model that want to save, -1 means save all
* \return Non-empty string if succeeded * \return Non-empty string if succeeded
*/ */
virtual std::string SaveModelToString(int num_iterations) const = 0; virtual std::string SaveModelToString(int start_iteration, int num_iterations) const = 0;
/*! /*!
* \brief Restore from a serialized string * \brief Restore from a serialized string
......
...@@ -371,6 +371,11 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterLoadModelFromString( ...@@ -371,6 +371,11 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterLoadModelFromString(
*/ */
LIGHTGBM_C_EXPORT int LGBM_BoosterFree(BoosterHandle handle); LIGHTGBM_C_EXPORT int LGBM_BoosterFree(BoosterHandle handle);
/*!
* \brief Shuffle Models
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterShuffleModels(BoosterHandle handle);
/*! /*!
* \brief Merge model in two booster to first handle * \brief Merge model in two booster to first handle
* \param handle handle, will merge other handle to this * \param handle handle, will merge other handle to this
...@@ -682,6 +687,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMat(BoosterHandle handle, ...@@ -682,6 +687,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMat(BoosterHandle handle,
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle, LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle,
int start_iteration,
int num_iteration, int num_iteration,
const char* filename); const char* filename);
...@@ -695,6 +701,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle, ...@@ -695,6 +701,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModel(BoosterHandle handle,
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle, LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle,
int start_iteration,
int num_iteration, int num_iteration,
int64_t buffer_len, int64_t buffer_len,
int64_t* out_len, int64_t* out_len,
...@@ -710,6 +717,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle, ...@@ -710,6 +717,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSaveModelToString(BoosterHandle handle,
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
LIGHTGBM_C_EXPORT int LGBM_BoosterDumpModel(BoosterHandle handle, LIGHTGBM_C_EXPORT int LGBM_BoosterDumpModel(BoosterHandle handle,
int start_iteration,
int num_iteration, int num_iteration,
int64_t buffer_len, int64_t buffer_len,
int64_t* out_len, int64_t* out_len,
......
...@@ -1401,7 +1401,7 @@ class Booster(object): ...@@ -1401,7 +1401,7 @@ class Booster(object):
self.__num_class = out_num_class.value self.__num_class = out_num_class.value
self.pandas_categorical = _load_pandas_categorical(model_file) self.pandas_categorical = _load_pandas_categorical(model_file)
elif 'model_str' in params: elif 'model_str' in params:
self._load_model_from_string(params['model_str']) self.model_from_string(params['model_str'])
else: else:
raise TypeError('Need at least one training dataset or model file to create booster instance') raise TypeError('Need at least one training dataset or model file to create booster instance')
...@@ -1421,7 +1421,7 @@ class Booster(object): ...@@ -1421,7 +1421,7 @@ class Booster(object):
return self.__deepcopy__(None) return self.__deepcopy__(None)
def __deepcopy__(self, _): def __deepcopy__(self, _):
model_str = self._save_model_to_string(num_iteration=-1) model_str = self.model_to_string(num_iteration=-1)
booster = Booster({'model_str': model_str}) booster = Booster({'model_str': model_str})
booster.pandas_categorical = self.pandas_categorical booster.pandas_categorical = self.pandas_categorical
return booster return booster
...@@ -1432,7 +1432,7 @@ class Booster(object): ...@@ -1432,7 +1432,7 @@ class Booster(object):
this.pop('train_set', None) this.pop('train_set', None)
this.pop('valid_sets', None) this.pop('valid_sets', None)
if handle is not None: if handle is not None:
this["handle"] = self._save_model_to_string(num_iteration=-1) this["handle"] = self.model_to_string(num_iteration=-1)
return this return this
def __setstate__(self, state): def __setstate__(self, state):
...@@ -1710,7 +1710,7 @@ class Booster(object): ...@@ -1710,7 +1710,7 @@ class Booster(object):
return [item for i in range_(1, self.__num_dataset) return [item for i in range_(1, self.__num_dataset)
for item in self.__inner_eval(self.name_valid_sets[i - 1], i, feval)] for item in self.__inner_eval(self.name_valid_sets[i - 1], i, feval)]
def save_model(self, filename, num_iteration=None): def save_model(self, filename, num_iteration=None, start_iteration=0):
"""Save Booster to file. """Save Booster to file.
Parameters Parameters
...@@ -1721,17 +1721,38 @@ class Booster(object): ...@@ -1721,17 +1721,38 @@ class Booster(object):
Index of the iteration that should be saved. Index of the iteration that should be saved.
If None, if the best iteration exists, it is saved; otherwise, all iterations are saved. If None, if the best iteration exists, it is saved; otherwise, all iterations are saved.
If <= 0, all iterations are saved. If <= 0, all iterations are saved.
start_iteration: int, optional (default=0)
Start index of the iteration that should be saved.
""" """
if num_iteration is None: if num_iteration is None:
num_iteration = self.best_iteration num_iteration = self.best_iteration
_safe_call(_LIB.LGBM_BoosterSaveModel( _safe_call(_LIB.LGBM_BoosterSaveModel(
self.handle, self.handle,
ctypes.c_int(start_iteration),
ctypes.c_int(num_iteration), ctypes.c_int(num_iteration),
c_str(filename))) c_str(filename)))
_save_pandas_categorical(filename, self.pandas_categorical) _save_pandas_categorical(filename, self.pandas_categorical)
def _load_model_from_string(self, model_str, verbose=True): def shuffle_models(self):
"""[Private] Load model from string""" """Shuffle models.
"""
_safe_call(_LIB.LGBM_BoosterShuffleModels(self.handle))
def model_from_string(self, model_str, verbose=True):
"""Load Booster from a string.
Parameters
----------
model_str: string
Model will be loaded from this string.
verbose: bool, optional (default=True)
Set to False to disable log when loading model.
Returns
-------
result: Booster
Loaded Booster object.
"""
if self.handle is not None: if self.handle is not None:
_safe_call(_LIB.LGBM_BoosterFree(self.handle)) _safe_call(_LIB.LGBM_BoosterFree(self.handle))
self._free_buffer() self._free_buffer()
...@@ -1748,9 +1769,25 @@ class Booster(object): ...@@ -1748,9 +1769,25 @@ class Booster(object):
if verbose: if verbose:
print('Finished loading model, total used %d iterations' % (int(out_num_iterations.value))) print('Finished loading model, total used %d iterations' % (int(out_num_iterations.value)))
self.__num_class = out_num_class.value self.__num_class = out_num_class.value
return self
def model_to_string(self, num_iteration=None, start_iteration=0):
"""Save Booster to string.
def _save_model_to_string(self, num_iteration=None): Parameters
"""[Private] Save model to string""" ----------
num_iteration : int or None, optional (default=None)
Index of the iteration that should be saved.
If None, if the best iteration exists, it is saved; otherwise, all iterations are saved.
If <= 0, all iterations are saved.
start_iteration: int, optional (default=0)
Start index of the iteration that should be saved.
Returns
-------
result: string
String representation of Booster.
"""
if num_iteration is None: if num_iteration is None:
num_iteration = self.best_iteration num_iteration = self.best_iteration
buffer_len = 1 << 20 buffer_len = 1 << 20
...@@ -1759,6 +1796,7 @@ class Booster(object): ...@@ -1759,6 +1796,7 @@ class Booster(object):
ptr_string_buffer = ctypes.c_char_p(*[ctypes.addressof(string_buffer)]) ptr_string_buffer = ctypes.c_char_p(*[ctypes.addressof(string_buffer)])
_safe_call(_LIB.LGBM_BoosterSaveModelToString( _safe_call(_LIB.LGBM_BoosterSaveModelToString(
self.handle, self.handle,
ctypes.c_int(start_iteration),
ctypes.c_int(num_iteration), ctypes.c_int(num_iteration),
ctypes.c_int64(buffer_len), ctypes.c_int64(buffer_len),
ctypes.byref(tmp_out_len), ctypes.byref(tmp_out_len),
...@@ -1770,13 +1808,14 @@ class Booster(object): ...@@ -1770,13 +1808,14 @@ class Booster(object):
ptr_string_buffer = ctypes.c_char_p(*[ctypes.addressof(string_buffer)]) ptr_string_buffer = ctypes.c_char_p(*[ctypes.addressof(string_buffer)])
_safe_call(_LIB.LGBM_BoosterSaveModelToString( _safe_call(_LIB.LGBM_BoosterSaveModelToString(
self.handle, self.handle,
ctypes.c_int(start_iteration),
ctypes.c_int(num_iteration), ctypes.c_int(num_iteration),
ctypes.c_int64(actual_len), ctypes.c_int64(actual_len),
ctypes.byref(tmp_out_len), ctypes.byref(tmp_out_len),
ptr_string_buffer)) ptr_string_buffer))
return string_buffer.value.decode() return string_buffer.value.decode()
def dump_model(self, num_iteration=None): def dump_model(self, num_iteration=None, start_iteration=0):
"""Dump Booster to json format. """Dump Booster to json format.
Parameters Parameters
...@@ -1785,6 +1824,8 @@ class Booster(object): ...@@ -1785,6 +1824,8 @@ class Booster(object):
Index of the iteration that should be dumped. Index of the iteration that should be dumped.
If None, if the best iteration exists, it is dumped; otherwise, all iterations are dumped. If None, if the best iteration exists, it is dumped; otherwise, all iterations are dumped.
If <= 0, all iterations are dumped. If <= 0, all iterations are dumped.
start_iteration: int, optional (default=0)
Start index of the iteration that should be dumped.
Returns Returns
------- -------
...@@ -1799,6 +1840,7 @@ class Booster(object): ...@@ -1799,6 +1840,7 @@ class Booster(object):
ptr_string_buffer = ctypes.c_char_p(*[ctypes.addressof(string_buffer)]) ptr_string_buffer = ctypes.c_char_p(*[ctypes.addressof(string_buffer)])
_safe_call(_LIB.LGBM_BoosterDumpModel( _safe_call(_LIB.LGBM_BoosterDumpModel(
self.handle, self.handle,
ctypes.c_int(start_iteration),
ctypes.c_int(num_iteration), ctypes.c_int(num_iteration),
ctypes.c_int64(buffer_len), ctypes.c_int64(buffer_len),
ctypes.byref(tmp_out_len), ctypes.byref(tmp_out_len),
...@@ -1810,6 +1852,7 @@ class Booster(object): ...@@ -1810,6 +1852,7 @@ class Booster(object):
ptr_string_buffer = ctypes.c_char_p(*[ctypes.addressof(string_buffer)]) ptr_string_buffer = ctypes.c_char_p(*[ctypes.addressof(string_buffer)])
_safe_call(_LIB.LGBM_BoosterDumpModel( _safe_call(_LIB.LGBM_BoosterDumpModel(
self.handle, self.handle,
ctypes.c_int(start_iteration),
ctypes.c_int(num_iteration), ctypes.c_int(num_iteration),
ctypes.c_int64(actual_len), ctypes.c_int64(actual_len),
ctypes.byref(tmp_out_len), ctypes.byref(tmp_out_len),
......
...@@ -230,7 +230,7 @@ def train(params, train_set, num_boost_round=100, ...@@ -230,7 +230,7 @@ def train(params, train_set, num_boost_round=100,
for dataset_name, eval_name, score, _ in evaluation_result_list: for dataset_name, eval_name, score, _ in evaluation_result_list:
booster.best_score[dataset_name][eval_name] = score booster.best_score[dataset_name][eval_name] = score
if not keep_training_booster: if not keep_training_booster:
booster._load_model_from_string(booster._save_model_to_string(), False) booster.model_from_string(booster.model_to_string(), False)
booster.free_dataset() booster.free_dataset()
return booster return booster
......
...@@ -203,7 +203,7 @@ void Application::InitTrain() { ...@@ -203,7 +203,7 @@ void Application::InitTrain() {
void Application::Train() { void Application::Train() {
Log::Info("Started training..."); Log::Info("Started training...");
boosting_->Train(config_.snapshot_freq, config_.output_model); boosting_->Train(config_.snapshot_freq, config_.output_model);
boosting_->SaveModelToFile(-1, config_.output_model.c_str()); boosting_->SaveModelToFile(0, -1, config_.output_model.c_str());
// convert model to if-else statement code // convert model to if-else statement code
if (config_.convert_model_language == std::string("cpp")) { if (config_.convert_model_language == std::string("cpp")) {
boosting_->SaveModelToIfElse(-1, config_.convert_model.c_str()); boosting_->SaveModelToIfElse(-1, config_.convert_model.c_str());
...@@ -237,7 +237,7 @@ void Application::Predict() { ...@@ -237,7 +237,7 @@ void Application::Predict() {
boosting_->Init(&config_, train_data_.get(), objective_fun_.get(), boosting_->Init(&config_, train_data_.get(), objective_fun_.get(),
Common::ConstPtrInVectorWrapper<Metric>(train_metric_)); Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
boosting_->RefitTree(pred_leaf); boosting_->RefitTree(pred_leaf);
boosting_->SaveModelToFile(-1, config_.output_model.c_str()); boosting_->SaveModelToFile(0, -1, config_.output_model.c_str());
Log::Info("Finished RefitTree"); Log::Info("Finished RefitTree");
} else { } else {
// create predictor // create predictor
......
...@@ -330,7 +330,7 @@ void GBDT::Train(int snapshot_freq, const std::string& model_output_path) { ...@@ -330,7 +330,7 @@ void GBDT::Train(int snapshot_freq, const std::string& model_output_path) {
if (snapshot_freq > 0 if (snapshot_freq > 0
&& (iter + 1) % snapshot_freq == 0) { && (iter + 1) % snapshot_freq == 0) {
std::string snapshot_out = model_output_path + ".snapshot_iter_" + std::to_string(iter + 1); std::string snapshot_out = model_output_path + ".snapshot_iter_" + std::to_string(iter + 1);
SaveModelToFile(-1, snapshot_out.c_str()); SaveModelToFile(0, -1, snapshot_out.c_str());
} }
} }
} }
......
...@@ -70,6 +70,28 @@ public: ...@@ -70,6 +70,28 @@ public:
num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_tree_per_iteration_; num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_tree_per_iteration_;
} }
void ShuffleModels() override {
int total_iter = static_cast<int>(models_.size()) / num_tree_per_iteration_;
auto original_models = std::move(models_);
std::vector<int> indices(total_iter);
for (int i = 0; i < total_iter; ++i) {
indices[i] = i;
}
Random tmp_rand(17);
for (int i = 0; i < total_iter - 1; ++i) {
int j = tmp_rand.NextShort(i + 1, total_iter);
std::swap(indices[i], indices[j]);
}
models_ = std::vector<std::unique_ptr<Tree>>();
for (int i = 0; i < total_iter; ++i) {
for (int j = 0; j < num_tree_per_iteration_; ++j) {
int tree_idx = indices[i] * num_tree_per_iteration_ + j;
auto new_tree = std::unique_ptr<Tree>(new Tree(*(original_models[tree_idx].get())));
models_.push_back(std::move(new_tree));
}
}
}
/*! /*!
* \brief Reset the training data * \brief Reset the training data
* \param train_data New Training data * \param train_data New Training data
...@@ -211,10 +233,11 @@ public: ...@@ -211,10 +233,11 @@ public:
/*! /*!
* \brief Dump model to json format string * \brief Dump model to json format string
* \param start_iteration The model will be saved start from
* \param num_iteration Number of iterations that want to dump, -1 means dump all * \param num_iteration Number of iterations that want to dump, -1 means dump all
* \return Json format string of model * \return Json format string of model
*/ */
std::string DumpModel(int num_iteration) const override; std::string DumpModel(int start_iteration, int num_iteration) const override;
/*! /*!
* \brief Translate model to if-else statement * \brief Translate model to if-else statement
...@@ -233,18 +256,20 @@ public: ...@@ -233,18 +256,20 @@ public:
/*! /*!
* \brief Save model to file * \brief Save model to file
* \param start_iteration The model will be saved start from
* \param num_iterations Number of model that want to save, -1 means save all * \param num_iterations Number of model that want to save, -1 means save all
* \param filename Filename that want to save to * \param filename Filename that want to save to
* \return is_finish Is training finished or not * \return is_finish Is training finished or not
*/ */
virtual bool SaveModelToFile(int num_iterations, const char* filename) const override; virtual bool SaveModelToFile(int start_iteration, int num_iterations, const char* filename) const override;
/*! /*!
* \brief Save model to string * \brief Save model to string
* \param start_iteration The model will be saved start from
* \param num_iterations Number of model that want to save, -1 means save all * \param num_iterations Number of model that want to save, -1 means save all
* \return Non-empty string if succeeded * \return Non-empty string if succeeded
*/ */
virtual std::string SaveModelToString(int num_iterations) const override; virtual std::string SaveModelToString(int start_iteration, int num_iterations) const override;
/*! /*!
* \brief Restore from a serialized buffer * \brief Restore from a serialized buffer
......
...@@ -12,7 +12,7 @@ namespace LightGBM { ...@@ -12,7 +12,7 @@ namespace LightGBM {
const std::string kModelVersion = "v2"; const std::string kModelVersion = "v2";
std::string GBDT::DumpModel(int num_iteration) const { std::string GBDT::DumpModel(int start_iteration, int num_iteration) const {
std::stringstream str_buf; std::stringstream str_buf;
str_buf << "{"; str_buf << "{";
...@@ -29,11 +29,16 @@ std::string GBDT::DumpModel(int num_iteration) const { ...@@ -29,11 +29,16 @@ std::string GBDT::DumpModel(int num_iteration) const {
str_buf << "\"tree_info\":["; str_buf << "\"tree_info\":[";
int num_used_model = static_cast<int>(models_.size()); int num_used_model = static_cast<int>(models_.size());
int total_iteration = num_used_model / num_tree_per_iteration_;
start_iteration = std::max(start_iteration, 0);
start_iteration = std::min(start_iteration, total_iteration);
if (num_iteration > 0) { if (num_iteration > 0) {
num_used_model = std::min(num_iteration * num_tree_per_iteration_, num_used_model); int end_iteration = start_iteration + num_iteration;
num_used_model = std::min(end_iteration * num_tree_per_iteration_ , num_used_model);
} }
for (int i = 0; i < num_used_model; ++i) { int start_model = start_iteration * num_tree_per_iteration_;
if (i > 0) { for (int i = start_model; i < num_used_model; ++i) {
if (i > start_model) {
str_buf << ","; str_buf << ",";
} }
str_buf << "{"; str_buf << "{";
...@@ -232,7 +237,7 @@ bool GBDT::SaveModelToIfElse(int num_iteration, const char* filename) const { ...@@ -232,7 +237,7 @@ bool GBDT::SaveModelToIfElse(int num_iteration, const char* filename) const {
return (bool)output_file; return (bool)output_file;
} }
std::string GBDT::SaveModelToString(int num_iteration) const { std::string GBDT::SaveModelToString(int start_iteration, int num_iteration) const {
std::stringstream ss; std::stringstream ss;
// output model type // output model type
...@@ -259,24 +264,31 @@ std::string GBDT::SaveModelToString(int num_iteration) const { ...@@ -259,24 +264,31 @@ std::string GBDT::SaveModelToString(int num_iteration) const {
ss << "feature_infos=" << Common::Join(feature_infos_, " ") << '\n'; ss << "feature_infos=" << Common::Join(feature_infos_, " ") << '\n';
int num_used_model = static_cast<int>(models_.size()); int num_used_model = static_cast<int>(models_.size());
int total_iteration = num_used_model / num_tree_per_iteration_;
start_iteration = std::max(start_iteration, 0);
start_iteration = std::min(start_iteration, total_iteration);
if (num_iteration > 0) { if (num_iteration > 0) {
num_used_model = std::min(num_iteration * num_tree_per_iteration_, num_used_model); int end_iteration = start_iteration + num_iteration;
num_used_model = std::min(end_iteration * num_tree_per_iteration_, num_used_model);
} }
std::vector<std::string> tree_strs(num_used_model); int start_model = start_iteration * num_tree_per_iteration_;
std::vector<size_t> tree_sizes(num_used_model);
std::vector<std::string> tree_strs(num_used_model - start_model);
std::vector<size_t> tree_sizes(num_used_model - start_model);
// output tree models // output tree models
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (int i = 0; i < num_used_model; ++i) { for (int i = start_model; i < num_used_model; ++i) {
tree_strs[i] = "Tree=" + std::to_string(i) + '\n'; const int idx = i - start_model;
tree_strs[i] += models_[i]->ToString() + '\n'; tree_strs[idx] = "Tree=" + std::to_string(idx) + '\n';
tree_sizes[i] = tree_strs[i].size(); tree_strs[idx] += models_[i]->ToString() + '\n';
tree_sizes[idx] = tree_strs[idx].size();
} }
ss << "tree_sizes=" << Common::Join(tree_sizes, " ") << '\n'; ss << "tree_sizes=" << Common::Join(tree_sizes, " ") << '\n';
ss << '\n'; ss << '\n';
for (int i = 0; i < num_used_model; ++i) { for (int i = 0; i < num_used_model - start_model; ++i) {
ss << tree_strs[i]; ss << tree_strs[i];
tree_strs[i].clear(); tree_strs[i].clear();
} }
...@@ -313,11 +325,11 @@ std::string GBDT::SaveModelToString(int num_iteration) const { ...@@ -313,11 +325,11 @@ std::string GBDT::SaveModelToString(int num_iteration) const {
return ss.str(); return ss.str();
} }
bool GBDT::SaveModelToFile(int num_iteration, const char* filename) const { bool GBDT::SaveModelToFile(int start_iteration, int num_iteration, const char* filename) const {
/*! \brief File to write models */ /*! \brief File to write models */
std::ofstream output_file; std::ofstream output_file;
output_file.open(filename, std::ios::out | std::ios::binary); output_file.open(filename, std::ios::out | std::ios::binary);
std::string str_to_write = SaveModelToString(num_iteration); std::string str_to_write = SaveModelToString(start_iteration, num_iteration);
output_file.write(str_to_write.c_str(), str_to_write.size()); output_file.write(str_to_write.c_str(), str_to_write.size());
output_file.close(); output_file.close();
......
...@@ -253,8 +253,8 @@ public: ...@@ -253,8 +253,8 @@ public:
boosting_->GetPredictAt(data_idx, out_result, out_len); boosting_->GetPredictAt(data_idx, out_result, out_len);
} }
void SaveModelToFile(int num_iteration, const char* filename) { void SaveModelToFile(int start_iteration, int num_iteration, const char* filename) {
boosting_->SaveModelToFile(num_iteration, filename); boosting_->SaveModelToFile(start_iteration, num_iteration, filename);
} }
void LoadModelFromString(const char* model_str) { void LoadModelFromString(const char* model_str) {
...@@ -262,12 +262,12 @@ public: ...@@ -262,12 +262,12 @@ public:
boosting_->LoadModelFromString(model_str, len); boosting_->LoadModelFromString(model_str, len);
} }
std::string SaveModelToString(int num_iteration) { std::string SaveModelToString(int start_iteration, int num_iteration) {
return boosting_->SaveModelToString(num_iteration); return boosting_->SaveModelToString(start_iteration, num_iteration);
} }
std::string DumpModel(int num_iteration) { std::string DumpModel(int start_iteration,int num_iteration) {
return boosting_->DumpModel(num_iteration); return boosting_->DumpModel(start_iteration, num_iteration);
} }
std::vector<double> FeatureImportance(int num_iteration, int importance_type) { std::vector<double> FeatureImportance(int num_iteration, int importance_type) {
...@@ -283,6 +283,11 @@ public: ...@@ -283,6 +283,11 @@ public:
dynamic_cast<GBDTBase*>(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val); dynamic_cast<GBDTBase*>(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val);
} }
void ShuffleModels() {
std::lock_guard<std::mutex> lock(mutex_);
boosting_->ShuffleModels();
}
int GetEvalCounts() const { int GetEvalCounts() const {
int ret = 0; int ret = 0;
for (const auto& metric : train_metric_) { for (const auto& metric : train_metric_) {
...@@ -903,6 +908,13 @@ int LGBM_BoosterFree(BoosterHandle handle) { ...@@ -903,6 +908,13 @@ int LGBM_BoosterFree(BoosterHandle handle) {
API_END(); API_END();
} }
int LGBM_BoosterShuffleModels(BoosterHandle handle) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->ShuffleModels();
API_END();
}
int LGBM_BoosterMerge(BoosterHandle handle, int LGBM_BoosterMerge(BoosterHandle handle,
BoosterHandle other_handle) { BoosterHandle other_handle) {
API_BEGIN(); API_BEGIN();
...@@ -1188,22 +1200,24 @@ int LGBM_BoosterPredictForMat(BoosterHandle handle, ...@@ -1188,22 +1200,24 @@ int LGBM_BoosterPredictForMat(BoosterHandle handle,
} }
int LGBM_BoosterSaveModel(BoosterHandle handle, int LGBM_BoosterSaveModel(BoosterHandle handle,
int start_iteration,
int num_iteration, int num_iteration,
const char* filename) { const char* filename) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->SaveModelToFile(num_iteration, filename); ref_booster->SaveModelToFile(start_iteration, num_iteration, filename);
API_END(); API_END();
} }
int LGBM_BoosterSaveModelToString(BoosterHandle handle, int LGBM_BoosterSaveModelToString(BoosterHandle handle,
int start_iteration,
int num_iteration, int num_iteration,
int64_t buffer_len, int64_t buffer_len,
int64_t* out_len, int64_t* out_len,
char* out_str) { char* out_str) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
std::string model = ref_booster->SaveModelToString(num_iteration); std::string model = ref_booster->SaveModelToString(start_iteration, num_iteration);
*out_len = static_cast<int64_t>(model.size()) + 1; *out_len = static_cast<int64_t>(model.size()) + 1;
if (*out_len <= buffer_len) { if (*out_len <= buffer_len) {
std::memcpy(out_str, model.c_str(), *out_len); std::memcpy(out_str, model.c_str(), *out_len);
...@@ -1212,13 +1226,14 @@ int LGBM_BoosterSaveModelToString(BoosterHandle handle, ...@@ -1212,13 +1226,14 @@ int LGBM_BoosterSaveModelToString(BoosterHandle handle,
} }
int LGBM_BoosterDumpModel(BoosterHandle handle, int LGBM_BoosterDumpModel(BoosterHandle handle,
int start_iteration,
int num_iteration, int num_iteration,
int64_t buffer_len, int64_t buffer_len,
int64_t* out_len, int64_t* out_len,
char* out_str) { char* out_str) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
std::string model = ref_booster->DumpModel(num_iteration); std::string model = ref_booster->DumpModel(start_iteration, num_iteration);
*out_len = static_cast<int64_t>(model.size()) + 1; *out_len = static_cast<int64_t>(model.size()) + 1;
if (*out_len <= buffer_len) { if (*out_len <= buffer_len) {
std::memcpy(out_str, model.c_str(), *out_len); std::memcpy(out_str, model.c_str(), *out_len);
......
...@@ -597,7 +597,7 @@ LGBM_SE LGBM_BoosterSaveModel_R(LGBM_SE handle, ...@@ -597,7 +597,7 @@ LGBM_SE LGBM_BoosterSaveModel_R(LGBM_SE handle,
LGBM_SE filename, LGBM_SE filename,
LGBM_SE call_state) { LGBM_SE call_state) {
R_API_BEGIN(); R_API_BEGIN();
CHECK_CALL(LGBM_BoosterSaveModel(R_GET_PTR(handle), R_AS_INT(num_iteration), R_CHAR_PTR(filename))); CHECK_CALL(LGBM_BoosterSaveModel(R_GET_PTR(handle), 0, R_AS_INT(num_iteration), R_CHAR_PTR(filename)));
R_API_END(); R_API_END();
} }
...@@ -610,7 +610,7 @@ LGBM_SE LGBM_BoosterSaveModelToString_R(LGBM_SE handle, ...@@ -610,7 +610,7 @@ LGBM_SE LGBM_BoosterSaveModelToString_R(LGBM_SE handle,
R_API_BEGIN(); R_API_BEGIN();
int64_t out_len = 0; int64_t out_len = 0;
std::vector<char> inner_char_buf(R_AS_INT(buffer_len)); std::vector<char> inner_char_buf(R_AS_INT(buffer_len));
CHECK_CALL(LGBM_BoosterSaveModelToString(R_GET_PTR(handle), R_AS_INT(num_iteration), R_AS_INT(buffer_len), &out_len, inner_char_buf.data())); CHECK_CALL(LGBM_BoosterSaveModelToString(R_GET_PTR(handle), 0, R_AS_INT(num_iteration), R_AS_INT(buffer_len), &out_len, inner_char_buf.data()));
EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len, static_cast<size_t>(out_len)); EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len, static_cast<size_t>(out_len));
R_API_END(); R_API_END();
} }
...@@ -624,7 +624,7 @@ LGBM_SE LGBM_BoosterDumpModel_R(LGBM_SE handle, ...@@ -624,7 +624,7 @@ LGBM_SE LGBM_BoosterDumpModel_R(LGBM_SE handle,
R_API_BEGIN(); R_API_BEGIN();
int64_t out_len = 0; int64_t out_len = 0;
std::vector<char> inner_char_buf(R_AS_INT(buffer_len)); std::vector<char> inner_char_buf(R_AS_INT(buffer_len));
CHECK_CALL(LGBM_BoosterDumpModel(R_GET_PTR(handle), R_AS_INT(num_iteration), R_AS_INT(buffer_len), &out_len, inner_char_buf.data())); CHECK_CALL(LGBM_BoosterDumpModel(R_GET_PTR(handle), 0, R_AS_INT(num_iteration), R_AS_INT(buffer_len), &out_len, inner_char_buf.data()));
EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len, static_cast<size_t>(out_len)); EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len, static_cast<size_t>(out_len));
R_API_END(); R_API_END();
} }
...@@ -225,7 +225,7 @@ def test_booster(): ...@@ -225,7 +225,7 @@ def test_booster():
LIB.LGBM_BoosterGetEval(booster, 0, ctypes.byref(out_len), result.ctypes.data_as(ctypes.POINTER(ctypes.c_double))) LIB.LGBM_BoosterGetEval(booster, 0, ctypes.byref(out_len), result.ctypes.data_as(ctypes.POINTER(ctypes.c_double)))
if i % 10 == 0: if i % 10 == 0:
print('%d Iteration test AUC %f' % (i, result[0])) print('%d Iteration test AUC %f' % (i, result[0]))
LIB.LGBM_BoosterSaveModel(booster, -1, c_str('model.txt')) LIB.LGBM_BoosterSaveModel(booster, 0, -1, c_str('model.txt'))
LIB.LGBM_BoosterFree(booster) LIB.LGBM_BoosterFree(booster)
test_free_dataset(train) test_free_dataset(train)
test_free_dataset(test) test_free_dataset(test)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment