Unverified Commit 24de7391 authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[GraphBolt] Modify multi-gpu example to make use of the persistent_workers (#6603)

parent c63a926d
......@@ -148,20 +148,10 @@ def create_dataloader(
@torch.no_grad()
def evaluate(rank, args, model, graph, features, itemset, num_classes, device):
def evaluate(rank, model, dataloader, num_classes, device):
model.eval()
y = []
y_hats = []
dataloader = create_dataloader(
args,
graph,
features,
itemset,
drop_last=False,
shuffle=False,
drop_uneven_inputs=False,
device=device,
)
for step, data in (
tqdm.tqdm(enumerate(dataloader)) if rank == 0 else enumerate(dataloader)
......@@ -185,26 +175,13 @@ def train(
world_size,
rank,
args,
graph,
features,
train_set,
valid_set,
train_dataloader,
valid_dataloader,
num_classes,
model,
device,
):
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
# Create training data loader.
dataloader = create_dataloader(
args,
graph,
features,
train_set,
device,
drop_last=False,
shuffle=True,
drop_uneven_inputs=False,
)
for epoch in range(args.epochs):
epoch_start = time.time()
......@@ -227,9 +204,9 @@ def train(
########################################################################
with Join([model]):
for step, data in (
tqdm.tqdm(enumerate(dataloader))
tqdm.tqdm(enumerate(train_dataloader))
if rank == 0
else enumerate(dataloader)
else enumerate(train_dataloader)
):
# The input features are from the source nodes in the first
# layer's computation graph.
......@@ -258,11 +235,8 @@ def train(
acc = (
evaluate(
rank,
args,
model,
graph,
features,
valid_set,
valid_dataloader,
num_classes,
device,
)
......@@ -305,6 +279,7 @@ def run(rank, world_size, args, devices, dataset):
features = dataset.feature
train_set = dataset.tasks[0].train_set
valid_set = dataset.tasks[0].validation_set
test_set = dataset.tasks[0].test_set
args.fanout = list(map(int, args.fanout.split(",")))
num_classes = dataset.tasks[0].metadata["num_classes"]
......@@ -316,6 +291,38 @@ def run(rank, world_size, args, devices, dataset):
model = SAGE(in_size, hidden_size, out_size).to(device)
model = DDP(model)
# Create data loaders.
train_dataloader = create_dataloader(
args,
graph,
features,
train_set,
device,
drop_last=False,
shuffle=True,
drop_uneven_inputs=False,
)
valid_dataloader = create_dataloader(
args,
graph,
features,
valid_set,
device,
drop_last=False,
shuffle=False,
drop_uneven_inputs=False,
)
test_dataloader = create_dataloader(
args,
graph,
features,
test_set,
device,
drop_last=False,
shuffle=False,
drop_uneven_inputs=False,
)
# Model training.
if rank == 0:
print("Training...")
......@@ -323,10 +330,8 @@ def run(rank, world_size, args, devices, dataset):
world_size,
rank,
args,
graph,
features,
train_set,
valid_set,
train_dataloader,
valid_dataloader,
num_classes,
model,
device,
......@@ -335,17 +340,13 @@ def run(rank, world_size, args, devices, dataset):
# Test the model.
if rank == 0:
print("Testing...")
test_set = dataset.tasks[0].test_set
test_acc = (
evaluate(
rank,
args,
model,
graph,
features,
itemset=test_set,
num_classes=num_classes,
device=device,
test_dataloader,
num_classes,
device,
)
/ world_size
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment