prophet.py 1.44 KB
Newer Older
bailuo's avatar
readme  
bailuo committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from copy import deepcopy
from typing import List
from threadpoolctl import threadpool_limits

import pandas as pd
from prophet import Prophet

from ..utils.parallel_forecaster import ParallelForecaster
from ..utils.forecaster import Forecaster


class NixtlaProphet(Prophet, ParallelForecaster, Forecaster):
    def __init__(
        self,
        alias: str = "Prophet",
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.alias = alias

    def __local_forecast(
        self,
        df: pd.DataFrame,
        h: int,
        freq: str,
        quantiles: List[float] | None = None,
    ) -> pd.DataFrame:
        if quantiles is not None:
            raise NotImplementedError
        model = deepcopy(self)
        model.fit(df=df)
        future_df = model.make_future_dataframe(
            periods=h,
            include_history=False,
            freq=freq,
        )
        fcst_df = model.predict(future_df)
        fcst_df = fcst_df.rename({"yhat": self.alias}, axis=1)
        fcst_df = fcst_df[["ds", self.alias]]
        return fcst_df

    def _local_forecast(
        self,
        df: pd.DataFrame,
        h: int,
        freq: str,
        quantiles: List[float] | None = None,
    ) -> pd.DataFrame:
        with threadpool_limits(limits=1):
            return self.__local_forecast(
                df=df,
                h=h,
                freq=freq,
                quantiles=quantiles,
            )