"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "0dd59e0dda22eabf54fc95ad8050094df239bd39"
Unverified Commit 346f6dd9 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

use bitshifts for int to int in convert_dtype (#6978)

parent 51e8dace
...@@ -379,15 +379,7 @@ def convert_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.f ...@@ -379,15 +379,7 @@ def convert_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.f
if num_value_bits_input > num_value_bits_output: if num_value_bits_input > num_value_bits_output:
return image.bitwise_right_shift(num_value_bits_input - num_value_bits_output).to(dtype) return image.bitwise_right_shift(num_value_bits_input - num_value_bits_output).to(dtype)
else: else:
# The bitshift kernel is not vectorized return image.to(dtype).bitwise_left_shift_(num_value_bits_output - num_value_bits_input)
# https://github.com/pytorch/pytorch/blob/703c19008df4700b6a522b0ae5c4b6d5ffc0906f/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp#L315-L322
# This results in the multiplication actually being faster.
# TODO: If the bitshift kernel is optimized in core, replace the computation below with
# `image.to(dtype).bitwise_left_shift_(num_value_bits_output - num_value_bits_input)`
max_value_input = float(_FT._max_value(dtype))
max_value_output = float(_FT._max_value(image.dtype))
factor = int((max_value_input + 1) // (max_value_output + 1))
return image.to(dtype).mul_(factor)
# We changed the name to align it with the new naming scheme. Still, `convert_image_dtype` is # We changed the name to align it with the new naming scheme. Still, `convert_image_dtype` is
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment