Unverified Commit f7039418 authored by HaoWei-TomTom's avatar HaoWei-TomTom Committed by GitHub
Browse files

[Bugfix][Pytorch] Fix model save and load bug of stgcn_wave (#3303)


Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent b81bb914
......@@ -3,12 +3,12 @@ Spatio-Temporal Graph Convolutional Networks
- Paper link: [arXiv](https://arxiv.org/pdf/1709.04875v4.pdf)
- Author's code repo: https://github.com/VeritasYin/STGCN_IJCAI-18.
Dependencies
------------
- PyTorch 1.1.0+
- sklearn
- dgl
- tables
- See [this blog](https://towardsdatascience.com/build-your-first-graph-neural-network-model-to-predict-traffic-speed-in-20-minutes-b593f8f838e5) for more details about running the code.
- Dependencies
- PyTorch 1.1.0+
- sklearn
- dgl
- tables
How to run
......
......@@ -116,7 +116,7 @@ for epoch in range(1, epochs + 1):
print("epoch", epoch, ", train loss:", l_sum / n, ", validation loss:", val_loss)
best_model = STGCN_WAVE(blocks, n_his, n_route, G, drop_prob, num_layers).to(device)
best_model = STGCN_WAVE(blocks, n_his, n_route, G, drop_prob, num_layers, device, args.control_str).to(device)
best_model.load_state_dict(torch.load(save_path))
......
......@@ -75,7 +75,7 @@ class STGCN_WAVE(nn.Module):
super(STGCN_WAVE, self).__init__()
self.control_str = control_str # model structure controller
self.num_layers = len(control_str)
self.layers = []
self.layers = nn.ModuleList([])
cnt = 0
diapower = 0
for i in range(self.num_layers):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment