Unverified Commit 6662b30a authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add typehints for .datasets.samplers (#2667)

parent f8bf06d5
...@@ -3,6 +3,7 @@ import torch ...@@ -3,6 +3,7 @@ import torch
from torch.utils.data import Sampler from torch.utils.data import Sampler
import torch.distributed as dist import torch.distributed as dist
from torchvision.datasets.video_utils import VideoClips from torchvision.datasets.video_utils import VideoClips
from typing import Optional, List, Iterator, Sized, Union, cast
class DistributedSampler(Sampler): class DistributedSampler(Sampler):
...@@ -34,7 +35,14 @@ class DistributedSampler(Sampler): ...@@ -34,7 +35,14 @@ class DistributedSampler(Sampler):
""" """
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=False, group_size=1): def __init__(
self,
dataset: Sized,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = False,
group_size: int = 1,
) -> None:
if num_replicas is None: if num_replicas is None:
if not dist.is_available(): if not dist.is_available():
raise RuntimeError("Requires distributed package to be available") raise RuntimeError("Requires distributed package to be available")
...@@ -60,10 +68,11 @@ class DistributedSampler(Sampler): ...@@ -60,10 +68,11 @@ class DistributedSampler(Sampler):
self.total_size = self.num_samples * self.num_replicas self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle self.shuffle = shuffle
def __iter__(self): def __iter__(self) -> Iterator[int]:
# deterministically shuffle based on epoch # deterministically shuffle based on epoch
g = torch.Generator() g = torch.Generator()
g.manual_seed(self.epoch) g.manual_seed(self.epoch)
indices: Union[torch.Tensor, List[int]]
if self.shuffle: if self.shuffle:
indices = torch.randperm(len(self.dataset), generator=g).tolist() indices = torch.randperm(len(self.dataset), generator=g).tolist()
else: else:
...@@ -89,10 +98,10 @@ class DistributedSampler(Sampler): ...@@ -89,10 +98,10 @@ class DistributedSampler(Sampler):
return iter(indices) return iter(indices)
def __len__(self): def __len__(self) -> int:
return self.num_samples return self.num_samples
def set_epoch(self, epoch): def set_epoch(self, epoch: int) -> None:
self.epoch = epoch self.epoch = epoch
...@@ -106,14 +115,14 @@ class UniformClipSampler(Sampler): ...@@ -106,14 +115,14 @@ class UniformClipSampler(Sampler):
video_clips (VideoClips): video clips to sample from video_clips (VideoClips): video clips to sample from
num_clips_per_video (int): number of clips to be sampled per video num_clips_per_video (int): number of clips to be sampled per video
""" """
def __init__(self, video_clips, num_clips_per_video): def __init__(self, video_clips: VideoClips, num_clips_per_video: int) -> None:
if not isinstance(video_clips, VideoClips): if not isinstance(video_clips, VideoClips):
raise TypeError("Expected video_clips to be an instance of VideoClips, " raise TypeError("Expected video_clips to be an instance of VideoClips, "
"got {}".format(type(video_clips))) "got {}".format(type(video_clips)))
self.video_clips = video_clips self.video_clips = video_clips
self.num_clips_per_video = num_clips_per_video self.num_clips_per_video = num_clips_per_video
def __iter__(self): def __iter__(self) -> Iterator[int]:
idxs = [] idxs = []
s = 0 s = 0
# select num_clips_per_video for each video, uniformly spaced # select num_clips_per_video for each video, uniformly spaced
...@@ -130,10 +139,9 @@ class UniformClipSampler(Sampler): ...@@ -130,10 +139,9 @@ class UniformClipSampler(Sampler):
) )
s += length s += length
idxs.append(sampled) idxs.append(sampled)
idxs = torch.cat(idxs).tolist() return iter(cast(List[int], torch.cat(idxs).tolist()))
return iter(idxs)
def __len__(self): def __len__(self) -> int:
return sum( return sum(
self.num_clips_per_video for c in self.video_clips.clips if len(c) > 0 self.num_clips_per_video for c in self.video_clips.clips if len(c) > 0
) )
...@@ -147,14 +155,14 @@ class RandomClipSampler(Sampler): ...@@ -147,14 +155,14 @@ class RandomClipSampler(Sampler):
video_clips (VideoClips): video clips to sample from video_clips (VideoClips): video clips to sample from
max_clips_per_video (int): maximum number of clips to be sampled per video max_clips_per_video (int): maximum number of clips to be sampled per video
""" """
def __init__(self, video_clips, max_clips_per_video): def __init__(self, video_clips: VideoClips, max_clips_per_video: int) -> None:
if not isinstance(video_clips, VideoClips): if not isinstance(video_clips, VideoClips):
raise TypeError("Expected video_clips to be an instance of VideoClips, " raise TypeError("Expected video_clips to be an instance of VideoClips, "
"got {}".format(type(video_clips))) "got {}".format(type(video_clips)))
self.video_clips = video_clips self.video_clips = video_clips
self.max_clips_per_video = max_clips_per_video self.max_clips_per_video = max_clips_per_video
def __iter__(self): def __iter__(self) -> Iterator[int]:
idxs = [] idxs = []
s = 0 s = 0
# select at most max_clips_per_video for each video, randomly # select at most max_clips_per_video for each video, randomly
...@@ -164,11 +172,10 @@ class RandomClipSampler(Sampler): ...@@ -164,11 +172,10 @@ class RandomClipSampler(Sampler):
sampled = torch.randperm(length)[:size] + s sampled = torch.randperm(length)[:size] + s
s += length s += length
idxs.append(sampled) idxs.append(sampled)
idxs = torch.cat(idxs) idxs_ = torch.cat(idxs)
# shuffle all clips randomly # shuffle all clips randomly
perm = torch.randperm(len(idxs)) perm = torch.randperm(len(idxs_))
idxs = idxs[perm].tolist() return iter(idxs_[perm].tolist())
return iter(idxs)
def __len__(self): def __len__(self) -> int:
return sum(min(len(c), self.max_clips_per_video) for c in self.video_clips.clips) return sum(min(len(c), self.max_clips_per_video) for c in self.video_clips.clips)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment