Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
6662b30a
"torchvision/vscode:/vscode.git/clone" did not exist on "05dcf50ae490c78dc75544fb23922eb3a95b6ac6"
Unverified
Commit
6662b30a
authored
Sep 14, 2020
by
Philip Meier
Committed by
GitHub
Sep 14, 2020
Browse files
add typehints for .datasets.samplers (#2667)
parent
f8bf06d5
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
23 additions
and
16 deletions
+23
-16
torchvision/datasets/samplers/clip_sampler.py
torchvision/datasets/samplers/clip_sampler.py
+23
-16
No files found.
torchvision/datasets/samplers/clip_sampler.py
View file @
6662b30a
...
@@ -3,6 +3,7 @@ import torch
...
@@ -3,6 +3,7 @@ import torch
from
torch.utils.data
import
Sampler
from
torch.utils.data
import
Sampler
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
torchvision.datasets.video_utils
import
VideoClips
from
torchvision.datasets.video_utils
import
VideoClips
from
typing
import
Optional
,
List
,
Iterator
,
Sized
,
Union
,
cast
class
DistributedSampler
(
Sampler
):
class
DistributedSampler
(
Sampler
):
...
@@ -34,7 +35,14 @@ class DistributedSampler(Sampler):
...
@@ -34,7 +35,14 @@ class DistributedSampler(Sampler):
"""
"""
def
__init__
(
self
,
dataset
,
num_replicas
=
None
,
rank
=
None
,
shuffle
=
False
,
group_size
=
1
):
def
__init__
(
self
,
dataset
:
Sized
,
num_replicas
:
Optional
[
int
]
=
None
,
rank
:
Optional
[
int
]
=
None
,
shuffle
:
bool
=
False
,
group_size
:
int
=
1
,
)
->
None
:
if
num_replicas
is
None
:
if
num_replicas
is
None
:
if
not
dist
.
is_available
():
if
not
dist
.
is_available
():
raise
RuntimeError
(
"Requires distributed package to be available"
)
raise
RuntimeError
(
"Requires distributed package to be available"
)
...
@@ -60,10 +68,11 @@ class DistributedSampler(Sampler):
...
@@ -60,10 +68,11 @@ class DistributedSampler(Sampler):
self
.
total_size
=
self
.
num_samples
*
self
.
num_replicas
self
.
total_size
=
self
.
num_samples
*
self
.
num_replicas
self
.
shuffle
=
shuffle
self
.
shuffle
=
shuffle
def
__iter__
(
self
):
def
__iter__
(
self
)
->
Iterator
[
int
]
:
# deterministically shuffle based on epoch
# deterministically shuffle based on epoch
g
=
torch
.
Generator
()
g
=
torch
.
Generator
()
g
.
manual_seed
(
self
.
epoch
)
g
.
manual_seed
(
self
.
epoch
)
indices
:
Union
[
torch
.
Tensor
,
List
[
int
]]
if
self
.
shuffle
:
if
self
.
shuffle
:
indices
=
torch
.
randperm
(
len
(
self
.
dataset
),
generator
=
g
).
tolist
()
indices
=
torch
.
randperm
(
len
(
self
.
dataset
),
generator
=
g
).
tolist
()
else
:
else
:
...
@@ -89,10 +98,10 @@ class DistributedSampler(Sampler):
...
@@ -89,10 +98,10 @@ class DistributedSampler(Sampler):
return
iter
(
indices
)
return
iter
(
indices
)
def
__len__
(
self
):
def
__len__
(
self
)
->
int
:
return
self
.
num_samples
return
self
.
num_samples
def
set_epoch
(
self
,
epoch
)
:
def
set_epoch
(
self
,
epoch
:
int
)
->
None
:
self
.
epoch
=
epoch
self
.
epoch
=
epoch
...
@@ -106,14 +115,14 @@ class UniformClipSampler(Sampler):
...
@@ -106,14 +115,14 @@ class UniformClipSampler(Sampler):
video_clips (VideoClips): video clips to sample from
video_clips (VideoClips): video clips to sample from
num_clips_per_video (int): number of clips to be sampled per video
num_clips_per_video (int): number of clips to be sampled per video
"""
"""
def
__init__
(
self
,
video_clips
,
num_clips_per_video
)
:
def
__init__
(
self
,
video_clips
:
VideoClips
,
num_clips_per_video
:
int
)
->
None
:
if
not
isinstance
(
video_clips
,
VideoClips
):
if
not
isinstance
(
video_clips
,
VideoClips
):
raise
TypeError
(
"Expected video_clips to be an instance of VideoClips, "
raise
TypeError
(
"Expected video_clips to be an instance of VideoClips, "
"got {}"
.
format
(
type
(
video_clips
)))
"got {}"
.
format
(
type
(
video_clips
)))
self
.
video_clips
=
video_clips
self
.
video_clips
=
video_clips
self
.
num_clips_per_video
=
num_clips_per_video
self
.
num_clips_per_video
=
num_clips_per_video
def
__iter__
(
self
):
def
__iter__
(
self
)
->
Iterator
[
int
]
:
idxs
=
[]
idxs
=
[]
s
=
0
s
=
0
# select num_clips_per_video for each video, uniformly spaced
# select num_clips_per_video for each video, uniformly spaced
...
@@ -130,10 +139,9 @@ class UniformClipSampler(Sampler):
...
@@ -130,10 +139,9 @@ class UniformClipSampler(Sampler):
)
)
s
+=
length
s
+=
length
idxs
.
append
(
sampled
)
idxs
.
append
(
sampled
)
idxs
=
torch
.
cat
(
idxs
).
tolist
()
return
iter
(
cast
(
List
[
int
],
torch
.
cat
(
idxs
).
tolist
()))
return
iter
(
idxs
)
def
__len__
(
self
):
def
__len__
(
self
)
->
int
:
return
sum
(
return
sum
(
self
.
num_clips_per_video
for
c
in
self
.
video_clips
.
clips
if
len
(
c
)
>
0
self
.
num_clips_per_video
for
c
in
self
.
video_clips
.
clips
if
len
(
c
)
>
0
)
)
...
@@ -147,14 +155,14 @@ class RandomClipSampler(Sampler):
...
@@ -147,14 +155,14 @@ class RandomClipSampler(Sampler):
video_clips (VideoClips): video clips to sample from
video_clips (VideoClips): video clips to sample from
max_clips_per_video (int): maximum number of clips to be sampled per video
max_clips_per_video (int): maximum number of clips to be sampled per video
"""
"""
def
__init__
(
self
,
video_clips
,
max_clips_per_video
)
:
def
__init__
(
self
,
video_clips
:
VideoClips
,
max_clips_per_video
:
int
)
->
None
:
if
not
isinstance
(
video_clips
,
VideoClips
):
if
not
isinstance
(
video_clips
,
VideoClips
):
raise
TypeError
(
"Expected video_clips to be an instance of VideoClips, "
raise
TypeError
(
"Expected video_clips to be an instance of VideoClips, "
"got {}"
.
format
(
type
(
video_clips
)))
"got {}"
.
format
(
type
(
video_clips
)))
self
.
video_clips
=
video_clips
self
.
video_clips
=
video_clips
self
.
max_clips_per_video
=
max_clips_per_video
self
.
max_clips_per_video
=
max_clips_per_video
def
__iter__
(
self
):
def
__iter__
(
self
)
->
Iterator
[
int
]
:
idxs
=
[]
idxs
=
[]
s
=
0
s
=
0
# select at most max_clips_per_video for each video, randomly
# select at most max_clips_per_video for each video, randomly
...
@@ -164,11 +172,10 @@ class RandomClipSampler(Sampler):
...
@@ -164,11 +172,10 @@ class RandomClipSampler(Sampler):
sampled
=
torch
.
randperm
(
length
)[:
size
]
+
s
sampled
=
torch
.
randperm
(
length
)[:
size
]
+
s
s
+=
length
s
+=
length
idxs
.
append
(
sampled
)
idxs
.
append
(
sampled
)
idxs
=
torch
.
cat
(
idxs
)
idxs
_
=
torch
.
cat
(
idxs
)
# shuffle all clips randomly
# shuffle all clips randomly
perm
=
torch
.
randperm
(
len
(
idxs
))
perm
=
torch
.
randperm
(
len
(
idxs_
))
idxs
=
idxs
[
perm
].
tolist
()
return
iter
(
idxs_
[
perm
].
tolist
())
return
iter
(
idxs
)
def
__len__
(
self
):
def
__len__
(
self
)
->
int
:
return
sum
(
min
(
len
(
c
),
self
.
max_clips_per_video
)
for
c
in
self
.
video_clips
.
clips
)
return
sum
(
min
(
len
(
c
),
self
.
max_clips_per_video
)
for
c
in
self
.
video_clips
.
clips
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment