"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "c375903db58826494d858e02b44d21b42669ff5e"
Unverified Commit b6feccbc authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

fix RandomPhotometricDistort (#6482)



* fix RandomPhotometricDistort

* Update torchvision/prototype/transforms/_color.py
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent c3ec6948
...@@ -105,9 +105,9 @@ class RandomPhotometricDistort(Transform): ...@@ -105,9 +105,9 @@ class RandomPhotometricDistort(Transform):
return dict( return dict(
zip( zip(
["brightness", "contrast1", "saturation", "hue", "contrast2"], ["brightness", "contrast1", "saturation", "hue", "contrast2"],
torch.rand(6) < self.p, (torch.rand(5) < self.p).tolist(),
), ),
contrast_before=torch.rand(()) < 0.5, contrast_before=bool(torch.rand(()) < 0.5),
channel_permutation=torch.randperm(num_channels) if torch.rand(()) < self.p else None, channel_permutation=torch.randperm(num_channels) if torch.rand(()) < self.p else None,
) )
...@@ -147,7 +147,7 @@ class RandomPhotometricDistort(Transform): ...@@ -147,7 +147,7 @@ class RandomPhotometricDistort(Transform):
inpt = F.adjust_contrast( inpt = F.adjust_contrast(
inpt, contrast_factor=ColorJitter._generate_value(self.contrast[0], self.contrast[1]) inpt, contrast_factor=ColorJitter._generate_value(self.contrast[0], self.contrast[1])
) )
if params["channel_permutation"]: if params["channel_permutation"] is not None:
inpt = self._permute_channels(inpt, permutation=params["channel_permutation"]) inpt = self._permute_channels(inpt, permutation=params["channel_permutation"])
return inpt return inpt
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment