Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
6662b30a
Unverified
Commit
6662b30a
authored
Sep 14, 2020
by
Philip Meier
Committed by
GitHub
Sep 14, 2020
Browse files
add typehints for .datasets.samplers (#2667)
parent
f8bf06d5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
23 additions
and
16 deletions
+23
-16
torchvision/datasets/samplers/clip_sampler.py
torchvision/datasets/samplers/clip_sampler.py
+23
-16
No files found.
torchvision/datasets/samplers/clip_sampler.py
View file @
6662b30a
...
@@ -3,6 +3,7 @@ import torch
...
@@ -3,6 +3,7 @@ import torch
from
torch.utils.data
import
Sampler
from
torch.utils.data
import
Sampler
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
torchvision.datasets.video_utils
import
VideoClips
from
torchvision.datasets.video_utils
import
VideoClips
from
typing
import
Optional
,
List
,
Iterator
,
Sized
,
Union
,
cast
class
DistributedSampler
(
Sampler
):
class
DistributedSampler
(
Sampler
):
...
@@ -34,7 +35,14 @@ class DistributedSampler(Sampler):
...
@@ -34,7 +35,14 @@ class DistributedSampler(Sampler):
"""
"""
def
__init__
(
self
,
dataset
,
num_replicas
=
None
,
rank
=
None
,
shuffle
=
False
,
group_size
=
1
):
def
__init__
(
self
,
dataset
:
Sized
,
num_replicas
:
Optional
[
int
]
=
None
,
rank
:
Optional
[
int
]
=
None
,
shuffle
:
bool
=
False
,
group_size
:
int
=
1
,
)
->
None
:
if
num_replicas
is
None
:
if
num_replicas
is
None
:
if
not
dist
.
is_available
():
if
not
dist
.
is_available
():
raise
RuntimeError
(
"Requires distributed package to be available"
)
raise
RuntimeError
(
"Requires distributed package to be available"
)
...
@@ -60,10 +68,11 @@ class DistributedSampler(Sampler):
...
@@ -60,10 +68,11 @@ class DistributedSampler(Sampler):
self
.
total_size
=
self
.
num_samples
*
self
.
num_replicas
self
.
total_size
=
self
.
num_samples
*
self
.
num_replicas
self
.
shuffle
=
shuffle
self
.
shuffle
=
shuffle
def
__iter__
(
self
):
def
__iter__
(
self
)
->
Iterator
[
int
]
:
# deterministically shuffle based on epoch
# deterministically shuffle based on epoch
g
=
torch
.
Generator
()
g
=
torch
.
Generator
()
g
.
manual_seed
(
self
.
epoch
)
g
.
manual_seed
(
self
.
epoch
)
indices
:
Union
[
torch
.
Tensor
,
List
[
int
]]
if
self
.
shuffle
:
if
self
.
shuffle
:
indices
=
torch
.
randperm
(
len
(
self
.
dataset
),
generator
=
g
).
tolist
()
indices
=
torch
.
randperm
(
len
(
self
.
dataset
),
generator
=
g
).
tolist
()
else
:
else
:
...
@@ -89,10 +98,10 @@ class DistributedSampler(Sampler):
...
@@ -89,10 +98,10 @@ class DistributedSampler(Sampler):
return
iter
(
indices
)
return
iter
(
indices
)
def
__len__
(
self
):
def
__len__
(
self
)
->
int
:
return
self
.
num_samples
return
self
.
num_samples
def
set_epoch
(
self
,
epoch
)
:
def
set_epoch
(
self
,
epoch
:
int
)
->
None
:
self
.
epoch
=
epoch
self
.
epoch
=
epoch
...
@@ -106,14 +115,14 @@ class UniformClipSampler(Sampler):
...
@@ -106,14 +115,14 @@ class UniformClipSampler(Sampler):
video_clips (VideoClips): video clips to sample from
video_clips (VideoClips): video clips to sample from
num_clips_per_video (int): number of clips to be sampled per video
num_clips_per_video (int): number of clips to be sampled per video
"""
"""
def
__init__
(
self
,
video_clips
,
num_clips_per_video
)
:
def
__init__
(
self
,
video_clips
:
VideoClips
,
num_clips_per_video
:
int
)
->
None
:
if
not
isinstance
(
video_clips
,
VideoClips
):
if
not
isinstance
(
video_clips
,
VideoClips
):
raise
TypeError
(
"Expected video_clips to be an instance of VideoClips, "
raise
TypeError
(
"Expected video_clips to be an instance of VideoClips, "
"got {}"
.
format
(
type
(
video_clips
)))
"got {}"
.
format
(
type
(
video_clips
)))
self
.
video_clips
=
video_clips
self
.
video_clips
=
video_clips
self
.
num_clips_per_video
=
num_clips_per_video
self
.
num_clips_per_video
=
num_clips_per_video
def
__iter__
(
self
):
def
__iter__
(
self
)
->
Iterator
[
int
]
:
idxs
=
[]
idxs
=
[]
s
=
0
s
=
0
# select num_clips_per_video for each video, uniformly spaced
# select num_clips_per_video for each video, uniformly spaced
...
@@ -130,10 +139,9 @@ class UniformClipSampler(Sampler):
...
@@ -130,10 +139,9 @@ class UniformClipSampler(Sampler):
)
)
s
+=
length
s
+=
length
idxs
.
append
(
sampled
)
idxs
.
append
(
sampled
)
idxs
=
torch
.
cat
(
idxs
).
tolist
()
return
iter
(
cast
(
List
[
int
],
torch
.
cat
(
idxs
).
tolist
()))
return
iter
(
idxs
)
def
__len__
(
self
):
def
__len__
(
self
)
->
int
:
return
sum
(
return
sum
(
self
.
num_clips_per_video
for
c
in
self
.
video_clips
.
clips
if
len
(
c
)
>
0
self
.
num_clips_per_video
for
c
in
self
.
video_clips
.
clips
if
len
(
c
)
>
0
)
)
...
@@ -147,14 +155,14 @@ class RandomClipSampler(Sampler):
...
@@ -147,14 +155,14 @@ class RandomClipSampler(Sampler):
video_clips (VideoClips): video clips to sample from
video_clips (VideoClips): video clips to sample from
max_clips_per_video (int): maximum number of clips to be sampled per video
max_clips_per_video (int): maximum number of clips to be sampled per video
"""
"""
def
__init__
(
self
,
video_clips
,
max_clips_per_video
)
:
def
__init__
(
self
,
video_clips
:
VideoClips
,
max_clips_per_video
:
int
)
->
None
:
if
not
isinstance
(
video_clips
,
VideoClips
):
if
not
isinstance
(
video_clips
,
VideoClips
):
raise
TypeError
(
"Expected video_clips to be an instance of VideoClips, "
raise
TypeError
(
"Expected video_clips to be an instance of VideoClips, "
"got {}"
.
format
(
type
(
video_clips
)))
"got {}"
.
format
(
type
(
video_clips
)))
self
.
video_clips
=
video_clips
self
.
video_clips
=
video_clips
self
.
max_clips_per_video
=
max_clips_per_video
self
.
max_clips_per_video
=
max_clips_per_video
def
__iter__
(
self
):
def
__iter__
(
self
)
->
Iterator
[
int
]
:
idxs
=
[]
idxs
=
[]
s
=
0
s
=
0
# select at most max_clips_per_video for each video, randomly
# select at most max_clips_per_video for each video, randomly
...
@@ -164,11 +172,10 @@ class RandomClipSampler(Sampler):
...
@@ -164,11 +172,10 @@ class RandomClipSampler(Sampler):
sampled
=
torch
.
randperm
(
length
)[:
size
]
+
s
sampled
=
torch
.
randperm
(
length
)[:
size
]
+
s
s
+=
length
s
+=
length
idxs
.
append
(
sampled
)
idxs
.
append
(
sampled
)
idxs
=
torch
.
cat
(
idxs
)
idxs
_
=
torch
.
cat
(
idxs
)
# shuffle all clips randomly
# shuffle all clips randomly
perm
=
torch
.
randperm
(
len
(
idxs
))
perm
=
torch
.
randperm
(
len
(
idxs_
))
idxs
=
idxs
[
perm
].
tolist
()
return
iter
(
idxs_
[
perm
].
tolist
())
return
iter
(
idxs
)
def
__len__
(
self
):
def
__len__
(
self
)
->
int
:
return
sum
(
min
(
len
(
c
),
self
.
max_clips_per_video
)
for
c
in
self
.
video_clips
.
clips
)
return
sum
(
min
(
len
(
c
),
self
.
max_clips_per_video
)
for
c
in
self
.
video_clips
.
clips
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment