# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/models.tft.ipynb.
# %% auto 0
__all__ = ['TFT']
# %% ../../nbs/models.tft.ipynb 4
from typing import Tuple, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn import LayerNorm
from ..losses.pytorch import MAE
from ..common._base_windows import BaseWindows
# %% ../../nbs/models.tft.ipynb 10
class MaybeLayerNorm(nn.Module):
def __init__(self, output_size, hidden_size, eps):
super().__init__()
if output_size and output_size == 1:
self.ln = nn.Identity()
else:
self.ln = LayerNorm(output_size if output_size else hidden_size, eps=eps)
def forward(self, x):
return self.ln(x)
class GLU(nn.Module):
def __init__(self, hidden_size, output_size):
super().__init__()
self.lin = nn.Linear(hidden_size, output_size * 2)
def forward(self, x: Tensor) -> Tensor:
x = self.lin(x)
x = F.glu(x)
return x
class GRN(nn.Module):
def __init__(
self,
input_size,
hidden_size,
output_size=None,
context_hidden_size=None,
dropout=0,
):
super().__init__()
self.layer_norm = MaybeLayerNorm(output_size, hidden_size, eps=1e-3)
self.lin_a = nn.Linear(input_size, hidden_size)
if context_hidden_size is not None:
self.lin_c = nn.Linear(context_hidden_size, hidden_size, bias=False)
self.lin_i = nn.Linear(hidden_size, hidden_size)
self.glu = GLU(hidden_size, output_size if output_size else hidden_size)
self.dropout = nn.Dropout(dropout)
self.out_proj = nn.Linear(input_size, output_size) if output_size else None
def forward(self, a: Tensor, c: Optional[Tensor] = None):
x = self.lin_a(a)
if c is not None:
x = x + self.lin_c(c).unsqueeze(1)
x = F.elu(x)
x = self.lin_i(x)
x = self.dropout(x)
x = self.glu(x)
y = a if not self.out_proj else self.out_proj(a)
x = x + y
x = self.layer_norm(x)
return x
# %% ../../nbs/models.tft.ipynb 13
class TFTEmbedding(nn.Module):
def __init__(
self, hidden_size, stat_input_size, futr_input_size, hist_input_size, tgt_size
):
super().__init__()
# There are 4 types of input:
# 1. Static continuous
# 2. Temporal known a priori continuous
# 3. Temporal observed continuous
# 4. Temporal observed targets (time series obseved so far)
self.hidden_size = hidden_size
self.stat_input_size = stat_input_size
self.futr_input_size = futr_input_size
self.hist_input_size = hist_input_size
self.tgt_size = tgt_size
# Instantiate Continuous Embeddings if size is not None
for attr, size in [
("stat_exog_embedding", stat_input_size),
("futr_exog_embedding", futr_input_size),
("hist_exog_embedding", hist_input_size),
("tgt_embedding", tgt_size),
]:
if size:
vectors = nn.Parameter(torch.Tensor(size, hidden_size))
bias = nn.Parameter(torch.zeros(size, hidden_size))
torch.nn.init.xavier_normal_(vectors)
setattr(self, attr + "_vectors", vectors)
setattr(self, attr + "_bias", bias)
else:
setattr(self, attr + "_vectors", None)
setattr(self, attr + "_bias", None)
def _apply_embedding(
self,
cont: Optional[Tensor],
cont_emb: Tensor,
cont_bias: Tensor,
):
if cont is not None:
# the line below is equivalent to following einsums
# e_cont = torch.einsum('btf,fh->bthf', cont, cont_emb)
# e_cont = torch.einsum('bf,fh->bhf', cont, cont_emb)
e_cont = torch.mul(cont.unsqueeze(-1), cont_emb)
e_cont = e_cont + cont_bias
return e_cont
return None
def forward(self, target_inp, stat_exog=None, futr_exog=None, hist_exog=None):
# temporal/static categorical/continuous known/observed input
# tries to get input, if fails returns None
# Static inputs are expected to be equal for all timesteps
# For memory efficiency there is no assert statement
stat_exog = stat_exog[:, :] if stat_exog is not None else None
s_inp = self._apply_embedding(
cont=stat_exog,
cont_emb=self.stat_exog_embedding_vectors,
cont_bias=self.stat_exog_embedding_bias,
)
k_inp = self._apply_embedding(
cont=futr_exog,
cont_emb=self.futr_exog_embedding_vectors,
cont_bias=self.futr_exog_embedding_bias,
)
o_inp = self._apply_embedding(
cont=hist_exog,
cont_emb=self.hist_exog_embedding_vectors,
cont_bias=self.hist_exog_embedding_bias,
)
# Temporal observed targets
# t_observed_tgt = torch.einsum('btf,fh->btfh',
# target_inp, self.tgt_embedding_vectors)
target_inp = torch.matmul(
target_inp.unsqueeze(3).unsqueeze(4),
self.tgt_embedding_vectors.unsqueeze(1),
).squeeze(3)
target_inp = target_inp + self.tgt_embedding_bias
return s_inp, k_inp, o_inp, target_inp
class VariableSelectionNetwork(nn.Module):
def __init__(self, hidden_size, num_inputs, dropout):
super().__init__()
self.joint_grn = GRN(
input_size=hidden_size * num_inputs,
hidden_size=hidden_size,
output_size=num_inputs,
context_hidden_size=hidden_size,
)
self.var_grns = nn.ModuleList(
[
GRN(input_size=hidden_size, hidden_size=hidden_size, dropout=dropout)
for _ in range(num_inputs)
]
)
def forward(self, x: Tensor, context: Optional[Tensor] = None):
Xi = x.reshape(*x.shape[:-2], -1)
grn_outputs = self.joint_grn(Xi, c=context)
sparse_weights = F.softmax(grn_outputs, dim=-1)
transformed_embed_list = [m(x[..., i, :]) for i, m in enumerate(self.var_grns)]
transformed_embed = torch.stack(transformed_embed_list, dim=-1)
# the line below performs batched matrix vector multiplication
# for temporal features it's bthf,btf->bth
# for static features it's bhf,bf->bh
variable_ctx = torch.matmul(
transformed_embed, sparse_weights.unsqueeze(-1)
).squeeze(-1)
return variable_ctx, sparse_weights
# %% ../../nbs/models.tft.ipynb 15
class InterpretableMultiHeadAttention(nn.Module):
def __init__(self, n_head, hidden_size, example_length, attn_dropout, dropout):
super().__init__()
self.n_head = n_head
assert hidden_size % n_head == 0
self.d_head = hidden_size // n_head
self.qkv_linears = nn.Linear(
hidden_size, (2 * self.n_head + 1) * self.d_head, bias=False
)
self.out_proj = nn.Linear(self.d_head, hidden_size, bias=False)
self.attn_dropout = nn.Dropout(attn_dropout)
self.out_dropout = nn.Dropout(dropout)
self.scale = self.d_head**-0.5
self.register_buffer(
"_mask",
torch.triu(
torch.full((example_length, example_length), float("-inf")), 1
).unsqueeze(0),
)
def forward(
self, x: Tensor, mask_future_timesteps: bool = True
) -> Tuple[Tensor, Tensor]:
# [Batch,Time,MultiHead,AttDim] := [N,T,M,AD]
bs, t, h_size = x.shape
qkv = self.qkv_linears(x)
q, k, v = qkv.split(
(self.n_head * self.d_head, self.n_head * self.d_head, self.d_head), dim=-1
)
q = q.view(bs, t, self.n_head, self.d_head)
k = k.view(bs, t, self.n_head, self.d_head)
v = v.view(bs, t, self.d_head)
# [N,T1,M,Ad] x [N,T2,M,Ad] -> [N,M,T1,T2]
# attn_score = torch.einsum('bind,bjnd->bnij', q, k)
attn_score = torch.matmul(q.permute((0, 2, 1, 3)), k.permute((0, 2, 3, 1)))
attn_score.mul_(self.scale)
if mask_future_timesteps:
attn_score = attn_score + self._mask
attn_prob = F.softmax(attn_score, dim=3)
attn_prob = self.attn_dropout(attn_prob)
# [N,M,T1,T2] x [N,M,T1,Ad] -> [N,M,T1,Ad]
# attn_vec = torch.einsum('bnij,bjd->bnid', attn_prob, v)
attn_vec = torch.matmul(attn_prob, v.unsqueeze(1))
m_attn_vec = torch.mean(attn_vec, dim=1)
out = self.out_proj(m_attn_vec)
out = self.out_dropout(out)
return out, attn_vec
# %% ../../nbs/models.tft.ipynb 18
class StaticCovariateEncoder(nn.Module):
def __init__(self, hidden_size, num_static_vars, dropout):
super().__init__()
self.vsn = VariableSelectionNetwork(
hidden_size=hidden_size, num_inputs=num_static_vars, dropout=dropout
)
self.context_grns = nn.ModuleList(
[
GRN(input_size=hidden_size, hidden_size=hidden_size, dropout=dropout)
for _ in range(4)
]
)
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
variable_ctx, sparse_weights = self.vsn(x)
# Context vectors:
# variable selection context
# enrichment context
# state_c context
# state_h context
cs, ce, ch, cc = tuple(m(variable_ctx) for m in self.context_grns)
return cs, ce, ch, cc
# %% ../../nbs/models.tft.ipynb 20
class TemporalCovariateEncoder(nn.Module):
def __init__(self, hidden_size, num_historic_vars, num_future_vars, dropout):
super(TemporalCovariateEncoder, self).__init__()
self.history_vsn = VariableSelectionNetwork(
hidden_size=hidden_size, num_inputs=num_historic_vars, dropout=dropout
)
self.history_encoder = nn.LSTM(
input_size=hidden_size, hidden_size=hidden_size, batch_first=True
)
self.future_vsn = VariableSelectionNetwork(
hidden_size=hidden_size, num_inputs=num_future_vars, dropout=dropout
)
self.future_encoder = nn.LSTM(
input_size=hidden_size, hidden_size=hidden_size, batch_first=True
)
# Shared Gated-Skip Connection
self.input_gate = GLU(hidden_size, hidden_size)
self.input_gate_ln = LayerNorm(hidden_size, eps=1e-3)
def forward(self, historical_inputs, future_inputs, cs, ch, cc):
# [N,X_in,L] -> [N,hidden_size,L]
historical_features, _ = self.history_vsn(historical_inputs, cs)
history, state = self.history_encoder(historical_features, (ch, cc))
future_features, _ = self.future_vsn(future_inputs, cs)
future, _ = self.future_encoder(future_features, state)
# torch.cuda.synchronize() # this call gives prf boost for unknown reasons
input_embedding = torch.cat([historical_features, future_features], dim=1)
temporal_features = torch.cat([history, future], dim=1)
temporal_features = self.input_gate(temporal_features)
temporal_features = temporal_features + input_embedding
temporal_features = self.input_gate_ln(temporal_features)
return temporal_features
# %% ../../nbs/models.tft.ipynb 22
class TemporalFusionDecoder(nn.Module):
def __init__(
self, n_head, hidden_size, example_length, encoder_length, attn_dropout, dropout
):
super(TemporalFusionDecoder, self).__init__()
self.encoder_length = encoder_length
# ------------- Encoder-Decoder Attention --------------#
self.enrichment_grn = GRN(
input_size=hidden_size,
hidden_size=hidden_size,
context_hidden_size=hidden_size,
dropout=dropout,
)
self.attention = InterpretableMultiHeadAttention(
n_head=n_head,
hidden_size=hidden_size,
example_length=example_length,
attn_dropout=attn_dropout,
dropout=dropout,
)
self.attention_gate = GLU(hidden_size, hidden_size)
self.attention_ln = LayerNorm(normalized_shape=hidden_size, eps=1e-3)
self.positionwise_grn = GRN(
input_size=hidden_size, hidden_size=hidden_size, dropout=dropout
)
# ---------------------- Decoder -----------------------#
self.decoder_gate = GLU(hidden_size, hidden_size)
self.decoder_ln = LayerNorm(normalized_shape=hidden_size, eps=1e-3)
def forward(self, temporal_features, ce):
# ------------- Encoder-Decoder Attention --------------#
# Static enrichment
enriched = self.enrichment_grn(temporal_features, c=ce)
# Temporal self attention
x, _ = self.attention(enriched, mask_future_timesteps=True)
# Don't compute historical quantiles
x = x[:, self.encoder_length :, :]
temporal_features = temporal_features[:, self.encoder_length :, :]
enriched = enriched[:, self.encoder_length :, :]
x = self.attention_gate(x)
x = x + enriched
x = self.attention_ln(x)
# Position-wise feed-forward
x = self.positionwise_grn(x)
# ---------------------- Decoder ----------------------#
# Final skip connection
x = self.decoder_gate(x)
x = x + temporal_features
x = self.decoder_ln(x)
return x
# %% ../../nbs/models.tft.ipynb 24
class TFT(BaseWindows):
"""TFT
The Temporal Fusion Transformer architecture (TFT) is an Sequence-to-Sequence
model that combines static, historic and future available data to predict an
univariate target. The method combines gating layers, an LSTM recurrent encoder,
with and interpretable multi-head attention layer and a multi-step forecasting
strategy decoder.
**Parameters:**
`h`: int, Forecast horizon.
`input_size`: int, autorregresive inputs size, y=[1,2,3,4] input_size=2 -> y_[t-2:t]=[1,2].
`stat_exog_list`: str list, static continuous columns.
`hist_exog_list`: str list, historic continuous columns.
`futr_exog_list`: str list, future continuous columns.
`hidden_size`: int, units of embeddings and encoders.
`dropout`: float (0, 1), dropout of inputs VSNs.
`n_head`: int=4, number of attention heads in temporal fusion decoder.
`attn_dropout`: float (0, 1), dropout of fusion decoder's attention layer.
`shared_weights`: bool, If True, all blocks within each stack will share parameters.
`activation`: str, activation from ['ReLU', 'Softplus', 'Tanh', 'SELU', 'LeakyReLU', 'PReLU', 'Sigmoid'].
`loss`: PyTorch module, instantiated train loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).
`valid_loss`: PyTorch module=`loss`, instantiated valid loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).
`max_steps`: int=1000, maximum number of training steps.
`learning_rate`: float=1e-3, Learning rate between (0, 1).
`num_lr_decays`: int=-1, Number of learning rate decays, evenly distributed across max_steps.
`early_stop_patience_steps`: int=-1, Number of validation iterations before early stopping.
`val_check_steps`: int=100, Number of training steps between every validation loss check.
`batch_size`: int, number of different series in each batch.
`windows_batch_size`: int=None, windows sampled from rolled data, default uses all.
`inference_windows_batch_size`: int=-1, number of windows to sample in each inference batch, -1 uses all.
`start_padding_enabled`: bool=False, if True, the model will pad the time series with zeros at the beginning, by input size.
`valid_batch_size`: int=None, number of different series in each validation and test batch.
`step_size`: int=1, step size between each window of temporal data.
`scaler_type`: str='robust', type of scaler for temporal inputs normalization see [temporal scalers](https://nixtla.github.io/neuralforecast/common.scalers.html).
`random_seed`: int, random seed initialization for replicability.
`num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
`drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
`alias`: str, optional, Custom name of the model.
`optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
`optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
`**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
**References:**
- [Bryan Lim, Sercan O. Arik, Nicolas Loeff, Tomas Pfister,
"Temporal Fusion Transformers for interpretable multi-horizon time series forecasting"](https://www.sciencedirect.com/science/article/pii/S0169207021000637)
"""
# Class attributes
SAMPLING_TYPE = "windows"
def __init__(
self,
h,
input_size,
tgt_size: int = 1,
stat_exog_list=None,
hist_exog_list=None,
futr_exog_list=None,
hidden_size: int = 128,
n_head: int = 4,
attn_dropout: float = 0.0,
dropout: float = 0.1,
loss=MAE(),
valid_loss=None,
max_steps: int = 1000,
learning_rate: float = 1e-3,
num_lr_decays: int = -1,
early_stop_patience_steps: int = -1,
val_check_steps: int = 100,
batch_size: int = 32,
valid_batch_size: Optional[int] = None,
windows_batch_size: int = 1024,
inference_windows_batch_size: int = 1024,
start_padding_enabled=False,
step_size: int = 1,
scaler_type: str = "robust",
num_workers_loader=0,
drop_last_loader=False,
random_seed: int = 1,
optimizer=None,
optimizer_kwargs=None,
**trainer_kwargs
):
# Inherit BaseWindows class
super(TFT, self).__init__(
h=h,
input_size=input_size,
stat_exog_list=stat_exog_list,
hist_exog_list=hist_exog_list,
futr_exog_list=futr_exog_list,
loss=loss,
valid_loss=valid_loss,
max_steps=max_steps,
learning_rate=learning_rate,
num_lr_decays=num_lr_decays,
early_stop_patience_steps=early_stop_patience_steps,
val_check_steps=val_check_steps,
batch_size=batch_size,
valid_batch_size=valid_batch_size,
windows_batch_size=windows_batch_size,
inference_windows_batch_size=inference_windows_batch_size,
start_padding_enabled=start_padding_enabled,
step_size=step_size,
scaler_type=scaler_type,
num_workers_loader=num_workers_loader,
drop_last_loader=drop_last_loader,
random_seed=random_seed,
optimizer=optimizer,
optimizer_kwargs=optimizer_kwargs,
**trainer_kwargs
)
self.example_length = input_size + h
stat_input_size = len(self.stat_exog_list)
futr_input_size = max(len(self.futr_exog_list), 1)
hist_input_size = len(self.hist_exog_list)
num_historic_vars = futr_input_size + hist_input_size + tgt_size
# ------------------------------- Encoders -----------------------------#
self.embedding = TFTEmbedding(
hidden_size=hidden_size,
stat_input_size=stat_input_size,
futr_input_size=futr_input_size,
hist_input_size=hist_input_size,
tgt_size=tgt_size,
)
self.static_encoder = StaticCovariateEncoder(
hidden_size=hidden_size, num_static_vars=stat_input_size, dropout=dropout
)
self.temporal_encoder = TemporalCovariateEncoder(
hidden_size=hidden_size,
num_historic_vars=num_historic_vars,
num_future_vars=futr_input_size,
dropout=dropout,
)
# ------------------------------ Decoders -----------------------------#
self.temporal_fusion_decoder = TemporalFusionDecoder(
n_head=n_head,
hidden_size=hidden_size,
example_length=self.example_length,
encoder_length=self.input_size,
attn_dropout=attn_dropout,
dropout=dropout,
)
# Adapter with Loss dependent dimensions
self.output_adapter = nn.Linear(
in_features=hidden_size, out_features=self.loss.outputsize_multiplier
)
def forward(self, windows_batch):
# Parsiw windows_batch
y_insample = windows_batch["insample_y"][:, :, None] # <- [B,T,1]
futr_exog = windows_batch["futr_exog"]
hist_exog = windows_batch["hist_exog"]
stat_exog = windows_batch["stat_exog"]
if futr_exog is None:
futr_exog = y_insample[:, [-1]]
futr_exog = futr_exog.repeat(1, self.example_length, 1)
s_inp, k_inp, o_inp, t_observed_tgt = self.embedding(
target_inp=y_insample,
hist_exog=hist_exog,
futr_exog=futr_exog,
stat_exog=stat_exog,
)
# -------------------------------- Inputs ------------------------------#
# Static context
if s_inp is not None:
cs, ce, ch, cc = self.static_encoder(s_inp)
ch, cc = ch.unsqueeze(0), cc.unsqueeze(0) # LSTM initial states
else:
# If None add zeros
batch_size, example_length, target_size, hidden_size = t_observed_tgt.shape
cs = torch.zeros(size=(batch_size, hidden_size), device=y_insample.device)
ce = torch.zeros(size=(batch_size, hidden_size), device=y_insample.device)
ch = torch.zeros(
size=(1, batch_size, hidden_size), device=y_insample.device
)
cc = torch.zeros(
size=(1, batch_size, hidden_size), device=y_insample.device
)
# Historical inputs
_historical_inputs = [
k_inp[:, : self.input_size, :],
t_observed_tgt[:, : self.input_size, :],
]
if o_inp is not None:
_historical_inputs.insert(0, o_inp[:, : self.input_size, :])
historical_inputs = torch.cat(_historical_inputs, dim=-2)
# Future inputs
future_inputs = k_inp[:, self.input_size :]
# ---------------------------- Encode/Decode ---------------------------#
# Embeddings + VSN + LSTM encoders
temporal_features = self.temporal_encoder(
historical_inputs=historical_inputs,
future_inputs=future_inputs,
cs=cs,
ch=ch,
cc=cc,
)
# Static enrichment, Attention and decoders
temporal_features = self.temporal_fusion_decoder(
temporal_features=temporal_features, ce=ce
)
# Adapt output to loss
y_hat = self.output_adapter(temporal_features)
y_hat = self.loss.domain_map(y_hat)
return y_hat