Unverified Commit f74875ed authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] move validation up earlier in cv() and train() (#5836)

parent fd921d53
......@@ -141,6 +141,20 @@ def train(
booster : Booster
The trained Booster model.
"""
if not isinstance(train_set, Dataset):
raise TypeError(f"train() only accepts Dataset object, train_set has type '{type(train_set).__name__}'.")
if num_boost_round <= 0:
raise ValueError(f"num_boost_round must be greater than 0. Got {num_boost_round}.")
if isinstance(valid_sets, list):
for i, valid_item in enumerate(valid_sets):
if not isinstance(valid_item, Dataset):
raise TypeError(
"Every item in valid_sets must be a Dataset object. "
f"Item {i} has type '{type(valid_item).__name__}'."
)
# create predictor first
params = copy.deepcopy(params)
params = _choose_param_value(
......@@ -167,17 +181,12 @@ def train(
params.pop("early_stopping_round")
first_metric_only = params.get('first_metric_only', False)
if num_boost_round <= 0:
raise ValueError("num_boost_round should be greater than zero.")
predictor: Optional[_InnerPredictor] = None
if isinstance(init_model, (str, Path)):
predictor = _InnerPredictor(model_file=init_model, pred_parameter=params)
elif isinstance(init_model, Booster):
predictor = init_model._to_predictor(pred_parameter=dict(init_model.params, **params))
init_iteration = predictor.num_total_iteration if predictor is not None else 0
# check dataset
if not isinstance(train_set, Dataset):
raise TypeError("Training only accepts Dataset object")
train_set._update_params(params) \
._set_predictor(predictor) \
......@@ -200,8 +209,6 @@ def train(
if valid_names is not None:
train_data_name = valid_names[i]
continue
if not isinstance(valid_data, Dataset):
raise TypeError("Training only accepts Dataset object")
reduced_valid_sets.append(valid_data._update_params(params).set_reference(train_set))
if valid_names is not None and len(valid_names) > i:
name_valid_sets.append(valid_names[i])
......@@ -647,7 +654,11 @@ def cv(
If ``return_cvbooster=True``, also returns trained boosters via ``cvbooster`` key.
"""
if not isinstance(train_set, Dataset):
raise TypeError("Training only accepts Dataset object")
raise TypeError(f"cv() only accepts Dataset object, train_set has type '{type(train_set).__name__}'.")
if num_boost_round <= 0:
raise ValueError(f"num_boost_round must be greater than 0. Got {num_boost_round}.")
params = copy.deepcopy(params)
params = _choose_param_value(
main_param_name='objective',
......@@ -673,8 +684,6 @@ def cv(
params.pop("early_stopping_round")
first_metric_only = params.get('first_metric_only', False)
if num_boost_round <= 0:
raise ValueError("num_boost_round should be greater than zero.")
if isinstance(init_model, (str, Path)):
predictor = _InnerPredictor(model_file=init_model, pred_parameter=params)
elif isinstance(init_model, Booster):
......
......@@ -4017,6 +4017,38 @@ def test_validate_features():
bst.refit(df2, y, validate_features=False)
def test_train_and_cv_raise_informative_error_for_train_set_of_wrong_type():
with pytest.raises(TypeError, match=r"train\(\) only accepts Dataset object, train_set has type 'list'\."):
lgb.train({}, train_set=[])
with pytest.raises(TypeError, match=r"cv\(\) only accepts Dataset object, train_set has type 'list'\."):
lgb.cv({}, train_set=[])
@pytest.mark.parametrize('num_boost_round', [-7, -1, 0])
def test_train_and_cv_raise_informative_error_for_impossible_num_boost_round(num_boost_round):
X, y = make_synthetic_regression(n_samples=100)
error_msg = rf"num_boost_round must be greater than 0\. Got {num_boost_round}\."
with pytest.raises(ValueError, match=error_msg):
lgb.train({}, train_set=lgb.Dataset(X, y), num_boost_round=num_boost_round)
with pytest.raises(ValueError, match=error_msg):
lgb.cv({}, train_set=lgb.Dataset(X, y), num_boost_round=num_boost_round)
def test_train_raises_informative_error_if_any_valid_sets_are_not_dataset_objects():
X, y = make_synthetic_regression(n_samples=100)
X_valid = X * 2.0
with pytest.raises(TypeError, match=r"Every item in valid_sets must be a Dataset object\. Item 1 has type 'tuple'\."):
lgb.train(
params={},
train_set=lgb.Dataset(X, y),
valid_sets=[
lgb.Dataset(X_valid, y),
([1.0], [2.0]),
[5.6, 5.7, 5.8]
]
)
def test_train_raises_informative_error_for_params_of_wrong_type():
X, y = make_synthetic_regression()
params = {"early_stopping_round": "too-many"}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment