statsforecast_sn.py 1.02 KB
Newer Older
bailuo's avatar
readme  
bailuo committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import os
from pathlib import Path
from time import time

import fire
from statsforecast import StatsForecast
from statsforecast.models import SeasonalNaive

from src.utils.data_handler import ExperimentDataset, ForecastDataset


def sn_forecast(dataset_path: str, results_dir: str = "./results"):
    os.environ["NIXTLA_ID_AS_COL"] = "true"
    dataset = ExperimentDataset.from_parquet(parquet_path=dataset_path)
    sf = StatsForecast(
        models=[SeasonalNaive(season_length=dataset.seasonality)],
        freq=dataset.pandas_frequency,
    )
    start = time()
    forecast_df = sf.forecast(
        df=dataset.Y_df_train,
        h=dataset.horizon,
    )
    end = time()
    total_time = end - start
    forecast_dataset = ForecastDataset(forecast_df=forecast_df, total_time=total_time)
    experiment_name = dataset_path.split("/")[-1].split(".")[0]
    results_path = Path(results_dir) / "statsforecast_sn" / experiment_name
    forecast_dataset.save_to_dir(results_path)


if __name__ == "__main__":
    fire.Fire(sn_forecast)