Unverified Commit 2da6acef authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] Refine the multi-GPU example (#6980)

parent b3841c25
...@@ -89,9 +89,7 @@ def create_dataloader( ...@@ -89,9 +89,7 @@ def create_dataloader(
features, features,
itemset, itemset,
device, device,
drop_last=False, is_train,
shuffle=True,
drop_uneven_inputs=False,
): ):
############################################################################ ############################################################################
# [HIGHLIGHT] # [HIGHLIGHT]
...@@ -122,9 +120,9 @@ def create_dataloader( ...@@ -122,9 +120,9 @@ def create_dataloader(
datapipe = gb.DistributedItemSampler( datapipe = gb.DistributedItemSampler(
item_set=itemset, item_set=itemset,
batch_size=args.batch_size, batch_size=args.batch_size,
drop_last=drop_last, drop_last=is_train,
shuffle=shuffle, shuffle=is_train,
drop_uneven_inputs=drop_uneven_inputs, drop_uneven_inputs=is_train,
) )
############################################################################ ############################################################################
# [Note]: # [Note]:
...@@ -190,7 +188,7 @@ def train( ...@@ -190,7 +188,7 @@ def train(
epoch_start = time.time() epoch_start = time.time()
model.train() model.train()
total_loss = torch.tensor(0, dtype=torch.float).to(device) total_loss = torch.tensor(0, dtype=torch.float, device=device)
######################################################################## ########################################################################
# (HIGHLIGHT) Use Join Context Manager to solve uneven input problem. # (HIGHLIGHT) Use Join Context Manager to solve uneven input problem.
# #
...@@ -230,20 +228,17 @@ def train( ...@@ -230,20 +228,17 @@ def train(
loss.backward() loss.backward()
optimizer.step() optimizer.step()
total_loss += loss total_loss += loss.detach()
# Evaluate the model. # Evaluate the model.
if rank == 0: if rank == 0:
print("Validating...") print("Validating...")
acc = ( acc = evaluate(
evaluate( rank,
rank, model,
model, valid_dataloader,
valid_dataloader, num_classes,
num_classes, device,
device,
)
/ world_size
) )
######################################################################## ########################################################################
# (HIGHLIGHT) Collect accuracy and loss values from sub-processes and # (HIGHLIGHT) Collect accuracy and loss values from sub-processes and
...@@ -255,14 +250,13 @@ def train( ...@@ -255,14 +250,13 @@ def train(
dist.reduce(tensor=acc, dst=0) dist.reduce(tensor=acc, dst=0)
total_loss /= step + 1 total_loss /= step + 1
dist.reduce(tensor=total_loss, dst=0) dist.reduce(tensor=total_loss, dst=0)
dist.barrier()
epoch_end = time.time() epoch_end = time.time()
if rank == 0: if rank == 0:
print( print(
f"Epoch {epoch:05d} | " f"Epoch {epoch:05d} | "
f"Average Loss {total_loss.item() / world_size:.4f} | " f"Average Loss {total_loss.item() / world_size:.4f} | "
f"Accuracy {acc.item():.4f} | " f"Accuracy {acc.item() / world_size:.4f} | "
f"Time {epoch_end - epoch_start:.4f}" f"Time {epoch_end - epoch_start:.4f}"
) )
...@@ -304,9 +298,7 @@ def run(rank, world_size, args, devices, dataset): ...@@ -304,9 +298,7 @@ def run(rank, world_size, args, devices, dataset):
dataset.feature, dataset.feature,
train_set, train_set,
device, device,
drop_last=False, is_train=True,
shuffle=True,
drop_uneven_inputs=False,
) )
valid_dataloader = create_dataloader( valid_dataloader = create_dataloader(
args, args,
...@@ -314,9 +306,7 @@ def run(rank, world_size, args, devices, dataset): ...@@ -314,9 +306,7 @@ def run(rank, world_size, args, devices, dataset):
dataset.feature, dataset.feature,
valid_set, valid_set,
device, device,
drop_last=False, is_train=False,
shuffle=False,
drop_uneven_inputs=False,
) )
test_dataloader = create_dataloader( test_dataloader = create_dataloader(
args, args,
...@@ -324,9 +314,7 @@ def run(rank, world_size, args, devices, dataset): ...@@ -324,9 +314,7 @@ def run(rank, world_size, args, devices, dataset):
dataset.feature, dataset.feature,
test_set, test_set,
device, device,
drop_last=False, is_train=False,
shuffle=False,
drop_uneven_inputs=False,
) )
# Model training. # Model training.
...@@ -357,7 +345,7 @@ def run(rank, world_size, args, devices, dataset): ...@@ -357,7 +345,7 @@ def run(rank, world_size, args, devices, dataset):
/ world_size / world_size
) )
dist.reduce(tensor=test_acc, dst=0) dist.reduce(tensor=test_acc, dst=0)
dist.barrier() torch.cuda.synchronize()
if rank == 0: if rank == 0:
print(f"Test Accuracy {test_acc.item():.4f}") print(f"Test Accuracy {test_acc.item():.4f}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment