Unverified Commit e3752754 authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[Misc] Align dgl multi-gpu example to graphbolt (#6605)

parent 096bbf96
...@@ -123,12 +123,13 @@ class SAGE(nn.Module): ...@@ -123,12 +123,13 @@ class SAGE(nn.Module):
return y return y
def evaluate(model, g, num_classes, dataloader): def evaluate(device, model, g, num_classes, dataloader):
model.eval() model.eval()
ys = [] ys = []
y_hats = [] y_hats = []
for it, (input_nodes, output_nodes, blocks) in enumerate(dataloader): for it, (input_nodes, output_nodes, blocks) in enumerate(dataloader):
with torch.no_grad(): with torch.no_grad():
blocks = [block.to(device) for block in blocks]
x = blocks[0].srcdata["feat"] x = blocks[0].srcdata["feat"]
ys.append(blocks[-1].dstdata["label"]) ys.append(blocks[-1].dstdata["label"])
y_hats.append(model(blocks, x)) y_hats.append(model(blocks, x))
...@@ -145,6 +146,8 @@ def layerwise_infer( ...@@ -145,6 +146,8 @@ def layerwise_infer(
): ):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
if not use_uva:
g = g.to(device)
pred = model.module.inference(g, device, batch_size, use_uva) pred = model.module.inference(g, device, batch_size, use_uva)
pred = pred[nid] pred = pred[nid]
labels = g.ndata["label"][nid].to(pred.device) labels = g.ndata["label"][nid].to(pred.device)
...@@ -159,17 +162,20 @@ def train( ...@@ -159,17 +162,20 @@ def train(
proc_id, proc_id,
nprocs, nprocs,
device, device,
args,
g, g,
num_classes, num_classes,
train_idx, train_idx,
val_idx, val_idx,
model, model,
use_uva, use_uva,
num_epochs,
): ):
# Instantiate a neighbor sampler # Instantiate a neighbor sampler
sampler = NeighborSampler( sampler = NeighborSampler(
[10, 10, 10], prefetch_node_feats=["feat"], prefetch_labels=["label"] [10, 10, 10],
prefetch_node_feats=["feat"],
prefetch_labels=["label"],
fused=(args.mode != "benchmark"),
) )
train_dataloader = DataLoader( train_dataloader = DataLoader(
g, g,
...@@ -179,7 +185,7 @@ def train( ...@@ -179,7 +185,7 @@ def train(
batch_size=1024, batch_size=1024,
shuffle=True, shuffle=True,
drop_last=False, drop_last=False,
num_workers=0, num_workers=args.num_workers,
use_ddp=True, # To split the set for each process use_ddp=True, # To split the set for each process
use_uva=use_uva, use_uva=use_uva,
) )
...@@ -191,12 +197,12 @@ def train( ...@@ -191,12 +197,12 @@ def train(
batch_size=1024, batch_size=1024,
shuffle=True, shuffle=True,
drop_last=False, drop_last=False,
num_workers=0, num_workers=args.num_workers,
use_ddp=True, use_ddp=True,
use_uva=use_uva, use_uva=use_uva,
) )
opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4) opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
for epoch in range(num_epochs): for epoch in range(args.num_epochs):
t0 = time.time() t0 = time.time()
model.train() model.train()
total_loss = 0 total_loss = 0
...@@ -224,7 +230,8 @@ def train( ...@@ -224,7 +230,8 @@ def train(
# for more information. # for more information.
##################################################################### #####################################################################
acc = ( acc = (
evaluate(model, g, num_classes, val_dataloader).to(device) / nprocs evaluate(device, model, g, num_classes, val_dataloader).to(device)
/ nprocs
) )
t1 = time.time() t1 = time.time()
# Reduce `acc` tensors to process 0. # Reduce `acc` tensors to process 0.
...@@ -236,7 +243,7 @@ def train( ...@@ -236,7 +243,7 @@ def train(
) )
def run(proc_id, nprocs, devices, g, data, mode, num_epochs): def run(proc_id, nprocs, devices, g, data, args):
# Find corresponding device for current process. # Find corresponding device for current process.
device = devices[proc_id] device = devices[proc_id]
torch.cuda.set_device(device) torch.cuda.set_device(device)
...@@ -263,9 +270,10 @@ def run(proc_id, nprocs, devices, g, data, mode, num_epochs): ...@@ -263,9 +270,10 @@ def run(proc_id, nprocs, devices, g, data, mode, num_epochs):
rank=proc_id, rank=proc_id,
) )
num_classes, train_idx, val_idx, test_idx = data num_classes, train_idx, val_idx, test_idx = data
train_idx = train_idx.to(device) if args.mode != "benchmark":
val_idx = val_idx.to(device) train_idx = train_idx.to(device)
g = g.to(device if mode == "puregpu" else "cpu") val_idx = val_idx.to(device)
g = g.to(device if args.mode == "puregpu" else "cpu")
in_size = g.ndata["feat"].shape[1] in_size = g.ndata["feat"].shape[1]
model = SAGE(in_size, 256, num_classes).to(device) model = SAGE(in_size, 256, num_classes).to(device)
model = DistributedDataParallel( model = DistributedDataParallel(
...@@ -273,20 +281,21 @@ def run(proc_id, nprocs, devices, g, data, mode, num_epochs): ...@@ -273,20 +281,21 @@ def run(proc_id, nprocs, devices, g, data, mode, num_epochs):
) )
# Training. # Training.
use_uva = mode == "mixed" use_uva = args.mode == "mixed"
if proc_id == 0: if proc_id == 0:
print("Training...") print("Training...")
train( train(
proc_id, proc_id,
nprocs, nprocs,
device, device,
args,
g, g,
num_classes, num_classes,
train_idx, train_idx,
val_idx, val_idx,
model, model,
use_uva, use_uva,
num_epochs,
) )
# Testing. # Testing.
...@@ -303,7 +312,7 @@ if __name__ == "__main__": ...@@ -303,7 +312,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--mode", "--mode",
default="mixed", default="mixed",
choices=["mixed", "puregpu"], choices=["mixed", "puregpu", "benchmark"],
help="Training mode. 'mixed' for CPU-GPU mixed training, " help="Training mode. 'mixed' for CPU-GPU mixed training, "
"'puregpu' for pure-GPU training.", "'puregpu' for pure-GPU training.",
) )
...@@ -317,7 +326,7 @@ if __name__ == "__main__": ...@@ -317,7 +326,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--num_epochs", "--num_epochs",
type=int, type=int,
default=20, default=10,
help="Number of epochs for train.", help="Number of epochs for train.",
) )
parser.add_argument( parser.add_argument(
...@@ -332,6 +341,12 @@ if __name__ == "__main__": ...@@ -332,6 +341,12 @@ if __name__ == "__main__":
default="dataset", default="dataset",
help="Root directory of dataset.", help="Root directory of dataset.",
) )
parser.add_argument(
"--num_workers",
type=int,
default=0,
help="Number of workers",
)
args = parser.parse_args() args = parser.parse_args()
devices = list(map(int, args.gpu.split(","))) devices = list(map(int, args.gpu.split(",")))
nprocs = len(devices) nprocs = len(devices)
...@@ -364,6 +379,6 @@ if __name__ == "__main__": ...@@ -364,6 +379,6 @@ if __name__ == "__main__":
# To use DDP with n GPUs, spawn up n processes. # To use DDP with n GPUs, spawn up n processes.
mp.spawn( mp.spawn(
run, run,
args=(nprocs, devices, g, data, args.mode, args.num_epochs), args=(nprocs, devices, g, data, args),
nprocs=nprocs, nprocs=nprocs,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment