"tests/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "8ab27b0581529aa51da8183003e3416f92ae073d"
Unverified Commit 3c387988 authored by Chen Sirui's avatar Chen Sirui Committed by GitHub
Browse files

[Example] DCRNN and GaAN (#2858)



* Ready for PR

* refractor code
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-45-47.ap-northeast-1.compute.internal>
Co-authored-by: default avatarTianjun Xiao <xiaotj1990327@gmail.com>
parent 3075b277
...@@ -91,7 +91,8 @@ The folder contains example implementations of selected research papers related ...@@ -91,7 +91,8 @@ The folder contains example implementations of selected research papers related
| [GNNExplainer: Generating Explanations for Graph Neural Networks](#gnnexplainer) | :heavy_check_mark: | | | | | | [GNNExplainer: Generating Explanations for Graph Neural Networks](#gnnexplainer) | :heavy_check_mark: | | | | |
| [Interaction Networks for Learning about Objects, Relations and Physics](#graphsim) | | |:heavy_check_mark: | | | | [Interaction Networks for Learning about Objects, Relations and Physics](#graphsim) | | |:heavy_check_mark: | | |
| [Representation Learning on Graphs with Jumping Knowledge Networks](#jknet) | :heavy_check_mark: | | | | | | [Representation Learning on Graphs with Jumping Knowledge Networks](#jknet) | :heavy_check_mark: | | | | |
| [Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forcasting](#dcrnn) | | | :heavy_check_mark: | | |
| [GaAN: Gated Attention Networks for Learning on large and Spatiotemporal Graphs](#gaan) | | | :heavy_check_mark: | | |
## 2021 ## 2021
- <a name="bgnn"></a> Ivanov et al. Boost then Convolve: Gradient Boosting Meets Graph Neural Networks. [Paper link](https://openreview.net/forum?id=ebS5NUfoMKL). - <a name="bgnn"></a> Ivanov et al. Boost then Convolve: Gradient Boosting Meets Graph Neural Networks. [Paper link](https://openreview.net/forum?id=ebS5NUfoMKL).
...@@ -268,6 +269,9 @@ The folder contains example implementations of selected research papers related ...@@ -268,6 +269,9 @@ The folder contains example implementations of selected research papers related
- Example code: [pytorch](../examples/pytorch/jknet) - Example code: [pytorch](../examples/pytorch/jknet)
- Tags: message passing, neighborhood - Tags: message passing, neighborhood
- <a name="gaan"></a> Zhang et al. GaAN: Gated Attention Networks for Learning on Large and Spatiotemporal Graphs. [Paper link](https://arxiv.org/abs/1803.07294).
- Example code: [pytorch](../examples/pytorch/dtgrnn)
- Tags: Static discrete temporal graph, traffic forcasting
## 2017 ## 2017
...@@ -323,6 +327,10 @@ The folder contains example implementations of selected research papers related ...@@ -323,6 +327,10 @@ The folder contains example implementations of selected research papers related
- Example code: [PyTorch](https://github.com/awslabs/dgl-lifesci/tree/master/examples/property_prediction/alchemy) - Example code: [PyTorch](https://github.com/awslabs/dgl-lifesci/tree/master/examples/property_prediction/alchemy)
- Tags: molecules, quantum chemistry - Tags: molecules, quantum chemistry
- <a name="dcrnn"></a> Li et al. Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forcasting. [Paper link](https://arxiv.org/abs/1707.01926).
- Example code: [Pytorch](../examples/pytorch/dtgrnn)
- Tags: Static discrete temporal graph, traffic forcasting.
## 2016 ## 2016
- <a name="ggnn"></a> Li et al. Gated Graph Sequence Neural Networks. [Paper link](https://arxiv.org/abs/1511.05493). - <a name="ggnn"></a> Li et al. Gated Graph Sequence Neural Networks. [Paper link](https://arxiv.org/abs/1511.05493).
......
# Discrete Temporal Dynamic Graph with recurrent structure
## DGL Implementation of DCRNN and GaAN paper.
This DGL example implements the GNN model proposed in the paper [Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forecasting](https://arxiv.org/abs/1707.01926) and [GaAN:Gated Attention Networks for Learning on Large and Spatiotemporal Graphs](https://arxiv.org/pdf/1803.07294).
Model implementor
----------------------
This example was implemented by [Ericcsr](https://github.com/Ericcsr) during his Internship work at the AWS Shanghai AI Lab.
The graph dataset used in this example
---------------------------------------
METR-LA dataset. Dataset summary:
- NumNodes: 207
- NumEdges: 1722
- NumFeats: 2
- TrainingSamples: 70%
- ValidationSamples: 20%
- TestSamples: 10%
PEMS-BAY dataset. Dataset Summary:
- NumNodes: 325
- NumEdges: 2694
- NumFeats: 2
- TrainingSamples: 70%
- ValidationSamples: 20%
- TestSamples: 10%
How to run example files
--------------------------------
In the dtdg folder, run
**Please use `train.py`**
Train the DCRNN model on METR-LA Dataset
```python
python train.py --dataset LA --model dcrnn
```
If want to use a GPU, run
```python
python train.py --gpu 0 --dataset LA --model dcrnn
```
if you want to use PEMS-BAY dataset
```python
python train.py --gpu 0 --dataset BAY --model dcrnn
```
Train GaAN model
```python
python train.py --gpu 0 --model gaan --dataset <LA/BAY>
```
Performance on METR-LA
-------------------------
| Models/Datasets | Test MAE |
| :-------------- | --------:|
| DCRNN in DGL | 2.91 |
| DCRNN paper | 3.17 |
| GaAN in DGL | 3.20 |
| GaAN paper | 3.16 |
Notice that Any Graph Convolution module can be plugged into the recurrent discrete temporal dynamic graph template to test performance; simply replace DiffConv or GaAN.
import os
import ssl
from six.moves import urllib
import torch
import numpy as np
import dgl
from torch.utils.data import Dataset, DataLoader
def download_file(dataset):
print("Start Downloading data: {}".format(dataset))
url = "https://s3.us-west-2.amazonaws.com/dgl-data/dataset/{}".format(
dataset)
print("Start Downloading File....")
context = ssl._create_unverified_context()
data = urllib.request.urlopen(url, context=context)
with open("./data/{}".format(dataset), "wb") as handle:
handle.write(data.read())
class SnapShotDataset(Dataset):
def __init__(self, path, npz_file):
if not os.path.exists(path+'/'+npz_file):
if not os.path.exists(path):
os.mkdir(path)
download_file(npz_file)
zipfile = np.load(path+'/'+npz_file)
self.x = zipfile['x']
self.y = zipfile['y']
def __len__(self):
return len(self.x)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
return self.x[idx, ...], self.y[idx, ...]
def METR_LAGraphDataset():
if not os.path.exists('data/graph_la.bin'):
if not os.path.exists('data'):
os.mkdir('data')
download_file('graph_la.bin')
g, _ = dgl.load_graphs('data/graph_la.bin')
return g[0]
class METR_LATrainDataset(SnapShotDataset):
def __init__(self):
super(METR_LATrainDataset, self).__init__('data', 'metr_la_train.npz')
self.mean = self.x[..., 0].mean()
self.std = self.x[..., 0].std()
class METR_LATestDataset(SnapShotDataset):
def __init__(self):
super(METR_LATestDataset, self).__init__('data', 'metr_la_test.npz')
class METR_LAValidDataset(SnapShotDataset):
def __init__(self):
super(METR_LAValidDataset, self).__init__('data', 'metr_la_valid.npz')
def PEMS_BAYGraphDataset():
if not os.path.exists('data/graph_bay.bin'):
if not os.path.exists('data'):
os.mkdir('data')
download_file('graph_bay.bin')
g, _ = dgl.load_graphs('data/graph_bay.bin')
return g[0]
class PEMS_BAYTrainDataset(SnapShotDataset):
def __init__(self):
super(PEMS_BAYTrainDataset, self).__init__(
'data', 'pems_bay_train.npz')
self.mean = self.x[..., 0].mean()
self.std = self.x[..., 0].std()
class PEMS_BAYTestDataset(SnapShotDataset):
def __init__(self):
super(PEMS_BAYTestDataset, self).__init__('data', 'pems_bay_test.npz')
class PEMS_BAYValidDataset(SnapShotDataset):
def __init__(self):
super(PEMS_BAYValidDataset, self).__init__(
'data', 'pems_bay_valid.npz')
import numpy as np
import scipy.sparse as sparse
import torch
import torch.nn as nn
import dgl
from dgl.base import DGLError
import dgl.function as fn
class DiffConv(nn.Module):
'''DiffConv is the implementation of diffusion convolution from paper DCRNN
It will compute multiple diffusion matrix and perform multiple diffusion conv on it,
this layer can be used for traffic prediction, pedamic model.
Parameter
==========
in_feats : int
number of input feature
out_feats : int
number of output feature
k : int
number of diffusion steps
dir : str [both/in/out]
direction of diffusion convolution
From paper default both direction
'''
def __init__(self, in_feats, out_feats, k, in_graph_list, out_graph_list, dir='both'):
super(DiffConv, self).__init__()
self.in_feats = in_feats
self.out_feats = out_feats
self.k = k
self.dir = dir
self.num_graphs = self.k-1 if self.dir == 'both' else 2*self.k-2
self.project_fcs = nn.ModuleList()
for i in range(self.num_graphs):
self.project_fcs.append(
nn.Linear(self.in_feats, self.out_feats, bias=False))
self.merger = nn.Parameter(torch.randn(self.num_graphs+1))
self.in_graph_list = in_graph_list
self.out_graph_list = out_graph_list
@staticmethod
def attach_graph(g, k):
device = g.device
out_graph_list = []
in_graph_list = []
wadj, ind, outd = DiffConv.get_weight_matrix(g)
adj = sparse.coo_matrix(wadj/outd.cpu().numpy())
outg = dgl.from_scipy(adj, eweight_name='weight').to(device)
outg.edata['weight'] = outg.edata['weight'].float().to(device)
out_graph_list.append(outg)
for i in range(k-1):
out_graph_list.append(DiffConv.diffuse(
out_graph_list[-1], wadj, outd))
adj = sparse.coo_matrix(wadj.T/ind.cpu().numpy())
ing = dgl.from_scipy(adj, eweight_name='weight').to(device)
ing.edata['weight'] = ing.edata['weight'].float().to(device)
in_graph_list.append(ing)
for i in range(k-1):
in_graph_list.append(DiffConv.diffuse(
in_graph_list[-1], wadj.T, ind))
return out_graph_list, in_graph_list
@staticmethod
def get_weight_matrix(g):
adj = g.adj(scipy_fmt='coo')
ind = g.in_degrees()
outd = g.out_degrees()
weight = g.edata['weight']
adj.data = weight.cpu().numpy()
return adj, ind, outd
@staticmethod
def diffuse(progress_g, weighted_adj, degree):
device = progress_g.device
progress_adj = progress_g.adj(scipy_fmt='coo')
progress_adj.data = progress_g.edata['weight'].cpu().numpy()
ret_adj = sparse.coo_matrix(progress_adj@(
weighted_adj/degree.cpu().numpy()))
ret_graph = dgl.from_scipy(ret_adj, eweight_name='weight').to(device)
ret_graph.edata['weight'] = ret_graph.edata['weight'].float().to(
device)
return ret_graph
def forward(self, g, x):
feat_list = []
if self.dir == 'both':
graph_list = self.in_graph_list+self.out_graph_list
elif self.dir == 'in':
graph_list = self.in_graph_list
elif self.dir == 'out':
graph_list = self.out_graph_list
for i in range(self.num_graphs):
g = graph_list[i]
with g.local_scope():
g.ndata['n'] = self.project_fcs[i](x)
g.update_all(fn.u_mul_e('n', 'weight', 'e'),
fn.sum('e', 'feat'))
feat_list.append(g.ndata['feat'])
# Each feat has shape [N,q_feats]
feat_list.append(self.project_fcs[-1](x))
feat_list = torch.cat(feat_list).view(
len(feat_list), -1, self.out_feats)
ret = (self.merger*feat_list.permute(1, 2, 0)).permute(2, 0, 1).mean(0)
return ret
import numpy as np
import torch
import torch.nn as nn
import dgl
import dgl.nn as dglnn
from dgl.base import DGLError
import dgl.function as fn
from dgl.nn.functional import edge_softmax
class WeightedGATConv(dglnn.GATConv):
'''
This model inherit from dgl GATConv for traffic prediction task,
it add edge weight when aggregating the node feature.
'''
def forward(self, graph, feat, get_attention=False):
with graph.local_scope():
if not self._allow_zero_in_degree:
if (graph.in_degrees() == 0).any():
raise DGLError('There are 0-in-degree nodes in the graph, '
'output for those nodes will be invalid. '
'This is harmful for some applications, '
'causing silent performance regression. '
'Adding self-loop on the input graph by '
'calling `g = dgl.add_self_loop(g)` will resolve '
'the issue. Setting ``allow_zero_in_degree`` '
'to be `True` when constructing this module will '
'suppress the check and let the code run.')
if isinstance(feat, tuple):
h_src = self.feat_drop(feat[0])
h_dst = self.feat_drop(feat[1])
if not hasattr(self, 'fc_src'):
feat_src = self.fc(
h_src).view(-1, self._num_heads, self._out_feats)
feat_dst = self.fc(
h_dst).view(-1, self._num_heads, self._out_feats)
else:
feat_src = self.fc_src(
h_src).view(-1, self._num_heads, self._out_feats)
feat_dst = self.fc_dst(
h_dst).view(-1, self._num_heads, self._out_feats)
else:
h_src = h_dst = self.feat_drop(feat)
feat_src = feat_dst = self.fc(h_src).view(
-1, self._num_heads, self._out_feats)
if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()]
# NOTE: GAT paper uses "first concatenation then linear projection"
# to compute attention scores, while ours is "first projection then
# addition", the two approaches are mathematically equivalent:
# We decompose the weight vector a mentioned in the paper into
# [a_l || a_r], then
# a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j
# Our implementation is much efficient because we do not need to
# save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus,
# addition could be optimized with DGL's built-in function u_add_v,
# which further speeds up computation and saves memory footprint.
el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1)
er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1)
graph.srcdata.update({'ft': feat_src, 'el': el})
graph.dstdata.update({'er': er})
# compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.
graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
e = self.leaky_relu(graph.edata.pop('e'))
# compute softmax
graph.edata['a'] = self.attn_drop(edge_softmax(graph, e))
# compute weighted attention
graph.edata['a'] = (graph.edata['a'].permute(
1, 2, 0)*graph.edata['weight']).permute(2, 0, 1)
# message passing
graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
fn.sum('m', 'ft'))
rst = graph.dstdata['ft']
# residual
if self.res_fc is not None:
resval = self.res_fc(h_dst).view(
h_dst.shape[0], -1, self._out_feats)
rst = rst + resval
# activation
if self.activation:
rst = self.activation(rst)
if get_attention:
return rst, graph.edata['a']
else:
return rst
class GatedGAT(nn.Module):
'''Gated Graph Attention module, it is a general purpose
graph attention module proposed in paper GaAN. The paper use
it for traffic prediction task
Parameter
==========
in_feats : int
number of input feature
out_feats : int
number of output feature
map_feats : int
intermediate feature size for gate computation
num_heads : int
number of head for multihead attention
'''
def __init__(self, in_feats, out_feats, map_feats, num_heads):
super(GatedGAT, self).__init__()
self.in_feats = in_feats
self.out_feats = out_feats
self.map_feats = map_feats
self.num_heads = num_heads
self.gatlayer = WeightedGATConv(self.in_feats,
self.out_feats,
self.num_heads)
self.gate_fn = nn.Linear(
2*self.in_feats+self.map_feats, self.num_heads)
self.gate_m = nn.Linear(self.in_feats, self.map_feats)
self.merger_layer = nn.Linear(
self.in_feats+self.out_feats, self.out_feats)
def forward(self, g, x):
with g.local_scope():
g.ndata['x'] = x
g.ndata['z'] = self.gate_m(x)
g.update_all(fn.copy_u('x', 'x'), fn.mean('x', 'mean_z'))
g.update_all(fn.copy_u('z', 'z'), fn.max('z', 'max_z'))
nft = torch.cat([g.ndata['x'], g.ndata['max_z'],
g.ndata['mean_z']], dim=1)
gate = self.gate_fn(nft).sigmoid()
attn_out = self.gatlayer(g, x)
node_num = g.num_nodes()
gated_out = ((gate.view(-1)*attn_out.view(-1, self.out_feats).T).T).view(
node_num, self.num_heads, self.out_feats)
gated_out = gated_out.mean(1)
merge = self.merger_layer(torch.cat([x, gated_out], dim=1))
return merge
import numpy as np
import scipy.sparse as sparse
import torch
import torch.nn as nn
import dgl
import dgl.nn as dglnn
from dgl.base import DGLError
import dgl.function as fn
from dgl.nn.functional import edge_softmax
class GraphGRUCell(nn.Module):
'''Graph GRU unit which can use any message passing
net to replace the linear layer in the original GRU
Parameter
==========
in_feats : int
number of input features
out_feats : int
number of output features
net : torch.nn.Module
message passing network
'''
def __init__(self, in_feats, out_feats, net):
super(GraphGRUCell, self).__init__()
self.in_feats = in_feats
self.out_feats = out_feats
self.dir = dir
# net can be any GNN model
self.r_net = net(in_feats+out_feats, out_feats)
self.u_net = net(in_feats+out_feats, out_feats)
self.c_net = net(in_feats+out_feats, out_feats)
# Manually add bias Bias
self.r_bias = nn.Parameter(torch.rand(out_feats))
self.u_bias = nn.Parameter(torch.rand(out_feats))
self.c_bias = nn.Parameter(torch.rand(out_feats))
def forward(self, g, x, h):
r = torch.sigmoid(self.r_net(
g, torch.cat([x, h], dim=1)) + self.r_bias)
u = torch.sigmoid(self.u_net(
g, torch.cat([x, h], dim=1)) + self.u_bias)
h_ = r*h
c = torch.sigmoid(self.c_net(
g, torch.cat([x, h_], dim=1)) + self.c_bias)
new_h = u*h + (1-u)*c
return new_h
class StackedEncoder(nn.Module):
'''One step encoder unit for hidden representation generation
it can stack multiple vertical layers to increase the depth.
Parameter
==========
in_feats : int
number if input features
out_feats : int
number of output features
num_layers : int
vertical depth of one step encoding unit
net : torch.nn.Module
message passing network for graph computation
'''
def __init__(self, in_feats, out_feats, num_layers, net):
super(StackedEncoder, self).__init__()
self.in_feats = in_feats
self.out_feats = out_feats
self.num_layers = num_layers
self.net = net
self.layers = nn.ModuleList()
if self.num_layers <= 0:
raise DGLError("Layer Number must be greater than 0! ")
self.layers.append(GraphGRUCell(
self.in_feats, self.out_feats, self.net))
for _ in range(self.num_layers-1):
self.layers.append(GraphGRUCell(
self.out_feats, self.out_feats, self.net))
# hidden_states should be a list which for different layer
def forward(self, g, x, hidden_states):
hiddens = []
for i, layer in enumerate(self.layers):
x = layer(g, x, hidden_states[i])
hiddens.append(x)
return x, hiddens
class StackedDecoder(nn.Module):
'''One step decoder unit for hidden representation generation
it can stack multiple vertical layers to increase the depth.
Parameter
==========
in_feats : int
number if input features
hid_feats : int
number of feature before the linear output layer
out_feats : int
number of output features
num_layers : int
vertical depth of one step encoding unit
net : torch.nn.Module
message passing network for graph computation
'''
def __init__(self, in_feats, hid_feats, out_feats, num_layers, net):
super(StackedDecoder, self).__init__()
self.in_feats = in_feats
self.hid_feats = hid_feats
self.out_feats = out_feats
self.num_layers = num_layers
self.net = net
self.out_layer = nn.Linear(self.hid_feats, self.out_feats)
self.layers = nn.ModuleList()
if self.num_layers <= 0:
raise DGLError("Layer Number must be greater than 0!")
self.layers.append(GraphGRUCell(self.in_feats, self.hid_feats, net))
for _ in range(self.num_layers-1):
self.layers.append(GraphGRUCell(
self.hid_feats, self.hid_feats, net))
def forward(self, g, x, hidden_states):
hiddens = []
for i, layer in enumerate(self.layers):
x = layer(g, x, hidden_states[i])
hiddens.append(x)
x = self.out_layer(x)
return x, hiddens
class GraphRNN(nn.Module):
'''Graph Sequence to sequence prediction framework
Support multiple backbone GNN. Mainly used for traffic prediction.
Parameter
==========
in_feats : int
number of input features
out_feats : int
number of prediction output features
seq_len : int
input and predicted sequence length
num_layers : int
vertical number of layers in encoder and decoder unit
net : torch.nn.Module
Message passing GNN as backbone
decay_steps : int
number of steps for the teacher forcing probability to decay
'''
def __init__(self,
in_feats,
out_feats,
seq_len,
num_layers,
net,
decay_steps):
super(GraphRNN, self).__init__()
self.in_feats = in_feats
self.out_feats = out_feats
self.seq_len = seq_len
self.num_layers = num_layers
self.net = net
self.decay_steps = decay_steps
self.encoder = StackedEncoder(self.in_feats,
self.out_feats,
self.num_layers,
self.net)
self.decoder = StackedDecoder(self.in_feats,
self.out_feats,
self.in_feats,
self.num_layers,
self.net)
# Threshold For Teacher Forcing
def compute_thresh(self, batch_cnt):
return self.decay_steps/(self.decay_steps + np.exp(batch_cnt / self.decay_steps))
def encode(self, g, inputs, device):
hidden_states = [torch.zeros(g.num_nodes(), self.out_feats).to(
device) for _ in range(self.num_layers)]
for i in range(self.seq_len):
_, hidden_states = self.encoder(g, inputs[i], hidden_states)
return hidden_states
def decode(self, g, teacher_states, hidden_states, batch_cnt, device):
outputs = []
inputs = torch.zeros(g.num_nodes(), self.in_feats).to(device)
for i in range(self.seq_len):
if np.random.random() < self.compute_thresh(batch_cnt) and self.training:
inputs, hidden_states = self.decoder(
g, teacher_states[i], hidden_states)
else:
inputs, hidden_states = self.decoder(g, inputs, hidden_states)
outputs.append(inputs)
outputs = torch.stack(outputs)
return outputs
def forward(self, g, inputs, teacher_states, batch_cnt, device):
hidden = self.encode(g, inputs, device)
outputs = self.decode(g, teacher_states, hidden, batch_cnt, device)
return outputs
from functools import partial
import argparse
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import dgl
from model import GraphRNN
from dcrnn import DiffConv
from gaan import GatedGAT
from dataloading import METR_LAGraphDataset, METR_LATrainDataset,\
METR_LATestDataset, METR_LAValidDataset,\
PEMS_BAYGraphDataset, PEMS_BAYTrainDataset,\
PEMS_BAYValidDataset, PEMS_BAYTestDataset
from utils import NormalizationLayer, masked_mae_loss, get_learning_rate
batch_cnt = [0]
def train(model, graph, dataloader, optimizer, scheduler, normalizer, loss_fn, device, args):
total_loss = []
graph = graph.to(device)
model.train()
batch_size = args.batch_size
for i, (x, y) in enumerate(dataloader):
optimizer.zero_grad()
# Padding: Since the diffusion graph is precmputed we need to pad the batch so that
# each batch have same batch size
if x.shape[0] != batch_size:
x_buff = torch.zeros(
batch_size, x.shape[1], x.shape[2], x.shape[3])
y_buff = torch.zeros(
batch_size, x.shape[1], x.shape[2], x.shape[3])
x_buff[:x.shape[0], :, :, :] = x
x_buff[x.shape[0]:, :, :,
:] = x[-1].repeat(batch_size-x.shape[0], 1, 1, 1)
y_buff[:x.shape[0], :, :, :] = y
y_buff[x.shape[0]:, :, :,
:] = y[-1].repeat(batch_size-x.shape[0], 1, 1, 1)
x = x_buff
y = y_buff
# Permute the dimension for shaping
x = x.permute(1, 0, 2, 3)
y = y.permute(1, 0, 2, 3)
x_norm = normalizer.normalize(x).reshape(
x.shape[0], -1, x.shape[3]).float().to(device)
y_norm = normalizer.normalize(y).reshape(
x.shape[0], -1, x.shape[3]).float().to(device)
y = y.reshape(y.shape[0], -1, y.shape[3]).float().to(device)
batch_graph = dgl.batch([graph]*batch_size)
output = model(batch_graph, x_norm, y_norm, batch_cnt[0], device)
# Denormalization for loss compute
y_pred = normalizer.denormalize(output)
loss = loss_fn(y_pred, y)
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
optimizer.step()
if get_learning_rate(optimizer) > args.minimum_lr:
scheduler.step()
total_loss.append(float(loss))
batch_cnt[0] += 1
print("Batch: ", i)
return np.mean(total_loss)
def eval(model, graph, dataloader, normalizer, loss_fn, device, args):
total_loss = []
graph = graph.to(device)
model.eval()
batch_size = args.batch_size
for i, (x, y) in enumerate(dataloader):
# Padding: Since the diffusion graph is precmputed we need to pad the batch so that
# each batch have same batch size
if x.shape[0] != batch_size:
x_buff = torch.zeros(
batch_size, x.shape[1], x.shape[2], x.shape[3])
y_buff = torch.zeros(
batch_size, x.shape[1], x.shape[2], x.shape[3])
x_buff[:x.shape[0], :, :, :] = x
x_buff[x.shape[0]:, :, :,
:] = x[-1].repeat(batch_size-x.shape[0], 1, 1, 1)
y_buff[:x.shape[0], :, :, :] = y
y_buff[x.shape[0]:, :, :,
:] = y[-1].repeat(batch_size-x.shape[0], 1, 1, 1)
x = x_buff
y = y_buff
# Permute the order of dimension
x = x.permute(1, 0, 2, 3)
y = y.permute(1, 0, 2, 3)
x_norm = normalizer.normalize(x).reshape(
x.shape[0], -1, x.shape[3]).float().to(device)
y_norm = normalizer.normalize(y).reshape(
x.shape[0], -1, x.shape[3]).float().to(device)
y = y.reshape(x.shape[0], -1, x.shape[3]).to(device)
batch_graph = dgl.batch([graph]*batch_size)
output = model(batch_graph, x_norm, y_norm, i, device)
y_pred = normalizer.denormalize(output)
loss = loss_fn(y_pred, y)
total_loss.append(float(loss))
return np.mean(total_loss)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Define the arguments
parser.add_argument('--batch_size', type=int, default=64,
help="Size of batch for minibatch Training")
parser.add_argument('--num_workers', type=int, default=0,
help="Number of workers for parallel dataloading")
parser.add_argument('--model', type=str, default='dcrnn',
help="WHich model to use DCRNN vs GaAN")
parser.add_argument('--gpu', type=int, default=-1,
help="GPU indexm -1 for CPU training")
parser.add_argument('--diffsteps', type=int, default=2,
help="Step of constructing the diffusiob matrix")
parser.add_argument('--num_heads', type=int, default=2,
help="Number of multiattention head")
parser.add_argument('--decay_steps', type=int, default=2000,
help="Teacher forcing probability decay ratio")
parser.add_argument('--lr', type=float, default=0.01,
help="Initial learning rate")
parser.add_argument('--minimum_lr', type=float, default=2e-6,
help="Lower bound of learning rate")
parser.add_argument('--dataset', type=str, default='LA',
help="dataset LA for METR_LA; BAY for PEMS_BAY")
parser.add_argument('--epochs', type=int, default=100,
help="Number of epoches for training")
parser.add_argument('--max_grad_norm', type=float, default=5.0,
help="Maximum gradient norm for update parameters")
args = parser.parse_args()
# Load the datasets
if args.dataset == 'LA':
g = METR_LAGraphDataset()
train_data = METR_LATrainDataset()
test_data = METR_LATestDataset()
valid_data = METR_LAValidDataset()
elif args.dataset == 'BAY':
g = PEMS_BAYGraphDataset()
train_data = PEMS_BAYTrainDataset()
test_data = PEMS_BAYTestDataset()
valid_data = PEMS_BAYValidDataset()
if args.gpu == -1:
device = torch.device('cpu')
else:
device = torch.device('cuda:{}'.format(args.gpu))
train_loader = DataLoader(
train_data, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True)
valid_loader = DataLoader(
valid_data, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True)
test_loader = DataLoader(
test_data, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True)
normalizer = NormalizationLayer(train_data.mean, train_data.std)
if args.model == 'dcrnn':
batch_g = dgl.batch([g]*args.batch_size).to(device)
out_gs, in_gs = DiffConv.attach_graph(batch_g, args.diffsteps)
net = partial(DiffConv, k=args.diffsteps,
in_graph_list=in_gs, out_graph_list=out_gs)
elif args.model == 'gaan':
net = partial(GatedGAT, map_feats=64, num_heads=args.num_heads)
dcrnn = GraphRNN(in_feats=2,
out_feats=64,
seq_len=12,
num_layers=2,
net=net,
decay_steps=args.decay_steps).to(device)
optimizer = torch.optim.Adam(dcrnn.parameters(), lr=args.lr)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
loss_fn = masked_mae_loss
for e in range(args.epochs):
train_loss = train(dcrnn, g, train_loader, optimizer, scheduler,
normalizer, loss_fn, device, args)
valid_loss = eval(dcrnn, g, valid_loader,
normalizer, loss_fn, device, args)
test_loss = eval(dcrnn, g, test_loader,
normalizer, loss_fn, device, args)
print("Epoch: {} Train Loss: {} Valid Loss: {} Test Loss: {}".format(e,
train_loss,
valid_loss,
test_loss))
import dgl
import scipy.sparse as sparse
import numpy as np
import torch.nn as nn
import torch
class NormalizationLayer(nn.Module):
def __init__(self, mean, std):
self.mean = mean
self.std = std
# Here we shall expect mean and std be scaler
def normalize(self, x):
return (x-self.mean)/self.std
def denormalize(self, x):
return x*self.std + self.mean
def masked_mae_loss(y_pred, y_true):
mask = (y_true != 0).float()
mask /= mask.mean()
loss = torch.abs(y_pred - y_true)
loss = loss * mask
# trick for nans: https://discuss.pytorch.org/t/how-to-set-nan-in-tensor-to-0/3918/3
loss[loss != loss] = 0
return loss.mean()
def get_learning_rate(optimizer):
for param in optimizer.param_groups:
return param['lr']
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment