"...git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "bb5d5711075d5442944211fde8bdb8eea5a4c6b3"
Commit 6cc1dd94 authored by Tsukasa OMOTO's avatar Tsukasa OMOTO Committed by Guolin Ke
Browse files

python-package: fix creating eval_set in LGBMClassifier (#451)

* python-package: fix creating eval_set in LGBMClassifier

* replace elements in eval_set directly
parent fedf3971
...@@ -572,7 +572,7 @@ class LGBMClassifier(LGBMModel, LGBMClassifierBase): ...@@ -572,7 +572,7 @@ class LGBMClassifier(LGBMModel, LGBMClassifierBase):
feature_name='auto', categorical_feature='auto', feature_name='auto', categorical_feature='auto',
callbacks=None): callbacks=None):
self._le = LGBMLabelEncoder().fit(y) self._le = LGBMLabelEncoder().fit(y)
y = self._le.transform(y) _y = self._le.transform(y)
self.classes = self._le.classes_ self.classes = self._le.classes_
self.n_classes = len(self.classes_) self.n_classes = len(self.classes_)
...@@ -590,9 +590,15 @@ class LGBMClassifier(LGBMModel, LGBMClassifierBase): ...@@ -590,9 +590,15 @@ class LGBMClassifier(LGBMModel, LGBMClassifierBase):
eval_metric = 'binary_error' eval_metric = 'binary_error'
if eval_set is not None: if eval_set is not None:
eval_set = [(x[0], self._le.transform(x[1])) for x in eval_set] if isinstance(eval_set, tuple):
eval_set = [eval_set]
for i, (valid_x, valid_y) in enumerate(eval_set):
if valid_x is X and valid_y is y:
eval_set[i] = (valid_x, _y)
else:
eval_set[i] = (valid_x, self._le.transform(valid_y))
super(LGBMClassifier, self).fit(X, y, sample_weight=sample_weight, super(LGBMClassifier, self).fit(X, _y, sample_weight=sample_weight,
init_score=init_score, eval_set=eval_set, init_score=init_score, eval_set=eval_set,
eval_names=eval_names, eval_names=eval_names,
eval_sample_weight=eval_sample_weight, eval_sample_weight=eval_sample_weight,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment