Commit 95131de3 authored by Henry Xia's avatar Henry Xia Committed by Francisco Massa
Browse files

expose audio_channels as a parameter to kinetics dataset (#1559)

parent be6f398c
...@@ -39,7 +39,7 @@ class Kinetics400(VisionDataset): ...@@ -39,7 +39,7 @@ class Kinetics400(VisionDataset):
def __init__(self, root, frames_per_clip, step_between_clips=1, frame_rate=None, def __init__(self, root, frames_per_clip, step_between_clips=1, frame_rate=None,
extensions=('avi',), transform=None, _precomputed_metadata=None, extensions=('avi',), transform=None, _precomputed_metadata=None,
num_workers=1, _video_width=0, _video_height=0, num_workers=1, _video_width=0, _video_height=0,
_video_min_dimension=0, _audio_samples=0): _video_min_dimension=0, _audio_samples=0, _audio_channels=0):
super(Kinetics400, self).__init__(root) super(Kinetics400, self).__init__(root)
classes = list(sorted(list_dir(root))) classes = list(sorted(list_dir(root)))
...@@ -58,6 +58,7 @@ class Kinetics400(VisionDataset): ...@@ -58,6 +58,7 @@ class Kinetics400(VisionDataset):
_video_height=_video_height, _video_height=_video_height,
_video_min_dimension=_video_min_dimension, _video_min_dimension=_video_min_dimension,
_audio_samples=_audio_samples, _audio_samples=_audio_samples,
_audio_channels=_audio_channels,
) )
self.transform = transform self.transform = transform
......
...@@ -71,7 +71,7 @@ class VideoClips(object): ...@@ -71,7 +71,7 @@ class VideoClips(object):
def __init__(self, video_paths, clip_length_in_frames=16, frames_between_clips=1, def __init__(self, video_paths, clip_length_in_frames=16, frames_between_clips=1,
frame_rate=None, _precomputed_metadata=None, num_workers=0, frame_rate=None, _precomputed_metadata=None, num_workers=0,
_video_width=0, _video_height=0, _video_min_dimension=0, _video_width=0, _video_height=0, _video_min_dimension=0,
_audio_samples=0): _audio_samples=0, _audio_channels=0):
self.video_paths = video_paths self.video_paths = video_paths
self.num_workers = num_workers self.num_workers = num_workers
...@@ -81,6 +81,7 @@ class VideoClips(object): ...@@ -81,6 +81,7 @@ class VideoClips(object):
self._video_height = _video_height self._video_height = _video_height
self._video_min_dimension = _video_min_dimension self._video_min_dimension = _video_min_dimension
self._audio_samples = _audio_samples self._audio_samples = _audio_samples
self._audio_channels = _audio_channels
if _precomputed_metadata is None: if _precomputed_metadata is None:
self._compute_frame_pts() self._compute_frame_pts()
...@@ -149,7 +150,8 @@ class VideoClips(object): ...@@ -149,7 +150,8 @@ class VideoClips(object):
_video_width=self._video_width, _video_width=self._video_width,
_video_height=self._video_height, _video_height=self._video_height,
_video_min_dimension=self._video_min_dimension, _video_min_dimension=self._video_min_dimension,
_audio_samples=self._audio_samples) _audio_samples=self._audio_samples,
_audio_channels=self._audio_channels)
@staticmethod @staticmethod
def compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate): def compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate):
...@@ -298,6 +300,7 @@ class VideoClips(object): ...@@ -298,6 +300,7 @@ class VideoClips(object):
video_pts_range=(video_start_pts, video_end_pts), video_pts_range=(video_start_pts, video_end_pts),
video_timebase=info["video_timebase"], video_timebase=info["video_timebase"],
audio_samples=self._audio_samples, audio_samples=self._audio_samples,
audio_channels=self._audio_channels,
audio_pts_range=(audio_start_pts, audio_end_pts), audio_pts_range=(audio_start_pts, audio_end_pts),
audio_timebase=audio_timebase, audio_timebase=audio_timebase,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment