test_callback.py 1.9 KB
Newer Older
1
2
3
4
5
6
7
# coding: utf-8
import pytest

import lightgbm as lgb

from .utils import pickle_obj, unpickle_obj

8
SERIALIZERS = ["pickle", "joblib", "cloudpickle"]
9

10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26

def pickle_and_unpickle_object(obj, serializer):
    with lgb.basic._TempFile() as tmp_file:
        pickle_obj(
            obj=obj,
            filepath=tmp_file.name,
            serializer=serializer
        )
        obj_from_disk = unpickle_obj(
            filepath=tmp_file.name,
            serializer=serializer
        )
    return obj_from_disk


@pytest.mark.parametrize('serializer', SERIALIZERS)
def test_early_stopping_callback_is_picklable(serializer):
27
28
    rounds = 5
    callback = lgb.early_stopping(stopping_rounds=rounds)
29
30
31
    callback_from_disk = pickle_and_unpickle_object(obj=callback, serializer=serializer)
    assert callback_from_disk.order == 30
    assert callback_from_disk.before_iteration is False
32
    assert callback.stopping_rounds == callback_from_disk.stopping_rounds
33
34
35
    assert callback.stopping_rounds == rounds


36
37
@pytest.mark.parametrize('serializer', SERIALIZERS)
def test_log_evaluation_callback_is_picklable(serializer):
38
39
    periods = 42
    callback = lgb.log_evaluation(period=periods)
40
41
42
    callback_from_disk = pickle_and_unpickle_object(obj=callback, serializer=serializer)
    assert callback_from_disk.order == 10
    assert callback_from_disk.before_iteration is False
43
44
    assert callback.period == callback_from_disk.period
    assert callback.period == periods
45
46
47
48
49
50
51
52
53
54
55


@pytest.mark.parametrize('serializer', SERIALIZERS)
def test_record_evaluation_callback_is_picklable(serializer):
    results = {}
    callback = lgb.record_evaluation(eval_result=results)
    callback_from_disk = pickle_and_unpickle_object(obj=callback, serializer=serializer)
    assert callback_from_disk.order == 20
    assert callback_from_disk.before_iteration is False
    assert callback.eval_result == callback_from_disk.eval_result
    assert callback.eval_result is results