moirai.py 1006 Bytes
Newer Older
bailuo's avatar
readme  
bailuo committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from gluonts.torch.model.predictor import PyTorchPredictor
from uni2ts.model.moirai import MoiraiForecast, MoiraiModule

from ..utils.gluonts_forecaster import GluonTSForecaster


class Moirai(GluonTSForecaster):
    def __init__(
        self,
        repo_id: str = "Salesforce/moirai-1.0-R-large",
        filename: str = "model.ckpt",
        alias: str = "Moirai",
    ):
        super().__init__(
            repo_id=repo_id,
            filename=filename,
            alias=alias,
        )

    def get_predictor(self, prediction_length: int) -> PyTorchPredictor:
        model = MoiraiForecast(
            module=MoiraiModule.from_pretrained(self.repo_id),
            prediction_length=prediction_length,
            context_length=200,
            patch_size="auto",
            num_samples=100,
            target_dim=1,
            feat_dynamic_real_dim=0,
            past_feat_dynamic_real_dim=0,
        )
        predictor = model.create_predictor(batch_size=32)
        return predictor