Unverified Commit 547b5582 authored by Daemyung Jang's avatar Daemyung Jang Committed by GitHub
Browse files

Consider inheritance in type checking for tensors (#31378)

* Consider inheritance in type checking for tensors

Add an additional check to bypass type assertion when both tensors are
torch.Tensor instances.

* Fix the quality issue
parent 3b5fa14f
......@@ -131,9 +131,10 @@ def nested_concat(tensors, new_tensors, padding_index=-100):
Concat the `new_tensors` to `tensors` on the first dim and pad them on the second if needed. Works for tensors or
nested list/tuples/dict of tensors.
"""
assert type(tensors) == type(
new_tensors
), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}."
if not (isinstance(tensors, torch.Tensor) and isinstance(new_tensors, torch.Tensor)):
assert (
type(tensors) == type(new_tensors)
), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}."
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_concat(t, n, padding_index=padding_index) for t, n in zip(tensors, new_tensors))
elif isinstance(tensors, torch.Tensor):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment