Unverified Commit 7f4086da authored by Tomohiro Endo's avatar Tomohiro Endo Committed by GitHub
Browse files

cuda() to to(device) (#3064)


Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent 2f43cdb3
......@@ -91,7 +91,7 @@ test_iter = torch.utils.data.DataLoader(test_data, batch_size)
loss = nn.MSELoss()
G = G.to(device)
model = STGCN_WAVE(blocks, n_his, n_route, G, drop_prob, num_layers, args.control_str).to(device)
model = STGCN_WAVE(blocks, n_his, n_route, G, drop_prob, num_layers, device, args.control_str).to(device)
optimizer = torch.optim.RMSprop(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.7)
......
......@@ -71,7 +71,7 @@ class OutputLayer(nn.Module):
return self.fc(x_t2)
class STGCN_WAVE(nn.Module):
def __init__(self, c, T, n, Lk, p, num_layers,control_str = 'TNTSTNTST'):
def __init__(self, c, T, n, Lk, p, num_layers, device, control_str = 'TNTSTNTST'):
super(STGCN_WAVE, self).__init__()
self.control_str = control_str # model structure controller
self.num_layers = len(control_str)
......@@ -90,7 +90,7 @@ class STGCN_WAVE(nn.Module):
self.layers.append(nn.LayerNorm([n,c[cnt]]))
self.output = OutputLayer(c[cnt], T + 1 - 2**(diapower), n)
for layer in self.layers:
layer = layer.cuda()
layer = layer.to(device)
def forward(self, x):
for i in range(self.num_layers):
i_layer = self.control_str[i]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment