Unverified Commit 18a2e8eb authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

move parameter sampling of RandomPhotometricDistort into _get_params (#7442)

parent b1d16c9c
......@@ -228,19 +228,22 @@ class RandomPhotometricDistort(Transform):
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
num_channels, *_ = query_chw(flat_inputs)
return dict(
zip(
["brightness", "contrast1", "saturation", "hue", "contrast2"],
(torch.rand(5) < self.p).tolist(),
),
contrast_before=bool(torch.rand(()) < 0.5),
channel_permutation=torch.randperm(num_channels) if torch.rand(()) < self.p else None,
)
params: Dict[str, Any] = {
key: ColorJitter._generate_value(range[0], range[1]) if torch.rand(1) < self.p else None
for key, range in [
("brightness_factor", self.brightness),
("contrast_factor", self.contrast),
("saturation_factor", self.saturation),
("hue_factor", self.hue),
]
}
params["contrast_before"] = bool(torch.rand(()) < 0.5)
params["channel_permutation"] = torch.randperm(num_channels) if torch.rand(1) < self.p else None
return params
def _permute_channels(
self, inpt: Union[datapoints._ImageType, datapoints._VideoType], permutation: torch.Tensor
) -> Union[datapoints._ImageType, datapoints._VideoType]:
orig_inpt = inpt
if isinstance(orig_inpt, PIL.Image.Image):
inpt = F.pil_to_tensor(inpt)
......@@ -256,24 +259,16 @@ class RandomPhotometricDistort(Transform):
def _transform(
self, inpt: Union[datapoints._ImageType, datapoints._VideoType], params: Dict[str, Any]
) -> Union[datapoints._ImageType, datapoints._VideoType]:
if params["brightness"]:
inpt = F.adjust_brightness(
inpt, brightness_factor=ColorJitter._generate_value(self.brightness[0], self.brightness[1])
)
if params["contrast1"] and params["contrast_before"]:
inpt = F.adjust_contrast(
inpt, contrast_factor=ColorJitter._generate_value(self.contrast[0], self.contrast[1])
)
if params["saturation"]:
inpt = F.adjust_saturation(
inpt, saturation_factor=ColorJitter._generate_value(self.saturation[0], self.saturation[1])
)
if params["hue"]:
inpt = F.adjust_hue(inpt, hue_factor=ColorJitter._generate_value(self.hue[0], self.hue[1]))
if params["contrast2"] and not params["contrast_before"]:
inpt = F.adjust_contrast(
inpt, contrast_factor=ColorJitter._generate_value(self.contrast[0], self.contrast[1])
)
if params["brightness_factor"] is not None:
inpt = F.adjust_brightness(inpt, brightness_factor=params["brightness_factor"])
if params["contrast_factor"] is not None and params["contrast_before"]:
inpt = F.adjust_contrast(inpt, contrast_factor=params["contrast_factor"])
if params["saturation_factor"] is not None:
inpt = F.adjust_saturation(inpt, saturation_factor=params["saturation_factor"])
if params["hue_factor"] is not None:
inpt = F.adjust_hue(inpt, hue_factor=params["hue_factor"])
if params["contrast_factor"] is not None and not params["contrast_before"]:
inpt = F.adjust_contrast(inpt, contrast_factor=params["contrast_factor"])
if params["channel_permutation"] is not None:
inpt = self._permute_channels(inpt, permutation=params["channel_permutation"])
return inpt
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment