Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
46d7b1d9
Unverified
Commit
46d7b1d9
authored
Nov 20, 2023
by
Ramon Zhou
Committed by
GitHub
Nov 20, 2023
Browse files
[GraphBolt] Rewrite DistributeItemSampler logic (#6565)
parent
81c7781b
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
260 additions
and
126 deletions
+260
-126
python/dgl/graphbolt/item_sampler.py
python/dgl/graphbolt/item_sampler.py
+47
-120
python/dgl/graphbolt/utils/__init__.py
python/dgl/graphbolt/utils/__init__.py
+1
-0
python/dgl/graphbolt/utils/item_sampler_utils.py
python/dgl/graphbolt/utils/item_sampler_utils.py
+112
-0
tests/python/pytorch/graphbolt/test_item_sampler.py
tests/python/pytorch/graphbolt/test_item_sampler.py
+100
-6
No files found.
python/dgl/graphbolt/item_sampler.py
View file @
46d7b1d9
...
...
@@ -16,6 +16,7 @@ from ..batch import batch as dgl_batch
from
..heterograph
import
DGLGraph
from
.itemset
import
ItemSet
,
ItemSetDict
from
.minibatch
import
MiniBatch
from
.utils
import
calculate_range
__all__
=
[
"ItemSampler"
,
"DistributedItemSampler"
,
"minibatcher_default"
]
...
...
@@ -125,7 +126,6 @@ class ItemShufflerAndBatcher:
)
self
.
_distributed
=
distributed
self
.
_drop_uneven_inputs
=
drop_uneven_inputs
if
distributed
:
self
.
_num_replicas
=
world_size
self
.
_rank
=
rank
...
...
@@ -184,101 +184,33 @@ class ItemShufflerAndBatcher:
num_workers
=
1
worker_id
=
0
buffer
=
None
if
not
self
.
_distributed
:
num_items
=
len
(
self
.
_item_set
)
start_offset
=
0
else
:
total_count
=
len
(
self
.
_item_set
)
big_batch_size
=
self
.
_num_replicas
*
self
.
_batch_size
big_batch_count
,
big_batch_remain
=
divmod
(
total_count
,
big_batch_size
)
last_batch_count
,
batch_remain
=
divmod
(
big_batch_remain
,
self
.
_batch_size
)
if
self
.
_rank
<
last_batch_count
:
last_batch
=
self
.
_batch_size
elif
self
.
_rank
==
last_batch_count
:
last_batch
=
batch_remain
else
:
last_batch
=
0
num_items
=
big_batch_count
*
self
.
_batch_size
+
last_batch
start_offset
=
(
big_batch_count
*
self
.
_batch_size
*
self
.
_rank
+
min
(
self
.
_rank
*
self
.
_batch_size
,
big_batch_remain
)
)
if
not
self
.
_drop_uneven_inputs
or
(
not
self
.
_drop_last
and
last_batch_count
==
self
.
_num_replicas
):
# No need to drop uneven batches.
num_evened_items
=
num_items
if
num_workers
>
1
:
total_batch_count
=
(
num_items
+
self
.
_batch_size
-
1
)
//
self
.
_batch_size
split_batch_count
=
total_batch_count
//
num_workers
+
(
worker_id
<
total_batch_count
%
num_workers
)
split_num_items
=
split_batch_count
*
self
.
_batch_size
num_items
=
(
min
(
num_items
,
split_num_items
*
(
worker_id
+
1
))
-
split_num_items
*
worker_id
)
num_evened_items
=
num_items
start_offset
=
(
big_batch_count
*
self
.
_batch_size
*
self
.
_rank
+
min
(
self
.
_rank
*
self
.
_batch_size
,
big_batch_remain
)
+
self
.
_batch_size
*
(
total_batch_count
//
num_workers
*
worker_id
+
min
(
worker_id
,
total_batch_count
%
num_workers
)
)
)
else
:
# Needs to drop uneven batches. As many items as `last_batch`
# size will be dropped. It would be better not to let those
# dropped items come from the same worker.
num_evened_items
=
big_batch_count
*
self
.
_batch_size
if
num_workers
>
1
:
total_batch_count
=
big_batch_count
split_batch_count
=
total_batch_count
//
num_workers
+
(
worker_id
<
total_batch_count
%
num_workers
)
split_num_items
=
split_batch_count
*
self
.
_batch_size
split_item_remain
=
last_batch
//
num_workers
+
(
worker_id
<
last_batch
%
num_workers
)
num_items
=
split_num_items
+
split_item_remain
num_evened_items
=
split_num_items
start_offset
=
(
big_batch_count
*
self
.
_batch_size
*
self
.
_rank
+
min
(
self
.
_rank
*
self
.
_batch_size
,
big_batch_remain
)
+
self
.
_batch_size
*
(
total_batch_count
//
num_workers
*
worker_id
+
min
(
worker_id
,
total_batch_count
%
num_workers
)
)
+
last_batch
//
num_workers
*
worker_id
+
min
(
worker_id
,
last_batch
%
num_workers
)
total
=
len
(
self
.
_item_set
)
start_offset
,
assigned_count
,
output_count
=
calculate_range
(
self
.
_distributed
,
total
,
self
.
_num_replicas
,
self
.
_rank
,
num_workers
,
worker_id
,
self
.
_batch_size
,
self
.
_drop_last
,
self
.
_drop_uneven_inputs
,
)
start
=
0
while
start
<
num_items
:
end
=
min
(
start
+
self
.
_buffer_size
,
num_items
)
while
start
<
assigned_count
:
end
=
min
(
start
+
self
.
_buffer_size
,
assigned_count
)
buffer
=
self
.
_item_set
[
start_offset
+
start
:
start_offset
+
end
]
indices
=
torch
.
arange
(
end
-
start
)
if
self
.
_shuffle
:
np
.
random
.
shuffle
(
indices
.
numpy
())
offsets
=
self
.
_calculate_offsets
(
buffer
)
for
i
in
range
(
0
,
len
(
indices
),
self
.
_batch_size
):
if
self
.
_drop_last
and
i
+
self
.
_batch_size
>
len
(
indices
):
break
if
(
self
.
_distributed
and
self
.
_drop_uneven_inputs
and
i
>=
num_evened_items
):
if
output_count
<=
0
:
break
batch_indices
=
indices
[
i
:
i
+
self
.
_batch_size
]
batch_indices
=
indices
[
i
:
i
+
min
(
self
.
_batch_size
,
output_count
)
]
output_count
-=
self
.
_batch_size
yield
self
.
_collate_batch
(
buffer
,
batch_indices
,
offsets
)
buffer
=
None
start
=
end
...
...
@@ -592,10 +524,6 @@ class DistributedItemSampler(ItemSampler):
counterparts. The original item set is split such that each replica
(process) receives an exclusive subset.
Note: DistributedItemSampler may not work as expected when it is the last
datapipe before the data is fetched. Please wrap a SingleProcessDataLoader
or another datapipe on it.
Note: The items will be first split onto each replica, then get shuffled
(if needed) and batched. Therefore, each replica will always get a same set
of items.
...
...
@@ -638,7 +566,6 @@ class DistributedItemSampler(ItemSampler):
Examples
--------
TODO[Kaicheng]: Modify examples here.
0. Preparation: DistributedItemSampler needs multi-processing environment to
work. You need to spawn subprocesses and initialize processing group before
executing following examples. Due to randomness, the output is not always
...
...
@@ -646,7 +573,7 @@ class DistributedItemSampler(ItemSampler):
>>> import torch
>>> from dgl import graphbolt as gb
>>> item_set = gb.ItemSet(torch.arange(
0, 14
))
>>> item_set = gb.ItemSet(torch.arange(
15
))
>>> num_replicas = 4
>>> batch_size = 2
>>> mp.spawn(...)
...
...
@@ -659,10 +586,10 @@ class DistributedItemSampler(ItemSampler):
>>> )
>>> data_loader = gb.SingleProcessDataLoader(item_sampler)
>>> print(f"Replica#{proc_id}: {list(data_loader)})
Replica#0: [tensor([0,
4
]), tensor([
8, 12
])]
Replica#1: [tensor([
1
, 5]), tensor([
9, 13
])]
Replica#2: [tensor([
2
,
6
]), tensor([10])]
Replica#3: [tensor([
3, 7
]), tensor([1
1
])]
Replica#0: [tensor([0,
1
]), tensor([
2, 3
])]
Replica#1: [tensor([
4
, 5]), tensor([
6, 7
])]
Replica#2: [tensor([
8
,
9
]), tensor([10
, 11
])]
Replica#3: [tensor([
12, 13
]), tensor([1
4
])]
2. shuffle = False, drop_last = True, drop_uneven_inputs = False.
...
...
@@ -672,10 +599,10 @@ class DistributedItemSampler(ItemSampler):
>>> )
>>> data_loader = gb.SingleProcessDataLoader(item_sampler)
>>> print(f"Replica#{proc_id}: {list(data_loader)})
Replica#0: [tensor([0,
4
]), tensor([
8, 12
])]
Replica#1: [tensor([
1
, 5]), tensor([
9, 13
])]
Replica#2: [tensor([
2, 6
])]
Replica#3: [tensor([
3, 7
])]
Replica#0: [tensor([0,
1
]), tensor([
2, 3
])]
Replica#1: [tensor([
4
, 5]), tensor([
6, 7
])]
Replica#2: [tensor([
8, 9]), tensor([10, 11
])]
Replica#3: [tensor([
12, 13
])]
3. shuffle = False, drop_last = False, drop_uneven_inputs = True.
...
...
@@ -685,10 +612,10 @@ class DistributedItemSampler(ItemSampler):
>>> )
>>> data_loader = gb.SingleProcessDataLoader(item_sampler)
>>> print(f"Replica#{proc_id}: {list(data_loader)})
Replica#0: [tensor([0,
4
]), tensor([
8, 12
])]
Replica#1: [tensor([
1
, 5]), tensor([
9, 13
])]
Replica#2: [tensor([
2
,
6
]), tensor([10])]
Replica#3: [tensor([
3, 7
]), tensor([1
1
])]
Replica#0: [tensor([0,
1
]), tensor([
2, 3
])]
Replica#1: [tensor([
4
, 5]), tensor([
6, 7
])]
Replica#2: [tensor([
8
,
9
]), tensor([10
, 11
])]
Replica#3: [tensor([
12, 13
]), tensor([1
4
])]
4. shuffle = False, drop_last = True, drop_uneven_inputs = True.
...
...
@@ -698,10 +625,10 @@ class DistributedItemSampler(ItemSampler):
>>> )
>>> data_loader = gb.SingleProcessDataLoader(item_sampler)
>>> print(f"Replica#{proc_id}: {list(data_loader)})
Replica#0: [tensor([0,
4
])]
Replica#1: [tensor([
1
, 5])]
Replica#2: [tensor([
2
,
6
])]
Replica#3: [tensor([
3, 7
])]
Replica#0: [tensor([0,
1
])]
Replica#1: [tensor([
4
, 5])]
Replica#2: [tensor([
8
,
9
])]
Replica#3: [tensor([
12, 13
])]
5. shuffle = True, drop_last = True, drop_uneven_inputs = False.
...
...
@@ -712,10 +639,10 @@ class DistributedItemSampler(ItemSampler):
>>> data_loader = gb.SingleProcessDataLoader(item_sampler)
>>> print(f"Replica#{proc_id}: {list(data_loader)})
(One possible output:)
Replica#0: [tensor([
0
,
8
]), tensor([
4
, 1
2
])]
Replica#1: [tensor([ 5
, 13
]), tensor([
9
,
1
])]
Replica#2: [tensor([
2
, 10])]
Replica#3: [tensor([1
1
,
7
])]
Replica#0: [tensor([
3
,
2
]), tensor([
0
, 1])]
Replica#1: [tensor([
6,
5]), tensor([
7
,
4
])]
Replica#2: [tensor([
8
, 10])]
Replica#3: [tensor([1
4
,
12
])]
6. shuffle = True, drop_last = True, drop_uneven_inputs = True.
...
...
@@ -726,10 +653,10 @@ class DistributedItemSampler(ItemSampler):
>>> data_loader = gb.SingleProcessDataLoader(item_sampler)
>>> print(f"Replica#{proc_id}: {list(data_loader)})
(One possible output:)
Replica#0: [tensor([
8
,
0
])]
Replica#1: [tensor([
1, 13
])]
Replica#2: [tensor([1
0
,
6
])]
Replica#3: [tensor([
3, 1
1
])]
Replica#0: [tensor([
1
,
3
])]
Replica#1: [tensor([
7, 5
])]
Replica#2: [tensor([1
1
,
9
])]
Replica#3: [tensor([
1
3, 1
4
])]
"""
def
__init__
(
...
...
python/dgl/graphbolt/utils/__init__.py
View file @
46d7b1d9
...
...
@@ -2,3 +2,4 @@
from
.internal
import
*
from
.sample_utils
import
*
from
.datapipe_utils
import
*
from
.item_sampler_utils
import
*
python/dgl/graphbolt/utils/item_sampler_utils.py
0 → 100644
View file @
46d7b1d9
"""Utility functions for DistributedItemSampler."""
def
count_split
(
total
,
num_workers
,
worker_id
,
batch_size
=
1
):
"""Calculate the number of assigned items after splitting them by batch
size evenly. It will return the number for this worker and also a sum of
previous workers.
"""
quotient
,
remainder
=
divmod
(
total
,
num_workers
*
batch_size
)
if
batch_size
==
1
:
assigned
=
quotient
+
(
worker_id
<
remainder
)
else
:
batch_count
,
last_batch
=
divmod
(
remainder
,
batch_size
)
assigned
=
quotient
*
batch_size
+
(
batch_size
if
worker_id
<
batch_count
else
(
last_batch
if
worker_id
==
batch_count
else
0
)
)
prefix_sum
=
quotient
*
worker_id
*
batch_size
+
min
(
worker_id
*
batch_size
,
remainder
)
return
(
assigned
,
prefix_sum
)
def
calculate_range
(
distributed
,
total
,
num_replicas
,
rank
,
num_workers
,
worker_id
,
batch_size
,
drop_last
,
drop_uneven_inputs
,
):
"""Calculates the range of items to be assigned to the current worker.
This function evenly distributes `total` items among multiple workers,
batching them using `batch_size`. Each replica has `num_workers` workers.
The batches generated by workers within the same replica are combined into
the replica`s output. The `drop_last` parameter determines whether
incomplete batches should be dropped. If `drop_last` is True, incomplete
batches are discarded. The `drop_uneven_inputs` parameter determines if the
number of batches assigned to each replica should be the same. If
`drop_uneven_inputs` is True, excessive batches for some replicas will be
dropped.
Args:
distributed (bool): Whether it's in distributed mode.
total (int): The total number of items.
num_replicas (int): The total number of replicas.
rank (int): The rank of the current replica.
num_workers (int): The number of workers per replica.
worker_id (int): The ID of the current worker.
batch_size (int): The desired batch size.
drop_last (bool): Whether to drop incomplete batches.
drop_uneven_inputs (bool): Whether to drop excessive batches for some
replicas.
Returns:
tuple: A tuple containing three numbers:
- start_offset (int): The starting offset of the range assigned to
the current worker.
- assigned_count (int): The length of the range assigned to the
current worker.
- output_count (int): The number of items that the current worker
will produce after dropping.
"""
# Check if it's distributed mode.
if
not
distributed
:
if
not
drop_last
:
return
(
0
,
total
,
total
)
else
:
return
(
0
,
total
,
total
//
batch_size
*
batch_size
)
# First, equally distribute items into all replicas.
assigned_count
,
start_offset
=
count_split
(
total
,
num_replicas
,
rank
,
batch_size
)
# Calculate the number of outputs when drop_uneven_inputs is True.
# `assigned_count` is the number of items distributed to the current
# process. `output_count` is the number of items should be output
# by this process after dropping.
if
not
drop_uneven_inputs
:
if
not
drop_last
:
output_count
=
assigned_count
else
:
output_count
=
assigned_count
//
batch_size
*
batch_size
else
:
if
not
drop_last
:
min_item_count
,
_
=
count_split
(
total
,
num_replicas
,
num_replicas
-
1
,
batch_size
)
min_batch_count
=
(
min_item_count
+
batch_size
-
1
)
//
batch_size
output_count
=
min
(
min_batch_count
*
batch_size
,
assigned_count
)
else
:
output_count
=
total
//
(
batch_size
*
num_replicas
)
*
batch_size
# If there are multiple workers, equally distribute the batches to
# all workers.
if
num_workers
>
1
:
# Equally distribute the dropped number too.
dropped_items
,
prev_dropped_items
=
count_split
(
assigned_count
-
output_count
,
num_workers
,
worker_id
)
output_count
,
prev_output_count
=
count_split
(
output_count
,
num_workers
,
worker_id
,
batch_size
,
)
assigned_count
=
output_count
+
dropped_items
start_offset
+=
prev_output_count
+
prev_dropped_items
return
(
start_offset
,
assigned_count
,
output_count
)
tests/python/pytorch/graphbolt/test_item_sampler.py
View file @
46d7b1d9
...
...
@@ -728,8 +728,8 @@ def distributed_item_sampler_subprocess(
nprocs
,
item_set
,
num_ids
,
num_workers
,
batch_size
,
shuffle
,
drop_last
,
drop_uneven_inputs
,
):
...
...
@@ -750,7 +750,7 @@ def distributed_item_sampler_subprocess(
item_sampler
=
gb
.
DistributedItemSampler
(
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffl
e
,
shuffle
=
Tru
e
,
drop_last
=
drop_last
,
drop_uneven_inputs
=
drop_uneven_inputs
,
)
...
...
@@ -759,7 +759,9 @@ def distributed_item_sampler_subprocess(
gb
.
BasicFeatureStore
({}),
[],
)
data_loader
=
gb
.
SingleProcessDataLoader
(
feature_fetcher
)
data_loader
=
gb
.
MultiProcessDataLoader
(
feature_fetcher
,
num_workers
=
num_workers
)
# Count the numbers of items and batches.
num_items
=
0
...
...
@@ -788,12 +790,104 @@ def distributed_item_sampler_subprocess(
dist
.
destroy_process_group
()
@
pytest
.
mark
.
parametrize
(
"params"
,
[
((
24
,
4
,
0
,
4
,
False
,
False
),
[(
8
,
8
),
(
8
,
8
),
(
4
,
4
),
(
4
,
4
)]),
((
30
,
4
,
0
,
4
,
False
,
False
),
[(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
6
,
6
)]),
((
30
,
4
,
0
,
4
,
True
,
False
),
[(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
6
,
4
)]),
((
30
,
4
,
0
,
4
,
False
,
True
),
[(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
6
,
6
)]),
((
30
,
4
,
0
,
4
,
True
,
True
),
[(
8
,
4
),
(
8
,
4
),
(
8
,
4
),
(
6
,
4
)]),
(
(
53
,
4
,
2
,
4
,
False
,
False
),
[(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
5
,
5
),
(
8
,
8
),
(
4
,
4
),
(
8
,
8
),
(
4
,
4
)],
),
(
(
53
,
4
,
2
,
4
,
True
,
False
),
[(
8
,
8
),
(
8
,
8
),
(
9
,
8
),
(
4
,
4
),
(
8
,
8
),
(
4
,
4
),
(
8
,
8
),
(
4
,
4
)],
),
(
(
53
,
4
,
2
,
4
,
False
,
True
),
[(
10
,
8
),
(
6
,
4
),
(
9
,
8
),
(
4
,
4
),
(
8
,
8
),
(
4
,
4
),
(
8
,
8
),
(
4
,
4
)],
),
(
(
53
,
4
,
2
,
4
,
True
,
True
),
[(
10
,
8
),
(
6
,
4
),
(
9
,
8
),
(
4
,
4
),
(
8
,
8
),
(
4
,
4
),
(
8
,
8
),
(
4
,
4
)],
),
(
(
63
,
4
,
2
,
4
,
False
,
False
),
[(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
7
,
7
)],
),
(
(
63
,
4
,
2
,
4
,
True
,
False
),
[(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
10
,
8
),
(
5
,
4
)],
),
(
(
63
,
4
,
2
,
4
,
False
,
True
),
[(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
7
,
7
)],
),
(
(
63
,
4
,
2
,
4
,
True
,
True
),
[
(
10
,
8
),
(
6
,
4
),
(
10
,
8
),
(
6
,
4
),
(
10
,
8
),
(
6
,
4
),
(
10
,
8
),
(
5
,
4
),
],
),
(
(
65
,
4
,
2
,
4
,
False
,
False
),
[(
9
,
9
),
(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
8
,
8
)],
),
(
(
65
,
4
,
2
,
4
,
True
,
True
),
[(
9
,
8
),
(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
8
,
8
)],
),
],
)
def
test_RangeCalculation
(
params
):
(
(
total
,
num_replicas
,
num_workers
,
batch_size
,
drop_last
,
drop_uneven_inputs
,
),
key
,
)
=
params
answer
=
[]
sum
=
0
for
rank
in
range
(
num_replicas
):
for
worker_id
in
range
(
max
(
num_workers
,
1
)):
result
=
gb
.
utils
.
calculate_range
(
True
,
total
,
num_replicas
,
rank
,
num_workers
,
worker_id
,
batch_size
,
drop_last
,
drop_uneven_inputs
,
)
assert
sum
==
result
[
0
]
sum
+=
result
[
1
]
answer
.
append
((
result
[
1
],
result
[
2
]))
assert
key
==
answer
@
pytest
.
mark
.
parametrize
(
"num_ids"
,
[
24
,
30
,
32
,
34
,
36
])
@
pytest
.
mark
.
parametrize
(
"
shuffle"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"
num_workers"
,
[
0
,
2
])
@
pytest
.
mark
.
parametrize
(
"drop_last"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"drop_uneven_inputs"
,
[
False
,
True
])
def
test_DistributedItemSampler
(
num_ids
,
shuffle
,
drop_last
,
drop_uneven_inputs
num_ids
,
num_workers
,
drop_last
,
drop_uneven_inputs
):
nprocs
=
4
batch_size
=
4
...
...
@@ -813,8 +907,8 @@ def test_DistributedItemSampler(
nprocs
,
item_set
,
num_ids
,
num_workers
,
batch_size
,
shuffle
,
drop_last
,
drop_uneven_inputs
,
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment