"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "97f543f3bc917ecaefb5b194f5cd7c342fbf78c9"
Unverified Commit 364806f2 authored by Mingbang Wang's avatar Mingbang Wang Committed by GitHub
Browse files

[GraphBolt] Modify `node_classification` for benchmark (#6501)

parent e645d936
...@@ -38,6 +38,7 @@ main ...@@ -38,6 +38,7 @@ main
└───> All nodes set inference & Test set evaluation └───> All nodes set inference & Test set evaluation
""" """
import argparse import argparse
import time
import dgl.graphbolt as gb import dgl.graphbolt as gb
import dgl.nn as dglnn import dgl.nn as dglnn
...@@ -282,6 +283,7 @@ def train(args, graph, features, train_set, valid_set, num_classes, model): ...@@ -282,6 +283,7 @@ def train(args, graph, features, train_set, valid_set, num_classes, model):
) )
for epoch in range(args.epochs): for epoch in range(args.epochs):
t0 = time.time()
model.train() model.train()
total_loss = 0 total_loss = 0
for step, data in tqdm(enumerate(dataloader)): for step, data in tqdm(enumerate(dataloader)):
...@@ -304,11 +306,12 @@ def train(args, graph, features, train_set, valid_set, num_classes, model): ...@@ -304,11 +306,12 @@ def train(args, graph, features, train_set, valid_set, num_classes, model):
total_loss += loss.item() total_loss += loss.item()
t1 = time.time()
# Evaluate the model. # Evaluate the model.
acc = evaluate(args, model, graph, features, valid_set, num_classes) acc = evaluate(args, model, graph, features, valid_set, num_classes)
print( print(
f"Epoch {epoch:05d} | Loss {total_loss / (step + 1):.4f} | " f"Epoch {epoch:05d} | Loss {total_loss / (step + 1):.4f} | "
f"Accuracy {acc.item():.4f} " f"Accuracy {acc.item():.4f} | Time {t1 - t0:.4f}"
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment