"...text-generation-inference.git" did not exist on "9987960062e40de2deae030ab7e4ad6f57de0b20"
Unverified Commit 297e2b87 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add random apply transform base class (#5639)

parent 39772ece
......@@ -7,10 +7,11 @@ import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, functional as F
from ._transform import _RandomApplyTransform
from ._utils import query_image, get_image_dimensions, has_all, has_any, is_simple_tensor
class RandomErasing(Transform):
class RandomErasing(_RandomApplyTransform):
def __init__(
self,
p: float = 0.5,
......@@ -18,7 +19,7 @@ class RandomErasing(Transform):
ratio: Tuple[float, float] = (0.3, 3.3),
value: float = 0,
):
super().__init__()
super().__init__(p=p)
if not isinstance(value, (numbers.Number, str, tuple, list)):
raise TypeError("Argument value should be either a number or str or a sequence")
if isinstance(value, str) and value != "random":
......@@ -31,9 +32,6 @@ class RandomErasing(Transform):
warnings.warn("Scale and ratio should be of kind (min, max)")
if scale[0] < 0 or scale[1] > 1:
raise ValueError("Scale should be between 0 and 1")
if p < 0 or p > 1:
raise ValueError("Random erasing probability should be between 0 and 1")
self.p = p
self.scale = scale
self.ratio = ratio
self.value = value
......@@ -99,8 +97,6 @@ class RandomErasing(Transform):
sample = inputs if len(inputs) > 1 else inputs[0]
if has_any(sample, features.BoundingBox, features.SegmentationMask):
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
elif torch.rand(1) >= self.p:
return sample
return super().forward(sample)
......
from typing import Any, Optional, List
from typing import Any, Optional, List, Dict
import torch
from torchvision.prototype.transforms import Transform
from ._transform import Transform
from ._transform import _RandomApplyTransform
class Compose(Transform):
......@@ -19,18 +20,13 @@ class Compose(Transform):
return sample
class RandomApply(Transform):
class RandomApply(_RandomApplyTransform):
def __init__(self, transform: Transform, *, p: float = 0.5) -> None:
super().__init__()
super().__init__(p=p)
self.transform = transform
self.p = p
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if float(torch.rand(())) < self.p:
return sample
return self.transform(sample)
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
return self.transform(input)
def extra_repr(self) -> str:
return f"p={self.p}"
......
......@@ -12,21 +12,11 @@ from torchvision.transforms.functional import pil_to_tensor
from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int
from typing_extensions import Literal
from ._transform import _RandomApplyTransform
from ._utils import query_image, get_image_dimensions, has_any, is_simple_tensor
class RandomHorizontalFlip(Transform):
def __init__(self, p: float = 0.5) -> None:
super().__init__()
self.p = p
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if torch.rand(1) >= self.p:
return sample
return super().forward(sample)
class RandomHorizontalFlip(_RandomApplyTransform):
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, features.Image):
output = F.horizontal_flip_image_tensor(input)
......@@ -45,18 +35,7 @@ class RandomHorizontalFlip(Transform):
return input
class RandomVerticalFlip(Transform):
def __init__(self, p: float = 0.5) -> None:
super().__init__()
self.p = p
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if torch.rand(1) > self.p:
return sample
return super().forward(sample)
class RandomVerticalFlip(_RandomApplyTransform):
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, features.Image):
output = F.vertical_flip_image_tensor(input)
......@@ -371,11 +350,11 @@ class Pad(Transform):
return input
class RandomZoomOut(Transform):
class RandomZoomOut(_RandomApplyTransform):
def __init__(
self, fill: Union[float, Sequence[float]] = 0.0, side_range: Tuple[float, float] = (1.0, 4.0), p: float = 0.5
) -> None:
super().__init__()
super().__init__(p=p)
if fill is None:
fill = 0.0
......@@ -385,8 +364,6 @@ class RandomZoomOut(Transform):
if side_range[0] < 1.0 or side_range[0] > side_range[1]:
raise ValueError(f"Invalid canvas side range provided {side_range}.")
self.p = p
def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample)
orig_c, orig_h, orig_w = get_image_dimensions(image)
......@@ -411,10 +388,3 @@ class RandomZoomOut(Transform):
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
transform = Pad(**params, padding_mode="constant")
return transform(input)
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if torch.rand(1) >= self.p:
return sample
return super().forward(sample)
......@@ -2,6 +2,7 @@ import enum
import functools
from typing import Any, Dict
import torch
from torch import nn
from torchvision.prototype.utils._internal import apply_recursively
from torchvision.utils import _log_api_usage_once
......@@ -34,3 +35,20 @@ class Transform(nn.Module):
extra.append(f"{name}={value}")
return ", ".join(extra)
class _RandomApplyTransform(Transform):
def __init__(self, *, p: float = 0.5) -> None:
if not (0.0 <= p <= 1.0):
raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].")
super().__init__()
self.p = p
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if torch.rand(1) >= self.p:
return sample
return super().forward(sample)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment