Unverified Commit e3752754 authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[Misc] Align dgl multi-gpu example to graphbolt (#6605)

parent 096bbf96
......@@ -123,12 +123,13 @@ class SAGE(nn.Module):
return y
def evaluate(model, g, num_classes, dataloader):
def evaluate(device, model, g, num_classes, dataloader):
model.eval()
ys = []
y_hats = []
for it, (input_nodes, output_nodes, blocks) in enumerate(dataloader):
with torch.no_grad():
blocks = [block.to(device) for block in blocks]
x = blocks[0].srcdata["feat"]
ys.append(blocks[-1].dstdata["label"])
y_hats.append(model(blocks, x))
......@@ -145,6 +146,8 @@ def layerwise_infer(
):
model.eval()
with torch.no_grad():
if not use_uva:
g = g.to(device)
pred = model.module.inference(g, device, batch_size, use_uva)
pred = pred[nid]
labels = g.ndata["label"][nid].to(pred.device)
......@@ -159,17 +162,20 @@ def train(
proc_id,
nprocs,
device,
args,
g,
num_classes,
train_idx,
val_idx,
model,
use_uva,
num_epochs,
):
# Instantiate a neighbor sampler
sampler = NeighborSampler(
[10, 10, 10], prefetch_node_feats=["feat"], prefetch_labels=["label"]
[10, 10, 10],
prefetch_node_feats=["feat"],
prefetch_labels=["label"],
fused=(args.mode != "benchmark"),
)
train_dataloader = DataLoader(
g,
......@@ -179,7 +185,7 @@ def train(
batch_size=1024,
shuffle=True,
drop_last=False,
num_workers=0,
num_workers=args.num_workers,
use_ddp=True, # To split the set for each process
use_uva=use_uva,
)
......@@ -191,12 +197,12 @@ def train(
batch_size=1024,
shuffle=True,
drop_last=False,
num_workers=0,
num_workers=args.num_workers,
use_ddp=True,
use_uva=use_uva,
)
opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
for epoch in range(num_epochs):
for epoch in range(args.num_epochs):
t0 = time.time()
model.train()
total_loss = 0
......@@ -224,7 +230,8 @@ def train(
# for more information.
#####################################################################
acc = (
evaluate(model, g, num_classes, val_dataloader).to(device) / nprocs
evaluate(device, model, g, num_classes, val_dataloader).to(device)
/ nprocs
)
t1 = time.time()
# Reduce `acc` tensors to process 0.
......@@ -236,7 +243,7 @@ def train(
)
def run(proc_id, nprocs, devices, g, data, mode, num_epochs):
def run(proc_id, nprocs, devices, g, data, args):
# Find corresponding device for current process.
device = devices[proc_id]
torch.cuda.set_device(device)
......@@ -263,9 +270,10 @@ def run(proc_id, nprocs, devices, g, data, mode, num_epochs):
rank=proc_id,
)
num_classes, train_idx, val_idx, test_idx = data
train_idx = train_idx.to(device)
val_idx = val_idx.to(device)
g = g.to(device if mode == "puregpu" else "cpu")
if args.mode != "benchmark":
train_idx = train_idx.to(device)
val_idx = val_idx.to(device)
g = g.to(device if args.mode == "puregpu" else "cpu")
in_size = g.ndata["feat"].shape[1]
model = SAGE(in_size, 256, num_classes).to(device)
model = DistributedDataParallel(
......@@ -273,20 +281,21 @@ def run(proc_id, nprocs, devices, g, data, mode, num_epochs):
)
# Training.
use_uva = mode == "mixed"
use_uva = args.mode == "mixed"
if proc_id == 0:
print("Training...")
train(
proc_id,
nprocs,
device,
args,
g,
num_classes,
train_idx,
val_idx,
model,
use_uva,
num_epochs,
)
# Testing.
......@@ -303,7 +312,7 @@ if __name__ == "__main__":
parser.add_argument(
"--mode",
default="mixed",
choices=["mixed", "puregpu"],
choices=["mixed", "puregpu", "benchmark"],
help="Training mode. 'mixed' for CPU-GPU mixed training, "
"'puregpu' for pure-GPU training.",
)
......@@ -317,7 +326,7 @@ if __name__ == "__main__":
parser.add_argument(
"--num_epochs",
type=int,
default=20,
default=10,
help="Number of epochs for train.",
)
parser.add_argument(
......@@ -332,6 +341,12 @@ if __name__ == "__main__":
default="dataset",
help="Root directory of dataset.",
)
parser.add_argument(
"--num_workers",
type=int,
default=0,
help="Number of workers",
)
args = parser.parse_args()
devices = list(map(int, args.gpu.split(",")))
nprocs = len(devices)
......@@ -364,6 +379,6 @@ if __name__ == "__main__":
# To use DDP with n GPUs, spawn up n processes.
mp.spawn(
run,
args=(nprocs, devices, g, data, args.mode, args.num_epochs),
args=(nprocs, devices, g, data, args),
nprocs=nprocs,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment