Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
7c7113f6
Unverified
Commit
7c7113f6
authored
May 17, 2021
by
Quan (Andy) Gan
Committed by
GitHub
May 17, 2021
Browse files
add use_ddp to dataloaders (#2911)
parent
b03077b6
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
218 additions
and
46 deletions
+218
-46
examples/pytorch/GATNE-T/src/main_sparse_multi_gpus.py
examples/pytorch/GATNE-T/src/main_sparse_multi_gpus.py
+23
-11
examples/pytorch/gcmc/train_sampling.py
examples/pytorch/gcmc/train_sampling.py
+16
-12
examples/pytorch/graphsage/train_cv_multi_gpu.py
examples/pytorch/graphsage/train_cv_multi_gpu.py
+18
-10
examples/pytorch/graphsage/train_sampling_multi_gpu.py
examples/pytorch/graphsage/train_sampling_multi_gpu.py
+4
-3
examples/pytorch/graphsage/train_sampling_unsupervised.py
examples/pytorch/graphsage/train_sampling_unsupervised.py
+3
-6
python/dgl/dataloading/pytorch/__init__.py
python/dgl/dataloading/pytorch/__init__.py
+154
-4
No files found.
examples/pytorch/GATNE-T/src/main_sparse_multi_gpus.py
View file @
7c7113f6
...
@@ -264,18 +264,28 @@ def run(proc_id, n_gpus, args, devices, data):
...
@@ -264,18 +264,28 @@ def run(proc_id, n_gpus, args, devices, data):
neighbor_samples
=
args
.
neighbor_samples
neighbor_samples
=
args
.
neighbor_samples
num_workers
=
args
.
workers
num_workers
=
args
.
workers
train_pairs
=
torch
.
split
(
torch
.
tensor
(
train_pairs
),
math
.
ceil
(
len
(
train_pairs
)
/
n_gpus
)
)[
proc_id
]
neighbor_sampler
=
NeighborSampler
(
g
,
[
neighbor_samples
])
neighbor_sampler
=
NeighborSampler
(
g
,
[
neighbor_samples
])
train_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
if
n_gpus
>
1
:
train_pairs
,
train_sampler
=
torch
.
utils
.
data
.
distributed
.
DistributedSampler
(
batch_size
=
batch_size
,
train_pairs
,
num_replicas
=
world_size
,
rank
=
proc_id
,
shuffle
=
True
,
drop_last
=
False
)
collate_fn
=
neighbor_sampler
.
sample
,
train_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
shuffle
=
True
,
train_pairs
,
num_workers
=
num_workers
,
batch_size
=
batch_size
,
pin_memory
=
True
,
collate_fn
=
neighbor_sampler
.
sample
,
)
num_workers
=
num_workers
,
sampler
=
train_sampler
,
pin_memory
=
True
,
)
else
:
train_dataloader
=
torch
.
utils
.
data
.
DataLoader
(
train_pairs
,
batch_size
=
batch_size
,
collate_fn
=
neighbor_sampler
.
sample
,
num_workers
=
num_workers
,
shuffle
=
True
,
drop_last
=
False
,
pin_memory
=
True
,
)
model
=
DGLGATNE
(
model
=
DGLGATNE
(
num_nodes
,
embedding_size
,
embedding_u_size
,
edge_types
,
edge_type_count
,
dim_a
,
num_nodes
,
embedding_size
,
embedding_u_size
,
edge_types
,
edge_type_count
,
dim_a
,
...
@@ -333,6 +343,8 @@ def run(proc_id, n_gpus, args, devices, data):
...
@@ -333,6 +343,8 @@ def run(proc_id, n_gpus, args, devices, data):
start
=
time
.
time
()
start
=
time
.
time
()
for
epoch
in
range
(
epochs
):
for
epoch
in
range
(
epochs
):
if
n_gpus
>
1
:
train_sampler
.
set_epoch
(
epoch
)
model
.
train
()
model
.
train
()
data_iter
=
train_dataloader
data_iter
=
train_dataloader
...
...
examples/pytorch/gcmc/train_sampling.py
View file @
7c7113f6
...
@@ -182,6 +182,17 @@ def config():
...
@@ -182,6 +182,17 @@ def config():
def
run
(
proc_id
,
n_gpus
,
args
,
devices
,
dataset
):
def
run
(
proc_id
,
n_gpus
,
args
,
devices
,
dataset
):
dev_id
=
devices
[
proc_id
]
dev_id
=
devices
[
proc_id
]
if
n_gpus
>
1
:
dist_init_method
=
'tcp://{master_ip}:{master_port}'
.
format
(
master_ip
=
'127.0.0.1'
,
master_port
=
'12345'
)
world_size
=
n_gpus
th
.
distributed
.
init_process_group
(
backend
=
"nccl"
,
init_method
=
dist_init_method
,
world_size
=
world_size
,
rank
=
dev_id
)
if
n_gpus
>
0
:
th
.
cuda
.
set_device
(
dev_id
)
train_labels
=
dataset
.
train_labels
train_labels
=
dataset
.
train_labels
train_truths
=
dataset
.
train_truths
train_truths
=
dataset
.
train_truths
num_edges
=
train_truths
.
shape
[
0
]
num_edges
=
train_truths
.
shape
[
0
]
...
@@ -196,6 +207,7 @@ def run(proc_id, n_gpus, args, devices, dataset):
...
@@ -196,6 +207,7 @@ def run(proc_id, n_gpus, args, devices, dataset):
dataset
.
train_enc_graph
.
number_of_edges
(
etype
=
to_etype_name
(
k
)))
dataset
.
train_enc_graph
.
number_of_edges
(
etype
=
to_etype_name
(
k
)))
for
k
in
dataset
.
possible_rating_values
},
for
k
in
dataset
.
possible_rating_values
},
sampler
,
sampler
,
use_ddp
=
n_gpus
>
1
,
batch_size
=
args
.
minibatch_size
,
batch_size
=
args
.
minibatch_size
,
shuffle
=
True
,
shuffle
=
True
,
drop_last
=
False
)
drop_last
=
False
)
...
@@ -218,17 +230,6 @@ def run(proc_id, n_gpus, args, devices, dataset):
...
@@ -218,17 +230,6 @@ def run(proc_id, n_gpus, args, devices, dataset):
shuffle
=
False
,
shuffle
=
False
,
drop_last
=
False
)
drop_last
=
False
)
if
n_gpus
>
1
:
dist_init_method
=
'tcp://{master_ip}:{master_port}'
.
format
(
master_ip
=
'127.0.0.1'
,
master_port
=
'12345'
)
world_size
=
n_gpus
th
.
distributed
.
init_process_group
(
backend
=
"nccl"
,
init_method
=
dist_init_method
,
world_size
=
world_size
,
rank
=
dev_id
)
if
n_gpus
>
0
:
th
.
cuda
.
set_device
(
dev_id
)
nd_possible_rating_values
=
\
nd_possible_rating_values
=
\
th
.
FloatTensor
(
dataset
.
possible_rating_values
)
th
.
FloatTensor
(
dataset
.
possible_rating_values
)
nd_possible_rating_values
=
nd_possible_rating_values
.
to
(
dev_id
)
nd_possible_rating_values
=
nd_possible_rating_values
.
to
(
dev_id
)
...
@@ -254,6 +255,8 @@ def run(proc_id, n_gpus, args, devices, dataset):
...
@@ -254,6 +255,8 @@ def run(proc_id, n_gpus, args, devices, dataset):
iter_idx
=
1
iter_idx
=
1
for
epoch
in
range
(
1
,
args
.
train_max_epoch
):
for
epoch
in
range
(
1
,
args
.
train_max_epoch
):
if
n_gpus
>
1
:
dataloader
.
set_epoch
(
epoch
)
if
epoch
>
1
:
if
epoch
>
1
:
t0
=
time
.
time
()
t0
=
time
.
time
()
net
.
train
()
net
.
train
()
...
@@ -340,7 +343,8 @@ def run(proc_id, n_gpus, args, devices, dataset):
...
@@ -340,7 +343,8 @@ def run(proc_id, n_gpus, args, devices, dataset):
if
n_gpus
>
1
:
if
n_gpus
>
1
:
th
.
distributed
.
barrier
()
th
.
distributed
.
barrier
()
print
(
logging_str
)
if
proc_id
==
0
:
print
(
logging_str
)
if
proc_id
==
0
:
if
proc_id
==
0
:
print
(
'Best epoch Idx={}, Best Valid RMSE={:.4f}, Best Test RMSE={:.4f}'
.
format
(
print
(
'Best epoch Idx={}, Best Valid RMSE={:.4f}, Best Test RMSE={:.4f}'
.
format
(
best_epoch
,
best_valid_rmse
,
best_test_rmse
))
best_epoch
,
best_valid_rmse
,
best_test_rmse
))
...
...
examples/pytorch/graphsage/train_cv_multi_gpu.py
View file @
7c7113f6
...
@@ -234,20 +234,26 @@ def run(proc_id, n_gpus, args, devices, data):
...
@@ -234,20 +234,26 @@ def run(proc_id, n_gpus, args, devices, data):
train_nid
=
train_mask
.
nonzero
().
squeeze
()
train_nid
=
train_mask
.
nonzero
().
squeeze
()
val_nid
=
val_mask
.
nonzero
().
squeeze
()
val_nid
=
val_mask
.
nonzero
().
squeeze
()
# Split train_nid
train_nid
=
th
.
split
(
train_nid
,
math
.
ceil
(
len
(
train_nid
)
/
n_gpus
))[
proc_id
]
# Create sampler
# Create sampler
sampler
=
NeighborSampler
(
g
,
[
int
(
_
)
for
_
in
args
.
fan_out
.
split
(
','
)])
sampler
=
NeighborSampler
(
g
,
[
int
(
_
)
for
_
in
args
.
fan_out
.
split
(
','
)])
# Create PyTorch DataLoader for constructing blocks
# Create PyTorch DataLoader for constructing blocks
dataloader
=
DataLoader
(
if
n_gpus
>
1
:
dataset
=
train_nid
.
numpy
(),
dist_sampler
=
torch
.
utils
.
data
.
distributed
.
DistributedSampler
(
train_nid
.
numpy
(),
shuffle
=
True
,
drop_last
=
False
)
batch_size
=
args
.
batch_size
,
dataloader
=
DataLoader
(
collate_fn
=
sampler
.
sample_blocks
,
dataset
=
train_nid
.
numpy
(),
shuffle
=
True
,
batch_size
=
args
.
batch_size
,
drop_last
=
False
,
collate_fn
=
sampler
.
sample_blocks
,
num_workers
=
args
.
num_workers_per_gpu
)
sampler
=
dist_sampler
,
num_workers
=
args
.
num_workers_per_gpu
)
else
:
dataloader
=
DataLoader
(
dataset
=
train_nid
.
numpy
(),
batch_size
=
args
.
batch_size
,
collate_fn
=
sampler
.
sample_blocks
,
shuffle
=
True
,
drop_last
=
False
,
num_workers
=
args
.
num_workers_per_gpu
)
# Define model
# Define model
model
=
SAGE
(
in_feats
,
args
.
num_hidden
,
n_classes
,
args
.
num_layers
,
F
.
relu
)
model
=
SAGE
(
in_feats
,
args
.
num_hidden
,
n_classes
,
args
.
num_layers
,
F
.
relu
)
...
@@ -274,6 +280,8 @@ def run(proc_id, n_gpus, args, devices, data):
...
@@ -274,6 +280,8 @@ def run(proc_id, n_gpus, args, devices, data):
avg
=
0
avg
=
0
iter_tput
=
[]
iter_tput
=
[]
for
epoch
in
range
(
args
.
num_epochs
):
for
epoch
in
range
(
args
.
num_epochs
):
if
n_gpus
>
1
:
dist_sampler
.
set_epoch
(
epoch
)
tic
=
time
.
time
()
tic
=
time
.
time
()
model
.
train
()
model
.
train
()
for
step
,
(
blocks
,
hist_blocks
)
in
enumerate
(
dataloader
):
for
step
,
(
blocks
,
hist_blocks
)
in
enumerate
(
dataloader
):
...
...
examples/pytorch/graphsage/train_sampling_multi_gpu.py
View file @
7c7113f6
...
@@ -85,9 +85,7 @@ def run(proc_id, n_gpus, args, devices, data):
...
@@ -85,9 +85,7 @@ def run(proc_id, n_gpus, args, devices, data):
train_nid
=
train_mask
.
nonzero
().
squeeze
()
train_nid
=
train_mask
.
nonzero
().
squeeze
()
val_nid
=
val_mask
.
nonzero
().
squeeze
()
val_nid
=
val_mask
.
nonzero
().
squeeze
()
test_nid
=
test_mask
.
nonzero
().
squeeze
()
test_nid
=
test_mask
.
nonzero
().
squeeze
()
train_nid
=
train_nid
[:
n_gpus
*
args
.
batch_size
+
1
]
# Split train_nid
train_nid
=
th
.
split
(
train_nid
,
math
.
ceil
(
len
(
train_nid
)
/
n_gpus
))[
proc_id
]
# Create PyTorch DataLoader for constructing blocks
# Create PyTorch DataLoader for constructing blocks
sampler
=
dgl
.
dataloading
.
MultiLayerNeighborSampler
(
sampler
=
dgl
.
dataloading
.
MultiLayerNeighborSampler
(
...
@@ -96,6 +94,7 @@ def run(proc_id, n_gpus, args, devices, data):
...
@@ -96,6 +94,7 @@ def run(proc_id, n_gpus, args, devices, data):
train_g
,
train_g
,
train_nid
,
train_nid
,
sampler
,
sampler
,
use_ddp
=
n_gpus
>
1
,
batch_size
=
args
.
batch_size
,
batch_size
=
args
.
batch_size
,
shuffle
=
True
,
shuffle
=
True
,
drop_last
=
False
,
drop_last
=
False
,
...
@@ -113,6 +112,8 @@ def run(proc_id, n_gpus, args, devices, data):
...
@@ -113,6 +112,8 @@ def run(proc_id, n_gpus, args, devices, data):
avg
=
0
avg
=
0
iter_tput
=
[]
iter_tput
=
[]
for
epoch
in
range
(
args
.
num_epochs
):
for
epoch
in
range
(
args
.
num_epochs
):
if
n_gpus
>
1
:
dataloader
.
set_epoch
(
epoch
)
tic
=
time
.
time
()
tic
=
time
.
time
()
# Loop over the dataloader to sample the computation dependency graph as a list of
# Loop over the dataloader to sample the computation dependency graph as a list of
...
...
examples/pytorch/graphsage/train_sampling_unsupervised.py
View file @
7c7113f6
...
@@ -76,12 +76,6 @@ def run(proc_id, n_gpus, args, devices, data):
...
@@ -76,12 +76,6 @@ def run(proc_id, n_gpus, args, devices, data):
# Create PyTorch DataLoader for constructing blocks
# Create PyTorch DataLoader for constructing blocks
n_edges
=
g
.
num_edges
()
n_edges
=
g
.
num_edges
()
train_seeds
=
np
.
arange
(
n_edges
)
train_seeds
=
np
.
arange
(
n_edges
)
if
n_gpus
>
0
:
num_per_gpu
=
(
train_seeds
.
shape
[
0
]
+
n_gpus
-
1
)
//
n_gpus
train_seeds
=
train_seeds
[
proc_id
*
num_per_gpu
:
(
proc_id
+
1
)
*
num_per_gpu
\
if
(
proc_id
+
1
)
*
num_per_gpu
<
train_seeds
.
shape
[
0
]
else
train_seeds
.
shape
[
0
]]
# Create sampler
# Create sampler
sampler
=
dgl
.
dataloading
.
MultiLayerNeighborSampler
(
sampler
=
dgl
.
dataloading
.
MultiLayerNeighborSampler
(
...
@@ -93,6 +87,7 @@ def run(proc_id, n_gpus, args, devices, data):
...
@@ -93,6 +87,7 @@ def run(proc_id, n_gpus, args, devices, data):
th
.
arange
(
n_edges
//
2
,
n_edges
),
th
.
arange
(
n_edges
//
2
,
n_edges
),
th
.
arange
(
0
,
n_edges
//
2
)]),
th
.
arange
(
0
,
n_edges
//
2
)]),
negative_sampler
=
NegativeSampler
(
g
,
args
.
num_negs
,
args
.
neg_share
),
negative_sampler
=
NegativeSampler
(
g
,
args
.
num_negs
,
args
.
neg_share
),
use_ddp
=
n_gpus
>
1
,
batch_size
=
args
.
batch_size
,
batch_size
=
args
.
batch_size
,
shuffle
=
True
,
shuffle
=
True
,
drop_last
=
False
,
drop_last
=
False
,
...
@@ -116,6 +111,8 @@ def run(proc_id, n_gpus, args, devices, data):
...
@@ -116,6 +111,8 @@ def run(proc_id, n_gpus, args, devices, data):
best_eval_acc
=
0
best_eval_acc
=
0
best_test_acc
=
0
best_test_acc
=
0
for
epoch
in
range
(
args
.
num_epochs
):
for
epoch
in
range
(
args
.
num_epochs
):
if
n_gpus
>
1
:
dataloader
.
set_epoch
(
epoch
)
tic
=
time
.
time
()
tic
=
time
.
time
()
# Loop over the dataloader to sample the computation dependency graph as a list of
# Loop over the dataloader to sample the computation dependency graph as a list of
...
...
python/dgl/dataloading/pytorch/__init__.py
View file @
7c7113f6
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
import
inspect
import
inspect
import
torch
as
th
import
torch
as
th
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
torch.utils.data.distributed
import
DistributedSampler
from
..dataloader
import
NodeCollator
,
EdgeCollator
,
GraphCollator
from
..dataloader
import
NodeCollator
,
EdgeCollator
,
GraphCollator
from
...distributed
import
DistGraph
from
...distributed
import
DistGraph
from
...distributed
import
DistDataLoader
from
...distributed
import
DistDataLoader
...
@@ -272,6 +273,12 @@ class NodeDataLoader:
...
@@ -272,6 +273,12 @@ class NodeDataLoader:
device : device context, optional
device : device context, optional
The device of the generated MFGs in each iteration, which should be a
The device of the generated MFGs in each iteration, which should be a
PyTorch device object (e.g., ``torch.device``).
PyTorch device object (e.g., ``torch.device``).
use_ddp : boolean, optional
If True, tells the DataLoader to split the training set for each
participating process appropriately using
:mod:`torch.utils.data.distributed.DistributedSampler`.
Overrides the :attr:`sampler` argument of :class:`torch.utils.data.DataLoader`.
kwargs : dict
kwargs : dict
Arguments being passed to :py:class:`torch.utils.data.DataLoader`.
Arguments being passed to :py:class:`torch.utils.data.DataLoader`.
...
@@ -288,6 +295,21 @@ class NodeDataLoader:
...
@@ -288,6 +295,21 @@ class NodeDataLoader:
>>> for input_nodes, output_nodes, blocks in dataloader:
>>> for input_nodes, output_nodes, blocks in dataloader:
... train_on(input_nodes, output_nodes, blocks)
... train_on(input_nodes, output_nodes, blocks)
**Using with Distributed Data Parallel**
If you are using PyTorch's distributed training (e.g. when using
:mod:`torch.nn.parallel.DistributedDataParallel`), you can train the model by turning
on the `use_ddp` option:
>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
>>> dataloader = dgl.dataloading.NodeDataLoader(
... g, train_nid, sampler, use_ddp=True,
... batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for epoch in range(start_epoch, n_epochs):
... dataloader.set_epoch(epoch)
... for input_nodes, output_nodes, blocks in dataloader:
... train_on(input_nodes, output_nodes, blocks)
Notes
Notes
-----
-----
Please refer to
Please refer to
...
@@ -296,7 +318,7 @@ class NodeDataLoader:
...
@@ -296,7 +318,7 @@ class NodeDataLoader:
"""
"""
collator_arglist
=
inspect
.
getfullargspec
(
NodeCollator
).
args
collator_arglist
=
inspect
.
getfullargspec
(
NodeCollator
).
args
def
__init__
(
self
,
g
,
nids
,
block_sampler
,
device
=
'cpu'
,
**
kwargs
):
def
__init__
(
self
,
g
,
nids
,
block_sampler
,
device
=
'cpu'
,
use_ddp
=
False
,
**
kwargs
):
collator_kwargs
=
{}
collator_kwargs
=
{}
dataloader_kwargs
=
{}
dataloader_kwargs
=
{}
for
k
,
v
in
kwargs
.
items
():
for
k
,
v
in
kwargs
.
items
():
...
@@ -347,10 +369,21 @@ class NodeDataLoader:
...
@@ -347,10 +369,21 @@ class NodeDataLoader:
dataloader_kwargs
[
'shuffle'
]
=
False
dataloader_kwargs
[
'shuffle'
]
=
False
dataloader_kwargs
[
'drop_last'
]
=
False
dataloader_kwargs
[
'drop_last'
]
=
False
self
.
use_ddp
=
use_ddp
if
use_ddp
:
self
.
dist_sampler
=
DistributedSampler
(
dataset
,
shuffle
=
dataloader_kwargs
[
'shuffle'
],
drop_last
=
dataloader_kwargs
[
'drop_last'
])
dataloader_kwargs
[
'shuffle'
]
=
False
dataloader_kwargs
[
'drop_last'
]
=
False
dataloader_kwargs
[
'sampler'
]
=
self
.
dist_sampler
self
.
dataloader
=
DataLoader
(
self
.
dataloader
=
DataLoader
(
dataset
,
dataset
,
collate_fn
=
self
.
collator
.
collate
,
collate_fn
=
self
.
collator
.
collate
,
**
dataloader_kwargs
)
**
dataloader_kwargs
)
self
.
is_distributed
=
False
self
.
is_distributed
=
False
# Precompute the CSR and CSC representations so each subprocess does not
# Precompute the CSR and CSC representations so each subprocess does not
...
@@ -371,6 +404,24 @@ class NodeDataLoader:
...
@@ -371,6 +404,24 @@ class NodeDataLoader:
"""Return the number of batches of the data loader."""
"""Return the number of batches of the data loader."""
return
len
(
self
.
dataloader
)
return
len
(
self
.
dataloader
)
def
set_epoch
(
self
,
epoch
):
"""Sets the epoch number for the underlying sampler which ensures all replicas
to use a different ordering for each epoch.
Only available when :attr:`use_ddp` is True.
Calls :meth:`torch.utils.data.distributed.DistributedSampler.set_epoch`.
Parameters
----------
epoch : int
The epoch number.
"""
if
self
.
use_ddp
:
self
.
dist_sampler
.
set_epoch
(
epoch
)
else
:
raise
DGLError
(
'set_epoch is only available when use_ddp is True.'
)
class
EdgeDataLoader
:
class
EdgeDataLoader
:
"""PyTorch dataloader for batch-iterating over a set of edges, generating the list
"""PyTorch dataloader for batch-iterating over a set of edges, generating the list
of message flow graphs (MFGs) as computation dependency of the said minibatch for
of message flow graphs (MFGs) as computation dependency of the said minibatch for
...
@@ -442,6 +493,15 @@ class EdgeDataLoader:
...
@@ -442,6 +493,15 @@ class EdgeDataLoader:
See the description of the argument with the same name in the docstring of
See the description of the argument with the same name in the docstring of
:class:`~dgl.dataloading.EdgeCollator` for more details.
:class:`~dgl.dataloading.EdgeCollator` for more details.
use_ddp : boolean, optional
If True, tells the DataLoader to split the training set for each
participating process appropriately using
:mod:`torch.utils.data.distributed.DistributedSampler`.
The dataloader will have a :attr:`dist_sampler` attribute to set the
epoch number, as recommended by PyTorch.
Overrides the :attr:`sampler` argument of :class:`torch.utils.data.DataLoader`.
kwargs : dict
kwargs : dict
Arguments being passed to :py:class:`torch.utils.data.DataLoader`.
Arguments being passed to :py:class:`torch.utils.data.DataLoader`.
...
@@ -524,6 +584,22 @@ class EdgeDataLoader:
...
@@ -524,6 +584,22 @@ class EdgeDataLoader:
>>> for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:
>>> for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:
... train_on(input_nodes, pair_graph, neg_pair_graph, blocks)
... train_on(input_nodes, pair_graph, neg_pair_graph, blocks)
**Using with Distributed Data Parallel**
If you are using PyTorch's distributed training (e.g. when using
:mod:`torch.nn.parallel.DistributedDataParallel`), you can train the model by
turning on the :attr:`use_ddp` option:
>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
>>> dataloader = dgl.dataloading.EdgeDataLoader(
... g, train_eid, sampler, use_ddp=True, exclude='reverse_id',
... reverse_eids=reverse_eids,
... batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for epoch in range(start_epoch, n_epochs):
... dataloader.set_epoch(epoch)
... for input_nodes, pair_graph, blocks in dataloader:
... train_on(input_nodes, pair_graph, blocks)
See also
See also
--------
--------
dgl.dataloading.dataloader.EdgeCollator
dgl.dataloading.dataloader.EdgeCollator
...
@@ -544,7 +620,7 @@ class EdgeDataLoader:
...
@@ -544,7 +620,7 @@ class EdgeDataLoader:
"""
"""
collator_arglist
=
inspect
.
getfullargspec
(
EdgeCollator
).
args
collator_arglist
=
inspect
.
getfullargspec
(
EdgeCollator
).
args
def
__init__
(
self
,
g
,
eids
,
block_sampler
,
device
=
'cpu'
,
**
kwargs
):
def
__init__
(
self
,
g
,
eids
,
block_sampler
,
device
=
'cpu'
,
use_ddp
=
False
,
**
kwargs
):
collator_kwargs
=
{}
collator_kwargs
=
{}
dataloader_kwargs
=
{}
dataloader_kwargs
=
{}
for
k
,
v
in
kwargs
.
items
():
for
k
,
v
in
kwargs
.
items
():
...
@@ -553,12 +629,27 @@ class EdgeDataLoader:
...
@@ -553,12 +629,27 @@ class EdgeDataLoader:
else
:
else
:
dataloader_kwargs
[
k
]
=
v
dataloader_kwargs
[
k
]
=
v
self
.
collator
=
_EdgeCollator
(
g
,
eids
,
block_sampler
,
**
collator_kwargs
)
self
.
collator
=
_EdgeCollator
(
g
,
eids
,
block_sampler
,
**
collator_kwargs
)
dataset
=
self
.
collator
.
dataset
assert
not
isinstance
(
g
,
DistGraph
),
\
assert
not
isinstance
(
g
,
DistGraph
),
\
'EdgeDataLoader does not support DistGraph for now. '
\
'EdgeDataLoader does not support DistGraph for now. '
\
+
'Please use DistDataLoader directly.'
+
'Please use DistDataLoader directly.'
self
.
use_ddp
=
use_ddp
if
use_ddp
:
self
.
dist_sampler
=
DistributedSampler
(
dataset
,
shuffle
=
dataloader_kwargs
[
'shuffle'
],
drop_last
=
dataloader_kwargs
[
'drop_last'
])
dataloader_kwargs
[
'shuffle'
]
=
False
dataloader_kwargs
[
'drop_last'
]
=
False
dataloader_kwargs
[
'sampler'
]
=
self
.
dist_sampler
self
.
dataloader
=
DataLoader
(
self
.
dataloader
=
DataLoader
(
self
.
collator
.
dataset
,
collate_fn
=
self
.
collator
.
collate
,
**
dataloader_kwargs
)
dataset
,
collate_fn
=
self
.
collator
.
collate
,
**
dataloader_kwargs
)
self
.
device
=
device
self
.
device
=
device
# Precompute the CSR and CSC representations so each subprocess does not
# Precompute the CSR and CSC representations so each subprocess does not
...
@@ -574,6 +665,24 @@ class EdgeDataLoader:
...
@@ -574,6 +665,24 @@ class EdgeDataLoader:
"""Return the number of batches of the data loader."""
"""Return the number of batches of the data loader."""
return
len
(
self
.
dataloader
)
return
len
(
self
.
dataloader
)
def
set_epoch
(
self
,
epoch
):
"""Sets the epoch number for the underlying sampler which ensures all replicas
to use a different ordering for each epoch.
Only available when :attr:`use_ddp` is True.
Calls :meth:`torch.utils.data.distributed.DistributedSampler.set_epoch`.
Parameters
----------
epoch : int
The epoch number.
"""
if
self
.
use_ddp
:
self
.
dist_sampler
.
set_epoch
(
epoch
)
else
:
raise
DGLError
(
'set_epoch is only available when use_ddp is True.'
)
class
GraphDataLoader
:
class
GraphDataLoader
:
"""PyTorch dataloader for batch-iterating over a set of graphs, generating the batched
"""PyTorch dataloader for batch-iterating over a set of graphs, generating the batched
graph and corresponding label tensor (if provided) of the said minibatch.
graph and corresponding label tensor (if provided) of the said minibatch.
...
@@ -595,10 +704,23 @@ class GraphDataLoader:
...
@@ -595,10 +704,23 @@ class GraphDataLoader:
... dataset, batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
... dataset, batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for batched_graph, labels in dataloader:
>>> for batched_graph, labels in dataloader:
... train_on(batched_graph, labels)
... train_on(batched_graph, labels)
**Using with Distributed Data Parallel**
If you are using PyTorch's distributed training (e.g. when using
:mod:`torch.nn.parallel.DistributedDataParallel`), you can train the model by
turning on the :attr:`use_ddp` option:
>>> dataloader = dgl.dataloading.GraphDataLoader(
... dataset, use_ddp=True, batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for epoch in range(start_epoch, n_epochs):
... dataloader.set_epoch(epoch)
... for batched_graph, labels in dataloader:
... train_on(batched_graph, labels)
"""
"""
collator_arglist
=
inspect
.
getfullargspec
(
GraphCollator
).
args
collator_arglist
=
inspect
.
getfullargspec
(
GraphCollator
).
args
def
__init__
(
self
,
dataset
,
collate_fn
=
None
,
**
kwargs
):
def
__init__
(
self
,
dataset
,
collate_fn
=
None
,
use_ddp
=
False
,
**
kwargs
):
collator_kwargs
=
{}
collator_kwargs
=
{}
dataloader_kwargs
=
{}
dataloader_kwargs
=
{}
for
k
,
v
in
kwargs
.
items
():
for
k
,
v
in
kwargs
.
items
():
...
@@ -612,6 +734,16 @@ class GraphDataLoader:
...
@@ -612,6 +734,16 @@ class GraphDataLoader:
else
:
else
:
self
.
collate
=
collate_fn
self
.
collate
=
collate_fn
self
.
use_ddp
=
use_ddp
if
use_ddp
:
self
.
dist_sampler
=
DistributedSampler
(
dataset
,
shuffle
=
dataloader_kwargs
[
'shuffle'
],
drop_last
=
dataloader_kwargs
[
'drop_last'
])
dataloader_kwargs
[
'shuffle'
]
=
False
dataloader_kwargs
[
'drop_last'
]
=
False
dataloader_kwargs
[
'sampler'
]
=
self
.
dist_sampler
self
.
dataloader
=
DataLoader
(
dataset
=
dataset
,
self
.
dataloader
=
DataLoader
(
dataset
=
dataset
,
collate_fn
=
self
.
collate
,
collate_fn
=
self
.
collate
,
**
dataloader_kwargs
)
**
dataloader_kwargs
)
...
@@ -623,3 +755,21 @@ class GraphDataLoader:
...
@@ -623,3 +755,21 @@ class GraphDataLoader:
def
__len__
(
self
):
def
__len__
(
self
):
"""Return the number of batches of the data loader."""
"""Return the number of batches of the data loader."""
return
len
(
self
.
dataloader
)
return
len
(
self
.
dataloader
)
def
set_epoch
(
self
,
epoch
):
"""Sets the epoch number for the underlying sampler which ensures all replicas
to use a different ordering for each epoch.
Only available when :attr:`use_ddp` is True.
Calls :meth:`torch.utils.data.distributed.DistributedSampler.set_epoch`.
Parameters
----------
epoch : int
The epoch number.
"""
if
self
.
use_ddp
:
self
.
dist_sampler
.
set_epoch
(
epoch
)
else
:
raise
DGLError
(
'set_epoch is only available when use_ddp is True.'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment