"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "067c4a310dd36d0472d4a587145e94d20bf64964"
Unverified Commit d27e4c18 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Move rescale dtype recasting to match torchvision ToTensor (#25229)

Move dtype recasting to match torchvision ToTensor
parent 3170af71
...@@ -110,11 +110,12 @@ def rescale( ...@@ -110,11 +110,12 @@ def rescale(
if not isinstance(image, np.ndarray): if not isinstance(image, np.ndarray):
raise ValueError(f"Input image must be of type np.ndarray, got {type(image)}") raise ValueError(f"Input image must be of type np.ndarray, got {type(image)}")
image = image.astype(dtype)
rescaled_image = image * scale rescaled_image = image * scale
if data_format is not None: if data_format is not None:
rescaled_image = to_channel_dimension_format(rescaled_image, data_format) rescaled_image = to_channel_dimension_format(rescaled_image, data_format)
rescaled_image = rescaled_image.astype(dtype)
return rescaled_image return rescaled_image
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment