Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
f74875ed
Unverified
Commit
f74875ed
authored
Apr 18, 2023
by
James Lamb
Committed by
GitHub
Apr 18, 2023
Browse files
[python-package] move validation up earlier in cv() and train() (#5836)
parent
fd921d53
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
51 additions
and
10 deletions
+51
-10
python-package/lightgbm/engine.py
python-package/lightgbm/engine.py
+19
-10
tests/python_package_test/test_engine.py
tests/python_package_test/test_engine.py
+32
-0
No files found.
python-package/lightgbm/engine.py
View file @
f74875ed
...
...
@@ -141,6 +141,20 @@ def train(
booster : Booster
The trained Booster model.
"""
if
not
isinstance
(
train_set
,
Dataset
):
raise
TypeError
(
f
"train() only accepts Dataset object, train_set has type '
{
type
(
train_set
).
__name__
}
'."
)
if
num_boost_round
<=
0
:
raise
ValueError
(
f
"num_boost_round must be greater than 0. Got
{
num_boost_round
}
."
)
if
isinstance
(
valid_sets
,
list
):
for
i
,
valid_item
in
enumerate
(
valid_sets
):
if
not
isinstance
(
valid_item
,
Dataset
):
raise
TypeError
(
"Every item in valid_sets must be a Dataset object. "
f
"Item
{
i
}
has type '
{
type
(
valid_item
).
__name__
}
'."
)
# create predictor first
params
=
copy
.
deepcopy
(
params
)
params
=
_choose_param_value
(
...
...
@@ -167,17 +181,12 @@ def train(
params
.
pop
(
"early_stopping_round"
)
first_metric_only
=
params
.
get
(
'first_metric_only'
,
False
)
if
num_boost_round
<=
0
:
raise
ValueError
(
"num_boost_round should be greater than zero."
)
predictor
:
Optional
[
_InnerPredictor
]
=
None
if
isinstance
(
init_model
,
(
str
,
Path
)):
predictor
=
_InnerPredictor
(
model_file
=
init_model
,
pred_parameter
=
params
)
elif
isinstance
(
init_model
,
Booster
):
predictor
=
init_model
.
_to_predictor
(
pred_parameter
=
dict
(
init_model
.
params
,
**
params
))
init_iteration
=
predictor
.
num_total_iteration
if
predictor
is
not
None
else
0
# check dataset
if
not
isinstance
(
train_set
,
Dataset
):
raise
TypeError
(
"Training only accepts Dataset object"
)
train_set
.
_update_params
(
params
)
\
.
_set_predictor
(
predictor
)
\
...
...
@@ -200,8 +209,6 @@ def train(
if
valid_names
is
not
None
:
train_data_name
=
valid_names
[
i
]
continue
if
not
isinstance
(
valid_data
,
Dataset
):
raise
TypeError
(
"Training only accepts Dataset object"
)
reduced_valid_sets
.
append
(
valid_data
.
_update_params
(
params
).
set_reference
(
train_set
))
if
valid_names
is
not
None
and
len
(
valid_names
)
>
i
:
name_valid_sets
.
append
(
valid_names
[
i
])
...
...
@@ -647,7 +654,11 @@ def cv(
If ``return_cvbooster=True``, also returns trained boosters via ``cvbooster`` key.
"""
if
not
isinstance
(
train_set
,
Dataset
):
raise
TypeError
(
"Training only accepts Dataset object"
)
raise
TypeError
(
f
"cv() only accepts Dataset object, train_set has type '
{
type
(
train_set
).
__name__
}
'."
)
if
num_boost_round
<=
0
:
raise
ValueError
(
f
"num_boost_round must be greater than 0. Got
{
num_boost_round
}
."
)
params
=
copy
.
deepcopy
(
params
)
params
=
_choose_param_value
(
main_param_name
=
'objective'
,
...
...
@@ -673,8 +684,6 @@ def cv(
params
.
pop
(
"early_stopping_round"
)
first_metric_only
=
params
.
get
(
'first_metric_only'
,
False
)
if
num_boost_round
<=
0
:
raise
ValueError
(
"num_boost_round should be greater than zero."
)
if
isinstance
(
init_model
,
(
str
,
Path
)):
predictor
=
_InnerPredictor
(
model_file
=
init_model
,
pred_parameter
=
params
)
elif
isinstance
(
init_model
,
Booster
):
...
...
tests/python_package_test/test_engine.py
View file @
f74875ed
...
...
@@ -4017,6 +4017,38 @@ def test_validate_features():
bst
.
refit
(
df2
,
y
,
validate_features
=
False
)
def
test_train_and_cv_raise_informative_error_for_train_set_of_wrong_type
():
with
pytest
.
raises
(
TypeError
,
match
=
r
"train\(\) only accepts Dataset object, train_set has type 'list'\."
):
lgb
.
train
({},
train_set
=
[])
with
pytest
.
raises
(
TypeError
,
match
=
r
"cv\(\) only accepts Dataset object, train_set has type 'list'\."
):
lgb
.
cv
({},
train_set
=
[])
@
pytest
.
mark
.
parametrize
(
'num_boost_round'
,
[
-
7
,
-
1
,
0
])
def
test_train_and_cv_raise_informative_error_for_impossible_num_boost_round
(
num_boost_round
):
X
,
y
=
make_synthetic_regression
(
n_samples
=
100
)
error_msg
=
rf
"num_boost_round must be greater than 0\. Got
{
num_boost_round
}
\."
with
pytest
.
raises
(
ValueError
,
match
=
error_msg
):
lgb
.
train
({},
train_set
=
lgb
.
Dataset
(
X
,
y
),
num_boost_round
=
num_boost_round
)
with
pytest
.
raises
(
ValueError
,
match
=
error_msg
):
lgb
.
cv
({},
train_set
=
lgb
.
Dataset
(
X
,
y
),
num_boost_round
=
num_boost_round
)
def
test_train_raises_informative_error_if_any_valid_sets_are_not_dataset_objects
():
X
,
y
=
make_synthetic_regression
(
n_samples
=
100
)
X_valid
=
X
*
2.0
with
pytest
.
raises
(
TypeError
,
match
=
r
"Every item in valid_sets must be a Dataset object\. Item 1 has type 'tuple'\."
):
lgb
.
train
(
params
=
{},
train_set
=
lgb
.
Dataset
(
X
,
y
),
valid_sets
=
[
lgb
.
Dataset
(
X_valid
,
y
),
([
1.0
],
[
2.0
]),
[
5.6
,
5.7
,
5.8
]
]
)
def
test_train_raises_informative_error_for_params_of_wrong_type
():
X
,
y
=
make_synthetic_regression
()
params
=
{
"early_stopping_round"
:
"too-many"
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment