Commit b0087754 authored by Nikita Titov's avatar Nikita Titov Committed by Guolin Ke
Browse files

added kwargs in refit for passing them to predict method (#1629)

parent ca39d3f8
......@@ -2095,7 +2095,7 @@ class Booster(object):
num_iteration = self.best_iteration
return predictor.predict(data, num_iteration, raw_score, pred_leaf, pred_contrib, data_has_header, is_reshape)
def refit(self, data, label, decay_rate=0.9):
def refit(self, data, label, decay_rate=0.9, **kwargs):
"""Refit the existing Booster by new data.
Parameters
......@@ -2107,6 +2107,8 @@ class Booster(object):
Label for refit.
decay_rate : float, optional (default=0.9)
Decay rate of refit, will use ``leaf_output = decay_rate * old_leaf_output + (1.0 - decay_rate) * new_leaf_output`` to refit trees.
**kwargs : other parameters for refit
These parameters will be passed to ``predict`` method.
Returns
-------
......@@ -2114,9 +2116,8 @@ class Booster(object):
Refitted Booster.
"""
predictor = self._to_predictor()
leaf_preds = predictor.predict(data, -1, pred_leaf=True)
nrow = leaf_preds.shape[0]
ncol = leaf_preds.shape[1]
leaf_preds = predictor.predict(data, -1, pred_leaf=True, **kwargs)
nrow, ncol = leaf_preds.shape
train_set = Dataset(data, label)
new_booster = Booster(self.params, train_set, silent=True)
# Copy models
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment