Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
c9b26dda
"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "e3921d5decacd10636b22e9a42ea32eebda69cb9"
Unverified
Commit
c9b26dda
authored
Nov 30, 2023
by
Rhett Ying
Committed by
GitHub
Nov 30, 2023
Browse files
[doc] add multi-gpu node classification tutorial (#6624)
parent
1d2a1cdc
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
408 additions
and
13 deletions
+408
-13
tutorials/multi/1_graph_classification.py
tutorials/multi/1_graph_classification.py
+12
-13
tutorials/multi/2_node_classification.py
tutorials/multi/2_node_classification.py
+396
-0
No files found.
tutorials/multi/1_graph_classification.py
View file @
c9b26dda
...
@@ -251,17 +251,16 @@ def main(rank, world_size, dataset, seed=0):
...
@@ -251,17 +251,16 @@ def main(rank, world_size, dataset, seed=0):
###############################################################################
###############################################################################
# Finally we load the dataset and launch the processes.
# Finally we load the dataset and launch the processes.
#
#
# .. code:: python
#
# if __name__ == '__main__':
# import torch.multiprocessing as mp
#
# from dgl.data import GINDataset
#
# num_gpus = 4
# procs = []
# dataset = GINDataset(name='IMDBBINARY', self_loop=False)
# mp.spawn(main, args=(num_gpus, dataset), nprocs=num_gpus)
# Thumbnail credits: DGL
import
torch.multiprocessing
as
mp
# sphinx_gallery_thumbnail_path = '_static/blitz_5_graph_classification.png'
from
dgl.data
import
GINDataset
if
__name__
==
"__main__"
:
if
not
torch
.
cuda
.
is_available
():
print
(
"No GPU found!"
)
exit
(
0
)
num_gpus
=
torch
.
cuda
.
device_count
()
procs
=
[]
dataset
=
GINDataset
(
name
=
"IMDBBINARY"
,
self_loop
=
False
)
mp
.
spawn
(
main
,
args
=
(
num_gpus
,
dataset
),
nprocs
=
num_gpus
)
tutorials/multi/2_node_classification.py
0 → 100644
View file @
c9b26dda
"""
Single Machine Multi-GPU Minibatch Node Classification
======================================================
In this tutorial, you will learn how to use multiple GPUs in training a
graph neural network (GNN) for node classification.
This tutorial assumes that you have read the `Stochastic GNN Training for Node
Classification in DGL <../../notebooks/stochastic_training/node_classification.ipynb>`__.
It also assumes that you know the basics of training general
models with multi-GPU with ``DistributedDataParallel``.
.. note::
See `this tutorial <https://pytorch.org/tutorials/intermediate/ddp_tutorial.html>`__
from PyTorch for general multi-GPU training with ``DistributedDataParallel``. Also,
see the first section of :doc:`the multi-GPU graph classification
tutorial <1_graph_classification>`
for an overview of using ``DistributedDataParallel`` with DGL.
"""
######################################################################
# Importing Packages
# ---------------
#
# We use ``torch.distributed`` to initialize a distributed training context
# and ``torch.multiprocessing`` to spawn multiple processes for each GPU.
#
import
os
os
.
environ
[
"DGLBACKEND"
]
=
"pytorch"
import
time
import
dgl.graphbolt
as
gb
import
dgl.nn
as
dglnn
import
torch
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torchmetrics.functional
as
MF
import
tqdm
from
torch.distributed.algorithms.join
import
Join
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
######################################################################
# Defining Model
# --------------
#
# The model will be again identical to `Stochastic GNN Training for Node
# Classification in DGL <../../notebooks/stochastic_training/node_classification.ipynb>`__.
#
class
SAGE
(
nn
.
Module
):
def
__init__
(
self
,
in_size
,
hidden_size
,
out_size
):
super
().
__init__
()
self
.
layers
=
nn
.
ModuleList
()
# Three-layer GraphSAGE-mean.
self
.
layers
.
append
(
dglnn
.
SAGEConv
(
in_size
,
hidden_size
,
"mean"
))
self
.
layers
.
append
(
dglnn
.
SAGEConv
(
hidden_size
,
hidden_size
,
"mean"
))
self
.
layers
.
append
(
dglnn
.
SAGEConv
(
hidden_size
,
out_size
,
"mean"
))
self
.
dropout
=
nn
.
Dropout
(
0.5
)
self
.
hidden_size
=
hidden_size
self
.
out_size
=
out_size
# Set the dtype for the layers manually.
self
.
set_layer_dtype
(
torch
.
float32
)
def
set_layer_dtype
(
self
,
dtype
):
for
layer
in
self
.
layers
:
for
param
in
layer
.
parameters
():
param
.
data
=
param
.
data
.
to
(
dtype
)
def
forward
(
self
,
blocks
,
x
):
hidden_x
=
x
for
layer_idx
,
(
layer
,
block
)
in
enumerate
(
zip
(
self
.
layers
,
blocks
)):
hidden_x
=
layer
(
block
,
hidden_x
)
is_last_layer
=
layer_idx
==
len
(
self
.
layers
)
-
1
if
not
is_last_layer
:
hidden_x
=
F
.
relu
(
hidden_x
)
hidden_x
=
self
.
dropout
(
hidden_x
)
return
hidden_x
######################################################################
# Mini-batch Data Loading
# -----------------------
#
# The major difference from the previous tutorial is that we will use
# ``DistributedItemSampler`` instead of ``ItemSampler`` to sample mini-batches
# of nodes. ``DistributedItemSampler`` is a distributed version of
# ``ItemSampler`` that works with ``DistributedDataParallel``. It is
# implemented as a wrapper around ``ItemSampler`` and will sample the same
# minibatch on all replicas. It also supports dropping the last non-full
# minibatch to avoid the need for padding.
#
def
create_dataloader
(
args
,
graph
,
features
,
itemset
,
device
,
drop_last
=
False
,
shuffle
=
True
,
drop_uneven_inputs
=
False
,
):
datapipe
=
gb
.
DistributedItemSampler
(
item_set
=
itemset
,
batch_size
=
args
.
batch_size
,
drop_last
=
drop_last
,
shuffle
=
shuffle
,
drop_uneven_inputs
=
drop_uneven_inputs
,
)
datapipe
=
datapipe
.
sample_neighbor
(
graph
,
args
.
fanout
)
datapipe
=
datapipe
.
fetch_feature
(
features
,
node_feature_keys
=
[
"feat"
])
datapipe
=
datapipe
.
to_dgl
()
datapipe
=
datapipe
.
copy_to
(
device
)
dataloader
=
gb
.
MultiProcessDataLoader
(
datapipe
,
num_workers
=
args
.
num_workers
)
return
dataloader
######################################################################
# Evaluation Loop
# ---------------
#
# The evaluation loop is almost identical to the previous tutorial.
#
@
torch
.
no_grad
()
def
evaluate
(
rank
,
args
,
model
,
graph
,
features
,
itemset
,
num_classes
,
device
):
model
.
eval
()
y
=
[]
y_hats
=
[]
dataloader
=
create_dataloader
(
args
,
graph
,
features
,
itemset
,
drop_last
=
False
,
shuffle
=
False
,
drop_uneven_inputs
=
False
,
device
=
device
,
)
for
step
,
data
in
(
tqdm
.
tqdm
(
enumerate
(
dataloader
))
if
rank
==
0
else
enumerate
(
dataloader
)
):
blocks
=
data
.
blocks
x
=
data
.
node_features
[
"feat"
]
y
.
append
(
data
.
labels
)
y_hats
.
append
(
model
.
module
(
blocks
,
x
))
res
=
MF
.
accuracy
(
torch
.
cat
(
y_hats
),
torch
.
cat
(
y
),
task
=
"multiclass"
,
num_classes
=
num_classes
,
)
return
res
.
to
(
device
)
######################################################################
# Training Loop
# -------------
#
# The training loop is also almost identical to the previous tutorial except
# that we use Join Context Manager to solve the uneven input problem. The
# mechanics of Distributed Data Parallel (DDP) training in PyTorch requires
# the number of inputs are the same for all ranks, otherwise the program may
# error or hang. To solve it, PyTorch provides Join Context Manager. Please
# refer to `this tutorial <https://pytorch.org/tutorials/advanced/generic_join.html>`__
# for detailed information.
#
def
train
(
world_size
,
rank
,
args
,
graph
,
features
,
train_set
,
valid_set
,
num_classes
,
model
,
device
,
):
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
args
.
lr
)
# Create training data loader.
dataloader
=
create_dataloader
(
args
,
graph
,
features
,
train_set
,
device
,
drop_last
=
False
,
shuffle
=
True
,
drop_uneven_inputs
=
False
,
)
for
epoch
in
range
(
args
.
epochs
):
epoch_start
=
time
.
time
()
model
.
train
()
total_loss
=
torch
.
tensor
(
0
,
dtype
=
torch
.
float
).
to
(
device
)
with
Join
([
model
]):
for
step
,
data
in
(
tqdm
.
tqdm
(
enumerate
(
dataloader
))
if
rank
==
0
else
enumerate
(
dataloader
)
):
# The input features are from the source nodes in the first
# layer's computation graph.
x
=
data
.
node_features
[
"feat"
]
# The ground truth labels are from the destination nodes
# in the last layer's computation graph.
y
=
data
.
labels
blocks
=
data
.
blocks
y_hat
=
model
(
blocks
,
x
)
# Compute loss.
loss
=
F
.
cross_entropy
(
y_hat
,
y
)
optimizer
.
zero_grad
()
loss
.
backward
()
optimizer
.
step
()
total_loss
+=
loss
# Evaluate the model.
if
rank
==
0
:
print
(
"Validating..."
)
acc
=
(
evaluate
(
rank
,
args
,
model
,
graph
,
features
,
valid_set
,
num_classes
,
device
,
)
/
world_size
)
########################################################################
# (HIGHLIGHT) Collect accuracy and loss values from sub-processes and
# obtain overall average values.
#
# `torch.distributed.reduce` is used to reduce tensors from all the
# sub-processes to a specified process, ReduceOp.SUM is used by default.
########################################################################
dist
.
reduce
(
tensor
=
acc
,
dst
=
0
)
total_loss
/=
step
+
1
dist
.
reduce
(
tensor
=
total_loss
,
dst
=
0
)
dist
.
barrier
()
epoch_end
=
time
.
time
()
if
rank
==
0
:
print
(
f
"Epoch
{
epoch
:
05
d
}
| "
f
"Average Loss
{
total_loss
.
item
()
/
world_size
:.
4
f
}
| "
f
"Accuracy
{
acc
.
item
():.
4
f
}
| "
f
"Time
{
epoch_end
-
epoch_start
:.
4
f
}
"
)
######################################################################
# Defining Traning and Evaluation Procedures
# ------------------------------------------
#
# The following code defines the main function for each process. It is
# similar to the previous tutorial except that we need to initialize a
# distributed training context with ``torch.distributed`` and wrap the model
# with ``torch.nn.parallel.DistributedDataParallel``.
#
def
run
(
rank
,
world_size
,
args
,
devices
,
dataset
):
# Set up multiprocessing environment.
device
=
devices
[
rank
]
torch
.
cuda
.
set_device
(
device
)
dist
.
init_process_group
(
backend
=
"nccl"
,
# Use NCCL backend for distributed GPU training
init_method
=
"tcp://127.0.0.1:12345"
,
world_size
=
world_size
,
rank
=
rank
,
)
graph
=
dataset
.
graph
features
=
dataset
.
feature
train_set
=
dataset
.
tasks
[
0
].
train_set
valid_set
=
dataset
.
tasks
[
0
].
validation_set
args
.
fanout
=
list
(
map
(
int
,
args
.
fanout
.
split
(
","
)))
num_classes
=
dataset
.
tasks
[
0
].
metadata
[
"num_classes"
]
in_size
=
features
.
size
(
"node"
,
None
,
"feat"
)[
0
]
hidden_size
=
256
out_size
=
num_classes
# Create GraphSAGE model. It should be copied onto a GPU as a replica.
model
=
SAGE
(
in_size
,
hidden_size
,
out_size
).
to
(
device
)
model
=
DDP
(
model
)
# Model training.
if
rank
==
0
:
print
(
"Training..."
)
train
(
world_size
,
rank
,
args
,
graph
,
features
,
train_set
,
valid_set
,
num_classes
,
model
,
device
,
)
# Test the model.
if
rank
==
0
:
print
(
"Testing..."
)
test_set
=
dataset
.
tasks
[
0
].
test_set
test_acc
=
(
evaluate
(
rank
,
args
,
model
,
graph
,
features
,
itemset
=
test_set
,
num_classes
=
num_classes
,
device
=
device
,
)
/
world_size
)
dist
.
reduce
(
tensor
=
test_acc
,
dst
=
0
)
dist
.
barrier
()
if
rank
==
0
:
print
(
f
"Test Accuracy
{
test_acc
.
item
():.
4
f
}
"
)
######################################################################
# Spawning Trainer Processes
# --------------------------
#
# The following code spawns a process for each GPU and calls the ``run``
# function defined above.
#
if
__name__
==
"__main__"
:
if
not
torch
.
cuda
.
is_available
():
print
(
"No GPU found!"
)
exit
(
0
)
args
=
{
"epochs"
:
5
,
"lr"
:
0.01
,
"batch_size"
:
1024
,
"fanout"
:
"10,10,10"
,
"num_workers"
:
0
,
}
devices
=
torch
.
arange
(
torch
.
cuda
.
device_count
())
world_size
=
len
(
devices
)
print
(
f
"Training with
{
world_size
}
gpus."
)
# Load and preprocess dataset.
dataset
=
gb
.
BuiltinDataset
(
"ogbn-arxiv"
).
load
()
# Thread limiting to avoid resource competition.
os
.
environ
[
"OMP_NUM_THREADS"
]
=
str
(
mp
.
cpu_count
()
//
2
//
world_size
)
mp
.
set_sharing_strategy
(
"file_system"
)
mp
.
spawn
(
run
,
args
=
(
world_size
,
args
,
devices
,
dataset
),
nprocs
=
world_size
,
join
=
True
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment