lag_llama_pipeline.py 3.27 KB
Newer Older
bailuo's avatar
readme  
bailuo committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from time import time
from typing import Iterable, List, Tuple

import fire
import pandas as pd
import torch
from gluonts.dataset import Dataset
from gluonts.model.forecast import Forecast
from gluonts.torch.model.predictor import PyTorchPredictor
from tqdm import tqdm

from lag_llama.gluon.estimator import LagLlamaEstimator
from src.utils import ExperimentHandler


def get_lag_llama_predictor(
    prediction_length: int, models_dir: str
) -> PyTorchPredictor:
    model_path = f"{models_dir}/lag-llama.ckpt"
    map_location = torch.device("cuda:0") if torch.cuda.is_available() else "cpu"
    if map_location == "cpu":
        raise ValueError("cpu is not supported in lagllama (there is a bug)")
    ckpt = torch.load(model_path, map_location=map_location)
    estimator_args = ckpt["hyper_parameters"]["model_kwargs"]
    # this context length is reported in the paper
    context_length = 32
    estimator = LagLlamaEstimator(
        ckpt_path=model_path,
        prediction_length=prediction_length,
        context_length=context_length,
        # estimator args
        input_size=estimator_args["input_size"],
        n_layer=estimator_args["n_layer"],
        n_embd_per_head=estimator_args["n_embd_per_head"],
        n_head=estimator_args["n_head"],
        scaling=estimator_args["scaling"],
        time_feat=estimator_args["time_feat"],
    )
    lightning_module = estimator.create_lightning_module()
    transformation = estimator.create_transformation()
    predictor = estimator.create_predictor(transformation, lightning_module)
    return predictor


def gluonts_instance_fcst_to_df(
    fcst: Forecast,
    quantiles: List[float],
    model_name: str,
) -> pd.DataFrame:
    point_forecast = fcst.mean
    h = len(point_forecast)
    dates = pd.date_range(
        fcst.start_date.to_timestamp(),
        freq=fcst.freq,
        periods=h,
    )
    fcst_df = pd.DataFrame(
        {
            "ds": dates,
            "unique_id": fcst.item_id,
            model_name: point_forecast,
        }
    )
    for q in quantiles:
        fcst_df[f"{model_name}-q-{q}"] = fcst.quantile(q)
    return fcst_df


def gluonts_fcsts_to_df(
    fcsts: Iterable[Forecast],
    quantiles: List[float],
    model_name: str,
) -> pd.DataFrame:
    df = []
    for fcst in tqdm(fcsts):
        fcst_df = gluonts_instance_fcst_to_df(fcst, quantiles, model_name)
        df.append(fcst_df)
    return pd.concat(df).reset_index(drop=True)


def run_lag_llama(
    gluonts_dataset: Dataset,
    horizon: int,
    quantiles: List[float],
    models_dir: str,
) -> Tuple[pd.DataFrame, float, str]:
    init_time = time()
    predictor = get_lag_llama_predictor(horizon, models_dir)
    fcsts = predictor.predict(gluonts_dataset, num_samples=100)
    model_name = "LagLlama"
    fcsts_df = gluonts_fcsts_to_df(
        fcsts,
        quantiles=quantiles,
        model_name=model_name,
    )
    total_time = time() - init_time
    return fcsts_df, total_time, model_name


def main(dataset: str):
    exp = ExperimentHandler(dataset)
    fcst_df, total_time, model_name = run_lag_llama(
        gluonts_dataset=exp.gluonts_train_dataset,
        horizon=exp.horizon,
        quantiles=exp.quantiles,
        models_dir=exp.models_dir,
    )
    exp._save_results(fcst_df, total_time, model_name)


if __name__ == "__main__":
    fire.Fire(main)