test_client.py 6.1 KB
Newer Older
bailuo's avatar
readme  
bailuo committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import os
import pytest
import pandas as pd
import warnings

from nixtla_tests.helpers.client_helper import delete_env_var
from nixtla.nixtla_client import NixtlaClient

def test_custom_business_hours(
    business_hours_series, custom_business_hours
):
    nixtla_test_client = NixtlaClient()
    nixtla_test_client.detect_anomalies(
        df=business_hours_series, freq=custom_business_hours, level=90
    )
    nixtla_test_client.cross_validation(
        df=business_hours_series, freq=custom_business_hours, h=7
    )
    fcst = nixtla_test_client.forecast(
        df=business_hours_series, freq=custom_business_hours, h=7
    )
    assert sorted(fcst["ds"].dt.hour.unique().tolist()) == list(range(9, 16))
    assert [
        (model, freq.lower())
        for (model, freq) in nixtla_test_client._model_params.keys()
    ] == [("timegpt-1", "cbh")]


def test_integer_freq(integer_freq_series):
    nixtla_test_client = NixtlaClient()
    nixtla_test_client.detect_anomalies(df=integer_freq_series, level=90, freq=1)
    nixtla_test_client.cross_validation(df=integer_freq_series, h=7, freq=1)
    fcst = nixtla_test_client.forecast(df=integer_freq_series, h=7, freq=1)
    train_ends = integer_freq_series.groupby("unique_id", observed=True)["ds"].max()
    fcst_ends = fcst.groupby("unique_id", observed=True)["ds"].max()
    pd.testing.assert_series_equal(fcst_ends, train_ends + 7)
    assert list(nixtla_test_client._model_params.keys()) == [("timegpt-1", "MS")]


def test_api_key_fail():
    with delete_env_var("NIXTLA_API_KEY"), delete_env_var("TIMEGPT_TOKEN"):
        with pytest.raises(KeyError) as excinfo:
            NixtlaClient()
        assert "NIXTLA_API_KEY" in str(excinfo.value)


def test_api_key_success():
    nixtla_client = NixtlaClient()
    assert nixtla_client.validate_api_key()


def test_custom_client_success():
    custom_client = NixtlaClient(
        base_url=os.environ["NIXTLA_BASE_URL_CUSTOM"],
        api_key=os.environ["NIXTLA_API_KEY_CUSTOM"],
    )
    assert custom_client.validate_api_key()

    # assert the usage endpoint
    usage = custom_client.usage()
    assert sorted(usage.keys()) == ["minute", "month"]


def test_forecast_with_wrong_api_key():
    with pytest.raises(Exception) as excinfo:
        NixtlaClient(api_key="transphobic").forecast(
            df=pd.DataFrame(), h=None, validate_api_key=True
        )

    assert "nixtla" in str(excinfo.value)


def test_get_model_params(nixtla_test_client):
    assert nixtla_test_client._get_model_params(model="timegpt-1", freq="D") == (28, 7)


def test_client_plot(nixtla_test_client, air_passengers_df):
    nixtla_test_client.plot(
        air_passengers_df, time_col="timestamp", target_col="value", engine="plotly"
    )


def test_finetune_cv(nixtla_test_client, air_passengers_df):
    finetune_cv = nixtla_test_client.cross_validation(
        air_passengers_df,
        h=12,
        time_col="timestamp",
        target_col="value",
        n_windows=1,
        finetune_steps=1,
    )
    assert finetune_cv is not None


def test_forecast_warning(nixtla_test_client, air_passengers_df, caplog):
    nixtla_test_client.forecast(
        df=air_passengers_df.tail(3),
        h=100,
        time_col="timestamp",
        target_col="value",
    )
    assert 'The specified horizon "h" exceeds the model horizon' in caplog.text


@pytest.mark.parametrize(
    "kwargs",
    [
        {"add_history": True},
    ],
    ids=["short horizon with add_history"],
)
def test_forecast_error(nixtla_test_client, air_passengers_df, kwargs):
    with pytest.raises(
        ValueError, match="Some series are too short. Please make sure that each series"
    ):
        nixtla_test_client.forecast(
            df=air_passengers_df.tail(3),
            h=12,
            time_col="timestamp",
            target_col="value",
            **kwargs,
        )


def test_large_request_partition_error(nixtla_test_client, large_series):
    with pytest.raises(Exception) as excinfo:
        nixtla_test_client.forecast(df=large_series, h=1, freq="min", finetune_steps=2)
    assert "num_partitions" in str(excinfo.value)


def test_forecast_exogenous_warnings(
    nixtla_test_client, two_short_series_with_time_features_train_future
):
    train, future = two_short_series_with_time_features_train_future

    # features in df but not in X_df
    missing_exogenous = train.columns.drop(["unique_id", "ds", "y"]).tolist()
    expected_warning = (
        f"`df` contains the following exogenous features: {missing_exogenous}, "
        "but `X_df` was not provided and they were not declared in `hist_exog_list`. "
        "They will be ignored."
    )
    with warnings.catch_warnings(record=True) as w:
        nixtla_test_client.forecast(train, h=5)
        assert any(expected_warning in str(warning.message) for warning in w)

    # features in df not set as historic nor in X_df
    expected_warning = (
        "`df` contains the following exogenous features: ['month'], "
        "but they were not found in `X_df` nor declared in `hist_exog_list`. "
        "They will be ignored."
    )
    with warnings.catch_warnings(record=True) as w:
        nixtla_test_client.forecast(
            train, h=5, X_df=future[["unique_id", "ds", "year"]]
        )
        assert any(expected_warning in str(warning.message) for warning in w)


def test_features_not_in_df_error(
    nixtla_test_client, two_short_series_with_time_features_train_future
):
    train, future = two_short_series_with_time_features_train_future
    with pytest.raises(
        ValueError, match="features are present in `X_df` but not in `df`"
    ):
        nixtla_test_client.forecast(
            df=train[["unique_id", "ds", "y"]],
            h=5,
            X_df=future,
        )


def test_setting_one_as_historic_and_other_as_future(
    nixtla_test_client, two_short_series_with_time_features_train_future
):
    train, future = two_short_series_with_time_features_train_future

    # test setting one as historic and other as future
    nixtla_test_client.forecast(
        train, h=5, X_df=future[["unique_id", "ds", "year"]], hist_exog_list=["month"]
    )
    assert nixtla_test_client.weights_x["features"].tolist() == ["year", "month"]