Unverified Commit d27e4c18 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Move rescale dtype recasting to match torchvision ToTensor (#25229)

Move dtype recasting to match torchvision ToTensor
parent 3170af71
......@@ -110,11 +110,12 @@ def rescale(
if not isinstance(image, np.ndarray):
raise ValueError(f"Input image must be of type np.ndarray, got {type(image)}")
image = image.astype(dtype)
rescaled_image = image * scale
if data_format is not None:
rescaled_image = to_channel_dimension_format(rescaled_image, data_format)
rescaled_image = rescaled_image.astype(dtype)
return rescaled_image
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment