"docs/git@developer.sourcefind.cn:one/TransferBench.git" did not exist on "877fd0d0f647b829c0e009d718f86a1774f78a53"
Unverified Commit 2584f3af authored by Andrei Ivanov's avatar Andrei Ivanov Committed by GitHub
Browse files

Improving the HGT example. (#6069)

parent ee651811
...@@ -49,7 +49,6 @@ def get_n_params(model): ...@@ -49,7 +49,6 @@ def get_n_params(model):
def train(model, G): def train(model, G):
best_val_acc = torch.tensor(0) best_val_acc = torch.tensor(0)
best_test_acc = torch.tensor(0) best_test_acc = torch.tensor(0)
train_step = torch.tensor(0)
for epoch in np.arange(args.n_epoch) + 1: for epoch in np.arange(args.n_epoch) + 1:
model.train() model.train()
logits = model(G, "paper") logits = model(G, "paper")
...@@ -59,8 +58,7 @@ def train(model, G): ...@@ -59,8 +58,7 @@ def train(model, G):
loss.backward() loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
optimizer.step() optimizer.step()
train_step += 1 scheduler.step()
scheduler.step(train_step)
if epoch % 5 == 0: if epoch % 5 == 0:
model.eval() model.eval()
logits = model(G, "paper") logits = model(G, "paper")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment