Unverified Commit 3f3652e0 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Graphbolt] Polish the quickstart node classification tasks. (#6530)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
parent 6dcdaf59
"""
[Semi-Supervised Classification with Graph Convolutional Networks]
(https://arxiv.org/abs/1609.02907)
This example shows how to create a GraphBolt dataloader to sample and train a
node classification model with the Cora dataset.
"""
import dgl.graphbolt as gb
import dgl.nn as dglnn
......@@ -15,7 +15,7 @@ import torchmetrics.functional as MF
############################################################################
def create_dataloader(dateset, itemset, device):
# Sample seed nodes from the itemset.
datapipe = gb.ItemSampler(itemset, batch_size=8)
datapipe = gb.ItemSampler(itemset, batch_size=16)
# Sample neighbors for the seed nodes.
datapipe = datapipe.sample_neighbor(dataset.graph, fanouts=[4, 2])
......@@ -73,6 +73,7 @@ def evaluate(model, dataset, itemset, device):
def train(model, dataset, device):
# The first of two tasks in the dataset is node classification.
task = dataset.tasks[0]
dataloader = create_dataloader(dataset, task.train_set, device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
......@@ -119,9 +120,12 @@ if __name__ == "__main__":
# Load and preprocess dataset.
print("Loading data...")
dataset = gb.OnDiskDataset(
"examples/sampling/graphbolt/quickstart/cora/"
).load()
dataset = gb.BuiltinDataset("cora").load()
# Uncomment to use the example cora dataset.
# dataset = gb.OnDiskDataset(
# "examples/sampling/graphbolt/quickstart/cora/"
# ).load()
in_size = dataset.feature.size("node", None, "feat")[0]
out_size = dataset.tasks[0].metadata["num_classes"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment