test_callback.py 2.72 KB
Newer Older
1
2
3
4
5
# coding: utf-8
import pytest

import lightgbm as lgb

6
from .utils import SERIALIZERS, pickle_and_unpickle_object
7
8


9
10
11
12
def reset_feature_fraction(boosting_round):
    return 0.6 if boosting_round < 15 else 0.8


13
14
@pytest.mark.parametrize('serializer', SERIALIZERS)
def test_early_stopping_callback_is_picklable(serializer):
15
16
    rounds = 5
    callback = lgb.early_stopping(stopping_rounds=rounds)
17
18
19
    callback_from_disk = pickle_and_unpickle_object(obj=callback, serializer=serializer)
    assert callback_from_disk.order == 30
    assert callback_from_disk.before_iteration is False
20
    assert callback.stopping_rounds == callback_from_disk.stopping_rounds
21
22
23
    assert callback.stopping_rounds == rounds


24
25
26
27
28
29
30
31
32
33
34
def test_early_stopping_callback_rejects_invalid_stopping_rounds_with_informative_errors():
    with pytest.raises(ValueError, match="stopping_rounds should be an integer and greater than 0. got: 0"):
        lgb.early_stopping(stopping_rounds=0)

    with pytest.raises(ValueError, match="stopping_rounds should be an integer and greater than 0. got: -1"):
        lgb.early_stopping(stopping_rounds=-1)

    with pytest.raises(ValueError, match="stopping_rounds should be an integer and greater than 0. got: neverrrr"):
        lgb.early_stopping(stopping_rounds="neverrrr")


35
36
@pytest.mark.parametrize('serializer', SERIALIZERS)
def test_log_evaluation_callback_is_picklable(serializer):
37
38
    periods = 42
    callback = lgb.log_evaluation(period=periods)
39
40
41
    callback_from_disk = pickle_and_unpickle_object(obj=callback, serializer=serializer)
    assert callback_from_disk.order == 10
    assert callback_from_disk.before_iteration is False
42
43
    assert callback.period == callback_from_disk.period
    assert callback.period == periods
44
45
46
47
48
49
50
51
52
53
54


@pytest.mark.parametrize('serializer', SERIALIZERS)
def test_record_evaluation_callback_is_picklable(serializer):
    results = {}
    callback = lgb.record_evaluation(eval_result=results)
    callback_from_disk = pickle_and_unpickle_object(obj=callback, serializer=serializer)
    assert callback_from_disk.order == 20
    assert callback_from_disk.before_iteration is False
    assert callback.eval_result == callback_from_disk.eval_result
    assert callback.eval_result is results
55
56
57
58
59
60
61
62
63
64
65
66
67
68


@pytest.mark.parametrize('serializer', SERIALIZERS)
def test_reset_parameter_callback_is_picklable(serializer):
    params = {
        'bagging_fraction': [0.7] * 5 + [0.6] * 5,
        'feature_fraction': reset_feature_fraction
    }
    callback = lgb.reset_parameter(**params)
    callback_from_disk = pickle_and_unpickle_object(obj=callback, serializer=serializer)
    assert callback_from_disk.order == 10
    assert callback_from_disk.before_iteration is True
    assert callback.kwargs == callback_from_disk.kwargs
    assert callback.kwargs == params