Unverified Commit 74d5543e authored by Peter Salas's avatar Peter Salas Committed by GitHub
Browse files

[VLM][Core] Fix exceptions on ragged NestedTensors (#7974)

parent a7f65c2b
...@@ -81,3 +81,15 @@ def test_multimodal_input_batch_multiple_batchable_lists(): ...@@ -81,3 +81,15 @@ def test_multimodal_input_batch_multiple_batchable_lists():
result, result,
{"image": torch.stack([torch.stack([a, b]), {"image": torch.stack([torch.stack([a, b]),
torch.stack([c, d])])}) torch.stack([c, d])])})
def test_multimodal_input_batch_mixed_stacking_depths():
a = torch.rand([1, 2, 3])
b = torch.rand([1, 3, 3])
c = torch.rand([1, 4, 3])
result = MultiModalInputs.batch([{"image": [a, b]}, {"image": [c]}])
assert_multimodal_inputs_equal(result, {"image": [[a, b], c.unsqueeze(0)]})
result = MultiModalInputs.batch([{"image": [a]}, {"image": [b, c]}])
assert_multimodal_inputs_equal(result, {"image": [a.unsqueeze(0), [b, c]]})
from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple, from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple,
Union, overload) Union, overload)
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.func import functional_call from torch.func import functional_call
...@@ -96,12 +95,13 @@ def flatten_bn( ...@@ -96,12 +95,13 @@ def flatten_bn(
def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor: def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor:
""" """
Recursively concatenates NestedTensors along any heterogeneously sized Recursively flattens and concatenates NestedTensors on all but the last
dimensions. dimension.
""" """
if isinstance(embeddings, torch.Tensor): if isinstance(embeddings, torch.Tensor):
return embeddings # Flatten all but the last dimension.
return embeddings.flatten(0, -2)
return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings)) return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))
...@@ -136,15 +136,13 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor, ...@@ -136,15 +136,13 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor,
assert isinstance(num_expected_tokens, int) assert isinstance(num_expected_tokens, int)
flattened = _flatten_embeddings(multimodal_embeddings) flattened = _flatten_embeddings(multimodal_embeddings)
*dims, embed_dim = flattened.shape if flattened.shape[0] != num_expected_tokens:
num_multimodal_embeddings = np.prod(dims)
if num_multimodal_embeddings != num_expected_tokens:
expr = _embedding_count_expression(multimodal_embeddings) expr = _embedding_count_expression(multimodal_embeddings)
raise ValueError( raise ValueError(
f"Attempted to assign {expr} = {num_multimodal_embeddings} " f"Attempted to assign {expr} = {flattened.shape[0]} "
f"multimodal tokens to {num_expected_tokens} placeholders") f"multimodal tokens to {num_expected_tokens} placeholders")
inputs_embeds[mask] = flattened.view(num_expected_tokens, embed_dim) inputs_embeds[mask] = flattened
return inputs_embeds return inputs_embeds
......
...@@ -54,8 +54,8 @@ class MultiModalInputs(_MultiModalInputsBase): ...@@ -54,8 +54,8 @@ class MultiModalInputs(_MultiModalInputsBase):
return nested_tensors return nested_tensors
stacked = [MultiModalInputs._try_stack(t) for t in nested_tensors] stacked = [MultiModalInputs._try_stack(t) for t in nested_tensors]
if is_list_of(stacked, list): if not is_list_of(stacked, torch.Tensor, check="all"):
# Do not stack nested lists # Only tensors (not lists) can be stacked.
return stacked return stacked
tensors_ = cast(List[torch.Tensor], stacked) tensors_ = cast(List[torch.Tensor], stacked)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment