Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
0a87bc6a
Unverified
Commit
0a87bc6a
authored
Sep 22, 2023
by
Ramon Zhou
Committed by
GitHub
Sep 22, 2023
Browse files
[GraphBolt] Add DistributedItemSampler to support multi-gpu training (#6341)
parent
e22bd78f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
291 additions
and
19 deletions
+291
-19
python/dgl/graphbolt/item_sampler.py
python/dgl/graphbolt/item_sampler.py
+164
-19
tests/python/pytorch/graphbolt/test_item_sampler.py
tests/python/pytorch/graphbolt/test_item_sampler.py
+127
-0
No files found.
python/dgl/graphbolt/item_sampler.py
View file @
0a87bc6a
...
...
@@ -4,6 +4,7 @@ from collections.abc import Mapping
from
functools
import
partial
from
typing
import
Callable
,
Iterator
,
Optional
import
torch.distributed
as
dist
from
torch.utils.data
import
default_collate
from
torchdata.datapipes.iter
import
IterableWrapper
,
IterDataPipe
...
...
@@ -14,7 +15,7 @@ from ..heterograph import DGLGraph
from
.itemset
import
ItemSet
,
ItemSetDict
from
.minibatch
import
MiniBatch
__all__
=
[
"ItemSampler"
,
"minibatcher_default"
]
__all__
=
[
"ItemSampler"
,
"DistributedItemSampler"
,
"minibatcher_default"
]
def
minibatcher_default
(
batch
,
names
):
...
...
@@ -280,12 +281,30 @@ class ItemSampler(IterDataPipe):
shuffle
:
Optional
[
bool
]
=
False
,
)
->
None
:
super
().
__init__
()
self
.
_item_set
=
item_set
self
.
_names
=
item_set
.
names
self
.
_item_set
=
IterableWrapper
(
item_set
)
self
.
_batch_size
=
batch_size
self
.
_minibatcher
=
minibatcher
self
.
_drop_last
=
drop_last
self
.
_shuffle
=
shuffle
def
_organize_items
(
self
,
data_pipe
)
->
None
:
# Shuffle before batch.
if
self
.
_shuffle
:
# `torchdata.datapipes.iter.Shuffler` works with stream too.
# To ensure randomness, make sure the buffer size is at least 10
# times the batch size.
buffer_size
=
max
(
10000
,
10
*
self
.
_batch_size
)
data_pipe
=
data_pipe
.
shuffle
(
buffer_size
=
buffer_size
)
# Batch.
data_pipe
=
data_pipe
.
batch
(
batch_size
=
self
.
_batch_size
,
drop_last
=
self
.
_drop_last
,
)
return
data_pipe
@
staticmethod
def
_collate
(
batch
):
"""Collate items into a batch. For internal use only."""
...
...
@@ -306,27 +325,153 @@ class ItemSampler(IterDataPipe):
return
default_collate
(
batch
)
def
__iter__
(
self
)
->
Iterator
:
data_pipe
=
IterableWrapper
(
self
.
_item_set
)
# Shuffle before batch.
if
self
.
_shuffle
:
# `torchdata.datapipes.iter.Shuffler` works with stream too.
# To ensure randomness, make sure the buffer size is at least 10
# times the batch size.
buffer_size
=
max
(
10000
,
10
*
self
.
_batch_size
)
data_pipe
=
data_pipe
.
shuffle
(
buffer_size
=
buffer_size
)
# Batch.
data_pipe
=
data_pipe
.
batch
(
batch_size
=
self
.
_batch_size
,
drop_last
=
self
.
_drop_last
,
)
# Organize items.
data_pipe
=
self
.
_organize_items
(
self
.
_item_set
)
# Collate.
data_pipe
=
data_pipe
.
collate
(
collate_fn
=
self
.
_collate
)
# Map to minibatch.
data_pipe
=
data_pipe
.
map
(
partial
(
self
.
_minibatcher
,
names
=
self
.
_item_set
.
names
)
)
data_pipe
=
data_pipe
.
map
(
partial
(
self
.
_minibatcher
,
names
=
self
.
_names
))
return
iter
(
data_pipe
)
class
DistributedItemSampler
(
ItemSampler
):
"""Distributed Item Sampler.
This sampler creates a distributed subset of items from the given data set,
which can be used for training with PyTorch's Distributed Data Parallel
(DDP). The items can be node IDs, node pairs with or without labels, node
pairs with negative sources/destinations, DGLGraphs, or heterogeneous
counterparts. The original item set is sharded such that each replica
(process) receives an exclusive subset.
Note: The items will be first sharded onto each replica, then get shuffled
(if needed) and batched. Therefore, each replica will always get a same set
of items.
Note: This class `DistributedItemSampler` is not decorated with
`torchdata.datapipes.functional_datapipe` on purpose. This indicates it
does not support function-like call. But any iterable datapipes from
`torchdata` can be further appended.
Parameters
----------
item_set : ItemSet or ItemSetDict
Data to be sampled.
batch_size : int
The size of each batch.
minibatcher : Optional[Callable]
A callable that takes in a list of items and returns a `MiniBatch`.
drop_last : bool
Option to drop the last batch if it's not full.
shuffle : bool
Option to shuffle before sample.
num_replicas: int
The number of model replicas that will be created during Distributed
Data Parallel (DDP) training. It should be the same as the real world
size, otherwise it could cause errors. By default, it is retrieved from
the current distributed group.
drop_uneven_inputs : bool
Option to make sure the numbers of batches for each replica are the
same. If some of the replicas have more batches than the others, the
redundant batches of those replicas will be dropped. If the drop_last
parameter is also set to True, the last batch will be dropped before the
redundant batches are dropped.
Note: When using Distributed Data Parallel (DDP) training, the program
may hang or error if the a replica has fewer inputs. It is recommended
to use the Join Context Manager provided by PyTorch to solve this
problem. Please refer to
https://pytorch.org/tutorials/advanced/generic_join.html. However, this
option can be used if the Join Context Manager is not helpful for any
reason.
Examples
--------
1. num_replica = 4, batch_size = 2, shuffle = False, drop_last = False,
drop_uneven_inputs = False, item_set = [0, 1, 2, ..., 7, 8, 9]
- Replica#0 gets [[0, 4], [8]]
- Replica#1 gets [[1, 5], [9]]
- Replica#2 gets [[2, 6]]
- Replica#3 gets [[3, 7]]
2. num_replica = 4, batch_size = 2, shuffle = False, drop_last = True,
drop_uneven_inputs = False, item_set = [0, 1, 2, ..., 7, 8, 9].
- Replica#0 gets [[0, 4]]
- Replica#1 gets [[1, 5]]
- Replica#2 gets [[2, 6]]
- Replica#3 gets [[3, 7]]
3. num_replica = 4, batch_size = 2, shuffle = False, drop_last = True,
drop_uneven_inputs = False, item_set = [0, 1, 2, ..., 11, 12, 13].
- Replica#0 gets [[0, 4], [8, 12]]
- Replica#1 gets [[1, 5], [9, 13]]
- Replica#2 gets [[2, 6]]
- Replica#3 gets [[3, 7]]
3. num_replica = 4, batch_size = 2, shuffle = False, drop_last = False,
drop_uneven_inputs = True, item_set = [0, 1, 2, ..., 11, 12, 13].
- Replica#0 gets [[0, 4], [8, 12]]
- Replica#1 gets [[1, 5], [9, 13]]
- Replica#2 gets [[2, 6], [10]]
- Replica#3 gets [[3, 7], [11]]
4. num_replica = 4, batch_size = 2, shuffle = False, drop_last = True,
drop_uneven_inputs = True, item_set = [0, 1, 2, ..., 11, 12, 13].
- Replica#0 gets [[0, 4]]
- Replica#1 gets [[1, 5]]
- Replica#2 gets [[2, 6]]
- Replica#3 gets [[3, 7]]
5. num_replica = 4, batch_size = 2, shuffle = True, drop_last = True,
drop_uneven_inputs = False, item_set = [0, 1, 2, ..., 11, 12, 13].
One possible output:
- Replica#0 gets [[8, 0], [12, 4]]
- Replica#1 gets [[13, 1], [9, 5]]
- Replica#2 gets [[10, 2]]
- Replica#3 gets [[7, 11]]
"""
def
__init__
(
self
,
item_set
:
ItemSet
or
ItemSetDict
,
batch_size
:
int
,
minibatcher
:
Optional
[
Callable
]
=
minibatcher_default
,
drop_last
:
Optional
[
bool
]
=
False
,
shuffle
:
Optional
[
bool
]
=
False
,
num_replicas
:
Optional
[
int
]
=
None
,
drop_uneven_inputs
:
Optional
[
bool
]
=
False
,
)
->
None
:
super
().
__init__
(
item_set
,
batch_size
,
minibatcher
,
drop_last
,
shuffle
)
self
.
_drop_uneven_inputs
=
drop_uneven_inputs
# Apply a sharding filter to distribute the items.
self
.
_item_set
=
self
.
_item_set
.
sharding_filter
()
# Get world size.
if
num_replicas
is
None
:
assert
(
dist
.
is_available
()
),
"Requires distributed package to be available."
num_replicas
=
dist
.
get_world_size
()
if
self
.
_drop_uneven_inputs
:
# If the len() method of the item_set is not available, it will
# throw an exception.
total_len
=
len
(
item_set
)
# Calculate the number of batches after dropping uneven batches for
# each replica.
self
.
_num_evened_batches
=
total_len
//
(
num_replicas
*
batch_size
)
+
(
(
not
drop_last
)
and
(
total_len
%
(
num_replicas
*
batch_size
)
>=
num_replicas
)
)
def
_organize_items
(
self
,
data_pipe
)
->
None
:
data_pipe
=
super
().
_organize_items
(
data_pipe
)
# If drop_uneven_inputs is True, drop the excessive inputs by limiting
# the length of the datapipe.
if
self
.
_drop_uneven_inputs
:
data_pipe
=
data_pipe
.
header
(
self
.
_num_evened_batches
)
return
data_pipe
tests/python/pytorch/graphbolt/test_item_sampler.py
View file @
0a87bc6a
import
os
import
re
from
sys
import
platform
import
dgl
import
pytest
import
torch
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
from
dgl
import
graphbolt
as
gb
from
torch.testing
import
assert_close
...
...
@@ -610,3 +614,126 @@ def test_ItemSetDict_node_pairs_negative_dsts(batch_size, shuffle, drop_last):
assert
torch
.
all
(
src_ids
[:
-
1
]
<=
src_ids
[
1
:])
is
not
shuffle
assert
torch
.
all
(
dst_ids
[:
-
1
]
<=
dst_ids
[
1
:])
is
not
shuffle
assert
torch
.
all
(
negs_ids
[:
-
1
]
<=
negs_ids
[
1
:])
is
not
shuffle
def
distributed_item_sampler_subprocess
(
proc_id
,
nprocs
,
item_set
,
num_ids
,
batch_size
,
shuffle
,
drop_last
,
drop_uneven_inputs
,
):
# On Windows, the init method can only be file.
init_method
=
(
f
"file:///
{
os
.
path
.
join
(
os
.
getcwd
(),
'dis_tempfile'
)
}
"
if
platform
==
"win32"
else
"tcp://127.0.0.1:12345"
)
dist
.
init_process_group
(
backend
=
"gloo"
,
# Use Gloo backend for CPU multiprocessing
init_method
=
init_method
,
world_size
=
nprocs
,
rank
=
proc_id
,
)
# Create a DistributedItemSampler.
item_sampler
=
gb
.
DistributedItemSampler
(
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
,
drop_uneven_inputs
=
drop_uneven_inputs
,
)
feature_fetcher
=
gb
.
FeatureFetcher
(
item_sampler
,
gb
.
BasicFeatureStore
({}),
[],
)
data_loader
=
gb
.
SingleProcessDataLoader
(
feature_fetcher
)
# Count the numbers of items and batches.
num_items
=
0
sampled_count
=
torch
.
zeros
(
num_ids
,
dtype
=
torch
.
int32
)
for
i
in
data_loader
:
# Count how many times each item is sampled.
sampled_count
[
i
.
seed_nodes
]
+=
1
num_items
+=
i
.
seed_nodes
.
size
(
0
)
num_batches
=
len
(
list
(
item_sampler
))
# Calculate expected numbers of items and batches.
expected_num_items
=
num_ids
//
nprocs
+
(
num_ids
%
nprocs
>
proc_id
)
if
drop_last
and
expected_num_items
%
batch_size
>
0
:
expected_num_items
-=
expected_num_items
%
batch_size
expected_num_batches
=
expected_num_items
//
batch_size
+
(
(
not
drop_last
)
and
(
expected_num_items
%
batch_size
>
0
)
)
if
drop_uneven_inputs
:
if
(
(
not
drop_last
)
and
(
num_ids
%
(
nprocs
*
batch_size
)
<
nprocs
)
and
(
num_ids
%
(
nprocs
*
batch_size
)
>
proc_id
)
):
expected_num_batches
-=
1
expected_num_items
-=
1
elif
(
drop_last
and
(
nprocs
*
batch_size
-
num_ids
%
(
nprocs
*
batch_size
)
<
nprocs
)
and
(
num_ids
%
nprocs
>
proc_id
)
):
expected_num_batches
-=
1
expected_num_items
-=
batch_size
num_batches_tensor
=
torch
.
tensor
(
num_batches
)
dist
.
broadcast
(
num_batches_tensor
,
0
)
# Test if the number of batches are the same for all processes.
assert
num_batches_tensor
==
num_batches
# Add up results from all processes.
dist
.
reduce
(
sampled_count
,
0
)
try
:
# Check if the numbers are as expected.
assert
num_items
==
expected_num_items
assert
num_batches
==
expected_num_batches
# Make sure no item is sampled more than once.
assert
sampled_count
.
max
()
<=
1
finally
:
dist
.
destroy_process_group
()
@
pytest
.
mark
.
parametrize
(
"num_ids"
,
[
24
,
30
,
32
,
34
,
36
])
@
pytest
.
mark
.
parametrize
(
"shuffle"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"drop_last"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"drop_uneven_inputs"
,
[
False
,
True
])
def
test_DistributedItemSampler
(
num_ids
,
shuffle
,
drop_last
,
drop_uneven_inputs
):
nprocs
=
4
batch_size
=
4
item_set
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
num_ids
),
names
=
"seed_nodes"
)
# On Windows, if the process group initialization file already exists,
# the program may hang. So we need to delete it if it exists.
if
platform
==
"win32"
:
try
:
os
.
remove
(
os
.
path
.
join
(
os
.
getcwd
(),
"dis_tempfile"
))
except
FileNotFoundError
:
pass
mp
.
spawn
(
distributed_item_sampler_subprocess
,
args
=
(
nprocs
,
item_set
,
num_ids
,
batch_size
,
shuffle
,
drop_last
,
drop_uneven_inputs
,
),
nprocs
=
nprocs
,
join
=
True
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment