"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "f4fa3beee7f49b80ce7a58f9c8002f43299175c9"
Unverified Commit 9b0da0c3 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

replace tensor division with scalar division and tensor multiplication (#6903)

* replace tensor division with scalar division and tensor multiplication

* fix consistency test tolerances
parent 4508c84e
...@@ -163,6 +163,8 @@ CONSISTENCY_CONFIGS = [ ...@@ -163,6 +163,8 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs(torch.uint8), ArgsKwargs(torch.uint8),
], ],
supports_pil=False, supports_pil=False,
# Use default tolerances of `torch.testing.assert_close`
closeness_kwargs=dict(rtol=None, atol=None),
), ),
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.ToPILImage, prototype_transforms.ToPILImage,
......
...@@ -180,7 +180,7 @@ def _rgb_to_hsv(image: torch.Tensor) -> torch.Tensor: ...@@ -180,7 +180,7 @@ def _rgb_to_hsv(image: torch.Tensor) -> torch.Tensor:
hb = (4.0 + gc).sub_(rc).mul_(mask_maxc_neq_g & mask_maxc_neq_r) hb = (4.0 + gc).sub_(rc).mul_(mask_maxc_neq_g & mask_maxc_neq_r)
h = hr.add_(hg).add_(hb) h = hr.add_(hg).add_(hb)
h = h.div_(6.0).add_(1.0).fmod_(1.0) h = h.mul_(1.0 / 6.0).add_(1.0).fmod_(1.0)
return torch.stack((h, s, maxc), dim=-3) return torch.stack((h, s, maxc), dim=-3)
...@@ -287,7 +287,7 @@ def adjust_gamma(inpt: features.InputTypeJIT, gamma: float, gain: float = 1) -> ...@@ -287,7 +287,7 @@ def adjust_gamma(inpt: features.InputTypeJIT, gamma: float, gain: float = 1) ->
def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor: def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor:
if image.is_floating_point(): if image.is_floating_point():
levels = 1 << bits levels = 1 << bits
return image.mul(levels).floor_().clamp_(0, levels - 1).div_(levels) return image.mul(levels).floor_().clamp_(0, levels - 1).mul_(1.0 / levels)
else: else:
num_value_bits = _num_value_bits(image.dtype) num_value_bits = _num_value_bits(image.dtype)
if bits >= num_value_bits: if bits >= num_value_bits:
......
...@@ -367,7 +367,7 @@ def convert_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.f ...@@ -367,7 +367,7 @@ def convert_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.f
else: else:
# int to float # int to float
if float_output: if float_output:
return image.to(dtype).div_(_FT._max_value(image.dtype)) return image.to(dtype).mul_(1.0 / _FT._max_value(image.dtype))
# int to int # int to int
num_value_bits_input = _num_value_bits(image.dtype) num_value_bits_input = _num_value_bits(image.dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment