Unverified Commit efd6bc06 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

make fill defaultdict an implementation detail (#7258)

parent b7892d3a
......@@ -22,7 +22,8 @@ class FixedSizeCrop(Transform):
self.crop_height = size[0]
self.crop_width = size[1]
self.fill = _setup_fill_arg(fill)
self.fill = fill
self._fill = _setup_fill_arg(fill)
self.padding_mode = padding_mode
......@@ -118,7 +119,7 @@ class FixedSizeCrop(Transform):
)
if params["needs_pad"]:
fill = self.fill[type(inpt)]
fill = self._fill[type(inpt)]
inpt = F.pad(inpt, params["padding"], fill=fill, padding_mode=self.padding_mode)
return inpt
......@@ -255,9 +255,7 @@ class Pad(Transform):
params = super()._extract_params_for_v1_transform()
if not (params["fill"] is None or isinstance(params["fill"], (int, float))):
raise ValueError(
f"{type(self.__name__)}() can only be scripted for a scalar `fill`, but got {self.fill} for images."
)
raise ValueError(f"{type(self).__name__}() can only be scripted for a scalar `fill`, but got {self.fill}.")
return params
......@@ -276,11 +274,12 @@ class Pad(Transform):
if not isinstance(padding, int):
padding = list(padding)
self.padding = padding
self.fill = _setup_fill_arg(fill)
self.fill = fill
self._fill = _setup_fill_arg(fill)
self.padding_mode = padding_mode
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
fill = self._fill[type(inpt)]
return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type]
......@@ -293,7 +292,8 @@ class RandomZoomOut(_RandomApplyTransform):
) -> None:
super().__init__(p=p)
self.fill = _setup_fill_arg(fill)
self.fill = fill
self._fill = _setup_fill_arg(fill)
_check_sequence_input(side_range, "side_range", req_sizes=(2,))
......@@ -318,7 +318,7 @@ class RandomZoomOut(_RandomApplyTransform):
return dict(padding=padding)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
fill = self._fill[type(inpt)]
return F.pad(inpt, **params, fill=fill)
......@@ -338,7 +338,8 @@ class RandomRotation(Transform):
self.interpolation = _check_interpolation(interpolation)
self.expand = expand
self.fill = _setup_fill_arg(fill)
self.fill = fill
self._fill = _setup_fill_arg(fill)
if center is not None:
_check_sequence_input(center, "center", req_sizes=(2,))
......@@ -350,7 +351,7 @@ class RandomRotation(Transform):
return dict(angle=angle)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
fill = self._fill[type(inpt)]
return F.rotate(
inpt,
**params,
......@@ -395,7 +396,8 @@ class RandomAffine(Transform):
self.shear = shear
self.interpolation = _check_interpolation(interpolation)
self.fill = _setup_fill_arg(fill)
self.fill = fill
self._fill = _setup_fill_arg(fill)
if center is not None:
_check_sequence_input(center, "center", req_sizes=(2,))
......@@ -430,7 +432,7 @@ class RandomAffine(Transform):
return dict(angle=angle, translate=translate, scale=scale, shear=shear)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
fill = self._fill[type(inpt)]
return F.affine(
inpt,
**params,
......@@ -447,9 +449,7 @@ class RandomCrop(Transform):
params = super()._extract_params_for_v1_transform()
if not (params["fill"] is None or isinstance(params["fill"], (int, float))):
raise ValueError(
f"{type(self.__name__)}() can only be scripted for a scalar `fill`, but got {self.fill} for images."
)
raise ValueError(f"{type(self).__name__}() can only be scripted for a scalar `fill`, but got {self.fill}.")
padding = self.padding
if padding is not None:
......@@ -478,7 +478,8 @@ class RandomCrop(Transform):
self.padding = F._geometry._parse_pad_padding(padding) if padding else None # type: ignore[arg-type]
self.pad_if_needed = pad_if_needed
self.fill = _setup_fill_arg(fill)
self.fill = fill
self._fill = _setup_fill_arg(fill)
self.padding_mode = padding_mode
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
......@@ -541,7 +542,7 @@ class RandomCrop(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if params["needs_pad"]:
fill = self.fill[type(inpt)]
fill = self._fill[type(inpt)]
inpt = F.pad(inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode)
if params["needs_crop"]:
......@@ -567,7 +568,8 @@ class RandomPerspective(_RandomApplyTransform):
self.distortion_scale = distortion_scale
self.interpolation = _check_interpolation(interpolation)
self.fill = _setup_fill_arg(fill)
self.fill = fill
self._fill = _setup_fill_arg(fill)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
height, width = query_spatial_size(flat_inputs)
......@@ -600,7 +602,7 @@ class RandomPerspective(_RandomApplyTransform):
return dict(coefficients=perspective_coeffs)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
fill = self._fill[type(inpt)]
return F.perspective(
inpt,
None,
......@@ -626,7 +628,8 @@ class ElasticTransform(Transform):
self.sigma = _setup_float_or_seq(sigma, "sigma", 2)
self.interpolation = _check_interpolation(interpolation)
self.fill = _setup_fill_arg(fill)
self.fill = fill
self._fill = _setup_fill_arg(fill)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
size = list(query_spatial_size(flat_inputs))
......@@ -652,7 +655,7 @@ class ElasticTransform(Transform):
return dict(displacement=displacement)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
fill = self._fill[type(inpt)]
return F.elastic(
inpt,
**params,
......
......@@ -108,30 +108,17 @@ class Transform(nn.Module):
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
# This method is called by `__prepare_scriptable__` to instantiate the equivalent v1 transform from the current
# v2 transform instance. It does two things:
# 1. Extract all available public attributes that are specific to that transform and not `nn.Module` in general
# 2. If available handle the `fill` attribute for v1 compatibility (see below for details)
# v2 transform instance. It extracts all available public attributes that are specific to that transform and
# not `nn.Module` in general.
# Overwrite this method on the v2 transform class if the above is not sufficient. For example, this might happen
# if the v2 transform introduced new parameters that are not support by the v1 transform.
common_attrs = nn.Module().__dict__.keys()
params = {
return {
attr: value
for attr, value in self.__dict__.items()
if not attr.startswith("_") and attr not in common_attrs
}
# transforms v2 has a more complex handling for the `fill` parameter than v1. By default, the input is parsed
# with `prototype.transforms._utils._setup_fill_arg()`, which returns a defaultdict that holds the fill value
# for the different datapoint types. Below we extract the value for tensors and return that together with the
# other params.
# This is needed for `Pad`, `ElasticTransform`, `RandomAffine`, `RandomCrop`, `RandomPerspective` and
# `RandomRotation`
if "fill" in params:
fill_type_defaultdict = params.pop("fill")
params["fill"] = fill_type_defaultdict[torch.Tensor]
return params
def __prepare_scriptable__(self) -> nn.Module:
# This method is called early on when `torch.jit.script`'ing an `nn.Module` instance. If it succeeds, the return
# value is used for scripting over the original object that should have been scripted. Since the v1 transforms
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment