Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
a5285985
Unverified
Commit
a5285985
authored
Mar 29, 2023
by
James Lamb
Committed by
GitHub
Mar 29, 2023
Browse files
[python-package] fix mypy errors about custom eval and metric functions (#5790)
parent
9f035100
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
44 additions
and
20 deletions
+44
-20
python-package/lightgbm/engine.py
python-package/lightgbm/engine.py
+10
-4
python-package/lightgbm/sklearn.py
python-package/lightgbm/sklearn.py
+34
-16
No files found.
python-package/lightgbm/engine.py
View file @
a5285985
...
@@ -11,7 +11,7 @@ import numpy as np
...
@@ -11,7 +11,7 @@ import numpy as np
from
.
import
callback
from
.
import
callback
from
.basic
import
(
Booster
,
Dataset
,
LightGBMError
,
_choose_param_value
,
_ConfigAliases
,
_InnerPredictor
,
from
.basic
import
(
Booster
,
Dataset
,
LightGBMError
,
_choose_param_value
,
_ConfigAliases
,
_InnerPredictor
,
_LGBM_CategoricalFeatureConfiguration
,
_LGBM_CustomObjectiveFunction
,
_LGBM_CategoricalFeatureConfiguration
,
_LGBM_CustomObjectiveFunction
,
_LGBM_EvalFunctionResultType
,
_LGBM_FeatureNameConfiguration
,
_log_warning
)
_LGBM_FeatureNameConfiguration
,
_log_warning
)
from
.compat
import
SKLEARN_INSTALLED
,
_LGBMBaseCrossValidator
,
_LGBMGroupKFold
,
_LGBMStratifiedKFold
from
.compat
import
SKLEARN_INSTALLED
,
_LGBMBaseCrossValidator
,
_LGBMGroupKFold
,
_LGBMStratifiedKFold
...
@@ -22,9 +22,15 @@ __all__ = [
...
@@ -22,9 +22,15 @@ __all__ = [
]
]
_LGBM_CustomMetricFunction
=
Callable
[
_LGBM_CustomMetricFunction
=
Union
[
Callable
[
[
np
.
ndarray
,
Dataset
],
[
np
.
ndarray
,
Dataset
],
Union
[
Tuple
[
str
,
float
,
bool
],
List
[
Tuple
[
str
,
float
,
bool
]]]
_LGBM_EvalFunctionResultType
,
],
Callable
[
[
np
.
ndarray
,
Dataset
],
List
[
_LGBM_EvalFunctionResultType
]
],
]
]
_LGBM_PreprocFunction
=
Callable
[
_LGBM_PreprocFunction
=
Callable
[
...
...
python-package/lightgbm/sklearn.py
View file @
a5285985
...
@@ -33,32 +33,50 @@ _LGBM_ScikitMatrixLike = Union[
...
@@ -33,32 +33,50 @@ _LGBM_ScikitMatrixLike = Union[
scipy
.
sparse
.
spmatrix
scipy
.
sparse
.
spmatrix
]
]
_LGBM_ScikitCustomObjectiveFunction
=
Union
[
_LGBM_ScikitCustomObjectiveFunction
=
Union
[
# f(labels, preds)
Callable
[
Callable
[
[
np
.
ndarray
,
np
.
ndarray
],
[
Optional
[
np
.
ndarray
]
,
np
.
ndarray
],
Tuple
[
np
.
ndarray
,
np
.
ndarray
]
Tuple
[
np
.
ndarray
,
np
.
ndarray
]
],
],
# f(labels, preds, weights)
Callable
[
Callable
[
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
],
[
Optional
[
np
.
ndarray
]
,
np
.
ndarray
,
Optional
[
np
.
ndarray
]
]
,
Tuple
[
np
.
ndarray
,
np
.
ndarray
]
Tuple
[
np
.
ndarray
,
np
.
ndarray
]
],
],
# f(labels, preds, weights, group)
Callable
[
Callable
[
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
],
[
Optional
[
np
.
ndarray
]
,
np
.
ndarray
,
Optional
[
np
.
ndarray
]
,
Optional
[
np
.
ndarray
]
]
,
Tuple
[
np
.
ndarray
,
np
.
ndarray
]
Tuple
[
np
.
ndarray
,
np
.
ndarray
]
],
],
]
]
_LGBM_ScikitCustomEvalFunction
=
Union
[
_LGBM_ScikitCustomEvalFunction
=
Union
[
# f(labels, preds)
Callable
[
Callable
[
[
np
.
ndarray
,
np
.
ndarray
],
[
Optional
[
np
.
ndarray
]
,
np
.
ndarray
],
Union
[
_LGBM_EvalFunctionResultType
,
List
[
_LGBM_EvalFunctionResultType
]]
_LGBM_EvalFunctionResultType
],
],
Callable
[
Callable
[
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
],
[
Optional
[
np
.
ndarray
]
,
np
.
ndarray
],
Union
[
_LGBM_EvalFunctionResultType
,
List
[
_LGBM_EvalFunctionResultType
]
]
List
[
_LGBM_EvalFunctionResultType
]
],
],
# f(labels, preds, weights)
Callable
[
Callable
[
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
],
[
Optional
[
np
.
ndarray
]
,
np
.
ndarray
,
Optional
[
np
.
ndarray
]
]
,
Union
[
_LGBM_EvalFunctionResultType
,
List
[
_LGBM_EvalFunctionResultType
]]
_LGBM_EvalFunctionResultType
],
],
Callable
[
[
Optional
[
np
.
ndarray
],
np
.
ndarray
,
Optional
[
np
.
ndarray
]],
List
[
_LGBM_EvalFunctionResultType
]
],
# f(labels, preds, weights, group)
Callable
[
[
Optional
[
np
.
ndarray
],
np
.
ndarray
,
Optional
[
np
.
ndarray
],
Optional
[
np
.
ndarray
]],
_LGBM_EvalFunctionResultType
],
Callable
[
[
Optional
[
np
.
ndarray
],
np
.
ndarray
,
Optional
[
np
.
ndarray
],
Optional
[
np
.
ndarray
]],
List
[
_LGBM_EvalFunctionResultType
]
]
]
]
_LGBM_ScikitEvalMetricType
=
Union
[
_LGBM_ScikitEvalMetricType
=
Union
[
str
,
str
,
...
@@ -135,11 +153,11 @@ class _ObjectiveFunctionWrapper:
...
@@ -135,11 +153,11 @@ class _ObjectiveFunctionWrapper:
labels
=
dataset
.
get_label
()
labels
=
dataset
.
get_label
()
argc
=
len
(
signature
(
self
.
func
).
parameters
)
argc
=
len
(
signature
(
self
.
func
).
parameters
)
if
argc
==
2
:
if
argc
==
2
:
grad
,
hess
=
self
.
func
(
labels
,
preds
)
grad
,
hess
=
self
.
func
(
labels
,
preds
)
# type: ignore[call-arg]
elif
argc
==
3
:
elif
argc
==
3
:
grad
,
hess
=
self
.
func
(
labels
,
preds
,
dataset
.
get_weight
())
grad
,
hess
=
self
.
func
(
labels
,
preds
,
dataset
.
get_weight
())
# type: ignore[call-arg]
elif
argc
==
4
:
elif
argc
==
4
:
grad
,
hess
=
self
.
func
(
labels
,
preds
,
dataset
.
get_weight
(),
dataset
.
get_group
())
grad
,
hess
=
self
.
func
(
labels
,
preds
,
dataset
.
get_weight
(),
dataset
.
get_group
())
# type: ignore [call-arg]
else
:
else
:
raise
TypeError
(
f
"Self-defined objective function should have 2, 3 or 4 arguments, got
{
argc
}
"
)
raise
TypeError
(
f
"Self-defined objective function should have 2, 3 or 4 arguments, got
{
argc
}
"
)
return
grad
,
hess
return
grad
,
hess
...
@@ -213,11 +231,11 @@ class _EvalFunctionWrapper:
...
@@ -213,11 +231,11 @@ class _EvalFunctionWrapper:
labels
=
dataset
.
get_label
()
labels
=
dataset
.
get_label
()
argc
=
len
(
signature
(
self
.
func
).
parameters
)
argc
=
len
(
signature
(
self
.
func
).
parameters
)
if
argc
==
2
:
if
argc
==
2
:
return
self
.
func
(
labels
,
preds
)
return
self
.
func
(
labels
,
preds
)
# type: ignore[call-arg]
elif
argc
==
3
:
elif
argc
==
3
:
return
self
.
func
(
labels
,
preds
,
dataset
.
get_weight
())
return
self
.
func
(
labels
,
preds
,
dataset
.
get_weight
())
# type: ignore[call-arg]
elif
argc
==
4
:
elif
argc
==
4
:
return
self
.
func
(
labels
,
preds
,
dataset
.
get_weight
(),
dataset
.
get_group
())
return
self
.
func
(
labels
,
preds
,
dataset
.
get_weight
(),
dataset
.
get_group
())
# type: ignore[call-arg]
else
:
else
:
raise
TypeError
(
f
"Self-defined eval function should have 2, 3 or 4 arguments, got
{
argc
}
"
)
raise
TypeError
(
f
"Self-defined eval function should have 2, 3 or 4 arguments, got
{
argc
}
"
)
...
@@ -819,7 +837,7 @@ class LGBMModel(_LGBMModelBase):
...
@@ -819,7 +837,7 @@ class LGBMModel(_LGBMModelBase):
num_boost_round
=
self
.
n_estimators
,
num_boost_round
=
self
.
n_estimators
,
valid_sets
=
valid_sets
,
valid_sets
=
valid_sets
,
valid_names
=
eval_names
,
valid_names
=
eval_names
,
feval
=
eval_metrics_callable
,
feval
=
eval_metrics_callable
,
# type: ignore[arg-type]
init_model
=
init_model
,
init_model
=
init_model
,
feature_name
=
feature_name
,
feature_name
=
feature_name
,
callbacks
=
callbacks
callbacks
=
callbacks
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment