Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
7e4c226d
Unverified
Commit
7e4c226d
authored
Aug 09, 2023
by
Muhammed Fatih BALIN
Committed by
GitHub
Aug 10, 2023
Browse files
[Model] Simplify labor example, add proper inference code (#6104)
parent
cf829dc3
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
151 additions
and
297 deletions
+151
-297
examples/pytorch/labor/README.md
examples/pytorch/labor/README.md
+64
-0
examples/pytorch/labor/model.py
examples/pytorch/labor/model.py
+35
-158
examples/pytorch/labor/train_lightning.py
examples/pytorch/labor/train_lightning.py
+52
-139
No files found.
examples/pytorch/labor/README.md
0 → 100644
View file @
7e4c226d
Layer-Neighbor Sampling -- Defusing Neighborhood Explosion in GNNs
============
-
Paper link:
[
https://arxiv.org/abs/2210.13339
](
https://arxiv.org/abs/2210.13339
)
This is the official Labor sampling example to reproduce the results in the original
paper with the GraphSAGE GNN model. The model can be changed to any other model where
NeighborSampler can be used.
Requirements
------------
```
bash
pip
install
requests
lightning
==
2.0.6 ogb
```
How to run
-------
### Minibatch training for node classification
Train w/ mini-batch sampling on the GPU for node classification on "ogbn-products"
```
bash
python3 train_lightning.py
--dataset
=
ogbn-products
```
Results:
```
Test Accuracy: 0.797
```
Any integer passed as the
`--importance-sampling=i`
argument runs the corresponding
LABOR-i variant.
`--importance-sampling=-1`
runs the LABOR-
*
variant.
`--vertex-limit`
argument is used if a vertex sampling budget is needed. It adjusts
the batch size at the end of every epoch so that the average number of sampled vertices
converges to the provided vertex limit. Can be used to replicate the vertex sampling
budget experiments in the Labor paper.
During training runs, statistics about number of sampled vertices, edges,
cache miss rates will be reported. One can use tensorboard to look at their plots
during/after training:
```
bash
tensorboard
--logdir
tb_logs
```
## Utilize a GPU feature cache for UVA training
```
bash
python3 train_lightning.py
--dataset
=
ogbn-products
--use-uva
--cache-size
=
500000
```
## Reduce GPU feature cache miss rate for UVA training
```
bash
python3 train_lightning.py
--dataset
=
ogbn-products
--use-uva
--cache-size
=
500000
--batch-dependency
=
64
```
## Force all layers to share the same neighborhood for shared vertices
```
bash
python3 train_lightning.py
--dataset
=
ogbn-products
--layer-dependency
```
\ No newline at end of file
examples/pytorch/labor/model.py
View file @
7e4c226d
...
...
@@ -5,84 +5,7 @@ import sklearn.metrics as skm
import
torch
as
th
import
torch.functional
as
F
import
torch.nn
as
nn
from
dgl.nn
import
GATv2Conv
class
GATv2
(
nn
.
Module
):
def
__init__
(
self
,
num_layers
,
in_dim
,
num_hidden
,
num_classes
,
heads
,
activation
,
feat_drop
,
attn_drop
,
negative_slope
,
residual
,
):
super
(
GATv2
,
self
).
__init__
()
self
.
num_layers
=
num_layers
self
.
gatv2_layers
=
nn
.
ModuleList
()
self
.
activation
=
activation
# input projection (no residual)
self
.
gatv2_layers
.
append
(
GATv2Conv
(
in_dim
,
num_hidden
,
heads
[
0
],
feat_drop
,
attn_drop
,
negative_slope
,
False
,
self
.
activation
,
True
,
bias
=
False
,
share_weights
=
True
,
)
)
# hidden layers
for
l
in
range
(
1
,
num_layers
-
1
):
# due to multi-head, the in_dim = num_hidden * num_heads
self
.
gatv2_layers
.
append
(
GATv2Conv
(
num_hidden
*
heads
[
l
-
1
],
num_hidden
,
heads
[
l
],
feat_drop
,
attn_drop
,
negative_slope
,
residual
,
self
.
activation
,
True
,
bias
=
False
,
share_weights
=
True
,
)
)
# output projection
self
.
gatv2_layers
.
append
(
GATv2Conv
(
num_hidden
*
heads
[
-
2
],
num_classes
,
heads
[
-
1
],
feat_drop
,
attn_drop
,
negative_slope
,
residual
,
None
,
True
,
bias
=
False
,
share_weights
=
True
,
)
)
def
forward
(
self
,
mfgs
,
h
):
for
l
,
mfg
in
enumerate
(
mfgs
):
h
=
self
.
gatv2_layers
[
l
](
mfg
,
h
)
h
=
h
.
flatten
(
1
)
if
l
<
self
.
num_layers
-
1
else
h
.
mean
(
1
)
return
h
import
tqdm
class
SAGE
(
nn
.
Module
):
...
...
@@ -124,87 +47,41 @@ class SAGE(nn.Module):
h
=
self
.
dropout
(
h
)
return
h
def
inference
(
self
,
g
,
device
,
batch_size
,
num_workers
,
buffer_device
=
None
):
# The difference between this inference function and the one in the official
# example is that the intermediate results can also benefit from prefetching.
g
.
ndata
[
"h"
]
=
g
.
ndata
[
"features"
]
sampler
=
dgl
.
dataloading
.
MultiLayerFullNeighborSampler
(
1
,
prefetch_node_feats
=
[
"h"
]
)
dataloader
=
dgl
.
dataloading
.
DataLoader
(
g
,
th
.
arange
(
g
.
num_nodes
(),
dtype
=
g
.
idtype
,
device
=
g
.
device
),
sampler
,
device
=
device
,
batch_size
=
batch_size
,
shuffle
=
False
,
drop_last
=
False
,
num_workers
=
num_workers
,
persistent_workers
=
(
num_workers
>
0
),
)
if
buffer_device
is
None
:
buffer_device
=
device
class
RGAT
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
hidden_channels
,
num_etypes
,
num_layers
,
num_heads
,
dropout
,
pred_ntype
,
):
super
().
__init__
()
self
.
convs
=
nn
.
ModuleList
()
self
.
norms
=
nn
.
ModuleList
()
self
.
skips
=
nn
.
ModuleList
()
self
.
train
(
False
)
self
.
convs
.
append
(
nn
.
ModuleList
(
[
dglnn
.
GATConv
(
in_channels
,
hidden_channels
//
num_heads
,
num_heads
,
allow_zero_in_degree
=
True
,
)
for
_
in
range
(
num_etypes
)
]
)
)
self
.
norms
.
append
(
nn
.
BatchNorm1d
(
hidden_channels
))
self
.
skips
.
append
(
nn
.
Linear
(
in_channels
,
hidden_channels
))
for
_
in
range
(
num_layers
-
1
):
self
.
convs
.
append
(
nn
.
ModuleList
(
[
dglnn
.
GATConv
(
hidden_channels
,
hidden_channels
//
num_heads
,
num_heads
,
allow_zero_in_degree
=
True
,
)
for
_
in
range
(
num_etypes
)
]
)
)
self
.
norms
.
append
(
nn
.
BatchNorm1d
(
hidden_channels
))
self
.
skips
.
append
(
nn
.
Linear
(
hidden_channels
,
hidden_channels
))
self
.
mlp
=
nn
.
Sequential
(
nn
.
Linear
(
hidden_channels
,
hidden_channels
),
nn
.
BatchNorm1d
(
hidden_channels
),
nn
.
ReLU
(),
nn
.
Dropout
(
dropout
),
nn
.
Linear
(
hidden_channels
,
out_channels
),
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
hidden_channels
=
hidden_channels
self
.
pred_ntype
=
pred_ntype
self
.
num_etypes
=
num_etypes
def
forward
(
self
,
mfgs
,
x
):
for
i
in
range
(
len
(
mfgs
)):
mfg
=
mfgs
[
i
]
x_dst
=
x
[
mfg
.
dst_in_src
]
for
data
in
[
mfg
.
srcdata
,
mfg
.
dstdata
]:
for
k
in
list
(
data
.
keys
()):
if
k
not
in
[
"features"
,
"labels"
]:
data
.
pop
(
k
)
mfg
=
dgl
.
block_to_graph
(
mfg
)
x_skip
=
self
.
skips
[
i
](
x_dst
)
for
j
in
range
(
self
.
num_etypes
):
subg
=
mfg
.
edge_subgraph
(
mfg
.
edata
[
"etype"
]
==
j
,
relabel_nodes
=
False
)
x_skip
+=
self
.
convs
[
i
][
j
](
subg
,
(
x
,
x_dst
)).
view
(
-
1
,
self
.
hidden_channels
)
x
=
self
.
norms
[
i
](
x_skip
)
x
=
th
.
nn
.
functional
.
elu
(
x
)
x
=
self
.
dropout
(
x
)
return
self
.
mlp
(
x
)
for
l
,
layer
in
enumerate
(
self
.
layers
):
y
=
th
.
zeros
(
g
.
num_nodes
(),
self
.
n_hidden
if
l
!=
len
(
self
.
layers
)
-
1
else
self
.
n_classes
,
device
=
buffer_device
,
)
for
input_nodes
,
output_nodes
,
blocks
in
tqdm
.
tqdm
(
dataloader
):
x
=
blocks
[
0
].
srcdata
[
"h"
]
h
=
layer
(
blocks
[
0
],
x
)
if
l
!=
len
(
self
.
layers
)
-
1
:
h
=
self
.
activation
(
h
)
h
=
self
.
dropout
(
h
)
y
[
output_nodes
]
=
h
.
to
(
buffer_device
)
g
.
ndata
[
"h"
]
=
y
return
y
examples/pytorch/labor/train_lightning.py
View file @
7e4c226d
...
...
@@ -33,7 +33,7 @@ import torch.nn.functional as F
from
ladies_sampler
import
LadiesSampler
,
normalized_edata
,
PoissonLadiesSampler
from
load_graph
import
load_dataset
from
model
import
GATv2
,
RGAT
,
SAGE
from
model
import
SAGE
from
pytorch_lightning
import
LightningDataModule
,
LightningModule
,
Trainer
from
pytorch_lightning.callbacks
import
Callback
,
EarlyStopping
,
ModelCheckpoint
from
pytorch_lightning.loggers
import
TensorBoardLogger
...
...
@@ -41,14 +41,6 @@ from pytorch_lightning.loggers import TensorBoardLogger
from
torchmetrics.classification
import
MulticlassF1Score
,
MultilabelF1Score
def
cuda_index_tensor
(
tensor
,
idx
):
assert
idx
.
device
!=
th
.
device
(
"cpu"
)
if
tensor
.
is_pinned
():
return
dgl
.
utils
.
gather_pinned_tensor_rows
(
tensor
,
idx
)
else
:
return
tensor
[
idx
.
long
()]
class
SAGELightning
(
LightningModule
):
def
__init__
(
self
,
...
...
@@ -56,7 +48,6 @@ class SAGELightning(LightningModule):
n_hidden
,
n_classes
,
n_layers
,
model
,
activation
,
dropout
,
lr
,
...
...
@@ -64,54 +55,15 @@ class SAGELightning(LightningModule):
):
super
().
__init__
()
self
.
save_hyperparameters
()
if
model
in
[
"sage"
]:
self
.
module
=
(
SAGE
(
self
.
module
=
SAGE
(
in_feats
,
n_hidden
,
n_classes
,
n_layers
,
activation
,
dropout
)
if
in_feats
!=
768
else
RGAT
(
in_feats
,
n_classes
,
n_hidden
,
5
,
n_layers
,
4
,
args
.
dropout
,
"paper"
,
)
)
else
:
heads
=
([
8
]
*
n_layers
)
+
[
1
]
self
.
module
=
GATv2
(
n_layers
,
in_feats
,
n_hidden
,
n_classes
,
heads
,
activation
,
dropout
,
dropout
,
0.2
,
True
,
)
self
.
lr
=
lr
f1score_class
=
(
self
.
f1score_class
=
lambda
:
(
MulticlassF1Score
if
not
multilabel
else
MultilabelF1Score
)
self
.
train_acc
=
f1score_class
(
n_classes
,
average
=
"micro"
)
self
.
val_acc
=
nn
.
ModuleList
(
[
f1score_class
(
n_classes
,
average
=
"micro"
),
f1score_class
(
n_classes
,
average
=
"micro"
),
]
)
self
.
test_acc
=
nn
.
ModuleList
(
[
f1score_class
(
n_classes
,
average
=
"micro"
),
f1score_class
(
n_classes
,
average
=
"micro"
),
]
)
)(
n_classes
,
average
=
"micro"
)
self
.
train_acc
=
self
.
f1score_class
()
self
.
val_acc
=
self
.
f1score_class
()
self
.
num_steps
=
0
self
.
cum_sampled_nodes
=
[
0
for
_
in
range
(
n_layers
+
1
)]
self
.
cum_sampled_edges
=
[
0
for
_
in
range
(
n_layers
)]
...
...
@@ -225,10 +177,10 @@ class SAGELightning(LightningModule):
batch_labels
=
mfgs
[
-
1
].
dstdata
[
"labels"
]
batch_pred
=
self
.
module
(
mfgs
,
batch_inputs
)
loss
=
self
.
loss_fn
(
batch_pred
,
batch_labels
)
self
.
val_acc
[
dataloader_idx
]
(
batch_pred
,
batch_labels
.
int
())
self
.
val_acc
(
batch_pred
,
batch_labels
.
int
())
self
.
log
(
"val_acc"
,
self
.
val_acc
[
dataloader_idx
]
,
self
.
val_acc
,
prog_bar
=
True
,
on_step
=
False
,
on_epoch
=
True
,
...
...
@@ -244,32 +196,6 @@ class SAGELightning(LightningModule):
batch_size
=
batch_labels
.
shape
[
0
],
)
def
test_step
(
self
,
batch
,
batch_idx
,
dataloader_idx
=
0
):
input_nodes
,
output_nodes
,
mfgs
=
batch
mfgs
=
[
mfg
.
int
().
to
(
device
)
for
mfg
in
mfgs
]
batch_inputs
=
mfgs
[
0
].
srcdata
[
"features"
]
batch_labels
=
mfgs
[
-
1
].
dstdata
[
"labels"
]
batch_pred
=
self
.
module
(
mfgs
,
batch_inputs
)
loss
=
self
.
loss_fn
(
batch_pred
,
batch_labels
)
self
.
test_acc
[
dataloader_idx
](
batch_pred
,
batch_labels
.
int
())
self
.
log
(
"test_acc"
,
self
.
test_acc
[
dataloader_idx
],
prog_bar
=
True
,
on_step
=
False
,
on_epoch
=
True
,
sync_dist
=
True
,
batch_size
=
batch_labels
.
shape
[
0
],
)
self
.
log
(
"test_loss"
,
loss
,
on_step
=
False
,
on_epoch
=
True
,
sync_dist
=
True
,
batch_size
=
batch_labels
.
shape
[
0
],
)
def
configure_optimizers
(
self
):
optimizer
=
th
.
optim
.
Adam
(
self
.
parameters
(),
lr
=
self
.
lr
)
return
optimizer
...
...
@@ -331,13 +257,6 @@ class DataModule(LightningDataModule):
prefetch_edge_feats
=
[
"etype"
]
if
"etype"
in
g
.
edata
else
[],
prefetch_labels
=
[
"labels"
],
)
full_sampler
=
dgl
.
dataloading
.
MultiLayerFullNeighborSampler
(
len
(
fanouts
),
prefetch_node_feats
=
[
"features"
],
prefetch_edge_feats
=
[
"etype"
]
if
"etype"
in
g
.
edata
else
[],
prefetch_labels
=
[
"labels"
],
)
unbiased_sampler
=
sampler
dataloader_device
=
th
.
device
(
"cpu"
)
g
=
g
.
formats
([
"csc"
])
...
...
@@ -363,8 +282,6 @@ class DataModule(LightningDataModule):
test_nid
,
)
self
.
sampler
=
sampler
self
.
unbiased_sampler
=
unbiased_sampler
self
.
full_sampler
=
full_sampler
self
.
device
=
dataloader_device
self
.
use_uva
=
use_uva
self
.
batch_size
=
batch_size
...
...
@@ -389,28 +306,10 @@ class DataModule(LightningDataModule):
)
def
val_dataloader
(
self
):
return
[
dgl
.
dataloading
.
DataLoader
(
return
dgl
.
dataloading
.
DataLoader
(
self
.
g
,
self
.
val_nid
,
sampler
,
device
=
self
.
device
,
use_uva
=
self
.
use_uva
,
batch_size
=
self
.
batch_size
,
shuffle
=
False
,
drop_last
=
False
,
num_workers
=
self
.
num_workers
,
gpu_cache
=
self
.
gpu_cache_arg
,
)
for
sampler
in
[
self
.
unbiased_sampler
]
]
def
test_dataloader
(
self
):
return
[
dgl
.
dataloading
.
DataLoader
(
self
.
g
,
self
.
test_nid
,
sampler
,
self
.
sampler
,
device
=
self
.
device
,
use_uva
=
self
.
use_uva
,
batch_size
=
self
.
batch_size
,
...
...
@@ -419,8 +318,6 @@ class DataModule(LightningDataModule):
num_workers
=
self
.
num_workers
,
gpu_cache
=
self
.
gpu_cache_arg
,
)
for
sampler
in
[
self
.
full_sampler
]
]
class
BatchSizeCallback
(
Callback
):
...
...
@@ -476,8 +373,10 @@ class BatchSizeCallback(Callback):
trainer
.
datamodule
.
batch_size
=
int
(
trainer
.
datamodule
.
batch_size
*
self
.
limit
/
self
.
m
)
trainer
.
reset_train_dataloader
()
trainer
.
reset_val_dataloader
()
loop
=
trainer
.
_active_loop
assert
loop
is
not
None
loop
.
_combined_loader
=
None
loop
.
setup_data
()
self
.
clear
()
...
...
@@ -500,7 +399,6 @@ if __name__ == "__main__":
argparser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
1024
)
argparser
.
add_argument
(
"--lr"
,
type
=
float
,
default
=
0.001
)
argparser
.
add_argument
(
"--dropout"
,
type
=
float
,
default
=
0.5
)
argparser
.
add_argument
(
"--independent-batches"
,
type
=
int
,
default
=
1
)
argparser
.
add_argument
(
"--num-workers"
,
type
=
int
,
...
...
@@ -515,8 +413,12 @@ if __name__ == "__main__":
"be undesired if they cannot fit in GPU memory at once. "
"This flag disables that."
,
)
argparser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"sage"
)
argparser
.
add_argument
(
"--sampler"
,
type
=
str
,
default
=
"labor"
)
argparser
.
add_argument
(
"--sampler"
,
type
=
str
,
default
=
"labor"
,
choices
=
[
"neighbor"
,
"labor"
,
"ladies"
,
"poisson-ladies"
],
)
argparser
.
add_argument
(
"--importance-sampling"
,
type
=
int
,
default
=
0
)
argparser
.
add_argument
(
"--layer-dependency"
,
action
=
"store_true"
)
argparser
.
add_argument
(
"--batch-dependency"
,
type
=
int
,
default
=
1
)
...
...
@@ -547,7 +449,7 @@ if __name__ == "__main__":
[
int
(
_
)
for
_
in
args
.
fan_out
.
split
(
","
)],
[
int
(
_
)
for
_
in
args
.
lad_out
.
split
(
","
)],
device
,
args
.
batch_size
//
args
.
independent_batches
,
args
.
batch_size
,
args
.
num_workers
,
args
.
sampler
,
args
.
importance_sampling
,
...
...
@@ -560,7 +462,6 @@ if __name__ == "__main__":
args
.
num_hidden
,
datamodule
.
n_classes
,
args
.
num_layers
,
args
.
model
,
F
.
relu
,
args
.
dropout
,
args
.
lr
,
...
...
@@ -570,12 +471,10 @@ if __name__ == "__main__":
# Train
callbacks
=
[]
if
not
args
.
disable_checkpoint
:
# callbacks.append(ModelCheckpoint(monitor='val_acc/dataloader_idx_0', save_top_k=1, mode='max'))
callbacks
.
append
(
ModelCheckpoint
(
monitor
=
"val_acc"
,
save_top_k
=
1
,
mode
=
"max"
)
)
callbacks
.
append
(
BatchSizeCallback
(
args
.
vertex_limit
))
# callbacks.append(EarlyStopping(monitor='val_acc/dataloader_idx_0', stopping_threshold=args.val_acc_target, mode='max', patience=args.early_stopping_patience))
callbacks
.
append
(
EarlyStopping
(
monitor
=
"val_acc"
,
...
...
@@ -584,19 +483,17 @@ if __name__ == "__main__":
patience
=
args
.
early_stopping_patience
,
)
)
subdir
=
"{}_{}_{}_{}_{}
_{}
"
.
format
(
subdir
=
"{}_{}_{}_{}_{}"
.
format
(
args
.
dataset
,
args
.
sampler
,
args
.
importance_sampling
,
args
.
layer_dependency
,
args
.
batch_dependency
,
args
.
independent_batches
,
)
logger
=
TensorBoardLogger
(
args
.
logdir
,
name
=
subdir
)
trainer
=
Trainer
(
accelerator
=
"gpu"
if
args
.
gpu
!=
-
1
else
"cpu"
,
devices
=
[
args
.
gpu
],
accumulate_grad_batches
=
args
.
independent_batches
,
max_epochs
=
args
.
num_epochs
,
max_steps
=
args
.
num_steps
,
min_steps
=
args
.
min_steps
,
...
...
@@ -618,5 +515,21 @@ if __name__ == "__main__":
checkpoint_path
=
ckpt
,
hparams_file
=
os
.
path
.
join
(
logdir
,
"hparams.yaml"
),
).
to
(
device
)
test_acc
=
trainer
.
test
(
model
,
datamodule
=
datamodule
)
print
(
"Test accuracy:"
,
test_acc
)
with
th
.
no_grad
():
graph
=
datamodule
.
g
pred
=
model
.
module
.
inference
(
graph
,
f
"cuda:
{
args
.
gpu
}
"
if
args
.
gpu
!=
-
1
else
"cpu"
,
4096
,
args
.
num_workers
,
graph
.
device
,
)
for
nid
,
split_name
in
zip
(
[
datamodule
.
train_nid
,
datamodule
.
val_nid
,
datamodule
.
test_nid
],
[
"Train"
,
"Validation"
,
"Test"
],
):
pred_nid
=
pred
[
nid
]
label
=
graph
.
ndata
[
"labels"
][
nid
]
f1score
=
model
.
f1score_class
().
to
(
pred
.
device
)
acc
=
f1score
(
pred_nid
,
label
)
print
(
f
"
{
split_name
}
accuracy:"
,
acc
.
item
())
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment