lagllama.py 1.47 KB
Newer Older
bailuo's avatar
readme  
bailuo committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from gluonts.torch.model.predictor import PyTorchPredictor
from lag_llama.gluon.estimator import LagLlamaEstimator

from ..utils.gluonts_forecaster import GluonTSForecaster


class LagLlama(GluonTSForecaster):
    def __init__(
        self,
        repo_id: str = "time-series-foundation-models/Lag-Llama",
        filename: str = "lag-llama.ckpt",
        alias: str = "LagLlama",
    ):
        super().__init__(
            repo_id=repo_id,
            filename=filename,
            alias=alias,
        )

    def get_predictor(self, prediction_length: int) -> PyTorchPredictor:
        ckpt = self.load()
        estimator_args = ckpt["hyper_parameters"]["model_kwargs"]
        # this context length is reported in the paper
        context_length = 32
        estimator = LagLlamaEstimator(
            ckpt_path=self.checkpoint_path,
            prediction_length=prediction_length,
            context_length=context_length,
            # estimator args
            input_size=estimator_args["input_size"],
            n_layer=estimator_args["n_layer"],
            n_embd_per_head=estimator_args["n_embd_per_head"],
            n_head=estimator_args["n_head"],
            scaling=estimator_args["scaling"],
            time_feat=estimator_args["time_feat"],
        )
        lightning_module = estimator.create_lightning_module()
        transformation = estimator.create_transformation()
        predictor = estimator.create_predictor(transformation, lightning_module)
        return predictor