Unverified Commit 2da6acef authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] Refine the multi-GPU example (#6980)

parent b3841c25
......@@ -89,9 +89,7 @@ def create_dataloader(
features,
itemset,
device,
drop_last=False,
shuffle=True,
drop_uneven_inputs=False,
is_train,
):
############################################################################
# [HIGHLIGHT]
......@@ -122,9 +120,9 @@ def create_dataloader(
datapipe = gb.DistributedItemSampler(
item_set=itemset,
batch_size=args.batch_size,
drop_last=drop_last,
shuffle=shuffle,
drop_uneven_inputs=drop_uneven_inputs,
drop_last=is_train,
shuffle=is_train,
drop_uneven_inputs=is_train,
)
############################################################################
# [Note]:
......@@ -190,7 +188,7 @@ def train(
epoch_start = time.time()
model.train()
total_loss = torch.tensor(0, dtype=torch.float).to(device)
total_loss = torch.tensor(0, dtype=torch.float, device=device)
########################################################################
# (HIGHLIGHT) Use Join Context Manager to solve uneven input problem.
#
......@@ -230,21 +228,18 @@ def train(
loss.backward()
optimizer.step()
total_loss += loss
total_loss += loss.detach()
# Evaluate the model.
if rank == 0:
print("Validating...")
acc = (
evaluate(
acc = evaluate(
rank,
model,
valid_dataloader,
num_classes,
device,
)
/ world_size
)
########################################################################
# (HIGHLIGHT) Collect accuracy and loss values from sub-processes and
# obtain overall average values.
......@@ -255,14 +250,13 @@ def train(
dist.reduce(tensor=acc, dst=0)
total_loss /= step + 1
dist.reduce(tensor=total_loss, dst=0)
dist.barrier()
epoch_end = time.time()
if rank == 0:
print(
f"Epoch {epoch:05d} | "
f"Average Loss {total_loss.item() / world_size:.4f} | "
f"Accuracy {acc.item():.4f} | "
f"Accuracy {acc.item() / world_size:.4f} | "
f"Time {epoch_end - epoch_start:.4f}"
)
......@@ -304,9 +298,7 @@ def run(rank, world_size, args, devices, dataset):
dataset.feature,
train_set,
device,
drop_last=False,
shuffle=True,
drop_uneven_inputs=False,
is_train=True,
)
valid_dataloader = create_dataloader(
args,
......@@ -314,9 +306,7 @@ def run(rank, world_size, args, devices, dataset):
dataset.feature,
valid_set,
device,
drop_last=False,
shuffle=False,
drop_uneven_inputs=False,
is_train=False,
)
test_dataloader = create_dataloader(
args,
......@@ -324,9 +314,7 @@ def run(rank, world_size, args, devices, dataset):
dataset.feature,
test_set,
device,
drop_last=False,
shuffle=False,
drop_uneven_inputs=False,
is_train=False,
)
# Model training.
......@@ -357,7 +345,7 @@ def run(rank, world_size, args, devices, dataset):
/ world_size
)
dist.reduce(tensor=test_acc, dst=0)
dist.barrier()
torch.cuda.synchronize()
if rank == 0:
print(f"Test Accuracy {test_acc.item():.4f}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment