"git@developer.sourcefind.cn:change/sglang.git" did not exist on "c250939ecb206d23d1c52291eba31d7659456081"
Unverified Commit 364806f2 authored by Mingbang Wang's avatar Mingbang Wang Committed by GitHub
Browse files

[GraphBolt] Modify `node_classification` for benchmark (#6501)

parent e645d936
......@@ -38,6 +38,7 @@ main
└───> All nodes set inference & Test set evaluation
"""
import argparse
import time
import dgl.graphbolt as gb
import dgl.nn as dglnn
......@@ -282,6 +283,7 @@ def train(args, graph, features, train_set, valid_set, num_classes, model):
)
for epoch in range(args.epochs):
t0 = time.time()
model.train()
total_loss = 0
for step, data in tqdm(enumerate(dataloader)):
......@@ -304,11 +306,12 @@ def train(args, graph, features, train_set, valid_set, num_classes, model):
total_loss += loss.item()
t1 = time.time()
# Evaluate the model.
acc = evaluate(args, model, graph, features, valid_set, num_classes)
print(
f"Epoch {epoch:05d} | Loss {total_loss / (step + 1):.4f} | "
f"Accuracy {acc.item():.4f} "
f"Accuracy {acc.item():.4f} | Time {t1 - t0:.4f}"
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment