Unverified Commit d805aeae authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Fixed issues in elastic transform (#7257)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent 3080082d
......@@ -1538,10 +1538,21 @@ def elastic_image_tensor(
device = image.device
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
# Patch: elastic transform should support (cpu,f16) input
is_cpu_half = device.type == "cpu" and dtype == torch.float16
if is_cpu_half:
image = image.to(torch.float32)
dtype = torch.float32
# We are aware that if input image dtype is uint8 and displacement is float64 then
# displacement will be casted to float32 and all computations will be done with float32
# We can fix this later if needed
expected_shape = (1,) + shape[-2:] + (2,)
if expected_shape != displacement.shape:
raise ValueError(f"Argument displacement shape should be {expected_shape}, but given {displacement.shape}")
if ndim > 4:
image = image.reshape((-1,) + shape[-3:])
needs_unsquash = True
......@@ -1561,6 +1572,9 @@ def elastic_image_tensor(
if needs_unsquash:
output = output.reshape(shape)
if is_cpu_half:
output = output.to(torch.float16)
return output
......@@ -1676,6 +1690,9 @@ def elastic(
if not torch.jit.is_scripting():
_log_api_usage_once(elastic)
if not isinstance(displacement, torch.Tensor):
raise TypeError("Argument displacement should be a Tensor")
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return elastic_image_tensor(inpt, displacement, interpolation=interpolation, fill=fill)
elif isinstance(inpt, datapoints._datapoint.Datapoint):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment