"git@developer.sourcefind.cn:change/sglang.git" did not exist on "c8547ecddd8ebf5095b8ee3b825166b5cf94ad89"
Unverified Commit d8025b9a authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

cleanup prototype auto augment transforms (#6463)

* cleanup prototype auto augment transforms

* remove custom fill parsing from auto augment
parent 38cd6b39
import math import math
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Type, TypeVar, Union import numbers
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
import PIL.Image import PIL.Image
import torch import torch
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform from torchvision.prototype.transforms import functional as F, Transform
from torchvision.transforms.autoaugment import AutoAugmentPolicy from torchvision.transforms.autoaugment import AutoAugmentPolicy
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image
from ._utils import get_chw, is_simple_tensor from ._utils import is_simple_tensor, query_chw
K = TypeVar("K") K = TypeVar("K")
V = TypeVar("V") V = TypeVar("V")
def _put_into_sample(sample: Any, id: int, item: Any) -> Any:
sample_flat, spec = tree_flatten(sample)
sample_flat[id] = item
return tree_unflatten(sample_flat, spec)
class _AutoAugmentBase(Transform): class _AutoAugmentBase(Transform):
def __init__( def __init__(
self, self,
...@@ -31,6 +25,9 @@ class _AutoAugmentBase(Transform): ...@@ -31,6 +25,9 @@ class _AutoAugmentBase(Transform):
) -> None: ) -> None:
super().__init__() super().__init__()
self.interpolation = interpolation self.interpolation = interpolation
if not isinstance(fill, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate fill arg")
self.fill = fill self.fill = fill
def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]: def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]:
...@@ -38,41 +35,9 @@ class _AutoAugmentBase(Transform): ...@@ -38,41 +35,9 @@ class _AutoAugmentBase(Transform):
key = keys[int(torch.randint(len(keys), ()))] key = keys[int(torch.randint(len(keys), ()))]
return key, dct[key] return key, dct[key]
def _extract_image( def _get_params(self, sample: Any) -> Dict[str, Any]:
self, _, height, width = query_chw(sample)
sample: Any, return dict(height=height, width=width)
unsupported_types: Tuple[Type, ...] = (features.BoundingBox, features.SegmentationMask),
) -> Tuple[int, Union[PIL.Image.Image, torch.Tensor, features.Image]]:
sample_flat, _ = tree_flatten(sample)
images = []
for id, inpt in enumerate(sample_flat):
if isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt):
images.append((id, inpt))
elif isinstance(inpt, unsupported_types):
raise TypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()")
if not images:
raise TypeError("Found no image in the sample.")
if len(images) > 1:
raise TypeError(
f"Auto augment transformations are only properly defined for a single image, but found {len(images)}."
)
return images[0]
def _parse_fill(
self, image: Union[PIL.Image.Image, torch.Tensor, features.Image], num_channels: int
) -> Union[int, float, Sequence[int], Sequence[float]]:
fill = self.fill
if isinstance(image, PIL.Image.Image) or fill is None:
return fill
if isinstance(fill, (int, float)):
fill = [float(fill)] * num_channels
else:
fill = [float(f) for f in fill]
return fill
def _apply_image_transform( def _apply_image_transform(
self, self,
...@@ -277,22 +242,22 @@ class AutoAugment(_AutoAugmentBase): ...@@ -277,22 +242,22 @@ class AutoAugment(_AutoAugmentBase):
else: else:
raise ValueError(f"The provided policy {policy} is not recognized.") raise ValueError(f"The provided policy {policy} is not recognized.")
def forward(self, *inputs: Any) -> Any: def _get_params(self, sample: Any) -> Dict[str, Any]:
sample = inputs if len(inputs) > 1 else inputs[0] params = super(AutoAugment, self)._get_params(sample)
params["policy"] = self._policies[int(torch.randint(len(self._policies), ()))]
return params
id, image = self._extract_image(sample) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
num_channels, height, width = get_chw(image) if not (isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt)):
fill = self._parse_fill(image, num_channels) return inpt
policy = self._policies[int(torch.randint(len(self._policies), ()))] for transform_id, probability, magnitude_idx in params["policy"]:
for transform_id, probability, magnitude_idx in policy:
if not torch.rand(()) <= probability: if not torch.rand(()) <= probability:
continue continue
magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id] magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id]
magnitudes = magnitudes_fn(10, height, width) magnitudes = magnitudes_fn(10, params["height"], params["width"])
if magnitudes is not None: if magnitudes is not None:
magnitude = float(magnitudes[magnitude_idx]) magnitude = float(magnitudes[magnitude_idx])
if signed and torch.rand(()) <= 0.5: if signed and torch.rand(()) <= 0.5:
...@@ -300,11 +265,11 @@ class AutoAugment(_AutoAugmentBase): ...@@ -300,11 +265,11 @@ class AutoAugment(_AutoAugmentBase):
else: else:
magnitude = 0.0 magnitude = 0.0
image = self._apply_image_transform( inpt = self._apply_image_transform(
image, transform_id, magnitude, interpolation=self.interpolation, fill=fill inpt, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
) )
return _put_into_sample(sample, id, image) return inpt
class RandAugment(_AutoAugmentBase): class RandAugment(_AutoAugmentBase):
...@@ -350,17 +315,14 @@ class RandAugment(_AutoAugmentBase): ...@@ -350,17 +315,14 @@ class RandAugment(_AutoAugmentBase):
self.magnitude = magnitude self.magnitude = magnitude
self.num_magnitude_bins = num_magnitude_bins self.num_magnitude_bins = num_magnitude_bins
def forward(self, *inputs: Any) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0] if not (isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt)):
return inpt
id, image = self._extract_image(sample)
num_channels, height, width = get_chw(image)
fill = self._parse_fill(image, num_channels)
for _ in range(self.num_ops): for _ in range(self.num_ops):
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width) magnitudes = magnitudes_fn(self.num_magnitude_bins, params["height"], params["width"])
if magnitudes is not None: if magnitudes is not None:
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))]) magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
if signed and torch.rand(()) <= 0.5: if signed and torch.rand(()) <= 0.5:
...@@ -368,11 +330,11 @@ class RandAugment(_AutoAugmentBase): ...@@ -368,11 +330,11 @@ class RandAugment(_AutoAugmentBase):
else: else:
magnitude = 0.0 magnitude = 0.0
image = self._apply_image_transform( inpt = self._apply_image_transform(
image, transform_id, magnitude, interpolation=self.interpolation, fill=fill inpt, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
) )
return _put_into_sample(sample, id, image) return inpt
class TrivialAugmentWide(_AutoAugmentBase): class TrivialAugmentWide(_AutoAugmentBase):
...@@ -408,16 +370,13 @@ class TrivialAugmentWide(_AutoAugmentBase): ...@@ -408,16 +370,13 @@ class TrivialAugmentWide(_AutoAugmentBase):
super().__init__(interpolation=interpolation, fill=fill) super().__init__(interpolation=interpolation, fill=fill)
self.num_magnitude_bins = num_magnitude_bins self.num_magnitude_bins = num_magnitude_bins
def forward(self, *inputs: Any) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0] if not (isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt)):
return inpt
id, image = self._extract_image(sample)
num_channels, height, width = get_chw(image)
fill = self._parse_fill(image, num_channels)
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width) magnitudes = magnitudes_fn(self.num_magnitude_bins, params["height"], params["width"])
if magnitudes is not None: if magnitudes is not None:
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))]) magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
if signed and torch.rand(()) <= 0.5: if signed and torch.rand(()) <= 0.5:
...@@ -425,8 +384,9 @@ class TrivialAugmentWide(_AutoAugmentBase): ...@@ -425,8 +384,9 @@ class TrivialAugmentWide(_AutoAugmentBase):
else: else:
magnitude = 0.0 magnitude = 0.0
image = self._apply_image_transform(image, transform_id, magnitude, interpolation=self.interpolation, fill=fill) return self._apply_image_transform(
return _put_into_sample(sample, id, image) inpt, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
)
class AugMix(_AutoAugmentBase): class AugMix(_AutoAugmentBase):
...@@ -478,16 +438,13 @@ class AugMix(_AutoAugmentBase): ...@@ -478,16 +438,13 @@ class AugMix(_AutoAugmentBase):
# Must be on a separate method so that we can overwrite it in tests. # Must be on a separate method so that we can overwrite it in tests.
return torch._sample_dirichlet(params) return torch._sample_dirichlet(params)
def forward(self, *inputs: Any) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0] if isinstance(inpt, features.Image) or is_simple_tensor(inpt):
id, orig_image = self._extract_image(sample) image = inpt
num_channels, height, width = get_chw(orig_image) elif isinstance(inpt, PIL.Image.Image):
fill = self._parse_fill(orig_image, num_channels) image = pil_to_tensor(inpt)
else:
if isinstance(orig_image, torch.Tensor): return inpt
image = orig_image
else: # isinstance(inpt, PIL.Image.Image):
image = pil_to_tensor(orig_image)
augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE
...@@ -513,7 +470,7 @@ class AugMix(_AutoAugmentBase): ...@@ -513,7 +470,7 @@ class AugMix(_AutoAugmentBase):
for _ in range(depth): for _ in range(depth):
transform_id, (magnitudes_fn, signed) = self._get_random_item(augmentation_space) transform_id, (magnitudes_fn, signed) = self._get_random_item(augmentation_space)
magnitudes = magnitudes_fn(self._PARAMETER_MAX, height, width) magnitudes = magnitudes_fn(self._PARAMETER_MAX, params["height"], params["width"])
if magnitudes is not None: if magnitudes is not None:
magnitude = float(magnitudes[int(torch.randint(self.severity, ()))]) magnitude = float(magnitudes[int(torch.randint(self.severity, ()))])
if signed and torch.rand(()) <= 0.5: if signed and torch.rand(()) <= 0.5:
...@@ -522,14 +479,14 @@ class AugMix(_AutoAugmentBase): ...@@ -522,14 +479,14 @@ class AugMix(_AutoAugmentBase):
magnitude = 0.0 magnitude = 0.0
aug = self._apply_image_transform( aug = self._apply_image_transform(
aug, transform_id, magnitude, interpolation=self.interpolation, fill=fill aug, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
) )
mix.add_(combined_weights[:, i].view(batch_dims) * aug) mix.add_(combined_weights[:, i].view(batch_dims) * aug)
mix = mix.view(orig_dims).to(dtype=image.dtype) mix = mix.view(orig_dims).to(dtype=image.dtype)
if isinstance(orig_image, features.Image): if isinstance(inpt, features.Image):
mix = features.Image.new_like(orig_image, mix) mix = features.Image.new_like(inpt, mix)
elif isinstance(orig_image, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
mix = to_pil_image(mix) mix = to_pil_image(mix)
return _put_into_sample(sample, id, mix) return mix
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment