"vscode:/vscode.git/clone" did not exist on "fad970aa810d78fadfa032db105b62a6572f99aa"
Unverified Commit 9dedc445 authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

[Fix] Fix the cast_tensor_type (#1639)

* [Fix] Fix the cast_tensor_type when the type of inputs is not the same as src_type

* Add note

* improve comments
parent 580e374e
...@@ -24,6 +24,15 @@ except ImportError: ...@@ -24,6 +24,15 @@ except ImportError:
def cast_tensor_type(inputs, src_type, dst_type): def cast_tensor_type(inputs, src_type, dst_type):
"""Recursively convert Tensor in inputs from src_type to dst_type. """Recursively convert Tensor in inputs from src_type to dst_type.
Note:
In v1.4.4 and later, ``cast_tersor_type`` will only convert the
torch.Tensor which is consistent with ``src_type`` to the ``dst_type``.
Before v1.4.4, it ignores the ``src_type`` argument, leading to some
potential problems. For example,
``cast_tensor_type(inputs, torch.float, torch.half)`` will convert all
tensors in inputs to ``torch.half`` including those originally in
``torch.Int`` or other types, which is not expected.
Args: Args:
inputs: Inputs that to be casted. inputs: Inputs that to be casted.
src_type (torch.dtype): Source type.. src_type (torch.dtype): Source type..
...@@ -35,7 +44,9 @@ def cast_tensor_type(inputs, src_type, dst_type): ...@@ -35,7 +44,9 @@ def cast_tensor_type(inputs, src_type, dst_type):
if isinstance(inputs, nn.Module): if isinstance(inputs, nn.Module):
return inputs return inputs
elif isinstance(inputs, torch.Tensor): elif isinstance(inputs, torch.Tensor):
return inputs.to(dst_type) # we need to ensure that the type of inputs to be casted are the same
# as the argument `src_type`.
return inputs.to(dst_type) if inputs.dtype == src_type else inputs
elif isinstance(inputs, str): elif isinstance(inputs, str):
return inputs return inputs
elif isinstance(inputs, np.ndarray): elif isinstance(inputs, np.ndarray):
......
...@@ -14,6 +14,22 @@ def test_cast_tensor_type(): ...@@ -14,6 +14,22 @@ def test_cast_tensor_type():
assert isinstance(outputs, torch.Tensor) assert isinstance(outputs, torch.Tensor)
assert outputs.dtype == dst_type assert outputs.dtype == dst_type
# convert torch.float to torch.half
inputs = torch.FloatTensor([5.])
src_type = torch.float
dst_type = torch.half
outputs = cast_tensor_type(inputs, src_type, dst_type)
assert isinstance(outputs, torch.Tensor)
assert outputs.dtype == dst_type
# skip the conversion when the type of input is not the same as src_type
inputs = torch.IntTensor([5])
src_type = torch.float
dst_type = torch.half
outputs = cast_tensor_type(inputs, src_type, dst_type)
assert isinstance(outputs, torch.Tensor)
assert outputs.dtype == inputs.dtype
inputs = 'tensor' inputs = 'tensor'
src_type = str src_type = str
dst_type = str dst_type = str
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment