Unverified Commit 9b34b1c2 authored by jianmohuo's avatar jianmohuo Committed by GitHub
Browse files

[Model] Spatial-temporal Graph Neural Networks for Traffic Prediction (#1445)



* stgcn_wave model

* fix readme

* rm data file

* split sensors2graph

* rm dead code

* fix README

* rename class

* rm seed & dead code

* Update README.md

* rm dead code & networkx

* add num_layer papram, make model structure adjustable

* fix

* add model structure controller string, make code easier to understand and make model strcture more flexible

* Update main.py

* Update model.py

* fix

* Update README.md
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-14-255.ap-northeast-1.compute.internal>
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
Co-authored-by: default avatarDa Zheng <zhengda1936@gmail.com>
parent ed60f401
Spatio-Temporal Graph Convolutional Networks
============
- Paper link: [arXiv](https://arxiv.org/pdf/1709.04875v4.pdf)
- Author's code repo: https://github.com/VeritasYin/STGCN_IJCAI-18.
Dependencies
------------
- PyTorch 1.1.0+
- sklearn
- dgl
How to run
----------
please get METR_LA dataset from [this Google drive](https://drive.google.com/open?id=10FOTa6HXPqX8Pf5WRoRwcFnW9BrNZEIX).
and [this Github repo](https://github.com/chnsh/DCRNN_PyTorch)
An experiment in default settings can be run with
```bash
python main.py
```
An experiment on the METR_LA dataset in customized settings can be run with
```bash
python main.py --lr --seed --disable-cuda --batch_size <batch-size> --epochs <number-of-epochs>
```
If one wishes to adjust the model structure, you can change the arguments `control_str` and `channels`
```bash
python main.py --control_str <control-string> --channels <n-input-channel> <n-hidden-channels-1> <n-hidden-channels-2> ... <n-output-channels>
```
`<control-string>` is a string of the following characters representing a sequence of neural network modules:
* `T`: representing a dilated temporal convolution layer, working on the temporal dimension. The dilation factor is always twice as much as the previous temporal convolution layer.
* `S`: representing a graph convolution layer, working on the spatial dimension. The input channels and output channels are the same.
* `N`: a Layer Normalization.
The argument list following `--channels` represents the output channels on each temporal convolution layer. The list should have `N + 1` elements, where `N` is the number of `T`'s in `<control-string>`.
The activation function between two layers are always ReLU.
For example, the following command
```bash
python main.py --control_str TNTSTNTST --channels 1 16 32 32 64 128
```
specifies the following architecture:
```
+------------------------------------------------------------+
| Input |
+------------------------------------------------------------+
| 1D Conv, in_channel = 1, out_channel = 16, dilation = 1 |
+------------------------------------------------------------+
| Layer Normalization |
+------------------------------------------------------------+
| 1D Conv, in_channel = 16, out_channel = 32, dilation = 2 |
+------------------------------------------------------------+
| Graph Conv, in_channel = 32, out_channel = 32 |
+------------------------------------------------------------+
| 1D Conv, in_channel = 32, out_channel = 32, dilation = 4 |
+------------------------------------------------------------+
| Layer Normalization |
+------------------------------------------------------------+
| 1D Conv, in_channel = 32, out_channel = 64, dilation = 8 |
+------------------------------------------------------------+
| Graph Conv, in_channel = 64, out_channel = 64 |
+------------------------------------------------------------+
| 1D Conv, in_channel = 64, out_channel = 128, dilation = 16 |
+------------------------------------------------------------+
```
Results
-------
```bash
python main.py
```
METR_LA MAE: ~5.76
import torch
import numpy as np
import pandas as pd
def load_data(file_path, len_train, len_val):
df = pd.read_csv(file_path, header=None).values.astype(float)
train = df[: len_train]
val = df[len_train: len_train + len_val]
test = df[len_train + len_val:]
return train, val, test
def data_transform(data, n_his, n_pred, device):
# produce data slices for training and testing
n_route = data.shape[1]
l = len(data)
num = l-n_his-n_pred
x = np.zeros([num, 1, n_his, n_route])
y = np.zeros([num, n_route])
cnt = 0
for i in range(l-n_his-n_pred):
head = i
tail = i + n_his
x[cnt, :, :, :] = data[head: tail].reshape(1, n_his, n_route)
y[cnt] = data[tail + n_pred - 1]
cnt += 1
return torch.Tensor(x).to(device), torch.Tensor(y).to(device)
import dgl
import random
import torch
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from load_data import *
from utils import *
from model import *
from sensors2graph import *
import torch.nn as nn
import argparse
import scipy.sparse as sp
parser = argparse.ArgumentParser(description='STGCN_WAVE')
parser.add_argument('--lr', default=0.001, type=float, help='learning rate')
parser.add_argument('--disablecuda', action='store_true', help='Disable CUDA')
parser.add_argument('--batch_size', type=int, default=50, help='batch size for training and validation (default: 50)')
parser.add_argument('--epochs', type=int, default=50, help='epochs for training (default: 50)')
parser.add_argument('--num_layers', type=int, default=9, help='number of layers')
parser.add_argument('--window', type=int, default=144, help='window length')
parser.add_argument('--sensorsfilepath', type=str, default='./data/sensor_graph/graph_sensor_ids.txt', help='sensors file path')
parser.add_argument('--disfilepath', type=str, default='./data/sensor_graph/distances_la_2012.csv', help='distance file path')
parser.add_argument('--tsfilepath', type=str, default='./data/metr-la.h5', help='ts file path')
parser.add_argument('--savemodelpath', type=str, default='./save/stgcnwavemodel.pt', help='save model path')
parser.add_argument('--pred_len', type=int, default=5, help='how many steps away we want to predict')
parser.add_argument('--control_str', type=str, default='TNTSTNTST', help='model strcture controller, T: Temporal Layer, S: Spatio Layer, N: Norm Layer')
parser.add_argument('--channels', type=int, nargs='+', default=[1, 16, 32, 64, 32, 128], help='model strcture controller, T: Temporal Layer, S: Spatio Layer, N: Norm Layer')
args = parser.parse_args()
device = torch.device("cuda") if torch.cuda.is_available() and not args.disablecuda else torch.device("cpu")
with open(args.sensorsfilepath) as f:
sensor_ids = f.read().strip().split(',')
distance_df = pd.read_csv(args.disfilepath, dtype={'from': 'str', 'to': 'str'})
adj_mx = get_adjacency_matrix(distance_df, sensor_ids)
sp_mx = sp.coo_matrix(adj_mx)
G = dgl.DGLGraph()
G.from_scipy_sparse_matrix(sp_mx)
df = pd.read_hdf(args.tsfilepath)
num_samples, num_nodes = df.shape
tsdata = df.to_numpy()
n_his = args.window
save_path = args.savemodelpath
n_pred = args.pred_len
n_route = num_nodes
blocks = args.channels
# blocks = [1, 16, 32, 64, 32, 128]
drop_prob = 0
num_layers = args.num_layers
batch_size = args.batch_size
epochs = args.epochs
lr = args.lr
W = adj_mx
len_val = round(num_samples * 0.1)
len_train = round(num_samples * 0.7)
train = df[: len_train]
val = df[len_train: len_train + len_val]
test = df[len_train + len_val:]
scaler = StandardScaler()
train = scaler.fit_transform(train)
val = scaler.transform(val)
test = scaler.transform(test)
x_train, y_train = data_transform(train, n_his, n_pred, device)
x_val, y_val = data_transform(val, n_his, n_pred, device)
x_test, y_test = data_transform(test, n_his, n_pred, device)
train_data = torch.utils.data.TensorDataset(x_train, y_train)
train_iter = torch.utils.data.DataLoader(train_data, batch_size, shuffle=True)
val_data = torch.utils.data.TensorDataset(x_val, y_val)
val_iter = torch.utils.data.DataLoader(val_data, batch_size)
test_data = torch.utils.data.TensorDataset(x_test, y_test)
test_iter = torch.utils.data.DataLoader(test_data, batch_size)
loss = nn.MSELoss()
model = STGCN_WAVE(blocks, n_his, n_route, G, drop_prob, num_layers, args.control_str).to(device)
optimizer = torch.optim.RMSprop(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.7)
min_val_loss = np.inf
for epoch in range(1, epochs + 1):
l_sum, n = 0.0, 0
model.train()
for x, y in train_iter:
y_pred = model(x).view(len(x), -1)
l = loss(y_pred, y)
optimizer.zero_grad()
l.backward()
optimizer.step()
l_sum += l.item() * y.shape[0]
n += y.shape[0]
scheduler.step()
val_loss = evaluate_model(model, loss, val_iter)
if val_loss < min_val_loss:
min_val_loss = val_loss
torch.save(model.state_dict(), save_path)
print("epoch", epoch, ", train loss:", l_sum / n, ", validation loss:", val_loss)
best_model = STGCN_WAVE(blocks, n_his, n_route, G, drop_prob).to(device)
best_model.load_state_dict(torch.load(save_path))
l = evaluate_model(best_model, loss, test_iter)
MAE, MAPE, RMSE = evaluate_metric(best_model, test_iter, scaler)
print("test loss:", l, "\nMAE:", MAE, ", MAPE:", MAPE, ", RMSE:", RMSE)
import math
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
from dgl.nn.pytorch import GraphConv
from dgl.nn.pytorch.conv import ChebConv
class TemporalConvLayer(nn.Module):
''' Temporal convolution layer.
arguments
---------
c_in : int
The number of input channels (features)
c_out : int
The number of output channels (features)
dia : int
The dilation size
'''
def __init__(self, c_in, c_out, dia = 1):
super(TemporalConvLayer, self).__init__()
self.c_out = c_out
self.c_in = c_in
self.conv = nn.Conv2d(c_in, c_out, (2, 1), 1, dilation = dia, padding = (0,0))
def forward(self, x):
return torch.relu(self.conv(x))
class SpatioConvLayer(nn.Module):
def __init__(self, c, Lk): # c : hidden dimension Lk: graph matrix
super(SpatioConvLayer, self).__init__()
self.g = Lk
self.gc = GraphConv(c, c, activation=F.relu)
# self.gc = ChebConv(c, c, 3)
def init(self):
stdv = 1. / math.sqrt(self.W.weight.size(1))
self.W.weight.data.uniform_(-stdv, stdv)
def forward(self, x):
x = x.transpose(0, 3)
x = x.transpose(1, 3)
output = self.gc(self.g, x)
output = output.transpose(1, 3)
output = output.transpose(0, 3)
return torch.relu(output)
class FullyConvLayer(nn.Module):
def __init__(self, c):
super(FullyConvLayer, self).__init__()
self.conv = nn.Conv2d(c, 1, 1)
def forward(self, x):
return self.conv(x)
class OutputLayer(nn.Module):
def __init__(self, c, T, n):
super(OutputLayer, self).__init__()
self.tconv1 = nn.Conv2d(c, c, (T, 1), 1, dilation = 1, padding = (0,0))
self.ln = nn.LayerNorm([n, c])
self.tconv2 = nn.Conv2d(c, c, (1, 1), 1, dilation = 1, padding = (0,0))
self.fc = FullyConvLayer(c)
def forward(self, x):
x_t1 = self.tconv1(x)
x_ln = self.ln(x_t1.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
x_t2 = self.tconv2(x_ln)
return self.fc(x_t2)
class STGCN_WAVE(nn.Module):
def __init__(self, c, T, n, Lk, p, num_layers,control_str = 'TNTSTNTST'):
super(STGCN_WAVE, self).__init__()
self.control_str = control_str # model structure controller
self.num_layers = len(control_str)
self.layers = []
cnt = 0
diapower = 0
for i in range(self.num_layers):
i_layer = control_str[i]
if i_layer == 'T': # Temporal Layer
self.layers.append(TemporalConvLayer(c[cnt], c[cnt + 1], dia = 2**diapower))
diapower += 1
cnt += 1
if i_layer == 'S': # Spatio Layer
self.layers.append(SpatioConvLayer(c[cnt], Lk))
if i_layer == 'N': # Norm Layer
self.layers.append(nn.LayerNorm([n,c[cnt]]))
self.output = OutputLayer(c[cnt], T + 1 - 2**(diapower), n)
for layer in self.layers:
layer = layer.cuda()
def forward(self, x):
for i in range(self.num_layers):
i_layer = self.control_str[i]
if i_layer == 'N':
x = self.layers[i](x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
else:
x = self.layers[i](x)
return self.output(x)
import numpy as np
def get_adjacency_matrix(distance_df, sensor_ids, normalized_k=0.1):
"""
:param distance_df: data frame with three columns: [from, to, distance].
:param sensor_ids: list of sensor ids.
:param normalized_k: entries that become lower than normalized_k after normalization are set to zero for sparsity.
:return: adjacency matrix
"""
num_sensors = len(sensor_ids)
dist_mx = np.zeros((num_sensors, num_sensors), dtype=np.float32)
dist_mx[:] = np.inf
# Builds sensor id to index map.
sensor_id_to_ind = {}
for i, sensor_id in enumerate(sensor_ids):
sensor_id_to_ind[sensor_id] = i
# Fills cells in the matrix with distances.
for row in distance_df.values:
if row[0] not in sensor_id_to_ind or row[1] not in sensor_id_to_ind:
continue
dist_mx[sensor_id_to_ind[row[0]], sensor_id_to_ind[row[1]]] = row[2]
# Calculates the standard deviation as theta.
distances = dist_mx[~np.isinf(dist_mx)].flatten()
std = distances.std()
adj_mx = np.exp(-np.square(dist_mx / std))
# Make the adjacent matrix symmetric by taking the max.
# adj_mx = np.maximum.reduce([adj_mx, adj_mx.T])
# Sets entries that lower than a threshold, i.e., k, to zero for sparsity.
adj_mx[adj_mx < normalized_k] = 0
return adj_mx
\ No newline at end of file
import torch
import numpy as np
def evaluate_model(model, loss, data_iter):
model.eval()
l_sum, n = 0.0, 0
with torch.no_grad():
for x, y in data_iter:
y_pred = model(x).view(len(x), -1)
l = loss(y_pred, y)
l_sum += l.item() * y.shape[0]
n += y.shape[0]
return l_sum / n
def evaluate_metric(model, data_iter, scaler):
model.eval()
with torch.no_grad():
mae, mape, mse = [], [], []
for x, y in data_iter:
y = scaler.inverse_transform(y.cpu().numpy()).reshape(-1)
y_pred = scaler.inverse_transform(model(x).view(len(x), -1).cpu().numpy()).reshape(-1)
d = np.abs(y - y_pred)
mae += d.tolist()
mape += (d / y).tolist()
mse += (d ** 2).tolist()
MAE = np.array(mae).mean()
MAPE = np.array(mape).mean()
RMSE = np.sqrt(np.array(mse).mean())
return MAE, MAPE, RMSE
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment