Unverified Commit d9e6d60f authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Fix RandAugment and TrivialAugment bugs (#4370)

* Fix RA bugs.

* Fix bins for TA.
parent 446b2ca5
......@@ -263,7 +263,7 @@ class RandAugment(torch.nn.Module):
image. If given a number, the value is used for all bands respectively.
"""
def __init__(self, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: int = 30,
def __init__(self, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: int = 31,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None) -> None:
super().__init__()
......@@ -276,6 +276,7 @@ class RandAugment(torch.nn.Module):
def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]:
return {
# op_name: (magnitudes, signed)
"Identity": (torch.tensor(0.0), False),
"ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
......@@ -289,7 +290,6 @@ class RandAugment(torch.nn.Module):
"Solarize": (torch.linspace(256.0, 0.0, num_bins), False),
"AutoContrast": (torch.tensor(0.0), False),
"Equalize": (torch.tensor(0.0), False),
"Invert": (torch.tensor(0.0), False),
}
def forward(self, img: Tensor) -> Tensor:
......@@ -345,7 +345,7 @@ class TrivialAugmentWide(torch.nn.Module):
image. If given a number, the value is used for all bands respectively.
"""
def __init__(self, num_magnitude_bins: int = 30, interpolation: InterpolationMode = InterpolationMode.NEAREST,
def __init__(self, num_magnitude_bins: int = 31, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None) -> None:
super().__init__()
self.num_magnitude_bins = num_magnitude_bins
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment