Unverified Commit 54a2d4e8 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

[prototype] Video Classes Clean up (#6751)

* Removing unnecessary methods/classes.

* Unions instead of ImageOrVideo types

* Fixing JIT issue.
parent 7d36d263
from ._bounding_box import BoundingBox, BoundingBoxFormat from ._bounding_box import BoundingBox, BoundingBoxFormat
from ._encoded import EncodedData, EncodedImage, EncodedVideo from ._encoded import EncodedData, EncodedImage
from ._feature import _Feature, FillType, FillTypeJIT, InputType, InputTypeJIT, is_simple_tensor from ._feature import _Feature, FillType, FillTypeJIT, InputType, InputTypeJIT, is_simple_tensor
from ._image import ( from ._image import (
ColorSpace, ColorSpace,
...@@ -14,12 +14,10 @@ from ._image import ( ...@@ -14,12 +14,10 @@ from ._image import (
from ._label import Label, OneHotLabel from ._label import Label, OneHotLabel
from ._mask import Mask from ._mask import Mask
from ._video import ( from ._video import (
ImageOrVideoType,
ImageOrVideoTypeJIT,
LegacyVideoType, LegacyVideoType,
LegacyVideoTypeJIT, LegacyVideoTypeJIT,
TensorImageOrVideoType, TensorVideoType,
TensorImageOrVideoTypeJIT, TensorVideoTypeJIT,
Video, Video,
VideoType, VideoType,
VideoTypeJIT, VideoTypeJIT,
......
...@@ -55,7 +55,3 @@ class EncodedImage(EncodedData): ...@@ -55,7 +55,3 @@ class EncodedImage(EncodedData):
self._spatial_size = image.height, image.width self._spatial_size = image.height, image.width
return self._spatial_size return self._spatial_size
class EncodedVideo(EncodedData):
pass
...@@ -6,10 +6,8 @@ from typing import Any, cast, List, Optional, Tuple, Union ...@@ -6,10 +6,8 @@ from typing import Any, cast, List, Optional, Tuple, Union
import PIL.Image import PIL.Image
import torch import torch
from torchvision._utils import StrEnum from torchvision._utils import StrEnum
from torchvision.transforms.functional import InterpolationMode, to_pil_image from torchvision.transforms.functional import InterpolationMode
from torchvision.utils import draw_bounding_boxes, make_grid
from ._bounding_box import BoundingBox
from ._feature import _Feature, FillTypeJIT from ._feature import _Feature, FillTypeJIT
...@@ -124,16 +122,6 @@ class Image(_Feature): ...@@ -124,16 +122,6 @@ class Image(_Feature):
color_space=color_space, color_space=color_space,
) )
def show(self) -> None:
# TODO: this is useful for developing and debugging but we should remove or at least revisit this before we
# promote this out of the prototype state
to_pil_image(make_grid(self.view(-1, *self.shape[-3:]))).show()
def draw_bounding_box(self, bounding_box: BoundingBox, **kwargs: Any) -> Image:
# TODO: this is useful for developing and debugging but we should remove or at least revisit this before we
# promote this out of the prototype state
return Image.wrap_like(self, draw_bounding_boxes(self, bounding_box.to_format("xyxy").view(-1, 4), **kwargs))
def horizontal_flip(self) -> Image: def horizontal_flip(self) -> Image:
output = self._F.horizontal_flip_image_tensor(self) output = self._F.horizontal_flip_image_tensor(self)
return Image.wrap_like(self, output) return Image.wrap_like(self, output)
......
...@@ -7,7 +7,7 @@ import torch ...@@ -7,7 +7,7 @@ import torch
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ._feature import _Feature, FillTypeJIT from ._feature import _Feature, FillTypeJIT
from ._image import ColorSpace, ImageType, ImageTypeJIT, TensorImageType, TensorImageTypeJIT from ._image import ColorSpace
class Video(_Feature): class Video(_Feature):
...@@ -236,9 +236,3 @@ LegacyVideoType = torch.Tensor ...@@ -236,9 +236,3 @@ LegacyVideoType = torch.Tensor
LegacyVideoTypeJIT = torch.Tensor LegacyVideoTypeJIT = torch.Tensor
TensorVideoType = Union[torch.Tensor, Video] TensorVideoType = Union[torch.Tensor, Video]
TensorVideoTypeJIT = torch.Tensor TensorVideoTypeJIT = torch.Tensor
# TODO: decide if we should do definitions for both Images and Videos or use unions in the methods
ImageOrVideoType = Union[ImageType, VideoType]
ImageOrVideoTypeJIT = Union[ImageTypeJIT, VideoTypeJIT]
TensorImageOrVideoType = Union[TensorImageType, TensorVideoType]
TensorImageOrVideoTypeJIT = Union[TensorImageTypeJIT, TensorVideoTypeJIT]
import math import math
import numbers import numbers
import warnings import warnings
from typing import Any, cast, Dict, List, Optional, Tuple from typing import Any, cast, Dict, List, Optional, Tuple, Union
import PIL.Image import PIL.Image
import torch import torch
...@@ -92,14 +92,15 @@ class RandomErasing(_RandomApplyTransform): ...@@ -92,14 +92,15 @@ class RandomErasing(_RandomApplyTransform):
return dict(i=i, j=j, h=h, w=w, v=v) return dict(i=i, j=j, h=h, w=w, v=v)
def _transform(self, inpt: features.ImageOrVideoType, params: Dict[str, Any]) -> features.ImageOrVideoType: def _transform(
self, inpt: Union[features.ImageType, features.VideoType], params: Dict[str, Any]
) -> Union[features.ImageType, features.VideoType]:
if params["v"] is not None: if params["v"] is not None:
inpt = F.erase(inpt, **params, inplace=self.inplace) inpt = F.erase(inpt, **params, inplace=self.inplace)
return inpt return inpt
# TODO: Add support for Video: https://github.com/pytorch/vision/issues/6731
class _BaseMixupCutmix(_RandomApplyTransform): class _BaseMixupCutmix(_RandomApplyTransform):
def __init__(self, alpha: float, p: float = 0.5) -> None: def __init__(self, alpha: float, p: float = 0.5) -> None:
super().__init__(p=p) super().__init__(p=p)
......
...@@ -35,7 +35,7 @@ class _AutoAugmentBase(Transform): ...@@ -35,7 +35,7 @@ class _AutoAugmentBase(Transform):
self, self,
sample: Any, sample: Any,
unsupported_types: Tuple[Type, ...] = (features.BoundingBox, features.Mask), unsupported_types: Tuple[Type, ...] = (features.BoundingBox, features.Mask),
) -> Tuple[int, features.ImageOrVideoType]: ) -> Tuple[int, Union[features.ImageType, features.VideoType]]:
sample_flat, _ = tree_flatten(sample) sample_flat, _ = tree_flatten(sample)
image_or_videos = [] image_or_videos = []
for id, inpt in enumerate(sample_flat): for id, inpt in enumerate(sample_flat):
...@@ -60,12 +60,12 @@ class _AutoAugmentBase(Transform): ...@@ -60,12 +60,12 @@ class _AutoAugmentBase(Transform):
def _apply_image_or_video_transform( def _apply_image_or_video_transform(
self, self,
image: features.ImageOrVideoType, image: Union[features.ImageType, features.VideoType],
transform_id: str, transform_id: str,
magnitude: float, magnitude: float,
interpolation: InterpolationMode, interpolation: InterpolationMode,
fill: Dict[Type, features.FillType], fill: Dict[Type, features.FillType],
) -> features.ImageOrVideoType: ) -> Union[features.ImageType, features.VideoType]:
fill_ = fill[type(image)] fill_ = fill[type(image)]
fill_ = F._geometry._convert_fill_arg(fill_) fill_ = F._geometry._convert_fill_arg(fill_)
......
...@@ -111,8 +111,8 @@ class RandomPhotometricDistort(Transform): ...@@ -111,8 +111,8 @@ class RandomPhotometricDistort(Transform):
) )
def _permute_channels( def _permute_channels(
self, inpt: features.ImageOrVideoType, permutation: torch.Tensor self, inpt: Union[features.ImageType, features.VideoType], permutation: torch.Tensor
) -> features.ImageOrVideoType: ) -> Union[features.ImageType, features.VideoType]:
if isinstance(inpt, PIL.Image.Image): if isinstance(inpt, PIL.Image.Image):
inpt = F.pil_to_tensor(inpt) inpt = F.pil_to_tensor(inpt)
...@@ -126,7 +126,9 @@ class RandomPhotometricDistort(Transform): ...@@ -126,7 +126,9 @@ class RandomPhotometricDistort(Transform):
return output return output
def _transform(self, inpt: features.ImageOrVideoType, params: Dict[str, Any]) -> features.ImageOrVideoType: def _transform(
self, inpt: Union[features.ImageType, features.VideoType], params: Dict[str, Any]
) -> Union[features.ImageType, features.VideoType]:
if params["brightness"]: if params["brightness"]:
inpt = F.adjust_brightness( inpt = F.adjust_brightness(
inpt, brightness_factor=ColorJitter._generate_value(self.brightness[0], self.brightness[1]) inpt, brightness_factor=ColorJitter._generate_value(self.brightness[0], self.brightness[1])
......
...@@ -52,7 +52,9 @@ class Grayscale(Transform): ...@@ -52,7 +52,9 @@ class Grayscale(Transform):
super().__init__() super().__init__()
self.num_output_channels = num_output_channels self.num_output_channels = num_output_channels
def _transform(self, inpt: features.ImageOrVideoType, params: Dict[str, Any]) -> features.ImageOrVideoType: def _transform(
self, inpt: Union[features.ImageType, features.VideoType], params: Dict[str, Any]
) -> Union[features.ImageType, features.VideoType]:
output = _F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels) output = _F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels)
if isinstance(inpt, (features.Image, features.Video)): if isinstance(inpt, (features.Image, features.Video)):
output = inpt.wrap_like(inpt, output, color_space=features.ColorSpace.GRAY) # type: ignore[arg-type] output = inpt.wrap_like(inpt, output, color_space=features.ColorSpace.GRAY) # type: ignore[arg-type]
...@@ -81,7 +83,9 @@ class RandomGrayscale(_RandomApplyTransform): ...@@ -81,7 +83,9 @@ class RandomGrayscale(_RandomApplyTransform):
num_input_channels, *_ = query_chw(sample) num_input_channels, *_ = query_chw(sample)
return dict(num_input_channels=num_input_channels) return dict(num_input_channels=num_input_channels)
def _transform(self, inpt: features.ImageOrVideoType, params: Dict[str, Any]) -> features.ImageOrVideoType: def _transform(
self, inpt: Union[features.ImageType, features.VideoType], params: Dict[str, Any]
) -> Union[features.ImageType, features.VideoType]:
output = _F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"]) output = _F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"])
if isinstance(inpt, (features.Image, features.Video)): if isinstance(inpt, (features.Image, features.Video)):
output = inpt.wrap_like(inpt, output, color_space=features.ColorSpace.GRAY) # type: ignore[arg-type] output = inpt.wrap_like(inpt, output, color_space=features.ColorSpace.GRAY) # type: ignore[arg-type]
......
...@@ -148,6 +148,9 @@ class RandomResizedCrop(Transform): ...@@ -148,6 +148,9 @@ class RandomResizedCrop(Transform):
) )
ImageOrVideoTypeJIT = Union[features.ImageTypeJIT, features.VideoTypeJIT]
class FiveCrop(Transform): class FiveCrop(Transform):
""" """
Example: Example:
...@@ -177,14 +180,8 @@ class FiveCrop(Transform): ...@@ -177,14 +180,8 @@ class FiveCrop(Transform):
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
def _transform( def _transform(
self, inpt: features.ImageOrVideoType, params: Dict[str, Any] self, inpt: ImageOrVideoTypeJIT, params: Dict[str, Any]
) -> Tuple[ ) -> Tuple[ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT]:
features.ImageOrVideoType,
features.ImageOrVideoType,
features.ImageOrVideoType,
features.ImageOrVideoType,
features.ImageOrVideoType,
]:
return F.five_crop(inpt, self.size) return F.five_crop(inpt, self.size)
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
...@@ -205,7 +202,9 @@ class TenCrop(Transform): ...@@ -205,7 +202,9 @@ class TenCrop(Transform):
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
self.vertical_flip = vertical_flip self.vertical_flip = vertical_flip
def _transform(self, inpt: features.ImageOrVideoType, params: Dict[str, Any]) -> List[features.ImageOrVideoType]: def _transform(
self, inpt: Union[features.ImageType, features.VideoType], params: Dict[str, Any]
) -> Union[List[features.ImageTypeJIT], List[features.VideoTypeJIT]]:
return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip) return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip)
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
......
...@@ -29,8 +29,8 @@ class ConvertImageDtype(Transform): ...@@ -29,8 +29,8 @@ class ConvertImageDtype(Transform):
self.dtype = dtype self.dtype = dtype
def _transform( def _transform(
self, inpt: features.TensorImageOrVideoType, params: Dict[str, Any] self, inpt: Union[features.TensorImageType, features.TensorVideoType], params: Dict[str, Any]
) -> features.TensorImageOrVideoType: ) -> Union[features.TensorImageType, features.TensorVideoType]:
output = F.convert_image_dtype(inpt, dtype=self.dtype) output = F.convert_image_dtype(inpt, dtype=self.dtype)
return ( return (
output if features.is_simple_tensor(inpt) else type(inpt).wrap_like(inpt, output) # type: ignore[attr-defined] output if features.is_simple_tensor(inpt) else type(inpt).wrap_like(inpt, output) # type: ignore[attr-defined]
...@@ -58,7 +58,9 @@ class ConvertColorSpace(Transform): ...@@ -58,7 +58,9 @@ class ConvertColorSpace(Transform):
self.copy = copy self.copy = copy
def _transform(self, inpt: features.ImageOrVideoType, params: Dict[str, Any]) -> features.ImageOrVideoType: def _transform(
self, inpt: Union[features.ImageType, features.VideoType], params: Dict[str, Any]
) -> Union[features.ImageType, features.VideoType]:
return F.convert_color_space( return F.convert_color_space(
inpt, color_space=self.color_space, old_color_space=self.old_color_space, copy=self.copy inpt, color_space=self.color_space, old_color_space=self.old_color_space, copy=self.copy
) )
......
...@@ -68,7 +68,9 @@ class LinearTransformation(Transform): ...@@ -68,7 +68,9 @@ class LinearTransformation(Transform):
return super().forward(*inputs) return super().forward(*inputs)
def _transform(self, inpt: features.TensorImageOrVideoType, params: Dict[str, Any]) -> torch.Tensor: def _transform(
self, inpt: Union[features.TensorImageType, features.TensorVideoType], params: Dict[str, Any]
) -> torch.Tensor:
# Image instance after linear transformation is not Image anymore due to unknown data range # Image instance after linear transformation is not Image anymore due to unknown data range
# Thus we will return Tensor for input Image # Thus we will return Tensor for input Image
...@@ -101,7 +103,9 @@ class Normalize(Transform): ...@@ -101,7 +103,9 @@ class Normalize(Transform):
self.std = list(std) self.std = list(std)
self.inplace = inplace self.inplace = inplace
def _transform(self, inpt: features.TensorImageOrVideoType, params: Dict[str, Any]) -> torch.Tensor: def _transform(
self, inpt: Union[features.TensorImageType, features.TensorVideoType], params: Dict[str, Any]
) -> torch.Tensor:
return F.normalize(inpt, mean=self.mean, std=self.std, inplace=self.inplace) return F.normalize(inpt, mean=self.mean, std=self.std, inplace=self.inplace)
def forward(self, *inpts: Any) -> Any: def forward(self, *inpts: Any) -> Any:
......
from typing import Union
import PIL.Image import PIL.Image
import torch import torch
...@@ -24,14 +26,14 @@ def erase_video( ...@@ -24,14 +26,14 @@ def erase_video(
def erase( def erase(
inpt: features.ImageOrVideoTypeJIT, inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT],
i: int, i: int,
j: int, j: int,
h: int, h: int,
w: int, w: int,
v: torch.Tensor, v: torch.Tensor,
inplace: bool = False, inplace: bool = False,
) -> features.ImageOrVideoTypeJIT: ) -> Union[features.ImageTypeJIT, features.VideoTypeJIT]:
if isinstance(inpt, torch.Tensor): if isinstance(inpt, torch.Tensor):
output = erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) output = erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)): if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)):
......
...@@ -59,7 +59,7 @@ def to_tensor(inpt: Any) -> torch.Tensor: ...@@ -59,7 +59,7 @@ def to_tensor(inpt: Any) -> torch.Tensor:
return _F.to_tensor(inpt) return _F.to_tensor(inpt)
def get_image_size(inpt: features.ImageOrVideoTypeJIT) -> List[int]: def get_image_size(inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT]) -> List[int]:
warnings.warn( warnings.warn(
"The function `get_image_size(...)` is deprecated and will be removed in a future release. " "The function `get_image_size(...)` is deprecated and will be removed in a future release. "
"Instead, please use `get_spatial_size(...)` which returns `[h, w]` instead of `[w, h]`." "Instead, please use `get_spatial_size(...)` which returns `[h, w]` instead of `[w, h]`."
......
...@@ -1382,16 +1382,13 @@ def five_crop_video( ...@@ -1382,16 +1382,13 @@ def five_crop_video(
return five_crop_image_tensor(video, size) return five_crop_image_tensor(video, size)
ImageOrVideoTypeJIT = Union[features.ImageTypeJIT, features.VideoTypeJIT]
def five_crop( def five_crop(
inpt: features.ImageOrVideoTypeJIT, size: List[int] inpt: ImageOrVideoTypeJIT, size: List[int]
) -> Tuple[ ) -> Tuple[ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT]:
features.ImageOrVideoTypeJIT, # TODO: consider breaking BC here to return List[features.ImageTypeJIT/VideoTypeJIT] to align this op with `ten_crop`
features.ImageOrVideoTypeJIT,
features.ImageOrVideoTypeJIT,
features.ImageOrVideoTypeJIT,
features.ImageOrVideoTypeJIT,
]:
# TODO: consider breaking BC here to return List[features.ImageOrVideoTypeJIT] to align this op with `ten_crop`
if isinstance(inpt, torch.Tensor): if isinstance(inpt, torch.Tensor):
output = five_crop_image_tensor(inpt, size) output = five_crop_image_tensor(inpt, size)
if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)): if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)):
...@@ -1434,8 +1431,8 @@ def ten_crop_video(video: torch.Tensor, size: List[int], vertical_flip: bool = F ...@@ -1434,8 +1431,8 @@ def ten_crop_video(video: torch.Tensor, size: List[int], vertical_flip: bool = F
def ten_crop( def ten_crop(
inpt: features.ImageOrVideoTypeJIT, size: List[int], vertical_flip: bool = False inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT], size: List[int], vertical_flip: bool = False
) -> List[features.ImageOrVideoTypeJIT]: ) -> Union[List[features.ImageTypeJIT], List[features.VideoTypeJIT]]:
if isinstance(inpt, torch.Tensor): if isinstance(inpt, torch.Tensor):
output = ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip) output = ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip)
if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)): if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)):
......
from typing import cast, List, Optional, Tuple from typing import List, Optional, Tuple, Union
import PIL.Image import PIL.Image
import torch import torch
...@@ -11,7 +11,7 @@ get_dimensions_image_tensor = _FT.get_dimensions ...@@ -11,7 +11,7 @@ get_dimensions_image_tensor = _FT.get_dimensions
get_dimensions_image_pil = _FP.get_dimensions get_dimensions_image_pil = _FP.get_dimensions
def get_dimensions(image: features.ImageOrVideoTypeJIT) -> List[int]: def get_dimensions(image: Union[features.ImageTypeJIT, features.VideoTypeJIT]) -> List[int]:
if isinstance(image, torch.Tensor) and ( if isinstance(image, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(image, (features.Image, features.Video)) torch.jit.is_scripting() or not isinstance(image, (features.Image, features.Video))
): ):
...@@ -32,7 +32,7 @@ def get_num_channels_video(video: torch.Tensor) -> int: ...@@ -32,7 +32,7 @@ def get_num_channels_video(video: torch.Tensor) -> int:
return get_num_channels_image_tensor(video) return get_num_channels_image_tensor(video)
def get_num_channels(image: features.ImageOrVideoTypeJIT) -> int: def get_num_channels(image: Union[features.ImageTypeJIT, features.VideoTypeJIT]) -> int:
if isinstance(image, torch.Tensor) and ( if isinstance(image, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(image, (features.Image, features.Video)) torch.jit.is_scripting() or not isinstance(image, (features.Image, features.Video))
): ):
...@@ -262,11 +262,11 @@ def convert_color_space_video( ...@@ -262,11 +262,11 @@ def convert_color_space_video(
def convert_color_space( def convert_color_space(
inpt: features.ImageOrVideoTypeJIT, inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT],
color_space: ColorSpace, color_space: ColorSpace,
old_color_space: Optional[ColorSpace] = None, old_color_space: Optional[ColorSpace] = None,
copy: bool = True, copy: bool = True,
) -> features.ImageOrVideoTypeJIT: ) -> Union[features.ImageTypeJIT, features.VideoTypeJIT]:
if isinstance(inpt, torch.Tensor) and ( if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video)) torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video))
): ):
...@@ -281,4 +281,4 @@ def convert_color_space( ...@@ -281,4 +281,4 @@ def convert_color_space(
elif isinstance(inpt, (features.Image, features.Video)): elif isinstance(inpt, (features.Image, features.Video)):
return inpt.to_color_space(color_space, copy=copy) return inpt.to_color_space(color_space, copy=copy)
else: else:
return cast(features.ImageOrVideoTypeJIT, convert_color_space_image_pil(inpt, color_space, copy=copy)) return convert_color_space_image_pil(inpt, color_space, copy=copy)
from typing import List, Optional from typing import List, Optional, Union
import PIL.Image import PIL.Image
import torch import torch
...@@ -14,7 +14,10 @@ def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], in ...@@ -14,7 +14,10 @@ def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], in
def normalize( def normalize(
inpt: features.TensorImageOrVideoTypeJIT, mean: List[float], std: List[float], inplace: bool = False inpt: Union[features.TensorImageTypeJIT, features.TensorVideoTypeJIT],
mean: List[float],
std: List[float],
inplace: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
correct_type = isinstance(inpt, torch.Tensor) correct_type = isinstance(inpt, torch.Tensor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment