{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "4a93f115",
"metadata": {},
"outputs": [],
"source": [
"#| default_exp common._scalers"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5c704dc1",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "markdown",
"id": "83d112c7-18f8-4f20-acad-34e6de54cebf",
"metadata": {},
"source": [
"# TemporalNorm\n",
"\n",
"> Temporal normalization has proven to be essential in neural forecasting tasks, as it enables network's non-linearities to express themselves. Forecasting scaling methods take particular interest in the temporal dimension where most of the variance dwells, contrary to other deep learning techniques like `BatchNorm` that normalizes across batch and temporal dimensions, and `LayerNorm` that normalizes across the feature dimension. Currently we support the following techniques: `std`, `median`, `norm`, `norm1`, `invariant`, `revin`."
]
},
{
"cell_type": "markdown",
"id": "fee5e60b-f53b-44ff-9ace-1f5def7b601d",
"metadata": {},
"source": [
"## References"
]
},
{
"cell_type": "markdown",
"id": "f9211dd2-99a4-4d67-90cb-bb1f7851685e",
"metadata": {},
"source": [
"* [Kin G. Olivares, David Luo, Cristian Challu, Stefania La Vattiata, Max Mergenthaler, Artur Dubrawski (2023). \"HINT: Hierarchical Mixture Networks For Coherent Probabilistic Forecasting\". Neural Information Processing Systems, submitted. Working Paper version available at arxiv.](https://arxiv.org/abs/2305.07089)\n",
"* [Taesung Kim and Jinhee Kim and Yunwon Tae and Cheonbok Park and Jang-Ho Choi and Jaegul Choo. \"Reversible Instance Normalization for Accurate Time-Series Forecasting against Distribution Shift\". ICLR 2022.](https://openreview.net/pdf?id=cGDAkQo1C0p)\n",
"* [David Salinas, Valentin Flunkert, Jan Gasthaus, Tim Januschowski (2020). \"DeepAR: Probabilistic forecasting with autoregressive recurrent networks\". International Journal of Forecasting.](https://www.sciencedirect.com/science/article/pii/S0169207019301888)"
]
},
{
"cell_type": "markdown",
"id": "9319296d",
"metadata": {},
"source": [
""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "df2cc55a",
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"import torch\n",
"import torch.nn as nn"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0f08562b-88d8-4e92-aeeb-bc9bc4c61ab7",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"from nbdev.showdoc import show_doc\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5201e067-f7c0-4ca3-89a7-d879001b1908",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"plt.rcParams[\"axes.grid\"]=True\n",
"plt.rcParams['font.family'] = 'serif'\n",
"plt.rcParams[\"figure.figsize\"] = (4,2)"
]
},
{
"cell_type": "markdown",
"id": "ef461e9c",
"metadata": {},
"source": [
"# 1. Auxiliary Functions"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "12a249a3",
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"def masked_median(x, mask, dim=-1, keepdim=True):\n",
" \"\"\" Masked Median\n",
"\n",
" Compute the median of tensor `x` along dim, ignoring values where \n",
" `mask` is False. `x` and `mask` need to be broadcastable.\n",
"\n",
" **Parameters:**
\n",
" `x`: torch.Tensor to compute median of along `dim` dimension.
\n",
" `mask`: torch Tensor bool with same shape as `x`, where `x` is valid and False\n",
" where `x` should be masked. Mask should not be all False in any column of\n",
" dimension dim to avoid NaNs from zero division.
\n",
" `dim` (int, optional): Dimension to take median of. Defaults to -1.
\n",
" `keepdim` (bool, optional): Keep dimension of `x` or not. Defaults to True.
\n",
"\n",
" **Returns:**
\n",
" `x_median`: torch.Tensor with normalized values.\n",
" \"\"\"\n",
" x_nan = x.float().masked_fill(mask<1, float(\"nan\"))\n",
" x_median, _ = x_nan.nanmedian(dim=dim, keepdim=keepdim)\n",
" x_median = torch.nan_to_num(x_median, nan=0.0)\n",
" return x_median\n",
"\n",
"def masked_mean(x, mask, dim=-1, keepdim=True):\n",
" \"\"\" Masked Mean\n",
"\n",
" Compute the mean of tensor `x` along dimension, ignoring values where \n",
" `mask` is False. `x` and `mask` need to be broadcastable.\n",
"\n",
" **Parameters:**
\n",
" `x`: torch.Tensor to compute mean of along `dim` dimension.
\n",
" `mask`: torch Tensor bool with same shape as `x`, where `x` is valid and False\n",
" where `x` should be masked. Mask should not be all False in any column of\n",
" dimension dim to avoid NaNs from zero division.
\n",
" `dim` (int, optional): Dimension to take mean of. Defaults to -1.
\n",
" `keepdim` (bool, optional): Keep dimension of `x` or not. Defaults to True.
\n",
"\n",
" **Returns:**
\n",
" `x_mean`: torch.Tensor with normalized values.\n",
" \"\"\"\n",
" x_nan = x.float().masked_fill(mask<1, float(\"nan\"))\n",
" x_mean = x_nan.nanmean(dim=dim, keepdim=keepdim)\n",
" x_mean = torch.nan_to_num(x_mean, nan=0.0)\n",
" return x_mean"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "49d2e338",
"metadata": {},
"outputs": [],
"source": [
"show_doc(masked_median, title_level=3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "300e1b4c",
"metadata": {},
"outputs": [],
"source": [
"show_doc(masked_mean, title_level=3)"
]
},
{
"cell_type": "markdown",
"id": "a7a486a2",
"metadata": {},
"source": [
"# 2. Scalers"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "42c76dab",
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"def minmax_statistics(x, mask, eps=1e-6, dim=-1):\n",
" \"\"\" MinMax Scaler\n",
"\n",
" Standardizes temporal features by ensuring its range dweels between\n",
" [0,1] range. This transformation is often used as an alternative \n",
" to the standard scaler. The scaled features are obtained as:\n",
"\n",
" $$\n",
" \\mathbf{z} = (\\mathbf{x}_{[B,T,C]}-\\mathrm{min}({\\mathbf{x}})_{[B,1,C]})/\n",
" (\\mathrm{max}({\\mathbf{x}})_{[B,1,C]}- \\mathrm{min}({\\mathbf{x}})_{[B,1,C]})\n",
" $$\n",
"\n",
" **Parameters:**
\n",
" `x`: torch.Tensor input tensor.
\n",
" `mask`: torch Tensor bool, same dimension as `x`, indicates where `x` is valid and False\n",
" where `x` should be masked. Mask should not be all False in any column of\n",
" dimension dim to avoid NaNs from zero division.
\n",
" `eps` (float, optional): Small value to avoid division by zero. Defaults to 1e-6.
\n",
" `dim` (int, optional): Dimension over to compute min and max. Defaults to -1.
\n",
"\n",
" **Returns:**
\n",
" `z`: torch.Tensor same shape as `x`, except scaled.\n",
" \"\"\"\n",
" mask = mask.clone()\n",
" mask[mask==0] = torch.inf\n",
" mask[mask==1] = 0\n",
" x_max = torch.max(torch.nan_to_num(x-mask,nan=-torch.inf), dim=dim, keepdim=True)[0]\n",
" x_min = torch.min(torch.nan_to_num(x+mask,nan=torch.inf), dim=dim, keepdim=True)[0]\n",
" x_max = x_max.type(x.dtype)\n",
" x_min = x_min.type(x.dtype)\n",
"\n",
" # x_range and prevent division by zero\n",
" x_range = x_max - x_min\n",
" x_range[x_range==0] = 1.0\n",
" x_range = x_range + eps\n",
" return x_min, x_range"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "39fa429b",
"metadata": {},
"outputs": [],
"source": [
"#| exporti\n",
"def minmax_scaler(x, x_min, x_range):\n",
" return (x - x_min) / x_range\n",
"\n",
"def inv_minmax_scaler(z, x_min, x_range):\n",
" return z * x_range + x_min"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "99ea1aa9",
"metadata": {},
"outputs": [],
"source": [
"show_doc(minmax_statistics, title_level=3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "334b3d18",
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"def minmax1_statistics(x, mask, eps=1e-6, dim=-1):\n",
" \"\"\" MinMax1 Scaler\n",
"\n",
" Standardizes temporal features by ensuring its range dweels between\n",
" [-1,1] range. This transformation is often used as an alternative \n",
" to the standard scaler or classic Min Max Scaler. \n",
" The scaled features are obtained as:\n",
"\n",
" $$\\mathbf{z} = 2 (\\mathbf{x}_{[B,T,C]}-\\mathrm{min}({\\mathbf{x}})_{[B,1,C]})/ (\\mathrm{max}({\\mathbf{x}})_{[B,1,C]}- \\mathrm{min}({\\mathbf{x}})_{[B,1,C]})-1$$\n",
"\n",
" **Parameters:**
\n",
" `x`: torch.Tensor input tensor.
\n",
" `mask`: torch Tensor bool, same dimension as `x`, indicates where `x` is valid and False\n",
" where `x` should be masked. Mask should not be all False in any column of\n",
" dimension dim to avoid NaNs from zero division.
\n",
" `eps` (float, optional): Small value to avoid division by zero. Defaults to 1e-6.
\n",
" `dim` (int, optional): Dimension over to compute min and max. Defaults to -1.
\n",
"\n",
" **Returns:**
\n",
" `z`: torch.Tensor same shape as `x`, except scaled.\n",
" \"\"\"\n",
" # Mask values (set masked to -inf or +inf)\n",
" mask = mask.clone()\n",
" mask[mask==0] = torch.inf\n",
" mask[mask==1] = 0\n",
" x_max = torch.max(torch.nan_to_num(x-mask,nan=-torch.inf), dim=dim, keepdim=True)[0]\n",
" x_min = torch.min(torch.nan_to_num(x+mask,nan=torch.inf), dim=dim, keepdim=True)[0]\n",
" x_max = x_max.type(x.dtype)\n",
" x_min = x_min.type(x.dtype)\n",
" \n",
" # x_range and prevent division by zero\n",
" x_range = x_max - x_min\n",
" x_range[x_range==0] = 1.0\n",
" x_range = x_range + eps\n",
" return x_min, x_range"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a19ed5a8",
"metadata": {},
"outputs": [],
"source": [
"#| exporti\n",
"def minmax1_scaler(x, x_min, x_range):\n",
" x = (x - x_min) / x_range\n",
" z = x * (2) - 1\n",
" return z\n",
"\n",
"def inv_minmax1_scaler(z, x_min, x_range):\n",
" z = (z + 1) / 2\n",
" return z * x_range + x_min"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "88ccb77b",
"metadata": {},
"outputs": [],
"source": [
"show_doc(minmax1_statistics, title_level=3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0c187a8f",
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"def std_statistics(x, mask, dim=-1, eps=1e-6):\n",
" \"\"\" Standard Scaler\n",
"\n",
" Standardizes features by removing the mean and scaling\n",
" to unit variance along the `dim` dimension. \n",
"\n",
" For example, for `base_windows` models, the scaled features are obtained as (with dim=1):\n",
"\n",
" $$\\mathbf{z} = (\\mathbf{x}_{[B,T,C]}-\\\\bar{\\mathbf{x}}_{[B,1,C]})/\\hat{\\sigma}_{[B,1,C]}$$\n",
"\n",
" **Parameters:**
\n",
" `x`: torch.Tensor.
\n",
" `mask`: torch Tensor bool, same dimension as `x`, indicates where `x` is valid and False\n",
" where `x` should be masked. Mask should not be all False in any column of\n",
" dimension dim to avoid NaNs from zero division.
\n",
" `eps` (float, optional): Small value to avoid division by zero. Defaults to 1e-6.
\n",
" `dim` (int, optional): Dimension over to compute mean and std. Defaults to -1.
\n",
"\n",
" **Returns:**
\n",
" `z`: torch.Tensor same shape as `x`, except scaled.\n",
" \"\"\"\n",
" x_means = masked_mean(x=x, mask=mask, dim=dim)\n",
" x_stds = torch.sqrt(masked_mean(x=(x-x_means)**2, mask=mask, dim=dim))\n",
"\n",
" # Protect against division by zero\n",
" x_stds[x_stds==0] = 1.0\n",
" x_stds = x_stds + eps\n",
" return x_means, x_stds"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "17f90821",
"metadata": {},
"outputs": [],
"source": [
"#| exporti\n",
"def std_scaler(x, x_means, x_stds):\n",
" return (x - x_means) / x_stds\n",
"\n",
"def inv_std_scaler(z, x_mean, x_std):\n",
" return (z * x_std) + x_mean"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e077730c",
"metadata": {},
"outputs": [],
"source": [
"show_doc(std_statistics, title_level=3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2c22a041",
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"def robust_statistics(x, mask, dim=-1, eps=1e-6):\n",
" \"\"\" Robust Median Scaler\n",
"\n",
" Standardizes features by removing the median and scaling\n",
" with the mean absolute deviation (mad) a robust estimator of variance.\n",
" This scaler is particularly useful with noisy data where outliers can \n",
" heavily influence the sample mean / variance in a negative way.\n",
" In these scenarios the median and amd give better results.\n",
" \n",
" For example, for `base_windows` models, the scaled features are obtained as (with dim=1):\n",
"\n",
" $$\\mathbf{z} = (\\mathbf{x}_{[B,T,C]}-\\\\textrm{median}(\\mathbf{x})_{[B,1,C]})/\\\\textrm{mad}(\\mathbf{x})_{[B,1,C]}$$\n",
" \n",
" $$\\\\textrm{mad}(\\mathbf{x}) = \\\\frac{1}{N} \\sum_{}|\\mathbf{x} - \\mathrm{median}(x)|$$\n",
"\n",
" **Parameters:**
\n",
" `x`: torch.Tensor input tensor.
\n",
" `mask`: torch Tensor bool, same dimension as `x`, indicates where `x` is valid and False\n",
" where `x` should be masked. Mask should not be all False in any column of\n",
" dimension dim to avoid NaNs from zero division.
\n",
" `eps` (float, optional): Small value to avoid division by zero. Defaults to 1e-6.
\n",
" `dim` (int, optional): Dimension over to compute median and mad. Defaults to -1.
\n",
"\n",
" **Returns:**
\n",
" `z`: torch.Tensor same shape as `x`, except scaled.\n",
" \"\"\"\n",
" x_median = masked_median(x=x, mask=mask, dim=dim)\n",
" x_mad = masked_median(x=torch.abs(x-x_median), mask=mask, dim=dim)\n",
"\n",
" # Protect x_mad=0 values\n",
" # Assuming normality and relationship between mad and std\n",
" x_means = masked_mean(x=x, mask=mask, dim=dim)\n",
" x_stds = torch.sqrt(masked_mean(x=(x-x_means)**2, mask=mask, dim=dim)) \n",
" x_mad_aux = x_stds * 0.6744897501960817\n",
" x_mad = x_mad * (x_mad>0) + x_mad_aux * (x_mad==0)\n",
" \n",
" # Protect against division by zero\n",
" x_mad[x_mad==0] = 1.0\n",
" x_mad = x_mad + eps\n",
" return x_median, x_mad"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "33f3cf28",
"metadata": {},
"outputs": [],
"source": [
"#| exporti\n",
"def robust_scaler(x, x_median, x_mad):\n",
" return (x - x_median) / x_mad\n",
"\n",
"def inv_robust_scaler(z, x_median, x_mad):\n",
" return z * x_mad + x_median"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7355a5f9",
"metadata": {},
"outputs": [],
"source": [
"show_doc(robust_statistics, title_level=3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8879b00b",
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"def invariant_statistics(x, mask, dim=-1, eps=1e-6):\n",
" \"\"\" Invariant Median Scaler\n",
"\n",
" Standardizes features by removing the median and scaling\n",
" with the mean absolute deviation (mad) a robust estimator of variance.\n",
" Aditionally it complements the transformation with the arcsinh transformation.\n",
"\n",
" For example, for `base_windows` models, the scaled features are obtained as (with dim=1):\n",
"\n",
" $$\\mathbf{z} = (\\mathbf{x}_{[B,T,C]}-\\\\textrm{median}(\\mathbf{x})_{[B,1,C]})/\\\\textrm{mad}(\\mathbf{x})_{[B,1,C]}$$\n",
"\n",
" $$\\mathbf{z} = \\\\textrm{arcsinh}(\\mathbf{z})$$\n",
"\n",
" **Parameters:**
\n",
" `x`: torch.Tensor input tensor.
\n",
" `mask`: torch Tensor bool, same dimension as `x`, indicates where `x` is valid and False\n",
" where `x` should be masked. Mask should not be all False in any column of\n",
" dimension dim to avoid NaNs from zero division.
\n",
" `eps` (float, optional): Small value to avoid division by zero. Defaults to 1e-6.
\n",
" `dim` (int, optional): Dimension over to compute median and mad. Defaults to -1.
\n",
"\n",
" **Returns:**
\n",
" `z`: torch.Tensor same shape as `x`, except scaled.\n",
" \"\"\"\n",
" x_median = masked_median(x=x, mask=mask, dim=dim)\n",
" x_mad = masked_median(x=torch.abs(x-x_median), mask=mask, dim=dim)\n",
"\n",
" # Protect x_mad=0 values\n",
" # Assuming normality and relationship between mad and std\n",
" x_means = masked_mean(x=x, mask=mask, dim=dim)\n",
" x_stds = torch.sqrt(masked_mean(x=(x-x_means)**2, mask=mask, dim=dim)) \n",
" x_mad_aux = x_stds * 0.6744897501960817\n",
" x_mad = x_mad * (x_mad>0) + x_mad_aux * (x_mad==0)\n",
"\n",
" # Protect against division by zero\n",
" x_mad[x_mad==0] = 1.0\n",
" x_mad = x_mad + eps\n",
" return x_median, x_mad"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "24cca2bf",
"metadata": {},
"outputs": [],
"source": [
"#| exporti\n",
"def invariant_scaler(x, x_median, x_mad):\n",
" return torch.arcsinh((x - x_median) / x_mad)\n",
"\n",
"def inv_invariant_scaler(z, x_median, x_mad):\n",
" return torch.sinh(z) * x_mad + x_median"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f4b1b313",
"metadata": {},
"outputs": [],
"source": [
"show_doc(invariant_statistics, title_level=3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "50ba1916",
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"def identity_statistics(x, mask, dim=-1, eps=1e-6):\n",
" \"\"\" Identity Scaler\n",
"\n",
" A placeholder identity scaler, that is argument insensitive.\n",
"\n",
" **Parameters:**
\n",
" `x`: torch.Tensor input tensor.
\n",
" `mask`: torch Tensor bool, same dimension as `x`, indicates where `x` is valid and False\n",
" where `x` should be masked. Mask should not be all False in any column of\n",
" dimension dim to avoid NaNs from zero division.
\n",
" `eps` (float, optional): Small value to avoid division by zero. Defaults to 1e-6.
\n",
" `dim` (int, optional): Dimension over to compute median and mad. Defaults to -1.
\n",
"\n",
" **Returns:**
\n",
" `x`: original torch.Tensor `x`.\n",
" \"\"\"\n",
" # Collapse dim dimension\n",
" shape = list(x.shape)\n",
" shape[dim] = 1\n",
"\n",
" x_shift = torch.zeros(shape)\n",
" x_scale = torch.ones(shape)\n",
"\n",
" return x_shift, x_scale"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1d7b313e",
"metadata": {},
"outputs": [],
"source": [
"#| exporti\n",
"def identity_scaler(x, x_shift, x_scale):\n",
" return x\n",
"\n",
"def inv_identity_scaler(z, x_shift, x_scale):\n",
" return z"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e56ae8f7",
"metadata": {},
"outputs": [],
"source": [
"show_doc(identity_statistics, title_level=3)"
]
},
{
"cell_type": "markdown",
"id": "e87e828c",
"metadata": {},
"source": [
"# 3. TemporalNorm Module"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cb48423b",
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"class TemporalNorm(nn.Module):\n",
" \"\"\" Temporal Normalization\n",
"\n",
" Standardization of the features is a common requirement for many \n",
" machine learning estimators, and it is commonly achieved by removing \n",
" the level and scaling its variance. The `TemporalNorm` module applies \n",
" temporal normalization over the batch of inputs as defined by the type of scaler.\n",
"\n",
" $$\\mathbf{z}_{[B,T,C]} = \\\\textrm{Scaler}(\\mathbf{x}_{[B,T,C]})$$\n",
"\n",
" If `scaler_type` is `revin` learnable normalization parameters are added on top of\n",
" the usual normalization technique, the parameters are learned through scale decouple\n",
" global skip connections. The technique is available for point and probabilistic outputs.\n",
"\n",
" $$\\mathbf{\\hat{z}}_{[B,T,C]} = \\\\boldsymbol{\\hat{\\\\gamma}}_{[1,1,C]} \\mathbf{z}_{[B,T,C]} +\\\\boldsymbol{\\hat{\\\\beta}}_{[1,1,C]}$$\n",
"\n",
" **Parameters:**
\n",
" `scaler_type`: str, defines the type of scaler used by TemporalNorm. Available [`identity`, `standard`, `robust`, `minmax`, `minmax1`, `invariant`, `revin`].
\n",
" `dim` (int, optional): Dimension over to compute scale and shift. Defaults to -1.
\n",
" `eps` (float, optional): Small value to avoid division by zero. Defaults to 1e-6.
\n",
" `num_features`: int=None, for RevIN-like learnable affine parameters initialization.
\n",
"\n",
" **References**
\n",
" - [Kin G. Olivares, David Luo, Cristian Challu, Stefania La Vattiata, Max Mergenthaler, Artur Dubrawski (2023). \"HINT: Hierarchical Mixture Networks For Coherent Probabilistic Forecasting\". Neural Information Processing Systems, submitted. Working Paper version available at arxiv.](https://arxiv.org/abs/2305.07089)
\n",
" \"\"\"\n",
" def __init__(self, scaler_type='robust', dim=-1, eps=1e-6, num_features=None):\n",
" super().__init__()\n",
" compute_statistics = {None: identity_statistics,\n",
" 'identity': identity_statistics,\n",
" 'standard': std_statistics,\n",
" 'revin': std_statistics,\n",
" 'robust': robust_statistics,\n",
" 'minmax': minmax_statistics,\n",
" 'minmax1': minmax1_statistics,\n",
" 'invariant': invariant_statistics,}\n",
" scalers = {None: identity_scaler,\n",
" 'identity': identity_scaler,\n",
" 'standard': std_scaler,\n",
" 'revin': std_scaler,\n",
" 'robust': robust_scaler,\n",
" 'minmax': minmax_scaler,\n",
" 'minmax1': minmax1_scaler,\n",
" 'invariant':invariant_scaler,}\n",
" inverse_scalers = {None: inv_identity_scaler,\n",
" 'identity': inv_identity_scaler,\n",
" 'standard': inv_std_scaler,\n",
" 'revin': inv_std_scaler,\n",
" 'robust': inv_robust_scaler,\n",
" 'minmax': inv_minmax_scaler,\n",
" 'minmax1': inv_minmax1_scaler,\n",
" 'invariant': inv_invariant_scaler,}\n",
" assert (scaler_type in scalers.keys()), f'{scaler_type} not defined'\n",
" if (scaler_type=='revin') and (num_features is None):\n",
" raise Exception('You must pass num_features for ReVIN scaler.')\n",
"\n",
" self.compute_statistics = compute_statistics[scaler_type]\n",
" self.scaler = scalers[scaler_type]\n",
" self.inverse_scaler = inverse_scalers[scaler_type]\n",
" self.scaler_type = scaler_type\n",
" self.dim = dim\n",
" self.eps = eps\n",
"\n",
" if (scaler_type=='revin'):\n",
" self._init_params(num_features=num_features)\n",
"\n",
" def _init_params(self, num_features):\n",
" # Initialize RevIN scaler params to broadcast:\n",
" if self.dim==1: # [B,T,C] [1,1,C]\n",
" self.revin_bias = nn.Parameter(torch.zeros(1,1,num_features))\n",
" self.revin_weight = nn.Parameter(torch.ones(1,1,num_features))\n",
" elif self.dim==-1: # [B,C,T] [1,C,1]\n",
" self.revin_bias = nn.Parameter(torch.zeros(1,num_features,1))\n",
" self.revin_weight = nn.Parameter(torch.ones(1,num_features,1))\n",
"\n",
" #@torch.no_grad()\n",
" def transform(self, x, mask):\n",
" \"\"\" Center and scale the data.\n",
"\n",
" **Parameters:**
\n",
" `x`: torch.Tensor shape [batch, time, channels].
\n",
" `mask`: torch Tensor bool, shape [batch, time] where `x` is valid and False\n",
" where `x` should be masked. Mask should not be all False in any column of\n",
" dimension dim to avoid NaNs from zero division.
\n",
"\n",
" **Returns:**
\n",
" `z`: torch.Tensor same shape as `x`, except scaled.\n",
" \"\"\"\n",
" x_shift, x_scale = self.compute_statistics(x=x, mask=mask, dim=self.dim, eps=self.eps)\n",
" self.x_shift = x_shift\n",
" self.x_scale = x_scale\n",
"\n",
" # Original Revin performs this operation\n",
" # z = self.revin_weight * z\n",
" # z = z + self.revin_bias\n",
" # However this is only valid for point forecast not for\n",
" # distribution's scale decouple technique.\n",
" if self.scaler_type=='revin':\n",
" self.x_shift = self.x_shift + self.revin_bias\n",
" self.x_scale = self.x_scale * (torch.relu(self.revin_weight) + self.eps)\n",
"\n",
" z = self.scaler(x, x_shift, x_scale)\n",
" return z\n",
"\n",
" #@torch.no_grad()\n",
" def inverse_transform(self, z, x_shift=None, x_scale=None):\n",
" \"\"\" Scale back the data to the original representation.\n",
"\n",
" **Parameters:**
\n",
" `z`: torch.Tensor shape [batch, time, channels], scaled.
\n",
"\n",
" **Returns:**
\n",
" `x`: torch.Tensor original data.\n",
" \"\"\"\n",
"\n",
" if x_shift is None:\n",
" x_shift = self.x_shift\n",
" if x_scale is None:\n",
" x_scale = self.x_scale\n",
"\n",
" # Original Revin performs this operation\n",
" # z = z - self.revin_bias\n",
" # z = (z / (self.revin_weight + self.eps))\n",
" # However this is only valid for point forecast not for\n",
" # distribution's scale decouple technique.\n",
"\n",
" x = self.inverse_scaler(z, x_shift, x_scale)\n",
" return x\n",
"\n",
" def forward(self, x):\n",
" # The gradients are optained from BaseWindows/BaseRecurrent forwards.\n",
" pass"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "91d7a892",
"metadata": {},
"outputs": [],
"source": [
"show_doc(TemporalNorm, name='TemporalNorm', title_level=3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3490b4a6",
"metadata": {},
"outputs": [],
"source": [
"show_doc(TemporalNorm.transform, title_level=3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "df49d4f5",
"metadata": {},
"outputs": [],
"source": [
"show_doc(TemporalNorm.inverse_transform, title_level=3)"
]
},
{
"cell_type": "markdown",
"id": "3e2968e0",
"metadata": {},
"source": [
"# Example"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "99722125",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c7fef46f",
"metadata": {},
"outputs": [],
"source": [
"# Declare synthetic batch to normalize\n",
"x1 = 10**0 * np.arange(36)[:, None]\n",
"x2 = 10**1 * np.arange(36)[:, None]\n",
"\n",
"np_x = np.concatenate([x1, x2], axis=1)\n",
"np_x = np.repeat(np_x[None, :,:], repeats=2, axis=0)\n",
"np_x[0,:,:] = np_x[0,:,:] + 100\n",
"\n",
"np_mask = np.ones(np_x.shape)\n",
"np_mask[:, -12:, :] = 0\n",
"\n",
"print(f'x.shape [batch, time, features]={np_x.shape}')\n",
"print(f'mask.shape [batch, time, features]={np_mask.shape}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "da1f93ae",
"metadata": {},
"outputs": [],
"source": [
"# Validate scalers\n",
"x = 1.0*torch.tensor(np_x)\n",
"mask = torch.tensor(np_mask)\n",
"scaler = TemporalNorm(scaler_type='standard', dim=1)\n",
"x_scaled = scaler.transform(x=x, mask=mask)\n",
"x_recovered = scaler.inverse_transform(x_scaled)\n",
"\n",
"plt.plot(x[0,:,0], label='x1', color='#78ACA8')\n",
"plt.plot(x[0,:,1], label='x2', color='#E3A39A')\n",
"plt.title('Before TemporalNorm')\n",
"plt.xlabel('Time')\n",
"plt.legend()\n",
"plt.show()\n",
"\n",
"plt.plot(x_scaled[0,:,0], label='x1', color='#78ACA8')\n",
"plt.plot(x_scaled[0,:,1]+0.1, label='x2+0.1', color='#E3A39A')\n",
"plt.title(f'TemporalNorm \\'{scaler.scaler_type}\\' ')\n",
"plt.xlabel('Time')\n",
"plt.legend()\n",
"plt.show()\n",
"\n",
"plt.plot(x_recovered[0,:,0], label='x1', color='#78ACA8')\n",
"plt.plot(x_recovered[0,:,1], label='x2', color='#E3A39A')\n",
"plt.title('Recovered')\n",
"plt.xlabel('Time')\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9aa6920e",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"# Validate scalers\n",
"for scaler_type in [None, 'identity', 'standard', 'robust', 'minmax', 'minmax1', 'invariant', 'revin']:\n",
" x = 1.0*torch.tensor(np_x)\n",
" mask = torch.tensor(np_mask)\n",
" scaler = TemporalNorm(scaler_type=scaler_type, dim=1, num_features=np_x.shape[-1])\n",
" x_scaled = scaler.transform(x=x, mask=mask)\n",
" x_recovered = scaler.inverse_transform(x_scaled)\n",
" assert torch.allclose(x, x_recovered, atol=1e-3), f'Recovered data is not the same as original with {scaler_type}'"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "17e3dbfc-2677-4d1f-85bc-de6343196045",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"import pandas as pd\n",
"\n",
"from neuralforecast import NeuralForecast\n",
"from neuralforecast.models import NHITS\n",
"from neuralforecast.utils import AirPassengersDF as Y_df"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "28e5f23d-9a64-4d77-8a27-55fcc765d0b7",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"# Unit test for masked predict filtering\n",
"model = NHITS(h=12,\n",
" input_size=12*2,\n",
" max_steps=1,\n",
" windows_batch_size=None, \n",
" n_freq_downsample=[1,1,1],\n",
" scaler_type='minmax')\n",
"\n",
"nf = NeuralForecast(models=[model], freq='M')\n",
"nf.fit(df=Y_df)\n",
"Y_hat = nf.predict(df=Y_df)\n",
"assert pd.isnull(Y_hat).sum().sum() == 0, 'Predictions should not have NaNs'"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "383f05b4-e921-4fa6-b2a1-65105b5eebd0",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"from neuralforecast import NeuralForecast\n",
"from neuralforecast.models import NHITS, RNN\n",
"from neuralforecast.losses.pytorch import DistributionLoss, HuberLoss, GMM, MAE\n",
"from neuralforecast.tsdataset import TimeSeriesDataset\n",
"from neuralforecast.utils import AirPassengers, AirPassengersPanel, AirPassengersStatic"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fb2095d2-74d4-4b94-bee3-c049aac8494d",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"# Unit test for ReVIN, and its compatibility with distribution's scale decouple\n",
"Y_df = AirPassengersPanel\n",
"# del Y_df['trend']\n",
"\n",
"# Instantiate BaseWindow model and test revin dynamic dimensionality with hist_exog_list\n",
"model = NHITS(h=12,\n",
" input_size=24,\n",
" loss=GMM(n_components=10, level=[90]),\n",
" hist_exog_list=['y_[lag12]'],\n",
" max_steps=1,\n",
" early_stop_patience_steps=10,\n",
" val_check_steps=50,\n",
" scaler_type='revin',\n",
" learning_rate=1e-3)\n",
"nf = NeuralForecast(models=[model], freq='MS')\n",
"Y_hat_df = nf.cross_validation(df=Y_df, val_size=12, n_windows=1)\n",
"\n",
"# Instantiate BaseWindow model and test revin dynamic dimensionality with hist_exog_list\n",
"model = NHITS(h=12,\n",
" input_size=24,\n",
" loss=HuberLoss(),\n",
" hist_exog_list=['trend', 'y_[lag12]'],\n",
" max_steps=1,\n",
" early_stop_patience_steps=10,\n",
" val_check_steps=50,\n",
" scaler_type='revin',\n",
" learning_rate=1e-3)\n",
"nf = NeuralForecast(models=[model], freq='MS')\n",
"Y_hat_df = nf.cross_validation(df=Y_df, val_size=12, n_windows=1)\n",
"\n",
"# Instantiate BaseRecurrent model and test revin dynamic dimensionality with hist_exog_list\n",
"model = RNN(h=12,\n",
" input_size=24,\n",
" loss=GMM(n_components=10, level=[90]),\n",
" hist_exog_list=['trend', 'y_[lag12]'],\n",
" max_steps=1,\n",
" early_stop_patience_steps=10,\n",
" val_check_steps=50,\n",
" scaler_type='revin',\n",
" learning_rate=1e-3)\n",
"nf = NeuralForecast(models=[model], freq='MS')\n",
"Y_hat_df = nf.cross_validation(df=Y_df, val_size=12, n_windows=1)\n",
"\n",
"# Instantiate BaseRecurrent model and test revin dynamic dimensionality with hist_exog_list\n",
"model = RNN(h=12,\n",
" input_size=24,\n",
" loss=HuberLoss(),\n",
" hist_exog_list=['trend'],\n",
" max_steps=1,\n",
" early_stop_patience_steps=10,\n",
" val_check_steps=50,\n",
" scaler_type='revin',\n",
" learning_rate=1e-3)\n",
"nf = NeuralForecast(models=[model], freq='MS')\n",
"Y_hat_df = nf.cross_validation(df=Y_df, val_size=12, n_windows=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b2f50bd8",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "python3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}