Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
355e9d2f
Commit
355e9d2f
authored
Oct 22, 2019
by
Zhicheng Yan
Committed by
Francisco Massa
Oct 22, 2019
Browse files
extend DistributedSampler to support group_size (#1512)
* extend DistributedSampler to support group_size * Fix lint
parent
b60cb726
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
71 additions
and
4 deletions
+71
-4
test/test_datasets_samplers.py
test/test_datasets_samplers.py
+30
-1
torchvision/datasets/samplers/clip_sampler.py
torchvision/datasets/samplers/clip_sampler.py
+41
-3
No files found.
test/test_datasets_samplers.py
View file @
355e9d2f
...
...
@@ -5,7 +5,11 @@ import torch
import
unittest
from
torchvision
import
io
from
torchvision.datasets.samplers
import
RandomClipSampler
,
UniformClipSampler
from
torchvision.datasets.samplers
import
(
DistributedSampler
,
RandomClipSampler
,
UniformClipSampler
,
)
from
torchvision.datasets.video_utils
import
VideoClips
,
unfold
from
torchvision
import
get_video_backend
...
...
@@ -83,6 +87,31 @@ class Tester(unittest.TestCase):
indices
=
torch
.
tensor
(
list
(
iter
(
sampler
)))
self
.
assertTrue
(
indices
.
equal
(
torch
.
tensor
([
0
,
0
,
1
,
2
,
4
,
6
,
7
,
9
,
11
])))
def
test_distributed_sampler_and_uniform_clip_sampler
(
self
):
with
get_list_of_videos
(
num_videos
=
3
,
sizes
=
[
25
,
25
,
25
])
as
video_list
:
video_clips
=
VideoClips
(
video_list
,
5
,
5
)
clip_sampler
=
UniformClipSampler
(
video_clips
,
3
)
distributed_sampler_rank0
=
DistributedSampler
(
clip_sampler
,
num_replicas
=
2
,
rank
=
0
,
group_size
=
3
,
)
indices
=
torch
.
tensor
(
list
(
iter
(
distributed_sampler_rank0
)))
self
.
assertEqual
(
len
(
distributed_sampler_rank0
),
6
)
self
.
assertTrue
(
indices
.
equal
(
torch
.
tensor
([
0
,
2
,
4
,
10
,
12
,
14
])))
distributed_sampler_rank1
=
DistributedSampler
(
clip_sampler
,
num_replicas
=
2
,
rank
=
1
,
group_size
=
3
,
)
indices
=
torch
.
tensor
(
list
(
iter
(
distributed_sampler_rank1
)))
self
.
assertEqual
(
len
(
distributed_sampler_rank1
),
6
)
self
.
assertTrue
(
indices
.
equal
(
torch
.
tensor
([
5
,
7
,
9
,
0
,
2
,
4
])))
if
__name__
==
'__main__'
:
unittest
.
main
()
torchvision/datasets/samplers/clip_sampler.py
View file @
355e9d2f
...
...
@@ -9,9 +9,32 @@ class DistributedSampler(Sampler):
"""
Extension of DistributedSampler, as discussed in
https://github.com/pytorch/pytorch/issues/23430
Example:
dataset: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
num_replicas: 4
shuffle: False
when group_size = 1
RANK | shard_dataset
=========================
rank_0 | [0, 4, 8, 12]
rank_1 | [1, 5, 9, 13]
rank_2 | [2, 6, 10, 0]
rank_3 | [3, 7, 11, 1]
when group_size = 2
RANK | shard_dataset
=========================
rank_0 | [0, 1, 8, 9]
rank_1 | [2, 3, 10, 11]
rank_2 | [4, 5, 12, 13]
rank_3 | [6, 7, 0, 1]
"""
def
__init__
(
self
,
dataset
,
num_replicas
=
None
,
rank
=
None
,
shuffle
=
False
):
def
__init__
(
self
,
dataset
,
num_replicas
=
None
,
rank
=
None
,
shuffle
=
False
,
group_size
=
1
):
if
num_replicas
is
None
:
if
not
dist
.
is_available
():
raise
RuntimeError
(
"Requires distributed package to be available"
)
...
...
@@ -20,11 +43,20 @@ class DistributedSampler(Sampler):
if
not
dist
.
is_available
():
raise
RuntimeError
(
"Requires distributed package to be available"
)
rank
=
dist
.
get_rank
()
assert
len
(
dataset
)
%
group_size
==
0
,
(
"dataset length must be a multiplier of group size"
"dataset length: %d, group size: %d"
%
(
len
(
dataset
),
group_size
)
)
self
.
dataset
=
dataset
self
.
group_size
=
group_size
self
.
num_replicas
=
num_replicas
self
.
rank
=
rank
self
.
epoch
=
0
self
.
num_samples
=
int
(
math
.
ceil
(
len
(
self
.
dataset
)
*
1.0
/
self
.
num_replicas
))
dataset_group_length
=
len
(
dataset
)
//
group_size
self
.
num_group_samples
=
int
(
math
.
ceil
(
dataset_group_length
*
1.0
/
self
.
num_replicas
)
)
self
.
num_samples
=
self
.
num_group_samples
*
group_size
self
.
total_size
=
self
.
num_samples
*
self
.
num_replicas
self
.
shuffle
=
shuffle
...
...
@@ -41,8 +73,14 @@ class DistributedSampler(Sampler):
indices
+=
indices
[:(
self
.
total_size
-
len
(
indices
))]
assert
len
(
indices
)
==
self
.
total_size
total_group_size
=
self
.
total_size
//
self
.
group_size
indices
=
torch
.
reshape
(
torch
.
LongTensor
(
indices
),
(
total_group_size
,
self
.
group_size
)
)
# subsample
indices
=
indices
[
self
.
rank
:
self
.
total_size
:
self
.
num_replicas
]
indices
=
indices
[
self
.
rank
:
total_group_size
:
self
.
num_replicas
,
:]
indices
=
torch
.
reshape
(
indices
,
(
-
1
,)).
tolist
()
assert
len
(
indices
)
==
self
.
num_samples
if
isinstance
(
self
.
dataset
,
Sampler
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment