Unverified Commit 60ce5bf4 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Remove `(N, T, H, W, C) => (N, T, C, H, W)` from presets (#6058)

* Remove `(N, T, H, W, C) => (N, T, C, H, W)` conversion on presets

* Update docs.

* Fix the tests

* Use `output_format` for `read_video()`

* Use `output_format` for `Kinetics()`

* Adding input descriptions on presets
parent 4c668139
...@@ -471,7 +471,7 @@ Here is an example of how to use the pre-trained video classification models: ...@@ -471,7 +471,7 @@ Here is an example of how to use the pre-trained video classification models:
from torchvision.io.video import read_video from torchvision.io.video import read_video
from torchvision.models.video import r3d_18, R3D_18_Weights from torchvision.models.video import r3d_18, R3D_18_Weights
vid, _, _ = read_video("test/assets/videos/v_SoccerJuggling_g23_c01.avi") vid, _, _ = read_video("test/assets/videos/v_SoccerJuggling_g23_c01.avi", output_format="TCHW")
vid = vid[:32] # optionally shorten duration vid = vid[:32] # optionally shorten duration
# Step 1: Initialize model with the best available weights # Step 1: Initialize model with the best available weights
......
...@@ -72,8 +72,7 @@ _ = urlretrieve(video_url, video_path) ...@@ -72,8 +72,7 @@ _ = urlretrieve(video_url, video_path)
# single model input. # single model input.
from torchvision.io import read_video from torchvision.io import read_video
frames, _, _ = read_video(str(video_path)) frames, _, _ = read_video(str(video_path), output_format="TCHW")
frames = frames.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
img1_batch = torch.stack([frames[100], frames[150]]) img1_batch = torch.stack([frames[100], frames[150]])
img2_batch = torch.stack([frames[101], frames[151]]) img2_batch = torch.stack([frames[101], frames[151]])
......
...@@ -157,6 +157,7 @@ def main(args): ...@@ -157,6 +157,7 @@ def main(args):
"avi", "avi",
"mp4", "mp4",
), ),
output_format="TCHW",
) )
if args.cache_dataset: if args.cache_dataset:
print(f"Saving dataset_train to {cache_path}") print(f"Saving dataset_train to {cache_path}")
...@@ -193,6 +194,7 @@ def main(args): ...@@ -193,6 +194,7 @@ def main(args):
"avi", "avi",
"mp4", "mp4",
), ),
output_format="TCHW",
) )
if args.cache_dataset: if args.cache_dataset:
print(f"Saving dataset_test to {cache_path}") print(f"Saving dataset_test to {cache_path}")
......
...@@ -180,7 +180,7 @@ def test_transforms_jit(model_fn): ...@@ -180,7 +180,7 @@ def test_transforms_jit(model_fn):
"input_shape": (1, 3, 520, 520), "input_shape": (1, 3, 520, 520),
}, },
"video": { "video": {
"input_shape": (1, 4, 112, 112, 3), "input_shape": (1, 4, 3, 112, 112),
}, },
"optical_flow": { "optical_flow": {
"input_shape": (1, 3, 128, 128), "input_shape": (1, 3, 128, 128),
......
...@@ -29,7 +29,10 @@ class ObjectDetection(nn.Module): ...@@ -29,7 +29,10 @@ class ObjectDetection(nn.Module):
return self.__class__.__name__ + "()" return self.__class__.__name__ + "()"
def describe(self) -> str: def describe(self) -> str:
return "The images are rescaled to ``[0.0, 1.0]``." return (
"Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
"The images are rescaled to ``[0.0, 1.0]``."
)
class ImageClassification(nn.Module): class ImageClassification(nn.Module):
...@@ -70,6 +73,7 @@ class ImageClassification(nn.Module): ...@@ -70,6 +73,7 @@ class ImageClassification(nn.Module):
def describe(self) -> str: def describe(self) -> str:
return ( return (
"Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, " f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, "
f"followed by a central crop of ``crop_size={self.crop_size}``. Finally the values are first rescaled to " f"followed by a central crop of ``crop_size={self.crop_size}``. Finally the values are first rescaled to "
f"``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and ``std={self.std}``." f"``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and ``std={self.std}``."
...@@ -99,7 +103,6 @@ class VideoClassification(nn.Module): ...@@ -99,7 +103,6 @@ class VideoClassification(nn.Module):
vid = vid.unsqueeze(dim=0) vid = vid.unsqueeze(dim=0)
need_squeeze = True need_squeeze = True
vid = vid.permute(0, 1, 4, 2, 3) # (N, T, H, W, C) => (N, T, C, H, W)
N, T, C, H, W = vid.shape N, T, C, H, W = vid.shape
vid = vid.view(-1, C, H, W) vid = vid.view(-1, C, H, W)
vid = F.resize(vid, self.resize_size, interpolation=self.interpolation) vid = F.resize(vid, self.resize_size, interpolation=self.interpolation)
...@@ -126,9 +129,11 @@ class VideoClassification(nn.Module): ...@@ -126,9 +129,11 @@ class VideoClassification(nn.Module):
def describe(self) -> str: def describe(self) -> str:
return ( return (
f"The video frames are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, " "Accepts batched ``(B, T, C, H, W)`` and single ``(T, C, H, W)`` video frame ``torch.Tensor`` objects. "
f"The frames are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, "
f"followed by a central crop of ``crop_size={self.crop_size}``. Finally the values are first rescaled to " f"followed by a central crop of ``crop_size={self.crop_size}``. Finally the values are first rescaled to "
f"``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and ``std={self.std}``." f"``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and ``std={self.std}``. Finally the output "
"dimensions are permuted to ``(..., C, T, H, W)`` tensors."
) )
...@@ -167,6 +172,7 @@ class SemanticSegmentation(nn.Module): ...@@ -167,6 +172,7 @@ class SemanticSegmentation(nn.Module):
def describe(self) -> str: def describe(self) -> str:
return ( return (
"Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``. " f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``. "
f"Finally the values are first rescaled to ``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and " f"Finally the values are first rescaled to ``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and "
f"``std={self.std}``." f"``std={self.std}``."
...@@ -196,4 +202,7 @@ class OpticalFlow(nn.Module): ...@@ -196,4 +202,7 @@ class OpticalFlow(nn.Module):
return self.__class__.__name__ + "()" return self.__class__.__name__ + "()"
def describe(self) -> str: def describe(self) -> str:
return "The images are rescaled to ``[-1.0, 1.0]``." return (
"Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
"The images are rescaled to ``[-1.0, 1.0]``."
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment