Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
d79701db
Unverified
Commit
d79701db
authored
Nov 08, 2023
by
Ramon Zhou
Committed by
GitHub
Nov 08, 2023
Browse files
[GraphBolt] Use indexing in DistributedItemSampler (#6508)
parent
8f1b5782
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
121 additions
and
62 deletions
+121
-62
python/dgl/graphbolt/item_sampler.py
python/dgl/graphbolt/item_sampler.py
+119
-37
tests/python/pytorch/graphbolt/test_item_sampler.py
tests/python/pytorch/graphbolt/test_item_sampler.py
+2
-25
No files found.
python/dgl/graphbolt/item_sampler.py
View file @
d79701db
...
...
@@ -109,6 +109,10 @@ class ItemShufflerAndBatcher:
batch_size
:
int
,
drop_last
:
bool
,
buffer_size
:
Optional
[
int
]
=
10
*
1000
,
distributed
:
Optional
[
bool
]
=
False
,
drop_uneven_inputs
:
Optional
[
bool
]
=
False
,
world_size
:
Optional
[
int
]
=
1
,
rank
:
Optional
[
int
]
=
0
,
):
self
.
_item_set
=
item_set
self
.
_shuffle
=
shuffle
...
...
@@ -119,6 +123,11 @@ class ItemShufflerAndBatcher:
self
.
_buffer_size
=
(
(
self
.
_buffer_size
+
batch_size
-
1
)
//
batch_size
*
batch_size
)
self
.
_distributed
=
distributed
self
.
_drop_uneven_inputs
=
drop_uneven_inputs
if
distributed
:
self
.
_num_replicas
=
world_size
self
.
_rank
=
rank
def
_collate_batch
(
self
,
buffer
,
indices
,
offsets
=
None
):
"""Collate a batch from the buffer. For internal use only."""
...
...
@@ -167,12 +176,95 @@ class ItemShufflerAndBatcher:
return
torch
.
tensor
(
offsets
)
def
__iter__
(
self
):
worker_info
=
torch
.
utils
.
data
.
get_worker_info
()
if
worker_info
is
not
None
:
num_workers
=
worker_info
.
num_workers
worker_id
=
worker_info
.
id
else
:
num_workers
=
1
worker_id
=
0
buffer
=
None
num_items
=
len
(
self
.
_item_set
)
if
not
self
.
_distributed
:
num_items
=
len
(
self
.
_item_set
)
start_offset
=
0
else
:
total_count
=
len
(
self
.
_item_set
)
big_batch_size
=
self
.
_num_replicas
*
self
.
_batch_size
big_batch_count
,
big_batch_remain
=
divmod
(
total_count
,
big_batch_size
)
last_batch_count
,
batch_remain
=
divmod
(
big_batch_remain
,
self
.
_batch_size
)
if
self
.
_rank
<
last_batch_count
:
last_batch
=
self
.
_batch_size
elif
self
.
_rank
==
last_batch_count
:
last_batch
=
batch_remain
else
:
last_batch
=
0
num_items
=
big_batch_count
*
self
.
_batch_size
+
last_batch
start_offset
=
(
big_batch_count
*
self
.
_batch_size
*
self
.
_rank
+
min
(
self
.
_rank
*
self
.
_batch_size
,
big_batch_remain
)
)
if
not
self
.
_drop_uneven_inputs
or
(
not
self
.
_drop_last
and
last_batch_count
==
self
.
_num_replicas
):
# No need to drop uneven batches.
num_evened_items
=
num_items
if
num_workers
>
1
:
total_batch_count
=
(
num_items
+
self
.
_batch_size
-
1
)
//
self
.
_batch_size
split_batch_count
=
total_batch_count
//
num_workers
+
(
worker_id
<
total_batch_count
%
num_workers
)
split_num_items
=
split_batch_count
*
self
.
_batch_size
num_items
=
(
min
(
num_items
,
split_num_items
*
(
worker_id
+
1
))
-
split_num_items
*
worker_id
)
num_evened_items
=
num_items
start_offset
=
(
big_batch_count
*
self
.
_batch_size
*
self
.
_rank
+
min
(
self
.
_rank
*
self
.
_batch_size
,
big_batch_remain
)
+
self
.
_batch_size
*
(
total_batch_count
//
num_workers
*
worker_id
+
min
(
worker_id
,
total_batch_count
%
num_workers
)
)
)
else
:
# Needs to drop uneven batches. As many items as `last_batch`
# size will be dropped. It would be better not to let those
# dropped items come from the same worker.
num_evened_items
=
big_batch_count
*
self
.
_batch_size
if
num_workers
>
1
:
total_batch_count
=
big_batch_count
split_batch_count
=
total_batch_count
//
num_workers
+
(
worker_id
<
total_batch_count
%
num_workers
)
split_num_items
=
split_batch_count
*
self
.
_batch_size
split_item_remain
=
last_batch
//
num_workers
+
(
worker_id
<
last_batch
%
num_workers
)
num_items
=
split_num_items
+
split_item_remain
num_evened_items
=
split_num_items
start_offset
=
(
big_batch_count
*
self
.
_batch_size
*
self
.
_rank
+
min
(
self
.
_rank
*
self
.
_batch_size
,
big_batch_remain
)
+
self
.
_batch_size
*
(
total_batch_count
//
num_workers
*
worker_id
+
min
(
worker_id
,
total_batch_count
%
num_workers
)
)
+
last_batch
//
num_workers
*
worker_id
+
min
(
worker_id
,
last_batch
%
num_workers
)
)
start
=
0
while
start
<
num_items
:
end
=
min
(
start
+
self
.
_buffer_size
,
num_items
)
buffer
=
self
.
_item_set
[
start
:
end
]
buffer
=
self
.
_item_set
[
start
_offset
+
start
:
start_offset
+
end
]
indices
=
torch
.
arange
(
end
-
start
)
if
self
.
_shuffle
:
np
.
random
.
shuffle
(
indices
.
numpy
())
...
...
@@ -180,6 +272,12 @@ class ItemShufflerAndBatcher:
for
i
in
range
(
0
,
len
(
indices
),
self
.
_batch_size
):
if
self
.
_drop_last
and
i
+
self
.
_batch_size
>
len
(
indices
):
break
if
(
self
.
_distributed
and
self
.
_drop_uneven_inputs
and
i
>=
num_evened_items
):
break
batch_indices
=
indices
[
i
:
i
+
self
.
_batch_size
]
yield
self
.
_collate_batch
(
buffer
,
batch_indices
,
offsets
)
buffer
=
None
...
...
@@ -416,6 +514,10 @@ class ItemSampler(IterDataPipe):
self
.
_drop_last
=
drop_last
self
.
_shuffle
=
shuffle
self
.
_use_indexing
=
use_indexing
self
.
_distributed
=
False
self
.
_drop_uneven_inputs
=
False
self
.
_world_size
=
None
self
.
_rank
=
None
def
_organize_items
(
self
,
data_pipe
)
->
None
:
# Shuffle before batch.
...
...
@@ -461,6 +563,10 @@ class ItemSampler(IterDataPipe):
self
.
_shuffle
,
self
.
_batch_size
,
self
.
_drop_last
,
distributed
=
self
.
_distributed
,
drop_uneven_inputs
=
self
.
_drop_uneven_inputs
,
world_size
=
self
.
_world_size
,
rank
=
self
.
_rank
,
)
)
else
:
...
...
@@ -483,14 +589,14 @@ class DistributedItemSampler(ItemSampler):
which can be used for training with PyTorch's Distributed Data Parallel
(DDP). The items can be node IDs, node pairs with or without labels, node
pairs with negative sources/destinations, DGLGraphs, or heterogeneous
counterparts. The original item set is s
harded
such that each replica
counterparts. The original item set is s
plit
such that each replica
(process) receives an exclusive subset.
Note: DistributedItemSampler may not work as expected when it is the last
datapipe before the data is fetched. Please wrap a SingleProcessDataLoader
or another datapipe on it.
Note: The items will be first s
harded
onto each replica, then get shuffled
Note: The items will be first s
plit
onto each replica, then get shuffled
(if needed) and batched. Therefore, each replica will always get a same set
of items.
...
...
@@ -532,6 +638,7 @@ class DistributedItemSampler(ItemSampler):
Examples
--------
TODO[Kaicheng]: Modify examples here.
0. Preparation: DistributedItemSampler needs multi-processing environment to
work. You need to spawn subprocesses and initialize processing group before
executing following examples. Due to randomness, the output is not always
...
...
@@ -539,7 +646,7 @@ class DistributedItemSampler(ItemSampler):
>>> import torch
>>> from dgl import graphbolt as gb
>>> item_set = gb.ItemSet(torch.arange(0, 1
3
))
>>> item_set = gb.ItemSet(torch.arange(0, 1
4
))
>>> num_replicas = 4
>>> batch_size = 2
>>> mp.spawn(...)
...
...
@@ -632,46 +739,21 @@ class DistributedItemSampler(ItemSampler):
minibatcher
:
Optional
[
Callable
]
=
minibatcher_default
,
drop_last
:
Optional
[
bool
]
=
False
,
shuffle
:
Optional
[
bool
]
=
False
,
num_replicas
:
Optional
[
int
]
=
None
,
drop_uneven_inputs
:
Optional
[
bool
]
=
False
,
)
->
None
:
# [TODO][Rui] For now, always set use_indexing to False.
super
().
__init__
(
item_set
,
batch_size
,
minibatcher
,
drop_last
,
shuffle
,
use_indexing
=
Fals
e
,
use_indexing
=
Tru
e
,
)
self
.
_distributed
=
True
self
.
_drop_uneven_inputs
=
drop_uneven_inputs
# Apply a sharding filter to distribute the items.
self
.
_item_set
=
self
.
_item_set
.
sharding_filter
()
# Get world size.
if
num_replicas
is
None
:
assert
(
dist
.
is_available
()
),
"Requires distributed package to be available."
num_replicas
=
dist
.
get_world_size
()
if
self
.
_drop_uneven_inputs
:
# If the len() method of the item_set is not available, it will
# throw an exception.
total_len
=
len
(
item_set
)
# Calculate the number of batches after dropping uneven batches for
# each replica.
self
.
_num_evened_batches
=
total_len
//
(
num_replicas
*
batch_size
)
+
(
(
not
drop_last
)
and
(
total_len
%
(
num_replicas
*
batch_size
)
>=
num_replicas
)
if
not
dist
.
is_available
():
raise
RuntimeError
(
"Distributed item sampler requires distributed package."
)
def
_organize_items
(
self
,
data_pipe
)
->
None
:
data_pipe
=
super
().
_organize_items
(
data_pipe
)
# If drop_uneven_inputs is True, drop the excessive inputs by limiting
# the length of the datapipe.
if
self
.
_drop_uneven_inputs
:
data_pipe
=
data_pipe
.
header
(
self
.
_num_evened_batches
)
return
data_pipe
self
.
_world_size
=
dist
.
get_world_size
()
self
.
_rank
=
dist
.
get_rank
()
tests/python/pytorch/graphbolt/test_item_sampler.py
View file @
d79701db
...
...
@@ -767,31 +767,12 @@ def distributed_item_sampler_subprocess(
for
i
in
data_loader
:
# Count how many times each item is sampled.
sampled_count
[
i
.
seed_nodes
]
+=
1
if
drop_last
:
assert
i
.
seed_nodes
.
size
(
0
)
==
batch_size
num_items
+=
i
.
seed_nodes
.
size
(
0
)
num_batches
=
len
(
list
(
item_sampler
))
# Calculate expected numbers of items and batches.
expected_num_items
=
num_ids
//
nprocs
+
(
num_ids
%
nprocs
>
proc_id
)
if
drop_last
and
expected_num_items
%
batch_size
>
0
:
expected_num_items
-=
expected_num_items
%
batch_size
expected_num_batches
=
expected_num_items
//
batch_size
+
(
(
not
drop_last
)
and
(
expected_num_items
%
batch_size
>
0
)
)
if
drop_uneven_inputs
:
if
(
(
not
drop_last
)
and
(
num_ids
%
(
nprocs
*
batch_size
)
<
nprocs
)
and
(
num_ids
%
(
nprocs
*
batch_size
)
>
proc_id
)
):
expected_num_batches
-=
1
expected_num_items
-=
1
elif
(
drop_last
and
(
nprocs
*
batch_size
-
num_ids
%
(
nprocs
*
batch_size
)
<
nprocs
)
and
(
num_ids
%
nprocs
>
proc_id
)
):
expected_num_batches
-=
1
expected_num_items
-=
batch_size
num_batches_tensor
=
torch
.
tensor
(
num_batches
)
dist
.
broadcast
(
num_batches_tensor
,
0
)
# Test if the number of batches are the same for all processes.
...
...
@@ -801,10 +782,6 @@ def distributed_item_sampler_subprocess(
dist
.
reduce
(
sampled_count
,
0
)
try
:
# Check if the numbers are as expected.
assert
num_items
==
expected_num_items
assert
num_batches
==
expected_num_batches
# Make sure no item is sampled more than once.
assert
sampled_count
.
max
()
<=
1
finally
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment