Unverified Commit 76ebe92d authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

[BC-breaking] ColorJitter gets its random params by calling get_params() (#3001)

* ColorJitter gets its random params by calling get_params().

* Update arguments.

* Styles.

* Add description for Nones.

* Chainging Nones to optional.
parent afd9d4d8
...@@ -1051,38 +1051,35 @@ class ColorJitter(torch.nn.Module): ...@@ -1051,38 +1051,35 @@ class ColorJitter(torch.nn.Module):
return value return value
@staticmethod @staticmethod
@torch.jit.unused def get_params(brightness: Optional[List[float]],
def get_params(brightness, contrast, saturation, hue): contrast: Optional[List[float]],
"""Get a randomized transform to be applied on image. saturation: Optional[List[float]],
hue: Optional[List[float]]
) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]:
"""Get the parameters for the randomized transform to be applied on image.
Arguments are same as that of __init__. Args:
brightness (tuple of float (min, max), optional): The range from which the brightness_factor is chosen
uniformly. Pass None to turn off the transformation.
contrast (tuple of float (min, max), optional): The range from which the contrast_factor is chosen
uniformly. Pass None to turn off the transformation.
saturation (tuple of float (min, max), optional): The range from which the saturation_factor is chosen
uniformly. Pass None to turn off the transformation.
hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly.
Pass None to turn off the transformation.
Returns: Returns:
Transform which randomly adjusts brightness, contrast and tuple: The parameters used to apply the randomized transform
saturation in a random order. along with their random order.
""" """
transforms = [] fn_idx = torch.randperm(4)
if brightness is not None:
brightness_factor = random.uniform(brightness[0], brightness[1])
transforms.append(Lambda(lambda img: F.adjust_brightness(img, brightness_factor)))
if contrast is not None:
contrast_factor = random.uniform(contrast[0], contrast[1])
transforms.append(Lambda(lambda img: F.adjust_contrast(img, contrast_factor)))
if saturation is not None:
saturation_factor = random.uniform(saturation[0], saturation[1])
transforms.append(Lambda(lambda img: F.adjust_saturation(img, saturation_factor)))
if hue is not None:
hue_factor = random.uniform(hue[0], hue[1])
transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor)))
random.shuffle(transforms) b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1]))
transform = Compose(transforms) c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1]))
s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1]))
h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1]))
return transform return fn_idx, b, c, s, h
def forward(self, img): def forward(self, img):
""" """
...@@ -1092,26 +1089,17 @@ class ColorJitter(torch.nn.Module): ...@@ -1092,26 +1089,17 @@ class ColorJitter(torch.nn.Module):
Returns: Returns:
PIL Image or Tensor: Color jittered image. PIL Image or Tensor: Color jittered image.
""" """
fn_idx = torch.randperm(4) fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = \
self.get_params(self.brightness, self.contrast, self.saturation, self.hue)
for fn_id in fn_idx: for fn_id in fn_idx:
if fn_id == 0 and self.brightness is not None: if fn_id == 0 and brightness_factor is not None:
brightness = self.brightness
brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
img = F.adjust_brightness(img, brightness_factor) img = F.adjust_brightness(img, brightness_factor)
elif fn_id == 1 and contrast_factor is not None:
if fn_id == 1 and self.contrast is not None:
contrast = self.contrast
contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
img = F.adjust_contrast(img, contrast_factor) img = F.adjust_contrast(img, contrast_factor)
elif fn_id == 2 and saturation_factor is not None:
if fn_id == 2 and self.saturation is not None:
saturation = self.saturation
saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
img = F.adjust_saturation(img, saturation_factor) img = F.adjust_saturation(img, saturation_factor)
elif fn_id == 3 and hue_factor is not None:
if fn_id == 3 and self.hue is not None:
hue = self.hue
hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
img = F.adjust_hue(img, hue_factor) img = F.adjust_hue(img, hue_factor)
return img return img
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment