Unverified Commit 6c0b7f54 authored by Peter Salas's avatar Peter Salas Committed by GitHub
Browse files

[Core][VLM] Add precise multi-modal placeholder tracking (#8346)


Signed-off-by: default avatarPeter Salas <peter@fixie.ai>
parent d151fde8
...@@ -34,11 +34,7 @@ def run_ultravox(question: str, audio_count: int): ...@@ -34,11 +34,7 @@ def run_ultravox(question: str, audio_count: int):
tokenize=False, tokenize=False,
add_generation_prompt=True) add_generation_prompt=True)
llm = LLM(model=model_name, llm = LLM(model=model_name, limit_mm_per_prompt={"audio": audio_count})
enforce_eager=True,
enable_chunked_prefill=False,
max_model_len=8192,
limit_mm_per_prompt={"audio": audio_count})
stop_token_ids = None stop_token_ids = None
return llm, prompt, stop_token_ids return llm, prompt, stop_token_ids
......
...@@ -869,6 +869,7 @@ def make_test_metadata( ...@@ -869,6 +869,7 @@ def make_test_metadata(
return attn_backend.make_metadata( return attn_backend.make_metadata(
num_prefills=num_prefills, num_prefills=num_prefills,
slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping), slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping),
multi_modal_placeholder_index_maps=None,
num_prefill_tokens=num_prefill_tokens, num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens, num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens, seq_lens=seq_lens,
...@@ -914,6 +915,7 @@ def make_test_metadata( ...@@ -914,6 +915,7 @@ def make_test_metadata(
return attn_backend.make_metadata( return attn_backend.make_metadata(
num_prefills=num_prefills, num_prefills=num_prefills,
slot_mapping=kv_mmap.slot_mapping, slot_mapping=kv_mmap.slot_mapping,
multi_modal_placeholder_index_maps=None,
num_prefill_tokens=num_prefill_tokens, num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens, num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens, seq_lens=seq_lens,
......
...@@ -2,8 +2,10 @@ from typing import List, Optional, Tuple, Type ...@@ -2,8 +2,10 @@ from typing import List, Optional, Tuple, Type
import numpy as np import numpy as np
import pytest import pytest
import pytest_asyncio
from transformers import AutoModel, AutoTokenizer, BatchEncoding from transformers import AutoModel, AutoTokenizer, BatchEncoding
from tests.utils import RemoteOpenAIServer
from vllm.sequence import SampleLogprobs from vllm.sequence import SampleLogprobs
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
...@@ -17,6 +19,13 @@ AudioTuple = Tuple[np.ndarray, int] ...@@ -17,6 +19,13 @@ AudioTuple = Tuple[np.ndarray, int]
VLLM_PLACEHOLDER = "<|reserved_special_token_0|>" VLLM_PLACEHOLDER = "<|reserved_special_token_0|>"
HF_PLACEHOLDER = "<|audio|>" HF_PLACEHOLDER = "<|audio|>"
CHUNKED_PREFILL_KWARGS = {
"enable_chunked_prefill": True,
"max_num_seqs": 2,
# Use a very small limit to exercise chunked prefill.
"max_num_batched_tokens": 16
}
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def audio_assets(): def audio_assets():
...@@ -30,6 +39,26 @@ def audio(request): ...@@ -30,6 +39,26 @@ def audio(request):
return AudioAsset(request.param) return AudioAsset(request.param)
@pytest.fixture(params=({}, CHUNKED_PREFILL_KWARGS))
def server(request, audio_assets):
args = [
"--dtype=bfloat16", "--max-model-len=4096", "--enforce-eager",
f"--limit-mm-per-prompt=audio={len(audio_assets)}"
] + [
f"--{key.replace('_','-')}={value}"
for key, value in request.param.items()
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client
def _get_prompt(audio_count, question, placeholder): def _get_prompt(audio_count, question, placeholder):
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
placeholder = f"{placeholder}\n" * audio_count placeholder = f"{placeholder}\n" * audio_count
...@@ -68,8 +97,7 @@ def run_test( ...@@ -68,8 +97,7 @@ def run_test(
dtype: str, dtype: str,
max_tokens: int, max_tokens: int,
num_logprobs: int, num_logprobs: int,
tensor_parallel_size: int, **kwargs,
distributed_executor_backend: Optional[str] = None,
): ):
"""Inference result should be the same between hf and vllm.""" """Inference result should be the same between hf and vllm."""
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype] torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
...@@ -79,11 +107,8 @@ def run_test( ...@@ -79,11 +107,8 @@ def run_test(
# if we run HF first, the cuda initialization will be done and it # if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method). # will hurt multiprocessing backend with fork method (the default method).
with vllm_runner(model, with vllm_runner(model, dtype=dtype, enforce_eager=True,
dtype=dtype, **kwargs) as vllm_model:
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model:
vllm_outputs_per_audio = [ vllm_outputs_per_audio = [
vllm_model.generate_greedy_logprobs([vllm_prompt], vllm_model.generate_greedy_logprobs([vllm_prompt],
max_tokens, max_tokens,
...@@ -135,18 +160,16 @@ def run_multi_audio_test( ...@@ -135,18 +160,16 @@ def run_multi_audio_test(
dtype: str, dtype: str,
max_tokens: int, max_tokens: int,
num_logprobs: int, num_logprobs: int,
tensor_parallel_size: int, **kwargs,
distributed_executor_backend: Optional[str] = None,
): ):
with vllm_runner(model, with vllm_runner(model,
dtype=dtype, dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True, enforce_eager=True,
limit_mm_per_prompt={ limit_mm_per_prompt={
"audio": "audio":
max((len(audio) for _, audio in prompts_and_audios)) max((len(audio) for _, audio in prompts_and_audios))
}) as vllm_model: },
**kwargs) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs( vllm_outputs = vllm_model.generate_greedy_logprobs(
[prompt for prompt, _ in prompts_and_audios], [prompt for prompt, _ in prompts_and_audios],
max_tokens, max_tokens,
...@@ -162,8 +185,9 @@ def run_multi_audio_test( ...@@ -162,8 +185,9 @@ def run_multi_audio_test(
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("vllm_kwargs", [{}, CHUNKED_PREFILL_KWARGS])
def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int, def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int,
num_logprobs: int) -> None: num_logprobs: int, vllm_kwargs: dict) -> None:
vllm_prompt = _get_prompt(1, "Describe the audio above.", VLLM_PLACEHOLDER) vllm_prompt = _get_prompt(1, "Describe the audio above.", VLLM_PLACEHOLDER)
hf_prompt = _get_prompt(1, "Describe the audio above.", HF_PLACEHOLDER) hf_prompt = _get_prompt(1, "Describe the audio above.", HF_PLACEHOLDER)
...@@ -175,7 +199,7 @@ def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int, ...@@ -175,7 +199,7 @@ def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int,
dtype=dtype, dtype=dtype,
max_tokens=max_tokens, max_tokens=max_tokens,
num_logprobs=num_logprobs, num_logprobs=num_logprobs,
tensor_parallel_size=1, **vllm_kwargs,
) )
...@@ -183,9 +207,10 @@ def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int, ...@@ -183,9 +207,10 @@ def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int,
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("vllm_kwargs", [{}, CHUNKED_PREFILL_KWARGS])
def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str, def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str,
max_tokens: int, max_tokens: int, num_logprobs: int,
num_logprobs: int) -> None: vllm_kwargs: dict) -> None:
vllm_prompt = _get_prompt(len(audio_assets), vllm_prompt = _get_prompt(len(audio_assets),
"Describe each of the audios above.", "Describe each of the audios above.",
...@@ -198,5 +223,37 @@ def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str, ...@@ -198,5 +223,37 @@ def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str,
dtype=dtype, dtype=dtype,
max_tokens=max_tokens, max_tokens=max_tokens,
num_logprobs=num_logprobs, num_logprobs=num_logprobs,
tensor_parallel_size=1, **vllm_kwargs,
) )
@pytest.mark.asyncio
async def test_online_inference(client, audio_assets):
"""Exercises online inference with/without chunked prefill enabled."""
messages = [{
"role":
"user",
"content": [
*[{
"type": "audio_url",
"audio_url": {
"url": audio.url
}
} for audio in audio_assets],
{
"type":
"text",
"text":
f"What's happening in these {len(audio_assets)} audio clips?"
},
],
}]
chat_completion = await client.chat.completions.create(model=MODEL_NAME,
messages=messages,
max_tokens=10)
assert len(chat_completion.choices) == 1
choice = chat_completion.choices[0]
assert choice.finish_reason == "length"
...@@ -5,8 +5,8 @@ from unittest.mock import patch ...@@ -5,8 +5,8 @@ from unittest.mock import patch
import pytest import pytest
import torch import torch
from vllm.inputs import DecoderOnlyInputs, InputContext, token_inputs from vllm.inputs import (DecoderOnlyInputs, DummyData, InputContext,
from vllm.inputs.registry import InputRegistry InputRegistry, token_inputs)
from vllm.multimodal import MultiModalRegistry from vllm.multimodal import MultiModalRegistry
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
...@@ -56,7 +56,7 @@ def use_dummy_data_mock(): ...@@ -56,7 +56,7 @@ def use_dummy_data_mock():
num_crops=DEFAULT_NUM_CROPS): num_crops=DEFAULT_NUM_CROPS):
seq_data = SequenceData( seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * num_crops)) array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * num_crops))
return seq_data, None return DummyData(seq_data, None)
with patch( with patch(
"vllm.inputs.registry.InputRegistry._default_dummy_data_factory", "vllm.inputs.registry.InputRegistry._default_dummy_data_factory",
...@@ -177,9 +177,9 @@ def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops): ...@@ -177,9 +177,9 @@ def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops):
# NOTE: seq_len is thrown away here since this will leverage the # NOTE: seq_len is thrown away here since this will leverage the
# default dummy data factory that we have patched in, whose seq # default dummy data factory that we have patched in, whose seq
# len is solely dependent on the value of the mm_processor_kwargs. # len is solely dependent on the value of the mm_processor_kwargs.
seq_data, _ = dummy_registry.dummy_data_for_profiling( dummy_data = dummy_registry.dummy_data_for_profiling(
ctx.model_config, seq_len=-1, mm_registry=mm_registry) ctx.model_config, seq_len=-1, mm_registry=mm_registry)
assert len(seq_data.prompt_token_ids) == expected_seq_count assert len(dummy_data.seq_data.prompt_token_ids) == expected_seq_count
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -206,9 +206,9 @@ def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock, ...@@ -206,9 +206,9 @@ def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock,
# NOTE: seq_len is thrown away here since this will leverage the # NOTE: seq_len is thrown away here since this will leverage the
# default dummy data factory that we have patched in, whose seq # default dummy data factory that we have patched in, whose seq
# len is solely dependent on the value of the mm_processor_kwargs. # len is solely dependent on the value of the mm_processor_kwargs.
seq_data, _ = dummy_registry.dummy_data_for_profiling( dummy_data = dummy_registry.dummy_data_for_profiling(
ctx.model_config, seq_len=-1, mm_registry=mm_registry) ctx.model_config, seq_len=-1, mm_registry=mm_registry)
assert len(seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS assert len(dummy_data.seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS
### Test overrides for the max token count per multimodal instance ### Test overrides for the max token count per multimodal instance
......
...@@ -92,18 +92,50 @@ def test_repeat_and_pad_placeholder_tokens(model): ...@@ -92,18 +92,50 @@ def test_repeat_and_pad_placeholder_tokens(model):
tokenizer = AutoTokenizer.from_pretrained(model) tokenizer = AutoTokenizer.from_pretrained(model)
test_cases = [ test_cases = [
("<image>", 2, "<image><image>", [32000, 32000]), (
("<image><image>", 2, "<image><image><image>", [32000, 32000, 32000]), "<image>",
("<image><image>", [3, 2], "<image><image><image><image><image>", 2,
[32000, 32000, 32000, 32000, 32000]), "<image><image>",
("Image:<image>Image:<image>!", [3, 2], [32000, 32000],
"Image:<image><image><image>Image:<image><image>!", [{ "offset": 0, "length": 2 }],
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918]), ),
("<image>", [3, 2], "<image><image><image>", [32000, 32000, 32000]), (
] "<image><image>",
2,
for prompt, repeat_count, expected_prompt, expected_token_ids in test_cases: "<image><image><image>",
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( [32000, 32000, 32000],
[{ "offset": 0, "length": 2 }]),
(
"<image><image>",
[3, 2],
"<image><image><image><image><image>",
[32000, 32000, 32000, 32000, 32000],
[{ "offset": 0, "length": 3 }, { "offset": 3, "length": 2 }],
),
(
"Image:<image>Image:<image>!",
[3, 2],
"Image:<image><image><image>Image:<image><image>!",
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
[{ "offset": 2, "length": 3 }, { "offset": 7, "length": 2 }],
),
(
"<image>",
[3, 2],
"<image><image><image>",
[32000, 32000, 32000],
[{ "offset": 0, "length": 3 }],
),
] # yapf: disable
for (
prompt,
repeat_count,
expected_prompt,
expected_token_ids,
expected_ranges,
) in test_cases:
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer=tokenizer, tokenizer=tokenizer,
prompt=prompt, prompt=prompt,
prompt_token_ids=tokenizer.encode(prompt, prompt_token_ids=tokenizer.encode(prompt,
...@@ -113,3 +145,4 @@ def test_repeat_and_pad_placeholder_tokens(model): ...@@ -113,3 +145,4 @@ def test_repeat_and_pad_placeholder_tokens(model):
) )
assert new_prompt == expected_prompt assert new_prompt == expected_prompt
assert new_token_ids == expected_token_ids assert new_token_ids == expected_token_ids
assert ranges == expected_ranges
...@@ -73,6 +73,7 @@ def test_model_runner_input(): ...@@ -73,6 +73,7 @@ def test_model_runner_input():
num_prefill_tokens=2, num_prefill_tokens=2,
num_decode_tokens=3, num_decode_tokens=3,
slot_mapping=torch.zeros(1), slot_mapping=torch.zeros(1),
multi_modal_placeholder_index_maps=None,
) )
model_input = ModelInputForGPUWithSamplingMetadata( model_input = ModelInputForGPUWithSamplingMetadata(
input_tokens=torch.ones(10), input_tokens=torch.ones(10),
...@@ -124,6 +125,7 @@ def test_embedding_model_runner_input(): ...@@ -124,6 +125,7 @@ def test_embedding_model_runner_input():
num_prefill_tokens=2, num_prefill_tokens=2,
num_decode_tokens=3, num_decode_tokens=3,
slot_mapping=torch.zeros(1), slot_mapping=torch.zeros(1),
multi_modal_placeholder_index_maps=None,
) )
model_input = ModelInputForGPUWithPoolingMetadata( model_input = ModelInputForGPUWithPoolingMetadata(
input_tokens=torch.ones(10), input_tokens=torch.ones(10),
...@@ -174,6 +176,7 @@ def test_multi_step_model_runner_input(): ...@@ -174,6 +176,7 @@ def test_multi_step_model_runner_input():
num_prefill_tokens=2, num_prefill_tokens=2,
num_decode_tokens=3, num_decode_tokens=3,
slot_mapping=torch.zeros(1), slot_mapping=torch.zeros(1),
multi_modal_placeholder_index_maps=None,
) )
frozen_model_input = ModelInputForGPUWithSamplingMetadata( frozen_model_input = ModelInputForGPUWithSamplingMetadata(
input_tokens=torch.ones(10), input_tokens=torch.ones(10),
......
...@@ -7,6 +7,8 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set, ...@@ -7,6 +7,8 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
import torch import torch
from vllm.multimodal import MultiModalPlaceholderMap
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.worker.model_runner_base import (ModelRunnerBase, from vllm.worker.model_runner_base import (ModelRunnerBase,
ModelRunnerInputBase, ModelRunnerInputBase,
...@@ -108,6 +110,15 @@ class AttentionMetadata: ...@@ -108,6 +110,15 @@ class AttentionMetadata:
# in block 0, and 1st slot in block 1, respectively. # in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor slot_mapping: torch.Tensor
# The index maps that relate multi-modal embeddings to the corresponding
# placeholders.
#
# N.B. These aren't really related to attention and don't belong on this
# type -- this is just a temporary solution to make them available to
# `model_executable`.
multi_modal_placeholder_index_maps: Optional[Dict[
str, MultiModalPlaceholderMap.IndexMap]]
@property @property
@abstractmethod @abstractmethod
def prefill_metadata(self) -> Optional["AttentionMetadata"]: def prefill_metadata(self) -> Optional["AttentionMetadata"]:
......
...@@ -215,6 +215,8 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata): ...@@ -215,6 +215,8 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
num_prefill_tokens=self.num_prefill_tokens, num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0, num_decode_tokens=0,
slot_mapping=self.slot_mapping[:self.num_prefill_tokens], slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
seq_lens=self.seq_lens[:self.num_prefills], seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
max_query_len=self.max_query_len, max_query_len=self.max_query_len,
...@@ -243,6 +245,7 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata): ...@@ -243,6 +245,7 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
num_prefill_tokens=0, num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens, num_decode_tokens=self.num_decode_tokens,
slot_mapping=self.slot_mapping[self.num_prefill_tokens:], slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
multi_modal_placeholder_index_maps=None,
seq_lens=None, seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_query_len=None, max_query_len=None,
......
"""Attention layer with FlashAttention.""" """Attention layer with FlashAttention."""
from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
...@@ -14,6 +15,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState, ...@@ -14,6 +15,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
compute_slot_mapping_start_idx, compute_slot_mapping_start_idx,
is_block_tables_empty) is_block_tables_empty)
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.multimodal import MultiModalPlaceholderMap
from vllm.utils import (async_tensor_h2d, direct_register_custom_op, from vllm.utils import (async_tensor_h2d, direct_register_custom_op,
make_tensor_with_pad) make_tensor_with_pad)
...@@ -169,6 +171,8 @@ class FlashAttentionMetadata(AttentionMetadata): ...@@ -169,6 +171,8 @@ class FlashAttentionMetadata(AttentionMetadata):
num_prefill_tokens=self.num_prefill_tokens, num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0, num_decode_tokens=0,
slot_mapping=self.slot_mapping[:self.num_prefill_tokens], slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
seq_lens=self.seq_lens[:self.num_prefills], seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
max_query_len=self.max_query_len, max_query_len=self.max_query_len,
...@@ -198,6 +202,7 @@ class FlashAttentionMetadata(AttentionMetadata): ...@@ -198,6 +202,7 @@ class FlashAttentionMetadata(AttentionMetadata):
num_prefill_tokens=0, num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens, num_decode_tokens=self.num_decode_tokens,
slot_mapping=self.slot_mapping[self.num_prefill_tokens:], slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
multi_modal_placeholder_index_maps=None,
seq_lens=None, seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_decode_query_len=self.max_decode_query_len, max_decode_query_len=self.max_decode_query_len,
...@@ -297,6 +302,9 @@ class FlashAttentionMetadataBuilder( ...@@ -297,6 +302,9 @@ class FlashAttentionMetadataBuilder(
self.context_lens: List[int] = [] self.context_lens: List[int] = []
self.block_tables: List[List[int]] = [] self.block_tables: List[List[int]] = []
self.curr_seq_lens: List[int] = [] self.curr_seq_lens: List[int] = []
self.multimodal_placeholder_maps: Dict[
str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
self.num_prefills = 0 self.num_prefills = 0
self.num_prefill_tokens = 0 self.num_prefill_tokens = 0
self.num_decode_tokens = 0 self.num_decode_tokens = 0
...@@ -327,6 +335,12 @@ class FlashAttentionMetadataBuilder( ...@@ -327,6 +335,12 @@ class FlashAttentionMetadataBuilder(
self.context_lens.append(context_len) self.context_lens.append(context_len)
if is_prompt: if is_prompt:
mm_maps = inter_data.multi_modal_placeholder_maps
if mm_maps:
for modality, placeholders in mm_maps.items():
self.multimodal_placeholder_maps[modality].extend(
placeholders)
self.num_prefills += 1 self.num_prefills += 1
self.num_prefill_tokens += token_len self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len) self.prefill_seq_lens.append(seq_len)
...@@ -449,6 +463,11 @@ class FlashAttentionMetadataBuilder( ...@@ -449,6 +463,11 @@ class FlashAttentionMetadataBuilder(
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
dtype=torch.int32, dtype=torch.int32,
device=device) device=device)
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
self.multimodal_placeholder_maps.items()
}
torch.cumsum(seq_lens_tensor, torch.cumsum(seq_lens_tensor,
dim=0, dim=0,
dtype=seq_start_loc.dtype, dtype=seq_start_loc.dtype,
...@@ -464,6 +483,7 @@ class FlashAttentionMetadataBuilder( ...@@ -464,6 +483,7 @@ class FlashAttentionMetadataBuilder(
num_prefill_tokens=self.num_prefill_tokens, num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens, num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens, seq_lens=seq_lens,
multi_modal_placeholder_index_maps=placeholder_index_maps,
seq_lens_tensor=seq_lens_tensor, seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len, max_query_len=max_query_len,
max_decode_query_len=max_decode_query_len, max_decode_query_len=max_decode_query_len,
......
from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type
from vllm.multimodal import MultiModalPlaceholderMap
try: try:
from flashinfer import BatchDecodeWithPagedKVCacheWrapper from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
...@@ -215,6 +218,7 @@ class FlashInferState(AttentionState): ...@@ -215,6 +218,7 @@ class FlashInferState(AttentionState):
attn_metadata = self.runner.attn_backend.make_metadata( attn_metadata = self.runner.attn_backend.make_metadata(
num_prefills=0, num_prefills=0,
slot_mapping=self._graph_slot_mapping[:batch_size], slot_mapping=self._graph_slot_mapping[:batch_size],
multi_modal_placeholder_index_maps=None,
num_prefill_tokens=0, num_prefill_tokens=0,
num_decode_tokens=batch_size, num_decode_tokens=batch_size,
max_prefill_seq_len=0, max_prefill_seq_len=0,
...@@ -470,6 +474,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -470,6 +474,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.context_lens: List[int] = [] self.context_lens: List[int] = []
self.block_tables: List[List[int]] = [] self.block_tables: List[List[int]] = []
self.curr_seq_lens: List[int] = [] self.curr_seq_lens: List[int] = []
self.multimodal_placeholder_maps: Dict[
str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
self.num_prefills = 0 self.num_prefills = 0
self.num_prefill_tokens = 0 self.num_prefill_tokens = 0
self.num_decode_tokens = 0 self.num_decode_tokens = 0
...@@ -519,6 +526,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -519,6 +526,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
inter_data.curr_sliding_window_blocks): inter_data.curr_sliding_window_blocks):
self.context_lens.append(context_len) self.context_lens.append(context_len)
if is_prompt: if is_prompt:
mm_maps = inter_data.multi_modal_placeholder_maps
if mm_maps:
for modality, placeholders in mm_maps.items():
self.multimodal_placeholder_maps[modality].extend(
placeholders)
self.num_prefills += 1 self.num_prefills += 1
self.num_prefill_tokens += token_len self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len) self.prefill_seq_lens.append(seq_len)
...@@ -651,6 +663,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -651,6 +663,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
dtype=torch.int32, dtype=torch.int32,
device=device) device=device)
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
self.multimodal_placeholder_maps.items()
}
torch.cumsum(seq_lens_tensor, torch.cumsum(seq_lens_tensor,
dim=0, dim=0,
dtype=seq_start_loc.dtype, dtype=seq_start_loc.dtype,
...@@ -694,6 +711,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -694,6 +711,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
decode_query_len=decode_query_len, decode_query_len=decode_query_len,
num_prefills=self.num_prefills, num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor, slot_mapping=slot_mapping_tensor,
multi_modal_placeholder_index_maps=placeholder_index_maps,
num_prefill_tokens=self.num_prefill_tokens, num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens, num_decode_tokens=num_decode_tokens,
max_prefill_seq_len=max_prefill_seq_len, max_prefill_seq_len=max_prefill_seq_len,
......
from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional, Tuple, Type from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type
import torch import torch
...@@ -7,6 +8,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, ...@@ -7,6 +8,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionMetadata,
AttentionMetadataBuilder) AttentionMetadataBuilder)
from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.backends.utils import CommonAttentionState
from vllm.multimodal import MultiModalPlaceholderMap
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUBuilder from vllm.worker.model_runner import ModelInputForGPUBuilder
...@@ -135,6 +137,8 @@ class PlaceholderAttentionMetadata(AttentionMetadata): ...@@ -135,6 +137,8 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
num_prefill_tokens=self.num_prefill_tokens, num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0, num_decode_tokens=0,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
seq_lens=self.seq_lens[:self.num_prefills], seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
max_decode_query_len=0, max_decode_query_len=0,
...@@ -167,6 +171,7 @@ class PlaceholderAttentionMetadata(AttentionMetadata): ...@@ -167,6 +171,7 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
num_prefill_tokens=0, num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens, num_decode_tokens=self.num_decode_tokens,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
seq_lens=None, seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_decode_query_len=self.max_decode_query_len, max_decode_query_len=self.max_decode_query_len,
...@@ -189,6 +194,9 @@ class PlaceholderAttentionMetadataBuilder( ...@@ -189,6 +194,9 @@ class PlaceholderAttentionMetadataBuilder(
self.prefill_seq_lens: List[int] = [] self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = [] self.context_lens: List[int] = []
self.curr_seq_lens: List[int] = [] self.curr_seq_lens: List[int] = []
self.multimodal_placeholder_maps: Dict[
str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
self.num_prefills = 0 self.num_prefills = 0
self.num_prefill_tokens = 0 self.num_prefill_tokens = 0
self.num_decode_tokens = 0 self.num_decode_tokens = 0
...@@ -213,6 +221,12 @@ class PlaceholderAttentionMetadataBuilder( ...@@ -213,6 +221,12 @@ class PlaceholderAttentionMetadataBuilder(
self.context_lens.append(context_len) self.context_lens.append(context_len)
if is_prompt: if is_prompt:
mm_maps = inter_data.multi_modal_placeholder_maps
if mm_maps:
for modality, placeholders in mm_maps.items():
self.multimodal_placeholder_maps[modality].extend(
placeholders)
self.num_prefills += 1 self.num_prefills += 1
self.num_prefill_tokens += token_len self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len) self.prefill_seq_lens.append(seq_len)
...@@ -280,6 +294,11 @@ class PlaceholderAttentionMetadataBuilder( ...@@ -280,6 +294,11 @@ class PlaceholderAttentionMetadataBuilder(
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
dtype=torch.int32, dtype=torch.int32,
device=device) device=device)
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
self.multimodal_placeholder_maps.items()
}
torch.cumsum(seq_lens_tensor, torch.cumsum(seq_lens_tensor,
dim=0, dim=0,
dtype=seq_start_loc.dtype, dtype=seq_start_loc.dtype,
...@@ -296,6 +315,7 @@ class PlaceholderAttentionMetadataBuilder( ...@@ -296,6 +315,7 @@ class PlaceholderAttentionMetadataBuilder(
return PlaceholderAttentionMetadata( return PlaceholderAttentionMetadata(
num_prefills=self.num_prefills, num_prefills=self.num_prefills,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=placeholder_index_maps,
num_prefill_tokens=self.num_prefill_tokens, num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens, num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens, seq_lens=seq_lens,
......
...@@ -150,6 +150,8 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -150,6 +150,8 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
num_prefill_tokens=self.num_prefill_tokens, num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0, num_decode_tokens=0,
slot_mapping=self.slot_mapping[:self.num_prefill_tokens], slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
seq_lens=self.seq_lens[:self.num_prefills], seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
max_query_len=self.max_query_len, max_query_len=self.max_query_len,
...@@ -178,6 +180,7 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -178,6 +180,7 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
num_prefill_tokens=0, num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens, num_decode_tokens=self.num_decode_tokens,
slot_mapping=self.slot_mapping[self.num_prefill_tokens:], slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
multi_modal_placeholder_index_maps=None,
seq_lens=None, seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_query_len=None, max_query_len=None,
......
"""Attention backend utils""" """Attention backend utils"""
from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Dict, List, Type, TypeVar, Union from typing import TYPE_CHECKING, Any, Dict, List, Type, TypeVar, Union
...@@ -7,6 +8,7 @@ import torch ...@@ -7,6 +8,7 @@ import torch
from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder, from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
AttentionState) AttentionState)
from vllm.multimodal import MultiModalPlaceholderMap
from vllm.utils import async_tensor_h2d, make_tensor_with_pad from vllm.utils import async_tensor_h2d, make_tensor_with_pad
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -123,6 +125,9 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): ...@@ -123,6 +125,9 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
self.context_lens: List[int] = [] self.context_lens: List[int] = []
self.block_tables: List[List[int]] = [] self.block_tables: List[List[int]] = []
self.curr_seq_lens: List[int] = [] self.curr_seq_lens: List[int] = []
self.multimodal_placeholder_maps: Dict[
str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
self.num_prefills = 0 self.num_prefills = 0
self.num_prefill_tokens = 0 self.num_prefill_tokens = 0
self.num_decode_tokens = 0 self.num_decode_tokens = 0
...@@ -147,6 +152,12 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): ...@@ -147,6 +152,12 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
inter_data.curr_sliding_window_blocks): inter_data.curr_sliding_window_blocks):
self.context_lens.append(context_len) self.context_lens.append(context_len)
if is_prompt: if is_prompt:
mm_maps = inter_data.multi_modal_placeholder_maps
if mm_maps:
for modality, placeholders in mm_maps.items():
self.multimodal_placeholder_maps[modality].extend(
placeholders)
self.num_prefills += 1 self.num_prefills += 1
self.num_prefill_tokens += token_len self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len) self.prefill_seq_lens.append(seq_len)
...@@ -242,6 +253,11 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): ...@@ -242,6 +253,11 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
dtype=torch.int32, dtype=torch.int32,
device=device) device=device)
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
self.multimodal_placeholder_maps.items()
}
torch.cumsum(seq_lens_tensor, torch.cumsum(seq_lens_tensor,
dim=0, dim=0,
dtype=seq_start_loc.dtype, dtype=seq_start_loc.dtype,
...@@ -254,6 +270,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): ...@@ -254,6 +270,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
return self._metadata_cls( # type: ignore return self._metadata_cls( # type: ignore
num_prefills=self.num_prefills, num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor, slot_mapping=slot_mapping_tensor,
multi_modal_placeholder_index_maps=placeholder_index_maps,
num_prefill_tokens=self.num_prefill_tokens, num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens, num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens, seq_lens=seq_lens,
...@@ -305,6 +322,7 @@ class CommonAttentionState(AttentionState): ...@@ -305,6 +322,7 @@ class CommonAttentionState(AttentionState):
num_prefill_tokens=0, num_prefill_tokens=0,
num_decode_tokens=batch_size, num_decode_tokens=batch_size,
slot_mapping=self._graph_slot_mapping[:batch_size], slot_mapping=self._graph_slot_mapping[:batch_size],
multi_modal_placeholder_index_maps=None,
seq_lens=None, seq_lens=None,
seq_lens_tensor=self._graph_seq_lens[:batch_size], seq_lens_tensor=self._graph_seq_lens[:batch_size],
max_query_len=1, max_query_len=1,
......
...@@ -212,6 +212,8 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -212,6 +212,8 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
num_prefill_tokens=self.num_prefill_tokens, num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0, num_decode_tokens=0,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
seq_lens=seq_lens, seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor, seq_lens_tensor=seq_lens_tensor,
max_query_len=self.max_query_len, max_query_len=self.max_query_len,
...@@ -255,6 +257,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -255,6 +257,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
num_prefill_tokens=0, num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens, num_decode_tokens=self.num_decode_tokens,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
seq_lens_tensor=seq_lens_tensor, seq_lens_tensor=seq_lens_tensor,
max_prefill_seq_len=0, max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len, max_decode_seq_len=self.max_decode_seq_len,
......
...@@ -1308,6 +1308,8 @@ class Scheduler: ...@@ -1308,6 +1308,8 @@ class Scheduler:
# `multi_modal_data` will be None. # `multi_modal_data` will be None.
multi_modal_data=seq_group.multi_modal_data multi_modal_data=seq_group.multi_modal_data
if scheduler_outputs.num_prefill_groups > 0 else None, if scheduler_outputs.num_prefill_groups > 0 else None,
multi_modal_placeholders=seq_group.multi_modal_placeholders
if scheduler_outputs.num_prefill_groups > 0 else None,
mm_processor_kwargs=seq_group.mm_processor_kwargs, mm_processor_kwargs=seq_group.mm_processor_kwargs,
prompt_adapter_request=seq_group.prompt_adapter_request, prompt_adapter_request=seq_group.prompt_adapter_request,
) )
......
...@@ -3,7 +3,7 @@ from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ...@@ -3,7 +3,7 @@ from .data import (DecoderOnlyInputs, EncoderDecoderInputs,
SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt, SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt,
build_explicit_enc_dec_prompt, to_enc_dec_tuple_list, build_explicit_enc_dec_prompt, to_enc_dec_tuple_list,
token_inputs, zip_enc_dec_prompts) token_inputs, zip_enc_dec_prompts)
from .registry import InputContext, InputRegistry from .registry import DummyData, InputContext, InputRegistry
INPUT_REGISTRY = InputRegistry() INPUT_REGISTRY = InputRegistry()
""" """
...@@ -29,6 +29,7 @@ __all__ = [ ...@@ -29,6 +29,7 @@ __all__ = [
"to_enc_dec_tuple_list", "to_enc_dec_tuple_list",
"zip_enc_dec_prompts", "zip_enc_dec_prompts",
"INPUT_REGISTRY", "INPUT_REGISTRY",
"DummyData",
"InputContext", "InputContext",
"InputRegistry", "InputRegistry",
] ]
......
...@@ -4,7 +4,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List, ...@@ -4,7 +4,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List,
from typing_extensions import NotRequired, TypedDict, TypeVar from typing_extensions import NotRequired, TypedDict, TypeVar
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.multimodal import MultiModalDataDict from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict
class TextPrompt(TypedDict): class TextPrompt(TypedDict):
...@@ -136,6 +136,12 @@ class TokenInputs(TypedDict): ...@@ -136,6 +136,12 @@ class TokenInputs(TypedDict):
if the model supports it. if the model supports it.
""" """
multi_modal_placeholders: NotRequired[
Optional["MultiModalPlaceholderDict"]]
"""
Placeholder ranges for the multi-modal data.
"""
mm_processor_kwargs: NotRequired[Optional[Dict[str, Any]]] mm_processor_kwargs: NotRequired[Optional[Dict[str, Any]]]
""" """
Optional multi-modal processor kwargs to be forwarded to the Optional multi-modal processor kwargs to be forwarded to the
...@@ -149,6 +155,7 @@ def token_inputs( ...@@ -149,6 +155,7 @@ def token_inputs(
prompt_token_ids: List[int], prompt_token_ids: List[int],
prompt: Optional[str] = None, prompt: Optional[str] = None,
multi_modal_data: Optional["MultiModalDataDict"] = None, multi_modal_data: Optional["MultiModalDataDict"] = None,
multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None,
) -> TokenInputs: ) -> TokenInputs:
"""Construct :class:`TokenInputs` from optional values.""" """Construct :class:`TokenInputs` from optional values."""
...@@ -158,6 +165,8 @@ def token_inputs( ...@@ -158,6 +165,8 @@ def token_inputs(
inputs["prompt"] = prompt inputs["prompt"] = prompt
if multi_modal_data is not None: if multi_modal_data is not None:
inputs["multi_modal_data"] = multi_modal_data inputs["multi_modal_data"] = multi_modal_data
if multi_modal_placeholders is not None:
inputs["multi_modal_placeholders"] = multi_modal_placeholders
if mm_processor_kwargs is not None: if mm_processor_kwargs is not None:
inputs["mm_processor_kwargs"] = mm_processor_kwargs inputs["mm_processor_kwargs"] = mm_processor_kwargs
......
import functools import functools
from collections import UserDict from collections import UserDict
from dataclasses import dataclass from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, Optional, from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, NamedTuple,
Protocol, Tuple, Type) Optional, Protocol, Type)
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
...@@ -16,7 +16,8 @@ from .data import DecoderOnlyInputs ...@@ -16,7 +16,8 @@ from .data import DecoderOnlyInputs
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.multimodal import MultiModalDataDict, MultiModalRegistry from vllm.multimodal import (MultiModalDataDict, MultiModalPlaceholderDict,
MultiModalRegistry)
from vllm.sequence import SequenceData from vllm.sequence import SequenceData
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -63,6 +64,14 @@ class InputContext: ...@@ -63,6 +64,14 @@ class InputContext:
N = TypeVar("N", bound=Type[nn.Module]) N = TypeVar("N", bound=Type[nn.Module])
class DummyData(NamedTuple):
"""Dummy data used for profiling."""
seq_data: "SequenceData"
multi_modal_data: Optional["MultiModalDataDict"] = None
multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None
class DummyDataFactory(Protocol): class DummyDataFactory(Protocol):
def __call__( def __call__(
...@@ -71,7 +80,7 @@ class DummyDataFactory(Protocol): ...@@ -71,7 +80,7 @@ class DummyDataFactory(Protocol):
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
**mm_processor_kwargs: Any, **mm_processor_kwargs: Any,
) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]: ) -> DummyData:
""" """
Create dummy data to be inputted into the model. Create dummy data to be inputted into the model.
...@@ -123,7 +132,7 @@ class InputRegistry: ...@@ -123,7 +132,7 @@ class InputRegistry:
ctx: InputContext, ctx: InputContext,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]: ) -> DummyData:
""" """
The default dummy data factory represents the longest possible text The default dummy data factory represents the longest possible text
that can be inputted to the model. that can be inputted to the model.
...@@ -134,10 +143,7 @@ class InputRegistry: ...@@ -134,10 +143,7 @@ class InputRegistry:
# Avoid circular import # Avoid circular import
from vllm.sequence import SequenceData from vllm.sequence import SequenceData
dummy_seq_data = SequenceData.from_prompt_token_counts((0, seq_len)) return DummyData(SequenceData.from_prompt_token_counts((0, seq_len)))
dummy_multi_modal_data = None
return dummy_seq_data, dummy_multi_modal_data
def register_dummy_data(self, factory: DummyDataFactory): def register_dummy_data(self, factory: DummyDataFactory):
""" """
...@@ -195,7 +201,7 @@ class InputRegistry: ...@@ -195,7 +201,7 @@ class InputRegistry:
seq_len: int, seq_len: int,
mm_registry: "MultiModalRegistry", mm_registry: "MultiModalRegistry",
is_encoder_data: bool = False, is_encoder_data: bool = False,
) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]: ) -> DummyData:
""" """
Create dummy data for profiling the memory usage of a model. Create dummy data for profiling the memory usage of a model.
...@@ -220,12 +226,12 @@ class InputRegistry: ...@@ -220,12 +226,12 @@ class InputRegistry:
mm_processor_kwargs = get_allowed_kwarg_only_overrides( mm_processor_kwargs = get_allowed_kwarg_only_overrides(
dummy_factory, overrides=model_config.mm_processor_kwargs) dummy_factory, overrides=model_config.mm_processor_kwargs)
seq_data, mm_data = dummy_factory(InputContext(model_config), seq_len, dummy_data = dummy_factory(InputContext(model_config), seq_len,
_MultiModalCounts(mm_counts), _MultiModalCounts(mm_counts),
**mm_processor_kwargs) **mm_processor_kwargs)
# Having more tokens is over-conservative but otherwise fine # Having more tokens is over-conservative but otherwise fine
num_tokens = seq_data.prompt_token_ids num_tokens = dummy_data.seq_data.prompt_token_ids
if len(num_tokens) < seq_len: if len(num_tokens) < seq_len:
if is_encoder_data: if is_encoder_data:
print_warning_once( print_warning_once(
...@@ -235,15 +241,15 @@ class InputRegistry: ...@@ -235,15 +241,15 @@ class InputRegistry:
raise AssertionError( raise AssertionError(
f"Expected at least {seq_len} dummy tokens for profiling, " f"Expected at least {seq_len} dummy tokens for profiling, "
f"but found {len(num_tokens)} tokens instead.") f"but found {len(num_tokens)} tokens instead.")
if mm_data is not None: if dummy_data.multi_modal_data is not None:
for k, v in mm_data.items(): for k, v in dummy_data.multi_modal_data.items():
num_items = len(v) if isinstance(v, list) else 1 num_items = len(v) if isinstance(v, list) else 1
num_expected = mm_counts[k] num_expected = mm_counts[k]
assert num_items >= num_expected, ( assert num_items >= num_expected, (
f"Expected at least {num_expected} dummy '{k}' instances " f"Expected at least {num_expected} dummy '{k}' instances "
f"for profiling, but found {num_items} instances instead.") f"for profiling, but found {num_items} instances instead.")
return seq_data, mm_data return dummy_data
def _default_input_processor( def _default_input_processor(
self, self,
......
...@@ -98,6 +98,11 @@ def input_processor_for_blip( ...@@ -98,6 +98,11 @@ def input_processor_for_blip(
if multi_modal_data is None or "image" not in multi_modal_data: if multi_modal_data is None or "image" not in multi_modal_data:
return inputs return inputs
if "multi_modal_placeholders" in inputs and "image" in inputs[
"multi_modal_placeholders"]:
# The inputs already have placeholders.
return inputs
tokenizer = cached_get_tokenizer(model_config.tokenizer) tokenizer = cached_get_tokenizer(model_config.tokenizer)
if image_feature_size_override is None: if image_feature_size_override is None:
...@@ -105,7 +110,7 @@ def input_processor_for_blip( ...@@ -105,7 +110,7 @@ def input_processor_for_blip(
else: else:
image_feature_size = image_feature_size_override image_feature_size = image_feature_size_override
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer, tokenizer,
inputs.get("prompt"), inputs.get("prompt"),
inputs["prompt_token_ids"], inputs["prompt_token_ids"],
...@@ -116,7 +121,8 @@ def input_processor_for_blip( ...@@ -116,7 +121,8 @@ def input_processor_for_blip(
# NOTE: Create a defensive copy of the original inputs # NOTE: Create a defensive copy of the original inputs
return token_inputs(prompt_token_ids=new_token_ids, return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt, prompt=new_prompt,
multi_modal_data=multi_modal_data) multi_modal_data=multi_modal_data,
multi_modal_placeholders={"image": ranges})
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa
......
...@@ -9,13 +9,14 @@ from transformers import (Blip2Config, Blip2QFormerConfig, Blip2VisionConfig, ...@@ -9,13 +9,14 @@ from transformers import (Blip2Config, Blip2QFormerConfig, Blip2VisionConfig,
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
token_inputs) InputContext, token_inputs)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import consecutive_placeholder_ranges
from vllm.sequence import IntermediateTensors, SequenceData from vllm.sequence import IntermediateTensors, SequenceData
from .blip import (BlipVisionModel, dummy_image_for_blip, from .blip import (BlipVisionModel, dummy_image_for_blip,
...@@ -425,7 +426,11 @@ def dummy_seq_data_for_blip2( ...@@ -425,7 +426,11 @@ def dummy_seq_data_for_blip2(
return SequenceData.from_prompt_token_counts( return SequenceData.from_prompt_token_counts(
(image_token_id, image_feature_size * num_images), (image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images), (0, seq_len - image_feature_size * num_images),
) ), {
"image":
consecutive_placeholder_ranges(num_items=num_images,
item_size=image_feature_size)
}
def dummy_data_for_blip2(ctx: InputContext, seq_len: int, def dummy_data_for_blip2(ctx: InputContext, seq_len: int,
...@@ -434,7 +439,7 @@ def dummy_data_for_blip2(ctx: InputContext, seq_len: int, ...@@ -434,7 +439,7 @@ def dummy_data_for_blip2(ctx: InputContext, seq_len: int,
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
num_images = mm_counts["image"] num_images = mm_counts["image"]
seq_data = dummy_seq_data_for_blip2( seq_data, ranges = dummy_seq_data_for_blip2(
hf_config, hf_config,
seq_len, seq_len,
num_images, num_images,
...@@ -444,7 +449,7 @@ def dummy_data_for_blip2(ctx: InputContext, seq_len: int, ...@@ -444,7 +449,7 @@ def dummy_data_for_blip2(ctx: InputContext, seq_len: int,
if isinstance(vision_config, Blip2VisionConfig): if isinstance(vision_config, Blip2VisionConfig):
mm_data = dummy_image_for_blip(vision_config, num_images) mm_data = dummy_image_for_blip(vision_config, num_images)
return seq_data, mm_data return DummyData(seq_data, mm_data, ranges)
msg = f"Unsupported vision config: {type(vision_config)}" msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg) raise NotImplementedError(msg)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment