# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/models.stemgnn.ipynb.
# %% auto 0
__all__ = ['GLU', 'StockBlockLayer', 'StemGNN']
# %% ../../nbs/models.stemgnn.ipynb 6
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..losses.pytorch import MAE
from ..common._base_multivariate import BaseMultivariate
# %% ../../nbs/models.stemgnn.ipynb 7
class GLU(nn.Module):
def __init__(self, input_channel, output_channel):
super(GLU, self).__init__()
self.linear_left = nn.Linear(input_channel, output_channel)
self.linear_right = nn.Linear(input_channel, output_channel)
def forward(self, x):
return torch.mul(self.linear_left(x), torch.sigmoid(self.linear_right(x)))
# %% ../../nbs/models.stemgnn.ipynb 8
class StockBlockLayer(nn.Module):
def __init__(self, time_step, unit, multi_layer, stack_cnt=0):
super(StockBlockLayer, self).__init__()
self.time_step = time_step
self.unit = unit
self.stack_cnt = stack_cnt
self.multi = multi_layer
self.weight = nn.Parameter(
torch.Tensor(
1, 3 + 1, 1, self.time_step * self.multi, self.multi * self.time_step
)
) # [K+1, 1, in_c, out_c]
nn.init.xavier_normal_(self.weight)
self.forecast = nn.Linear(
self.time_step * self.multi, self.time_step * self.multi
)
self.forecast_result = nn.Linear(self.time_step * self.multi, self.time_step)
if self.stack_cnt == 0:
self.backcast = nn.Linear(self.time_step * self.multi, self.time_step)
self.backcast_short_cut = nn.Linear(self.time_step, self.time_step)
self.relu = nn.ReLU()
self.GLUs = nn.ModuleList()
self.output_channel = 4 * self.multi
for i in range(3):
if i == 0:
self.GLUs.append(
GLU(self.time_step * 4, self.time_step * self.output_channel)
)
self.GLUs.append(
GLU(self.time_step * 4, self.time_step * self.output_channel)
)
elif i == 1:
self.GLUs.append(
GLU(
self.time_step * self.output_channel,
self.time_step * self.output_channel,
)
)
self.GLUs.append(
GLU(
self.time_step * self.output_channel,
self.time_step * self.output_channel,
)
)
else:
self.GLUs.append(
GLU(
self.time_step * self.output_channel,
self.time_step * self.output_channel,
)
)
self.GLUs.append(
GLU(
self.time_step * self.output_channel,
self.time_step * self.output_channel,
)
)
def spe_seq_cell(self, input):
batch_size, k, input_channel, node_cnt, time_step = input.size()
input = input.view(batch_size, -1, node_cnt, time_step)
ffted = torch.view_as_real(torch.fft.fft(input, dim=1))
real = (
ffted[..., 0]
.permute(0, 2, 1, 3)
.contiguous()
.reshape(batch_size, node_cnt, -1)
)
img = (
ffted[..., 1]
.permute(0, 2, 1, 3)
.contiguous()
.reshape(batch_size, node_cnt, -1)
)
for i in range(3):
real = self.GLUs[i * 2](real)
img = self.GLUs[2 * i + 1](img)
real = (
real.reshape(batch_size, node_cnt, 4, -1).permute(0, 2, 1, 3).contiguous()
)
img = img.reshape(batch_size, node_cnt, 4, -1).permute(0, 2, 1, 3).contiguous()
time_step_as_inner = torch.cat([real.unsqueeze(-1), img.unsqueeze(-1)], dim=-1)
iffted = torch.fft.irfft(
torch.view_as_complex(time_step_as_inner),
n=time_step_as_inner.shape[1],
dim=1,
)
return iffted
def forward(self, x, mul_L):
mul_L = mul_L.unsqueeze(1)
x = x.unsqueeze(1)
gfted = torch.matmul(mul_L, x)
gconv_input = self.spe_seq_cell(gfted).unsqueeze(2)
igfted = torch.matmul(gconv_input, self.weight)
igfted = torch.sum(igfted, dim=1)
forecast_source = torch.sigmoid(self.forecast(igfted).squeeze(1))
forecast = self.forecast_result(forecast_source)
if self.stack_cnt == 0:
backcast_short = self.backcast_short_cut(x).squeeze(1)
backcast_source = torch.sigmoid(self.backcast(igfted) - backcast_short)
else:
backcast_source = None
return forecast, backcast_source
# %% ../../nbs/models.stemgnn.ipynb 9
class StemGNN(BaseMultivariate):
"""StemGNN
The Spectral Temporal Graph Neural Network (`StemGNN`) is a Graph-based multivariate
time-series forecasting model. `StemGNN` jointly learns temporal dependencies and
inter-series correlations in the spectral domain, by combining Graph Fourier Transform (GFT)
and Discrete Fourier Transform (DFT).
**Parameters:**
`h`: int, Forecast horizon.
`input_size`: int, autorregresive inputs size, y=[1,2,3,4] input_size=2 -> y_[t-2:t]=[1,2].
`n_series`: int, number of time-series.
`stat_exog_list`: str list, static exogenous columns.
`hist_exog_list`: str list, historic exogenous columns.
`futr_exog_list`: str list, future exogenous columns.
`n_stacks`: int=2, number of stacks in the model.
`multi_layer`: int=5, multiplier for FC hidden size on StemGNN blocks.
`dropout_rate`: float=0.5, dropout rate.
`leaky_rate`: float=0.2, alpha for LeakyReLU layer on Latent Correlation layer.
`loss`: PyTorch module, instantiated train loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).
`valid_loss`: PyTorch module=`loss`, instantiated valid loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).
`max_steps`: int=1000, maximum number of training steps.
`learning_rate`: float=1e-3, Learning rate between (0, 1).
`num_lr_decays`: int=-1, Number of learning rate decays, evenly distributed across max_steps.
`early_stop_patience_steps`: int=-1, Number of validation iterations before early stopping.
`val_check_steps`: int=100, Number of training steps between every validation loss check.
`batch_size`: int, number of windows in each batch.
`step_size`: int=1, step size between each window of temporal data.
`scaler_type`: str='robust', type of scaler for temporal inputs normalization see [temporal scalers](https://nixtla.github.io/neuralforecast/common.scalers.html).
`random_seed`: int, random_seed for pytorch initializer and numpy generators.
`num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
`drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
`alias`: str, optional, Custom name of the model.
`optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
`optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
`**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
"""
# Class attributes
SAMPLING_TYPE = "multivariate"
def __init__(
self,
h,
input_size,
n_series,
futr_exog_list=None,
hist_exog_list=None,
stat_exog_list=None,
n_stacks=2,
multi_layer: int = 5,
dropout_rate: float = 0.5,
leaky_rate: float = 0.2,
loss=MAE(),
valid_loss=None,
max_steps: int = 1000,
learning_rate: float = 1e-3,
num_lr_decays: int = 3,
early_stop_patience_steps: int = -1,
val_check_steps: int = 100,
batch_size: int = 32,
step_size: int = 1,
scaler_type: str = "robust",
random_seed: int = 1,
num_workers_loader=0,
drop_last_loader=False,
optimizer=None,
optimizer_kwargs=None,
**trainer_kwargs
):
# Inherit BaseMultivariate class
super(StemGNN, self).__init__(
h=h,
input_size=input_size,
n_series=n_series,
futr_exog_list=futr_exog_list,
hist_exog_list=hist_exog_list,
stat_exog_list=stat_exog_list,
loss=loss,
valid_loss=valid_loss,
max_steps=max_steps,
learning_rate=learning_rate,
num_lr_decays=num_lr_decays,
early_stop_patience_steps=early_stop_patience_steps,
val_check_steps=val_check_steps,
batch_size=batch_size,
step_size=step_size,
scaler_type=scaler_type,
num_workers_loader=num_workers_loader,
drop_last_loader=drop_last_loader,
random_seed=random_seed,
optimizer=optimizer,
optimizer_kwargs=optimizer_kwargs,
**trainer_kwargs
)
# Quick fix for now, fix the model later.
if n_stacks != 2:
raise Exception("StemGNN currently only supports n_stacks=2.")
# Exogenous variables
self.futr_input_size = len(self.futr_exog_list)
self.hist_input_size = len(self.hist_exog_list)
self.stat_input_size = len(self.stat_exog_list)
self.unit = n_series
self.stack_cnt = n_stacks
self.alpha = leaky_rate
self.time_step = input_size
self.horizon = h
self.h = h
self.weight_key = nn.Parameter(torch.zeros(size=(self.unit, 1)))
nn.init.xavier_uniform_(self.weight_key.data, gain=1.414)
self.weight_query = nn.Parameter(torch.zeros(size=(self.unit, 1)))
nn.init.xavier_uniform_(self.weight_query.data, gain=1.414)
self.GRU = nn.GRU(self.time_step, self.unit)
self.multi_layer = multi_layer
self.stock_block = nn.ModuleList()
self.stock_block.extend(
[
StockBlockLayer(
self.time_step, self.unit, self.multi_layer, stack_cnt=i
)
for i in range(self.stack_cnt)
]
)
self.fc = nn.Sequential(
nn.Linear(int(self.time_step), int(self.time_step)),
nn.LeakyReLU(),
nn.Linear(
int(self.time_step), self.horizon * self.loss.outputsize_multiplier
),
)
self.leakyrelu = nn.LeakyReLU(self.alpha)
self.dropout = nn.Dropout(p=dropout_rate)
def get_laplacian(self, graph, normalize):
"""
return the laplacian of the graph.
:param graph: the graph structure without self loop, [N, N].
:param normalize: whether to used the normalized laplacian.
:return: graph laplacian.
"""
if normalize:
D = torch.diag(torch.sum(graph, dim=-1) ** (-1 / 2))
L = torch.eye(
graph.size(0), device=graph.device, dtype=graph.dtype
) - torch.mm(torch.mm(D, graph), D)
else:
D = torch.diag(torch.sum(graph, dim=-1))
L = D - graph
return L
def cheb_polynomial(self, laplacian):
"""
Compute the Chebyshev Polynomial, according to the graph laplacian.
:param laplacian: the graph laplacian, [N, N].
:return: the multi order Chebyshev laplacian, [K, N, N].
"""
N = laplacian.size(0) # [N, N]
laplacian = laplacian.unsqueeze(0)
first_laplacian = torch.zeros(
[1, N, N], device=laplacian.device, dtype=torch.float
)
second_laplacian = laplacian
third_laplacian = (
2 * torch.matmul(laplacian, second_laplacian)
) - first_laplacian
forth_laplacian = (
2 * torch.matmul(laplacian, third_laplacian) - second_laplacian
)
multi_order_laplacian = torch.cat(
[first_laplacian, second_laplacian, third_laplacian, forth_laplacian], dim=0
)
return multi_order_laplacian
def latent_correlation_layer(self, x):
input, _ = self.GRU(x.permute(2, 0, 1).contiguous())
input = input.permute(1, 0, 2).contiguous()
attention = self.self_graph_attention(input)
attention = torch.mean(attention, dim=0)
degree = torch.sum(attention, dim=1)
# laplacian is sym or not
attention = 0.5 * (attention + attention.T)
degree_l = torch.diag(degree)
diagonal_degree_hat = torch.diag(1 / (torch.sqrt(degree) + 1e-7))
laplacian = torch.matmul(
diagonal_degree_hat, torch.matmul(degree_l - attention, diagonal_degree_hat)
)
mul_L = self.cheb_polynomial(laplacian)
return mul_L, attention
def self_graph_attention(self, input):
input = input.permute(0, 2, 1).contiguous()
bat, N, fea = input.size()
key = torch.matmul(input, self.weight_key)
query = torch.matmul(input, self.weight_query)
data = key.repeat(1, 1, N).view(bat, N * N, 1) + query.repeat(1, N, 1)
data = data.squeeze(2)
data = data.view(bat, N, -1)
data = self.leakyrelu(data)
attention = F.softmax(data, dim=2)
attention = self.dropout(attention)
return attention
def graph_fft(self, input, eigenvectors):
return torch.matmul(eigenvectors, input)
def forward(self, windows_batch):
# Parse batch
x = windows_batch["insample_y"]
batch_size = x.shape[0]
mul_L, attention = self.latent_correlation_layer(x)
X = x.unsqueeze(1).permute(0, 1, 3, 2).contiguous()
result = []
for stack_i in range(self.stack_cnt):
forecast, X = self.stock_block[stack_i](X, mul_L)
result.append(forecast)
forecast = result[0] + result[1]
forecast = self.fc(forecast)
forecast = forecast.permute(0, 2, 1).contiguous()
forecast = forecast.reshape(
batch_size, self.h, self.loss.outputsize_multiplier * self.n_series
)
forecast = self.loss.domain_map(forecast)
# domain_map might have squeezed the last dimension in case n_series == 1
# Note that this fails in case of a tuple loss, but Multivariate does not support tuple losses yet.
if forecast.ndim == 2:
return forecast.unsqueeze(-1)
else:
return forecast