tsmixer.py 10.9 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/models.tsmixer.ipynb.

# %% auto 0
__all__ = ['TSMixer']

# %% ../../nbs/models.tsmixer.ipynb 5
import torch
import torch.nn as nn
import torch.nn.functional as F

from ..losses.pytorch import MAE
from ..common._base_multivariate import BaseMultivariate

# %% ../../nbs/models.tsmixer.ipynb 8
class TemporalMixing(nn.Module):
    def __init__(self, n_series, input_size, dropout):
        super().__init__()
        self.temporal_norm = nn.BatchNorm1d(
            num_features=n_series * input_size, eps=0.001, momentum=0.01
        )
        self.temporal_lin = nn.Linear(input_size, input_size)
        self.temporal_drop = nn.Dropout(dropout)

    def forward(self, input):
        # Get shapes
        batch_size = input.shape[0]
        input_size = input.shape[1]
        n_series = input.shape[2]

        # Temporal MLP
        x = input.permute(0, 2, 1)  # [B, L, N] -> [B, N, L]
        x = x.reshape(batch_size, -1)  # [B, N, L] -> [B, N * L]
        x = self.temporal_norm(x)  # [B, N * L] -> [B, N * L]
        x = x.reshape(batch_size, n_series, input_size)  # [B, N * L] -> [B, N, L]
        x = F.relu(self.temporal_lin(x))  # [B, N, L] -> [B, N, L]
        x = x.permute(0, 2, 1)  # [B, N, L] -> [B, L, N]
        x = self.temporal_drop(x)  # [B, L, N] -> [B, L, N]

        return x + input


class FeatureMixing(nn.Module):
    def __init__(self, n_series, input_size, dropout, ff_dim):
        super().__init__()
        self.feature_norm = nn.BatchNorm1d(
            num_features=n_series * input_size, eps=0.001, momentum=0.01
        )
        self.feature_lin_1 = nn.Linear(n_series, ff_dim)
        self.feature_lin_2 = nn.Linear(ff_dim, n_series)
        self.feature_drop_1 = nn.Dropout(dropout)
        self.feature_drop_2 = nn.Dropout(dropout)

    def forward(self, input):
        # Get shapes
        batch_size = input.shape[0]
        input_size = input.shape[1]
        n_series = input.shape[2]

        # Feature MLP
        x = input.reshape(batch_size, -1)  # [B, L, N] -> [B, L * N]
        x = self.feature_norm(x)  # [B, L * N] -> [B, L * N]
        x = x.reshape(batch_size, input_size, n_series)  # [B, L * N] -> [B, L, N]
        x = F.relu(self.feature_lin_1(x))  # [B, L, N] -> [B, L, ff_dim]
        x = self.feature_drop_1(x)  # [B, L, ff_dim] -> [B, L, ff_dim]
        x = self.feature_lin_2(x)  # [B, L, ff_dim] -> [B, L, N]
        x = self.feature_drop_2(x)  # [B, L, N] -> [B, L, N]

        return x + input


class MixingLayer(nn.Module):
    def __init__(self, n_series, input_size, dropout, ff_dim):
        super().__init__()
        # Mixing layer consists of a temporal and feature mixer
        self.temporal_mixer = TemporalMixing(n_series, input_size, dropout)
        self.feature_mixer = FeatureMixing(n_series, input_size, dropout, ff_dim)

    def forward(self, input):
        x = self.temporal_mixer(input)
        x = self.feature_mixer(x)
        return x

# %% ../../nbs/models.tsmixer.ipynb 10
class ReversibleInstanceNorm1d(nn.Module):
    def __init__(self, n_series, eps=1e-5):
        super().__init__()
        self.weight = nn.Parameter(torch.ones((1, 1, n_series)))
        self.bias = nn.Parameter(torch.zeros((1, 1, n_series)))

        self.eps = eps

    def forward(self, x):
        # Batch statistics
        self.batch_mean = torch.mean(x, axis=1, keepdim=True).detach()
        self.batch_std = torch.sqrt(
            torch.var(x, axis=1, keepdim=True, unbiased=False) + self.eps
        ).detach()

        # Instance normalization
        x = x - self.batch_mean
        x = x / self.batch_std
        x = x * self.weight
        x = x + self.bias

        return x

    def reverse(self, x):
        # Reverse the normalization
        x = x - self.bias
        x = x / self.weight
        x = x * self.batch_std
        x = x + self.batch_mean

        return x

# %% ../../nbs/models.tsmixer.ipynb 12
class TSMixer(BaseMultivariate):
    """TSMixer

    Time-Series Mixer (`TSMixer`) is a MLP-based multivariate time-series forecasting model. `TSMixer` jointly learns temporal and cross-sectional representations of the time-series by repeatedly combining time- and feature information using stacked mixing layers. A mixing layer consists of a sequential time- and feature Multi Layer Perceptron (`MLP`).

    **Parameters:**<br>
    `h`: int, forecast horizon.<br>
    `input_size`: int, considered autorregresive inputs (lags), y=[1,2,3,4] input_size=2 -> lags=[1,2].<br>
    `n_series`: int, number of time-series.<br>
    `futr_exog_list`: str list, future exogenous columns.<br>
    `hist_exog_list`: str list, historic exogenous columns.<br>
    `stat_exog_list`: str list, static exogenous columns.<br>
    `n_block`: int=2, number of mixing layers in the model.<br>
    `ff_dim`: int=64, number of units for the second feed-forward layer in the feature MLP.<br>
    `dropout`: float=0.9, dropout rate between (0, 1) .<br>
    `revin`: bool=True, if True uses Reverse Instance Normalization to process inputs and outputs.<br>
    `loss`: PyTorch module, instantiated train loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).<br>
    `valid_loss`: PyTorch module=`loss`, instantiated valid loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).<br>
    `max_steps`: int=1000, maximum number of training steps.<br>
    `learning_rate`: float=1e-3, Learning rate between (0, 1).<br>
    `num_lr_decays`: int=-1, Number of learning rate decays, evenly distributed across max_steps.<br>
    `early_stop_patience_steps`: int=-1, Number of validation iterations before early stopping.<br>
    `val_check_steps`: int=100, Number of training steps between every validation loss check.<br>
    `batch_size`: int=32, number of different series in each batch.<br>
    `step_size`: int=1, step size between each window of temporal data.<br>
    `scaler_type`: str='identity', type of scaler for temporal inputs normalization see [temporal scalers](https://nixtla.github.io/neuralforecast/common.scalers.html).<br>
    `random_seed`: int=1, random_seed for pytorch initializer and numpy generators.<br>
    `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.<br>
    `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.<br>
    `alias`: str, optional,  Custom name of the model.<br>
    `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).<br>
    `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.<br>
    `**trainer_kwargs`: int,  keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).<br>

    **References:**<br>
    - [Chen, Si-An, Chun-Liang Li, Nate Yoder, Sercan O. Arik, and Tomas Pfister (2023). "TSMixer: An All-MLP Architecture for Time Series Forecasting."](http://arxiv.org/abs/2303.06053)

    """

    # Class attributes
    SAMPLING_TYPE = "multivariate"

    def __init__(
        self,
        h,
        input_size,
        n_series,
        futr_exog_list=None,
        hist_exog_list=None,
        stat_exog_list=None,
        n_block=2,
        ff_dim=64,
        dropout=0.9,
        revin=True,
        loss=MAE(),
        valid_loss=None,
        max_steps: int = 1000,
        learning_rate: float = 1e-3,
        num_lr_decays: int = -1,
        early_stop_patience_steps: int = -1,
        val_check_steps: int = 100,
        batch_size: int = 32,
        step_size: int = 1,
        scaler_type: str = "identity",
        random_seed: int = 1,
        num_workers_loader: int = 0,
        drop_last_loader: bool = False,
        optimizer=None,
        optimizer_kwargs=None,
        **trainer_kwargs
    ):

        # Inherit BaseMultivariate class
        super(TSMixer, self).__init__(
            h=h,
            input_size=input_size,
            n_series=n_series,
            futr_exog_list=futr_exog_list,
            hist_exog_list=hist_exog_list,
            stat_exog_list=stat_exog_list,
            loss=loss,
            valid_loss=valid_loss,
            max_steps=max_steps,
            learning_rate=learning_rate,
            num_lr_decays=num_lr_decays,
            early_stop_patience_steps=early_stop_patience_steps,
            val_check_steps=val_check_steps,
            batch_size=batch_size,
            step_size=step_size,
            scaler_type=scaler_type,
            random_seed=random_seed,
            num_workers_loader=num_workers_loader,
            drop_last_loader=drop_last_loader,
            optimizer=optimizer,
            optimizer_kwargs=optimizer_kwargs,
            **trainer_kwargs
        )
        # Asserts
        if stat_exog_list is not None:
            raise Exception(
                "TSMixer does not support static exogenous variables. Use TSMixerx if you want to use static exogenous variables."
            )
        if futr_exog_list is not None:
            raise Exception(
                "TSMixer does not support future exogenous variables. Use TSMixerx if you want to use future exogenous variables."
            )
        if hist_exog_list is not None:
            raise Exception(
                "TSMixer does not support historical exogenous variables. Use TSMixerx if you want to use historical exogenous variables."
            )

        # Reversible InstanceNormalization layer
        self.revin = revin
        if self.revin:
            self.norm = ReversibleInstanceNorm1d(n_series=n_series)

        # Mixing layers
        mixing_layers = [
            MixingLayer(
                n_series=n_series, input_size=input_size, dropout=dropout, ff_dim=ff_dim
            )
            for _ in range(n_block)
        ]
        self.mixing_layers = nn.Sequential(*mixing_layers)

        # Linear output with Loss dependent dimensions
        self.out = nn.Linear(
            in_features=input_size, out_features=h * self.loss.outputsize_multiplier
        )

    def forward(self, windows_batch):
        # Parse batch
        x = windows_batch["insample_y"]  # x: [batch_size, input_size, n_series]
        batch_size = x.shape[0]

        # TSMixer: InstanceNorm + Mixing layers + Dense output layer + ReverseInstanceNorm
        if self.revin:
            x = self.norm(x)
        x = self.mixing_layers(x)
        x = x.permute(0, 2, 1)
        x = self.out(x)
        x = x.permute(0, 2, 1)
        if self.revin:
            x = self.norm.reverse(x)

        x = x.reshape(
            batch_size, self.h, self.loss.outputsize_multiplier * self.n_series
        )
        forecast = self.loss.domain_map(x)

        # domain_map might have squeezed the last dimension in case n_series == 1
        # Note that this fails in case of a tuple loss, but Multivariate does not support tuple losses yet.
        if forecast.ndim == 2:
            return forecast.unsqueeze(-1)
        else:
            return forecast