Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
46d7b1d9
Unverified
Commit
46d7b1d9
authored
Nov 20, 2023
by
Ramon Zhou
Committed by
GitHub
Nov 20, 2023
Browse files
[GraphBolt] Rewrite DistributeItemSampler logic (#6565)
parent
81c7781b
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
260 additions
and
126 deletions
+260
-126
python/dgl/graphbolt/item_sampler.py
python/dgl/graphbolt/item_sampler.py
+47
-120
python/dgl/graphbolt/utils/__init__.py
python/dgl/graphbolt/utils/__init__.py
+1
-0
python/dgl/graphbolt/utils/item_sampler_utils.py
python/dgl/graphbolt/utils/item_sampler_utils.py
+112
-0
tests/python/pytorch/graphbolt/test_item_sampler.py
tests/python/pytorch/graphbolt/test_item_sampler.py
+100
-6
No files found.
python/dgl/graphbolt/item_sampler.py
View file @
46d7b1d9
...
@@ -16,6 +16,7 @@ from ..batch import batch as dgl_batch
...
@@ -16,6 +16,7 @@ from ..batch import batch as dgl_batch
from
..heterograph
import
DGLGraph
from
..heterograph
import
DGLGraph
from
.itemset
import
ItemSet
,
ItemSetDict
from
.itemset
import
ItemSet
,
ItemSetDict
from
.minibatch
import
MiniBatch
from
.minibatch
import
MiniBatch
from
.utils
import
calculate_range
__all__
=
[
"ItemSampler"
,
"DistributedItemSampler"
,
"minibatcher_default"
]
__all__
=
[
"ItemSampler"
,
"DistributedItemSampler"
,
"minibatcher_default"
]
...
@@ -125,7 +126,6 @@ class ItemShufflerAndBatcher:
...
@@ -125,7 +126,6 @@ class ItemShufflerAndBatcher:
)
)
self
.
_distributed
=
distributed
self
.
_distributed
=
distributed
self
.
_drop_uneven_inputs
=
drop_uneven_inputs
self
.
_drop_uneven_inputs
=
drop_uneven_inputs
if
distributed
:
self
.
_num_replicas
=
world_size
self
.
_num_replicas
=
world_size
self
.
_rank
=
rank
self
.
_rank
=
rank
...
@@ -184,101 +184,33 @@ class ItemShufflerAndBatcher:
...
@@ -184,101 +184,33 @@ class ItemShufflerAndBatcher:
num_workers
=
1
num_workers
=
1
worker_id
=
0
worker_id
=
0
buffer
=
None
buffer
=
None
if
not
self
.
_distributed
:
total
=
len
(
self
.
_item_set
)
num_items
=
len
(
self
.
_item_set
)
start_offset
,
assigned_count
,
output_count
=
calculate_range
(
start_offset
=
0
self
.
_distributed
,
else
:
total
,
total_count
=
len
(
self
.
_item_set
)
self
.
_num_replicas
,
big_batch_size
=
self
.
_num_replicas
*
self
.
_batch_size
self
.
_rank
,
big_batch_count
,
big_batch_remain
=
divmod
(
num_workers
,
total_count
,
big_batch_size
worker_id
,
)
self
.
_batch_size
,
last_batch_count
,
batch_remain
=
divmod
(
self
.
_drop_last
,
big_batch_remain
,
self
.
_batch_size
self
.
_drop_uneven_inputs
,
)
if
self
.
_rank
<
last_batch_count
:
last_batch
=
self
.
_batch_size
elif
self
.
_rank
==
last_batch_count
:
last_batch
=
batch_remain
else
:
last_batch
=
0
num_items
=
big_batch_count
*
self
.
_batch_size
+
last_batch
start_offset
=
(
big_batch_count
*
self
.
_batch_size
*
self
.
_rank
+
min
(
self
.
_rank
*
self
.
_batch_size
,
big_batch_remain
)
)
if
not
self
.
_drop_uneven_inputs
or
(
not
self
.
_drop_last
and
last_batch_count
==
self
.
_num_replicas
):
# No need to drop uneven batches.
num_evened_items
=
num_items
if
num_workers
>
1
:
total_batch_count
=
(
num_items
+
self
.
_batch_size
-
1
)
//
self
.
_batch_size
split_batch_count
=
total_batch_count
//
num_workers
+
(
worker_id
<
total_batch_count
%
num_workers
)
split_num_items
=
split_batch_count
*
self
.
_batch_size
num_items
=
(
min
(
num_items
,
split_num_items
*
(
worker_id
+
1
))
-
split_num_items
*
worker_id
)
num_evened_items
=
num_items
start_offset
=
(
big_batch_count
*
self
.
_batch_size
*
self
.
_rank
+
min
(
self
.
_rank
*
self
.
_batch_size
,
big_batch_remain
)
+
self
.
_batch_size
*
(
total_batch_count
//
num_workers
*
worker_id
+
min
(
worker_id
,
total_batch_count
%
num_workers
)
)
)
else
:
# Needs to drop uneven batches. As many items as `last_batch`
# size will be dropped. It would be better not to let those
# dropped items come from the same worker.
num_evened_items
=
big_batch_count
*
self
.
_batch_size
if
num_workers
>
1
:
total_batch_count
=
big_batch_count
split_batch_count
=
total_batch_count
//
num_workers
+
(
worker_id
<
total_batch_count
%
num_workers
)
split_num_items
=
split_batch_count
*
self
.
_batch_size
split_item_remain
=
last_batch
//
num_workers
+
(
worker_id
<
last_batch
%
num_workers
)
num_items
=
split_num_items
+
split_item_remain
num_evened_items
=
split_num_items
start_offset
=
(
big_batch_count
*
self
.
_batch_size
*
self
.
_rank
+
min
(
self
.
_rank
*
self
.
_batch_size
,
big_batch_remain
)
+
self
.
_batch_size
*
(
total_batch_count
//
num_workers
*
worker_id
+
min
(
worker_id
,
total_batch_count
%
num_workers
)
)
+
last_batch
//
num_workers
*
worker_id
+
min
(
worker_id
,
last_batch
%
num_workers
)
)
)
start
=
0
start
=
0
while
start
<
num_items
:
while
start
<
assigned_count
:
end
=
min
(
start
+
self
.
_buffer_size
,
num_items
)
end
=
min
(
start
+
self
.
_buffer_size
,
assigned_count
)
buffer
=
self
.
_item_set
[
start_offset
+
start
:
start_offset
+
end
]
buffer
=
self
.
_item_set
[
start_offset
+
start
:
start_offset
+
end
]
indices
=
torch
.
arange
(
end
-
start
)
indices
=
torch
.
arange
(
end
-
start
)
if
self
.
_shuffle
:
if
self
.
_shuffle
:
np
.
random
.
shuffle
(
indices
.
numpy
())
np
.
random
.
shuffle
(
indices
.
numpy
())
offsets
=
self
.
_calculate_offsets
(
buffer
)
offsets
=
self
.
_calculate_offsets
(
buffer
)
for
i
in
range
(
0
,
len
(
indices
),
self
.
_batch_size
):
for
i
in
range
(
0
,
len
(
indices
),
self
.
_batch_size
):
if
self
.
_drop_last
and
i
+
self
.
_batch_size
>
len
(
indices
):
if
output_count
<=
0
:
break
if
(
self
.
_distributed
and
self
.
_drop_uneven_inputs
and
i
>=
num_evened_items
):
break
break
batch_indices
=
indices
[
i
:
i
+
self
.
_batch_size
]
batch_indices
=
indices
[
i
:
i
+
min
(
self
.
_batch_size
,
output_count
)
]
output_count
-=
self
.
_batch_size
yield
self
.
_collate_batch
(
buffer
,
batch_indices
,
offsets
)
yield
self
.
_collate_batch
(
buffer
,
batch_indices
,
offsets
)
buffer
=
None
buffer
=
None
start
=
end
start
=
end
...
@@ -592,10 +524,6 @@ class DistributedItemSampler(ItemSampler):
...
@@ -592,10 +524,6 @@ class DistributedItemSampler(ItemSampler):
counterparts. The original item set is split such that each replica
counterparts. The original item set is split such that each replica
(process) receives an exclusive subset.
(process) receives an exclusive subset.
Note: DistributedItemSampler may not work as expected when it is the last
datapipe before the data is fetched. Please wrap a SingleProcessDataLoader
or another datapipe on it.
Note: The items will be first split onto each replica, then get shuffled
Note: The items will be first split onto each replica, then get shuffled
(if needed) and batched. Therefore, each replica will always get a same set
(if needed) and batched. Therefore, each replica will always get a same set
of items.
of items.
...
@@ -638,7 +566,6 @@ class DistributedItemSampler(ItemSampler):
...
@@ -638,7 +566,6 @@ class DistributedItemSampler(ItemSampler):
Examples
Examples
--------
--------
TODO[Kaicheng]: Modify examples here.
0. Preparation: DistributedItemSampler needs multi-processing environment to
0. Preparation: DistributedItemSampler needs multi-processing environment to
work. You need to spawn subprocesses and initialize processing group before
work. You need to spawn subprocesses and initialize processing group before
executing following examples. Due to randomness, the output is not always
executing following examples. Due to randomness, the output is not always
...
@@ -646,7 +573,7 @@ class DistributedItemSampler(ItemSampler):
...
@@ -646,7 +573,7 @@ class DistributedItemSampler(ItemSampler):
>>> import torch
>>> import torch
>>> from dgl import graphbolt as gb
>>> from dgl import graphbolt as gb
>>> item_set = gb.ItemSet(torch.arange(
0, 14
))
>>> item_set = gb.ItemSet(torch.arange(
15
))
>>> num_replicas = 4
>>> num_replicas = 4
>>> batch_size = 2
>>> batch_size = 2
>>> mp.spawn(...)
>>> mp.spawn(...)
...
@@ -659,10 +586,10 @@ class DistributedItemSampler(ItemSampler):
...
@@ -659,10 +586,10 @@ class DistributedItemSampler(ItemSampler):
>>> )
>>> )
>>> data_loader = gb.SingleProcessDataLoader(item_sampler)
>>> data_loader = gb.SingleProcessDataLoader(item_sampler)
>>> print(f"Replica#{proc_id}: {list(data_loader)})
>>> print(f"Replica#{proc_id}: {list(data_loader)})
Replica#0: [tensor([0,
4
]), tensor([
8, 12
])]
Replica#0: [tensor([0,
1
]), tensor([
2, 3
])]
Replica#1: [tensor([
1
, 5]), tensor([
9, 13
])]
Replica#1: [tensor([
4
, 5]), tensor([
6, 7
])]
Replica#2: [tensor([
2
,
6
]), tensor([10])]
Replica#2: [tensor([
8
,
9
]), tensor([10
, 11
])]
Replica#3: [tensor([
3, 7
]), tensor([1
1
])]
Replica#3: [tensor([
12, 13
]), tensor([1
4
])]
2. shuffle = False, drop_last = True, drop_uneven_inputs = False.
2. shuffle = False, drop_last = True, drop_uneven_inputs = False.
...
@@ -672,10 +599,10 @@ class DistributedItemSampler(ItemSampler):
...
@@ -672,10 +599,10 @@ class DistributedItemSampler(ItemSampler):
>>> )
>>> )
>>> data_loader = gb.SingleProcessDataLoader(item_sampler)
>>> data_loader = gb.SingleProcessDataLoader(item_sampler)
>>> print(f"Replica#{proc_id}: {list(data_loader)})
>>> print(f"Replica#{proc_id}: {list(data_loader)})
Replica#0: [tensor([0,
4
]), tensor([
8, 12
])]
Replica#0: [tensor([0,
1
]), tensor([
2, 3
])]
Replica#1: [tensor([
1
, 5]), tensor([
9, 13
])]
Replica#1: [tensor([
4
, 5]), tensor([
6, 7
])]
Replica#2: [tensor([
2, 6
])]
Replica#2: [tensor([
8, 9]), tensor([10, 11
])]
Replica#3: [tensor([
3, 7
])]
Replica#3: [tensor([
12, 13
])]
3. shuffle = False, drop_last = False, drop_uneven_inputs = True.
3. shuffle = False, drop_last = False, drop_uneven_inputs = True.
...
@@ -685,10 +612,10 @@ class DistributedItemSampler(ItemSampler):
...
@@ -685,10 +612,10 @@ class DistributedItemSampler(ItemSampler):
>>> )
>>> )
>>> data_loader = gb.SingleProcessDataLoader(item_sampler)
>>> data_loader = gb.SingleProcessDataLoader(item_sampler)
>>> print(f"Replica#{proc_id}: {list(data_loader)})
>>> print(f"Replica#{proc_id}: {list(data_loader)})
Replica#0: [tensor([0,
4
]), tensor([
8, 12
])]
Replica#0: [tensor([0,
1
]), tensor([
2, 3
])]
Replica#1: [tensor([
1
, 5]), tensor([
9, 13
])]
Replica#1: [tensor([
4
, 5]), tensor([
6, 7
])]
Replica#2: [tensor([
2
,
6
]), tensor([10])]
Replica#2: [tensor([
8
,
9
]), tensor([10
, 11
])]
Replica#3: [tensor([
3, 7
]), tensor([1
1
])]
Replica#3: [tensor([
12, 13
]), tensor([1
4
])]
4. shuffle = False, drop_last = True, drop_uneven_inputs = True.
4. shuffle = False, drop_last = True, drop_uneven_inputs = True.
...
@@ -698,10 +625,10 @@ class DistributedItemSampler(ItemSampler):
...
@@ -698,10 +625,10 @@ class DistributedItemSampler(ItemSampler):
>>> )
>>> )
>>> data_loader = gb.SingleProcessDataLoader(item_sampler)
>>> data_loader = gb.SingleProcessDataLoader(item_sampler)
>>> print(f"Replica#{proc_id}: {list(data_loader)})
>>> print(f"Replica#{proc_id}: {list(data_loader)})
Replica#0: [tensor([0,
4
])]
Replica#0: [tensor([0,
1
])]
Replica#1: [tensor([
1
, 5])]
Replica#1: [tensor([
4
, 5])]
Replica#2: [tensor([
2
,
6
])]
Replica#2: [tensor([
8
,
9
])]
Replica#3: [tensor([
3, 7
])]
Replica#3: [tensor([
12, 13
])]
5. shuffle = True, drop_last = True, drop_uneven_inputs = False.
5. shuffle = True, drop_last = True, drop_uneven_inputs = False.
...
@@ -712,10 +639,10 @@ class DistributedItemSampler(ItemSampler):
...
@@ -712,10 +639,10 @@ class DistributedItemSampler(ItemSampler):
>>> data_loader = gb.SingleProcessDataLoader(item_sampler)
>>> data_loader = gb.SingleProcessDataLoader(item_sampler)
>>> print(f"Replica#{proc_id}: {list(data_loader)})
>>> print(f"Replica#{proc_id}: {list(data_loader)})
(One possible output:)
(One possible output:)
Replica#0: [tensor([
0
,
8
]), tensor([
4
, 1
2
])]
Replica#0: [tensor([
3
,
2
]), tensor([
0
, 1])]
Replica#1: [tensor([ 5
, 13
]), tensor([
9
,
1
])]
Replica#1: [tensor([
6,
5]), tensor([
7
,
4
])]
Replica#2: [tensor([
2
, 10])]
Replica#2: [tensor([
8
, 10])]
Replica#3: [tensor([1
1
,
7
])]
Replica#3: [tensor([1
4
,
12
])]
6. shuffle = True, drop_last = True, drop_uneven_inputs = True.
6. shuffle = True, drop_last = True, drop_uneven_inputs = True.
...
@@ -726,10 +653,10 @@ class DistributedItemSampler(ItemSampler):
...
@@ -726,10 +653,10 @@ class DistributedItemSampler(ItemSampler):
>>> data_loader = gb.SingleProcessDataLoader(item_sampler)
>>> data_loader = gb.SingleProcessDataLoader(item_sampler)
>>> print(f"Replica#{proc_id}: {list(data_loader)})
>>> print(f"Replica#{proc_id}: {list(data_loader)})
(One possible output:)
(One possible output:)
Replica#0: [tensor([
8
,
0
])]
Replica#0: [tensor([
1
,
3
])]
Replica#1: [tensor([
1, 13
])]
Replica#1: [tensor([
7, 5
])]
Replica#2: [tensor([1
0
,
6
])]
Replica#2: [tensor([1
1
,
9
])]
Replica#3: [tensor([
3, 1
1
])]
Replica#3: [tensor([
1
3, 1
4
])]
"""
"""
def
__init__
(
def
__init__
(
...
...
python/dgl/graphbolt/utils/__init__.py
View file @
46d7b1d9
...
@@ -2,3 +2,4 @@
...
@@ -2,3 +2,4 @@
from
.internal
import
*
from
.internal
import
*
from
.sample_utils
import
*
from
.sample_utils
import
*
from
.datapipe_utils
import
*
from
.datapipe_utils
import
*
from
.item_sampler_utils
import
*
python/dgl/graphbolt/utils/item_sampler_utils.py
0 → 100644
View file @
46d7b1d9
"""Utility functions for DistributedItemSampler."""
def
count_split
(
total
,
num_workers
,
worker_id
,
batch_size
=
1
):
"""Calculate the number of assigned items after splitting them by batch
size evenly. It will return the number for this worker and also a sum of
previous workers.
"""
quotient
,
remainder
=
divmod
(
total
,
num_workers
*
batch_size
)
if
batch_size
==
1
:
assigned
=
quotient
+
(
worker_id
<
remainder
)
else
:
batch_count
,
last_batch
=
divmod
(
remainder
,
batch_size
)
assigned
=
quotient
*
batch_size
+
(
batch_size
if
worker_id
<
batch_count
else
(
last_batch
if
worker_id
==
batch_count
else
0
)
)
prefix_sum
=
quotient
*
worker_id
*
batch_size
+
min
(
worker_id
*
batch_size
,
remainder
)
return
(
assigned
,
prefix_sum
)
def
calculate_range
(
distributed
,
total
,
num_replicas
,
rank
,
num_workers
,
worker_id
,
batch_size
,
drop_last
,
drop_uneven_inputs
,
):
"""Calculates the range of items to be assigned to the current worker.
This function evenly distributes `total` items among multiple workers,
batching them using `batch_size`. Each replica has `num_workers` workers.
The batches generated by workers within the same replica are combined into
the replica`s output. The `drop_last` parameter determines whether
incomplete batches should be dropped. If `drop_last` is True, incomplete
batches are discarded. The `drop_uneven_inputs` parameter determines if the
number of batches assigned to each replica should be the same. If
`drop_uneven_inputs` is True, excessive batches for some replicas will be
dropped.
Args:
distributed (bool): Whether it's in distributed mode.
total (int): The total number of items.
num_replicas (int): The total number of replicas.
rank (int): The rank of the current replica.
num_workers (int): The number of workers per replica.
worker_id (int): The ID of the current worker.
batch_size (int): The desired batch size.
drop_last (bool): Whether to drop incomplete batches.
drop_uneven_inputs (bool): Whether to drop excessive batches for some
replicas.
Returns:
tuple: A tuple containing three numbers:
- start_offset (int): The starting offset of the range assigned to
the current worker.
- assigned_count (int): The length of the range assigned to the
current worker.
- output_count (int): The number of items that the current worker
will produce after dropping.
"""
# Check if it's distributed mode.
if
not
distributed
:
if
not
drop_last
:
return
(
0
,
total
,
total
)
else
:
return
(
0
,
total
,
total
//
batch_size
*
batch_size
)
# First, equally distribute items into all replicas.
assigned_count
,
start_offset
=
count_split
(
total
,
num_replicas
,
rank
,
batch_size
)
# Calculate the number of outputs when drop_uneven_inputs is True.
# `assigned_count` is the number of items distributed to the current
# process. `output_count` is the number of items should be output
# by this process after dropping.
if
not
drop_uneven_inputs
:
if
not
drop_last
:
output_count
=
assigned_count
else
:
output_count
=
assigned_count
//
batch_size
*
batch_size
else
:
if
not
drop_last
:
min_item_count
,
_
=
count_split
(
total
,
num_replicas
,
num_replicas
-
1
,
batch_size
)
min_batch_count
=
(
min_item_count
+
batch_size
-
1
)
//
batch_size
output_count
=
min
(
min_batch_count
*
batch_size
,
assigned_count
)
else
:
output_count
=
total
//
(
batch_size
*
num_replicas
)
*
batch_size
# If there are multiple workers, equally distribute the batches to
# all workers.
if
num_workers
>
1
:
# Equally distribute the dropped number too.
dropped_items
,
prev_dropped_items
=
count_split
(
assigned_count
-
output_count
,
num_workers
,
worker_id
)
output_count
,
prev_output_count
=
count_split
(
output_count
,
num_workers
,
worker_id
,
batch_size
,
)
assigned_count
=
output_count
+
dropped_items
start_offset
+=
prev_output_count
+
prev_dropped_items
return
(
start_offset
,
assigned_count
,
output_count
)
tests/python/pytorch/graphbolt/test_item_sampler.py
View file @
46d7b1d9
...
@@ -728,8 +728,8 @@ def distributed_item_sampler_subprocess(
...
@@ -728,8 +728,8 @@ def distributed_item_sampler_subprocess(
nprocs
,
nprocs
,
item_set
,
item_set
,
num_ids
,
num_ids
,
num_workers
,
batch_size
,
batch_size
,
shuffle
,
drop_last
,
drop_last
,
drop_uneven_inputs
,
drop_uneven_inputs
,
):
):
...
@@ -750,7 +750,7 @@ def distributed_item_sampler_subprocess(
...
@@ -750,7 +750,7 @@ def distributed_item_sampler_subprocess(
item_sampler
=
gb
.
DistributedItemSampler
(
item_sampler
=
gb
.
DistributedItemSampler
(
item_set
,
item_set
,
batch_size
=
batch_size
,
batch_size
=
batch_size
,
shuffle
=
shuffl
e
,
shuffle
=
Tru
e
,
drop_last
=
drop_last
,
drop_last
=
drop_last
,
drop_uneven_inputs
=
drop_uneven_inputs
,
drop_uneven_inputs
=
drop_uneven_inputs
,
)
)
...
@@ -759,7 +759,9 @@ def distributed_item_sampler_subprocess(
...
@@ -759,7 +759,9 @@ def distributed_item_sampler_subprocess(
gb
.
BasicFeatureStore
({}),
gb
.
BasicFeatureStore
({}),
[],
[],
)
)
data_loader
=
gb
.
SingleProcessDataLoader
(
feature_fetcher
)
data_loader
=
gb
.
MultiProcessDataLoader
(
feature_fetcher
,
num_workers
=
num_workers
)
# Count the numbers of items and batches.
# Count the numbers of items and batches.
num_items
=
0
num_items
=
0
...
@@ -788,12 +790,104 @@ def distributed_item_sampler_subprocess(
...
@@ -788,12 +790,104 @@ def distributed_item_sampler_subprocess(
dist
.
destroy_process_group
()
dist
.
destroy_process_group
()
@
pytest
.
mark
.
parametrize
(
"params"
,
[
((
24
,
4
,
0
,
4
,
False
,
False
),
[(
8
,
8
),
(
8
,
8
),
(
4
,
4
),
(
4
,
4
)]),
((
30
,
4
,
0
,
4
,
False
,
False
),
[(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
6
,
6
)]),
((
30
,
4
,
0
,
4
,
True
,
False
),
[(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
6
,
4
)]),
((
30
,
4
,
0
,
4
,
False
,
True
),
[(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
6
,
6
)]),
((
30
,
4
,
0
,
4
,
True
,
True
),
[(
8
,
4
),
(
8
,
4
),
(
8
,
4
),
(
6
,
4
)]),
(
(
53
,
4
,
2
,
4
,
False
,
False
),
[(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
5
,
5
),
(
8
,
8
),
(
4
,
4
),
(
8
,
8
),
(
4
,
4
)],
),
(
(
53
,
4
,
2
,
4
,
True
,
False
),
[(
8
,
8
),
(
8
,
8
),
(
9
,
8
),
(
4
,
4
),
(
8
,
8
),
(
4
,
4
),
(
8
,
8
),
(
4
,
4
)],
),
(
(
53
,
4
,
2
,
4
,
False
,
True
),
[(
10
,
8
),
(
6
,
4
),
(
9
,
8
),
(
4
,
4
),
(
8
,
8
),
(
4
,
4
),
(
8
,
8
),
(
4
,
4
)],
),
(
(
53
,
4
,
2
,
4
,
True
,
True
),
[(
10
,
8
),
(
6
,
4
),
(
9
,
8
),
(
4
,
4
),
(
8
,
8
),
(
4
,
4
),
(
8
,
8
),
(
4
,
4
)],
),
(
(
63
,
4
,
2
,
4
,
False
,
False
),
[(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
7
,
7
)],
),
(
(
63
,
4
,
2
,
4
,
True
,
False
),
[(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
10
,
8
),
(
5
,
4
)],
),
(
(
63
,
4
,
2
,
4
,
False
,
True
),
[(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
7
,
7
)],
),
(
(
63
,
4
,
2
,
4
,
True
,
True
),
[
(
10
,
8
),
(
6
,
4
),
(
10
,
8
),
(
6
,
4
),
(
10
,
8
),
(
6
,
4
),
(
10
,
8
),
(
5
,
4
),
],
),
(
(
65
,
4
,
2
,
4
,
False
,
False
),
[(
9
,
9
),
(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
8
,
8
)],
),
(
(
65
,
4
,
2
,
4
,
True
,
True
),
[(
9
,
8
),
(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
8
,
8
),
(
8
,
8
)],
),
],
)
def
test_RangeCalculation
(
params
):
(
(
total
,
num_replicas
,
num_workers
,
batch_size
,
drop_last
,
drop_uneven_inputs
,
),
key
,
)
=
params
answer
=
[]
sum
=
0
for
rank
in
range
(
num_replicas
):
for
worker_id
in
range
(
max
(
num_workers
,
1
)):
result
=
gb
.
utils
.
calculate_range
(
True
,
total
,
num_replicas
,
rank
,
num_workers
,
worker_id
,
batch_size
,
drop_last
,
drop_uneven_inputs
,
)
assert
sum
==
result
[
0
]
sum
+=
result
[
1
]
answer
.
append
((
result
[
1
],
result
[
2
]))
assert
key
==
answer
@
pytest
.
mark
.
parametrize
(
"num_ids"
,
[
24
,
30
,
32
,
34
,
36
])
@
pytest
.
mark
.
parametrize
(
"num_ids"
,
[
24
,
30
,
32
,
34
,
36
])
@
pytest
.
mark
.
parametrize
(
"
shuffle"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"
num_workers"
,
[
0
,
2
])
@
pytest
.
mark
.
parametrize
(
"drop_last"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"drop_last"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"drop_uneven_inputs"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"drop_uneven_inputs"
,
[
False
,
True
])
def
test_DistributedItemSampler
(
def
test_DistributedItemSampler
(
num_ids
,
shuffle
,
drop_last
,
drop_uneven_inputs
num_ids
,
num_workers
,
drop_last
,
drop_uneven_inputs
):
):
nprocs
=
4
nprocs
=
4
batch_size
=
4
batch_size
=
4
...
@@ -813,8 +907,8 @@ def test_DistributedItemSampler(
...
@@ -813,8 +907,8 @@ def test_DistributedItemSampler(
nprocs
,
nprocs
,
item_set
,
item_set
,
num_ids
,
num_ids
,
num_workers
,
batch_size
,
batch_size
,
shuffle
,
drop_last
,
drop_last
,
drop_uneven_inputs
,
drop_uneven_inputs
,
),
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment