"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "3e4c5707c3e6e0e363ef93a6a60bad7245f05e46"
Unverified Commit 297e2b87 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add random apply transform base class (#5639)

parent 39772ece
...@@ -7,10 +7,11 @@ import torch ...@@ -7,10 +7,11 @@ import torch
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, functional as F from torchvision.prototype.transforms import Transform, functional as F
from ._transform import _RandomApplyTransform
from ._utils import query_image, get_image_dimensions, has_all, has_any, is_simple_tensor from ._utils import query_image, get_image_dimensions, has_all, has_any, is_simple_tensor
class RandomErasing(Transform): class RandomErasing(_RandomApplyTransform):
def __init__( def __init__(
self, self,
p: float = 0.5, p: float = 0.5,
...@@ -18,7 +19,7 @@ class RandomErasing(Transform): ...@@ -18,7 +19,7 @@ class RandomErasing(Transform):
ratio: Tuple[float, float] = (0.3, 3.3), ratio: Tuple[float, float] = (0.3, 3.3),
value: float = 0, value: float = 0,
): ):
super().__init__() super().__init__(p=p)
if not isinstance(value, (numbers.Number, str, tuple, list)): if not isinstance(value, (numbers.Number, str, tuple, list)):
raise TypeError("Argument value should be either a number or str or a sequence") raise TypeError("Argument value should be either a number or str or a sequence")
if isinstance(value, str) and value != "random": if isinstance(value, str) and value != "random":
...@@ -31,9 +32,6 @@ class RandomErasing(Transform): ...@@ -31,9 +32,6 @@ class RandomErasing(Transform):
warnings.warn("Scale and ratio should be of kind (min, max)") warnings.warn("Scale and ratio should be of kind (min, max)")
if scale[0] < 0 or scale[1] > 1: if scale[0] < 0 or scale[1] > 1:
raise ValueError("Scale should be between 0 and 1") raise ValueError("Scale should be between 0 and 1")
if p < 0 or p > 1:
raise ValueError("Random erasing probability should be between 0 and 1")
self.p = p
self.scale = scale self.scale = scale
self.ratio = ratio self.ratio = ratio
self.value = value self.value = value
...@@ -99,8 +97,6 @@ class RandomErasing(Transform): ...@@ -99,8 +97,6 @@ class RandomErasing(Transform):
sample = inputs if len(inputs) > 1 else inputs[0] sample = inputs if len(inputs) > 1 else inputs[0]
if has_any(sample, features.BoundingBox, features.SegmentationMask): if has_any(sample, features.BoundingBox, features.SegmentationMask):
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()") raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
elif torch.rand(1) >= self.p:
return sample
return super().forward(sample) return super().forward(sample)
......
from typing import Any, Optional, List from typing import Any, Optional, List, Dict
import torch import torch
from torchvision.prototype.transforms import Transform
from ._transform import Transform from ._transform import _RandomApplyTransform
class Compose(Transform): class Compose(Transform):
...@@ -19,18 +20,13 @@ class Compose(Transform): ...@@ -19,18 +20,13 @@ class Compose(Transform):
return sample return sample
class RandomApply(Transform): class RandomApply(_RandomApplyTransform):
def __init__(self, transform: Transform, *, p: float = 0.5) -> None: def __init__(self, transform: Transform, *, p: float = 0.5) -> None:
super().__init__() super().__init__(p=p)
self.transform = transform self.transform = transform
self.p = p
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if float(torch.rand(())) < self.p:
return sample
return self.transform(sample) def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
return self.transform(input)
def extra_repr(self) -> str: def extra_repr(self) -> str:
return f"p={self.p}" return f"p={self.p}"
......
...@@ -12,21 +12,11 @@ from torchvision.transforms.functional import pil_to_tensor ...@@ -12,21 +12,11 @@ from torchvision.transforms.functional import pil_to_tensor
from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int
from typing_extensions import Literal from typing_extensions import Literal
from ._transform import _RandomApplyTransform
from ._utils import query_image, get_image_dimensions, has_any, is_simple_tensor from ._utils import query_image, get_image_dimensions, has_any, is_simple_tensor
class RandomHorizontalFlip(Transform): class RandomHorizontalFlip(_RandomApplyTransform):
def __init__(self, p: float = 0.5) -> None:
super().__init__()
self.p = p
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if torch.rand(1) >= self.p:
return sample
return super().forward(sample)
def _transform(self, input: Any, params: Dict[str, Any]) -> Any: def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, features.Image): if isinstance(input, features.Image):
output = F.horizontal_flip_image_tensor(input) output = F.horizontal_flip_image_tensor(input)
...@@ -45,18 +35,7 @@ class RandomHorizontalFlip(Transform): ...@@ -45,18 +35,7 @@ class RandomHorizontalFlip(Transform):
return input return input
class RandomVerticalFlip(Transform): class RandomVerticalFlip(_RandomApplyTransform):
def __init__(self, p: float = 0.5) -> None:
super().__init__()
self.p = p
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if torch.rand(1) > self.p:
return sample
return super().forward(sample)
def _transform(self, input: Any, params: Dict[str, Any]) -> Any: def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, features.Image): if isinstance(input, features.Image):
output = F.vertical_flip_image_tensor(input) output = F.vertical_flip_image_tensor(input)
...@@ -371,11 +350,11 @@ class Pad(Transform): ...@@ -371,11 +350,11 @@ class Pad(Transform):
return input return input
class RandomZoomOut(Transform): class RandomZoomOut(_RandomApplyTransform):
def __init__( def __init__(
self, fill: Union[float, Sequence[float]] = 0.0, side_range: Tuple[float, float] = (1.0, 4.0), p: float = 0.5 self, fill: Union[float, Sequence[float]] = 0.0, side_range: Tuple[float, float] = (1.0, 4.0), p: float = 0.5
) -> None: ) -> None:
super().__init__() super().__init__(p=p)
if fill is None: if fill is None:
fill = 0.0 fill = 0.0
...@@ -385,8 +364,6 @@ class RandomZoomOut(Transform): ...@@ -385,8 +364,6 @@ class RandomZoomOut(Transform):
if side_range[0] < 1.0 or side_range[0] > side_range[1]: if side_range[0] < 1.0 or side_range[0] > side_range[1]:
raise ValueError(f"Invalid canvas side range provided {side_range}.") raise ValueError(f"Invalid canvas side range provided {side_range}.")
self.p = p
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample) image = query_image(sample)
orig_c, orig_h, orig_w = get_image_dimensions(image) orig_c, orig_h, orig_w = get_image_dimensions(image)
...@@ -411,10 +388,3 @@ class RandomZoomOut(Transform): ...@@ -411,10 +388,3 @@ class RandomZoomOut(Transform):
def _transform(self, input: Any, params: Dict[str, Any]) -> Any: def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
transform = Pad(**params, padding_mode="constant") transform = Pad(**params, padding_mode="constant")
return transform(input) return transform(input)
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if torch.rand(1) >= self.p:
return sample
return super().forward(sample)
...@@ -2,6 +2,7 @@ import enum ...@@ -2,6 +2,7 @@ import enum
import functools import functools
from typing import Any, Dict from typing import Any, Dict
import torch
from torch import nn from torch import nn
from torchvision.prototype.utils._internal import apply_recursively from torchvision.prototype.utils._internal import apply_recursively
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
...@@ -34,3 +35,20 @@ class Transform(nn.Module): ...@@ -34,3 +35,20 @@ class Transform(nn.Module):
extra.append(f"{name}={value}") extra.append(f"{name}={value}")
return ", ".join(extra) return ", ".join(extra)
class _RandomApplyTransform(Transform):
def __init__(self, *, p: float = 0.5) -> None:
if not (0.0 <= p <= 1.0):
raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].")
super().__init__()
self.p = p
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if torch.rand(1) >= self.p:
return sample
return super().forward(sample)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment