Unverified Commit 490f17d0 authored by Khairul Kabir's avatar Khairul Kabir Committed by GitHub
Browse files

[Multimodal] Fix nested_tensors_equal: add length check for lists and tuple support (#38388)


Signed-off-by: default avatarkhairulkabir1661 <khairulkabir1661@users.noreply.github.com>
Co-authored-by: default avatarkhairulkabir1661 <khairulkabir1661@users.noreply.github.com>
parent 2e984060
......@@ -238,12 +238,29 @@ def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
return isinstance(a, torch.Tensor) and torch.equal(b, a)
if isinstance(a, list):
return isinstance(b, list) and all(
nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b)
return (
isinstance(b, list)
and len(a) == len(b)
and all(nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b))
)
if isinstance(b, list):
return isinstance(a, list) and all(
nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a)
return (
isinstance(a, list)
and len(b) == len(a)
and all(nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a))
)
if isinstance(a, tuple):
return (
isinstance(b, tuple)
and len(a) == len(b)
and all(nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b))
)
if isinstance(b, tuple):
return (
isinstance(a, tuple)
and len(b) == len(a)
and all(nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a))
)
# Both a and b are scalars
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment