Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
c0ac2f60
"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "86becea77f3630b0e65525e4ebd2b10c26113763"
Unverified
Commit
c0ac2f60
authored
Sep 20, 2023
by
Muhammed Fatih BALIN
Committed by
GitHub
Sep 21, 2023
Browse files
[GraphBolt] Improving the node prediction example (#6353)
parent
6a9142c2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
60 additions
and
42 deletions
+60
-42
examples/sampling/graphbolt/lightning/README.md
examples/sampling/graphbolt/lightning/README.md
+1
-1
examples/sampling/graphbolt/lightning/node_classification.py
examples/sampling/graphbolt/lightning/node_classification.py
+59
-41
No files found.
examples/sampling/graphbolt/lightning/README.md
View file @
c0ac2f60
...
...
@@ -9,5 +9,5 @@ python3 node_classification.py
### Results
```
Valid Accuracy: 0.
87
7
Valid Accuracy: 0.
90
7
```
\ No newline at end of file
examples/sampling/graphbolt/lightning/node_classification.py
View file @
c0ac2f60
...
...
@@ -10,7 +10,8 @@ main
│ │
│ └───> ItemSampler (Distribute data to minibatchs)
│ │
│ └───> sample_neighbor (Sample a subgraph for a minibatch)
│ └───> sample_neighbor or sample_layer_neighbor
(Sample a subgraph for a minibatch)
│ │
│ └───> fetch_feature (Fetch features for the sampled subgraph)
│
...
...
@@ -29,7 +30,7 @@ main
│
└───> Trainer[HIGHLIGHT]
│
├───> SAGE.forward (
RGCN
model forward pass)
├───> SAGE.forward (
GraphSAGE
model forward pass)
│
└───> Validate
"""
...
...
@@ -42,7 +43,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
pytorch_lightning
import
LightningDataModule
,
LightningModule
,
Trainer
from
pytorch_lightning.callbacks
import
ModelCheckpoint
from
pytorch_lightning.callbacks
import
EarlyStopping
,
ModelCheckpoint
from
torchmetrics
import
Accuracy
...
...
@@ -69,6 +70,28 @@ class SAGE(LightningModule):
h
=
self
.
dropout
(
h
)
return
h
def
log_node_and_edge_counts
(
self
,
blocks
):
node_counts
=
[
block
.
num_src_nodes
()
for
block
in
blocks
]
+
[
blocks
[
-
1
].
num_dst_nodes
()
]
edge_counts
=
[
block
.
num_edges
()
for
block
in
blocks
]
for
i
,
c
in
enumerate
(
node_counts
):
self
.
log
(
f
"num_nodes/
{
i
}
"
,
float
(
c
),
prog_bar
=
True
,
on_step
=
True
,
on_epoch
=
False
,
)
if
i
<
len
(
edge_counts
):
self
.
log
(
f
"num_edges/
{
i
}
"
,
float
(
edge_counts
[
i
]),
prog_bar
=
True
,
on_step
=
True
,
on_epoch
=
False
,
)
def
training_step
(
self
,
batch
,
batch_idx
):
# TODO: Move this to the data pipeline as a stage.
blocks
=
[
block
.
to
(
"cuda"
)
for
block
in
batch
.
to_dgl_blocks
()]
...
...
@@ -84,6 +107,7 @@ class SAGE(LightningModule):
on_step
=
True
,
on_epoch
=
False
,
)
self
.
log_node_and_edge_counts
(
blocks
)
return
loss
def
validation_step
(
self
,
batch
,
batch_idx
):
...
...
@@ -96,10 +120,11 @@ class SAGE(LightningModule):
"val_acc"
,
self
.
val_acc
,
prog_bar
=
True
,
on_step
=
Tru
e
,
on_step
=
Fals
e
,
on_epoch
=
True
,
sync_dist
=
True
,
)
self
.
log_node_and_edge_counts
(
blocks
)
def
configure_optimizers
(
self
):
optimizer
=
torch
.
optim
.
Adam
(
...
...
@@ -114,75 +139,67 @@ class DataModule(LightningDataModule):
self
.
fanouts
=
fanouts
self
.
batch_size
=
batch_size
self
.
num_workers
=
num_workers
# TODO: Update with a publicly accessible URL once the dataset has been
# uploaded.
dataset
=
gb
.
OnDiskDataset
(
"/home/ubuntu/workspace/example_ogbn_products/"
)
dataset
.
load
()
dataset
=
gb
.
BuiltinDataset
(
"ogbn-products"
).
load
()
self
.
feature_store
=
dataset
.
feature
self
.
graph
=
dataset
.
graph
self
.
train_set
=
dataset
.
tasks
[
0
].
train_set
self
.
valid_set
=
dataset
.
tasks
[
0
].
validation_set
self
.
num_classes
=
dataset
.
tasks
[
0
].
metadata
[
"num_classes"
]
########################################################################
# (HIGHLIGHT) The 'train_dataloader' and 'val_dataloader' hooks are
# essential components of the Lightning framework, defining how data is
# loaded during training and validation. In this example, we utilize a
# specialized 'graphbolt dataloader', which are concatenated by a series
# of datappipes, for these purposes.
########################################################################
def
train_dataloader
(
self
):
def
create_dataloader
(
self
,
node_set
,
is_train
):
datapipe
=
gb
.
ItemSampler
(
self
.
train
_set
,
node
_set
,
batch_size
=
self
.
batch_size
,
shuffle
=
True
,
drop_last
=
True
,
)
datapipe
=
datapipe
.
sample_neighbor
(
self
.
graph
,
self
.
fanouts
)
sampler
=
(
datapipe
.
sample_layer_neighbor
if
is_train
else
datapipe
.
sample_neighbor
)
datapipe
=
sampler
(
self
.
graph
,
self
.
fanouts
)
datapipe
=
datapipe
.
fetch_feature
(
self
.
feature_store
,
[
"feat"
])
dataloader
=
gb
.
MultiProcessDataLoader
(
datapipe
,
num_workers
=
self
.
num_workers
)
return
dataloader
########################################################################
# (HIGHLIGHT) The 'train_dataloader' and 'val_dataloader' hooks are
# essential components of the Lightning framework, defining how data is
# loaded during training and validation. In this example, we utilize a
# specialized 'graphbolt dataloader', which are concatenated by a series
# of datapipes, for these purposes.
########################################################################
def
train_dataloader
(
self
):
return
self
.
create_dataloader
(
self
.
train_set
,
is_train
=
True
)
def
val_dataloader
(
self
):
datapipe
=
gb
.
ItemSampler
(
self
.
valid_set
,
batch_size
=
self
.
batch_size
,
shuffle
=
True
,
drop_last
=
True
,
)
datapipe
=
datapipe
.
sample_neighbor
(
self
.
graph
,
self
.
fanouts
)
datapipe
=
datapipe
.
fetch_feature
(
self
.
feature_store
,
[
"feat"
])
dataloader
=
gb
.
MultiProcessDataLoader
(
datapipe
,
num_workers
=
self
.
num_workers
)
return
dataloader
return
self
.
create_dataloader
(
self
.
valid_set
,
is_train
=
False
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"GNN baselines on ogb
gmol*
data with
Pytorch Geometrics
"
description
=
"GNN baselines on ogb
n-products
data with
GraphBolt
"
)
parser
.
add_argument
(
"--num_gpus"
,
type
=
int
,
default
=
4
,
help
=
"number of GPUs used for computing (default:
4
)"
,
default
=
1
,
help
=
"number of GPUs used for computing (default:
1
)"
,
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
32
,
help
=
"input batch size for training (default:
32
)"
,
default
=
1024
,
help
=
"input batch size for training (default:
1024
)"
,
)
parser
.
add_argument
(
"--epochs"
,
type
=
int
,
default
=
1
0
,
help
=
"number of epochs to train (default:
10
0)"
,
default
=
4
0
,
help
=
"number of epochs to train (default:
4
0)"
,
)
parser
.
add_argument
(
"--num_workers"
,
...
...
@@ -196,7 +213,8 @@ if __name__ == "__main__":
model
=
SAGE
(
100
,
256
,
datamodule
.
num_classes
).
to
(
torch
.
double
)
# Train.
checkpoint_callback
=
ModelCheckpoint
(
monitor
=
"val_acc"
,
save_top_k
=
1
)
checkpoint_callback
=
ModelCheckpoint
(
monitor
=
"val_acc"
,
mode
=
"max"
)
early_stopping_callback
=
EarlyStopping
(
monitor
=
"val_acc"
,
mode
=
"max"
)
########################################################################
# (HIGHLIGHT) The `Trainer` is the key Class in lightning, which automates
# everything after defining `LightningDataModule` and
...
...
@@ -207,6 +225,6 @@ if __name__ == "__main__":
accelerator
=
"gpu"
,
devices
=
args
.
num_gpus
,
max_epochs
=
args
.
epochs
,
callbacks
=
[
checkpoint_callback
],
callbacks
=
[
checkpoint_callback
,
early_stopping_callback
],
)
trainer
.
fit
(
model
,
datamodule
=
datamodule
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment