statsforecast_pipeline.py 1.2 KB
Newer Older
bailuo's avatar
readme  
bailuo committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import os
from time import time
from typing import List, Tuple

import fire
import pandas as pd
from statsforecast import StatsForecast
from statsforecast.models import SeasonalNaive

from src.utils import ExperimentHandler


def run_statsforecast(
    train_df: pd.DataFrame,
    horizon: int,
    freq: str,
    seasonality: int,
    level: List[int],
) -> Tuple[pd.DataFrame, float, str]:
    os.environ["NIXTLA_ID_AS_COL"] = "true"
    models = [SeasonalNaive(season_length=seasonality)]
    init_time = time()
    sf = StatsForecast(
        models=models,
        freq=freq,
        n_jobs=-1,
    )
    fcsts_df = sf.forecast(df=train_df, h=horizon, level=level)
    total_time = time() - init_time
    model_name = repr(models[0])
    return fcsts_df, total_time, model_name


def main(dataset: str):
    exp = ExperimentHandler(dataset)
    fcst_df, total_time, model_name = run_statsforecast(
        train_df=exp.train_df,
        horizon=exp.horizon,
        freq=exp.freq,
        seasonality=exp.seasonality,
        level=exp.level,
    )
    fcst_df = exp._fcst_from_level_to_quantiles(fcst_df, model_name)
    exp._save_results(fcst_df, total_time, model_name)


if __name__ == "__main__":
    fire.Fire(main)