Commit 681c6c11 authored by Rahul Somani's avatar Rahul Somani Committed by Francisco Massa
Browse files

Refactored clip_sampler (#1562)

parent 95131de3
...@@ -2,7 +2,7 @@ import math ...@@ -2,7 +2,7 @@ import math
import torch import torch
from torch.utils.data import Sampler from torch.utils.data import Sampler
import torch.distributed as dist import torch.distributed as dist
import torchvision.datasets.video_utils from torchvision.datasets.video_utils import VideoClips
class DistributedSampler(Sampler): class DistributedSampler(Sampler):
...@@ -96,7 +96,7 @@ class DistributedSampler(Sampler): ...@@ -96,7 +96,7 @@ class DistributedSampler(Sampler):
self.epoch = epoch self.epoch = epoch
class UniformClipSampler(torch.utils.data.Sampler): class UniformClipSampler(Sampler):
""" """
Sample `num_video_clips_per_video` clips for each video, equally spaced. Sample `num_video_clips_per_video` clips for each video, equally spaced.
When number of unique clips in the video is fewer than num_video_clips_per_video, When number of unique clips in the video is fewer than num_video_clips_per_video,
...@@ -107,7 +107,7 @@ class UniformClipSampler(torch.utils.data.Sampler): ...@@ -107,7 +107,7 @@ class UniformClipSampler(torch.utils.data.Sampler):
num_clips_per_video (int): number of clips to be sampled per video num_clips_per_video (int): number of clips to be sampled per video
""" """
def __init__(self, video_clips, num_clips_per_video): def __init__(self, video_clips, num_clips_per_video):
if not isinstance(video_clips, torchvision.datasets.video_utils.VideoClips): if not isinstance(video_clips, VideoClips):
raise TypeError("Expected video_clips to be an instance of VideoClips, " raise TypeError("Expected video_clips to be an instance of VideoClips, "
"got {}".format(type(video_clips))) "got {}".format(type(video_clips)))
self.video_clips = video_clips self.video_clips = video_clips
...@@ -139,7 +139,7 @@ class UniformClipSampler(torch.utils.data.Sampler): ...@@ -139,7 +139,7 @@ class UniformClipSampler(torch.utils.data.Sampler):
) )
class RandomClipSampler(torch.utils.data.Sampler): class RandomClipSampler(Sampler):
""" """
Samples at most `max_video_clips_per_video` clips for each video randomly Samples at most `max_video_clips_per_video` clips for each video randomly
...@@ -148,7 +148,7 @@ class RandomClipSampler(torch.utils.data.Sampler): ...@@ -148,7 +148,7 @@ class RandomClipSampler(torch.utils.data.Sampler):
max_clips_per_video (int): maximum number of clips to be sampled per video max_clips_per_video (int): maximum number of clips to be sampled per video
""" """
def __init__(self, video_clips, max_clips_per_video): def __init__(self, video_clips, max_clips_per_video):
if not isinstance(video_clips, torchvision.datasets.video_utils.VideoClips): if not isinstance(video_clips, VideoClips):
raise TypeError("Expected video_clips to be an instance of VideoClips, " raise TypeError("Expected video_clips to be an instance of VideoClips, "
"got {}".format(type(video_clips))) "got {}".format(type(video_clips)))
self.video_clips = video_clips self.video_clips = video_clips
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment