Unverified Commit 2db6377a authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

add support of refit-decay (#1603)

* add support of refit-decay

* add refit into c_api

* add test

* update document

* Update basic.py

* Update test_engine.py

* Update basic.py

* Update test_engine.py

* fix comments

* update test

* fix the comments

* Update test_engine.py
parent b1bbebaa
......@@ -49,7 +49,7 @@ Core Parameters
- ``refit``, for refitting existing models with new data, aliases: ``refit_tree``
- **Note**: can be used only in CLI version
- **Note**: can be used only in CLI version; for language-specific packages you can use the correspondent functions
- ``objective`` :raw-html:`<a id="objective" title="Permalink to this parameter" href="#objective">&#x1F517;&#xFE0E;</a>`, default = ``regression``, type = enum, options: ``regression``, ``regression_l1``, ``huber``, ``fair``, ``poisson``, ``quantile``, ``mape``, ``gammma``, ``tweedie``, ``binary``, ``multiclass``, ``multiclassova``, ``xentropy``, ``xentlambda``, ``lambdarank``, aliases: ``objective_type``, ``app``, ``application``
......@@ -364,6 +364,12 @@ Learning Control Parameters
- see `this file <https://github.com/Microsoft/LightGBM/tree/master/examples/binary_classification/forced_splits.json>`__ as an example
- ``refit_decay_rate`` :raw-html:`<a id="refit_decay_rate" title="Permalink to this parameter" href="#refit_decay_rate">&#x1F517;&#xFE0E;</a>`, default = ``0.9``, type = double, constraints: ``0.0 <= refit_decay_rate <= 1.0``
- decay rate of ``refit`` task, will use ``leaf_output = refit_decay_rate * old_leaf_output + (1.0 - refit_decay_rate) * new_leaf_output`` to refit trees
- used only in ``refit`` task in CLI version or as argument in ``refit`` function in language-specific package
IO Parameters
-------------
......
......@@ -427,6 +427,16 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterGetNumClasses(BoosterHandle handle, int* out_l
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished);
/*!
* \brief Refit the tree model using the new data (online learning)
* \param handle handle
* \param leaf_preds
* \param nrow number of rows of leaf_preds
* \param ncol number of columns of leaf_preds
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterRefit(BoosterHandle handle, const int32_t* leaf_preds, int32_t nrow, int32_t ncol);
/*!
* \brief update the model, by directly specify gradient and second order gradient,
* this can be used to support customized loss function
......
......@@ -93,7 +93,7 @@ public:
// desc = ``predict``, for prediction, aliases: ``prediction``, ``test``
// desc = ``convert_model``, for converting model file into if-else format, see more information in `IO Parameters <#io-parameters>`__
// desc = ``refit``, for refitting existing models with new data, aliases: ``refit_tree``
// desc = **Note**: can be used only in CLI version
// desc = **Note**: can be used only in CLI version; for language-specific packages you can use the correspondent functions
TaskType task = TaskType::kTrain;
// [doc-only]
......@@ -368,6 +368,12 @@ public:
// desc = see `this file <https://github.com/Microsoft/LightGBM/tree/master/examples/binary_classification/forced_splits.json>`__ as an example
std::string forcedsplits_filename = "";
// check = >=0.0
// check = <=1.0
// desc = decay rate of ``refit`` task, will use ``leaf_output = refit_decay_rate * old_leaf_output + (1.0 - refit_decay_rate) * new_leaf_output`` to refit trees
// desc = used only in ``refit`` task in CLI version or as argument in ``refit`` function in language-specific package
double refit_decay_rate = 0.9;
#pragma endregion
#pragma region IO Parameters
......
......@@ -1459,6 +1459,7 @@ class Booster(object):
self.model_from_string(params['model_str'])
else:
raise TypeError('Need at least one training dataset or model file to create booster instance')
self.params = params.copy()
def __del__(self):
try:
......@@ -1624,6 +1625,7 @@ class Booster(object):
_safe_call(_LIB.LGBM_BoosterResetParameter(
self.handle,
c_str(params_str)))
self.params.update(params)
return self
def update(self, train_set=None, fobj=None):
......@@ -2019,6 +2021,43 @@ class Booster(object):
num_iteration = self.best_iteration
return predictor.predict(data, num_iteration, raw_score, pred_leaf, pred_contrib, data_has_header, is_reshape)
def refit(self, data, label, decay_rate=0.9):
"""Refit the existing Booster by new data.
Parameters
----------
data : string, numpy array or scipy.sparse
Data source for refit.
If string, it represents the path to txt file.
label : list or numpy 1-D array
Label for refit.
decay_rate : float, optional (default=0.9)
Decay rate of refit, will use ``leaf_output = decay_rate * old_leaf_output + (1.0 - decay_rate) * new_leaf_output`` to refit trees.
Returns
-------
result : Booster
Refitted Booster.
"""
predictor = self._to_predictor()
leaf_preds = predictor.predict(data, -1, pred_leaf=True)
nrow = leaf_preds.shape[0]
ncol = leaf_preds.shape[1]
train_set = Dataset(data, label)
new_booster = Booster(self.params, train_set, silent=True)
# Copy models
_safe_call(_LIB.LGBM_BoosterMerge(
new_booster.handle,
predictor.handle))
leaf_preds = leaf_preds.reshape(-1)
ptr_data, type_ptr_data, _ = c_int_array(leaf_preds)
_safe_call(_LIB.LGBM_BoosterRefit(
new_booster.handle,
ptr_data,
ctypes.c_int(nrow),
ctypes.c_int(ncol)))
return new_booster
def get_leaf_output(self, tree_id, leaf_id):
"""Get the output of a leaf.
......
......@@ -348,6 +348,7 @@ void GBDT::RefitTree(const std::vector<std::vector<int>>& tree_leaf_prediction)
#pragma omp parallel for schedule(static)
for (int i = 0; i < num_data_; ++i) {
leaf_pred[i] = tree_leaf_prediction[i][model_index];
CHECK(leaf_pred[i] < models_[model_index]->num_leaves());
}
size_t bias = static_cast<size_t>(tree_id) * num_data_;
auto grad = gradients_.data() + bias;
......
......@@ -182,6 +182,17 @@ public:
return boosting_->TrainOneIter(nullptr, nullptr);
}
void Refit(const int32_t* leaf_preds, int32_t nrow, int32_t ncol) {
std::lock_guard<std::mutex> lock(mutex_);
std::vector<std::vector<int32_t>> v_leaf_preds(nrow, std::vector<int32_t>(ncol, 0));
for (int i = 0; i < nrow; ++i) {
for (int j = 0; j < ncol; ++j) {
v_leaf_preds[i][j] = leaf_preds[i * ncol + j];
}
}
boosting_->RefitTree(v_leaf_preds);
}
bool TrainOneIter(const score_t* gradients, const score_t* hessians) {
std::lock_guard<std::mutex> lock(mutex_);
return boosting_->TrainOneIter(gradients, hessians);
......@@ -956,6 +967,13 @@ int LGBM_BoosterGetNumClasses(BoosterHandle handle, int* out_len) {
API_END();
}
int LGBM_BoosterRefit(BoosterHandle handle, const int32_t* leaf_preds, int32_t nrow, int32_t ncol) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->Refit(leaf_preds, nrow, ncol);
API_END();
}
int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
......
......@@ -199,6 +199,7 @@ std::unordered_set<std::string> Config::parameter_set({
"monotone_constraints",
"feature_contri",
"forcedsplits_filename",
"refit_decay_rate",
"verbosity",
"max_bin",
"min_data_in_bin",
......@@ -368,6 +369,10 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
GetString(params, "forcedsplits_filename", &forcedsplits_filename);
GetDouble(params, "refit_decay_rate", &refit_decay_rate);
CHECK(refit_decay_rate >=0.0);
CHECK(refit_decay_rate <=1.0);
GetInt(params, "verbosity", &verbosity);
GetInt(params, "max_bin", &max_bin);
......@@ -554,6 +559,7 @@ std::string Config::SaveMembersToString() const {
str_buf << "[monotone_constraints: " << Common::Join(Common::ArrayCast<int8_t, int>(monotone_constraints),",") << "]\n";
str_buf << "[feature_contri: " << Common::Join(feature_contri,",") << "]\n";
str_buf << "[forcedsplits_filename: " << forcedsplits_filename << "]\n";
str_buf << "[refit_decay_rate: " << refit_decay_rate << "]\n";
str_buf << "[verbosity: " << verbosity << "]\n";
str_buf << "[max_bin: " << max_bin << "]\n";
str_buf << "[min_data_in_bin: " << min_data_in_bin << "]\n";
......
......@@ -238,7 +238,9 @@ Tree* SerialTreeLearner::FitByExistingTree(const Tree* old_tree, const score_t*
}
double output = FeatureHistogram::CalculateSplittedLeafOutput(sum_grad, sum_hess,
config_->lambda_l1, config_->lambda_l2, config_->max_delta_step);
tree->SetLeafOutput(i, output* tree->shrinkage());
auto old_leaf_output = tree->LeafOutput(i);
auto new_leaf_output = output * tree->shrinkage();
tree->SetLeafOutput(i, config_->refit_decay_rate * old_leaf_output + (1.0 - config_->refit_decay_rate) * new_leaf_output);
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
......
......@@ -633,3 +633,21 @@ class TestEngine(unittest.TestCase):
}
constrained_model = lgb.train(params, trainset)
assert is_correctly_constrained(constrained_model)
def test_refit(self):
X, y = load_breast_cancer(True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
params = {
'objective': 'binary',
'metric': 'binary_logloss',
'verbose': -1,
'min_data': 10
}
lgb_train = lgb.Dataset(X_train, y_train)
gbm = lgb.train(params, lgb_train,
num_boost_round=20,
verbose_eval=False)
err_pred = log_loss(y_test, gbm.predict(X_test))
new_gbm = gbm.refit(X_test, y_test)
new_err_pred = log_loss(y_test, new_gbm.predict(X_test))
self.assertGreater(err_pred, new_err_pred)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment