"git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "623ac048cdea55290f3134981558b8a541ca03ff"
test_callback.py 1.22 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
# coding: utf-8
import pytest

import lightgbm as lgb

from .utils import pickle_obj, unpickle_obj


@pytest.mark.parametrize('serializer', ["pickle", "joblib", "cloudpickle"])
def test_early_stopping_callback_is_picklable(serializer, tmp_path):
11
12
    rounds = 5
    callback = lgb.early_stopping(stopping_rounds=rounds)
13
14
15
16
17
18
19
20
21
22
23
    tmp_file = tmp_path / "early_stopping.pkl"
    pickle_obj(
        obj=callback,
        filepath=tmp_file,
        serializer=serializer
    )
    callback_from_disk = unpickle_obj(
        filepath=tmp_file,
        serializer=serializer
    )
    assert callback.stopping_rounds == callback_from_disk.stopping_rounds
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
    assert callback.stopping_rounds == rounds


@pytest.mark.parametrize('serializer', ["pickle", "joblib", "cloudpickle"])
def test_log_evaluation_callback_is_picklable(serializer, tmp_path):
    periods = 42
    callback = lgb.log_evaluation(period=periods)
    tmp_file = tmp_path / "log_evaluation.pkl"
    pickle_obj(
        obj=callback,
        filepath=tmp_file,
        serializer=serializer
    )
    callback_from_disk = unpickle_obj(
        filepath=tmp_file,
        serializer=serializer
    )
    assert callback.period == callback_from_disk.period
    assert callback.period == periods