Unverified Commit cbb9f4e7 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] fix mypy errors about validation sets in sklearn.py (#5724)

parent 90a4510c
...@@ -756,7 +756,7 @@ class LGBMModel(_LGBMModelBase): ...@@ -756,7 +756,7 @@ class LGBMModel(_LGBMModelBase):
init_score=init_score, categorical_feature=categorical_feature, init_score=init_score, categorical_feature=categorical_feature,
params=params) params=params)
valid_sets = [] valid_sets: List[Dataset] = []
if eval_set is not None: if eval_set is not None:
def _get_meta_data(collection, name, i): def _get_meta_data(collection, name, i):
...@@ -1088,16 +1088,16 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel): ...@@ -1088,16 +1088,16 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel):
eval_metric = eval_metric_list eval_metric = eval_metric_list
# do not modify args, as it causes errors in model selection tools # do not modify args, as it causes errors in model selection tools
valid_sets = None valid_sets: Optional[List[Tuple]] = None
if eval_set is not None: if eval_set is not None:
if isinstance(eval_set, tuple): if isinstance(eval_set, tuple):
eval_set = [eval_set] eval_set = [eval_set]
valid_sets = [None] * len(eval_set) valid_sets = []
for i, (valid_x, valid_y) in enumerate(eval_set): for valid_x, valid_y in eval_set:
if valid_x is X and valid_y is y: if valid_x is X and valid_y is y:
valid_sets[i] = (valid_x, _y) valid_sets.append((valid_x, _y))
else: else:
valid_sets[i] = (valid_x, self._le.transform(valid_y)) valid_sets.append((valid_x, self._le.transform(valid_y)))
super().fit( super().fit(
X, X,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment