Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
2550ac5c
Commit
2550ac5c
authored
Mar 01, 2024
by
Ramon Zhou
Committed by
RhettYing
Mar 01, 2024
Browse files
[GraphBolt][PyG] Modify PyG example with `to_pyg_data` (#7123)
Co-authored-by:
Muhammed Fatih BALIN
<
m.f.balin@gmail.com
>
parent
9d7fe9d3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
134 additions
and
89 deletions
+134
-89
examples/sampling/pyg/node_classification.py
examples/sampling/pyg/node_classification.py
+134
-89
No files found.
examples/sampling/pyg/node_classification.py
View file @
2550ac5c
"""
"""
This script demonstrates node classification with GraphSAGE on large graphs,
This script demonstrates node classification with GraphSAGE on large graphs,
merging GraphBolt (GB) and PyTorch Geometric (PyG). GraphBolt efficiently manages
merging GraphBolt (GB) and PyTorch Geometric (PyG). GraphBolt efficiently
data loading for large datasets, crucial for mini-batch processing. Post data
manages data loading for large datasets, crucial for mini-batch processing.
loading, PyG's user-friendly framework takes over for training, showcasing seamless
Post data loading, PyG's user-friendly framework takes over for training,
integration with GraphBolt. This combination offers an efficient alternative to
showcasing seamless integration with GraphBolt. This combination offers an
traditional Deep Graph Library (DGL) methods, highlighting adaptability and
efficient alternative to traditional Deep Graph Library (DGL) methods,
scalability in handling large-scale graph data for diverse real-world applications.
highlighting adaptability and scalability in handling large-scale graph data
for diverse real-world applications.
Key Features:
Key Features:
- Implements the GraphSAGE model, a scalable GNN, for node classification on large graphs.
- Implements the GraphSAGE model, a scalable GNN, for node classification on
large graphs.
- Utilizes GraphBolt, an efficient framework for large-scale graph data processing.
- Utilizes GraphBolt, an efficient framework for large-scale graph data processing.
- Integrates with PyTorch Geometric for building and training the GraphSAGE model.
- Integrates with PyTorch Geometric for building and training the GraphSAGE model.
- The script is well-documented, providing clear explanations at each step.
- The script is well-documented, providing clear explanations at each step.
...
@@ -38,6 +38,8 @@ main
...
@@ -38,6 +38,8 @@ main
│ │
│ │
│ ├───> Forward and backward passes
│ ├───> Forward and backward passes
│ │
│ │
│ ├───> Convert GraphBolt MiniBatch to PyG Data
│ │
│ └───> Parameters optimization
│ └───> Parameters optimization
│
│
└───> Evaluate the model
└───> Evaluate the model
...
@@ -56,6 +58,7 @@ import torch
...
@@ -56,6 +58,7 @@ import torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
torchmetrics.functional
as
MF
import
torchmetrics.functional
as
MF
from
torch_geometric.nn
import
SAGEConv
from
torch_geometric.nn
import
SAGEConv
from
tqdm
import
tqdm
class
GraphSAGE
(
torch
.
nn
.
Module
):
class
GraphSAGE
(
torch
.
nn
.
Module
):
...
@@ -67,6 +70,8 @@ class GraphSAGE(torch.nn.Module):
...
@@ -67,6 +70,8 @@ class GraphSAGE(torch.nn.Module):
# - 'in_size', 'hidden_size', 'out_size' are the sizes of
# - 'in_size', 'hidden_size', 'out_size' are the sizes of
# the input, hidden, and output features, respectively.
# the input, hidden, and output features, respectively.
# - The forward method defines the computation performed at every call.
# - The forward method defines the computation performed at every call.
# - It's adopted from the official PyG example which can be found at
# https://github.com/pyg-team/pytorch_geometric/blob/master/examples/ogbn_products_sage.py
#####################################################################
#####################################################################
def
__init__
(
self
,
in_size
,
hidden_size
,
out_size
):
def
__init__
(
self
,
in_size
,
hidden_size
,
out_size
):
super
(
GraphSAGE
,
self
).
__init__
()
super
(
GraphSAGE
,
self
).
__init__
()
...
@@ -75,87 +80,83 @@ class GraphSAGE(torch.nn.Module):
...
@@ -75,87 +80,83 @@ class GraphSAGE(torch.nn.Module):
self
.
layers
.
append
(
SAGEConv
(
hidden_size
,
hidden_size
))
self
.
layers
.
append
(
SAGEConv
(
hidden_size
,
hidden_size
))
self
.
layers
.
append
(
SAGEConv
(
hidden_size
,
out_size
))
self
.
layers
.
append
(
SAGEConv
(
hidden_size
,
out_size
))
def
forward
(
self
,
blocks
,
x
,
device
):
def
forward
(
self
,
x
,
edge_index
):
h
=
x
for
i
,
layer
in
enumerate
(
self
.
layers
):
for
i
,
(
layer
,
block
)
in
enumerate
(
zip
(
self
.
layers
,
blocks
)):
x
=
layer
(
x
,
edge_index
)
src
,
dst
=
block
.
edges
()
if
i
!=
len
(
self
.
layers
)
-
1
:
edge_index
=
torch
.
stack
([
src
,
dst
],
dim
=
0
)
x
=
x
.
relu
()
h_src
,
h_dst
=
h
,
h
[:
block
.
number_of_dst_nodes
()]
x
=
F
.
dropout
(
x
,
p
=
0.5
,
training
=
self
.
training
)
h
=
layer
((
h_src
,
h_dst
),
edge_index
)
return
x
if
i
!=
len
(
blocks
)
-
1
:
h
=
F
.
relu
(
h
)
return
h
def
inference
(
self
,
args
,
dataloader
,
x_all
,
device
):
"""Conduct layer-wise inference to get all the node embeddings."""
for
i
,
layer
in
tqdm
(
enumerate
(
self
.
layers
),
"inference"
):
xs
=
[]
for
minibatch
in
dataloader
:
# Call `to_pyg_data` to convert GB Minibatch to PyG Data.
pyg_data
=
minibatch
.
to_pyg_data
()
n_ids
=
minibatch
.
node_ids
().
to
(
"cpu"
)
x
=
x_all
[
n_ids
].
to
(
device
)
edge_index
=
pyg_data
.
edge_index
x
=
layer
(
x
,
edge_index
)
x
=
x
[:
4
*
args
.
batch_size
]
if
i
!=
len
(
self
.
layers
)
-
1
:
x
=
x
.
relu
()
xs
.
append
(
x
.
cpu
())
x_all
=
torch
.
cat
(
xs
,
dim
=
0
)
return
x_all
def
create_dataloader
(
dataset_set
,
graph
,
feature
,
device
,
is_train
):
#####################################################################
# (HIGHLIGHT) Create a data loader for efficiently loading graph data.
#
# - 'ItemSampler' samples mini-batches of node IDs from the dataset.
# - 'sample_neighbor' performs neighbor sampling on the graph.
# - 'FeatureFetcher' fetches node features based on the sampled subgraph.
# - 'CopyTo' copies the fetched data to the specified device.
#####################################################################
# Create a datapipe for mini-batch sampling with a specific neighbor fanout.
# Here, [10, 10, 10] specifies the number of neighbors sampled for each node at each layer.
# We're using `sample_neighbor` for consistency with DGL's sampling API.
# Note: GraphBolt offers additional sampling methods, such as `sample_layer_neighbor`,
# which could provide further optimization and efficiency for GNN training.
# Users are encouraged to explore these advanced features for potentially improved performance.
def
create_dataloader
(
dataset_set
,
graph
,
feature
,
batch_size
,
fanout
,
device
,
job
):
# Initialize an ItemSampler to sample mini-batches from the dataset.
# Initialize an ItemSampler to sample mini-batches from the dataset.
datapipe
=
gb
.
ItemSampler
(
datapipe
=
gb
.
ItemSampler
(
dataset_set
,
batch_size
=
1024
,
shuffle
=
is_train
,
drop_last
=
is_train
dataset_set
,
batch_size
=
batch_size
,
shuffle
=
(
job
==
"train"
),
drop_last
=
(
job
==
"train"
),
)
)
# Sample neighbors for each node in the mini-batch.
# Sample neighbors for each node in the mini-batch.
datapipe
=
datapipe
.
sample_neighbor
(
graph
,
[
10
,
10
,
10
])
datapipe
=
datapipe
.
sample_neighbor
(
graph
,
fanout
if
job
!=
"infer"
else
[
-
1
]
)
# Copy the data to the specified device.
datapipe
=
datapipe
.
copy_to
(
device
=
device
,
extra_attrs
=
[
"input_nodes"
])
# Fetch node features for the sampled subgraph.
# Fetch node features for the sampled subgraph.
datapipe
=
datapipe
.
fetch_feature
(
feature
,
node_feature_keys
=
[
"feat"
])
datapipe
=
datapipe
.
fetch_feature
(
feature
,
node_feature_keys
=
[
"feat"
])
# Copy the data to the specified device.
datapipe
=
datapipe
.
copy_to
(
device
=
device
)
# Create and return a DataLoader to handle data loading.
# Create and return a DataLoader to handle data loading.
dataloader
=
gb
.
DataLoader
(
datapipe
,
num_workers
=
0
)
dataloader
=
gb
.
DataLoader
(
datapipe
,
num_workers
=
0
)
return
dataloader
return
dataloader
def
train
(
model
,
dataloader
,
optimizer
,
criterion
,
device
,
num_classes
):
def
train
(
model
,
dataloader
,
optimizer
):
#####################################################################
# (HIGHLIGHT) Train the model for one epoch.
#
# - Iterates over the data loader, fetching mini-batches of graph data.
# - For each mini-batch, it performs a forward pass, computes loss, and
# updates the model parameters.
# - The function returns the average loss and accuracy for the epoch.
#
# Parameters:
# model: The GraphSAGE model.
# dataloader: DataLoader that provides mini-batches of graph data.
# optimizer: Optimizer used for updating model parameters.
# criterion: Loss function used for training.
# device: The device (CPU/GPU) to run the training on.
#####################################################################
model
.
train
()
# Set the model to training mode
model
.
train
()
# Set the model to training mode
total_loss
=
0
# Accumulator for the total loss
total_loss
=
0
# Accumulator for the total loss
total_correct
=
0
# Accumulator for the total number of correct predictions
total_correct
=
0
# Accumulator for the total number of correct predictions
total_samples
=
0
# Accumulator for the total number of samples processed
total_samples
=
0
# Accumulator for the total number of samples processed
num_batches
=
0
# Counter for the number of mini-batches processed
num_batches
=
0
# Counter for the number of mini-batches processed
for
minibatch
in
dataloader
:
for
_
,
minibatch
in
tqdm
(
enumerate
(
dataloader
),
"training"
):
node_features
=
minibatch
.
node_features
[
"feat"
]
#####################################################################
labels
=
minibatch
.
labels
# (HIGHLIGHT) Convert GraphBolt MiniBatch to PyG Data class.
#
# Call `MiniBatch.to_pyg_data()` and it will return a PyG Data class
# with necessary data and information.
#####################################################################
pyg_data
=
minibatch
.
to_pyg_data
()
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
out
=
model
(
minibatch
.
blocks
,
node_features
,
device
)
out
=
model
(
pyg_data
.
x
,
pyg_data
.
edge_index
)[:
pyg_data
.
y
.
shape
[
0
]]
loss
=
criterion
(
out
,
labels
)
y
=
pyg_data
.
y
total_loss
+=
loss
.
item
()
loss
=
F
.
cross_entropy
(
out
,
y
)
total_correct
+=
MF
.
accuracy
(
out
,
labels
,
task
=
"multiclass"
,
num_classes
=
num_classes
)
*
labels
.
size
(
0
)
total_samples
+=
labels
.
size
(
0
)
loss
.
backward
()
loss
.
backward
()
optimizer
.
step
()
optimizer
.
step
()
total_loss
+=
float
(
loss
)
total_correct
+=
int
(
out
.
argmax
(
dim
=-
1
).
eq
(
y
).
sum
())
total_samples
+=
y
.
shape
[
0
]
num_batches
+=
1
num_batches
+=
1
avg_loss
=
total_loss
/
num_batches
avg_loss
=
total_loss
/
num_batches
avg_accuracy
=
total_correct
/
total_samples
avg_accuracy
=
total_correct
/
total_samples
...
@@ -163,16 +164,16 @@ def train(model, dataloader, optimizer, criterion, device, num_classes):
...
@@ -163,16 +164,16 @@ def train(model, dataloader, optimizer, criterion, device, num_classes):
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
evaluate
(
model
,
dataloader
,
device
,
num_classes
):
def
evaluate
(
model
,
dataloader
,
num_classes
):
model
.
eval
()
model
.
eval
()
y_hats
=
[]
y_hats
=
[]
ys
=
[]
ys
=
[]
for
minibatch
in
dataloader
:
for
_
,
minibatch
in
tqdm
(
enumerate
(
dataloader
),
"evaluating"
)
:
node_features
=
minibatch
.
node_features
[
"feat"
]
pyg_data
=
minibatch
.
to_pyg_data
()
labels
=
minibatch
.
labels
out
=
model
(
pyg_data
.
x
,
pyg_data
.
edge_index
)[:
pyg_data
.
y
.
shape
[
0
]]
out
=
model
(
minibatch
.
blocks
,
node_features
,
device
)
y
=
pyg_data
.
y
y_hats
.
append
(
out
)
y_hats
.
append
(
out
)
ys
.
append
(
labels
)
ys
.
append
(
y
)
return
MF
.
accuracy
(
return
MF
.
accuracy
(
torch
.
cat
(
y_hats
),
torch
.
cat
(
y_hats
),
...
@@ -182,6 +183,24 @@ def evaluate(model, dataloader, device, num_classes):
...
@@ -182,6 +183,24 @@ def evaluate(model, dataloader, device, num_classes):
)
)
@
torch
.
no_grad
()
def
layerwise_infer
(
model
,
args
,
infer_dataloader
,
test_set
,
feature
,
num_classes
,
device
):
model
.
eval
()
features
=
feature
.
read
(
"node"
,
None
,
"feat"
)
pred
=
model
.
inference
(
args
,
infer_dataloader
,
features
,
device
)
pred
=
pred
[
test_set
.
_items
[
0
]]
label
=
test_set
.
_items
[
1
].
to
(
pred
.
device
)
return
MF
.
accuracy
(
pred
,
label
,
task
=
"multiclass"
,
num_classes
=
num_classes
,
)
def
main
():
def
main
():
parser
=
argparse
.
ArgumentParser
(
parser
=
argparse
.
ArgumentParser
(
description
=
"Which dataset are you going to use?"
description
=
"Which dataset are you going to use?"
...
@@ -189,45 +208,71 @@ def main():
...
@@ -189,45 +208,71 @@ def main():
parser
.
add_argument
(
parser
.
add_argument
(
"--dataset"
,
"--dataset"
,
type
=
str
,
type
=
str
,
default
=
"ogbn-
arxiv
"
,
default
=
"ogbn-
products
"
,
help
=
'Name of the dataset to use (e.g., "ogbn-products", "ogbn-arxiv")'
,
help
=
'Name of the dataset to use (e.g., "ogbn-products", "ogbn-arxiv")'
,
)
)
parser
.
add_argument
(
"--epochs"
,
type
=
int
,
default
=
10
,
help
=
"Number of training epochs."
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
1024
,
help
=
"Batch size for training."
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
dataset_name
=
args
.
dataset
dataset_name
=
args
.
dataset
dataset
=
gb
.
BuiltinDataset
(
dataset_name
).
load
()
dataset
=
gb
.
BuiltinDataset
(
dataset_name
).
load
()
graph
=
dataset
.
graph
graph
=
dataset
.
graph
feature
=
dataset
.
feature
feature
=
dataset
.
feature
.
pin_memory_
()
train_set
=
dataset
.
tasks
[
0
].
train_set
train_set
=
dataset
.
tasks
[
0
].
train_set
valid_set
=
dataset
.
tasks
[
0
].
validation_set
valid_set
=
dataset
.
tasks
[
0
].
validation_set
test_set
=
dataset
.
tasks
[
0
].
test_set
test_set
=
dataset
.
tasks
[
0
].
test_set
all_nodes_set
=
dataset
.
all_nodes_set
num_classes
=
dataset
.
tasks
[
0
].
metadata
[
"num_classes"
]
num_classes
=
dataset
.
tasks
[
0
].
metadata
[
"num_classes"
]
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
train_dataloader
=
create_dataloader
(
train_dataloader
=
create_dataloader
(
train_set
,
graph
,
feature
,
device
,
is_train
=
True
train_set
,
graph
,
feature
,
args
.
batch_size
,
[
5
,
10
,
15
],
device
,
job
=
"train"
,
)
)
valid_dataloader
=
create_dataloader
(
valid_dataloader
=
create_dataloader
(
valid_set
,
graph
,
feature
,
device
,
is_train
=
False
valid_set
,
graph
,
feature
,
args
.
batch_size
,
[
5
,
10
,
15
],
device
,
job
=
"evaluate"
,
)
)
test_dataloader
=
create_dataloader
(
infer_dataloader
=
create_dataloader
(
test_set
,
graph
,
feature
,
device
,
is_train
=
False
all_nodes_set
,
graph
,
feature
,
4
*
args
.
batch_size
,
[
-
1
],
device
,
job
=
"infer"
,
)
)
in_channels
=
feature
.
size
(
"node"
,
None
,
"feat"
)[
0
]
in_channels
=
feature
.
size
(
"node"
,
None
,
"feat"
)[
0
]
hidden_channels
=
128
hidden_channels
=
256
model
=
GraphSAGE
(
in_channels
,
hidden_channels
,
num_classes
).
to
(
device
)
model
=
GraphSAGE
(
in_channels
,
hidden_channels
,
num_classes
).
to
(
device
)
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.01
,
weight_decay
=
5e-4
)
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.003
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
for
epoch
in
range
(
args
.
epochs
):
for
epoch
in
range
(
10
):
train_loss
,
train_accuracy
=
train
(
model
,
train_dataloader
,
optimizer
)
train_loss
,
train_accuracy
=
train
(
model
,
train_dataloader
,
optimizer
,
criterion
,
device
,
num_classes
)
valid_accuracy
=
evaluate
(
model
,
valid_dataloader
,
device
,
num_classes
)
valid_accuracy
=
evaluate
(
model
,
valid_dataloader
,
num_classes
)
print
(
print
(
f
"Epoch
{
epoch
}
, Train Loss:
{
train_loss
:.
4
f
}
, Train Accuracy:
{
train_accuracy
:.
4
f
}
, "
f
"Epoch
{
epoch
}
, Train Loss:
{
train_loss
:.
4
f
}
, "
f
"Train Accuracy:
{
train_accuracy
:.
4
f
}
, "
f
"Valid Accuracy:
{
valid_accuracy
:.
4
f
}
"
f
"Valid Accuracy:
{
valid_accuracy
:.
4
f
}
"
)
)
test_accuracy
=
evaluate
(
model
,
test_dataloader
,
device
,
num_classes
)
test_accuracy
=
layerwise_infer
(
model
,
args
,
infer_dataloader
,
test_set
,
feature
,
num_classes
,
device
)
print
(
f
"Test Accuracy:
{
test_accuracy
:.
4
f
}
"
)
print
(
f
"Test Accuracy:
{
test_accuracy
:.
4
f
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment