Unverified Commit 15b97d4b authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Fix BC for inplace arg in Normalize and RandomErasing (#6530)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 11304cb7
...@@ -1004,7 +1004,13 @@ class TestRandomErasing: ...@@ -1004,7 +1004,13 @@ class TestRandomErasing:
if p: if p:
mock.assert_called_once_with( mock.assert_called_once_with(
inpt_sentinel, i=i_sentinel, j=j_sentinel, h=h_sentinel, w=w_sentinel, v=v_sentinel inpt_sentinel,
i=i_sentinel,
j=j_sentinel,
h=h_sentinel,
w=w_sentinel,
v=v_sentinel,
inplace=transform.inplace,
) )
else: else:
mock.assert_not_called() mock.assert_not_called()
......
...@@ -88,7 +88,6 @@ CONSISTENCY_CONFIGS = [ ...@@ -88,7 +88,6 @@ CONSISTENCY_CONFIGS = [
], ],
supports_pil=False, supports_pil=False,
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.float]), make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.float]),
removed_params=["inplace"],
), ),
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.Resize, prototype_transforms.Resize,
...@@ -315,7 +314,6 @@ CONSISTENCY_CONFIGS = [ ...@@ -315,7 +314,6 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs(p=1, value="random"), ArgsKwargs(p=1, value="random"),
], ],
supports_pil=False, supports_pil=False,
removed_params=["inplace"],
), ),
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.ColorJitter, prototype_transforms.ColorJitter,
......
...@@ -23,6 +23,7 @@ class RandomErasing(_RandomApplyTransform): ...@@ -23,6 +23,7 @@ class RandomErasing(_RandomApplyTransform):
scale: Tuple[float, float] = (0.02, 0.33), scale: Tuple[float, float] = (0.02, 0.33),
ratio: Tuple[float, float] = (0.3, 3.3), ratio: Tuple[float, float] = (0.3, 3.3),
value: float = 0, value: float = 0,
inplace: bool = False,
): ):
super().__init__(p=p) super().__init__(p=p)
if not isinstance(value, (numbers.Number, str, tuple, list)): if not isinstance(value, (numbers.Number, str, tuple, list)):
...@@ -40,6 +41,7 @@ class RandomErasing(_RandomApplyTransform): ...@@ -40,6 +41,7 @@ class RandomErasing(_RandomApplyTransform):
self.scale = scale self.scale = scale
self.ratio = ratio self.ratio = ratio
self.value = value self.value = value
self.inplace = inplace
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
img_c, img_h, img_w = query_chw(sample) img_c, img_h, img_w = query_chw(sample)
...@@ -92,7 +94,7 @@ class RandomErasing(_RandomApplyTransform): ...@@ -92,7 +94,7 @@ class RandomErasing(_RandomApplyTransform):
self, inpt: Union[torch.Tensor, features.Image, PIL.Image.Image], params: Dict[str, Any] self, inpt: Union[torch.Tensor, features.Image, PIL.Image.Image], params: Dict[str, Any]
) -> Union[torch.Tensor, features.Image, PIL.Image.Image]: ) -> Union[torch.Tensor, features.Image, PIL.Image.Image]:
if params["v"] is not None: if params["v"] is not None:
inpt = F.erase(inpt, **params) inpt = F.erase(inpt, **params, inplace=self.inplace)
return inpt return inpt
......
...@@ -95,13 +95,14 @@ class LinearTransformation(Transform): ...@@ -95,13 +95,14 @@ class LinearTransformation(Transform):
class Normalize(Transform): class Normalize(Transform):
_transformed_types = (features.Image, features.is_simple_tensor) _transformed_types = (features.Image, features.is_simple_tensor)
def __init__(self, mean: Sequence[float], std: Sequence[float]): def __init__(self, mean: Sequence[float], std: Sequence[float], inplace: bool = False):
super().__init__() super().__init__()
self.mean = list(mean) self.mean = list(mean)
self.std = list(std) self.std = list(std)
self.inplace = inplace
def _transform(self, inpt: Union[torch.Tensor, features._Feature], params: Dict[str, Any]) -> torch.Tensor: def _transform(self, inpt: Union[torch.Tensor, features._Feature], params: Dict[str, Any]) -> torch.Tensor:
return F.normalize(inpt, mean=self.mean, std=self.std) return F.normalize(inpt, mean=self.mean, std=self.std, inplace=self.inplace)
def forward(self, *inpts: Any) -> Any: def forward(self, *inpts: Any) -> Any:
if has_any(inpts, PIL.Image.Image): if has_any(inpts, PIL.Image.Image):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment