Unverified Commit 15b97d4b authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Fix BC for inplace arg in Normalize and RandomErasing (#6530)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 11304cb7
......@@ -1004,7 +1004,13 @@ class TestRandomErasing:
if p:
mock.assert_called_once_with(
inpt_sentinel, i=i_sentinel, j=j_sentinel, h=h_sentinel, w=w_sentinel, v=v_sentinel
inpt_sentinel,
i=i_sentinel,
j=j_sentinel,
h=h_sentinel,
w=w_sentinel,
v=v_sentinel,
inplace=transform.inplace,
)
else:
mock.assert_not_called()
......
......@@ -88,7 +88,6 @@ CONSISTENCY_CONFIGS = [
],
supports_pil=False,
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.float]),
removed_params=["inplace"],
),
ConsistencyConfig(
prototype_transforms.Resize,
......@@ -315,7 +314,6 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs(p=1, value="random"),
],
supports_pil=False,
removed_params=["inplace"],
),
ConsistencyConfig(
prototype_transforms.ColorJitter,
......
......@@ -23,6 +23,7 @@ class RandomErasing(_RandomApplyTransform):
scale: Tuple[float, float] = (0.02, 0.33),
ratio: Tuple[float, float] = (0.3, 3.3),
value: float = 0,
inplace: bool = False,
):
super().__init__(p=p)
if not isinstance(value, (numbers.Number, str, tuple, list)):
......@@ -40,6 +41,7 @@ class RandomErasing(_RandomApplyTransform):
self.scale = scale
self.ratio = ratio
self.value = value
self.inplace = inplace
def _get_params(self, sample: Any) -> Dict[str, Any]:
img_c, img_h, img_w = query_chw(sample)
......@@ -92,7 +94,7 @@ class RandomErasing(_RandomApplyTransform):
self, inpt: Union[torch.Tensor, features.Image, PIL.Image.Image], params: Dict[str, Any]
) -> Union[torch.Tensor, features.Image, PIL.Image.Image]:
if params["v"] is not None:
inpt = F.erase(inpt, **params)
inpt = F.erase(inpt, **params, inplace=self.inplace)
return inpt
......
......@@ -95,13 +95,14 @@ class LinearTransformation(Transform):
class Normalize(Transform):
_transformed_types = (features.Image, features.is_simple_tensor)
def __init__(self, mean: Sequence[float], std: Sequence[float]):
def __init__(self, mean: Sequence[float], std: Sequence[float], inplace: bool = False):
super().__init__()
self.mean = list(mean)
self.std = list(std)
self.inplace = inplace
def _transform(self, inpt: Union[torch.Tensor, features._Feature], params: Dict[str, Any]) -> torch.Tensor:
return F.normalize(inpt, mean=self.mean, std=self.std)
return F.normalize(inpt, mean=self.mean, std=self.std, inplace=self.inplace)
def forward(self, *inpts: Any) -> Any:
if has_any(inpts, PIL.Image.Image):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment