"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "6a7a5467cab6df8bb24b20a7ad3f2223c1a2e8de"
Unverified Commit 2584f3af authored by Andrei Ivanov's avatar Andrei Ivanov Committed by GitHub
Browse files

Improving the HGT example. (#6069)

parent ee651811
......@@ -49,7 +49,6 @@ def get_n_params(model):
def train(model, G):
best_val_acc = torch.tensor(0)
best_test_acc = torch.tensor(0)
train_step = torch.tensor(0)
for epoch in np.arange(args.n_epoch) + 1:
model.train()
logits = model(G, "paper")
......@@ -59,8 +58,7 @@ def train(model, G):
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
optimizer.step()
train_step += 1
scheduler.step(train_step)
scheduler.step()
if epoch % 5 == 0:
model.eval()
logits = model(G, "paper")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment