import math import torch import torch.nn as nn import torch.nn.functional as F import torch.nn.init as init from dgl.nn.pytorch import GraphConv from dgl.nn.pytorch.conv import ChebConv class TemporalConvLayer(nn.Module): """Temporal convolution layer. arguments --------- c_in : int The number of input channels (features) c_out : int The number of output channels (features) dia : int The dilation size """ def __init__(self, c_in, c_out, dia=1): super(TemporalConvLayer, self).__init__() self.c_out = c_out self.c_in = c_in self.conv = nn.Conv2d( c_in, c_out, (2, 1), 1, dilation=dia, padding=(0, 0) ) def forward(self, x): return torch.relu(self.conv(x)) class SpatioConvLayer(nn.Module): def __init__(self, c, Lk): # c : hidden dimension Lk: graph matrix super(SpatioConvLayer, self).__init__() self.g = Lk self.gc = GraphConv(c, c, activation=F.relu) # self.gc = ChebConv(c, c, 3) def init(self): stdv = 1.0 / math.sqrt(self.W.weight.size(1)) self.W.weight.data.uniform_(-stdv, stdv) def forward(self, x): x = x.transpose(0, 3) x = x.transpose(1, 3) output = self.gc(self.g, x) output = output.transpose(1, 3) output = output.transpose(0, 3) return torch.relu(output) class FullyConvLayer(nn.Module): def __init__(self, c): super(FullyConvLayer, self).__init__() self.conv = nn.Conv2d(c, 1, 1) def forward(self, x): return self.conv(x) class OutputLayer(nn.Module): def __init__(self, c, T, n): super(OutputLayer, self).__init__() self.tconv1 = nn.Conv2d(c, c, (T, 1), 1, dilation=1, padding=(0, 0)) self.ln = nn.LayerNorm([n, c]) self.tconv2 = nn.Conv2d(c, c, (1, 1), 1, dilation=1, padding=(0, 0)) self.fc = FullyConvLayer(c) def forward(self, x): x_t1 = self.tconv1(x) x_ln = self.ln(x_t1.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) x_t2 = self.tconv2(x_ln) return self.fc(x_t2) class STGCN_WAVE(nn.Module): def __init__( self, c, T, n, Lk, p, num_layers, device, control_str="TNTSTNTST" ): super(STGCN_WAVE, self).__init__() self.control_str = control_str # model structure controller self.num_layers = len(control_str) self.layers = nn.ModuleList([]) cnt = 0 diapower = 0 for i in range(self.num_layers): i_layer = control_str[i] if i_layer == "T": # Temporal Layer self.layers.append( TemporalConvLayer(c[cnt], c[cnt + 1], dia=2**diapower) ) diapower += 1 cnt += 1 if i_layer == "S": # Spatio Layer self.layers.append(SpatioConvLayer(c[cnt], Lk)) if i_layer == "N": # Norm Layer self.layers.append(nn.LayerNorm([n, c[cnt]])) self.output = OutputLayer(c[cnt], T + 1 - 2 ** (diapower), n) for layer in self.layers: layer = layer.to(device) def forward(self, x): for i in range(self.num_layers): i_layer = self.control_str[i] if i_layer == "N": x = self.layers[i](x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) else: x = self.layers[i](x) return self.output(x)