test_dask.py 1.72 KB
Newer Older
bailuo's avatar
readme  
bailuo committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import pytest

from nixtla_tests.helpers.checks import check_anomalies_dataframe
from nixtla_tests.helpers.checks import check_anomalies_online_dataframe
from nixtla_tests.helpers.checks import check_anomalies_dataframe_diff_cols
from nixtla_tests.helpers.checks import check_forecast_dataframe
from nixtla_tests.helpers.checks import check_forecast_dataframe_diff_cols
from nixtla_tests.helpers.checks import check_forecast_x_dataframe
from nixtla_tests.helpers.checks import check_forecast_x_dataframe_diff_cols
from nixtla_tests.helpers.checks import check_quantiles

pytestmark = pytest.mark.distributed_run

def test_quantiles(nixtla_test_client, dask_df):
    check_quantiles(nixtla_test_client, dask_df, id_col="unique_id", time_col="ds")


def test_forecast(nixtla_test_client, dask_df, dask_diff_cols_df, distributed_n_series):
    check_forecast_dataframe(
        nixtla_test_client, dask_df, n_series_to_check=distributed_n_series
    )
    check_forecast_dataframe_diff_cols(nixtla_test_client, dask_diff_cols_df)


def test_anomalies(nixtla_test_client, dask_df, dask_diff_cols_df):
    check_anomalies_dataframe(nixtla_test_client, dask_df)
    check_anomalies_dataframe_diff_cols(nixtla_test_client, dask_diff_cols_df)


def test_anomalies_online(nixtla_test_client, dask_df):
    check_anomalies_online_dataframe(nixtla_test_client, dask_df)


def test_forecast_x_dataframe(
    nixtla_test_client,
    dask_df_x,
    dask_future_ex_vars_df,
    dask_df_x_diff_cols,
    dask_future_ex_vars_df_diff_cols,
):
    check_forecast_x_dataframe(nixtla_test_client, dask_df_x, dask_future_ex_vars_df)
    check_forecast_x_dataframe_diff_cols(
        nixtla_test_client,
        dask_df_x_diff_cols,
        dask_future_ex_vars_df_diff_cols,
    )