Unverified Commit 4c668139 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Add output_format do video datasets and readers (#6061)


Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 5486b768
...@@ -37,11 +37,13 @@ class HMDB51(VisionDataset): ...@@ -37,11 +37,13 @@ class HMDB51(VisionDataset):
otherwise from the ``test`` split. otherwise from the ``test`` split.
transform (callable, optional): A function/transform that takes in a TxHxWxC video transform (callable, optional): A function/transform that takes in a TxHxWxC video
and returns a transformed version. and returns a transformed version.
output_format (str, optional): The format of the output video tensors (before transforms).
Can be either "THWC" (default) or "TCHW".
Returns: Returns:
tuple: A 3-tuple with the following entries: tuple: A 3-tuple with the following entries:
- video (Tensor[T, H, W, C]): The `T` video frames - video (Tensor[T, H, W, C] or Tensor[T, C, H, W]): The `T` video frames
- audio(Tensor[K, L]): the audio frames, where `K` is the number of channels - audio(Tensor[K, L]): the audio frames, where `K` is the number of channels
and `L` is the number of points and `L` is the number of points
- label (int): class of the video clip - label (int): class of the video clip
...@@ -71,6 +73,7 @@ class HMDB51(VisionDataset): ...@@ -71,6 +73,7 @@ class HMDB51(VisionDataset):
_video_height: int = 0, _video_height: int = 0,
_video_min_dimension: int = 0, _video_min_dimension: int = 0,
_audio_samples: int = 0, _audio_samples: int = 0,
output_format: str = "THWC",
) -> None: ) -> None:
super().__init__(root) super().__init__(root)
if fold not in (1, 2, 3): if fold not in (1, 2, 3):
...@@ -96,6 +99,7 @@ class HMDB51(VisionDataset): ...@@ -96,6 +99,7 @@ class HMDB51(VisionDataset):
_video_height=_video_height, _video_height=_video_height,
_video_min_dimension=_video_min_dimension, _video_min_dimension=_video_min_dimension,
_audio_samples=_audio_samples, _audio_samples=_audio_samples,
output_format=output_format,
) )
# we bookkeep the full version of video clips because we want to be able # we bookkeep the full version of video clips because we want to be able
# to return the meta data of full version rather than the subset version of # to return the meta data of full version rather than the subset version of
......
...@@ -62,11 +62,14 @@ class Kinetics(VisionDataset): ...@@ -62,11 +62,14 @@ class Kinetics(VisionDataset):
download (bool): Download the official version of the dataset to root folder. download (bool): Download the official version of the dataset to root folder.
num_workers (int): Use multiple workers for VideoClips creation num_workers (int): Use multiple workers for VideoClips creation
num_download_workers (int): Use multiprocessing in order to speed up download. num_download_workers (int): Use multiprocessing in order to speed up download.
output_format (str, optional): The format of the output video tensors (before transforms).
Can be either "THWC" or "TCHW" (default).
Note that in most other utils and datasets, the default is actually "THWC".
Returns: Returns:
tuple: A 3-tuple with the following entries: tuple: A 3-tuple with the following entries:
- video (Tensor[T, C, H, W]): the `T` video frames in torch.uint8 tensor - video (Tensor[T, C, H, W] or Tensor[T, H, W, C]): the `T` video frames in torch.uint8 tensor
- audio(Tensor[K, L]): the audio frames, where `K` is the number of channels - audio(Tensor[K, L]): the audio frames, where `K` is the number of channels
and `L` is the number of points in torch.float tensor and `L` is the number of points in torch.float tensor
- label (int): class of the video clip - label (int): class of the video clip
...@@ -106,6 +109,7 @@ class Kinetics(VisionDataset): ...@@ -106,6 +109,7 @@ class Kinetics(VisionDataset):
_audio_samples: int = 0, _audio_samples: int = 0,
_audio_channels: int = 0, _audio_channels: int = 0,
_legacy: bool = False, _legacy: bool = False,
output_format: str = "TCHW",
) -> None: ) -> None:
# TODO: support test # TODO: support test
...@@ -115,10 +119,12 @@ class Kinetics(VisionDataset): ...@@ -115,10 +119,12 @@ class Kinetics(VisionDataset):
self.root = root self.root = root
self._legacy = _legacy self._legacy = _legacy
if _legacy: if _legacy:
print("Using legacy structure") print("Using legacy structure")
self.split_folder = root self.split_folder = root
self.split = "unknown" self.split = "unknown"
output_format = "THWC"
if download: if download:
raise ValueError("Cannot download the videos using legacy_structure.") raise ValueError("Cannot download the videos using legacy_structure.")
else: else:
...@@ -145,6 +151,7 @@ class Kinetics(VisionDataset): ...@@ -145,6 +151,7 @@ class Kinetics(VisionDataset):
_video_min_dimension=_video_min_dimension, _video_min_dimension=_video_min_dimension,
_audio_samples=_audio_samples, _audio_samples=_audio_samples,
_audio_channels=_audio_channels, _audio_channels=_audio_channels,
output_format=output_format,
) )
self.transform = transform self.transform = transform
...@@ -233,9 +240,6 @@ class Kinetics(VisionDataset): ...@@ -233,9 +240,6 @@ class Kinetics(VisionDataset):
def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, int]: def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, int]:
video, audio, info, video_idx = self.video_clips.get_clip(idx) video, audio, info, video_idx = self.video_clips.get_clip(idx)
if not self._legacy:
# [T,H,W,C] --> [T,C,H,W]
video = video.permute(0, 3, 1, 2)
label = self.samples[video_idx][1] label = self.samples[video_idx][1]
if self.transform is not None: if self.transform is not None:
...@@ -308,7 +312,7 @@ class Kinetics400(Kinetics): ...@@ -308,7 +312,7 @@ class Kinetics400(Kinetics):
warnings.warn( warnings.warn(
"The Kinetics400 class is deprecated since 0.12 and will be removed in 0.14." "The Kinetics400 class is deprecated since 0.12 and will be removed in 0.14."
"Please use Kinetics(..., num_classes='400') instead." "Please use Kinetics(..., num_classes='400') instead."
"Note that Kinetics(..., num_classes='400') returns video in a more logical Tensor[T, C, H, W] format." "Note that Kinetics(..., num_classes='400') returns video in a Tensor[T, C, H, W] format."
) )
if any(value is not None for value in (num_classes, split, download, num_download_workers)): if any(value is not None for value in (num_classes, split, download, num_download_workers)):
raise RuntimeError( raise RuntimeError(
......
...@@ -38,11 +38,13 @@ class UCF101(VisionDataset): ...@@ -38,11 +38,13 @@ class UCF101(VisionDataset):
otherwise from the ``test`` split. otherwise from the ``test`` split.
transform (callable, optional): A function/transform that takes in a TxHxWxC video transform (callable, optional): A function/transform that takes in a TxHxWxC video
and returns a transformed version. and returns a transformed version.
output_format (str, optional): The format of the output video tensors (before transforms).
Can be either "THWC" (default) or "TCHW".
Returns: Returns:
tuple: A 3-tuple with the following entries: tuple: A 3-tuple with the following entries:
- video (Tensor[T, H, W, C]): the `T` video frames - video (Tensor[T, H, W, C] or Tensor[T, C, H, W]): The `T` video frames
- audio(Tensor[K, L]): the audio frames, where `K` is the number of channels - audio(Tensor[K, L]): the audio frames, where `K` is the number of channels
and `L` is the number of points and `L` is the number of points
- label (int): class of the video clip - label (int): class of the video clip
...@@ -64,6 +66,7 @@ class UCF101(VisionDataset): ...@@ -64,6 +66,7 @@ class UCF101(VisionDataset):
_video_height: int = 0, _video_height: int = 0,
_video_min_dimension: int = 0, _video_min_dimension: int = 0,
_audio_samples: int = 0, _audio_samples: int = 0,
output_format: str = "THWC",
) -> None: ) -> None:
super().__init__(root) super().__init__(root)
if not 1 <= fold <= 3: if not 1 <= fold <= 3:
...@@ -87,6 +90,7 @@ class UCF101(VisionDataset): ...@@ -87,6 +90,7 @@ class UCF101(VisionDataset):
_video_height=_video_height, _video_height=_video_height,
_video_min_dimension=_video_min_dimension, _video_min_dimension=_video_min_dimension,
_audio_samples=_audio_samples, _audio_samples=_audio_samples,
output_format=output_format,
) )
# we bookkeep the full version of video clips because we want to be able # we bookkeep the full version of video clips because we want to be able
# to return the meta data of full version rather than the subset version of # to return the meta data of full version rather than the subset version of
......
...@@ -99,6 +99,7 @@ class VideoClips: ...@@ -99,6 +99,7 @@ class VideoClips:
on the resampled video on the resampled video
num_workers (int): how many subprocesses to use for data loading. num_workers (int): how many subprocesses to use for data loading.
0 means that the data will be loaded in the main process. (default: 0) 0 means that the data will be loaded in the main process. (default: 0)
output_format (str): The format of the output video tensors. Can be either "THWC" (default) or "TCHW".
""" """
def __init__( def __init__(
...@@ -115,6 +116,7 @@ class VideoClips: ...@@ -115,6 +116,7 @@ class VideoClips:
_video_max_dimension: int = 0, _video_max_dimension: int = 0,
_audio_samples: int = 0, _audio_samples: int = 0,
_audio_channels: int = 0, _audio_channels: int = 0,
output_format: str = "THWC",
) -> None: ) -> None:
self.video_paths = video_paths self.video_paths = video_paths
...@@ -127,6 +129,9 @@ class VideoClips: ...@@ -127,6 +129,9 @@ class VideoClips:
self._video_max_dimension = _video_max_dimension self._video_max_dimension = _video_max_dimension
self._audio_samples = _audio_samples self._audio_samples = _audio_samples
self._audio_channels = _audio_channels self._audio_channels = _audio_channels
self.output_format = output_format.upper()
if self.output_format not in ("THWC", "TCHW"):
raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.")
if _precomputed_metadata is None: if _precomputed_metadata is None:
self._compute_frame_pts() self._compute_frame_pts()
...@@ -366,6 +371,11 @@ class VideoClips: ...@@ -366,6 +371,11 @@ class VideoClips:
video = video[resampling_idx] video = video[resampling_idx]
info["video_fps"] = self.frame_rate info["video_fps"] = self.frame_rate
assert len(video) == self.num_frames, f"{video.shape} x {self.num_frames}" assert len(video) == self.num_frames, f"{video.shape} x {self.num_frames}"
if self.output_format == "TCHW":
# [T,H,W,C] --> [T,C,H,W]
video = video.permute(0, 3, 1, 2)
return video, audio, info, video_idx return video, audio, info, video_idx
def __getstate__(self) -> Dict[str, Any]: def __getstate__(self) -> Dict[str, Any]:
......
...@@ -239,6 +239,7 @@ def read_video( ...@@ -239,6 +239,7 @@ def read_video(
start_pts: Union[float, Fraction] = 0, start_pts: Union[float, Fraction] = 0,
end_pts: Optional[Union[float, Fraction]] = None, end_pts: Optional[Union[float, Fraction]] = None,
pts_unit: str = "pts", pts_unit: str = "pts",
output_format: str = "THWC",
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]: ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
""" """
Reads a video from a file, returning both the video frames as well as Reads a video from a file, returning both the video frames as well as
...@@ -252,15 +253,20 @@ def read_video( ...@@ -252,15 +253,20 @@ def read_video(
The end presentation time The end presentation time
pts_unit (str, optional): unit in which start_pts and end_pts values will be interpreted, pts_unit (str, optional): unit in which start_pts and end_pts values will be interpreted,
either 'pts' or 'sec'. Defaults to 'pts'. either 'pts' or 'sec'. Defaults to 'pts'.
output_format (str, optional): The format of the output video tensors. Can be either "THWC" (default) or "TCHW".
Returns: Returns:
vframes (Tensor[T, H, W, C]): the `T` video frames vframes (Tensor[T, H, W, C] or Tensor[T, C, H, W]): the `T` video frames
aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points
info (Dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int) info (Dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int)
""" """
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(read_video) _log_api_usage_once(read_video)
output_format = output_format.upper()
if output_format not in ("THWC", "TCHW"):
raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.")
from torchvision import get_video_backend from torchvision import get_video_backend
if not os.path.exists(filename): if not os.path.exists(filename):
...@@ -334,6 +340,10 @@ def read_video( ...@@ -334,6 +340,10 @@ def read_video(
else: else:
aframes = torch.empty((1, 0), dtype=torch.float32) aframes = torch.empty((1, 0), dtype=torch.float32)
if output_format == "TCHW":
# [T,H,W,C] --> [T,C,H,W]
vframes = vframes.permute(0, 3, 1, 2)
return vframes, aframes, info return vframes, aframes, info
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment