Unverified Commit 4c668139 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Add output_format do video datasets and readers (#6061)


Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 5486b768
......@@ -37,11 +37,13 @@ class HMDB51(VisionDataset):
otherwise from the ``test`` split.
transform (callable, optional): A function/transform that takes in a TxHxWxC video
and returns a transformed version.
output_format (str, optional): The format of the output video tensors (before transforms).
Can be either "THWC" (default) or "TCHW".
Returns:
tuple: A 3-tuple with the following entries:
- video (Tensor[T, H, W, C]): The `T` video frames
- video (Tensor[T, H, W, C] or Tensor[T, C, H, W]): The `T` video frames
- audio(Tensor[K, L]): the audio frames, where `K` is the number of channels
and `L` is the number of points
- label (int): class of the video clip
......@@ -71,6 +73,7 @@ class HMDB51(VisionDataset):
_video_height: int = 0,
_video_min_dimension: int = 0,
_audio_samples: int = 0,
output_format: str = "THWC",
) -> None:
super().__init__(root)
if fold not in (1, 2, 3):
......@@ -96,6 +99,7 @@ class HMDB51(VisionDataset):
_video_height=_video_height,
_video_min_dimension=_video_min_dimension,
_audio_samples=_audio_samples,
output_format=output_format,
)
# we bookkeep the full version of video clips because we want to be able
# to return the meta data of full version rather than the subset version of
......
......@@ -62,11 +62,14 @@ class Kinetics(VisionDataset):
download (bool): Download the official version of the dataset to root folder.
num_workers (int): Use multiple workers for VideoClips creation
num_download_workers (int): Use multiprocessing in order to speed up download.
output_format (str, optional): The format of the output video tensors (before transforms).
Can be either "THWC" or "TCHW" (default).
Note that in most other utils and datasets, the default is actually "THWC".
Returns:
tuple: A 3-tuple with the following entries:
- video (Tensor[T, C, H, W]): the `T` video frames in torch.uint8 tensor
- video (Tensor[T, C, H, W] or Tensor[T, H, W, C]): the `T` video frames in torch.uint8 tensor
- audio(Tensor[K, L]): the audio frames, where `K` is the number of channels
and `L` is the number of points in torch.float tensor
- label (int): class of the video clip
......@@ -106,6 +109,7 @@ class Kinetics(VisionDataset):
_audio_samples: int = 0,
_audio_channels: int = 0,
_legacy: bool = False,
output_format: str = "TCHW",
) -> None:
# TODO: support test
......@@ -115,10 +119,12 @@ class Kinetics(VisionDataset):
self.root = root
self._legacy = _legacy
if _legacy:
print("Using legacy structure")
self.split_folder = root
self.split = "unknown"
output_format = "THWC"
if download:
raise ValueError("Cannot download the videos using legacy_structure.")
else:
......@@ -145,6 +151,7 @@ class Kinetics(VisionDataset):
_video_min_dimension=_video_min_dimension,
_audio_samples=_audio_samples,
_audio_channels=_audio_channels,
output_format=output_format,
)
self.transform = transform
......@@ -233,9 +240,6 @@ class Kinetics(VisionDataset):
def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, int]:
video, audio, info, video_idx = self.video_clips.get_clip(idx)
if not self._legacy:
# [T,H,W,C] --> [T,C,H,W]
video = video.permute(0, 3, 1, 2)
label = self.samples[video_idx][1]
if self.transform is not None:
......@@ -308,7 +312,7 @@ class Kinetics400(Kinetics):
warnings.warn(
"The Kinetics400 class is deprecated since 0.12 and will be removed in 0.14."
"Please use Kinetics(..., num_classes='400') instead."
"Note that Kinetics(..., num_classes='400') returns video in a more logical Tensor[T, C, H, W] format."
"Note that Kinetics(..., num_classes='400') returns video in a Tensor[T, C, H, W] format."
)
if any(value is not None for value in (num_classes, split, download, num_download_workers)):
raise RuntimeError(
......
......@@ -38,11 +38,13 @@ class UCF101(VisionDataset):
otherwise from the ``test`` split.
transform (callable, optional): A function/transform that takes in a TxHxWxC video
and returns a transformed version.
output_format (str, optional): The format of the output video tensors (before transforms).
Can be either "THWC" (default) or "TCHW".
Returns:
tuple: A 3-tuple with the following entries:
- video (Tensor[T, H, W, C]): the `T` video frames
- video (Tensor[T, H, W, C] or Tensor[T, C, H, W]): The `T` video frames
- audio(Tensor[K, L]): the audio frames, where `K` is the number of channels
and `L` is the number of points
- label (int): class of the video clip
......@@ -64,6 +66,7 @@ class UCF101(VisionDataset):
_video_height: int = 0,
_video_min_dimension: int = 0,
_audio_samples: int = 0,
output_format: str = "THWC",
) -> None:
super().__init__(root)
if not 1 <= fold <= 3:
......@@ -87,6 +90,7 @@ class UCF101(VisionDataset):
_video_height=_video_height,
_video_min_dimension=_video_min_dimension,
_audio_samples=_audio_samples,
output_format=output_format,
)
# we bookkeep the full version of video clips because we want to be able
# to return the meta data of full version rather than the subset version of
......
......@@ -99,6 +99,7 @@ class VideoClips:
on the resampled video
num_workers (int): how many subprocesses to use for data loading.
0 means that the data will be loaded in the main process. (default: 0)
output_format (str): The format of the output video tensors. Can be either "THWC" (default) or "TCHW".
"""
def __init__(
......@@ -115,6 +116,7 @@ class VideoClips:
_video_max_dimension: int = 0,
_audio_samples: int = 0,
_audio_channels: int = 0,
output_format: str = "THWC",
) -> None:
self.video_paths = video_paths
......@@ -127,6 +129,9 @@ class VideoClips:
self._video_max_dimension = _video_max_dimension
self._audio_samples = _audio_samples
self._audio_channels = _audio_channels
self.output_format = output_format.upper()
if self.output_format not in ("THWC", "TCHW"):
raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.")
if _precomputed_metadata is None:
self._compute_frame_pts()
......@@ -366,6 +371,11 @@ class VideoClips:
video = video[resampling_idx]
info["video_fps"] = self.frame_rate
assert len(video) == self.num_frames, f"{video.shape} x {self.num_frames}"
if self.output_format == "TCHW":
# [T,H,W,C] --> [T,C,H,W]
video = video.permute(0, 3, 1, 2)
return video, audio, info, video_idx
def __getstate__(self) -> Dict[str, Any]:
......
......@@ -239,6 +239,7 @@ def read_video(
start_pts: Union[float, Fraction] = 0,
end_pts: Optional[Union[float, Fraction]] = None,
pts_unit: str = "pts",
output_format: str = "THWC",
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
"""
Reads a video from a file, returning both the video frames as well as
......@@ -252,15 +253,20 @@ def read_video(
The end presentation time
pts_unit (str, optional): unit in which start_pts and end_pts values will be interpreted,
either 'pts' or 'sec'. Defaults to 'pts'.
output_format (str, optional): The format of the output video tensors. Can be either "THWC" (default) or "TCHW".
Returns:
vframes (Tensor[T, H, W, C]): the `T` video frames
vframes (Tensor[T, H, W, C] or Tensor[T, C, H, W]): the `T` video frames
aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points
info (Dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int)
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(read_video)
output_format = output_format.upper()
if output_format not in ("THWC", "TCHW"):
raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.")
from torchvision import get_video_backend
if not os.path.exists(filename):
......@@ -334,6 +340,10 @@ def read_video(
else:
aframes = torch.empty((1, 0), dtype=torch.float32)
if output_format == "TCHW":
# [T,H,W,C] --> [T,C,H,W]
vframes = vframes.permute(0, 3, 1, 2)
return vframes, aframes, info
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment