Unverified Commit 44014015 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] add a few type hints in LGBMModel.fit() (#6470)

parent 8579d5e3
......@@ -454,6 +454,30 @@ _lgbmmodel_doc_predict = """
"""
def _extract_evaluation_meta_data(
*,
collection: Optional[Union[Dict[Any, Any], List[Any]]],
name: str,
i: int,
) -> Optional[Any]:
"""Try to extract the ith element of one of the ``eval_*`` inputs."""
if collection is None:
return None
elif isinstance(collection, list):
# It's possible, for example, to pass 3 eval sets through `eval_set`,
# but only 1 init_score through `eval_init_score`.
#
# This if-else accounts for that possiblity.
if len(collection) > i:
return collection[i]
else:
return None
elif isinstance(collection, dict):
return collection.get(i, None)
else:
raise TypeError(f"{name} should be dict or list")
class LGBMModel(_LGBMModelBase):
"""Implementation of the scikit-learn API for LightGBM."""
......@@ -869,17 +893,6 @@ class LGBMModel(_LGBMModelBase):
valid_sets: List[Dataset] = []
if eval_set is not None:
def _get_meta_data(collection, name, i):
if collection is None:
return None
elif isinstance(collection, list):
return collection[i] if len(collection) > i else None
elif isinstance(collection, dict):
return collection.get(i, None)
else:
raise TypeError(f"{name} should be dict or list")
if isinstance(eval_set, tuple):
eval_set = [eval_set]
for i, valid_data in enumerate(eval_set):
......@@ -887,8 +900,16 @@ class LGBMModel(_LGBMModelBase):
if valid_data[0] is X and valid_data[1] is y:
valid_set = train_set
else:
valid_weight = _get_meta_data(eval_sample_weight, "eval_sample_weight", i)
valid_class_weight = _get_meta_data(eval_class_weight, "eval_class_weight", i)
valid_weight = _extract_evaluation_meta_data(
collection=eval_sample_weight,
name="eval_sample_weight",
i=i,
)
valid_class_weight = _extract_evaluation_meta_data(
collection=eval_class_weight,
name="eval_class_weight",
i=i,
)
if valid_class_weight is not None:
if isinstance(valid_class_weight, dict) and self._class_map is not None:
valid_class_weight = {self._class_map[k]: v for k, v in valid_class_weight.items()}
......@@ -897,8 +918,16 @@ class LGBMModel(_LGBMModelBase):
valid_weight = valid_class_sample_weight
else:
valid_weight = np.multiply(valid_weight, valid_class_sample_weight)
valid_init_score = _get_meta_data(eval_init_score, "eval_init_score", i)
valid_group = _get_meta_data(eval_group, "eval_group", i)
valid_init_score = _extract_evaluation_meta_data(
collection=eval_init_score,
name="eval_init_score",
i=i,
)
valid_group = _extract_evaluation_meta_data(
collection=eval_group,
name="eval_group",
i=i,
)
valid_set = Dataset(
data=valid_data[0],
label=valid_data[1],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment