Unverified Commit 18a2e8eb authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

move parameter sampling of RandomPhotometricDistort into _get_params (#7442)

parent b1d16c9c
...@@ -228,19 +228,22 @@ class RandomPhotometricDistort(Transform): ...@@ -228,19 +228,22 @@ class RandomPhotometricDistort(Transform):
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
num_channels, *_ = query_chw(flat_inputs) num_channels, *_ = query_chw(flat_inputs)
return dict( params: Dict[str, Any] = {
zip( key: ColorJitter._generate_value(range[0], range[1]) if torch.rand(1) < self.p else None
["brightness", "contrast1", "saturation", "hue", "contrast2"], for key, range in [
(torch.rand(5) < self.p).tolist(), ("brightness_factor", self.brightness),
), ("contrast_factor", self.contrast),
contrast_before=bool(torch.rand(()) < 0.5), ("saturation_factor", self.saturation),
channel_permutation=torch.randperm(num_channels) if torch.rand(()) < self.p else None, ("hue_factor", self.hue),
) ]
}
params["contrast_before"] = bool(torch.rand(()) < 0.5)
params["channel_permutation"] = torch.randperm(num_channels) if torch.rand(1) < self.p else None
return params
def _permute_channels( def _permute_channels(
self, inpt: Union[datapoints._ImageType, datapoints._VideoType], permutation: torch.Tensor self, inpt: Union[datapoints._ImageType, datapoints._VideoType], permutation: torch.Tensor
) -> Union[datapoints._ImageType, datapoints._VideoType]: ) -> Union[datapoints._ImageType, datapoints._VideoType]:
orig_inpt = inpt orig_inpt = inpt
if isinstance(orig_inpt, PIL.Image.Image): if isinstance(orig_inpt, PIL.Image.Image):
inpt = F.pil_to_tensor(inpt) inpt = F.pil_to_tensor(inpt)
...@@ -256,24 +259,16 @@ class RandomPhotometricDistort(Transform): ...@@ -256,24 +259,16 @@ class RandomPhotometricDistort(Transform):
def _transform( def _transform(
self, inpt: Union[datapoints._ImageType, datapoints._VideoType], params: Dict[str, Any] self, inpt: Union[datapoints._ImageType, datapoints._VideoType], params: Dict[str, Any]
) -> Union[datapoints._ImageType, datapoints._VideoType]: ) -> Union[datapoints._ImageType, datapoints._VideoType]:
if params["brightness"]: if params["brightness_factor"] is not None:
inpt = F.adjust_brightness( inpt = F.adjust_brightness(inpt, brightness_factor=params["brightness_factor"])
inpt, brightness_factor=ColorJitter._generate_value(self.brightness[0], self.brightness[1]) if params["contrast_factor"] is not None and params["contrast_before"]:
) inpt = F.adjust_contrast(inpt, contrast_factor=params["contrast_factor"])
if params["contrast1"] and params["contrast_before"]: if params["saturation_factor"] is not None:
inpt = F.adjust_contrast( inpt = F.adjust_saturation(inpt, saturation_factor=params["saturation_factor"])
inpt, contrast_factor=ColorJitter._generate_value(self.contrast[0], self.contrast[1]) if params["hue_factor"] is not None:
) inpt = F.adjust_hue(inpt, hue_factor=params["hue_factor"])
if params["saturation"]: if params["contrast_factor"] is not None and not params["contrast_before"]:
inpt = F.adjust_saturation( inpt = F.adjust_contrast(inpt, contrast_factor=params["contrast_factor"])
inpt, saturation_factor=ColorJitter._generate_value(self.saturation[0], self.saturation[1])
)
if params["hue"]:
inpt = F.adjust_hue(inpt, hue_factor=ColorJitter._generate_value(self.hue[0], self.hue[1]))
if params["contrast2"] and not params["contrast_before"]:
inpt = F.adjust_contrast(
inpt, contrast_factor=ColorJitter._generate_value(self.contrast[0], self.contrast[1])
)
if params["channel_permutation"] is not None: if params["channel_permutation"] is not None:
inpt = self._permute_channels(inpt, permutation=params["channel_permutation"]) inpt = self._permute_channels(inpt, permutation=params["channel_permutation"])
return inpt return inpt
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment