pipeline.py 1.23 KB
Newer Older
bailuo's avatar
readme  
bailuo committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import os
from time import time
from typing import List, Tuple

import fire
import pandas as pd


from ..utils import ExperimentHandler
from .forecaster import AmazonChronos


def run_amazon_chronos(
    train_df: pd.DataFrame,
    model_name: str,
    horizon: int,
    freq: str,
    quantiles: List[float],
) -> Tuple[pd.DataFrame, float, str]:
    ac = AmazonChronos(model_name)
    init_time = time()
    fcsts_df = ac.forecast(
        df=train_df,
        h=horizon,
        freq=freq,
        batch_size=8,
        quantiles=quantiles,
        # parameters as in https://github.com/amazon-science/chronos-forecasting/blob/73be25042f5f587823d46106d372ba133152fb00/README.md?plain=1#L62-L65
        num_samples=20,
        temperature=1.0,
        top_k=50,
        top_p=1.0,
    )
    total_time = time() - init_time
    return fcsts_df, total_time, model_name


def main(dataset: str, model_name: str):
    exp = ExperimentHandler(dataset)
    fcst_df, total_time, model_name = run_amazon_chronos(
        train_df=exp.train_df,
        model_name=model_name,
        horizon=exp.horizon,
        freq=exp.freq,
        quantiles=exp.quantiles,
    )
    exp.save_results(fcst_df, total_time, model_name)


if __name__ == "__main__":
    fire.Fire(main)