Unverified Commit 8909d1ff authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] refine the node_classification examples. (#7136)


Co-authored-by: default avatarMingbang Wang <100203018+Skeleton003@users.noreply.github.com>
parent 3e139033
...@@ -145,6 +145,23 @@ def create_dataloader( ...@@ -145,6 +145,23 @@ def create_dataloader(
return dataloader return dataloader
def weighted_reduce(tensor, weight, dst=0):
########################################################################
# (HIGHLIGHT) Collect accuracy and loss values from sub-processes and
# obtain overall average values.
#
# `torch.distributed.reduce` is used to reduce tensors from all the
# sub-processes to a specified process, ReduceOp.SUM is used by default.
#
# Because the GPUs may have differing numbers of processed items, we
# perform a weighted mean to calculate the exact loss and accuracy.
########################################################################
dist.reduce(tensor=tensor, dst=dst)
weight = torch.tensor(weight, device=tensor.device)
dist.reduce(tensor=weight, dst=dst)
return tensor / weight
@torch.no_grad() @torch.no_grad()
def evaluate(rank, model, dataloader, num_classes, device): def evaluate(rank, model, dataloader, num_classes, device):
model.eval() model.eval()
...@@ -164,11 +181,10 @@ def evaluate(rank, model, dataloader, num_classes, device): ...@@ -164,11 +181,10 @@ def evaluate(rank, model, dataloader, num_classes, device):
num_classes=num_classes, num_classes=num_classes,
) )
return res.to(device) return res.to(device), sum(y_i.size(0) for y_i in y)
def train( def train(
world_size,
rank, rank,
args, args,
train_dataloader, train_dataloader,
...@@ -184,6 +200,7 @@ def train( ...@@ -184,6 +200,7 @@ def train(
model.train() model.train()
total_loss = torch.tensor(0, dtype=torch.float, device=device) total_loss = torch.tensor(0, dtype=torch.float, device=device)
num_train_items = 0
######################################################################## ########################################################################
# (HIGHLIGHT) Use Join Context Manager to solve uneven input problem. # (HIGHLIGHT) Use Join Context Manager to solve uneven input problem.
# #
...@@ -199,10 +216,8 @@ def train( ...@@ -199,10 +216,8 @@ def train(
# uneven inputs. # uneven inputs.
######################################################################## ########################################################################
with Join([model]): with Join([model]):
for step, data in ( for data in (
tqdm.tqdm(enumerate(train_dataloader)) tqdm.tqdm(train_dataloader) if rank == 0 else train_dataloader
if rank == 0
else enumerate(train_dataloader)
): ):
# The input features are from the source nodes in the first # The input features are from the source nodes in the first
# layer's computation graph. # layer's computation graph.
...@@ -223,35 +238,31 @@ def train( ...@@ -223,35 +238,31 @@ def train(
loss.backward() loss.backward()
optimizer.step() optimizer.step()
total_loss += loss.detach() total_loss += loss.detach() * y.size(0)
num_train_items += y.size(0)
# Evaluate the model. # Evaluate the model.
if rank == 0: if rank == 0:
print("Validating...") print("Validating...")
acc = evaluate( acc, num_val_items = evaluate(
rank, rank,
model, model,
valid_dataloader, valid_dataloader,
num_classes, num_classes,
device, device,
) )
########################################################################
# (HIGHLIGHT) Collect accuracy and loss values from sub-processes and
# obtain overall average values.
#
# `torch.distributed.reduce` is used to reduce tensors from all the
# sub-processes to a specified process, ReduceOp.SUM is used by default.
########################################################################
dist.reduce(tensor=acc, dst=0)
total_loss /= step + 1
dist.reduce(tensor=total_loss, dst=0)
total_loss = weighted_reduce(total_loss, num_train_items)
acc = weighted_reduce(acc * num_val_items, num_val_items)
# We synchronize before measuring the epoch time.
torch.cuda.synchronize()
epoch_end = time.time() epoch_end = time.time()
if rank == 0: if rank == 0:
print( print(
f"Epoch {epoch:05d} | " f"Epoch {epoch:05d} | "
f"Average Loss {total_loss.item() / world_size:.4f} | " f"Average Loss {total_loss.item():.4f} | "
f"Accuracy {acc.item() / world_size:.4f} | " f"Accuracy {acc.item():.4f} | "
f"Time {epoch_end - epoch_start:.4f}" f"Time {epoch_end - epoch_start:.4f}"
) )
...@@ -325,7 +336,6 @@ def run(rank, world_size, args, devices, dataset): ...@@ -325,7 +336,6 @@ def run(rank, world_size, args, devices, dataset):
if rank == 0: if rank == 0:
print("Training...") print("Training...")
train( train(
world_size,
rank, rank,
args, args,
train_dataloader, train_dataloader,
...@@ -338,18 +348,15 @@ def run(rank, world_size, args, devices, dataset): ...@@ -338,18 +348,15 @@ def run(rank, world_size, args, devices, dataset):
# Test the model. # Test the model.
if rank == 0: if rank == 0:
print("Testing...") print("Testing...")
test_acc = ( test_acc, num_test_items = evaluate(
evaluate( rank,
rank, model,
model, test_dataloader,
test_dataloader, num_classes,
num_classes, device,
device,
)
/ world_size
) )
dist.reduce(tensor=test_acc, dst=0) test_acc = weighted_reduce(test_acc * num_test_items, num_test_items)
torch.cuda.synchronize()
if rank == 0: if rank == 0:
print(f"Test Accuracy {test_acc.item():.4f}") print(f"Test Accuracy {test_acc.item():.4f}")
...@@ -394,6 +401,14 @@ def parse_args(): ...@@ -394,6 +401,14 @@ def parse_args():
default=0, default=0,
help="The capacity of the GPU cache, the number of features to store.", help="The capacity of the GPU cache, the number of features to store.",
) )
parser.add_argument(
"--dataset",
type=str,
default="ogbn-products",
choices=["ogbn-arxiv", "ogbn-products", "ogbn-papers100M"],
help="The dataset we can use for node classification example. Currently"
" ogbn-products, ogbn-arxiv, ogbn-papers100M datasets are supported.",
)
parser.add_argument( parser.add_argument(
"--mode", "--mode",
default="pinned-cuda", default="pinned-cuda",
...@@ -417,7 +432,7 @@ if __name__ == "__main__": ...@@ -417,7 +432,7 @@ if __name__ == "__main__":
print(f"Training with {world_size} gpus.") print(f"Training with {world_size} gpus.")
# Load and preprocess dataset. # Load and preprocess dataset.
dataset = gb.BuiltinDataset("ogbn-products").load() dataset = gb.BuiltinDataset(args.dataset).load()
# Thread limiting to avoid resource competition. # Thread limiting to avoid resource competition.
os.environ["OMP_NUM_THREADS"] = str(mp.cpu_count() // 2 // world_size) os.environ["OMP_NUM_THREADS"] = str(mp.cpu_count() // 2 // world_size)
......
...@@ -365,8 +365,9 @@ def parse_args(): ...@@ -365,8 +365,9 @@ def parse_args():
"--dataset", "--dataset",
type=str, type=str,
default="ogbn-products", default="ogbn-products",
choices=["ogbn-arxiv", "ogbn-products", "ogbn-papers100M"],
help="The dataset we can use for node classification example. Currently" help="The dataset we can use for node classification example. Currently"
"dataset ogbn-products, ogbn-arxiv, ogbn-papers100M is supported.", " ogbn-products, ogbn-arxiv, ogbn-papers100M datasets are supported.",
) )
parser.add_argument( parser.add_argument(
"--mode", "--mode",
......
...@@ -186,7 +186,6 @@ def evaluate(rank, model, graph, features, itemset, num_classes, device): ...@@ -186,7 +186,6 @@ def evaluate(rank, model, graph, features, itemset, num_classes, device):
def train( def train(
world_size,
rank, rank,
graph, graph,
features, features,
...@@ -233,7 +232,7 @@ def train( ...@@ -233,7 +232,7 @@ def train(
loss.backward() loss.backward()
optimizer.step() optimizer.step()
total_loss += loss * y.size(0) total_loss += loss.detach() * y.size(0)
num_train_items += y.size(0) num_train_items += y.size(0)
# Evaluate the model. # Evaluate the model.
...@@ -304,7 +303,6 @@ def run(rank, world_size, devices, dataset): ...@@ -304,7 +303,6 @@ def run(rank, world_size, devices, dataset):
if rank == 0: if rank == 0:
print("Training...") print("Training...")
train( train(
world_size,
rank, rank,
graph, graph,
features, features,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment