Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
d79701db
Unverified
Commit
d79701db
authored
Nov 08, 2023
by
Ramon Zhou
Committed by
GitHub
Nov 08, 2023
Browse files
[GraphBolt] Use indexing in DistributedItemSampler (#6508)
parent
8f1b5782
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
121 additions
and
62 deletions
+121
-62
python/dgl/graphbolt/item_sampler.py
python/dgl/graphbolt/item_sampler.py
+119
-37
tests/python/pytorch/graphbolt/test_item_sampler.py
tests/python/pytorch/graphbolt/test_item_sampler.py
+2
-25
No files found.
python/dgl/graphbolt/item_sampler.py
View file @
d79701db
...
@@ -109,6 +109,10 @@ class ItemShufflerAndBatcher:
...
@@ -109,6 +109,10 @@ class ItemShufflerAndBatcher:
batch_size
:
int
,
batch_size
:
int
,
drop_last
:
bool
,
drop_last
:
bool
,
buffer_size
:
Optional
[
int
]
=
10
*
1000
,
buffer_size
:
Optional
[
int
]
=
10
*
1000
,
distributed
:
Optional
[
bool
]
=
False
,
drop_uneven_inputs
:
Optional
[
bool
]
=
False
,
world_size
:
Optional
[
int
]
=
1
,
rank
:
Optional
[
int
]
=
0
,
):
):
self
.
_item_set
=
item_set
self
.
_item_set
=
item_set
self
.
_shuffle
=
shuffle
self
.
_shuffle
=
shuffle
...
@@ -119,6 +123,11 @@ class ItemShufflerAndBatcher:
...
@@ -119,6 +123,11 @@ class ItemShufflerAndBatcher:
self
.
_buffer_size
=
(
self
.
_buffer_size
=
(
(
self
.
_buffer_size
+
batch_size
-
1
)
//
batch_size
*
batch_size
(
self
.
_buffer_size
+
batch_size
-
1
)
//
batch_size
*
batch_size
)
)
self
.
_distributed
=
distributed
self
.
_drop_uneven_inputs
=
drop_uneven_inputs
if
distributed
:
self
.
_num_replicas
=
world_size
self
.
_rank
=
rank
def
_collate_batch
(
self
,
buffer
,
indices
,
offsets
=
None
):
def
_collate_batch
(
self
,
buffer
,
indices
,
offsets
=
None
):
"""Collate a batch from the buffer. For internal use only."""
"""Collate a batch from the buffer. For internal use only."""
...
@@ -167,12 +176,95 @@ class ItemShufflerAndBatcher:
...
@@ -167,12 +176,95 @@ class ItemShufflerAndBatcher:
return
torch
.
tensor
(
offsets
)
return
torch
.
tensor
(
offsets
)
def
__iter__
(
self
):
def
__iter__
(
self
):
worker_info
=
torch
.
utils
.
data
.
get_worker_info
()
if
worker_info
is
not
None
:
num_workers
=
worker_info
.
num_workers
worker_id
=
worker_info
.
id
else
:
num_workers
=
1
worker_id
=
0
buffer
=
None
buffer
=
None
if
not
self
.
_distributed
:
num_items
=
len
(
self
.
_item_set
)
num_items
=
len
(
self
.
_item_set
)
start_offset
=
0
else
:
total_count
=
len
(
self
.
_item_set
)
big_batch_size
=
self
.
_num_replicas
*
self
.
_batch_size
big_batch_count
,
big_batch_remain
=
divmod
(
total_count
,
big_batch_size
)
last_batch_count
,
batch_remain
=
divmod
(
big_batch_remain
,
self
.
_batch_size
)
if
self
.
_rank
<
last_batch_count
:
last_batch
=
self
.
_batch_size
elif
self
.
_rank
==
last_batch_count
:
last_batch
=
batch_remain
else
:
last_batch
=
0
num_items
=
big_batch_count
*
self
.
_batch_size
+
last_batch
start_offset
=
(
big_batch_count
*
self
.
_batch_size
*
self
.
_rank
+
min
(
self
.
_rank
*
self
.
_batch_size
,
big_batch_remain
)
)
if
not
self
.
_drop_uneven_inputs
or
(
not
self
.
_drop_last
and
last_batch_count
==
self
.
_num_replicas
):
# No need to drop uneven batches.
num_evened_items
=
num_items
if
num_workers
>
1
:
total_batch_count
=
(
num_items
+
self
.
_batch_size
-
1
)
//
self
.
_batch_size
split_batch_count
=
total_batch_count
//
num_workers
+
(
worker_id
<
total_batch_count
%
num_workers
)
split_num_items
=
split_batch_count
*
self
.
_batch_size
num_items
=
(
min
(
num_items
,
split_num_items
*
(
worker_id
+
1
))
-
split_num_items
*
worker_id
)
num_evened_items
=
num_items
start_offset
=
(
big_batch_count
*
self
.
_batch_size
*
self
.
_rank
+
min
(
self
.
_rank
*
self
.
_batch_size
,
big_batch_remain
)
+
self
.
_batch_size
*
(
total_batch_count
//
num_workers
*
worker_id
+
min
(
worker_id
,
total_batch_count
%
num_workers
)
)
)
else
:
# Needs to drop uneven batches. As many items as `last_batch`
# size will be dropped. It would be better not to let those
# dropped items come from the same worker.
num_evened_items
=
big_batch_count
*
self
.
_batch_size
if
num_workers
>
1
:
total_batch_count
=
big_batch_count
split_batch_count
=
total_batch_count
//
num_workers
+
(
worker_id
<
total_batch_count
%
num_workers
)
split_num_items
=
split_batch_count
*
self
.
_batch_size
split_item_remain
=
last_batch
//
num_workers
+
(
worker_id
<
last_batch
%
num_workers
)
num_items
=
split_num_items
+
split_item_remain
num_evened_items
=
split_num_items
start_offset
=
(
big_batch_count
*
self
.
_batch_size
*
self
.
_rank
+
min
(
self
.
_rank
*
self
.
_batch_size
,
big_batch_remain
)
+
self
.
_batch_size
*
(
total_batch_count
//
num_workers
*
worker_id
+
min
(
worker_id
,
total_batch_count
%
num_workers
)
)
+
last_batch
//
num_workers
*
worker_id
+
min
(
worker_id
,
last_batch
%
num_workers
)
)
start
=
0
start
=
0
while
start
<
num_items
:
while
start
<
num_items
:
end
=
min
(
start
+
self
.
_buffer_size
,
num_items
)
end
=
min
(
start
+
self
.
_buffer_size
,
num_items
)
buffer
=
self
.
_item_set
[
start
:
end
]
buffer
=
self
.
_item_set
[
start
_offset
+
start
:
start_offset
+
end
]
indices
=
torch
.
arange
(
end
-
start
)
indices
=
torch
.
arange
(
end
-
start
)
if
self
.
_shuffle
:
if
self
.
_shuffle
:
np
.
random
.
shuffle
(
indices
.
numpy
())
np
.
random
.
shuffle
(
indices
.
numpy
())
...
@@ -180,6 +272,12 @@ class ItemShufflerAndBatcher:
...
@@ -180,6 +272,12 @@ class ItemShufflerAndBatcher:
for
i
in
range
(
0
,
len
(
indices
),
self
.
_batch_size
):
for
i
in
range
(
0
,
len
(
indices
),
self
.
_batch_size
):
if
self
.
_drop_last
and
i
+
self
.
_batch_size
>
len
(
indices
):
if
self
.
_drop_last
and
i
+
self
.
_batch_size
>
len
(
indices
):
break
break
if
(
self
.
_distributed
and
self
.
_drop_uneven_inputs
and
i
>=
num_evened_items
):
break
batch_indices
=
indices
[
i
:
i
+
self
.
_batch_size
]
batch_indices
=
indices
[
i
:
i
+
self
.
_batch_size
]
yield
self
.
_collate_batch
(
buffer
,
batch_indices
,
offsets
)
yield
self
.
_collate_batch
(
buffer
,
batch_indices
,
offsets
)
buffer
=
None
buffer
=
None
...
@@ -416,6 +514,10 @@ class ItemSampler(IterDataPipe):
...
@@ -416,6 +514,10 @@ class ItemSampler(IterDataPipe):
self
.
_drop_last
=
drop_last
self
.
_drop_last
=
drop_last
self
.
_shuffle
=
shuffle
self
.
_shuffle
=
shuffle
self
.
_use_indexing
=
use_indexing
self
.
_use_indexing
=
use_indexing
self
.
_distributed
=
False
self
.
_drop_uneven_inputs
=
False
self
.
_world_size
=
None
self
.
_rank
=
None
def
_organize_items
(
self
,
data_pipe
)
->
None
:
def
_organize_items
(
self
,
data_pipe
)
->
None
:
# Shuffle before batch.
# Shuffle before batch.
...
@@ -461,6 +563,10 @@ class ItemSampler(IterDataPipe):
...
@@ -461,6 +563,10 @@ class ItemSampler(IterDataPipe):
self
.
_shuffle
,
self
.
_shuffle
,
self
.
_batch_size
,
self
.
_batch_size
,
self
.
_drop_last
,
self
.
_drop_last
,
distributed
=
self
.
_distributed
,
drop_uneven_inputs
=
self
.
_drop_uneven_inputs
,
world_size
=
self
.
_world_size
,
rank
=
self
.
_rank
,
)
)
)
)
else
:
else
:
...
@@ -483,14 +589,14 @@ class DistributedItemSampler(ItemSampler):
...
@@ -483,14 +589,14 @@ class DistributedItemSampler(ItemSampler):
which can be used for training with PyTorch's Distributed Data Parallel
which can be used for training with PyTorch's Distributed Data Parallel
(DDP). The items can be node IDs, node pairs with or without labels, node
(DDP). The items can be node IDs, node pairs with or without labels, node
pairs with negative sources/destinations, DGLGraphs, or heterogeneous
pairs with negative sources/destinations, DGLGraphs, or heterogeneous
counterparts. The original item set is s
harded
such that each replica
counterparts. The original item set is s
plit
such that each replica
(process) receives an exclusive subset.
(process) receives an exclusive subset.
Note: DistributedItemSampler may not work as expected when it is the last
Note: DistributedItemSampler may not work as expected when it is the last
datapipe before the data is fetched. Please wrap a SingleProcessDataLoader
datapipe before the data is fetched. Please wrap a SingleProcessDataLoader
or another datapipe on it.
or another datapipe on it.
Note: The items will be first s
harded
onto each replica, then get shuffled
Note: The items will be first s
plit
onto each replica, then get shuffled
(if needed) and batched. Therefore, each replica will always get a same set
(if needed) and batched. Therefore, each replica will always get a same set
of items.
of items.
...
@@ -532,6 +638,7 @@ class DistributedItemSampler(ItemSampler):
...
@@ -532,6 +638,7 @@ class DistributedItemSampler(ItemSampler):
Examples
Examples
--------
--------
TODO[Kaicheng]: Modify examples here.
0. Preparation: DistributedItemSampler needs multi-processing environment to
0. Preparation: DistributedItemSampler needs multi-processing environment to
work. You need to spawn subprocesses and initialize processing group before
work. You need to spawn subprocesses and initialize processing group before
executing following examples. Due to randomness, the output is not always
executing following examples. Due to randomness, the output is not always
...
@@ -539,7 +646,7 @@ class DistributedItemSampler(ItemSampler):
...
@@ -539,7 +646,7 @@ class DistributedItemSampler(ItemSampler):
>>> import torch
>>> import torch
>>> from dgl import graphbolt as gb
>>> from dgl import graphbolt as gb
>>> item_set = gb.ItemSet(torch.arange(0, 1
3
))
>>> item_set = gb.ItemSet(torch.arange(0, 1
4
))
>>> num_replicas = 4
>>> num_replicas = 4
>>> batch_size = 2
>>> batch_size = 2
>>> mp.spawn(...)
>>> mp.spawn(...)
...
@@ -632,46 +739,21 @@ class DistributedItemSampler(ItemSampler):
...
@@ -632,46 +739,21 @@ class DistributedItemSampler(ItemSampler):
minibatcher
:
Optional
[
Callable
]
=
minibatcher_default
,
minibatcher
:
Optional
[
Callable
]
=
minibatcher_default
,
drop_last
:
Optional
[
bool
]
=
False
,
drop_last
:
Optional
[
bool
]
=
False
,
shuffle
:
Optional
[
bool
]
=
False
,
shuffle
:
Optional
[
bool
]
=
False
,
num_replicas
:
Optional
[
int
]
=
None
,
drop_uneven_inputs
:
Optional
[
bool
]
=
False
,
drop_uneven_inputs
:
Optional
[
bool
]
=
False
,
)
->
None
:
)
->
None
:
# [TODO][Rui] For now, always set use_indexing to False.
super
().
__init__
(
super
().
__init__
(
item_set
,
item_set
,
batch_size
,
batch_size
,
minibatcher
,
minibatcher
,
drop_last
,
drop_last
,
shuffle
,
shuffle
,
use_indexing
=
Fals
e
,
use_indexing
=
Tru
e
,
)
)
self
.
_distributed
=
True
self
.
_drop_uneven_inputs
=
drop_uneven_inputs
self
.
_drop_uneven_inputs
=
drop_uneven_inputs
# Apply a sharding filter to distribute the items.
if
not
dist
.
is_available
():
self
.
_item_set
=
self
.
_item_set
.
sharding_filter
()
raise
RuntimeError
(
# Get world size.
"Distributed item sampler requires distributed package."
if
num_replicas
is
None
:
assert
(
dist
.
is_available
()
),
"Requires distributed package to be available."
num_replicas
=
dist
.
get_world_size
()
if
self
.
_drop_uneven_inputs
:
# If the len() method of the item_set is not available, it will
# throw an exception.
total_len
=
len
(
item_set
)
# Calculate the number of batches after dropping uneven batches for
# each replica.
self
.
_num_evened_batches
=
total_len
//
(
num_replicas
*
batch_size
)
+
(
(
not
drop_last
)
and
(
total_len
%
(
num_replicas
*
batch_size
)
>=
num_replicas
)
)
)
self
.
_world_size
=
dist
.
get_world_size
()
def
_organize_items
(
self
,
data_pipe
)
->
None
:
self
.
_rank
=
dist
.
get_rank
()
data_pipe
=
super
().
_organize_items
(
data_pipe
)
# If drop_uneven_inputs is True, drop the excessive inputs by limiting
# the length of the datapipe.
if
self
.
_drop_uneven_inputs
:
data_pipe
=
data_pipe
.
header
(
self
.
_num_evened_batches
)
return
data_pipe
tests/python/pytorch/graphbolt/test_item_sampler.py
View file @
d79701db
...
@@ -767,31 +767,12 @@ def distributed_item_sampler_subprocess(
...
@@ -767,31 +767,12 @@ def distributed_item_sampler_subprocess(
for
i
in
data_loader
:
for
i
in
data_loader
:
# Count how many times each item is sampled.
# Count how many times each item is sampled.
sampled_count
[
i
.
seed_nodes
]
+=
1
sampled_count
[
i
.
seed_nodes
]
+=
1
if
drop_last
:
assert
i
.
seed_nodes
.
size
(
0
)
==
batch_size
num_items
+=
i
.
seed_nodes
.
size
(
0
)
num_items
+=
i
.
seed_nodes
.
size
(
0
)
num_batches
=
len
(
list
(
item_sampler
))
num_batches
=
len
(
list
(
item_sampler
))
# Calculate expected numbers of items and batches.
expected_num_items
=
num_ids
//
nprocs
+
(
num_ids
%
nprocs
>
proc_id
)
if
drop_last
and
expected_num_items
%
batch_size
>
0
:
expected_num_items
-=
expected_num_items
%
batch_size
expected_num_batches
=
expected_num_items
//
batch_size
+
(
(
not
drop_last
)
and
(
expected_num_items
%
batch_size
>
0
)
)
if
drop_uneven_inputs
:
if
drop_uneven_inputs
:
if
(
(
not
drop_last
)
and
(
num_ids
%
(
nprocs
*
batch_size
)
<
nprocs
)
and
(
num_ids
%
(
nprocs
*
batch_size
)
>
proc_id
)
):
expected_num_batches
-=
1
expected_num_items
-=
1
elif
(
drop_last
and
(
nprocs
*
batch_size
-
num_ids
%
(
nprocs
*
batch_size
)
<
nprocs
)
and
(
num_ids
%
nprocs
>
proc_id
)
):
expected_num_batches
-=
1
expected_num_items
-=
batch_size
num_batches_tensor
=
torch
.
tensor
(
num_batches
)
num_batches_tensor
=
torch
.
tensor
(
num_batches
)
dist
.
broadcast
(
num_batches_tensor
,
0
)
dist
.
broadcast
(
num_batches_tensor
,
0
)
# Test if the number of batches are the same for all processes.
# Test if the number of batches are the same for all processes.
...
@@ -801,10 +782,6 @@ def distributed_item_sampler_subprocess(
...
@@ -801,10 +782,6 @@ def distributed_item_sampler_subprocess(
dist
.
reduce
(
sampled_count
,
0
)
dist
.
reduce
(
sampled_count
,
0
)
try
:
try
:
# Check if the numbers are as expected.
assert
num_items
==
expected_num_items
assert
num_batches
==
expected_num_batches
# Make sure no item is sampled more than once.
# Make sure no item is sampled more than once.
assert
sampled_count
.
max
()
<=
1
assert
sampled_count
.
max
()
<=
1
finally
:
finally
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment