Unverified Commit c9bcba44 authored by Tsukasa OMOTO's avatar Tsukasa OMOTO Committed by GitHub
Browse files

[python] fix creating train_set in fit (#1916)

* [python] fix creating train_set in fit

https://github.com/Microsoft/LightGBM/blob/cc99f0d36ae929eb02b22a072823ab7c6d3155ab/python-package/lightgbm/sklearn.py#L519
may False even if valid_data[0] is X and valid_data[1] is y actually, because `check_X_y` might return copy of X and y.
https://scikit-learn.org/0.20/modules/generated/sklearn.utils.check_X_y.html

cf. https://github.com/Microsoft/LightGBM/pull/451

* use assertIn
parent cba82447
......@@ -480,8 +480,10 @@ class LGBMModel(_LGBMModelBase):
params['metric'] = set(original_metric + eval_metric)
if not isinstance(X, DataFrame):
X, y = _LGBMCheckXY(X, y, accept_sparse=True, force_all_finite=False, ensure_min_samples=2)
_LGBMCheckConsistentLength(X, y, sample_weight)
_X, _y = _LGBMCheckXY(X, y, accept_sparse=True, force_all_finite=False, ensure_min_samples=2)
_LGBMCheckConsistentLength(_X, _y, sample_weight)
else:
_X, _y = X, y
if self.class_weight is not None:
class_sample_weight = _LGBMComputeSampleWeight(self.class_weight, y)
......@@ -490,13 +492,13 @@ class LGBMModel(_LGBMModelBase):
else:
sample_weight = np.multiply(sample_weight, class_sample_weight)
self._n_features = X.shape[1]
self._n_features = _X.shape[1]
def _construct_dataset(X, y, sample_weight, init_score, group, params):
ret = Dataset(X, label=y, weight=sample_weight, group=group, params=params)
return ret.set_init_score(init_score)
train_set = _construct_dataset(X, y, sample_weight, init_score, group, params)
train_set = _construct_dataset(_X, _y, sample_weight, init_score, group, params)
valid_sets = []
if eval_set is not None:
......
......@@ -279,3 +279,13 @@ class TestSklearn(unittest.TestCase):
self.assertRaises(AssertionError,
np.testing.assert_allclose,
res_engine, res_sklearn_params)
def test_evaluate_train_set(self):
X, y = load_boston(True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
gbm = lgb.LGBMRegressor(n_estimators=10, silent=True)
gbm.fit(X_train, y_train, eval_set=[(X_train, y_train), (X_test, y_test)], verbose=False)
self.assertIn('training', gbm.evals_result_)
self.assertIn('l2', gbm.evals_result_['training'])
self.assertIn('valid_1', gbm.evals_result_)
self.assertIn('l2', gbm.evals_result_['valid_1'])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment