Unverified Commit 76ebe92d authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

[BC-breaking] ColorJitter gets its random params by calling get_params() (#3001)

* ColorJitter gets its random params by calling get_params().

* Update arguments.

* Styles.

* Add description for Nones.

* Chainging Nones to optional.
parent afd9d4d8
......@@ -1051,38 +1051,35 @@ class ColorJitter(torch.nn.Module):
return value
@staticmethod
@torch.jit.unused
def get_params(brightness, contrast, saturation, hue):
"""Get a randomized transform to be applied on image.
def get_params(brightness: Optional[List[float]],
contrast: Optional[List[float]],
saturation: Optional[List[float]],
hue: Optional[List[float]]
) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]:
"""Get the parameters for the randomized transform to be applied on image.
Arguments are same as that of __init__.
Args:
brightness (tuple of float (min, max), optional): The range from which the brightness_factor is chosen
uniformly. Pass None to turn off the transformation.
contrast (tuple of float (min, max), optional): The range from which the contrast_factor is chosen
uniformly. Pass None to turn off the transformation.
saturation (tuple of float (min, max), optional): The range from which the saturation_factor is chosen
uniformly. Pass None to turn off the transformation.
hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly.
Pass None to turn off the transformation.
Returns:
Transform which randomly adjusts brightness, contrast and
saturation in a random order.
tuple: The parameters used to apply the randomized transform
along with their random order.
"""
transforms = []
if brightness is not None:
brightness_factor = random.uniform(brightness[0], brightness[1])
transforms.append(Lambda(lambda img: F.adjust_brightness(img, brightness_factor)))
if contrast is not None:
contrast_factor = random.uniform(contrast[0], contrast[1])
transforms.append(Lambda(lambda img: F.adjust_contrast(img, contrast_factor)))
if saturation is not None:
saturation_factor = random.uniform(saturation[0], saturation[1])
transforms.append(Lambda(lambda img: F.adjust_saturation(img, saturation_factor)))
if hue is not None:
hue_factor = random.uniform(hue[0], hue[1])
transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor)))
fn_idx = torch.randperm(4)
random.shuffle(transforms)
transform = Compose(transforms)
b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1]))
c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1]))
s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1]))
h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1]))
return transform
return fn_idx, b, c, s, h
def forward(self, img):
"""
......@@ -1092,26 +1089,17 @@ class ColorJitter(torch.nn.Module):
Returns:
PIL Image or Tensor: Color jittered image.
"""
fn_idx = torch.randperm(4)
fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = \
self.get_params(self.brightness, self.contrast, self.saturation, self.hue)
for fn_id in fn_idx:
if fn_id == 0 and self.brightness is not None:
brightness = self.brightness
brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
if fn_id == 0 and brightness_factor is not None:
img = F.adjust_brightness(img, brightness_factor)
if fn_id == 1 and self.contrast is not None:
contrast = self.contrast
contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
elif fn_id == 1 and contrast_factor is not None:
img = F.adjust_contrast(img, contrast_factor)
if fn_id == 2 and self.saturation is not None:
saturation = self.saturation
saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
elif fn_id == 2 and saturation_factor is not None:
img = F.adjust_saturation(img, saturation_factor)
if fn_id == 3 and self.hue is not None:
hue = self.hue
hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
elif fn_id == 3 and hue_factor is not None:
img = F.adjust_hue(img, hue_factor)
return img
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment