Unverified Commit b452043c authored by Mingbang Wang's avatar Mingbang Wang Committed by GitHub
Browse files

[Misc] Add compare-to-graphbolt mode for regression test (#6569)

parent 23649071
...@@ -75,7 +75,7 @@ class SAGE(nn.Module): ...@@ -75,7 +75,7 @@ class SAGE(nn.Module):
hidden_x = self.dropout(hidden_x) hidden_x = self.dropout(hidden_x)
return hidden_x return hidden_x
def inference(self, g, device, batch_size): def inference(self, g, device, batch_size, fused_sampling: bool = True):
"""Conduct layer-wise inference to get all the node embeddings.""" """Conduct layer-wise inference to get all the node embeddings."""
feat = g.ndata["feat"] feat = g.ndata["feat"]
##################################################################### #####################################################################
...@@ -109,7 +109,9 @@ class SAGE(nn.Module): ...@@ -109,7 +109,9 @@ class SAGE(nn.Module):
# │ │ │ # │ │ │
# └─Compute1 └─Compute2 └─Compute3 # └─Compute1 └─Compute2 └─Compute3
##################################################################### #####################################################################
sampler = MultiLayerFullNeighborSampler(1, prefetch_node_feats=["feat"]) sampler = MultiLayerFullNeighborSampler(
1, prefetch_node_feats=["feat"], fused=fused_sampling
)
dataloader = DataLoader( dataloader = DataLoader(
g, g,
...@@ -167,18 +169,22 @@ def evaluate(model, graph, dataloader, num_classes): ...@@ -167,18 +169,22 @@ def evaluate(model, graph, dataloader, num_classes):
@torch.no_grad() @torch.no_grad()
def layerwise_infer(device, graph, nid, model, num_classes, batch_size): def layerwise_infer(
device, graph, nid, model, num_classes, batch_size, fused_sampling
):
model.eval() model.eval()
pred = model.inference(graph, device, batch_size) # pred in buffer_device. pred = model.inference(
graph, device, batch_size, fused_sampling
) # pred in buffer_device.
pred = pred[nid] pred = pred[nid]
label = graph.ndata["label"][nid].to(pred.device) label = graph.ndata["label"][nid].to(pred.device)
return MF.accuracy(pred, label, task="multiclass", num_classes=num_classes) return MF.accuracy(pred, label, task="multiclass", num_classes=num_classes)
def train(args, device, g, dataset, model, num_classes, use_uva): def train(device, g, dataset, model, num_classes, use_uva, fused_sampling):
# Create sampler & dataloader. # Create sampler & dataloader.
train_idx = dataset.train_idx.to(device) train_idx = dataset.train_idx.to(g.device if not use_uva else device)
val_idx = dataset.val_idx.to(device) val_idx = dataset.val_idx.to(g.device if not use_uva else device)
##################################################################### #####################################################################
# (HIGHLIGHT) Instantiate a NeighborSampler object for efficient # (HIGHLIGHT) Instantiate a NeighborSampler object for efficient
# training of Graph Neural Networks (GNNs) on large-scale graphs. # training of Graph Neural Networks (GNNs) on large-scale graphs.
...@@ -197,6 +203,7 @@ def train(args, device, g, dataset, model, num_classes, use_uva): ...@@ -197,6 +203,7 @@ def train(args, device, g, dataset, model, num_classes, use_uva):
[10, 10, 10], # fanout for [layer-0, layer-1, layer-2] [10, 10, 10], # fanout for [layer-0, layer-1, layer-2]
prefetch_node_feats=["feat"], prefetch_node_feats=["feat"],
prefetch_labels=["label"], prefetch_labels=["label"],
fused=fused_sampling,
) )
train_dataloader = DataLoader( train_dataloader = DataLoader(
...@@ -267,7 +274,7 @@ if __name__ == "__main__": ...@@ -267,7 +274,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--mode", "--mode",
default="mixed", default="mixed",
choices=["cpu", "mixed", "gpu"], choices=["cpu", "mixed", "gpu", "compare-to-graphbolt"],
help="Training mode. 'cpu' for CPU training, 'mixed' for " help="Training mode. 'cpu' for CPU training, 'mixed' for "
"CPU-GPU mixed training, 'gpu' for pure-GPU training.", "CPU-GPU mixed training, 'gpu' for pure-GPU training.",
) )
...@@ -285,6 +292,7 @@ if __name__ == "__main__": ...@@ -285,6 +292,7 @@ if __name__ == "__main__":
# Whether use Unified Virtual Addressing (UVA) for CUDA computation. # Whether use Unified Virtual Addressing (UVA) for CUDA computation.
use_uva = args.mode == "mixed" use_uva = args.mode == "mixed"
device = torch.device("cpu" if args.mode == "cpu" else "cuda") device = torch.device("cpu" if args.mode == "cpu" else "cuda")
fused_sampling = args.mode != "compare-to-graphbolt"
# Create GraphSAGE model. # Create GraphSAGE model.
in_size = g.ndata["feat"].shape[1] in_size = g.ndata["feat"].shape[1]
...@@ -293,11 +301,17 @@ if __name__ == "__main__": ...@@ -293,11 +301,17 @@ if __name__ == "__main__":
# Model training. # Model training.
print("Training...") print("Training...")
train(args, device, g, dataset, model, num_classes, use_uva) train(device, g, dataset, model, num_classes, use_uva, fused_sampling)
# Test the model. # Test the model.
print("Testing...") print("Testing...")
acc = layerwise_infer( acc = layerwise_infer(
device, g, dataset.test_idx, model, num_classes, batch_size=4096 device,
g,
dataset.test_idx,
model,
num_classes,
batch_size=4096,
fused_sampling=fused_sampling,
) )
print(f"Test accuracy {acc.item():.4f}") print(f"Test accuracy {acc.item():.4f}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment