test_callback.py 2.54 KB
Newer Older
1
2
3
4
5
6
7
# coding: utf-8
import pytest

import lightgbm as lgb

from .utils import pickle_obj, unpickle_obj

8
SERIALIZERS = ["pickle", "joblib", "cloudpickle"]
9

10
11
12
13
14
15
16
17
18
19
20
21
22
23
24

def pickle_and_unpickle_object(obj, serializer):
    with lgb.basic._TempFile() as tmp_file:
        pickle_obj(
            obj=obj,
            filepath=tmp_file.name,
            serializer=serializer
        )
        obj_from_disk = unpickle_obj(
            filepath=tmp_file.name,
            serializer=serializer
        )
    return obj_from_disk


25
26
27
28
def reset_feature_fraction(boosting_round):
    return 0.6 if boosting_round < 15 else 0.8


29
30
@pytest.mark.parametrize('serializer', SERIALIZERS)
def test_early_stopping_callback_is_picklable(serializer):
31
32
    rounds = 5
    callback = lgb.early_stopping(stopping_rounds=rounds)
33
34
35
    callback_from_disk = pickle_and_unpickle_object(obj=callback, serializer=serializer)
    assert callback_from_disk.order == 30
    assert callback_from_disk.before_iteration is False
36
    assert callback.stopping_rounds == callback_from_disk.stopping_rounds
37
38
39
    assert callback.stopping_rounds == rounds


40
41
@pytest.mark.parametrize('serializer', SERIALIZERS)
def test_log_evaluation_callback_is_picklable(serializer):
42
43
    periods = 42
    callback = lgb.log_evaluation(period=periods)
44
45
46
    callback_from_disk = pickle_and_unpickle_object(obj=callback, serializer=serializer)
    assert callback_from_disk.order == 10
    assert callback_from_disk.before_iteration is False
47
48
    assert callback.period == callback_from_disk.period
    assert callback.period == periods
49
50
51
52
53
54
55
56
57
58
59


@pytest.mark.parametrize('serializer', SERIALIZERS)
def test_record_evaluation_callback_is_picklable(serializer):
    results = {}
    callback = lgb.record_evaluation(eval_result=results)
    callback_from_disk = pickle_and_unpickle_object(obj=callback, serializer=serializer)
    assert callback_from_disk.order == 20
    assert callback_from_disk.before_iteration is False
    assert callback.eval_result == callback_from_disk.eval_result
    assert callback.eval_result is results
60
61
62
63
64
65
66
67
68
69
70
71
72
73


@pytest.mark.parametrize('serializer', SERIALIZERS)
def test_reset_parameter_callback_is_picklable(serializer):
    params = {
        'bagging_fraction': [0.7] * 5 + [0.6] * 5,
        'feature_fraction': reset_feature_fraction
    }
    callback = lgb.reset_parameter(**params)
    callback_from_disk = pickle_and_unpickle_object(obj=callback, serializer=serializer)
    assert callback_from_disk.order == 10
    assert callback_from_disk.before_iteration is True
    assert callback.kwargs == callback_from_disk.kwargs
    assert callback.kwargs == params