Unverified Commit f5c8c0b5 authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[GraphBolt] Update multi-gpu example for regression benchmark. (#6519)

parent 84a01a16
......@@ -38,6 +38,7 @@ main
"""
import argparse
import os
import time
import dgl.graphbolt as gb
import dgl.nn as dglnn
......@@ -138,7 +139,9 @@ def create_dataloader(
# A CopyTo object copying data in the datapipe to a specified device.\
############################################################################
datapipe = datapipe.copy_to(device)
dataloader = gb.SingleProcessDataLoader(datapipe)
dataloader = gb.MultiProcessDataLoader(
datapipe, num_workers=args.num_workers
)
# Return the fully-initialized DataLoader object.
return dataloader
......@@ -204,6 +207,8 @@ def train(
)
for epoch in range(args.epochs):
epoch_start = time.time()
model.train()
total_loss = torch.tensor(0, dtype=torch.float).to(device)
########################################################################
......@@ -273,11 +278,14 @@ def train(
dist.reduce(tensor=acc, dst=0)
total_loss /= step + 1
dist.reduce(tensor=total_loss, dst=0)
epoch_end = time.time()
if rank == 0:
print(
f"Epoch {epoch:05d} | "
f"Average Loss {total_loss.item() / world_size:.4f} | "
f"Accuracy {acc.item():.4f} "
f"Accuracy {acc.item():.4f} | "
f"Time {epoch_end - epoch_start:.4f}"
)
......@@ -342,7 +350,7 @@ def run(rank, world_size, args, devices, dataset):
)
dist.reduce(tensor=test_acc, dst=0)
if rank == 0:
print(f"Test Accuracy is {test_acc.item():.4f}")
print(f"Test Accuracy {test_acc.item():.4f}")
def parse_args():
......@@ -376,6 +384,9 @@ def parse_args():
help="Fan-out of neighbor sampling. It is IMPORTANT to keep len(fanout)"
" identical with the number of layers in your model. Default: 15,10,5",
)
parser.add_argument(
"--num-workers", type=int, default=0, help="The number of processes."
)
return parser.parse_args()
......@@ -393,6 +404,9 @@ if __name__ == "__main__":
# Load and preprocess dataset.
dataset = gb.BuiltinDataset("ogbn-products").load()
# Thread limiting to avoid resource competition.
os.environ["OMP_NUM_THREADS"] = str(mp.cpu_count() // 2 // world_size)
mp.set_sharing_strategy("file_system")
mp.spawn(
run,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment