Unverified Commit c35e0543 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

[prototype] Adjust solarize threshold on input type (#6874)

* Adjust solarize threshold on input type

* Handle PIL images

* minor refactor

* Fix linter
parent 226a56b9
......@@ -137,7 +137,8 @@ class _AutoAugmentBase(Transform):
elif transform_id == "Posterize":
return F.posterize(image, bits=int(magnitude))
elif transform_id == "Solarize":
return F.solarize(image, threshold=magnitude)
bound = 1.0 if isinstance(image, torch.Tensor) and image.is_floating_point() else 255.0
return F.solarize(image, threshold=bound * magnitude)
elif transform_id == "AutoContrast":
return F.autocontrast(image)
elif transform_id == "Equalize":
......@@ -169,7 +170,7 @@ class AutoAugment(_AutoAugmentBase):
lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
False,
),
"Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False),
"Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
"AutoContrast": (lambda num_bins, height, width: None, False),
"Equalize": (lambda num_bins, height, width: None, False),
"Invert": (lambda num_bins, height, width: None, False),
......@@ -324,7 +325,7 @@ class RandAugment(_AutoAugmentBase):
lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
False,
),
"Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False),
"Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
"AutoContrast": (lambda num_bins, height, width: None, False),
"Equalize": (lambda num_bins, height, width: None, False),
}
......@@ -378,7 +379,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
lambda num_bins, height, width: (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6))).round().int(),
False,
),
"Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False),
"Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
"AutoContrast": (lambda num_bins, height, width: None, False),
"Equalize": (lambda num_bins, height, width: None, False),
}
......@@ -423,7 +424,7 @@ class AugMix(_AutoAugmentBase):
lambda num_bins, height, width: (4 - (torch.arange(num_bins) / ((num_bins - 1) / 4))).round().int(),
False,
),
"Solarize": (lambda num_bins, height, width: torch.linspace(255.0, 0.0, num_bins), False),
"Solarize": (lambda num_bins, height, width: torch.linspace(1.0, 0.0, num_bins), False),
"AutoContrast": (lambda num_bins, height, width: None, False),
"Equalize": (lambda num_bins, height, width: None, False),
}
......@@ -505,13 +506,7 @@ class AugMix(_AutoAugmentBase):
aug = self._apply_image_or_video_transform(
aug, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
)
mix.add_(
# The multiplication below could become in-place provided `aug is not batch and aug.is_floating_point()`
# Currently we can't do this because `aug` has to be `unint8` to support ops like `equalize`.
# TODO: change this once all ops in `F` support floats. https://github.com/pytorch/vision/issues/6840
combined_weights[:, i].reshape(batch_dims)
* aug
)
mix.add_(combined_weights[:, i].reshape(batch_dims) * aug)
mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype)
if isinstance(orig_image_or_video, (features.Image, features.Video)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment