"model/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "0fbfcf3c9c7bfdbf4616238595eafd7eca2a916c"
Unverified Commit 7cf0f4cc authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

make transforms v2 JIT scriptable (#7135)

parent 170160a5
......@@ -34,6 +34,15 @@ from torchvision.transforms import functional as legacy_F
DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=["RGB"], extra_dims=[(4,)])
class NotScriptableArgsKwargs(ArgsKwargs):
"""
This class is used to mark parameters that render the transform non-scriptable. They still work in eager mode and
thus will be tested there, but will be skipped by the JIT tests.
"""
pass
class ConsistencyConfig:
def __init__(
self,
......@@ -73,7 +82,7 @@ CONSISTENCY_CONFIGS = [
prototype_transforms.Resize,
legacy_transforms.Resize,
[
ArgsKwargs(32),
NotScriptableArgsKwargs(32),
ArgsKwargs([32]),
ArgsKwargs((32, 29)),
ArgsKwargs((31, 28), interpolation=prototype_transforms.InterpolationMode.NEAREST),
......@@ -84,8 +93,10 @@ CONSISTENCY_CONFIGS = [
# ArgsKwargs((30, 27), interpolation=0),
# ArgsKwargs((35, 29), interpolation=2),
# ArgsKwargs((34, 25), interpolation=3),
ArgsKwargs(31, max_size=32),
ArgsKwargs(30, max_size=100),
NotScriptableArgsKwargs(31, max_size=32),
ArgsKwargs([31], max_size=32),
NotScriptableArgsKwargs(30, max_size=100),
ArgsKwargs([31], max_size=32),
ArgsKwargs((29, 32), antialias=False),
ArgsKwargs((28, 31), antialias=True),
],
......@@ -121,14 +132,15 @@ CONSISTENCY_CONFIGS = [
prototype_transforms.Pad,
legacy_transforms.Pad,
[
ArgsKwargs(3),
NotScriptableArgsKwargs(3),
ArgsKwargs([3]),
ArgsKwargs([2, 3]),
ArgsKwargs([3, 2, 1, 4]),
ArgsKwargs(5, fill=1, padding_mode="constant"),
ArgsKwargs(5, padding_mode="edge"),
ArgsKwargs(5, padding_mode="reflect"),
ArgsKwargs(5, padding_mode="symmetric"),
NotScriptableArgsKwargs(5, fill=1, padding_mode="constant"),
ArgsKwargs([5], fill=1, padding_mode="constant"),
NotScriptableArgsKwargs(5, padding_mode="edge"),
NotScriptableArgsKwargs(5, padding_mode="reflect"),
NotScriptableArgsKwargs(5, padding_mode="symmetric"),
],
),
ConsistencyConfig(
......@@ -170,7 +182,7 @@ CONSISTENCY_CONFIGS = [
ConsistencyConfig(
prototype_transforms.ToPILImage,
legacy_transforms.ToPILImage,
[ArgsKwargs()],
[NotScriptableArgsKwargs()],
make_images_kwargs=dict(
color_spaces=[
"GRAY",
......@@ -186,7 +198,7 @@ CONSISTENCY_CONFIGS = [
prototype_transforms.Lambda,
legacy_transforms.Lambda,
[
ArgsKwargs(lambda image: image / 2),
NotScriptableArgsKwargs(lambda image: image / 2),
],
# Technically, this also supports PIL, but it is overkill to write a function here that supports tensor and PIL
# images given that the transform does nothing but call it anyway.
......@@ -380,14 +392,15 @@ CONSISTENCY_CONFIGS = [
[
ArgsKwargs(12),
ArgsKwargs((15, 17)),
ArgsKwargs(11, padding=1),
NotScriptableArgsKwargs(11, padding=1),
ArgsKwargs(11, padding=[1]),
ArgsKwargs((8, 13), padding=(2, 3)),
ArgsKwargs((14, 9), padding=(0, 2, 1, 0)),
ArgsKwargs(36, pad_if_needed=True),
ArgsKwargs((7, 8), fill=1),
ArgsKwargs(5, fill=(1, 2, 3)),
NotScriptableArgsKwargs(5, fill=(1, 2, 3)),
ArgsKwargs(12),
ArgsKwargs(15, padding=2, padding_mode="edge"),
NotScriptableArgsKwargs(15, padding=2, padding_mode="edge"),
ArgsKwargs(17, padding=(1, 0), padding_mode="reflect"),
ArgsKwargs(8, padding=(3, 0, 0, 1), padding_mode="symmetric"),
],
......@@ -642,6 +655,38 @@ def test_call_consistency(config, args_kwargs):
)
@pytest.mark.parametrize(
("config", "args_kwargs"),
[
pytest.param(
config, args_kwargs, id=f"{config.legacy_cls.__name__}-{idx:0{len(str(len(config.args_kwargs)))}d}"
)
for config in CONSISTENCY_CONFIGS
for idx, args_kwargs in enumerate(config.args_kwargs)
if not isinstance(args_kwargs, NotScriptableArgsKwargs)
],
)
def test_jit_consistency(config, args_kwargs):
args, kwargs = args_kwargs
prototype_transform_eager = config.prototype_cls(*args, **kwargs)
legacy_transform_eager = config.legacy_cls(*args, **kwargs)
legacy_transform_scripted = torch.jit.script(legacy_transform_eager)
prototype_transform_scripted = torch.jit.script(prototype_transform_eager)
for image in make_images(**config.make_images_kwargs):
image = image.as_subclass(torch.Tensor)
torch.manual_seed(0)
output_legacy_scripted = legacy_transform_scripted(image)
torch.manual_seed(0)
output_prototype_scripted = prototype_transform_scripted(image)
assert_close(output_prototype_scripted, output_legacy_scripted, **config.closeness_kwargs)
class TestContainerTransforms:
"""
Since we are testing containers here, we also need some transforms to wrap. Thus, testing a container transform for
......
......@@ -6,7 +6,7 @@ from typing import Any, cast, Dict, List, Optional, Tuple, Union
import PIL.Image
import torch
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import transforms as _transforms
from torchvision.ops import masks_to_boxes
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F, InterpolationMode, Transform
......@@ -16,6 +16,14 @@ from .utils import has_any, is_simple_tensor, query_chw, query_spatial_size
class RandomErasing(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomErasing
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
return dict(
super()._extract_params_for_v1_transform(),
value="random" if self.value is None else self.value,
)
_transformed_types = (is_simple_tensor, datapoints.Image, PIL.Image.Image, datapoints.Video)
def __init__(
......
......@@ -5,7 +5,7 @@ import PIL.Image
import torch
from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
from torchvision import transforms as _transforms
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
from torchvision.prototype.transforms.functional._meta import get_spatial_size
......@@ -161,6 +161,8 @@ class _AutoAugmentBase(Transform):
class AutoAugment(_AutoAugmentBase):
_v1_transform_cls = _transforms.AutoAugment
_AUGMENTATION_SPACE = {
"ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
......@@ -315,6 +317,7 @@ class AutoAugment(_AutoAugmentBase):
class RandAugment(_AutoAugmentBase):
_v1_transform_cls = _transforms.RandAugment
_AUGMENTATION_SPACE = {
"Identity": (lambda num_bins, height, width: None, False),
"ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
......@@ -375,6 +378,7 @@ class RandAugment(_AutoAugmentBase):
class TrivialAugmentWide(_AutoAugmentBase):
_v1_transform_cls = _transforms.TrivialAugmentWide
_AUGMENTATION_SPACE = {
"Identity": (lambda num_bins, height, width: None, False),
"ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.99, num_bins), True),
......@@ -425,6 +429,8 @@ class TrivialAugmentWide(_AutoAugmentBase):
class AugMix(_AutoAugmentBase):
_v1_transform_cls = _transforms.AugMix
_PARTIAL_AUGMENTATION_SPACE = {
"ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),
......
......@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import PIL.Image
import torch
from torchvision import transforms as _transforms
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F, Transform
......@@ -12,6 +12,8 @@ from .utils import is_simple_tensor, query_chw
class Grayscale(Transform):
_v1_transform_cls = _transforms.Grayscale
_transformed_types = (
datapoints.Image,
PIL.Image.Image,
......@@ -28,6 +30,8 @@ class Grayscale(Transform):
class RandomGrayscale(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomGrayscale
_transformed_types = (
datapoints.Image,
PIL.Image.Image,
......@@ -47,6 +51,11 @@ class RandomGrayscale(_RandomApplyTransform):
class ColorJitter(Transform):
_v1_transform_cls = _transforms.ColorJitter
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
return {attr: value or 0 for attr, value in super()._extract_params_for_v1_transform().items()}
def __init__(
self,
brightness: Optional[Union[float, Sequence[float]]] = None,
......@@ -194,16 +203,22 @@ class RandomPhotometricDistort(Transform):
class RandomEqualize(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomEqualize
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.equalize(inpt)
class RandomInvert(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomInvert
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.invert(inpt)
class RandomPosterize(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomPosterize
def __init__(self, bits: int, p: float = 0.5) -> None:
super().__init__(p=p)
self.bits = bits
......@@ -213,6 +228,8 @@ class RandomPosterize(_RandomApplyTransform):
class RandomSolarize(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomSolarize
def __init__(self, threshold: float, p: float = 0.5) -> None:
super().__init__(p=p)
self.threshold = threshold
......@@ -222,11 +239,15 @@ class RandomSolarize(_RandomApplyTransform):
class RandomAutocontrast(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomAutocontrast
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.autocontrast(inpt)
class RandomAdjustSharpness(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomAdjustSharpness
def __init__(self, sharpness_factor: float, p: float = 0.5) -> None:
super().__init__(p=p)
self.sharpness_factor = sharpness_factor
......
......@@ -6,6 +6,7 @@ from typing import Any, cast, Dict, List, Literal, Optional, Sequence, Tuple, Ty
import PIL.Image
import torch
from torchvision import transforms as _transforms
from torchvision.ops.boxes import box_iou
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F, InterpolationMode, Transform
......@@ -25,16 +26,22 @@ from .utils import has_all, has_any, is_simple_tensor, query_bounding_box, query
class RandomHorizontalFlip(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomHorizontalFlip
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.horizontal_flip(inpt)
class RandomVerticalFlip(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomVerticalFlip
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.vertical_flip(inpt)
class Resize(Transform):
_v1_transform_cls = _transforms.Resize
def __init__(
self,
size: Union[int, Sequence[int]],
......@@ -69,6 +76,8 @@ class Resize(Transform):
class CenterCrop(Transform):
_v1_transform_cls = _transforms.CenterCrop
def __init__(self, size: Union[int, Sequence[int]]):
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
......@@ -78,6 +87,8 @@ class CenterCrop(Transform):
class RandomResizedCrop(Transform):
_v1_transform_cls = _transforms.RandomResizedCrop
def __init__(
self,
size: Union[int, Sequence[int]],
......@@ -174,6 +185,8 @@ class FiveCrop(Transform):
torch.Size([5])
"""
_v1_transform_cls = _transforms.FiveCrop
_transformed_types = (
datapoints.Image,
PIL.Image.Image,
......@@ -200,6 +213,8 @@ class TenCrop(Transform):
See :class:`~torchvision.prototype.transforms.FiveCrop` for an example.
"""
_v1_transform_cls = _transforms.TenCrop
_transformed_types = (
datapoints.Image,
PIL.Image.Image,
......@@ -223,6 +238,18 @@ class TenCrop(Transform):
class Pad(Transform):
_v1_transform_cls = _transforms.Pad
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
params = super()._extract_params_for_v1_transform()
if not (params["fill"] is None or isinstance(params["fill"], (int, float))):
raise ValueError(
f"{type(self.__name__)}() can only be scripted for a scalar `fill`, but got {self.fill} for images."
)
return params
def __init__(
self,
padding: Union[int, Sequence[int]],
......@@ -285,6 +312,8 @@ class RandomZoomOut(_RandomApplyTransform):
class RandomRotation(Transform):
_v1_transform_cls = _transforms.RandomRotation
def __init__(
self,
degrees: Union[numbers.Number, Sequence],
......@@ -322,6 +351,8 @@ class RandomRotation(Transform):
class RandomAffine(Transform):
_v1_transform_cls = _transforms.RandomAffine
def __init__(
self,
degrees: Union[numbers.Number, Sequence],
......@@ -399,6 +430,24 @@ class RandomAffine(Transform):
class RandomCrop(Transform):
_v1_transform_cls = _transforms.RandomCrop
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
params = super()._extract_params_for_v1_transform()
if not (params["fill"] is None or isinstance(params["fill"], (int, float))):
raise ValueError(
f"{type(self.__name__)}() can only be scripted for a scalar `fill`, but got {self.fill} for images."
)
padding = self.padding
if padding is not None:
pad_left, pad_right, pad_top, pad_bottom = padding
padding = [pad_left, pad_top, pad_right, pad_bottom]
params["padding"] = padding
return params
def __init__(
self,
size: Union[int, Sequence[int]],
......@@ -491,6 +540,8 @@ class RandomCrop(Transform):
class RandomPerspective(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomPerspective
def __init__(
self,
distortion_scale: float = 0.5,
......@@ -550,6 +601,8 @@ class RandomPerspective(_RandomApplyTransform):
class ElasticTransform(Transform):
_v1_transform_cls = _transforms.ElasticTransform
def __init__(
self,
alpha: Union[float, Sequence[float]] = 50.0,
......
......@@ -2,6 +2,7 @@ from typing import Any, Dict, Union
import torch
from torchvision import transforms as _transforms
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F, Transform
......@@ -27,6 +28,8 @@ class ConvertBoundingBoxFormat(Transform):
class ConvertDtype(Transform):
_v1_transform_cls = _transforms.ConvertImageDtype
_transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video)
def __init__(self, dtype: torch.dtype = torch.float32) -> None:
......
......@@ -4,6 +4,7 @@ import PIL.Image
import torch
from torchvision import transforms as _transforms
from torchvision.ops import remove_small_boxes
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F, Transform
......@@ -39,6 +40,8 @@ class Lambda(Transform):
class LinearTransformation(Transform):
_v1_transform_cls = _transforms.LinearTransformation
_transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video)
def __init__(self, transformation_matrix: torch.Tensor, mean_vector: torch.Tensor):
......@@ -94,6 +97,7 @@ class LinearTransformation(Transform):
class Normalize(Transform):
_v1_transform_cls = _transforms.Normalize
_transformed_types = (datapoints.Image, is_simple_tensor, datapoints.Video)
def __init__(self, mean: Sequence[float], std: Sequence[float], inplace: bool = False):
......@@ -113,6 +117,8 @@ class Normalize(Transform):
class GaussianBlur(Transform):
_v1_transform_cls = _transforms.GaussianBlur
def __init__(
self, kernel_size: Union[int, Sequence[int]], sigma: Union[int, float, Sequence[float]] = (0.1, 2.0)
) -> None:
......
from __future__ import annotations
import enum
from typing import Any, Callable, Dict, List, Tuple, Type, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
import PIL.Image
import torch
......@@ -54,6 +56,51 @@ class Transform(nn.Module):
return ", ".join(extra)
# This attribute should be set on all transforms that have a v1 equivalent. Doing so enables the v2 transformation
# to be scriptable. See `_extract_params_for_v1_transform()` and `__prepare_scriptable__` for details.
_v1_transform_cls: Optional[Type[nn.Module]] = None
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
# This method is called by `__prepare_scriptable__` to instantiate the equivalent v1 transform from the current
# v2 transform instance. It does two things:
# 1. Extract all available public attributes that are specific to that transform and not `nn.Module` in general
# 2. If available handle the `fill` attribute for v1 compatibility (see below for details)
# Overwrite this method on the v2 transform class if the above is not sufficient. For example, this might happen
# if the v2 transform introduced new parameters that are not support by the v1 transform.
common_attrs = nn.Module().__dict__.keys()
params = {
attr: value
for attr, value in self.__dict__.items()
if not attr.startswith("_") and attr not in common_attrs
}
# transforms v2 has a more complex handling for the `fill` parameter than v1. By default, the input is parsed
# with `prototype.transforms._utils._setup_fill_arg()`, which returns a defaultdict that holds the fill value
# for the different datapoint types. Below we extract the value for tensors and return that together with the
# other params.
# This is needed for `Pad`, `ElasticTransform`, `RandomAffine`, `RandomCrop`, `RandomPerspective` and
# `RandomRotation`
if "fill" in params:
fill_type_defaultdict = params.pop("fill")
params["fill"] = fill_type_defaultdict[torch.Tensor]
return params
def __prepare_scriptable__(self) -> nn.Module:
# This method is called early on when `torch.jit.script`'ing an `nn.Module` instance. If it succeeds, the return
# value is used for scripting over the original object that should have been scripted. Since the v1 transforms
# are JIT scriptable, and we made sure that for single image inputs v1 and v2 are equivalent, we just return the
# equivalent v1 transform here. This of course only makes transforms v2 JIT scriptable as long as transforms v1
# is around.
if self._v1_transform_cls is None:
raise RuntimeError(
f"Transform {type(self.__name__)} cannot be JIT scripted. "
f"This is only support for backward compatibility with transforms which already in v1."
f"For torchscript support (on tensors only), you can use the functional API instead."
)
return self._v1_transform_cls(**self._extract_params_for_v1_transform())
class _RandomApplyTransform(Transform):
def __init__(self, p: float = 0.5) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment