"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "13c754c15d5952f9e160b952d4177f1b7b329a67"
Unverified Commit 7f4086da authored by Tomohiro Endo's avatar Tomohiro Endo Committed by GitHub
Browse files

cuda() to to(device) (#3064)


Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent 2f43cdb3
...@@ -91,7 +91,7 @@ test_iter = torch.utils.data.DataLoader(test_data, batch_size) ...@@ -91,7 +91,7 @@ test_iter = torch.utils.data.DataLoader(test_data, batch_size)
loss = nn.MSELoss() loss = nn.MSELoss()
G = G.to(device) G = G.to(device)
model = STGCN_WAVE(blocks, n_his, n_route, G, drop_prob, num_layers, args.control_str).to(device) model = STGCN_WAVE(blocks, n_his, n_route, G, drop_prob, num_layers, device, args.control_str).to(device)
optimizer = torch.optim.RMSprop(model.parameters(), lr=lr) optimizer = torch.optim.RMSprop(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.7) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.7)
......
...@@ -71,7 +71,7 @@ class OutputLayer(nn.Module): ...@@ -71,7 +71,7 @@ class OutputLayer(nn.Module):
return self.fc(x_t2) return self.fc(x_t2)
class STGCN_WAVE(nn.Module): class STGCN_WAVE(nn.Module):
def __init__(self, c, T, n, Lk, p, num_layers,control_str = 'TNTSTNTST'): def __init__(self, c, T, n, Lk, p, num_layers, device, control_str = 'TNTSTNTST'):
super(STGCN_WAVE, self).__init__() super(STGCN_WAVE, self).__init__()
self.control_str = control_str # model structure controller self.control_str = control_str # model structure controller
self.num_layers = len(control_str) self.num_layers = len(control_str)
...@@ -90,7 +90,7 @@ class STGCN_WAVE(nn.Module): ...@@ -90,7 +90,7 @@ class STGCN_WAVE(nn.Module):
self.layers.append(nn.LayerNorm([n,c[cnt]])) self.layers.append(nn.LayerNorm([n,c[cnt]]))
self.output = OutputLayer(c[cnt], T + 1 - 2**(diapower), n) self.output = OutputLayer(c[cnt], T + 1 - 2**(diapower), n)
for layer in self.layers: for layer in self.layers:
layer = layer.cuda() layer = layer.to(device)
def forward(self, x): def forward(self, x):
for i in range(self.num_layers): for i in range(self.num_layers):
i_layer = self.control_str[i] i_layer = self.control_str[i]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment