"git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "ce3e31219c568323d66fa6e918fffcb96b3df92a"
Unverified Commit 0ec55c97 authored by Nick Miller's avatar Nick Miller Committed by GitHub
Browse files

[ci][python-package] Enforce and fix ruff flake8-pytest-style (PT) linting codes (#6959)



* Add "PT" ruff code to pyproject.toml plus some exceptions

* Extend PT011 for Exception, IndexError, TypeError

* Fix PT001: Use `@pytest.fixture` over `@pytest.fixture()`

* Fix PT011: pytest.raises({exception}) is too broad, set the match parameter or use a more specific exception

* Fix PT012: `pytest.raises()` block should contain a single simple statement

* Fix PT017: Found assertion on exception {name} in except block, use pytest.raises() instead

* Fix PT018: Assertion should be broken down into multiple parts

* Remove PT006 and PT007 from ignore list and fixup

* Use *Exception and *Error glob to apply PT011 and ignore AssertionErrors

---------
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent 54f481ef
...@@ -136,7 +136,9 @@ ignore = [ ...@@ -136,7 +136,9 @@ ignore = [
# (pylint) for loop variable overwritten by assignment target # (pylint) for loop variable overwritten by assignment target
"PLW2901", "PLW2901",
# (pylint) use 'elif' instead of 'else' then 'if', to reduce indentation # (pylint) use 'elif' instead of 'else' then 'if', to reduce indentation
"PLR5501" "PLR5501",
# (flake8-pytest-style) `scope='function'` is implied in `@pytest.fixture()`
"PT003"
] ]
select = [ select = [
# flake8-bugbear # flake8-bugbear
...@@ -155,6 +157,8 @@ select = [ ...@@ -155,6 +157,8 @@ select = [
"NPY", "NPY",
# pylint # pylint
"PL", "PL",
# flake8-pytest-style
"PT",
# flake8-return: unnecessary assignment before return # flake8-return: unnecessary assignment before return
"RET504", "RET504",
# flake8-return: superfluous-else-raise # flake8-return: superfluous-else-raise
...@@ -167,6 +171,9 @@ select = [ ...@@ -167,6 +171,9 @@ select = [
"W", "W",
] ]
[tool.ruff.lint.flake8-pytest-style]
raises-extend-require-match-for = ["*Exception", "*Error"]
[tool.ruff.lint.per-file-ignores] [tool.ruff.lint.per-file-ignores]
"docs/conf.py" = [ "docs/conf.py" = [
# (flake8-bugbear) raise exceptions with "raise ... from err" # (flake8-bugbear) raise exceptions with "raise ... from err"
......
...@@ -211,7 +211,7 @@ def test_dataset_construct_fields_fuzzy(): ...@@ -211,7 +211,7 @@ def test_dataset_construct_fields_fuzzy():
@pytest.mark.parametrize( @pytest.mark.parametrize(
["array_type", "label_data"], ("array_type", "label_data"),
[ [
(pa.array, [0, 1, 0, 0, 1]), (pa.array, [0, 1, 0, 0, 1]),
(pa.chunked_array, [[0], [1, 0, 0, 1]]), (pa.chunked_array, [[0], [1, 0, 0, 1]]),
...@@ -231,7 +231,7 @@ def test_dataset_construct_labels(array_type, label_data, arrow_type): ...@@ -231,7 +231,7 @@ def test_dataset_construct_labels(array_type, label_data, arrow_type):
@pytest.mark.parametrize( @pytest.mark.parametrize(
["array_type", "label_data"], ("array_type", "label_data"),
[ [
(pa.array, [False, True, False, False, True]), (pa.array, [False, True, False, False, True]),
(pa.chunked_array, [[False], [True, False, False, True]]), (pa.chunked_array, [[False], [True, False, False, True]]),
...@@ -262,7 +262,7 @@ def test_dataset_construct_weights_none(): ...@@ -262,7 +262,7 @@ def test_dataset_construct_weights_none():
@pytest.mark.parametrize( @pytest.mark.parametrize(
["array_type", "weight_data"], ("array_type", "weight_data"),
[ [
(pa.array, [3, 0.7, 1.5, 0.5, 0.1]), (pa.array, [3, 0.7, 1.5, 0.5, 0.1]),
(pa.chunked_array, [[3], [0.7, 1.5, 0.5, 0.1]]), (pa.chunked_array, [[3], [0.7, 1.5, 0.5, 0.1]]),
...@@ -285,7 +285,7 @@ def test_dataset_construct_weights(array_type, weight_data, arrow_type): ...@@ -285,7 +285,7 @@ def test_dataset_construct_weights(array_type, weight_data, arrow_type):
@pytest.mark.parametrize( @pytest.mark.parametrize(
["array_type", "group_data"], ("array_type", "group_data"),
[ [
(pa.array, [2, 3]), (pa.array, [2, 3]),
(pa.chunked_array, [[2], [3]]), (pa.chunked_array, [[2], [3]]),
...@@ -308,7 +308,7 @@ def test_dataset_construct_groups(array_type, group_data, arrow_type): ...@@ -308,7 +308,7 @@ def test_dataset_construct_groups(array_type, group_data, arrow_type):
@pytest.mark.parametrize( @pytest.mark.parametrize(
["array_type", "init_score_data"], ("array_type", "init_score_data"),
[ [
(pa.array, [0, 1, 2, 3, 3]), (pa.array, [0, 1, 2, 3, 3]),
(pa.chunked_array, [[0, 1, 2], [3, 3]]), (pa.chunked_array, [[0, 1, 2], [3, 3]]),
......
...@@ -281,17 +281,17 @@ def test_add_features_throws_if_datasets_unconstructed(rng): ...@@ -281,17 +281,17 @@ def test_add_features_throws_if_datasets_unconstructed(rng):
X1 = rng.uniform(size=(100, 1)) X1 = rng.uniform(size=(100, 1))
X2 = rng.uniform(size=(100, 1)) X2 = rng.uniform(size=(100, 1))
err_msg = "Both source and target Datasets must be constructed before adding features" err_msg = "Both source and target Datasets must be constructed before adding features"
d1 = lgb.Dataset(X1)
d2 = lgb.Dataset(X2)
with pytest.raises(ValueError, match=err_msg): with pytest.raises(ValueError, match=err_msg):
d1 = lgb.Dataset(X1)
d2 = lgb.Dataset(X2)
d1.add_features_from(d2) d1.add_features_from(d2)
d1 = lgb.Dataset(X1).construct()
d2 = lgb.Dataset(X2)
with pytest.raises(ValueError, match=err_msg): with pytest.raises(ValueError, match=err_msg):
d1 = lgb.Dataset(X1).construct()
d2 = lgb.Dataset(X2)
d1.add_features_from(d2) d1.add_features_from(d2)
d1 = lgb.Dataset(X1)
d2 = lgb.Dataset(X2).construct()
with pytest.raises(ValueError, match=err_msg): with pytest.raises(ValueError, match=err_msg):
d1 = lgb.Dataset(X1)
d2 = lgb.Dataset(X2).construct()
d1.add_features_from(d2) d1.add_features_from(d2)
...@@ -980,14 +980,16 @@ def test_no_copy_in_dataset_from_numpy_2d(rng, order, dtype): ...@@ -980,14 +980,16 @@ def test_no_copy_in_dataset_from_numpy_2d(rng, order, dtype):
def test_equal_datasets_from_row_major_and_col_major_data(tmp_path): def test_equal_datasets_from_row_major_and_col_major_data(tmp_path):
# row-major dataset # row-major dataset
X_row, y = make_blobs(n_samples=1_000, n_features=3, centers=2) X_row, y = make_blobs(n_samples=1_000, n_features=3, centers=2)
assert X_row.flags["C_CONTIGUOUS"] and not X_row.flags["F_CONTIGUOUS"] assert X_row.flags["C_CONTIGUOUS"]
assert not X_row.flags["F_CONTIGUOUS"]
ds_row = lgb.Dataset(X_row, y) ds_row = lgb.Dataset(X_row, y)
ds_row_path = tmp_path / "ds_row.txt" ds_row_path = tmp_path / "ds_row.txt"
ds_row._dump_text(ds_row_path) ds_row._dump_text(ds_row_path)
# col-major dataset # col-major dataset
X_col = np.asfortranarray(X_row) X_col = np.asfortranarray(X_row)
assert X_col.flags["F_CONTIGUOUS"] and not X_col.flags["C_CONTIGUOUS"] assert X_col.flags["F_CONTIGUOUS"]
assert not X_col.flags["C_CONTIGUOUS"]
ds_col = lgb.Dataset(X_col, y) ds_col = lgb.Dataset(X_col, y)
ds_col_path = tmp_path / "ds_col.txt" ds_col_path = tmp_path / "ds_col.txt"
ds_col._dump_text(ds_col_path) ds_col._dump_text(ds_col_path)
......
...@@ -80,7 +80,7 @@ def cluster_three_workers(): ...@@ -80,7 +80,7 @@ def cluster_three_workers():
dask_cluster.close() dask_cluster.close()
@pytest.fixture() @pytest.fixture
def listen_port(): def listen_port():
listen_port.port += 10 listen_port.port += 10
return listen_port.port return listen_port.port
...@@ -296,10 +296,10 @@ def test_classifier(output, task, boosting_type, tree_learner, cluster): ...@@ -296,10 +296,10 @@ def test_classifier(output, task, boosting_type, tree_learner, cluster):
assert_eq(p1_local, y) assert_eq(p1_local, y)
# extra predict() parameters should be passed through correctly # extra predict() parameters should be passed through correctly
with pytest.raises(AssertionError): with pytest.raises(AssertionError): # noqa: PT011
assert_eq(p1_raw, p1_first_iter_raw) assert_eq(p1_raw, p1_first_iter_raw)
with pytest.raises(AssertionError): with pytest.raises(AssertionError): # noqa: PT011
assert_eq(p1_raw, p1_early_stop_raw) assert_eq(p1_raw, p1_early_stop_raw)
# pref_leaf values should have the right shape # pref_leaf values should have the right shape
...@@ -551,7 +551,7 @@ def test_regressor(output, boosting_type, tree_learner, cluster): ...@@ -551,7 +551,7 @@ def test_regressor(output, boosting_type, tree_learner, cluster):
assert_eq(p2, y, rtol=0.5, atol=50.0) assert_eq(p2, y, rtol=0.5, atol=50.0)
# extra predict() parameters should be passed through correctly # extra predict() parameters should be passed through correctly
with pytest.raises(AssertionError): with pytest.raises(AssertionError): # noqa: PT011
assert_eq(p1_raw, p1_first_iter_raw) assert_eq(p1_raw, p1_first_iter_raw)
# be sure LightGBM actually used at least one categorical column, # be sure LightGBM actually used at least one categorical column,
...@@ -731,10 +731,10 @@ def test_ranker(output, group, boosting_type, tree_learner, cluster): ...@@ -731,10 +731,10 @@ def test_ranker(output, group, boosting_type, tree_learner, cluster):
assert_eq(rnkvec_dask, rnkvec_dask_local) assert_eq(rnkvec_dask, rnkvec_dask_local)
# extra predict() parameters should be passed through correctly # extra predict() parameters should be passed through correctly
with pytest.raises(AssertionError): with pytest.raises(AssertionError): # noqa: PT011
assert_eq(p1_raw, p1_first_iter_raw) assert_eq(p1_raw, p1_first_iter_raw)
with pytest.raises(AssertionError): with pytest.raises(AssertionError): # noqa: PT011
assert_eq(p1_raw, p1_early_stop_raw) assert_eq(p1_raw, p1_early_stop_raw)
# pref_leaf values should have the right shape # pref_leaf values should have the right shape
...@@ -1023,6 +1023,7 @@ def test_training_works_if_client_not_provided_or_set_after_construction(task, c ...@@ -1023,6 +1023,7 @@ def test_training_works_if_client_not_provided_or_set_after_construction(task, c
with pytest.raises(AttributeError, match=no_client_attr_msg): with pytest.raises(AttributeError, match=no_client_attr_msg):
local_model.client local_model.client
with pytest.raises(AttributeError, match=no_client_attr_msg):
local_model.client_ local_model.client_
# should be able to set client after construction # should be able to set client after construction
...@@ -1047,6 +1048,7 @@ def test_training_works_if_client_not_provided_or_set_after_construction(task, c ...@@ -1047,6 +1048,7 @@ def test_training_works_if_client_not_provided_or_set_after_construction(task, c
local_model = dask_model.to_local() local_model = dask_model.to_local()
with pytest.raises(AttributeError, match=no_client_attr_msg): with pytest.raises(AttributeError, match=no_client_attr_msg):
local_model.client local_model.client
with pytest.raises(AttributeError, match=no_client_attr_msg):
local_model.client_ local_model.client_
...@@ -1136,6 +1138,7 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici ...@@ -1136,6 +1138,7 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici
) )
with pytest.raises(AttributeError, match=no_client_attr_msg): with pytest.raises(AttributeError, match=no_client_attr_msg):
local_model.client local_model.client
with pytest.raises(AttributeError, match=no_client_attr_msg):
local_model.client_ local_model.client_
tmp_file2 = tmp_path / "model-2.pkl" tmp_file2 = tmp_path / "model-2.pkl"
...@@ -1233,7 +1236,7 @@ def test_errors(cluster): ...@@ -1233,7 +1236,7 @@ def test_errors(cluster):
df = dd.demo.make_timeseries() df = dd.demo.make_timeseries()
df = df.map_partitions(f, meta=df._meta) df = df.map_partitions(f, meta=df._meta)
with pytest.raises(Exception) as info: with pytest.raises(Exception) as info: # noqa: PT011, PT012 # error message needs to be coerced to a string
lgb.dask._train(client=client, data=df, label=df.x, params={}, model_factory=lgb.LGBMClassifier) lgb.dask._train(client=client, data=df, label=df.x, params={}, model_factory=lgb.LGBMClassifier)
assert "foo" in str(info.value) assert "foo" in str(info.value)
...@@ -1362,7 +1365,7 @@ def test_machines_should_be_used_if_provided(task, cluster): ...@@ -1362,7 +1365,7 @@ def test_machines_should_be_used_if_provided(task, cluster):
# test that "machines" is actually respected by creating a socket that uses # test that "machines" is actually respected by creating a socket that uses
# one of the ports mentioned in "machines" # one of the ports mentioned in "machines"
error_msg = f"Binding port {open_ports[0]} failed" error_msg = f"Binding port {open_ports[0]} failed"
with pytest.raises(lgb.basic.LightGBMError, match=error_msg): with pytest.raises(lgb.basic.LightGBMError, match=error_msg): # noqa: PT012
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind((workers_hostname, open_ports[0])) s.bind((workers_hostname, open_ports[0]))
dask_model.fit(dX, dy, group=dg) dask_model.fit(dX, dy, group=dg)
...@@ -1378,7 +1381,7 @@ def test_machines_should_be_used_if_provided(task, cluster): ...@@ -1378,7 +1381,7 @@ def test_machines_should_be_used_if_provided(task, cluster):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"dask_est,sklearn_est", ("dask_est", "sklearn_est"),
[ [
(lgb.DaskLGBMClassifier, lgb.LGBMClassifier), (lgb.DaskLGBMClassifier, lgb.LGBMClassifier),
(lgb.DaskLGBMRegressor, lgb.LGBMRegressor), (lgb.DaskLGBMRegressor, lgb.LGBMRegressor),
...@@ -1453,7 +1456,8 @@ def test_training_succeeds_when_data_is_dataframe_and_label_is_column_array(task ...@@ -1453,7 +1456,8 @@ def test_training_succeeds_when_data_is_dataframe_and_label_is_column_array(task
dy = dy.to_dask_array(lengths=True) dy = dy.to_dask_array(lengths=True)
dy_col_array = dy.reshape(-1, 1) dy_col_array = dy.reshape(-1, 1)
assert len(dy_col_array.shape) == 2 and dy_col_array.shape[1] == 1 assert len(dy_col_array.shape) == 2
assert dy_col_array.shape[1] == 1
params = {"n_estimators": 1, "num_leaves": 3, "random_state": 0, "time_out": 5} params = {"n_estimators": 1, "num_leaves": 3, "random_state": 0, "time_out": 5}
model = model_factory(**params) model = model_factory(**params)
...@@ -1499,7 +1503,7 @@ def test_init_score(task, output, cluster, rng): ...@@ -1499,7 +1503,7 @@ def test_init_score(task, output, cluster, rng):
pred_init_score = model_init_score.predict(dX, raw_score=True) pred_init_score = model_init_score.predict(dX, raw_score=True)
# check if init score changes predictions # check if init score changes predictions
with pytest.raises(AssertionError): with pytest.raises(AssertionError): # noqa: PT011
assert_eq(pred, pred_init_score) assert_eq(pred, pred_init_score)
......
...@@ -662,7 +662,7 @@ def test_ranking_prediction_early_stopping(): ...@@ -662,7 +662,7 @@ def test_ranking_prediction_early_stopping():
pred_parameter["pred_early_stop_margin"] = 5.5 pred_parameter["pred_early_stop_margin"] = 5.5
ret_early_more_strict = gbm.predict(X_test, **pred_parameter) ret_early_more_strict = gbm.predict(X_test, **pred_parameter)
with pytest.raises(AssertionError): with pytest.raises(AssertionError): # noqa: PT011
np.testing.assert_allclose(ret_early, ret_early_more_strict) np.testing.assert_allclose(ret_early, ret_early_more_strict)
...@@ -1828,18 +1828,18 @@ def test_pandas_categorical(rng_fixed_seed, tmp_path): ...@@ -1828,18 +1828,18 @@ def test_pandas_categorical(rng_fixed_seed, tmp_path):
gbm7 = lgb.train(params, lgb_train, num_boost_round=10) gbm7 = lgb.train(params, lgb_train, num_boost_round=10)
pred8 = gbm7.predict(X_test) pred8 = gbm7.predict(X_test)
assert lgb_train.categorical_feature == [] assert lgb_train.categorical_feature == []
with pytest.raises(AssertionError): with pytest.raises(AssertionError): # noqa: PT011
np.testing.assert_allclose(pred0, pred1) np.testing.assert_allclose(pred0, pred1)
with pytest.raises(AssertionError): with pytest.raises(AssertionError): # noqa: PT011
np.testing.assert_allclose(pred0, pred2) np.testing.assert_allclose(pred0, pred2)
np.testing.assert_allclose(pred1, pred2) np.testing.assert_allclose(pred1, pred2)
np.testing.assert_allclose(pred0, pred3) np.testing.assert_allclose(pred0, pred3)
np.testing.assert_allclose(pred0, pred4) np.testing.assert_allclose(pred0, pred4)
np.testing.assert_allclose(pred0, pred5) np.testing.assert_allclose(pred0, pred5)
np.testing.assert_allclose(pred0, pred6) np.testing.assert_allclose(pred0, pred6)
with pytest.raises(AssertionError): with pytest.raises(AssertionError): # noqa: PT011
np.testing.assert_allclose(pred0, pred7) # ordered cat features aren't treated as cat features by default np.testing.assert_allclose(pred0, pred7) # ordered cat features aren't treated as cat features by default
with pytest.raises(AssertionError): with pytest.raises(AssertionError): # noqa: PT011
np.testing.assert_allclose(pred0, pred8) np.testing.assert_allclose(pred0, pred8)
assert gbm0.pandas_categorical == cat_values assert gbm0.pandas_categorical == cat_values
assert gbm1.pandas_categorical == cat_values assert gbm1.pandas_categorical == cat_values
...@@ -4794,14 +4794,16 @@ def test_bagging_by_query_in_lambdarank(): ...@@ -4794,14 +4794,16 @@ def test_bagging_by_query_in_lambdarank():
def test_equal_predict_from_row_major_and_col_major_data(): def test_equal_predict_from_row_major_and_col_major_data():
X_row, y = make_synthetic_regression() X_row, y = make_synthetic_regression()
assert X_row.flags["C_CONTIGUOUS"] and not X_row.flags["F_CONTIGUOUS"] assert X_row.flags["C_CONTIGUOUS"]
assert not X_row.flags["F_CONTIGUOUS"]
ds = lgb.Dataset(X_row, y) ds = lgb.Dataset(X_row, y)
params = {"num_leaves": 8, "verbose": -1} params = {"num_leaves": 8, "verbose": -1}
bst = lgb.train(params, ds, num_boost_round=5) bst = lgb.train(params, ds, num_boost_round=5)
preds_row = bst.predict(X_row) preds_row = bst.predict(X_row)
X_col = np.asfortranarray(X_row) X_col = np.asfortranarray(X_row)
assert X_col.flags["F_CONTIGUOUS"] and not X_col.flags["C_CONTIGUOUS"] assert X_col.flags["F_CONTIGUOUS"]
assert not X_col.flags["C_CONTIGUOUS"]
preds_col = bst.predict(X_col) preds_col = bst.predict(X_col)
np.testing.assert_allclose(preds_row, preds_col) np.testing.assert_allclose(preds_row, preds_col)
...@@ -506,7 +506,7 @@ def test_clone_and_property(): ...@@ -506,7 +506,7 @@ def test_clone_and_property():
assert isinstance(clf.feature_importances_, np.ndarray) assert isinstance(clf.feature_importances_, np.ndarray)
@pytest.mark.parametrize("estimator", (lgb.LGBMClassifier, lgb.LGBMRegressor, lgb.LGBMRanker)) @pytest.mark.parametrize("estimator", (lgb.LGBMClassifier, lgb.LGBMRegressor, lgb.LGBMRanker)) # noqa: PT007
def test_estimators_all_have_the_same_kwargs_and_defaults(estimator): def test_estimators_all_have_the_same_kwargs_and_defaults(estimator):
base_spec = inspect.getfullargspec(lgb.LGBMModel) base_spec = inspect.getfullargspec(lgb.LGBMModel)
subclass_spec = inspect.getfullargspec(estimator) subclass_spec = inspect.getfullargspec(estimator)
...@@ -760,7 +760,7 @@ def test_random_state_object(rng_constructor): ...@@ -760,7 +760,7 @@ def test_random_state_object(rng_constructor):
df3 = clf1.booster_.model_to_string(num_iteration=0) df3 = clf1.booster_.model_to_string(num_iteration=0)
assert clf1.random_state is state1 assert clf1.random_state is state1
assert clf2.random_state is state2 assert clf2.random_state is state2
with pytest.raises(AssertionError): with pytest.raises(AssertionError): # noqa: PT011
np.testing.assert_allclose(y_pred1, y_pred1_refit) np.testing.assert_allclose(y_pred1, y_pred1_refit)
assert df1 != df3 assert df1 != df3
...@@ -832,16 +832,16 @@ def test_pandas_categorical(rng_fixed_seed, tmp_path): ...@@ -832,16 +832,16 @@ def test_pandas_categorical(rng_fixed_seed, tmp_path):
pred5 = gbm5.predict(X_test, raw_score=True) pred5 = gbm5.predict(X_test, raw_score=True)
gbm6 = lgb.sklearn.LGBMClassifier(n_estimators=10).fit(X, y, categorical_feature=[]) gbm6 = lgb.sklearn.LGBMClassifier(n_estimators=10).fit(X, y, categorical_feature=[])
pred6 = gbm6.predict(X_test, raw_score=True) pred6 = gbm6.predict(X_test, raw_score=True)
with pytest.raises(AssertionError): with pytest.raises(AssertionError): # noqa: PT011
np.testing.assert_allclose(pred0, pred1) np.testing.assert_allclose(pred0, pred1)
with pytest.raises(AssertionError): with pytest.raises(AssertionError): # noqa: PT011
np.testing.assert_allclose(pred0, pred2) np.testing.assert_allclose(pred0, pred2)
np.testing.assert_allclose(pred1, pred2) np.testing.assert_allclose(pred1, pred2)
np.testing.assert_allclose(pred0, pred3) np.testing.assert_allclose(pred0, pred3)
np.testing.assert_allclose(pred_prob, pred4) np.testing.assert_allclose(pred_prob, pred4)
with pytest.raises(AssertionError): with pytest.raises(AssertionError): # noqa: PT011
np.testing.assert_allclose(pred0, pred5) # ordered cat features aren't treated as cat features by default np.testing.assert_allclose(pred0, pred5) # ordered cat features aren't treated as cat features by default
with pytest.raises(AssertionError): with pytest.raises(AssertionError): # noqa: PT011
np.testing.assert_allclose(pred0, pred6) np.testing.assert_allclose(pred0, pred6)
assert gbm0.booster_.pandas_categorical == cat_values assert gbm0.booster_.pandas_categorical == cat_values
assert gbm1.booster_.pandas_categorical == cat_values assert gbm1.booster_.pandas_categorical == cat_values
...@@ -916,7 +916,7 @@ def test_predict(): ...@@ -916,7 +916,7 @@ def test_predict():
# Tests other parameters for the prediction works # Tests other parameters for the prediction works
res_engine = gbm.predict(X_test) res_engine = gbm.predict(X_test)
res_sklearn_params = clf.predict_proba(X_test, pred_early_stop=True, pred_early_stop_margin=1.0) res_sklearn_params = clf.predict_proba(X_test, pred_early_stop=True, pred_early_stop_margin=1.0)
with pytest.raises(AssertionError): with pytest.raises(AssertionError): # noqa: PT011
np.testing.assert_allclose(res_engine, res_sklearn_params) np.testing.assert_allclose(res_engine, res_sklearn_params)
# Tests start_iteration # Tests start_iteration
...@@ -948,7 +948,7 @@ def test_predict(): ...@@ -948,7 +948,7 @@ def test_predict():
# Tests other parameters for the prediction works, starting from iteration 10 # Tests other parameters for the prediction works, starting from iteration 10
res_engine = gbm.predict(X_test, start_iteration=10) res_engine = gbm.predict(X_test, start_iteration=10)
res_sklearn_params = clf.predict_proba(X_test, pred_early_stop=True, pred_early_stop_margin=1.0, start_iteration=10) res_sklearn_params = clf.predict_proba(X_test, pred_early_stop=True, pred_early_stop_margin=1.0, start_iteration=10)
with pytest.raises(AssertionError): with pytest.raises(AssertionError): # noqa: PT011
np.testing.assert_allclose(res_engine, res_sklearn_params) np.testing.assert_allclose(res_engine, res_sklearn_params)
# Test multiclass binary classification # Test multiclass binary classification
...@@ -982,7 +982,7 @@ def test_predict_with_params_from_init(): ...@@ -982,7 +982,7 @@ def test_predict_with_params_from_init():
y_preds_params_in_predict = ( y_preds_params_in_predict = (
lgb.LGBMClassifier(verbose=-1).fit(X_train, y_train).predict(X_test, raw_score=True, **predict_params) lgb.LGBMClassifier(verbose=-1).fit(X_train, y_train).predict(X_test, raw_score=True, **predict_params)
) )
with pytest.raises(AssertionError): with pytest.raises(AssertionError): # noqa: PT011
np.testing.assert_allclose(y_preds_no_params, y_preds_params_in_predict) np.testing.assert_allclose(y_preds_no_params, y_preds_params_in_predict)
y_preds_params_in_set_params_before_fit = ( y_preds_params_in_set_params_before_fit = (
...@@ -1645,7 +1645,8 @@ def test_getting_feature_names_in_np_input(estimator_class): ...@@ -1645,7 +1645,8 @@ def test_getting_feature_names_in_np_input(estimator_class):
def test_getting_feature_names_in_pd_input(estimator_class): def test_getting_feature_names_in_pd_input(estimator_class):
X, y = load_digits(n_class=2, return_X_y=True, as_frame=True) X, y = load_digits(n_class=2, return_X_y=True, as_frame=True)
col_names = X.columns.to_list() col_names = X.columns.to_list()
assert isinstance(col_names, list) and all(isinstance(c, str) for c in col_names), ( assert isinstance(col_names, list)
assert all(isinstance(c, str) for c in col_names), (
"input data must have feature names for this test to cover the expected functionality" "input data must have feature names for this test to cover the expected functionality"
) )
params = {"n_estimators": 2, "num_leaves": 7} params = {"n_estimators": 2, "num_leaves": 7}
...@@ -1703,9 +1704,10 @@ def test_sklearn_tags_should_correctly_reflect_lightgbm_specific_values(estimato ...@@ -1703,9 +1704,10 @@ def test_sklearn_tags_should_correctly_reflect_lightgbm_specific_values(estimato
# minimum supported scikit-learn version is at least 1.6 # minimum supported scikit-learn version is at least 1.6
try: try:
sklearn_tags = est.__sklearn_tags__() sklearn_tags = est.__sklearn_tags__()
except AttributeError as err: except AttributeError:
# only the exact error we expected to be raised should be raised # only the exact error we expected to be raised should be raised
assert bool(re.search(r"__sklearn_tags__.* should not be called", str(err))) with pytest.raises(AttributeError, match=r"__sklearn_tags__.* should not be called"):
est.__sklearn_tags__()
else: else:
# if no AttributeError was thrown, we must be using scikit-learn>=1.6, # if no AttributeError was thrown, we must be using scikit-learn>=1.6,
# and so the actual effects of __sklearn_tags__() should be tested # and so the actual effects of __sklearn_tags__() should be tested
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment