{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| default_exp models.timesnet" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# TimesNet" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The TimesNet univariate model tackles the challenge of modeling multiple intraperiod and interperiod temporal variations.\n", "\n", "The architecture has the following distinctive features:\n", "- An embedding layer that maps the input sequence into a latent space.\n", "- Transformation of 1D time seires into 2D tensors, based on periods found by FFT.\n", "- A convolutional Inception block that captures temporal variations at different scales and between periods.\n", "\n", "**References**
\n", "- [Haixu Wu and Tengge Hu and Yong Liu and Hang Zhou and Jianmin Wang and Mingsheng Long. TimesNet: Temporal 2D-Variation Modeling for General Time Series Analysis](https://openreview.net/pdf?id=ju_Uqw384Oq)\n", "- Based on the implementation in https://github.com/thuml/Time-Series-Library (license: https://github.com/thuml/Time-Series-Library/blob/main/LICENSE)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![Figure 1. TimesNet Architecture.](imgs_models/timesnet.png)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| export\n", "from typing import Optional\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.fft\n", "\n", "from neuralforecast.common._modules import DataEmbedding\n", "from neuralforecast.common._base_windows import BaseWindows\n", "\n", "from neuralforecast.losses.pytorch import MAE" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| hide\n", "from fastcore.test import test_eq\n", "from nbdev.showdoc import show_doc" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Auxiliary Functions" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| export\n", "class Inception_Block_V1(nn.Module):\n", " def __init__(self, in_channels, out_channels, num_kernels=6, init_weight=True):\n", " super(Inception_Block_V1, self).__init__()\n", " self.in_channels = in_channels\n", " self.out_channels = out_channels\n", " self.num_kernels = num_kernels\n", " kernels = []\n", " for i in range(self.num_kernels):\n", " kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=2 * i + 1, padding=i))\n", " self.kernels = nn.ModuleList(kernels)\n", " if init_weight:\n", " self._initialize_weights()\n", "\n", " def _initialize_weights(self):\n", " for m in self.modules():\n", " if isinstance(m, nn.Conv2d):\n", " nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n", " if m.bias is not None:\n", " nn.init.constant_(m.bias, 0)\n", "\n", " def forward(self, x):\n", " res_list = []\n", " for i in range(self.num_kernels):\n", " res_list.append(self.kernels[i](x))\n", " res = torch.stack(res_list, dim=-1).mean(-1)\n", " return res" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| export\n", "def FFT_for_Period(x, k=2):\n", " # [B, T, C]\n", " xf = torch.fft.rfft(x, dim=1)\n", " # find period by amplitudes\n", " frequency_list = abs(xf).mean(0).mean(-1)\n", " frequency_list[0] = 0\n", " _, top_list = torch.topk(frequency_list, k)\n", " top_list = top_list.detach().cpu().numpy()\n", " period = x.shape[1] // top_list\n", " return period, abs(xf).mean(-1)[:, top_list]\n", "\n", "class TimesBlock(nn.Module):\n", " def __init__(self, input_size, h, k, hidden_size, conv_hidden_size, num_kernels):\n", " super(TimesBlock, self).__init__()\n", " self.input_size = input_size\n", " self.h = h\n", " self.k = k\n", " # parameter-efficient design\n", " self.conv = nn.Sequential(\n", " Inception_Block_V1(hidden_size, conv_hidden_size,\n", " num_kernels=num_kernels),\n", " nn.GELU(),\n", " Inception_Block_V1(conv_hidden_size, hidden_size,\n", " num_kernels=num_kernels)\n", " )\n", "\n", " def forward(self, x):\n", " B, T, N = x.size()\n", " period_list, period_weight = FFT_for_Period(x, self.k)\n", "\n", " res = []\n", " for i in range(self.k):\n", " period = period_list[i]\n", " # padding\n", " if (self.input_size + self.h) % period != 0:\n", " length = (\n", " ((self.input_size + self.h) // period) + 1) * period\n", " padding = torch.zeros([x.shape[0], (length - (self.input_size + self.h)), x.shape[2]], device=x.device)\n", " out = torch.cat([x, padding], dim=1)\n", " else:\n", " length = (self.input_size + self.h)\n", " out = x\n", " # reshape\n", " out = out.reshape(B, length // period, period,\n", " N).permute(0, 3, 1, 2).contiguous()\n", " # 2D conv: from 1d Variation to 2d Variation\n", " out = self.conv(out)\n", " # reshape back\n", " out = out.permute(0, 2, 3, 1).reshape(B, -1, N)\n", " res.append(out[:, :(self.input_size + self.h), :])\n", " res = torch.stack(res, dim=-1)\n", " # adaptive aggregation\n", " period_weight = F.softmax(period_weight, dim=1)\n", " period_weight = period_weight.unsqueeze(\n", " 1).unsqueeze(1).repeat(1, T, N, 1)\n", " res = torch.sum(res * period_weight, -1)\n", " # residual connection\n", " res = res + x\n", " return res" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. TimesNet" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| export\n", "class TimesNet(BaseWindows):\n", " \"\"\" TimesNet\n", "\n", " The TimesNet univariate model tackles the challenge of modeling multiple intraperiod and interperiod temporal variations.\n", " \n", " Parameters\n", " ----------\n", " h : int\n", " Forecast horizon.\n", " input_size : int\n", " Length of input window (lags).\n", " futr_exog_list : list of str, optional (default=None)\n", " Future exogenous columns.\n", " hist_exog_list : list of str, optional (default=None)\n", " Historic exogenous columns.\n", " stat_exog_list : list of str, optional (default=None)\n", " Static exogenous columns.\n", " exclude_insample_y : bool (default=False)\n", " The model skips the autoregressive features y[t-input_size:t] if True\n", " hidden_size : int (default=64)\n", " Size of embedding for embedding and encoders.\n", " dropout : float between [0, 1) (default=0.1)\n", " Dropout for embeddings.\n", "\tconv_hidden_size: int (default=64)\n", " Channels of the Inception block.\n", " top_k: int (default=5)\n", " Number of periods.\n", " num_kernels: int (default=6)\n", " Number of kernels for the Inception block.\n", " encoder_layers : int, (default=2)\n", " Number of encoder layers.\n", " loss: PyTorch module (default=MAE())\n", " Instantiated train loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).\n", " valid_loss: PyTorch module (default=None, uses loss)\n", " Instantiated validation loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).\n", " max_steps: int (default=1000)\n", " Maximum number of training steps.\n", " learning_rate : float (default=1e-4)\n", " Learning rate.\n", " num_lr_decays`: int (default=-1)\n", " Number of learning rate decays, evenly distributed across max_steps. If -1, no learning rate decay is performed.\n", " early_stop_patience_steps : int (default=-1)\n", " Number of validation iterations before early stopping. If -1, no early stopping is performed.\n", " val_check_steps : int (default=100)\n", " Number of training steps between every validation loss check.\n", " batch_size : int (default=32)\n", " Number of different series in each batch.\n", " valid_batch_size : int (default=None)\n", " Number of different series in each validation and test batch, if None uses batch_size.\n", " windows_batch_size : int (default=64)\n", " Number of windows to sample in each training batch.\n", " inference_windows_batch_size : int (default=256)\n", " Number of windows to sample in each inference batch.\n", " start_padding_enabled : bool (default=False)\n", " If True, the model will pad the time series with zeros at the beginning by input size.\n", " scaler_type : str (default='standard')\n", " Type of scaler for temporal inputs normalization see [temporal scalers](https://nixtla.github.io/neuralforecast/common.scalers.html).
\n", " random_seed : int (default=1)\n", " Random_seed for pytorch initializer and numpy generators.\n", " num_workers_loader : int (default=0)\n", " Workers to be used by `TimeSeriesDataLoader`.\n", " drop_last_loader : bool (default=False)\n", " If True `TimeSeriesDataLoader` drops last non-full batch.\n", " `optimizer`: Subclass of 'torch.optim.Optimizer', optional (default=None)\n", " User specified optimizer instead of the default choice (Adam).\n", " `optimizer_kwargs`: dict, optional (defualt=None)\n", " List of parameters used by the user specified `optimizer`.\n", " **trainer_kwargs\n", " Keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer)\n", "\n", "\tReferences\n", "\t----------\n", " Haixu Wu and Tengge Hu and Yong Liu and Hang Zhou and Jianmin Wang and Mingsheng Long. TimesNet: Temporal 2D-Variation Modeling for General Time Series Analysis. https://openreview.net/pdf?id=ju_Uqw384Oq\n", " \"\"\"\n", " # Class attributes\n", " SAMPLING_TYPE = 'windows'\n", " \n", " def __init__(self,\n", " h: int, \n", " input_size: int,\n", " stat_exog_list = None,\n", " hist_exog_list = None,\n", " futr_exog_list = None,\n", " exclude_insample_y = False,\n", " hidden_size: int = 64, \n", " dropout: float = 0.1,\n", " conv_hidden_size: int = 64,\n", " top_k: int = 5,\n", " num_kernels: int = 6,\n", " encoder_layers: int = 2,\n", " loss = MAE(),\n", " valid_loss = None,\n", " max_steps: int = 1000,\n", " learning_rate: float = 1e-4,\n", " num_lr_decays: int = -1,\n", " early_stop_patience_steps: int =-1,\n", " val_check_steps: int = 100,\n", " batch_size: int = 32,\n", " valid_batch_size: Optional[int] = None,\n", " windows_batch_size = 64,\n", " inference_windows_batch_size = 256,\n", " start_padding_enabled = False,\n", " step_size: int = 1,\n", " scaler_type: str = 'standard',\n", " random_seed: int = 1,\n", " num_workers_loader: int = 0,\n", " drop_last_loader: bool = False,\n", " optimizer = None,\n", " optimizer_kwargs = None,\n", " **trainer_kwargs):\n", " super(TimesNet, self).__init__(h=h,\n", " input_size=input_size,\n", " hist_exog_list=hist_exog_list,\n", " stat_exog_list=stat_exog_list,\n", " futr_exog_list = futr_exog_list,\n", " exclude_insample_y = exclude_insample_y,\n", " loss=loss,\n", " valid_loss=valid_loss,\n", " max_steps=max_steps,\n", " learning_rate=learning_rate,\n", " num_lr_decays=num_lr_decays,\n", " early_stop_patience_steps=early_stop_patience_steps,\n", " val_check_steps=val_check_steps,\n", " batch_size=batch_size,\n", " windows_batch_size=windows_batch_size,\n", " valid_batch_size=valid_batch_size,\n", " inference_windows_batch_size=inference_windows_batch_size,\n", " start_padding_enabled = start_padding_enabled,\n", " step_size=step_size,\n", " scaler_type=scaler_type,\n", " num_workers_loader=num_workers_loader,\n", " drop_last_loader=drop_last_loader,\n", " random_seed=random_seed,\n", " optimizer=optimizer,\n", " optimizer_kwargs=optimizer_kwargs,\n", " **trainer_kwargs)\n", "\n", " # Architecture\n", " self.futr_input_size = len(self.futr_exog_list)\n", " self.hist_input_size = len(self.hist_exog_list)\n", " self.stat_input_size = len(self.stat_exog_list)\n", "\n", " if self.stat_input_size > 0:\n", " raise Exception('TimesNet does not support static variables yet')\n", " if self.hist_input_size > 0:\n", " raise Exception('TimesNet does not support historical variables yet')\n", "\n", " self.c_out = self.loss.outputsize_multiplier\n", " self.enc_in = 1 \n", " self.dec_in = 1\n", "\n", " self.model = nn.ModuleList([TimesBlock(input_size=input_size,\n", " h=h,\n", " k=top_k,\n", " hidden_size=hidden_size,\n", " conv_hidden_size=conv_hidden_size,\n", " num_kernels=num_kernels)\n", " for _ in range(encoder_layers)])\n", "\n", " self.enc_embedding = DataEmbedding(c_in=self.enc_in,\n", " exog_input_size=self.futr_input_size,\n", " hidden_size=hidden_size, \n", " pos_embedding=True, # Original implementation uses true\n", " dropout=dropout)\n", " self.encoder_layers = encoder_layers\n", " self.layer_norm = nn.LayerNorm(hidden_size)\n", " self.predict_linear = nn.Linear(self.input_size, self.h + self.input_size)\n", " self.projection = nn.Linear(hidden_size, self.c_out, bias=True)\n", "\n", " def forward(self, windows_batch):\n", "\n", " # Parse windows_batch\n", " insample_y = windows_batch['insample_y']\n", " #insample_mask = windows_batch['insample_mask']\n", " #hist_exog = windows_batch['hist_exog']\n", " #stat_exog = windows_batch['stat_exog']\n", " futr_exog = windows_batch['futr_exog']\n", "\n", " # Parse inputs\n", " insample_y = insample_y.unsqueeze(-1) # [Ws,L,1]\n", " if self.futr_input_size > 0:\n", " x_mark_enc = futr_exog[:,:self.input_size,:]\n", " else:\n", " x_mark_enc = None\n", "\n", " # embedding\n", " enc_out = self.enc_embedding(insample_y, x_mark_enc)\n", " enc_out = self.predict_linear(enc_out.permute(0, 2, 1)).permute(0, 2, 1) # align temporal dimension\n", " # TimesNet\n", " for i in range(self.encoder_layers):\n", " enc_out = self.layer_norm(self.model[i](enc_out))\n", " # porject back\n", " dec_out = self.projection(enc_out)\n", "\n", " forecast = self.loss.domain_map(dec_out[:, -self.h:])\n", " return forecast" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "show_doc(TimesNet)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "show_doc(TimesNet.fit, name='TimesNet.fit')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "show_doc(TimesNet.predict, name='TimesNet.predict')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Usage Example" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| eval: false\n", "import numpy as np\n", "import pandas as pd\n", "import pytorch_lightning as pl\n", "import matplotlib.pyplot as plt\n", "\n", "from neuralforecast import NeuralForecast\n", "from neuralforecast.losses.pytorch import MQLoss, DistributionLoss\n", "from neuralforecast.utils import AirPassengers, AirPassengersPanel, AirPassengersStatic, augment_calendar_df\n", "\n", "AirPassengersPanel, calendar_cols = augment_calendar_df(df=AirPassengersPanel, freq='M')\n", "\n", "Y_train_df = AirPassengersPanel[AirPassengersPanel.ds=AirPassengersPanel['ds'].values[-12]].reset_index(drop=True) # 12 test\n", "\n", "model = TimesNet(h=12,\n", " input_size=24,\n", " hidden_size = 16,\n", " conv_hidden_size = 32,\n", " #loss=MAE(),\n", " #loss=MQLoss(quantiles=[0.2, 0.5, 0.8]),\n", " loss=DistributionLoss(distribution='Normal', level=[80, 90]),\n", " futr_exog_list=calendar_cols,\n", " scaler_type='standard',\n", " learning_rate=1e-3,\n", " max_steps=5,\n", " val_check_steps=50,\n", " early_stop_patience_steps=2)\n", "\n", "nf = NeuralForecast(\n", " models=[model],\n", " freq='M'\n", ")\n", "nf.fit(df=Y_train_df, static_df=AirPassengersStatic, val_size=12)\n", "forecasts = nf.predict(futr_df=Y_test_df)\n", "\n", "Y_hat_df = forecasts.reset_index(drop=False).drop(columns=['unique_id','ds'])\n", "plot_df = pd.concat([Y_test_df, Y_hat_df], axis=1)\n", "plot_df = pd.concat([Y_train_df, plot_df])\n", "\n", "if model.loss.is_distribution_output:\n", " plot_df = plot_df[plot_df.unique_id=='Airline1'].drop('unique_id', axis=1)\n", " plt.plot(plot_df['ds'], plot_df['y'], c='black', label='True')\n", " plt.plot(plot_df['ds'], plot_df['TimesNet-median'], c='blue', label='median')\n", " plt.fill_between(x=plot_df['ds'][-12:], \n", " y1=plot_df['TimesNet-lo-90'][-12:].values, \n", " y2=plot_df['TimesNet-hi-90'][-12:].values,\n", " alpha=0.4, label='level 90')\n", " plt.grid()\n", " plt.legend()\n", " plt.plot()\n", "else:\n", " plot_df = plot_df[plot_df.unique_id=='Airline1'].drop('unique_id', axis=1)\n", " plt.plot(plot_df['ds'], plot_df['y'], c='black', label='True')\n", " plt.plot(plot_df['ds'], plot_df['TimesNet'], c='blue', label='Forecast')\n", " plt.legend()\n", " plt.grid()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "python3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 4 }