Unverified Commit 9fc348af authored by José Morales's avatar José Morales Committed by GitHub
Browse files

[python-package] make record_evaluation compatible with cv (fixes #4943) (#4947)

* make record_evaluation compatible with cv

* test multiple metrics in cv

* lint

* fix cv with train metric. save stdv as well

* always add dataset prefix to cv_agg

* remove unused function
parent 2d1caf14
...@@ -131,15 +131,30 @@ def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable: ...@@ -131,15 +131,30 @@ def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable:
def _init(env: CallbackEnv) -> None: def _init(env: CallbackEnv) -> None:
eval_result.clear() eval_result.clear()
for data_name, eval_name, _, _ in env.evaluation_result_list: for item in env.evaluation_result_list:
if len(item) == 4: # regular train
data_name, eval_name = item[:2]
else: # cv
data_name, eval_name = item[1].split()
eval_result.setdefault(data_name, collections.OrderedDict()) eval_result.setdefault(data_name, collections.OrderedDict())
eval_result[data_name].setdefault(eval_name, []) if len(item) == 4:
eval_result[data_name].setdefault(eval_name, [])
else:
eval_result[data_name].setdefault(f'{eval_name}-mean', [])
eval_result[data_name].setdefault(f'{eval_name}-stdv', [])
def _callback(env: CallbackEnv) -> None: def _callback(env: CallbackEnv) -> None:
if env.iteration == env.begin_iteration: if env.iteration == env.begin_iteration:
_init(env) _init(env)
for data_name, eval_name, result, _ in env.evaluation_result_list: for item in env.evaluation_result_list:
eval_result[data_name][eval_name].append(result) if len(item) == 4:
data_name, eval_name, result = item[:3]
eval_result[data_name][eval_name].append(result)
else:
data_name, eval_name = item[1].split()
res_mean, res_stdv = item[2], item[4]
eval_result[data_name][f'{eval_name}-mean'].append(res_mean)
eval_result[data_name][f'{eval_name}-stdv'].append(res_stdv)
_callback.order = 20 # type: ignore _callback.order = 20 # type: ignore
return _callback return _callback
......
...@@ -361,16 +361,13 @@ def _make_n_folds(full_data, folds, nfold, params, seed, fpreproc=None, stratifi ...@@ -361,16 +361,13 @@ def _make_n_folds(full_data, folds, nfold, params, seed, fpreproc=None, stratifi
return ret return ret
def _agg_cv_result(raw_results, eval_train_metric=False): def _agg_cv_result(raw_results):
"""Aggregate cross-validation results.""" """Aggregate cross-validation results."""
cvmap = collections.OrderedDict() cvmap = collections.OrderedDict()
metric_type = {} metric_type = {}
for one_result in raw_results: for one_result in raw_results:
for one_line in one_result: for one_line in one_result:
if eval_train_metric: key = f"{one_line[0]} {one_line[1]}"
key = f"{one_line[0]} {one_line[1]}"
else:
key = one_line[1]
metric_type[key] = one_line[3] metric_type[key] = one_line[3]
cvmap.setdefault(key, []) cvmap.setdefault(key, [])
cvmap[key].append(one_line[2]) cvmap[key].append(one_line[2])
...@@ -573,7 +570,7 @@ def cv(params, train_set, num_boost_round=100, ...@@ -573,7 +570,7 @@ def cv(params, train_set, num_boost_round=100,
end_iteration=num_boost_round, end_iteration=num_boost_round,
evaluation_result_list=None)) evaluation_result_list=None))
cvfolds.update(fobj=fobj) cvfolds.update(fobj=fobj)
res = _agg_cv_result(cvfolds.eval_valid(feval), eval_train_metric) res = _agg_cv_result(cvfolds.eval_valid(feval))
for _, key, mean, _, std in res: for _, key, mean, _, std in res:
results[f'{key}-mean'].append(mean) results[f'{key}-mean'].append(mean)
results[f'{key}-stdv'].append(std) results[f'{key}-stdv'].append(std)
......
...@@ -971,15 +971,15 @@ def test_cv(): ...@@ -971,15 +971,15 @@ def test_cv():
params_with_metric = {'metric': 'l2', 'verbose': -1} params_with_metric = {'metric': 'l2', 'verbose': -1}
cv_res = lgb.cv(params_with_metric, lgb_train, num_boost_round=10, cv_res = lgb.cv(params_with_metric, lgb_train, num_boost_round=10,
nfold=3, stratified=False, shuffle=False, metrics='l1') nfold=3, stratified=False, shuffle=False, metrics='l1')
assert 'l1-mean' in cv_res assert 'valid l1-mean' in cv_res
assert 'l2-mean' not in cv_res assert 'valid l2-mean' not in cv_res
assert len(cv_res['l1-mean']) == 10 assert len(cv_res['valid l1-mean']) == 10
# shuffle = True, callbacks # shuffle = True, callbacks
cv_res = lgb.cv(params, lgb_train, num_boost_round=10, nfold=3, cv_res = lgb.cv(params, lgb_train, num_boost_round=10, nfold=3,
stratified=False, shuffle=True, metrics='l1', stratified=False, shuffle=True, metrics='l1',
callbacks=[lgb.reset_parameter(learning_rate=lambda i: 0.1 - 0.001 * i)]) callbacks=[lgb.reset_parameter(learning_rate=lambda i: 0.1 - 0.001 * i)])
assert 'l1-mean' in cv_res assert 'valid l1-mean' in cv_res
assert len(cv_res['l1-mean']) == 10 assert len(cv_res['valid l1-mean']) == 10
# enable display training loss # enable display training loss
cv_res = lgb.cv(params_with_metric, lgb_train, num_boost_round=10, cv_res = lgb.cv(params_with_metric, lgb_train, num_boost_round=10,
nfold=3, stratified=False, shuffle=False, nfold=3, stratified=False, shuffle=False,
...@@ -995,7 +995,7 @@ def test_cv(): ...@@ -995,7 +995,7 @@ def test_cv():
folds = tss.split(X_train) folds = tss.split(X_train)
cv_res_gen = lgb.cv(params_with_metric, lgb_train, num_boost_round=10, folds=folds) cv_res_gen = lgb.cv(params_with_metric, lgb_train, num_boost_round=10, folds=folds)
cv_res_obj = lgb.cv(params_with_metric, lgb_train, num_boost_round=10, folds=tss) cv_res_obj = lgb.cv(params_with_metric, lgb_train, num_boost_round=10, folds=tss)
np.testing.assert_allclose(cv_res_gen['l2-mean'], cv_res_obj['l2-mean']) np.testing.assert_allclose(cv_res_gen['valid l2-mean'], cv_res_obj['valid l2-mean'])
# LambdaRank # LambdaRank
rank_example_dir = Path(__file__).absolute().parents[2] / 'examples' / 'lambdarank' rank_example_dir = Path(__file__).absolute().parents[2] / 'examples' / 'lambdarank'
X_train, y_train = load_svmlight_file(str(rank_example_dir / 'rank.train')) X_train, y_train = load_svmlight_file(str(rank_example_dir / 'rank.train'))
...@@ -1005,15 +1005,15 @@ def test_cv(): ...@@ -1005,15 +1005,15 @@ def test_cv():
# ... with l2 metric # ... with l2 metric
cv_res_lambda = lgb.cv(params_lambdarank, lgb_train, num_boost_round=10, nfold=3, metrics='l2') cv_res_lambda = lgb.cv(params_lambdarank, lgb_train, num_boost_round=10, nfold=3, metrics='l2')
assert len(cv_res_lambda) == 2 assert len(cv_res_lambda) == 2
assert not np.isnan(cv_res_lambda['l2-mean']).any() assert not np.isnan(cv_res_lambda['valid l2-mean']).any()
# ... with NDCG (default) metric # ... with NDCG (default) metric
cv_res_lambda = lgb.cv(params_lambdarank, lgb_train, num_boost_round=10, nfold=3) cv_res_lambda = lgb.cv(params_lambdarank, lgb_train, num_boost_round=10, nfold=3)
assert len(cv_res_lambda) == 2 assert len(cv_res_lambda) == 2
assert not np.isnan(cv_res_lambda['ndcg@3-mean']).any() assert not np.isnan(cv_res_lambda['valid ndcg@3-mean']).any()
# self defined folds with lambdarank # self defined folds with lambdarank
cv_res_lambda_obj = lgb.cv(params_lambdarank, lgb_train, num_boost_round=10, cv_res_lambda_obj = lgb.cv(params_lambdarank, lgb_train, num_boost_round=10,
folds=GroupKFold(n_splits=3)) folds=GroupKFold(n_splits=3))
np.testing.assert_allclose(cv_res_lambda['ndcg@3-mean'], cv_res_lambda_obj['ndcg@3-mean']) np.testing.assert_allclose(cv_res_lambda['valid ndcg@3-mean'], cv_res_lambda_obj['valid ndcg@3-mean'])
def test_cvbooster(): def test_cvbooster():
...@@ -1859,8 +1859,8 @@ def test_fpreproc(): ...@@ -1859,8 +1859,8 @@ def test_fpreproc():
dataset = lgb.Dataset(X, y, free_raw_data=False) dataset = lgb.Dataset(X, y, free_raw_data=False)
params = {'objective': 'multiclass', 'num_class': 3, 'verbose': -1} params = {'objective': 'multiclass', 'num_class': 3, 'verbose': -1}
results = lgb.cv(params, dataset, num_boost_round=10, fpreproc=preprocess_data) results = lgb.cv(params, dataset, num_boost_round=10, fpreproc=preprocess_data)
assert 'multi_logloss-mean' in results assert 'valid multi_logloss-mean' in results
assert len(results['multi_logloss-mean']) == 10 assert len(results['valid multi_logloss-mean']) == 10
def test_metrics(): def test_metrics():
...@@ -1902,39 +1902,39 @@ def test_metrics(): ...@@ -1902,39 +1902,39 @@ def test_metrics():
# default metric # default metric
res = get_cv_result() res = get_cv_result()
assert len(res) == 2 assert len(res) == 2
assert 'binary_logloss-mean' in res assert 'valid binary_logloss-mean' in res
# non-default metric in params # non-default metric in params
res = get_cv_result(params=params_obj_metric_err_verbose) res = get_cv_result(params=params_obj_metric_err_verbose)
assert len(res) == 2 assert len(res) == 2
assert 'binary_error-mean' in res assert 'valid binary_error-mean' in res
# default metric in args # default metric in args
res = get_cv_result(metrics='binary_logloss') res = get_cv_result(metrics='binary_logloss')
assert len(res) == 2 assert len(res) == 2
assert 'binary_logloss-mean' in res assert 'valid binary_logloss-mean' in res
# non-default metric in args # non-default metric in args
res = get_cv_result(metrics='binary_error') res = get_cv_result(metrics='binary_error')
assert len(res) == 2 assert len(res) == 2
assert 'binary_error-mean' in res assert 'valid binary_error-mean' in res
# metric in args overwrites one in params # metric in args overwrites one in params
res = get_cv_result(params=params_obj_metric_inv_verbose, metrics='binary_error') res = get_cv_result(params=params_obj_metric_inv_verbose, metrics='binary_error')
assert len(res) == 2 assert len(res) == 2
assert 'binary_error-mean' in res assert 'valid binary_error-mean' in res
# multiple metrics in params # multiple metrics in params
res = get_cv_result(params=params_obj_metric_multi_verbose) res = get_cv_result(params=params_obj_metric_multi_verbose)
assert len(res) == 4 assert len(res) == 4
assert 'binary_logloss-mean' in res assert 'valid binary_logloss-mean' in res
assert 'binary_error-mean' in res assert 'valid binary_error-mean' in res
# multiple metrics in args # multiple metrics in args
res = get_cv_result(metrics=['binary_logloss', 'binary_error']) res = get_cv_result(metrics=['binary_logloss', 'binary_error'])
assert len(res) == 4 assert len(res) == 4
assert 'binary_logloss-mean' in res assert 'valid binary_logloss-mean' in res
assert 'binary_error-mean' in res assert 'valid binary_error-mean' in res
# remove default metric by 'None' in list # remove default metric by 'None' in list
res = get_cv_result(metrics=['None']) res = get_cv_result(metrics=['None'])
...@@ -1953,126 +1953,126 @@ def test_metrics(): ...@@ -1953,126 +1953,126 @@ def test_metrics():
# metric in params # metric in params
res = get_cv_result(params=params_metric_err_verbose, fobj=dummy_obj) res = get_cv_result(params=params_metric_err_verbose, fobj=dummy_obj)
assert len(res) == 2 assert len(res) == 2
assert 'binary_error-mean' in res assert 'valid binary_error-mean' in res
# metric in args # metric in args
res = get_cv_result(params=params_verbose, fobj=dummy_obj, metrics='binary_error') res = get_cv_result(params=params_verbose, fobj=dummy_obj, metrics='binary_error')
assert len(res) == 2 assert len(res) == 2
assert 'binary_error-mean' in res assert 'valid binary_error-mean' in res
# metric in args overwrites its' alias in params # metric in args overwrites its' alias in params
res = get_cv_result(params=params_metric_inv_verbose, fobj=dummy_obj, metrics='binary_error') res = get_cv_result(params=params_metric_inv_verbose, fobj=dummy_obj, metrics='binary_error')
assert len(res) == 2 assert len(res) == 2
assert 'binary_error-mean' in res assert 'valid binary_error-mean' in res
# multiple metrics in params # multiple metrics in params
res = get_cv_result(params=params_metric_multi_verbose, fobj=dummy_obj) res = get_cv_result(params=params_metric_multi_verbose, fobj=dummy_obj)
assert len(res) == 4 assert len(res) == 4
assert 'binary_logloss-mean' in res assert 'valid binary_logloss-mean' in res
assert 'binary_error-mean' in res assert 'valid binary_error-mean' in res
# multiple metrics in args # multiple metrics in args
res = get_cv_result(params=params_verbose, fobj=dummy_obj, res = get_cv_result(params=params_verbose, fobj=dummy_obj,
metrics=['binary_logloss', 'binary_error']) metrics=['binary_logloss', 'binary_error'])
assert len(res) == 4 assert len(res) == 4
assert 'binary_logloss-mean' in res assert 'valid binary_logloss-mean' in res
assert 'binary_error-mean' in res assert 'valid binary_error-mean' in res
# no fobj, feval # no fobj, feval
# default metric with custom one # default metric with custom one
res = get_cv_result(feval=constant_metric) res = get_cv_result(feval=constant_metric)
assert len(res) == 4 assert len(res) == 4
assert 'binary_logloss-mean' in res assert 'valid binary_logloss-mean' in res
assert 'error-mean' in res assert 'valid error-mean' in res
# non-default metric in params with custom one # non-default metric in params with custom one
res = get_cv_result(params=params_obj_metric_err_verbose, feval=constant_metric) res = get_cv_result(params=params_obj_metric_err_verbose, feval=constant_metric)
assert len(res) == 4 assert len(res) == 4
assert 'binary_error-mean' in res assert 'valid binary_error-mean' in res
assert 'error-mean' in res assert 'valid error-mean' in res
# default metric in args with custom one # default metric in args with custom one
res = get_cv_result(metrics='binary_logloss', feval=constant_metric) res = get_cv_result(metrics='binary_logloss', feval=constant_metric)
assert len(res) == 4 assert len(res) == 4
assert 'binary_logloss-mean' in res assert 'valid binary_logloss-mean' in res
assert 'error-mean' in res assert 'valid error-mean' in res
# non-default metric in args with custom one # non-default metric in args with custom one
res = get_cv_result(metrics='binary_error', feval=constant_metric) res = get_cv_result(metrics='binary_error', feval=constant_metric)
assert len(res) == 4 assert len(res) == 4
assert 'binary_error-mean' in res assert 'valid binary_error-mean' in res
assert 'error-mean' in res assert 'valid error-mean' in res
# metric in args overwrites one in params, custom one is evaluated too # metric in args overwrites one in params, custom one is evaluated too
res = get_cv_result(params=params_obj_metric_inv_verbose, metrics='binary_error', feval=constant_metric) res = get_cv_result(params=params_obj_metric_inv_verbose, metrics='binary_error', feval=constant_metric)
assert len(res) == 4 assert len(res) == 4
assert 'binary_error-mean' in res assert 'valid binary_error-mean' in res
assert 'error-mean' in res assert 'valid error-mean' in res
# multiple metrics in params with custom one # multiple metrics in params with custom one
res = get_cv_result(params=params_obj_metric_multi_verbose, feval=constant_metric) res = get_cv_result(params=params_obj_metric_multi_verbose, feval=constant_metric)
assert len(res) == 6 assert len(res) == 6
assert 'binary_logloss-mean' in res assert 'valid binary_logloss-mean' in res
assert 'binary_error-mean' in res assert 'valid binary_error-mean' in res
assert 'error-mean' in res assert 'valid error-mean' in res
# multiple metrics in args with custom one # multiple metrics in args with custom one
res = get_cv_result(metrics=['binary_logloss', 'binary_error'], feval=constant_metric) res = get_cv_result(metrics=['binary_logloss', 'binary_error'], feval=constant_metric)
assert len(res) == 6 assert len(res) == 6
assert 'binary_logloss-mean' in res assert 'valid binary_logloss-mean' in res
assert 'binary_error-mean' in res assert 'valid binary_error-mean' in res
assert 'error-mean' in res assert 'valid error-mean' in res
# custom metric is evaluated despite 'None' is passed # custom metric is evaluated despite 'None' is passed
res = get_cv_result(metrics=['None'], feval=constant_metric) res = get_cv_result(metrics=['None'], feval=constant_metric)
assert len(res) == 2 assert len(res) == 2
assert 'error-mean' in res assert 'valid error-mean' in res
# fobj, feval # fobj, feval
# no default metric, only custom one # no default metric, only custom one
res = get_cv_result(params=params_verbose, fobj=dummy_obj, feval=constant_metric) res = get_cv_result(params=params_verbose, fobj=dummy_obj, feval=constant_metric)
assert len(res) == 2 assert len(res) == 2
assert 'error-mean' in res assert 'valid error-mean' in res
# metric in params with custom one # metric in params with custom one
res = get_cv_result(params=params_metric_err_verbose, fobj=dummy_obj, feval=constant_metric) res = get_cv_result(params=params_metric_err_verbose, fobj=dummy_obj, feval=constant_metric)
assert len(res) == 4 assert len(res) == 4
assert 'binary_error-mean' in res assert 'valid binary_error-mean' in res
assert 'error-mean' in res assert 'valid error-mean' in res
# metric in args with custom one # metric in args with custom one
res = get_cv_result(params=params_verbose, fobj=dummy_obj, res = get_cv_result(params=params_verbose, fobj=dummy_obj,
feval=constant_metric, metrics='binary_error') feval=constant_metric, metrics='binary_error')
assert len(res) == 4 assert len(res) == 4
assert 'binary_error-mean' in res assert 'valid binary_error-mean' in res
assert 'error-mean' in res assert 'valid error-mean' in res
# metric in args overwrites one in params, custom one is evaluated too # metric in args overwrites one in params, custom one is evaluated too
res = get_cv_result(params=params_metric_inv_verbose, fobj=dummy_obj, res = get_cv_result(params=params_metric_inv_verbose, fobj=dummy_obj,
feval=constant_metric, metrics='binary_error') feval=constant_metric, metrics='binary_error')
assert len(res) == 4 assert len(res) == 4
assert 'binary_error-mean' in res assert 'valid binary_error-mean' in res
assert 'error-mean' in res assert 'valid error-mean' in res
# multiple metrics in params with custom one # multiple metrics in params with custom one
res = get_cv_result(params=params_metric_multi_verbose, fobj=dummy_obj, feval=constant_metric) res = get_cv_result(params=params_metric_multi_verbose, fobj=dummy_obj, feval=constant_metric)
assert len(res) == 6 assert len(res) == 6
assert 'binary_logloss-mean' in res assert 'valid binary_logloss-mean' in res
assert 'binary_error-mean' in res assert 'valid binary_error-mean' in res
assert 'error-mean' in res assert 'valid error-mean' in res
# multiple metrics in args with custom one # multiple metrics in args with custom one
res = get_cv_result(params=params_verbose, fobj=dummy_obj, feval=constant_metric, res = get_cv_result(params=params_verbose, fobj=dummy_obj, feval=constant_metric,
metrics=['binary_logloss', 'binary_error']) metrics=['binary_logloss', 'binary_error'])
assert len(res) == 6 assert len(res) == 6
assert 'binary_logloss-mean' in res assert 'valid binary_logloss-mean' in res
assert 'binary_error-mean' in res assert 'valid binary_error-mean' in res
assert 'error-mean' in res assert 'valid error-mean' in res
# custom metric is evaluated despite 'None' is passed # custom metric is evaluated despite 'None' is passed
res = get_cv_result(params=params_metric_none_verbose, fobj=dummy_obj, feval=constant_metric) res = get_cv_result(params=params_metric_none_verbose, fobj=dummy_obj, feval=constant_metric)
assert len(res) == 2 assert len(res) == 2
assert 'error-mean' in res assert 'valid error-mean' in res
# no fobj, no feval # no fobj, no feval
# default metric # default metric
...@@ -2184,23 +2184,23 @@ def test_metrics(): ...@@ -2184,23 +2184,23 @@ def test_metrics():
# multiclass default metric # multiclass default metric
res = get_cv_result(params_obj_class_3_verbose) res = get_cv_result(params_obj_class_3_verbose)
assert len(res) == 2 assert len(res) == 2
assert 'multi_logloss-mean' in res assert 'valid multi_logloss-mean' in res
# multiclass default metric with custom one # multiclass default metric with custom one
res = get_cv_result(params_obj_class_3_verbose, feval=constant_metric) res = get_cv_result(params_obj_class_3_verbose, feval=constant_metric)
assert len(res) == 4 assert len(res) == 4
assert 'multi_logloss-mean' in res assert 'valid multi_logloss-mean' in res
assert 'error-mean' in res assert 'valid error-mean' in res
# multiclass metric alias with custom one for custom objective # multiclass metric alias with custom one for custom objective
res = get_cv_result(params_obj_class_3_verbose, fobj=dummy_obj, feval=constant_metric) res = get_cv_result(params_obj_class_3_verbose, fobj=dummy_obj, feval=constant_metric)
assert len(res) == 2 assert len(res) == 2
assert 'error-mean' in res assert 'valid error-mean' in res
# no metric for invalid class_num # no metric for invalid class_num
res = get_cv_result(params_obj_class_1_verbose, fobj=dummy_obj) res = get_cv_result(params_obj_class_1_verbose, fobj=dummy_obj)
assert len(res) == 0 assert len(res) == 0
# custom metric for invalid class_num # custom metric for invalid class_num
res = get_cv_result(params_obj_class_1_verbose, fobj=dummy_obj, feval=constant_metric) res = get_cv_result(params_obj_class_1_verbose, fobj=dummy_obj, feval=constant_metric)
assert len(res) == 2 assert len(res) == 2
assert 'error-mean' in res assert 'valid error-mean' in res
# multiclass metric alias with custom one with invalid class_num # multiclass metric alias with custom one with invalid class_num
with pytest.raises(lgb.basic.LightGBMError): with pytest.raises(lgb.basic.LightGBMError):
get_cv_result(params_obj_class_1_verbose, metrics=obj_multi_alias, get_cv_result(params_obj_class_1_verbose, metrics=obj_multi_alias,
...@@ -2212,11 +2212,11 @@ def test_metrics(): ...@@ -2212,11 +2212,11 @@ def test_metrics():
# multiclass metric alias # multiclass metric alias
res = get_cv_result(params_obj_class_3_verbose, metrics=metric_multi_alias) res = get_cv_result(params_obj_class_3_verbose, metrics=metric_multi_alias)
assert len(res) == 2 assert len(res) == 2
assert 'multi_logloss-mean' in res assert 'valid multi_logloss-mean' in res
# multiclass metric # multiclass metric
res = get_cv_result(params_obj_class_3_verbose, metrics='multi_error') res = get_cv_result(params_obj_class_3_verbose, metrics='multi_error')
assert len(res) == 2 assert len(res) == 2
assert 'multi_error-mean' in res assert 'valid multi_error-mean' in res
# non-valid metric for multiclass objective # non-valid metric for multiclass objective
with pytest.raises(lgb.basic.LightGBMError): with pytest.raises(lgb.basic.LightGBMError):
get_cv_result(params_obj_class_3_verbose, metrics='binary_logloss') get_cv_result(params_obj_class_3_verbose, metrics='binary_logloss')
...@@ -2231,11 +2231,11 @@ def test_metrics(): ...@@ -2231,11 +2231,11 @@ def test_metrics():
# multiclass metric alias for custom objective # multiclass metric alias for custom objective
res = get_cv_result(params_class_3_verbose, metrics=metric_multi_alias, fobj=dummy_obj) res = get_cv_result(params_class_3_verbose, metrics=metric_multi_alias, fobj=dummy_obj)
assert len(res) == 2 assert len(res) == 2
assert 'multi_logloss-mean' in res assert 'valid multi_logloss-mean' in res
# multiclass metric for custom objective # multiclass metric for custom objective
res = get_cv_result(params_class_3_verbose, metrics='multi_error', fobj=dummy_obj) res = get_cv_result(params_class_3_verbose, metrics='multi_error', fobj=dummy_obj)
assert len(res) == 2 assert len(res) == 2
assert 'multi_error-mean' in res assert 'valid multi_error-mean' in res
# binary metric with non-default num_class for custom objective # binary metric with non-default num_class for custom objective
with pytest.raises(lgb.basic.LightGBMError): with pytest.raises(lgb.basic.LightGBMError):
get_cv_result(params_class_3_verbose, metrics='binary_error', fobj=dummy_obj) get_cv_result(params_class_3_verbose, metrics='binary_error', fobj=dummy_obj)
...@@ -2281,12 +2281,12 @@ def test_multiple_feval_cv(): ...@@ -2281,12 +2281,12 @@ def test_multiple_feval_cv():
# Expect three metrics but mean and stdv for each metric # Expect three metrics but mean and stdv for each metric
assert len(cv_results) == 6 assert len(cv_results) == 6
assert 'binary_logloss-mean' in cv_results assert 'valid binary_logloss-mean' in cv_results
assert 'error-mean' in cv_results assert 'valid error-mean' in cv_results
assert 'decreasing_metric-mean' in cv_results assert 'valid decreasing_metric-mean' in cv_results
assert 'binary_logloss-stdv' in cv_results assert 'valid binary_logloss-stdv' in cv_results
assert 'error-stdv' in cv_results assert 'valid error-stdv' in cv_results
assert 'decreasing_metric-stdv' in cv_results assert 'valid decreasing_metric-stdv' in cv_results
def test_default_objective_and_metric(): def test_default_objective_and_metric():
...@@ -3252,3 +3252,42 @@ def test_force_split_with_feature_fraction(tmp_path): ...@@ -3252,3 +3252,42 @@ def test_force_split_with_feature_fraction(tmp_path):
for tree in tree_info: for tree in tree_info:
tree_structure = tree["tree_structure"] tree_structure = tree["tree_structure"]
assert tree_structure['split_feature'] == 0 assert tree_structure['split_feature'] == 0
def test_record_evaluation_with_train():
X, y = make_synthetic_regression()
ds = lgb.Dataset(X, y)
eval_result = {}
callbacks = [lgb.record_evaluation(eval_result)]
params = {'objective': 'l2', 'num_leaves': 3}
num_boost_round = 5
bst = lgb.train(params, ds, num_boost_round=num_boost_round, valid_sets=[ds], callbacks=callbacks)
assert list(eval_result.keys()) == ['training']
train_mses = []
for i in range(num_boost_round):
pred = bst.predict(X, num_iteration=i + 1)
mse = mean_squared_error(y, pred)
train_mses.append(mse)
np.testing.assert_allclose(eval_result['training']['l2'], train_mses)
@pytest.mark.parametrize('train_metric', [False, True])
def test_record_evaluation_with_cv(train_metric):
X, y = make_synthetic_regression()
ds = lgb.Dataset(X, y)
eval_result = {}
callbacks = [lgb.record_evaluation(eval_result)]
metrics = ['l2', 'rmse']
params = {'objective': 'l2', 'num_leaves': 3, 'metric': metrics}
cv_hist = lgb.cv(params, ds, num_boost_round=5, stratified=False, callbacks=callbacks, eval_train_metric=train_metric)
expected_datasets = {'valid'}
if train_metric:
expected_datasets.add('train')
assert set(eval_result.keys()) == expected_datasets
for dataset in expected_datasets:
for metric in metrics:
for agg in ('mean', 'stdv'):
key = f'{dataset} {metric}-{agg}'
np.testing.assert_allclose(
cv_hist[key], eval_result[dataset][f'{metric}-{agg}']
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment