Unverified Commit cbe4c28f authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

Fix (#1596)

parent 36c7b771
...@@ -116,7 +116,7 @@ for epoch in range(1, epochs + 1): ...@@ -116,7 +116,7 @@ for epoch in range(1, epochs + 1):
print("epoch", epoch, ", train loss:", l_sum / n, ", validation loss:", val_loss) print("epoch", epoch, ", train loss:", l_sum / n, ", validation loss:", val_loss)
best_model = STGCN_WAVE(blocks, n_his, n_route, G, drop_prob).to(device) best_model = STGCN_WAVE(blocks, n_his, n_route, G, drop_prob, num_layers).to(device)
best_model.load_state_dict(torch.load(save_path)) best_model.load_state_dict(torch.load(save_path))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment