test_callback.py 2.61 KB
Newer Older
1
2
3
4
5
# coding: utf-8
import pytest

import lightgbm as lgb

6
from .utils import SERIALIZERS, pickle_and_unpickle_object
7
8


9
10
11
12
def reset_feature_fraction(boosting_round):
    return 0.6 if boosting_round < 15 else 0.8


13
@pytest.mark.parametrize("serializer", SERIALIZERS)
14
def test_early_stopping_callback_is_picklable(serializer):
15
16
    rounds = 5
    callback = lgb.early_stopping(stopping_rounds=rounds)
17
18
19
    callback_from_disk = pickle_and_unpickle_object(obj=callback, serializer=serializer)
    assert callback_from_disk.order == 30
    assert callback_from_disk.before_iteration is False
20
    assert callback.stopping_rounds == callback_from_disk.stopping_rounds
21
22
23
    assert callback.stopping_rounds == rounds


24
def test_early_stopping_callback_rejects_invalid_stopping_rounds_with_informative_errors():
25
26
    with pytest.raises(TypeError, match="early_stopping_round should be an integer. Got 'str'"):
        lgb.early_stopping(stopping_rounds="neverrrr")
27
28


29
30
31
32
@pytest.mark.parametrize("stopping_rounds", [-10, -1, 0])
def test_early_stopping_callback_accepts_non_positive_stopping_rounds(stopping_rounds):
    cb = lgb.early_stopping(stopping_rounds=stopping_rounds)
    assert cb.enabled is False
33
34


35
@pytest.mark.parametrize("serializer", SERIALIZERS)
36
def test_log_evaluation_callback_is_picklable(serializer):
37
38
    periods = 42
    callback = lgb.log_evaluation(period=periods)
39
40
41
    callback_from_disk = pickle_and_unpickle_object(obj=callback, serializer=serializer)
    assert callback_from_disk.order == 10
    assert callback_from_disk.before_iteration is False
42
43
    assert callback.period == callback_from_disk.period
    assert callback.period == periods
44
45


46
@pytest.mark.parametrize("serializer", SERIALIZERS)
47
48
49
50
51
52
53
54
def test_record_evaluation_callback_is_picklable(serializer):
    results = {}
    callback = lgb.record_evaluation(eval_result=results)
    callback_from_disk = pickle_and_unpickle_object(obj=callback, serializer=serializer)
    assert callback_from_disk.order == 20
    assert callback_from_disk.before_iteration is False
    assert callback.eval_result == callback_from_disk.eval_result
    assert callback.eval_result is results
55
56


57
@pytest.mark.parametrize("serializer", SERIALIZERS)
58
def test_reset_parameter_callback_is_picklable(serializer):
59
    params = {"bagging_fraction": [0.7] * 5 + [0.6] * 5, "feature_fraction": reset_feature_fraction}
60
61
62
63
64
65
    callback = lgb.reset_parameter(**params)
    callback_from_disk = pickle_and_unpickle_object(obj=callback, serializer=serializer)
    assert callback_from_disk.order == 10
    assert callback_from_disk.before_iteration is True
    assert callback.kwargs == callback_from_disk.kwargs
    assert callback.kwargs == params