Unverified Commit 6662b30a authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add typehints for .datasets.samplers (#2667)

parent f8bf06d5
......@@ -3,6 +3,7 @@ import torch
from torch.utils.data import Sampler
import torch.distributed as dist
from torchvision.datasets.video_utils import VideoClips
from typing import Optional, List, Iterator, Sized, Union, cast
class DistributedSampler(Sampler):
......@@ -34,7 +35,14 @@ class DistributedSampler(Sampler):
"""
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=False, group_size=1):
def __init__(
self,
dataset: Sized,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = False,
group_size: int = 1,
) -> None:
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
......@@ -60,10 +68,11 @@ class DistributedSampler(Sampler):
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
def __iter__(self):
def __iter__(self) -> Iterator[int]:
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
indices: Union[torch.Tensor, List[int]]
if self.shuffle:
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
......@@ -89,10 +98,10 @@ class DistributedSampler(Sampler):
return iter(indices)
def __len__(self):
def __len__(self) -> int:
return self.num_samples
def set_epoch(self, epoch):
def set_epoch(self, epoch: int) -> None:
self.epoch = epoch
......@@ -106,14 +115,14 @@ class UniformClipSampler(Sampler):
video_clips (VideoClips): video clips to sample from
num_clips_per_video (int): number of clips to be sampled per video
"""
def __init__(self, video_clips, num_clips_per_video):
def __init__(self, video_clips: VideoClips, num_clips_per_video: int) -> None:
if not isinstance(video_clips, VideoClips):
raise TypeError("Expected video_clips to be an instance of VideoClips, "
"got {}".format(type(video_clips)))
self.video_clips = video_clips
self.num_clips_per_video = num_clips_per_video
def __iter__(self):
def __iter__(self) -> Iterator[int]:
idxs = []
s = 0
# select num_clips_per_video for each video, uniformly spaced
......@@ -130,10 +139,9 @@ class UniformClipSampler(Sampler):
)
s += length
idxs.append(sampled)
idxs = torch.cat(idxs).tolist()
return iter(idxs)
return iter(cast(List[int], torch.cat(idxs).tolist()))
def __len__(self):
def __len__(self) -> int:
return sum(
self.num_clips_per_video for c in self.video_clips.clips if len(c) > 0
)
......@@ -147,14 +155,14 @@ class RandomClipSampler(Sampler):
video_clips (VideoClips): video clips to sample from
max_clips_per_video (int): maximum number of clips to be sampled per video
"""
def __init__(self, video_clips, max_clips_per_video):
def __init__(self, video_clips: VideoClips, max_clips_per_video: int) -> None:
if not isinstance(video_clips, VideoClips):
raise TypeError("Expected video_clips to be an instance of VideoClips, "
"got {}".format(type(video_clips)))
self.video_clips = video_clips
self.max_clips_per_video = max_clips_per_video
def __iter__(self):
def __iter__(self) -> Iterator[int]:
idxs = []
s = 0
# select at most max_clips_per_video for each video, randomly
......@@ -164,11 +172,10 @@ class RandomClipSampler(Sampler):
sampled = torch.randperm(length)[:size] + s
s += length
idxs.append(sampled)
idxs = torch.cat(idxs)
idxs_ = torch.cat(idxs)
# shuffle all clips randomly
perm = torch.randperm(len(idxs))
idxs = idxs[perm].tolist()
return iter(idxs)
perm = torch.randperm(len(idxs_))
return iter(idxs_[perm].tolist())
def __len__(self):
def __len__(self) -> int:
return sum(min(len(c), self.max_clips_per_video) for c in self.video_clips.clips)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment