Unverified Commit 9dedc445 authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

[Fix] Fix the cast_tensor_type (#1639)

* [Fix] Fix the cast_tensor_type when the type of inputs is not the same as src_type

* Add note

* improve comments
parent 580e374e
......@@ -24,6 +24,15 @@ except ImportError:
def cast_tensor_type(inputs, src_type, dst_type):
"""Recursively convert Tensor in inputs from src_type to dst_type.
Note:
In v1.4.4 and later, ``cast_tersor_type`` will only convert the
torch.Tensor which is consistent with ``src_type`` to the ``dst_type``.
Before v1.4.4, it ignores the ``src_type`` argument, leading to some
potential problems. For example,
``cast_tensor_type(inputs, torch.float, torch.half)`` will convert all
tensors in inputs to ``torch.half`` including those originally in
``torch.Int`` or other types, which is not expected.
Args:
inputs: Inputs that to be casted.
src_type (torch.dtype): Source type..
......@@ -35,7 +44,9 @@ def cast_tensor_type(inputs, src_type, dst_type):
if isinstance(inputs, nn.Module):
return inputs
elif isinstance(inputs, torch.Tensor):
return inputs.to(dst_type)
# we need to ensure that the type of inputs to be casted are the same
# as the argument `src_type`.
return inputs.to(dst_type) if inputs.dtype == src_type else inputs
elif isinstance(inputs, str):
return inputs
elif isinstance(inputs, np.ndarray):
......
......@@ -14,6 +14,22 @@ def test_cast_tensor_type():
assert isinstance(outputs, torch.Tensor)
assert outputs.dtype == dst_type
# convert torch.float to torch.half
inputs = torch.FloatTensor([5.])
src_type = torch.float
dst_type = torch.half
outputs = cast_tensor_type(inputs, src_type, dst_type)
assert isinstance(outputs, torch.Tensor)
assert outputs.dtype == dst_type
# skip the conversion when the type of input is not the same as src_type
inputs = torch.IntTensor([5])
src_type = torch.float
dst_type = torch.half
outputs = cast_tensor_type(inputs, src_type, dst_type)
assert isinstance(outputs, torch.Tensor)
assert outputs.dtype == inputs.dtype
inputs = 'tensor'
src_type = str
dst_type = str
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment