{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| default_exp models.bitcn"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# BiTCN"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Bidirectional Temporal Convolutional Network (BiTCN) is a forecasting architecture based on two temporal convolutional networks (TCNs). The first network ('forward') encodes future covariates of the time series, whereas the second network ('backward') encodes past observations and covariates. This method allows to preserve the temporal information of sequence data, and is computationally more efficient than common RNN methods (LSTM, GRU, ...). As compared to Transformer-based methods, BiTCN has a lower space complexity, i.e. it requires orders of magnitude less parameters.\n",
"\n",
"This model may be a good choice if you seek a small model (small amount of trainable parameters) with few hyperparameters to tune (only 2).\n",
"\n",
"**References**
\n",
"-[Olivier Sprangers, Sebastian Schelter, Maarten de Rijke (2023). Parameter-Efficient Deep Probabilistic Forecasting. International Journal of Forecasting 39, no. 1 (1 January 2023): 332–45. URL: https://doi.org/10.1016/j.ijforecast.2021.11.011.](https://doi.org/10.1016/j.ijforecast.2021.11.011)
\n",
"-[Shaojie Bai, Zico Kolter, Vladlen Koltun. (2018). An Empirical Evaluation of Generic Convolutional and Recurrent Networks for Sequence Modeling. Computing Research Repository, abs/1803.01271. URL: https://arxiv.org/abs/1803.01271.](https://arxiv.org/abs/1803.01271)
\n",
"-[van den Oord, A., Dieleman, S., Zen, H., Simonyan, K., Vinyals, O., Graves, A., Kalchbrenner, N., Senior, A. W., & Kavukcuoglu, K. (2016). Wavenet: A generative model for raw audio. Computing Research Repository, abs/1609.03499. URL: http://arxiv.org/abs/1609.03499. arXiv:1609.03499.](https://arxiv.org/abs/1609.03499)
"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"from fastcore.test import test_eq\n",
"from nbdev.showdoc import show_doc"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"from typing import Optional\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import numpy as np\n",
"\n",
"from neuralforecast.losses.pytorch import MAE\n",
"from neuralforecast.common._base_windows import BaseWindows"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Auxiliary Functions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"class CustomConv1d(nn.Module):\n",
" def __init__(self, in_channels, out_channels, kernel_size, padding=0, dilation=1, mode='backward', groups=1):\n",
" super().__init__()\n",
" k = np.sqrt(1 / (in_channels * kernel_size))\n",
" weight_data = -k + 2 * k * torch.rand((out_channels, in_channels // groups, kernel_size))\n",
" bias_data = -k + 2 * k * torch.rand((out_channels))\n",
" self.weight = nn.Parameter(weight_data, requires_grad=True)\n",
" self.bias = nn.Parameter(bias_data, requires_grad=True) \n",
" self.dilation = dilation\n",
" self.groups = groups\n",
" if mode == 'backward':\n",
" self.padding_left = padding\n",
" self.padding_right= 0\n",
" elif mode == 'forward':\n",
" self.padding_left = 0\n",
" self.padding_right= padding \n",
"\n",
" def forward(self, x):\n",
" xp = F.pad(x, (self.padding_left, self.padding_right))\n",
" return F.conv1d(xp, self.weight, self.bias, dilation=self.dilation, groups=self.groups)\n",
"\n",
"class TCNCell(nn.Module):\n",
" def __init__(self, in_channels, out_channels, kernel_size, padding, dilation, mode, groups, dropout):\n",
" super().__init__()\n",
" self.conv1 = CustomConv1d(in_channels, out_channels, kernel_size, padding, dilation, mode, groups)\n",
" self.conv2 = CustomConv1d(out_channels, in_channels * 2, 1)\n",
" self.drop = nn.Dropout(dropout)\n",
" \n",
" def forward(self, x):\n",
" h_prev, out_prev = x\n",
" h = self.drop(F.gelu(self.conv1(h_prev)))\n",
" h_next, out_next = self.conv2(h).chunk(2, 1)\n",
" return (h_prev + h_next, out_prev + out_next)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. BiTCN"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"class BiTCN(BaseWindows):\n",
" \"\"\" BiTCN\n",
"\n",
" Bidirectional Temporal Convolutional Network (BiTCN) is a forecasting architecture based on two temporal convolutional networks (TCNs). The first network ('forward') encodes future covariates of the time series, whereas the second network ('backward') encodes past observations and covariates. This is a univariate model.\n",
"\n",
" **Parameters:**
\n",
" `h`: int, forecast horizon.
\n",
" `input_size`: int, considered autorregresive inputs (lags), y=[1,2,3,4] input_size=2 -> lags=[1,2].
\n",
" `hidden_size`: int=16, units for the TCN's hidden state size.
\n",
" `dropout`: float=0.1, dropout rate used for the dropout layers throughout the architecture.
\n",
" `futr_exog_list`: str list, future exogenous columns.
\n",
" `hist_exog_list`: str list, historic exogenous columns.
\n",
" `stat_exog_list`: str list, static exogenous columns.
\n",
" `exclude_insample_y`: bool=False, the model skips the autoregressive features y[t-input_size:t] if True.
\n",
" `loss`: PyTorch module, instantiated train loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).
\n",
" `valid_loss`: PyTorch module=`loss`, instantiated valid loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).
\n",
" `max_steps`: int=1000, maximum number of training steps.
\n",
" `learning_rate`: float=1e-3, Learning rate between (0, 1).
\n",
" `num_lr_decays`: int=-1, Number of learning rate decays, evenly distributed across max_steps.
\n",
" `early_stop_patience_steps`: int=-1, Number of validation iterations before early stopping.
\n",
" `val_check_steps`: int=100, Number of training steps between every validation loss check.
\n",
" `batch_size`: int=32, number of different series in each batch.
\n",
" `valid_batch_size`: int=None, number of different series in each validation and test batch, if None uses batch_size.
\n",
" `windows_batch_size`: int=1024, number of windows to sample in each training batch, default uses all.
\n",
" `inference_windows_batch_size`: int=-1, number of windows to sample in each inference batch, -1 uses all.
\n",
" `start_padding_enabled`: bool=False, if True, the model will pad the time series with zeros at the beginning, by input size.
\n",
" `step_size`: int=1, step size between each window of temporal data.
\n",
" `scaler_type`: str='identity', type of scaler for temporal inputs normalization see [temporal scalers](https://nixtla.github.io/neuralforecast/common.scalers.html).
\n",
" `random_seed`: int=1, random_seed for pytorch initializer and numpy generators.
\n",
" `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
\n",
" `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
\n",
" `alias`: str, optional, Custom name of the model.
\n",
" `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
\n",
" `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
\n",
" `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
\n",
"\n",
" \"\"\"\n",
" # Class attributes\n",
" SAMPLING_TYPE = 'windows'\n",
" \n",
" def __init__(self,\n",
" h: int,\n",
" input_size: int,\n",
" hidden_size: int = 16,\n",
" dropout: float = 0.5,\n",
" futr_exog_list = None,\n",
" hist_exog_list = None,\n",
" stat_exog_list = None,\n",
" exclude_insample_y = False,\n",
" loss = MAE(),\n",
" valid_loss = None,\n",
" max_steps: int = 1000,\n",
" learning_rate: float = 1e-3,\n",
" num_lr_decays: int = -1,\n",
" early_stop_patience_steps: int =-1,\n",
" val_check_steps: int = 100,\n",
" batch_size: int = 32,\n",
" valid_batch_size: Optional[int] = None,\n",
" windows_batch_size = 1024,\n",
" inference_windows_batch_size = 1024,\n",
" start_padding_enabled = False,\n",
" step_size: int = 1,\n",
" scaler_type: str = 'identity',\n",
" random_seed: int = 1,\n",
" num_workers_loader: int = 0,\n",
" drop_last_loader: bool = False,\n",
" optimizer = None,\n",
" optimizer_kwargs = None,\n",
" **trainer_kwargs):\n",
" super(BiTCN, self).__init__(\n",
" h=h,\n",
" input_size=input_size,\n",
" futr_exog_list=futr_exog_list,\n",
" hist_exog_list=hist_exog_list,\n",
" stat_exog_list=stat_exog_list,\n",
" exclude_insample_y = exclude_insample_y,\n",
" loss=loss,\n",
" valid_loss=valid_loss,\n",
" max_steps=max_steps,\n",
" learning_rate=learning_rate,\n",
" num_lr_decays=num_lr_decays,\n",
" early_stop_patience_steps=early_stop_patience_steps,\n",
" val_check_steps=val_check_steps,\n",
" batch_size=batch_size,\n",
" valid_batch_size=valid_batch_size,\n",
" windows_batch_size=windows_batch_size,\n",
" inference_windows_batch_size=inference_windows_batch_size,\n",
" start_padding_enabled=start_padding_enabled,\n",
" step_size=step_size,\n",
" scaler_type=scaler_type,\n",
" random_seed=random_seed,\n",
" num_workers_loader=num_workers_loader,\n",
" drop_last_loader=drop_last_loader,\n",
" optimizer=optimizer,\n",
" optimizer_kwargs=optimizer_kwargs,\n",
" **trainer_kwargs\n",
" )\n",
"\n",
" #----------------------------------- Parse dimensions -----------------------------------#\n",
" # TCN\n",
" kernel_size = 2 # Not really necessary as parameter, so simplifying the architecture here.\n",
" self.kernel_size = kernel_size\n",
" self.hidden_size = hidden_size\n",
" self.h = h\n",
" self.input_size = input_size\n",
" self.dropout = dropout\n",
" \n",
" # Calculate required number of TCN layers based on the required receptive field of the TCN\n",
" self.n_layers_bwd = int(np.ceil(np.log2(((self.input_size - 1) / (self.kernel_size - 1)) + 1)))\n",
"\n",
" self.futr_exog_size = len(self.futr_exog_list)\n",
" self.hist_exog_size = len(self.hist_exog_list)\n",
" self.stat_exog_size = len(self.stat_exog_list) \n",
" \n",
" #---------------------------------- Instantiate Model -----------------------------------#\n",
" \n",
" # Dense layers\n",
" self.lin_hist = nn.Linear(1 + self.hist_exog_size + self.stat_exog_size + self.futr_exog_size, hidden_size)\n",
" self.drop_hist = nn.Dropout(dropout)\n",
" \n",
" # TCN looking back\n",
" layers_bwd = [TCNCell(\n",
" hidden_size, \n",
" hidden_size, \n",
" kernel_size, \n",
" padding = (kernel_size-1)*2**i, \n",
" dilation = 2**i, \n",
" mode = 'backward', \n",
" groups = 1, \n",
" dropout = dropout) for i in range(self.n_layers_bwd)] \n",
" self.net_bwd = nn.Sequential(*layers_bwd)\n",
" \n",
" # TCN looking forward when future covariates exist\n",
" output_lin_dim_multiplier = 1\n",
" if self.futr_exog_size > 0:\n",
" self.n_layers_fwd = int(np.ceil(np.log2(((self.h + self.input_size - 1) / (self.kernel_size - 1)) + 1)))\n",
" self.lin_futr = nn.Linear(self.futr_exog_size, hidden_size)\n",
" self.drop_futr = nn.Dropout(dropout)\n",
" layers_fwd = [TCNCell(\n",
" hidden_size, \n",
" hidden_size, \n",
" kernel_size, \n",
" padding = (kernel_size - 1)*2**i, \n",
" dilation = 2**i, \n",
" mode = 'forward', \n",
" groups = 1, \n",
" dropout = dropout) for i in range(self.n_layers_fwd)] \n",
" self.net_fwd = nn.Sequential(*layers_fwd)\n",
" output_lin_dim_multiplier += 2\n",
"\n",
" # Dense temporal and output layers\n",
" self.drop_temporal = nn.Dropout(dropout)\n",
" self.temporal_lin1 = nn.Linear(self.input_size, hidden_size)\n",
" self.temporal_lin2 = nn.Linear(hidden_size, self.h)\n",
" self.output_lin = nn.Linear(output_lin_dim_multiplier * hidden_size, self.loss.outputsize_multiplier)\n",
"\n",
" def forward(self, windows_batch):\n",
" # Parse windows_batch\n",
" x = windows_batch['insample_y'].unsqueeze(-1) # [B, L, 1]\n",
" hist_exog = windows_batch['hist_exog'] # [B, L, X]\n",
" futr_exog = windows_batch['futr_exog'] # [B, L + h, F]\n",
" stat_exog = windows_batch['stat_exog'] # [B, S]\n",
"\n",
" # Concatenate x with historic exogenous\n",
" batch_size, seq_len = x.shape[:2] # B = batch_size, L = seq_len\n",
" if self.hist_exog_size > 0:\n",
" x = torch.cat((x, hist_exog), dim=2) # [B, L, 1] + [B, L, X] -> [B, L, 1 + X]\n",
"\n",
" # Concatenate x with static exogenous\n",
" if self.stat_exog_size > 0:\n",
" stat_exog = stat_exog.unsqueeze(1).repeat(1, seq_len, 1) # [B, S] -> [B, L, S]\n",
" x = torch.cat((x, stat_exog), dim=2) # [B, L, 1 + X] + [B, L, S] -> [B, L, 1 + X + S]\n",
"\n",
" # Concatenate x with future exogenous & apply forward TCN to x_futr\n",
" if self.futr_exog_size > 0:\n",
" x = torch.cat((x, futr_exog[:, :seq_len]), dim=2) # [B, L, 1 + X + S] + [B, L, F] -> [B, L, 1 + X + S + F]\n",
" x_futr = self.drop_futr(self.lin_futr(futr_exog)) # [B, L + h, F] -> [B, L + h, hidden_size]\n",
" x_futr = x_futr.permute(0, 2, 1) # [B, L + h, hidden_size] -> [B, hidden_size, L + h]\n",
" _, x_futr = self.net_fwd((x_futr, 0)) # [B, hidden_size, L + h] -> [B, hidden_size, L + h]\n",
" x_futr_L = x_futr[:, :, :seq_len] # [B, hidden_size, L + h] -> [B, hidden_size, L]\n",
" x_futr_h = x_futr[:, :, seq_len:] # [B, hidden_size, L + h] -> [B, hidden_size, h]\n",
"\n",
" # Apply backward TCN to x\n",
" x = self.drop_hist(self.lin_hist(x)) # [B, L, 1 + X + S + F] -> [B, L, hidden_size]\n",
" x = x.permute(0, 2, 1) # [B, L, hidden_size] -> [B, hidden_size, L]\n",
" _, x = self.net_bwd((x, 0)) # [B, hidden_size, L] -> [B, hidden_size, L]\n",
"\n",
" # Concatenate with future exogenous for seq_len\n",
" if self.futr_exog_size > 0:\n",
" x = torch.cat((x, x_futr_L), dim=1) # [B, hidden_size, L] + [B, hidden_size, L] -> [B, 2 * hidden_size, L]\n",
"\n",
" # Temporal dense layer to go to output horizon\n",
" x = self.drop_temporal(F.gelu(self.temporal_lin1(x))) # [B, 2 * hidden_size, L] -> [B, 2 * hidden_size, hidden_size]\n",
" x = self.temporal_lin2(x) # [B, 2 * hidden_size, hidden_size] -> [B, 2 * hidden_size, h]\n",
" \n",
" # Concatenate with future exogenous for horizon\n",
" if self.futr_exog_size > 0:\n",
" x = torch.cat((x, x_futr_h), dim=1) # [B, 2 * hidden_size, h] + [B, hidden_size, h] -> [B, 3 * hidden_size, h]\n",
"\n",
" # Output layer to create forecasts\n",
" x = x.permute(0, 2, 1) # [B, 3 * hidden_size, h] -> [B, h, 3 * hidden_size]\n",
" x = self.output_lin(x) # [B, h, 3 * hidden_size] -> [B, h, n_outputs] \n",
"\n",
" # Map to output domain\n",
" forecast = self.loss.domain_map(x)\n",
" \n",
" return forecast"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"show_doc(BiTCN)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"show_doc(BiTCN.fit, name='BiTCN.fit')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"show_doc(BiTCN.predict, name='BiTCN.predict')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Usage Example"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from neuralforecast.utils import AirPassengersDF as Y_df\n",
"from neuralforecast.tsdataset import TimeSeriesDataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"Y_train_df = Y_df[Y_df.ds<='1959-12-31'] # 132 train\n",
"Y_test_df = Y_df[Y_df.ds>'1959-12-31'] # 12 test\n",
"\n",
"dataset, *_ = TimeSeriesDataset.from_df(Y_train_df)\n",
"model = BiTCN(h=12, input_size=24, max_steps=500, scaler_type='standard')\n",
"model.fit(dataset=dataset)\n",
"y_hat = model.predict(dataset=dataset)\n",
"Y_test_df['BiTCN'] = y_hat\n",
"\n",
"#test we recover the same forecast\n",
"y_hat2 = model.predict(dataset=dataset)\n",
"test_eq(y_hat, y_hat2)\n",
"\n",
"pd.concat([Y_train_df, Y_test_df]).drop('unique_id', axis=1).set_index('ds').plot()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import pytorch_lightning as pl\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from neuralforecast import NeuralForecast\n",
"from neuralforecast.losses.pytorch import GMM, DistributionLoss\n",
"from neuralforecast.utils import AirPassengersPanel, AirPassengersStatic"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"Y_train_df = AirPassengersPanel[AirPassengersPanel.ds=AirPassengersPanel['ds'].values[-12]].reset_index(drop=True) # 12 test\n",
"\n",
"fcst = NeuralForecast(\n",
" models=[\n",
" BiTCN(h=12,\n",
" input_size=24,\n",
" loss=GMM(n_components=7, return_params=True, level=[80,90]),\n",
" max_steps=500,\n",
" scaler_type='standard',\n",
" futr_exog_list=['y_[lag12]'],\n",
" hist_exog_list=None,\n",
" stat_exog_list=['airline1'],\n",
" ), \n",
" ],\n",
" freq='M'\n",
")\n",
"fcst.fit(df=Y_train_df, static_df=AirPassengersStatic)\n",
"forecasts = fcst.predict(futr_df=Y_test_df)\n",
"\n",
"# Plot quantile predictions\n",
"Y_hat_df = forecasts.reset_index(drop=False).drop(columns=['unique_id','ds'])\n",
"plot_df = pd.concat([Y_test_df, Y_hat_df], axis=1)\n",
"plot_df = pd.concat([Y_train_df, plot_df])\n",
"\n",
"plot_df = plot_df[plot_df.unique_id=='Airline1'].drop('unique_id', axis=1)\n",
"plt.plot(plot_df['ds'], plot_df['y'], c='black', label='True')\n",
"plt.plot(plot_df['ds'], plot_df['BiTCN-median'], c='blue', label='median')\n",
"plt.fill_between(x=plot_df['ds'][-12:], \n",
" y1=plot_df['BiTCN-lo-90'][-12:].values,\n",
" y2=plot_df['BiTCN-hi-90'][-12:].values,\n",
" alpha=0.4, label='level 90')\n",
"plt.legend()\n",
"plt.grid()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "python3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}