"vscode:/vscode.git/clone" did not exist on "4adb38a40c30f58abd35e6f8488c35a15b63511d"
main.py 2.19 KB
Newer Older
bailuo's avatar
readme  
bailuo committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import logging
import subprocess

import fire
import pandas as pd

from src.utils import ExperimentHandler

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

datasets = [
    "m1_yearly",
    "m1_quarterly",
    "m1_monthly",
    "m3_yearly",
    "m3_quarterly",
    "m3_monthly",
    "m3_other",
    "tourism_yearly",
    "tourism_quarterly",
    "tourism_monthly",
    "m4_yearly",
    "m4_quarterly",
]

amazon_chronos_models = [
    "amazon/chronos-t5-large",
    "amazon/chronos-t5-tiny",
    "amazon/chronos-t5-mini",
    "amazon/chronos-t5-small",
    "amazon/chronos-t5-base",
]


def main(mode: str):
    prefix_process = ["python", "-m"]

    eval_df = None
    for dataset in datasets:
        logger.info(f"Evaluating {dataset}...")
        if mode in ["fcst_statsforecast", "fcst_chronos"]:
            suffix_process = ["--dataset", dataset]

            def process(middle_process):
                return prefix_process + middle_process + suffix_process

            if mode == "fcst_statsforecast":
                logger.info("Running StatisticalEnsemble")
                subprocess.run(process(["src.statsforecast_pipeline"]))
            elif mode == "fcst_chronos":
                for model in amazon_chronos_models:
                    logger.info(f"Running Amazon Chronos {model}")
                    chronos_process = process(["src.amazon_chronos.pipeline"])
                    chronos_process.extend(["--model_name", model])
                    subprocess.run(chronos_process)
        elif mode == "evaluation":
            if eval_df is None:
                eval_df = []
            logger.info("Running dataset evaluation")
            exp = ExperimentHandler(dataset)
            try:
                eval_dataset_df = exp.evaluate_models(
                    amazon_chronos_models + ["StatisticalEnsemble", "SeasonalNaive"]
                )
                print(eval_dataset_df)
                eval_df.append(eval_dataset_df)
            except Exception as e:
                logger.error(e)
    if eval_df is not None:
        eval_df = pd.concat(eval_df).reset_index(drop=True)
        exp.save_dataframe(eval_df, "complete-results.csv")


if __name__ == "__main__":
    fire.Fire(main)