"git@developer.sourcefind.cn:OpenDAS/torchani.git" did not exist on "48eeb3dc79ac29c5d759186d646c2acb38763aff"
Unverified Commit 3f3652e0 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Graphbolt] Polish the quickstart node classification tasks. (#6530)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
parent 6dcdaf59
""" """
[Semi-Supervised Classification with Graph Convolutional Networks] This example shows how to create a GraphBolt dataloader to sample and train a
(https://arxiv.org/abs/1609.02907) node classification model with the Cora dataset.
""" """
import dgl.graphbolt as gb import dgl.graphbolt as gb
import dgl.nn as dglnn import dgl.nn as dglnn
...@@ -15,7 +15,7 @@ import torchmetrics.functional as MF ...@@ -15,7 +15,7 @@ import torchmetrics.functional as MF
############################################################################ ############################################################################
def create_dataloader(dateset, itemset, device): def create_dataloader(dateset, itemset, device):
# Sample seed nodes from the itemset. # Sample seed nodes from the itemset.
datapipe = gb.ItemSampler(itemset, batch_size=8) datapipe = gb.ItemSampler(itemset, batch_size=16)
# Sample neighbors for the seed nodes. # Sample neighbors for the seed nodes.
datapipe = datapipe.sample_neighbor(dataset.graph, fanouts=[4, 2]) datapipe = datapipe.sample_neighbor(dataset.graph, fanouts=[4, 2])
...@@ -73,6 +73,7 @@ def evaluate(model, dataset, itemset, device): ...@@ -73,6 +73,7 @@ def evaluate(model, dataset, itemset, device):
def train(model, dataset, device): def train(model, dataset, device):
# The first of two tasks in the dataset is node classification.
task = dataset.tasks[0] task = dataset.tasks[0]
dataloader = create_dataloader(dataset, task.train_set, device) dataloader = create_dataloader(dataset, task.train_set, device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2) optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
...@@ -119,9 +120,12 @@ if __name__ == "__main__": ...@@ -119,9 +120,12 @@ if __name__ == "__main__":
# Load and preprocess dataset. # Load and preprocess dataset.
print("Loading data...") print("Loading data...")
dataset = gb.OnDiskDataset( dataset = gb.BuiltinDataset("cora").load()
"examples/sampling/graphbolt/quickstart/cora/"
).load() # Uncomment to use the example cora dataset.
# dataset = gb.OnDiskDataset(
# "examples/sampling/graphbolt/quickstart/cora/"
# ).load()
in_size = dataset.feature.size("node", None, "feat")[0] in_size = dataset.feature.size("node", None, "feat")[0]
out_size = dataset.tasks[0].metadata["num_classes"] out_size = dataset.tasks[0].metadata["num_classes"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment