You need to sign in or sign up before continuing.
Unverified Commit 24de7391 authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[GraphBolt] Modify multi-gpu example to make use of the persistent_workers (#6603)

parent c63a926d
...@@ -148,20 +148,10 @@ def create_dataloader( ...@@ -148,20 +148,10 @@ def create_dataloader(
@torch.no_grad() @torch.no_grad()
def evaluate(rank, args, model, graph, features, itemset, num_classes, device): def evaluate(rank, model, dataloader, num_classes, device):
model.eval() model.eval()
y = [] y = []
y_hats = [] y_hats = []
dataloader = create_dataloader(
args,
graph,
features,
itemset,
drop_last=False,
shuffle=False,
drop_uneven_inputs=False,
device=device,
)
for step, data in ( for step, data in (
tqdm.tqdm(enumerate(dataloader)) if rank == 0 else enumerate(dataloader) tqdm.tqdm(enumerate(dataloader)) if rank == 0 else enumerate(dataloader)
...@@ -185,26 +175,13 @@ def train( ...@@ -185,26 +175,13 @@ def train(
world_size, world_size,
rank, rank,
args, args,
graph, train_dataloader,
features, valid_dataloader,
train_set,
valid_set,
num_classes, num_classes,
model, model,
device, device,
): ):
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
# Create training data loader.
dataloader = create_dataloader(
args,
graph,
features,
train_set,
device,
drop_last=False,
shuffle=True,
drop_uneven_inputs=False,
)
for epoch in range(args.epochs): for epoch in range(args.epochs):
epoch_start = time.time() epoch_start = time.time()
...@@ -227,9 +204,9 @@ def train( ...@@ -227,9 +204,9 @@ def train(
######################################################################## ########################################################################
with Join([model]): with Join([model]):
for step, data in ( for step, data in (
tqdm.tqdm(enumerate(dataloader)) tqdm.tqdm(enumerate(train_dataloader))
if rank == 0 if rank == 0
else enumerate(dataloader) else enumerate(train_dataloader)
): ):
# The input features are from the source nodes in the first # The input features are from the source nodes in the first
# layer's computation graph. # layer's computation graph.
...@@ -258,11 +235,8 @@ def train( ...@@ -258,11 +235,8 @@ def train(
acc = ( acc = (
evaluate( evaluate(
rank, rank,
args,
model, model,
graph, valid_dataloader,
features,
valid_set,
num_classes, num_classes,
device, device,
) )
...@@ -305,6 +279,7 @@ def run(rank, world_size, args, devices, dataset): ...@@ -305,6 +279,7 @@ def run(rank, world_size, args, devices, dataset):
features = dataset.feature features = dataset.feature
train_set = dataset.tasks[0].train_set train_set = dataset.tasks[0].train_set
valid_set = dataset.tasks[0].validation_set valid_set = dataset.tasks[0].validation_set
test_set = dataset.tasks[0].test_set
args.fanout = list(map(int, args.fanout.split(","))) args.fanout = list(map(int, args.fanout.split(",")))
num_classes = dataset.tasks[0].metadata["num_classes"] num_classes = dataset.tasks[0].metadata["num_classes"]
...@@ -316,6 +291,38 @@ def run(rank, world_size, args, devices, dataset): ...@@ -316,6 +291,38 @@ def run(rank, world_size, args, devices, dataset):
model = SAGE(in_size, hidden_size, out_size).to(device) model = SAGE(in_size, hidden_size, out_size).to(device)
model = DDP(model) model = DDP(model)
# Create data loaders.
train_dataloader = create_dataloader(
args,
graph,
features,
train_set,
device,
drop_last=False,
shuffle=True,
drop_uneven_inputs=False,
)
valid_dataloader = create_dataloader(
args,
graph,
features,
valid_set,
device,
drop_last=False,
shuffle=False,
drop_uneven_inputs=False,
)
test_dataloader = create_dataloader(
args,
graph,
features,
test_set,
device,
drop_last=False,
shuffle=False,
drop_uneven_inputs=False,
)
# Model training. # Model training.
if rank == 0: if rank == 0:
print("Training...") print("Training...")
...@@ -323,10 +330,8 @@ def run(rank, world_size, args, devices, dataset): ...@@ -323,10 +330,8 @@ def run(rank, world_size, args, devices, dataset):
world_size, world_size,
rank, rank,
args, args,
graph, train_dataloader,
features, valid_dataloader,
train_set,
valid_set,
num_classes, num_classes,
model, model,
device, device,
...@@ -335,17 +340,13 @@ def run(rank, world_size, args, devices, dataset): ...@@ -335,17 +340,13 @@ def run(rank, world_size, args, devices, dataset):
# Test the model. # Test the model.
if rank == 0: if rank == 0:
print("Testing...") print("Testing...")
test_set = dataset.tasks[0].test_set
test_acc = ( test_acc = (
evaluate( evaluate(
rank, rank,
args,
model, model,
graph, test_dataloader,
features, num_classes,
itemset=test_set, device,
num_classes=num_classes,
device=device,
) )
/ world_size / world_size
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment