You need to sign in or sign up before continuing.
Unverified Commit efd6bc06 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

make fill defaultdict an implementation detail (#7258)

parent b7892d3a
...@@ -22,7 +22,8 @@ class FixedSizeCrop(Transform): ...@@ -22,7 +22,8 @@ class FixedSizeCrop(Transform):
self.crop_height = size[0] self.crop_height = size[0]
self.crop_width = size[1] self.crop_width = size[1]
self.fill = _setup_fill_arg(fill) self.fill = fill
self._fill = _setup_fill_arg(fill)
self.padding_mode = padding_mode self.padding_mode = padding_mode
...@@ -118,7 +119,7 @@ class FixedSizeCrop(Transform): ...@@ -118,7 +119,7 @@ class FixedSizeCrop(Transform):
) )
if params["needs_pad"]: if params["needs_pad"]:
fill = self.fill[type(inpt)] fill = self._fill[type(inpt)]
inpt = F.pad(inpt, params["padding"], fill=fill, padding_mode=self.padding_mode) inpt = F.pad(inpt, params["padding"], fill=fill, padding_mode=self.padding_mode)
return inpt return inpt
...@@ -255,9 +255,7 @@ class Pad(Transform): ...@@ -255,9 +255,7 @@ class Pad(Transform):
params = super()._extract_params_for_v1_transform() params = super()._extract_params_for_v1_transform()
if not (params["fill"] is None or isinstance(params["fill"], (int, float))): if not (params["fill"] is None or isinstance(params["fill"], (int, float))):
raise ValueError( raise ValueError(f"{type(self).__name__}() can only be scripted for a scalar `fill`, but got {self.fill}.")
f"{type(self.__name__)}() can only be scripted for a scalar `fill`, but got {self.fill} for images."
)
return params return params
...@@ -276,11 +274,12 @@ class Pad(Transform): ...@@ -276,11 +274,12 @@ class Pad(Transform):
if not isinstance(padding, int): if not isinstance(padding, int):
padding = list(padding) padding = list(padding)
self.padding = padding self.padding = padding
self.fill = _setup_fill_arg(fill) self.fill = fill
self._fill = _setup_fill_arg(fill)
self.padding_mode = padding_mode self.padding_mode = padding_mode
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)] fill = self._fill[type(inpt)]
return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type] return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type]
...@@ -293,7 +292,8 @@ class RandomZoomOut(_RandomApplyTransform): ...@@ -293,7 +292,8 @@ class RandomZoomOut(_RandomApplyTransform):
) -> None: ) -> None:
super().__init__(p=p) super().__init__(p=p)
self.fill = _setup_fill_arg(fill) self.fill = fill
self._fill = _setup_fill_arg(fill)
_check_sequence_input(side_range, "side_range", req_sizes=(2,)) _check_sequence_input(side_range, "side_range", req_sizes=(2,))
...@@ -318,7 +318,7 @@ class RandomZoomOut(_RandomApplyTransform): ...@@ -318,7 +318,7 @@ class RandomZoomOut(_RandomApplyTransform):
return dict(padding=padding) return dict(padding=padding)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)] fill = self._fill[type(inpt)]
return F.pad(inpt, **params, fill=fill) return F.pad(inpt, **params, fill=fill)
...@@ -338,7 +338,8 @@ class RandomRotation(Transform): ...@@ -338,7 +338,8 @@ class RandomRotation(Transform):
self.interpolation = _check_interpolation(interpolation) self.interpolation = _check_interpolation(interpolation)
self.expand = expand self.expand = expand
self.fill = _setup_fill_arg(fill) self.fill = fill
self._fill = _setup_fill_arg(fill)
if center is not None: if center is not None:
_check_sequence_input(center, "center", req_sizes=(2,)) _check_sequence_input(center, "center", req_sizes=(2,))
...@@ -350,7 +351,7 @@ class RandomRotation(Transform): ...@@ -350,7 +351,7 @@ class RandomRotation(Transform):
return dict(angle=angle) return dict(angle=angle)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)] fill = self._fill[type(inpt)]
return F.rotate( return F.rotate(
inpt, inpt,
**params, **params,
...@@ -395,7 +396,8 @@ class RandomAffine(Transform): ...@@ -395,7 +396,8 @@ class RandomAffine(Transform):
self.shear = shear self.shear = shear
self.interpolation = _check_interpolation(interpolation) self.interpolation = _check_interpolation(interpolation)
self.fill = _setup_fill_arg(fill) self.fill = fill
self._fill = _setup_fill_arg(fill)
if center is not None: if center is not None:
_check_sequence_input(center, "center", req_sizes=(2,)) _check_sequence_input(center, "center", req_sizes=(2,))
...@@ -430,7 +432,7 @@ class RandomAffine(Transform): ...@@ -430,7 +432,7 @@ class RandomAffine(Transform):
return dict(angle=angle, translate=translate, scale=scale, shear=shear) return dict(angle=angle, translate=translate, scale=scale, shear=shear)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)] fill = self._fill[type(inpt)]
return F.affine( return F.affine(
inpt, inpt,
**params, **params,
...@@ -447,9 +449,7 @@ class RandomCrop(Transform): ...@@ -447,9 +449,7 @@ class RandomCrop(Transform):
params = super()._extract_params_for_v1_transform() params = super()._extract_params_for_v1_transform()
if not (params["fill"] is None or isinstance(params["fill"], (int, float))): if not (params["fill"] is None or isinstance(params["fill"], (int, float))):
raise ValueError( raise ValueError(f"{type(self).__name__}() can only be scripted for a scalar `fill`, but got {self.fill}.")
f"{type(self.__name__)}() can only be scripted for a scalar `fill`, but got {self.fill} for images."
)
padding = self.padding padding = self.padding
if padding is not None: if padding is not None:
...@@ -478,7 +478,8 @@ class RandomCrop(Transform): ...@@ -478,7 +478,8 @@ class RandomCrop(Transform):
self.padding = F._geometry._parse_pad_padding(padding) if padding else None # type: ignore[arg-type] self.padding = F._geometry._parse_pad_padding(padding) if padding else None # type: ignore[arg-type]
self.pad_if_needed = pad_if_needed self.pad_if_needed = pad_if_needed
self.fill = _setup_fill_arg(fill) self.fill = fill
self._fill = _setup_fill_arg(fill)
self.padding_mode = padding_mode self.padding_mode = padding_mode
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
...@@ -541,7 +542,7 @@ class RandomCrop(Transform): ...@@ -541,7 +542,7 @@ class RandomCrop(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if params["needs_pad"]: if params["needs_pad"]:
fill = self.fill[type(inpt)] fill = self._fill[type(inpt)]
inpt = F.pad(inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode) inpt = F.pad(inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode)
if params["needs_crop"]: if params["needs_crop"]:
...@@ -567,7 +568,8 @@ class RandomPerspective(_RandomApplyTransform): ...@@ -567,7 +568,8 @@ class RandomPerspective(_RandomApplyTransform):
self.distortion_scale = distortion_scale self.distortion_scale = distortion_scale
self.interpolation = _check_interpolation(interpolation) self.interpolation = _check_interpolation(interpolation)
self.fill = _setup_fill_arg(fill) self.fill = fill
self._fill = _setup_fill_arg(fill)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
height, width = query_spatial_size(flat_inputs) height, width = query_spatial_size(flat_inputs)
...@@ -600,7 +602,7 @@ class RandomPerspective(_RandomApplyTransform): ...@@ -600,7 +602,7 @@ class RandomPerspective(_RandomApplyTransform):
return dict(coefficients=perspective_coeffs) return dict(coefficients=perspective_coeffs)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)] fill = self._fill[type(inpt)]
return F.perspective( return F.perspective(
inpt, inpt,
None, None,
...@@ -626,7 +628,8 @@ class ElasticTransform(Transform): ...@@ -626,7 +628,8 @@ class ElasticTransform(Transform):
self.sigma = _setup_float_or_seq(sigma, "sigma", 2) self.sigma = _setup_float_or_seq(sigma, "sigma", 2)
self.interpolation = _check_interpolation(interpolation) self.interpolation = _check_interpolation(interpolation)
self.fill = _setup_fill_arg(fill) self.fill = fill
self._fill = _setup_fill_arg(fill)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
size = list(query_spatial_size(flat_inputs)) size = list(query_spatial_size(flat_inputs))
...@@ -652,7 +655,7 @@ class ElasticTransform(Transform): ...@@ -652,7 +655,7 @@ class ElasticTransform(Transform):
return dict(displacement=displacement) return dict(displacement=displacement)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)] fill = self._fill[type(inpt)]
return F.elastic( return F.elastic(
inpt, inpt,
**params, **params,
......
...@@ -108,30 +108,17 @@ class Transform(nn.Module): ...@@ -108,30 +108,17 @@ class Transform(nn.Module):
def _extract_params_for_v1_transform(self) -> Dict[str, Any]: def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
# This method is called by `__prepare_scriptable__` to instantiate the equivalent v1 transform from the current # This method is called by `__prepare_scriptable__` to instantiate the equivalent v1 transform from the current
# v2 transform instance. It does two things: # v2 transform instance. It extracts all available public attributes that are specific to that transform and
# 1. Extract all available public attributes that are specific to that transform and not `nn.Module` in general # not `nn.Module` in general.
# 2. If available handle the `fill` attribute for v1 compatibility (see below for details)
# Overwrite this method on the v2 transform class if the above is not sufficient. For example, this might happen # Overwrite this method on the v2 transform class if the above is not sufficient. For example, this might happen
# if the v2 transform introduced new parameters that are not support by the v1 transform. # if the v2 transform introduced new parameters that are not support by the v1 transform.
common_attrs = nn.Module().__dict__.keys() common_attrs = nn.Module().__dict__.keys()
params = { return {
attr: value attr: value
for attr, value in self.__dict__.items() for attr, value in self.__dict__.items()
if not attr.startswith("_") and attr not in common_attrs if not attr.startswith("_") and attr not in common_attrs
} }
# transforms v2 has a more complex handling for the `fill` parameter than v1. By default, the input is parsed
# with `prototype.transforms._utils._setup_fill_arg()`, which returns a defaultdict that holds the fill value
# for the different datapoint types. Below we extract the value for tensors and return that together with the
# other params.
# This is needed for `Pad`, `ElasticTransform`, `RandomAffine`, `RandomCrop`, `RandomPerspective` and
# `RandomRotation`
if "fill" in params:
fill_type_defaultdict = params.pop("fill")
params["fill"] = fill_type_defaultdict[torch.Tensor]
return params
def __prepare_scriptable__(self) -> nn.Module: def __prepare_scriptable__(self) -> nn.Module:
# This method is called early on when `torch.jit.script`'ing an `nn.Module` instance. If it succeeds, the return # This method is called early on when `torch.jit.script`'ing an `nn.Module` instance. If it succeeds, the return
# value is used for scripting over the original object that should have been scripted. Since the v1 transforms # value is used for scripting over the original object that should have been scripted. Since the v1 transforms
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment