test_callback.py 2.18 KB
Newer Older
1
2
3
4
5
# coding: utf-8
import pytest

import lightgbm as lgb

6
from .utils import SERIALIZERS, pickle_and_unpickle_object, pickle_obj, unpickle_obj
7
8


9
10
11
12
def reset_feature_fraction(boosting_round):
    return 0.6 if boosting_round < 15 else 0.8


13
14
@pytest.mark.parametrize('serializer', SERIALIZERS)
def test_early_stopping_callback_is_picklable(serializer):
15
16
    rounds = 5
    callback = lgb.early_stopping(stopping_rounds=rounds)
17
18
19
    callback_from_disk = pickle_and_unpickle_object(obj=callback, serializer=serializer)
    assert callback_from_disk.order == 30
    assert callback_from_disk.before_iteration is False
20
    assert callback.stopping_rounds == callback_from_disk.stopping_rounds
21
22
23
    assert callback.stopping_rounds == rounds


24
25
@pytest.mark.parametrize('serializer', SERIALIZERS)
def test_log_evaluation_callback_is_picklable(serializer):
26
27
    periods = 42
    callback = lgb.log_evaluation(period=periods)
28
29
30
    callback_from_disk = pickle_and_unpickle_object(obj=callback, serializer=serializer)
    assert callback_from_disk.order == 10
    assert callback_from_disk.before_iteration is False
31
32
    assert callback.period == callback_from_disk.period
    assert callback.period == periods
33
34
35
36
37
38
39
40
41
42
43


@pytest.mark.parametrize('serializer', SERIALIZERS)
def test_record_evaluation_callback_is_picklable(serializer):
    results = {}
    callback = lgb.record_evaluation(eval_result=results)
    callback_from_disk = pickle_and_unpickle_object(obj=callback, serializer=serializer)
    assert callback_from_disk.order == 20
    assert callback_from_disk.before_iteration is False
    assert callback.eval_result == callback_from_disk.eval_result
    assert callback.eval_result is results
44
45
46
47
48
49
50
51
52
53
54
55
56
57


@pytest.mark.parametrize('serializer', SERIALIZERS)
def test_reset_parameter_callback_is_picklable(serializer):
    params = {
        'bagging_fraction': [0.7] * 5 + [0.6] * 5,
        'feature_fraction': reset_feature_fraction
    }
    callback = lgb.reset_parameter(**params)
    callback_from_disk = pickle_and_unpickle_object(obj=callback, serializer=serializer)
    assert callback_from_disk.order == 10
    assert callback_from_disk.before_iteration is True
    assert callback.kwargs == callback_from_disk.kwargs
    assert callback.kwargs == params