".github/vscode:/vscode.git/clone" did not exist on "f79c8e984e7544ee6f457951980e4e6656caeb94"
Unverified Commit b452043c authored by Mingbang Wang's avatar Mingbang Wang Committed by GitHub
Browse files

[Misc] Add compare-to-graphbolt mode for regression test (#6569)

parent 23649071
......@@ -75,7 +75,7 @@ class SAGE(nn.Module):
hidden_x = self.dropout(hidden_x)
return hidden_x
def inference(self, g, device, batch_size):
def inference(self, g, device, batch_size, fused_sampling: bool = True):
"""Conduct layer-wise inference to get all the node embeddings."""
feat = g.ndata["feat"]
#####################################################################
......@@ -109,7 +109,9 @@ class SAGE(nn.Module):
# │ │ │
# └─Compute1 └─Compute2 └─Compute3
#####################################################################
sampler = MultiLayerFullNeighborSampler(1, prefetch_node_feats=["feat"])
sampler = MultiLayerFullNeighborSampler(
1, prefetch_node_feats=["feat"], fused=fused_sampling
)
dataloader = DataLoader(
g,
......@@ -167,18 +169,22 @@ def evaluate(model, graph, dataloader, num_classes):
@torch.no_grad()
def layerwise_infer(device, graph, nid, model, num_classes, batch_size):
def layerwise_infer(
device, graph, nid, model, num_classes, batch_size, fused_sampling
):
model.eval()
pred = model.inference(graph, device, batch_size) # pred in buffer_device.
pred = model.inference(
graph, device, batch_size, fused_sampling
) # pred in buffer_device.
pred = pred[nid]
label = graph.ndata["label"][nid].to(pred.device)
return MF.accuracy(pred, label, task="multiclass", num_classes=num_classes)
def train(args, device, g, dataset, model, num_classes, use_uva):
def train(device, g, dataset, model, num_classes, use_uva, fused_sampling):
# Create sampler & dataloader.
train_idx = dataset.train_idx.to(device)
val_idx = dataset.val_idx.to(device)
train_idx = dataset.train_idx.to(g.device if not use_uva else device)
val_idx = dataset.val_idx.to(g.device if not use_uva else device)
#####################################################################
# (HIGHLIGHT) Instantiate a NeighborSampler object for efficient
# training of Graph Neural Networks (GNNs) on large-scale graphs.
......@@ -197,6 +203,7 @@ def train(args, device, g, dataset, model, num_classes, use_uva):
[10, 10, 10], # fanout for [layer-0, layer-1, layer-2]
prefetch_node_feats=["feat"],
prefetch_labels=["label"],
fused=fused_sampling,
)
train_dataloader = DataLoader(
......@@ -267,7 +274,7 @@ if __name__ == "__main__":
parser.add_argument(
"--mode",
default="mixed",
choices=["cpu", "mixed", "gpu"],
choices=["cpu", "mixed", "gpu", "compare-to-graphbolt"],
help="Training mode. 'cpu' for CPU training, 'mixed' for "
"CPU-GPU mixed training, 'gpu' for pure-GPU training.",
)
......@@ -285,6 +292,7 @@ if __name__ == "__main__":
# Whether use Unified Virtual Addressing (UVA) for CUDA computation.
use_uva = args.mode == "mixed"
device = torch.device("cpu" if args.mode == "cpu" else "cuda")
fused_sampling = args.mode != "compare-to-graphbolt"
# Create GraphSAGE model.
in_size = g.ndata["feat"].shape[1]
......@@ -293,11 +301,17 @@ if __name__ == "__main__":
# Model training.
print("Training...")
train(args, device, g, dataset, model, num_classes, use_uva)
train(device, g, dataset, model, num_classes, use_uva, fused_sampling)
# Test the model.
print("Testing...")
acc = layerwise_infer(
device, g, dataset.test_idx, model, num_classes, batch_size=4096
device,
g,
dataset.test_idx,
model,
num_classes,
batch_size=4096,
fused_sampling=fused_sampling,
)
print(f"Test accuracy {acc.item():.4f}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment