Unverified Commit 036ca94c authored by Travis Johnson's avatar Travis Johnson Committed by GitHub
Browse files

[Bugfix] handle alignment of arguments in convert_sparse_cross_attention_mask_to_dense (#12347)


Signed-off-by: default avatarTravis Johnson <tsjohnso@us.ibm.com>
Signed-off-by: default avatarWallas Santos <wallashss@ibm.com>
Co-authored-by: default avatarWallas Santos <wallashss@ibm.com>
parent ef001d98
from typing import List, Optional, Tuple, Type, overload from typing import List, Optional, Tuple, Type, overload
import pytest import pytest
import torch
from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer, from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
BatchEncoding) BatchEncoding)
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.attention.selector import (_Backend, _cached_get_attn_backend, from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
global_force_attn_backend_context_manager) global_force_attn_backend_context_manager)
from vllm.model_executor.models.mllama import (MLLAMA_IMAGE_TOKEN_ID,
MllamaForConditionalGeneration)
from vllm.multimodal.image import rescale_image_size from vllm.multimodal.image import rescale_image_size
from vllm.sequence import SampleLogprobs from vllm.sequence import SampleLogprobs
...@@ -33,6 +37,29 @@ models = [ ...@@ -33,6 +37,29 @@ models = [
"meta-llama/Llama-3.2-11B-Vision-Instruct", "meta-llama/Llama-3.2-11B-Vision-Instruct",
] ]
# Indices for inputs
TEXT_ONLY = '0'
IMAGE_AT_BEG = '1'
IMAGE_AT_MIDDLE = '2'
TWO_IMAGES = '3'
# Input tokenized
prompt_data = {
# Tell me a story
TEXT_ONLY: [41551, 757, 264, 3446],
# <|image|> What's the content of this image
IMAGE_AT_BEG:
[MLLAMA_IMAGE_TOKEN_ID, 3639, 596, 279, 2262, 315, 420, 2217, 220],
# Hello <|image|>What' the content of this image
IMAGE_AT_MIDDLE:
[9906, 220, MLLAMA_IMAGE_TOKEN_ID, 3923, 6, 279, 2262, 315, 420, 2217],
#<|image|>Is there a duck in this image?<|image|>What's the animal in this image? # noqa: E501
TWO_IMAGES: [
MLLAMA_IMAGE_TOKEN_ID, 3957, 1070, 264, 37085, 304, 420, 2217, 30,
MLLAMA_IMAGE_TOKEN_ID, 3923, 596, 279, 10065, 304, 420, 2217, 30
]
}
def vllm_to_hf_output(vllm_output: Tuple[List[int], str, def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
Optional[SampleLogprobs]], Optional[SampleLogprobs]],
...@@ -365,3 +392,184 @@ def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model, ...@@ -365,3 +392,184 @@ def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model,
num_logprobs=num_logprobs, num_logprobs=num_logprobs,
tensor_parallel_size=1, tensor_parallel_size=1,
) )
@large_gpu_test(min_gb=48)
@pytest.mark.core_model
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
def test_regression(vllm_runner, image_assets, model, dtype, max_tokens,
num_logprobs, attn_backend: _Backend) -> None:
stop_sign = image_assets[0].pil_image
with global_force_attn_backend_context_manager(attn_backend), vllm_runner(
model,
dtype=dtype,
max_model_len=4096,
max_num_seqs=2,
tensor_parallel_size=1,
enforce_eager=True,
limit_mm_per_prompt={"image":
_LIMIT_IMAGE_PER_PROMPT}) as vllm_model:
# Regression tests for https://github.com/vllm-project/vllm/issues/10648
# Number of image tags is greater than the number of images provided
prompt = "<|begin_of_text|><|image|><|image|> Compare the two images" # noqa: E501
image = stop_sign
with pytest.raises(ValueError):
vllm_model.generate_greedy_logprobs([prompt],
max_tokens,
num_logprobs,
images=[image])
# Batch of a text-only and image request that requires cross-attention
prompts = [
"What is the capital of spain?",
"Text before the image...<|image|>What is in the image?", # noqa: E501
]
images = [
None,
[stop_sign],
]
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs,
images=images)
# Test the reverse order too for good measure
prompts = [
"<|begin_of_text|>Text before the image...<|image|>What is in the image?", # noqa: E501
"<|begin_of_text|>Hello!",
]
images = [
[stop_sign],
None,
]
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs,
images=images)
@pytest.mark.core_model
@pytest.mark.parametrize(
"input_indices_and_output",
# inputs, (cross_attention_mask, kv_range_for_decode)
[([TEXT_ONLY], (None, None)), ([IMAGE_AT_BEG], (None, None)),
([TEXT_ONLY, IMAGE_AT_BEG], (None, None)),
([IMAGE_AT_MIDDLE], ((10, 12), [[0, 6]])),
([TEXT_ONLY, IMAGE_AT_MIDDLE], ((14, 12), [[0, 6]])),
([TEXT_ONLY, IMAGE_AT_BEG, IMAGE_AT_MIDDLE],
((23, 24), [[0, 6], [6, 12]])),
([IMAGE_AT_MIDDLE, TEXT_ONLY], ((14, 12), [[0, 6]])),
([TWO_IMAGES], ((18, 12), [[6, 12]])),
([TEXT_ONLY, TWO_IMAGES], ((22, 12), [[6, 12]]))])
def test_get_cross_attention_mask(input_indices_and_output) -> None:
input_indices, expected_output = input_indices_and_output
sequences = [torch.tensor(prompt_data[i]) for i in input_indices]
num_tiles = [[2, 2] if i != TEXT_ONLY else [] for i in input_indices
if i != TEXT_ONLY]
input = torch.cat(sequences)
seq_lens = [len(s) for s in sequences]
attn_data = FlashAttentionMetadata(
seq_lens=seq_lens,
# Dummy values
enable_kv_scales_calculation=False,
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=0,
slot_mapping=0,
multi_modal_placeholder_index_maps=None,
seq_lens_tensor=0,
max_prefill_seq_len=0,
max_decode_seq_len=0,
context_lens_tensor=None,
block_tables=None,
use_cuda_graph=False,
)
dummy: dict[str, str] = {}
cross_attention_mask, kv_range_for_decode = MllamaForConditionalGeneration\
.get_cross_attention_mask(dummy,
input,
attn_data,
num_tiles=num_tiles,
num_tokens_per_tile=3,
dtype=torch.bfloat16)
expected_cross_attention_mask, expected_kv_range_for_decode = \
expected_output
assert kv_range_for_decode == expected_kv_range_for_decode
if expected_cross_attention_mask is not None:
assert cross_attention_mask is not None
assert cross_attention_mask.shape == expected_cross_attention_mask
else:
assert cross_attention_mask is None
@pytest.mark.core_model
@pytest.mark.parametrize(
"input_indices",
[[TEXT_ONLY], [IMAGE_AT_BEG], [TEXT_ONLY, IMAGE_AT_BEG], [IMAGE_AT_MIDDLE],
[TEXT_ONLY, IMAGE_AT_MIDDLE], [TEXT_ONLY, IMAGE_AT_BEG, IMAGE_AT_MIDDLE],
[IMAGE_AT_MIDDLE, TEXT_ONLY], [TWO_IMAGES], [TEXT_ONLY, TWO_IMAGES]])
def test_get_full_text_row_masked_out_mask(input_indices) -> None:
sequences = [torch.tensor(prompt_data[i]) for i in input_indices]
seq_lens = [len(s) for s in sequences]
num_prefill_tokens = sum(seq_lens)
# TEXT_ONLY is zero, so it will be masked out,
# other instances should not be.
encoder_seq_lens = [int(i) for i in input_indices]
attn_data = FlashAttentionMetadata(
seq_lens=seq_lens,
encoder_seq_lens=encoder_seq_lens,
num_prefill_tokens=num_prefill_tokens,
# Dummy values
enable_kv_scales_calculation=False,
num_prefills=0,
num_decode_tokens=0,
slot_mapping=0,
multi_modal_placeholder_index_maps=None,
seq_lens_tensor=0,
max_prefill_seq_len=0,
max_decode_seq_len=0,
context_lens_tensor=None,
block_tables=None,
use_cuda_graph=False,
)
dummy: dict[str, str] = {}
full_text_row_masked_out_mask = MllamaForConditionalGeneration\
.get_full_text_row_masked_out_mask(dummy,
attn_data,
torch.get_default_device())
full_text_row_masked_out_mask = full_text_row_masked_out_mask.squeeze()
full_text_row_masked_out_mask = full_text_row_masked_out_mask.tolist()
idx = 0
assert len(full_text_row_masked_out_mask) == num_prefill_tokens
for i, seq_len in enumerate(seq_lens):
must_be_masked = input_indices[i] != TEXT_ONLY
for _ in range(seq_len):
assert full_text_row_masked_out_mask[idx] == must_be_masked, \
f"full_text_row_masked_out_mask[{idx}] must be " \
f"'{must_be_masked}' "
idx += 1
...@@ -1485,14 +1485,23 @@ def convert_sparse_cross_attention_mask_to_dense( ...@@ -1485,14 +1485,23 @@ def convert_sparse_cross_attention_mask_to_dense(
total_length = sum(lengths) total_length = sum(lengths)
total_tiles = sum([sum(tiles) for tiles in num_tiles]) total_tiles = sum([sum(tiles) for tiles in num_tiles])
dense_mask = np.zeros(shape=(total_length, total_tiles), dtype=np.int64) dense_mask = np.zeros(shape=(total_length, total_tiles), dtype=np.int64)
# A list of ranges, range[i] = [start, end] means # A list of ranges, range[i] = [start, end] means that the i-th image will
# if the i-th sample has N tiles in total, the tiles[start, end] # use tiles[start, end] for cross-attention decoding.
# will be used for cross-attention decoding.
tile_range_for_decode = [] tile_range_for_decode = []
seq_start = 0 seq_start = 0
tile_start = 0 tile_start = 0
for masks, tiles, length in zip(sparse_mask, num_tiles, lengths):
# sparse_mask has an [] entry for each sequence that does not have images,
# but num_tiles does not have these entries...
num_tiles_idx = 0
for masks, length in zip(sparse_mask, lengths):
if len(masks) == 0:
# Text only
continue
tiles = num_tiles[num_tiles_idx]
num_tiles_idx += 1
ts, td = -1, 0 ts, td = -1, 0
for mask, tile in zip(masks, tiles): for mask, tile in zip(masks, tiles):
if len(mask) != 2: if len(mask) != 2:
...@@ -1512,6 +1521,7 @@ def convert_sparse_cross_attention_mask_to_dense( ...@@ -1512,6 +1521,7 @@ def convert_sparse_cross_attention_mask_to_dense(
assert td != 0 assert td != 0
tile_range_for_decode.append((ts, ts + td)) tile_range_for_decode.append((ts, ts + td))
seq_start += length seq_start += length
assert num_tiles_idx == len(num_tiles)
return dense_mask, tile_range_for_decode return dense_mask, tile_range_for_decode
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment