{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| default_exp models.informer"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Informer"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"The Informer model tackles the vanilla Transformer computational complexity challenges for long-horizon forecasting.\n",
"\n",
"The architecture has three distinctive features:\n",
"- A ProbSparse self-attention mechanism with an O time and memory complexity Llog(L).\n",
"- A self-attention distilling process that prioritizes attention and efficiently handles long input sequences.\n",
"- An MLP multi-step decoder that predicts long time-series sequences in a single forward operation rather than step-by-step.\n",
"\n",
"The Informer model utilizes a three-component approach to define its embedding:\n",
"- It employs encoded autoregressive features obtained from a convolution network.\n",
"- It uses window-relative positional embeddings derived from harmonic functions.\n",
"- Absolute positional embeddings obtained from calendar features are utilized."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"**References**
\n",
"- [Haoyi Zhou, Shanghang Zhang, Jieqi Peng, Shuai Zhang, Jianxin Li, Hui Xiong, Wancai Zhang. \"Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting\"](https://arxiv.org/abs/2012.07436)
"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"import math\n",
"import numpy as np\n",
"from typing import Optional\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"\n",
"from neuralforecast.common._modules import (\n",
" TransEncoderLayer, TransEncoder,\n",
" TransDecoderLayer, TransDecoder,\n",
" DataEmbedding, AttentionLayer,\n",
")\n",
"from neuralforecast.common._base_windows import BaseWindows\n",
"\n",
"from neuralforecast.losses.pytorch import MAE"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"from fastcore.test import test_eq\n",
"from nbdev.showdoc import show_doc"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Auxiliary Functions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"class ConvLayer(nn.Module):\n",
" def __init__(self, c_in):\n",
" super(ConvLayer, self).__init__()\n",
" self.downConv = nn.Conv1d(in_channels=c_in,\n",
" out_channels=c_in,\n",
" kernel_size=3,\n",
" padding=2,\n",
" padding_mode='circular')\n",
" self.norm = nn.BatchNorm1d(c_in)\n",
" self.activation = nn.ELU()\n",
" self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)\n",
"\n",
" def forward(self, x):\n",
" x = self.downConv(x.permute(0, 2, 1))\n",
" x = self.norm(x)\n",
" x = self.activation(x)\n",
" x = self.maxPool(x)\n",
" x = x.transpose(1, 2)\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"class ProbMask():\n",
" def __init__(self, B, H, L, index, scores, device=\"cpu\"):\n",
" _mask = torch.ones(L, scores.shape[-1], dtype=torch.bool, device=device).triu(1)\n",
" _mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1])\n",
" indicator = _mask_ex[torch.arange(B)[:, None, None],\n",
" torch.arange(H)[None, :, None],\n",
" index, :].to(device)\n",
" self._mask = indicator.view(scores.shape).to(device)\n",
"\n",
" @property\n",
" def mask(self):\n",
" return self._mask\n",
"\n",
"\n",
"class ProbAttention(nn.Module):\n",
" def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):\n",
" super(ProbAttention, self).__init__()\n",
" self.factor = factor\n",
" self.scale = scale\n",
" self.mask_flag = mask_flag\n",
" self.output_attention = output_attention\n",
" self.dropout = nn.Dropout(attention_dropout)\n",
"\n",
" def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q)\n",
" # Q [B, H, L, D]\n",
" B, H, L_K, E = K.shape\n",
" _, _, L_Q, _ = Q.shape\n",
"\n",
" # calculate the sampled Q_K\n",
" K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)\n",
"\n",
" index_sample = torch.randint(L_K, (L_Q, sample_k)) # real U = U_part(factor*ln(L_k))*L_q\n",
" K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]\n",
" Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze()\n",
"\n",
" # find the Top_k query with sparisty measurement\n",
" M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)\n",
" M_top = M.topk(n_top, sorted=False)[1]\n",
"\n",
" # use the reduced Q to calculate Q_K\n",
" Q_reduce = Q[torch.arange(B)[:, None, None],\n",
" torch.arange(H)[None, :, None],\n",
" M_top, :] # factor*ln(L_q)\n",
" Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k\n",
"\n",
" return Q_K, M_top\n",
"\n",
" def _get_initial_context(self, V, L_Q):\n",
" B, H, L_V, D = V.shape\n",
" if not self.mask_flag:\n",
" # V_sum = V.sum(dim=-2)\n",
" V_sum = V.mean(dim=-2)\n",
" contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone()\n",
" else: # use mask\n",
" assert (L_Q == L_V) # requires that L_Q == L_V, i.e. for self-attention only\n",
" contex = V.cumsum(dim=-2)\n",
" return contex\n",
"\n",
" def _update_context(self, context_in, V, scores, index, L_Q, attn_mask):\n",
" B, H, L_V, D = V.shape\n",
"\n",
" if self.mask_flag:\n",
" attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device)\n",
" scores.masked_fill_(attn_mask.mask, -np.inf)\n",
"\n",
" attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores)\n",
"\n",
" context_in[torch.arange(B)[:, None, None],\n",
" torch.arange(H)[None, :, None],\n",
" index, :] = torch.matmul(attn, V).type_as(context_in)\n",
" if self.output_attention:\n",
" attns = (torch.ones([B, H, L_V, L_V], device=attn.device) / L_V).type_as(attn)\n",
" attns[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] = attn\n",
" return (context_in, attns)\n",
" else:\n",
" return (context_in, None)\n",
"\n",
" def forward(self, queries, keys, values, attn_mask):\n",
" B, L_Q, H, D = queries.shape\n",
" _, L_K, _, _ = keys.shape\n",
"\n",
" queries = queries.transpose(2, 1)\n",
" keys = keys.transpose(2, 1)\n",
" values = values.transpose(2, 1)\n",
"\n",
" U_part = self.factor * np.ceil(np.log(L_K)).astype('int').item() # c*ln(L_k)\n",
" u = self.factor * np.ceil(np.log(L_Q)).astype('int').item() # c*ln(L_q)\n",
"\n",
" U_part = U_part if U_part < L_K else L_K\n",
" u = u if u < L_Q else L_Q\n",
"\n",
" scores_top, index = self._prob_QK(queries, keys, sample_k=U_part, n_top=u)\n",
"\n",
" # add scale factor\n",
" scale = self.scale or 1. / math.sqrt(D)\n",
" if scale is not None:\n",
" scores_top = scores_top * scale\n",
" # get the context\n",
" context = self._get_initial_context(values, L_Q)\n",
" # update the context with selected top_k queries\n",
" context, attn = self._update_context(context, values, scores_top, index, L_Q, attn_mask)\n",
"\n",
" return context.contiguous(), attn"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Informer"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"class Informer(BaseWindows):\n",
" \"\"\" Informer\n",
"\n",
"\tThe Informer model tackles the vanilla Transformer computational complexity challenges for long-horizon forecasting. \n",
"\tThe architecture has three distinctive features:\n",
" 1) A ProbSparse self-attention mechanism with an O time and memory complexity Llog(L).\n",
" 2) A self-attention distilling process that prioritizes attention and efficiently handles long input sequences.\n",
" 3) An MLP multi-step decoder that predicts long time-series sequences in a single forward operation rather than step-by-step.\n",
"\n",
" The Informer model utilizes a three-component approach to define its embedding:\n",
" 1) It employs encoded autoregressive features obtained from a convolution network.\n",
" 2) It uses window-relative positional embeddings derived from harmonic functions.\n",
" 3) Absolute positional embeddings obtained from calendar features are utilized.\n",
"\n",
" *Parameters:*
\n",
" `h`: int, forecast horizon.
\n",
" `input_size`: int, maximum sequence length for truncated train backpropagation. Default -1 uses all history.
\n",
" `futr_exog_list`: str list, future exogenous columns.
\n",
" `hist_exog_list`: str list, historic exogenous columns.
\n",
" `stat_exog_list`: str list, static exogenous columns.
\n",
" `exclude_insample_y`: bool=False, the model skips the autoregressive features y[t-input_size:t] if True.
\n",
"\t`decoder_input_size_multiplier`: float = 0.5, .
\n",
" `hidden_size`: int=128, units of embeddings and encoders.
\n",
" `n_head`: int=4, controls number of multi-head's attention.
\n",
" `dropout`: float (0, 1), dropout throughout Informer architecture.
\n",
"\t`factor`: int=3, Probsparse attention factor.
\n",
"\t`conv_hidden_size`: int=32, channels of the convolutional encoder.
\n",
"\t`activation`: str=`GELU`, activation from ['ReLU', 'Softplus', 'Tanh', 'SELU', 'LeakyReLU', 'PReLU', 'Sigmoid', 'GELU'].
\n",
" `encoder_layers`: int=2, number of layers for the TCN encoder.
\n",
" `decoder_layers`: int=1, number of layers for the MLP decoder.
\n",
" `distil`: bool = True, wether the Informer decoder uses bottlenecks.
\n",
" `loss`: PyTorch module, instantiated train loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).
\n",
" `max_steps`: int=1000, maximum number of training steps.
\n",
" `learning_rate`: float=1e-3, Learning rate between (0, 1).
\n",
" `num_lr_decays`: int=-1, Number of learning rate decays, evenly distributed across max_steps.
\n",
" `early_stop_patience_steps`: int=-1, Number of validation iterations before early stopping.
\n",
" `val_check_steps`: int=100, Number of training steps between every validation loss check.
\n",
" `batch_size`: int=32, number of different series in each batch.
\n",
" `valid_batch_size`: int=None, number of different series in each validation and test batch, if None uses batch_size.
\n",
" `windows_batch_size`: int=1024, number of windows to sample in each training batch, default uses all.
\n",
" `inference_windows_batch_size`: int=1024, number of windows to sample in each inference batch.
\n",
" `start_padding_enabled`: bool=False, if True, the model will pad the time series with zeros at the beginning, by input size.
\n",
" `scaler_type`: str='robust', type of scaler for temporal inputs normalization see [temporal scalers](https://nixtla.github.io/neuralforecast/common.scalers.html).
\n",
" `random_seed`: int=1, random_seed for pytorch initializer and numpy generators.
\n",
" `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
\n",
" `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
\n",
" `alias`: str, optional, Custom name of the model.
\n",
" `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
\n",
" `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
\n",
" `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
\n",
"\n",
"\t*References*
\n",
"\t- [Haoyi Zhou, Shanghang Zhang, Jieqi Peng, Shuai Zhang, Jianxin Li, Hui Xiong, Wancai Zhang. \"Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting\"](https://arxiv.org/abs/2012.07436)
\n",
" \"\"\"\n",
" # Class attributes\n",
" SAMPLING_TYPE = 'windows'\n",
" \n",
" def __init__(self,\n",
" h: int, \n",
" input_size: int,\n",
" stat_exog_list = None,\n",
" hist_exog_list = None,\n",
" futr_exog_list = None,\n",
" exclude_insample_y = False,\n",
" decoder_input_size_multiplier: float = 0.5,\n",
" hidden_size: int = 128, \n",
" dropout: float = 0.05,\n",
" factor: int = 3,\n",
" n_head: int = 4,\n",
" conv_hidden_size: int = 32,\n",
" activation: str = 'gelu',\n",
" encoder_layers: int = 2, \n",
" decoder_layers: int = 1, \n",
" distil: bool = True,\n",
" loss = MAE(),\n",
" valid_loss = None,\n",
" max_steps: int = 5000,\n",
" learning_rate: float = 1e-4,\n",
" num_lr_decays: int = -1,\n",
" early_stop_patience_steps: int =-1,\n",
" val_check_steps: int = 100,\n",
" batch_size: int = 32,\n",
" valid_batch_size: Optional[int] = None,\n",
" windows_batch_size = 1024,\n",
" inference_windows_batch_size = 1024,\n",
" start_padding_enabled = False,\n",
" step_size: int = 1,\n",
" scaler_type: str = 'identity',\n",
" random_seed: int = 1,\n",
" num_workers_loader: int = 0,\n",
" drop_last_loader: bool = False,\n",
" optimizer=None,\n",
" optimizer_kwargs=None,\n",
" **trainer_kwargs):\n",
" super(Informer, self).__init__(h=h,\n",
" input_size=input_size,\n",
" hist_exog_list=hist_exog_list,\n",
" stat_exog_list=stat_exog_list,\n",
" futr_exog_list = futr_exog_list,\n",
" exclude_insample_y = exclude_insample_y,\n",
" loss=loss,\n",
" valid_loss=valid_loss,\n",
" max_steps=max_steps,\n",
" learning_rate=learning_rate,\n",
" num_lr_decays=num_lr_decays,\n",
" early_stop_patience_steps=early_stop_patience_steps,\n",
" val_check_steps=val_check_steps,\n",
" batch_size=batch_size,\n",
" valid_batch_size=valid_batch_size,\n",
" windows_batch_size=windows_batch_size,\n",
" inference_windows_batch_size = inference_windows_batch_size,\n",
" start_padding_enabled=start_padding_enabled,\n",
" step_size=step_size,\n",
" scaler_type=scaler_type,\n",
" num_workers_loader=num_workers_loader,\n",
" drop_last_loader=drop_last_loader,\n",
" random_seed=random_seed,\n",
" optimizer=optimizer,\n",
" optimizer_kwargs=optimizer_kwargs,\n",
" **trainer_kwargs)\n",
"\n",
" # Architecture\n",
" self.futr_input_size = len(self.futr_exog_list)\n",
" self.hist_input_size = len(self.hist_exog_list)\n",
" self.stat_input_size = len(self.stat_exog_list)\n",
"\n",
" if self.stat_input_size > 0:\n",
" raise Exception('Informer does not support static variables yet')\n",
" \n",
" if self.hist_input_size > 0:\n",
" raise Exception('Informer does not support historical variables yet')\n",
"\n",
" self.label_len = int(np.ceil(input_size * decoder_input_size_multiplier))\n",
" if (self.label_len >= input_size) or (self.label_len <= 0):\n",
" raise Exception(f'Check decoder_input_size_multiplier={decoder_input_size_multiplier}, range (0,1)')\n",
"\n",
" if activation not in ['relu', 'gelu']:\n",
" raise Exception(f'Check activation={activation}')\n",
" \n",
" self.c_out = self.loss.outputsize_multiplier\n",
" self.output_attention = False\n",
" self.enc_in = 1 \n",
" self.dec_in = 1\n",
"\n",
" # Embedding\n",
" self.enc_embedding = DataEmbedding(c_in=self.enc_in,\n",
" exog_input_size=self.hist_input_size,\n",
" hidden_size=hidden_size, \n",
" pos_embedding=True,\n",
" dropout=dropout)\n",
" self.dec_embedding = DataEmbedding(self.dec_in,\n",
" exog_input_size=self.hist_input_size,\n",
" hidden_size=hidden_size, \n",
" pos_embedding=True,\n",
" dropout=dropout)\n",
"\n",
" # Encoder\n",
" self.encoder = TransEncoder(\n",
" [\n",
" TransEncoderLayer(\n",
" AttentionLayer(\n",
" ProbAttention(False, factor,\n",
" attention_dropout=dropout,\n",
" output_attention=self.output_attention),\n",
" hidden_size, n_head),\n",
" hidden_size,\n",
" conv_hidden_size,\n",
" dropout=dropout,\n",
" activation=activation\n",
" ) for l in range(encoder_layers)\n",
" ],\n",
" [\n",
" ConvLayer(\n",
" hidden_size\n",
" ) for l in range(encoder_layers - 1)\n",
" ] if distil else None,\n",
" norm_layer=torch.nn.LayerNorm(hidden_size)\n",
" )\n",
" # Decoder\n",
" self.decoder = TransDecoder(\n",
" [\n",
" TransDecoderLayer(\n",
" AttentionLayer(\n",
" ProbAttention(True, factor, attention_dropout=dropout, output_attention=False),\n",
" hidden_size, n_head),\n",
" AttentionLayer(\n",
" ProbAttention(False, factor, attention_dropout=dropout, output_attention=False),\n",
" hidden_size, n_head),\n",
" hidden_size,\n",
" conv_hidden_size,\n",
" dropout=dropout,\n",
" activation=activation,\n",
" )\n",
" for l in range(decoder_layers)\n",
" ],\n",
" norm_layer=torch.nn.LayerNorm(hidden_size),\n",
" projection=nn.Linear(hidden_size, self.c_out, bias=True)\n",
" )\n",
"\n",
" def forward(self, windows_batch):\n",
" # Parse windows_batch\n",
" insample_y = windows_batch['insample_y']\n",
" #insample_mask = windows_batch['insample_mask']\n",
" #hist_exog = windows_batch['hist_exog']\n",
" #stat_exog = windows_batch['stat_exog']\n",
"\n",
" futr_exog = windows_batch['futr_exog']\n",
"\n",
" insample_y = insample_y.unsqueeze(-1) # [Ws,L,1]\n",
"\n",
" if self.futr_input_size > 0:\n",
" x_mark_enc = futr_exog[:,:self.input_size,:]\n",
" x_mark_dec = futr_exog[:,-(self.label_len+self.h):,:]\n",
" else:\n",
" x_mark_enc = None\n",
" x_mark_dec = None\n",
"\n",
" x_dec = torch.zeros(size=(len(insample_y),self.h,1), device=insample_y.device)\n",
" x_dec = torch.cat([insample_y[:,-self.label_len:,:], x_dec], dim=1) \n",
"\n",
" enc_out = self.enc_embedding(insample_y, x_mark_enc)\n",
" enc_out, _ = self.encoder(enc_out, attn_mask=None) # attns visualization\n",
"\n",
" dec_out = self.dec_embedding(x_dec, x_mark_dec)\n",
" dec_out = self.decoder(dec_out, enc_out, x_mask=None, \n",
" cross_mask=None)\n",
"\n",
" forecast = self.loss.domain_map(dec_out[:, -self.h:])\n",
" return forecast"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"show_doc(Informer)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"show_doc(Informer.fit, name='Informer.fit')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"show_doc(Informer.predict, name='Informer.predict')"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Usage Example"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| eval: false\n",
"import numpy as np\n",
"import pandas as pd\n",
"import pytorch_lightning as pl\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from neuralforecast import NeuralForecast\n",
"from neuralforecast.models import MLP\n",
"from neuralforecast.losses.pytorch import MQLoss, DistributionLoss\n",
"from neuralforecast.tsdataset import TimeSeriesDataset\n",
"from neuralforecast.utils import AirPassengers, AirPassengersPanel, AirPassengersStatic, augment_calendar_df\n",
"\n",
"AirPassengersPanel, calendar_cols = augment_calendar_df(df=AirPassengersPanel, freq='M')\n",
"\n",
"Y_train_df = AirPassengersPanel[AirPassengersPanel.ds=AirPassengersPanel['ds'].values[-12]].reset_index(drop=True) # 12 test\n",
"\n",
"model = Informer(h=12,\n",
" input_size=24,\n",
" hidden_size = 16,\n",
" conv_hidden_size = 32,\n",
" n_head = 2,\n",
" #loss=DistributionLoss(distribution='StudentT', level=[80, 90]),\n",
" loss=MAE(),\n",
" futr_exog_list=calendar_cols,\n",
" scaler_type='robust',\n",
" learning_rate=1e-3,\n",
" max_steps=5,\n",
" val_check_steps=50,\n",
" early_stop_patience_steps=2)\n",
"\n",
"nf = NeuralForecast(\n",
" models=[model],\n",
" freq='M'\n",
")\n",
"nf.fit(df=Y_train_df, static_df=AirPassengersStatic, val_size=12)\n",
"forecasts = nf.predict(futr_df=Y_test_df)\n",
"\n",
"Y_hat_df = forecasts.reset_index(drop=False).drop(columns=['unique_id','ds'])\n",
"plot_df = pd.concat([Y_test_df, Y_hat_df], axis=1)\n",
"plot_df = pd.concat([Y_train_df, plot_df])\n",
"\n",
"if model.loss.is_distribution_output:\n",
" plot_df = plot_df[plot_df.unique_id=='Airline1'].drop('unique_id', axis=1)\n",
" plt.plot(plot_df['ds'], plot_df['y'], c='black', label='True')\n",
" plt.plot(plot_df['ds'], plot_df['Informer-median'], c='blue', label='median')\n",
" plt.fill_between(x=plot_df['ds'][-12:], \n",
" y1=plot_df['Informer-lo-90'][-12:].values, \n",
" y2=plot_df['Informer-hi-90'][-12:].values,\n",
" alpha=0.4, label='level 90')\n",
" plt.grid()\n",
" plt.legend()\n",
" plt.plot()\n",
"else:\n",
" plot_df = plot_df[plot_df.unique_id=='Airline1'].drop('unique_id', axis=1)\n",
" plt.plot(plot_df['ds'], plot_df['y'], c='black', label='True')\n",
" plt.plot(plot_df['ds'], plot_df['Informer'], c='blue', label='Forecast')\n",
" plt.legend()\n",
" plt.grid()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "python3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}