Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6c0b7f54
Unverified
Commit
6c0b7f54
authored
Nov 01, 2024
by
Peter Salas
Committed by
GitHub
Nov 01, 2024
Browse files
[Core][VLM] Add precise multi-modal placeholder tracking (#8346)
Signed-off-by:
Peter Salas
<
peter@fixie.ai
>
parent
d151fde8
Changes
53
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
284 additions
and
68 deletions
+284
-68
examples/offline_inference_audio_language.py
examples/offline_inference_audio_language.py
+1
-5
tests/kernels/utils.py
tests/kernels/utils.py
+2
-0
tests/models/decoder_only/audio_language/test_ultravox.py
tests/models/decoder_only/audio_language/test_ultravox.py
+74
-17
tests/multimodal/test_processor_kwargs.py
tests/multimodal/test_processor_kwargs.py
+7
-7
tests/multimodal/test_utils.py
tests/multimodal/test_utils.py
+45
-12
tests/worker/test_model_input.py
tests/worker/test_model_input.py
+3
-0
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+11
-0
vllm/attention/backends/blocksparse_attn.py
vllm/attention/backends/blocksparse_attn.py
+3
-0
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+20
-0
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+18
-0
vllm/attention/backends/placeholder_attn.py
vllm/attention/backends/placeholder_attn.py
+21
-1
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+3
-0
vllm/attention/backends/utils.py
vllm/attention/backends/utils.py
+18
-0
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+3
-0
vllm/core/scheduler.py
vllm/core/scheduler.py
+2
-0
vllm/inputs/__init__.py
vllm/inputs/__init__.py
+2
-1
vllm/inputs/data.py
vllm/inputs/data.py
+10
-1
vllm/inputs/registry.py
vllm/inputs/registry.py
+23
-17
vllm/model_executor/models/blip.py
vllm/model_executor/models/blip.py
+8
-2
vllm/model_executor/models/blip2.py
vllm/model_executor/models/blip2.py
+10
-5
No files found.
examples/offline_inference_audio_language.py
View file @
6c0b7f54
...
@@ -34,11 +34,7 @@ def run_ultravox(question: str, audio_count: int):
...
@@ -34,11 +34,7 @@ def run_ultravox(question: str, audio_count: int):
tokenize
=
False
,
tokenize
=
False
,
add_generation_prompt
=
True
)
add_generation_prompt
=
True
)
llm
=
LLM
(
model
=
model_name
,
llm
=
LLM
(
model
=
model_name
,
limit_mm_per_prompt
=
{
"audio"
:
audio_count
})
enforce_eager
=
True
,
enable_chunked_prefill
=
False
,
max_model_len
=
8192
,
limit_mm_per_prompt
=
{
"audio"
:
audio_count
})
stop_token_ids
=
None
stop_token_ids
=
None
return
llm
,
prompt
,
stop_token_ids
return
llm
,
prompt
,
stop_token_ids
...
...
tests/kernels/utils.py
View file @
6c0b7f54
...
@@ -869,6 +869,7 @@ def make_test_metadata(
...
@@ -869,6 +869,7 @@ def make_test_metadata(
return
attn_backend
.
make_metadata
(
return
attn_backend
.
make_metadata
(
num_prefills
=
num_prefills
,
num_prefills
=
num_prefills
,
slot_mapping
=
(
None
if
kv_mmap
is
None
else
kv_mmap
.
slot_mapping
),
slot_mapping
=
(
None
if
kv_mmap
is
None
else
kv_mmap
.
slot_mapping
),
multi_modal_placeholder_index_maps
=
None
,
num_prefill_tokens
=
num_prefill_tokens
,
num_prefill_tokens
=
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
num_decode_tokens
=
num_decode_tokens
,
seq_lens
=
seq_lens
,
seq_lens
=
seq_lens
,
...
@@ -914,6 +915,7 @@ def make_test_metadata(
...
@@ -914,6 +915,7 @@ def make_test_metadata(
return
attn_backend
.
make_metadata
(
return
attn_backend
.
make_metadata
(
num_prefills
=
num_prefills
,
num_prefills
=
num_prefills
,
slot_mapping
=
kv_mmap
.
slot_mapping
,
slot_mapping
=
kv_mmap
.
slot_mapping
,
multi_modal_placeholder_index_maps
=
None
,
num_prefill_tokens
=
num_prefill_tokens
,
num_prefill_tokens
=
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
num_decode_tokens
=
num_decode_tokens
,
seq_lens
=
seq_lens
,
seq_lens
=
seq_lens
,
...
...
tests/models/decoder_only/audio_language/test_ultravox.py
View file @
6c0b7f54
...
@@ -2,8 +2,10 @@ from typing import List, Optional, Tuple, Type
...
@@ -2,8 +2,10 @@ from typing import List, Optional, Tuple, Type
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
import
pytest_asyncio
from
transformers
import
AutoModel
,
AutoTokenizer
,
BatchEncoding
from
transformers
import
AutoModel
,
AutoTokenizer
,
BatchEncoding
from
tests.utils
import
RemoteOpenAIServer
from
vllm.sequence
import
SampleLogprobs
from
vllm.sequence
import
SampleLogprobs
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
...
@@ -17,6 +19,13 @@ AudioTuple = Tuple[np.ndarray, int]
...
@@ -17,6 +19,13 @@ AudioTuple = Tuple[np.ndarray, int]
VLLM_PLACEHOLDER
=
"<|reserved_special_token_0|>"
VLLM_PLACEHOLDER
=
"<|reserved_special_token_0|>"
HF_PLACEHOLDER
=
"<|audio|>"
HF_PLACEHOLDER
=
"<|audio|>"
CHUNKED_PREFILL_KWARGS
=
{
"enable_chunked_prefill"
:
True
,
"max_num_seqs"
:
2
,
# Use a very small limit to exercise chunked prefill.
"max_num_batched_tokens"
:
16
}
@
pytest
.
fixture
(
scope
=
"session"
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
audio_assets
():
def
audio_assets
():
...
@@ -30,6 +39,26 @@ def audio(request):
...
@@ -30,6 +39,26 @@ def audio(request):
return
AudioAsset
(
request
.
param
)
return
AudioAsset
(
request
.
param
)
@
pytest
.
fixture
(
params
=
({},
CHUNKED_PREFILL_KWARGS
))
def
server
(
request
,
audio_assets
):
args
=
[
"--dtype=bfloat16"
,
"--max-model-len=4096"
,
"--enforce-eager"
,
f
"--limit-mm-per-prompt=audio=
{
len
(
audio_assets
)
}
"
]
+
[
f
"--
{
key
.
replace
(
'_'
,
'-'
)
}
=
{
value
}
"
for
key
,
value
in
request
.
param
.
items
()
]
with
RemoteOpenAIServer
(
MODEL_NAME
,
args
)
as
remote_server
:
yield
remote_server
@
pytest_asyncio
.
fixture
async
def
client
(
server
):
async
with
server
.
get_async_client
()
as
async_client
:
yield
async_client
def
_get_prompt
(
audio_count
,
question
,
placeholder
):
def
_get_prompt
(
audio_count
,
question
,
placeholder
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
MODEL_NAME
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
MODEL_NAME
)
placeholder
=
f
"
{
placeholder
}
\n
"
*
audio_count
placeholder
=
f
"
{
placeholder
}
\n
"
*
audio_count
...
@@ -68,8 +97,7 @@ def run_test(
...
@@ -68,8 +97,7 @@ def run_test(
dtype
:
str
,
dtype
:
str
,
max_tokens
:
int
,
max_tokens
:
int
,
num_logprobs
:
int
,
num_logprobs
:
int
,
tensor_parallel_size
:
int
,
**
kwargs
,
distributed_executor_backend
:
Optional
[
str
]
=
None
,
):
):
"""Inference result should be the same between hf and vllm."""
"""Inference result should be the same between hf and vllm."""
torch_dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
dtype
]
torch_dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
dtype
]
...
@@ -79,11 +107,8 @@ def run_test(
...
@@ -79,11 +107,8 @@ def run_test(
# if we run HF first, the cuda initialization will be done and it
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
# will hurt multiprocessing backend with fork method (the default method).
with
vllm_runner
(
model
,
with
vllm_runner
(
model
,
dtype
=
dtype
,
enforce_eager
=
True
,
dtype
=
dtype
,
**
kwargs
)
as
vllm_model
:
tensor_parallel_size
=
tensor_parallel_size
,
distributed_executor_backend
=
distributed_executor_backend
,
enforce_eager
=
True
)
as
vllm_model
:
vllm_outputs_per_audio
=
[
vllm_outputs_per_audio
=
[
vllm_model
.
generate_greedy_logprobs
([
vllm_prompt
],
vllm_model
.
generate_greedy_logprobs
([
vllm_prompt
],
max_tokens
,
max_tokens
,
...
@@ -135,18 +160,16 @@ def run_multi_audio_test(
...
@@ -135,18 +160,16 @@ def run_multi_audio_test(
dtype
:
str
,
dtype
:
str
,
max_tokens
:
int
,
max_tokens
:
int
,
num_logprobs
:
int
,
num_logprobs
:
int
,
tensor_parallel_size
:
int
,
**
kwargs
,
distributed_executor_backend
:
Optional
[
str
]
=
None
,
):
):
with
vllm_runner
(
model
,
with
vllm_runner
(
model
,
dtype
=
dtype
,
dtype
=
dtype
,
tensor_parallel_size
=
tensor_parallel_size
,
distributed_executor_backend
=
distributed_executor_backend
,
enforce_eager
=
True
,
enforce_eager
=
True
,
limit_mm_per_prompt
=
{
limit_mm_per_prompt
=
{
"audio"
:
"audio"
:
max
((
len
(
audio
)
for
_
,
audio
in
prompts_and_audios
))
max
((
len
(
audio
)
for
_
,
audio
in
prompts_and_audios
))
})
as
vllm_model
:
},
**
kwargs
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_greedy_logprobs
(
vllm_outputs
=
vllm_model
.
generate_greedy_logprobs
(
[
prompt
for
prompt
,
_
in
prompts_and_audios
],
[
prompt
for
prompt
,
_
in
prompts_and_audios
],
max_tokens
,
max_tokens
,
...
@@ -162,8 +185,9 @@ def run_multi_audio_test(
...
@@ -162,8 +185,9 @@ def run_multi_audio_test(
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"vllm_kwargs"
,
[{},
CHUNKED_PREFILL_KWARGS
])
def
test_models
(
hf_runner
,
vllm_runner
,
audio
,
dtype
:
str
,
max_tokens
:
int
,
def
test_models
(
hf_runner
,
vllm_runner
,
audio
,
dtype
:
str
,
max_tokens
:
int
,
num_logprobs
:
int
)
->
None
:
num_logprobs
:
int
,
vllm_kwargs
:
dict
)
->
None
:
vllm_prompt
=
_get_prompt
(
1
,
"Describe the audio above."
,
VLLM_PLACEHOLDER
)
vllm_prompt
=
_get_prompt
(
1
,
"Describe the audio above."
,
VLLM_PLACEHOLDER
)
hf_prompt
=
_get_prompt
(
1
,
"Describe the audio above."
,
HF_PLACEHOLDER
)
hf_prompt
=
_get_prompt
(
1
,
"Describe the audio above."
,
HF_PLACEHOLDER
)
...
@@ -175,7 +199,7 @@ def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int,
...
@@ -175,7 +199,7 @@ def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int,
dtype
=
dtype
,
dtype
=
dtype
,
max_tokens
=
max_tokens
,
max_tokens
=
max_tokens
,
num_logprobs
=
num_logprobs
,
num_logprobs
=
num_logprobs
,
tensor_parallel_size
=
1
,
**
vllm_kwargs
,
)
)
...
@@ -183,9 +207,10 @@ def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int,
...
@@ -183,9 +207,10 @@ def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int,
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"vllm_kwargs"
,
[{},
CHUNKED_PREFILL_KWARGS
])
def
test_models_with_multiple_audios
(
vllm_runner
,
audio_assets
,
dtype
:
str
,
def
test_models_with_multiple_audios
(
vllm_runner
,
audio_assets
,
dtype
:
str
,
max_tokens
:
int
,
max_tokens
:
int
,
num_logprobs
:
int
,
num_logprobs
:
in
t
)
->
None
:
vllm_kwargs
:
dic
t
)
->
None
:
vllm_prompt
=
_get_prompt
(
len
(
audio_assets
),
vllm_prompt
=
_get_prompt
(
len
(
audio_assets
),
"Describe each of the audios above."
,
"Describe each of the audios above."
,
...
@@ -198,5 +223,37 @@ def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str,
...
@@ -198,5 +223,37 @@ def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str,
dtype
=
dtype
,
dtype
=
dtype
,
max_tokens
=
max_tokens
,
max_tokens
=
max_tokens
,
num_logprobs
=
num_logprobs
,
num_logprobs
=
num_logprobs
,
tensor_parallel_size
=
1
,
**
vllm_kwargs
,
)
)
@
pytest
.
mark
.
asyncio
async
def
test_online_inference
(
client
,
audio_assets
):
"""Exercises online inference with/without chunked prefill enabled."""
messages
=
[{
"role"
:
"user"
,
"content"
:
[
*
[{
"type"
:
"audio_url"
,
"audio_url"
:
{
"url"
:
audio
.
url
}
}
for
audio
in
audio_assets
],
{
"type"
:
"text"
,
"text"
:
f
"What's happening in these
{
len
(
audio_assets
)
}
audio clips?"
},
],
}]
chat_completion
=
await
client
.
chat
.
completions
.
create
(
model
=
MODEL_NAME
,
messages
=
messages
,
max_tokens
=
10
)
assert
len
(
chat_completion
.
choices
)
==
1
choice
=
chat_completion
.
choices
[
0
]
assert
choice
.
finish_reason
==
"length"
tests/multimodal/test_processor_kwargs.py
View file @
6c0b7f54
...
@@ -5,8 +5,8 @@ from unittest.mock import patch
...
@@ -5,8 +5,8 @@ from unittest.mock import patch
import
pytest
import
pytest
import
torch
import
torch
from
vllm.inputs
import
DecoderOnlyInputs
,
InputContext
,
token_inputs
from
vllm.inputs
import
(
DecoderOnlyInputs
,
DummyData
,
InputContext
,
from
vllm.i
nput
s.r
egistry
import
InputRegistry
I
nput
R
egistry
,
token_inputs
)
from
vllm.multimodal
import
MultiModalRegistry
from
vllm.multimodal
import
MultiModalRegistry
from
vllm.sequence
import
VLLM_TOKEN_ID_ARRAY_TYPE
,
SequenceData
from
vllm.sequence
import
VLLM_TOKEN_ID_ARRAY_TYPE
,
SequenceData
...
@@ -56,7 +56,7 @@ def use_dummy_data_mock():
...
@@ -56,7 +56,7 @@ def use_dummy_data_mock():
num_crops
=
DEFAULT_NUM_CROPS
):
num_crops
=
DEFAULT_NUM_CROPS
):
seq_data
=
SequenceData
(
seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
]
*
num_crops
))
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
]
*
num_crops
))
return
seq_data
,
None
return
DummyData
(
seq_data
,
None
)
with
patch
(
with
patch
(
"vllm.inputs.registry.InputRegistry._default_dummy_data_factory"
,
"vllm.inputs.registry.InputRegistry._default_dummy_data_factory"
,
...
@@ -177,9 +177,9 @@ def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops):
...
@@ -177,9 +177,9 @@ def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops):
# NOTE: seq_len is thrown away here since this will leverage the
# NOTE: seq_len is thrown away here since this will leverage the
# default dummy data factory that we have patched in, whose seq
# default dummy data factory that we have patched in, whose seq
# len is solely dependent on the value of the mm_processor_kwargs.
# len is solely dependent on the value of the mm_processor_kwargs.
seq
_data
,
_
=
dummy_registry
.
dummy_data_for_profiling
(
dummy
_data
=
dummy_registry
.
dummy_data_for_profiling
(
ctx
.
model_config
,
seq_len
=-
1
,
mm_registry
=
mm_registry
)
ctx
.
model_config
,
seq_len
=-
1
,
mm_registry
=
mm_registry
)
assert
len
(
seq_data
.
prompt_token_ids
)
==
expected_seq_count
assert
len
(
dummy_data
.
seq_data
.
prompt_token_ids
)
==
expected_seq_count
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -206,9 +206,9 @@ def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock,
...
@@ -206,9 +206,9 @@ def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock,
# NOTE: seq_len is thrown away here since this will leverage the
# NOTE: seq_len is thrown away here since this will leverage the
# default dummy data factory that we have patched in, whose seq
# default dummy data factory that we have patched in, whose seq
# len is solely dependent on the value of the mm_processor_kwargs.
# len is solely dependent on the value of the mm_processor_kwargs.
seq
_data
,
_
=
dummy_registry
.
dummy_data_for_profiling
(
dummy
_data
=
dummy_registry
.
dummy_data_for_profiling
(
ctx
.
model_config
,
seq_len
=-
1
,
mm_registry
=
mm_registry
)
ctx
.
model_config
,
seq_len
=-
1
,
mm_registry
=
mm_registry
)
assert
len
(
seq_data
.
prompt_token_ids
)
==
DEFAULT_NUM_CROPS
assert
len
(
dummy_data
.
seq_data
.
prompt_token_ids
)
==
DEFAULT_NUM_CROPS
### Test overrides for the max token count per multimodal instance
### Test overrides for the max token count per multimodal instance
...
...
tests/multimodal/test_utils.py
View file @
6c0b7f54
...
@@ -92,18 +92,50 @@ def test_repeat_and_pad_placeholder_tokens(model):
...
@@ -92,18 +92,50 @@ def test_repeat_and_pad_placeholder_tokens(model):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model
)
test_cases
=
[
test_cases
=
[
(
"<image>"
,
2
,
"<image><image>"
,
[
32000
,
32000
]),
(
(
"<image><image>"
,
2
,
"<image><image><image>"
,
[
32000
,
32000
,
32000
]),
"<image>"
,
(
"<image><image>"
,
[
3
,
2
],
"<image><image><image><image><image>"
,
2
,
[
32000
,
32000
,
32000
,
32000
,
32000
]),
"<image><image>"
,
(
"Image:<image>Image:<image>!"
,
[
3
,
2
],
[
32000
,
32000
],
"Image:<image><image><image>Image:<image><image>!"
,
[{
"offset"
:
0
,
"length"
:
2
}],
[
9833
,
28747
,
32000
,
32000
,
32000
,
9833
,
28747
,
32000
,
32000
,
918
]),
),
(
"<image>"
,
[
3
,
2
],
"<image><image><image>"
,
[
32000
,
32000
,
32000
]),
(
]
"<image><image>"
,
2
,
for
prompt
,
repeat_count
,
expected_prompt
,
expected_token_ids
in
test_cases
:
"<image><image><image>"
,
new_prompt
,
new_token_ids
=
repeat_and_pad_placeholder_tokens
(
[
32000
,
32000
,
32000
],
[{
"offset"
:
0
,
"length"
:
2
}]),
(
"<image><image>"
,
[
3
,
2
],
"<image><image><image><image><image>"
,
[
32000
,
32000
,
32000
,
32000
,
32000
],
[{
"offset"
:
0
,
"length"
:
3
},
{
"offset"
:
3
,
"length"
:
2
}],
),
(
"Image:<image>Image:<image>!"
,
[
3
,
2
],
"Image:<image><image><image>Image:<image><image>!"
,
[
9833
,
28747
,
32000
,
32000
,
32000
,
9833
,
28747
,
32000
,
32000
,
918
],
[{
"offset"
:
2
,
"length"
:
3
},
{
"offset"
:
7
,
"length"
:
2
}],
),
(
"<image>"
,
[
3
,
2
],
"<image><image><image>"
,
[
32000
,
32000
,
32000
],
[{
"offset"
:
0
,
"length"
:
3
}],
),
]
# yapf: disable
for
(
prompt
,
repeat_count
,
expected_prompt
,
expected_token_ids
,
expected_ranges
,
)
in
test_cases
:
new_prompt
,
new_token_ids
,
ranges
=
repeat_and_pad_placeholder_tokens
(
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
prompt
=
prompt
,
prompt
=
prompt
,
prompt_token_ids
=
tokenizer
.
encode
(
prompt
,
prompt_token_ids
=
tokenizer
.
encode
(
prompt
,
...
@@ -113,3 +145,4 @@ def test_repeat_and_pad_placeholder_tokens(model):
...
@@ -113,3 +145,4 @@ def test_repeat_and_pad_placeholder_tokens(model):
)
)
assert
new_prompt
==
expected_prompt
assert
new_prompt
==
expected_prompt
assert
new_token_ids
==
expected_token_ids
assert
new_token_ids
==
expected_token_ids
assert
ranges
==
expected_ranges
tests/worker/test_model_input.py
View file @
6c0b7f54
...
@@ -73,6 +73,7 @@ def test_model_runner_input():
...
@@ -73,6 +73,7 @@ def test_model_runner_input():
num_prefill_tokens
=
2
,
num_prefill_tokens
=
2
,
num_decode_tokens
=
3
,
num_decode_tokens
=
3
,
slot_mapping
=
torch
.
zeros
(
1
),
slot_mapping
=
torch
.
zeros
(
1
),
multi_modal_placeholder_index_maps
=
None
,
)
)
model_input
=
ModelInputForGPUWithSamplingMetadata
(
model_input
=
ModelInputForGPUWithSamplingMetadata
(
input_tokens
=
torch
.
ones
(
10
),
input_tokens
=
torch
.
ones
(
10
),
...
@@ -124,6 +125,7 @@ def test_embedding_model_runner_input():
...
@@ -124,6 +125,7 @@ def test_embedding_model_runner_input():
num_prefill_tokens
=
2
,
num_prefill_tokens
=
2
,
num_decode_tokens
=
3
,
num_decode_tokens
=
3
,
slot_mapping
=
torch
.
zeros
(
1
),
slot_mapping
=
torch
.
zeros
(
1
),
multi_modal_placeholder_index_maps
=
None
,
)
)
model_input
=
ModelInputForGPUWithPoolingMetadata
(
model_input
=
ModelInputForGPUWithPoolingMetadata
(
input_tokens
=
torch
.
ones
(
10
),
input_tokens
=
torch
.
ones
(
10
),
...
@@ -174,6 +176,7 @@ def test_multi_step_model_runner_input():
...
@@ -174,6 +176,7 @@ def test_multi_step_model_runner_input():
num_prefill_tokens
=
2
,
num_prefill_tokens
=
2
,
num_decode_tokens
=
3
,
num_decode_tokens
=
3
,
slot_mapping
=
torch
.
zeros
(
1
),
slot_mapping
=
torch
.
zeros
(
1
),
multi_modal_placeholder_index_maps
=
None
,
)
)
frozen_model_input
=
ModelInputForGPUWithSamplingMetadata
(
frozen_model_input
=
ModelInputForGPUWithSamplingMetadata
(
input_tokens
=
torch
.
ones
(
10
),
input_tokens
=
torch
.
ones
(
10
),
...
...
vllm/attention/backends/abstract.py
View file @
6c0b7f54
...
@@ -7,6 +7,8 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
...
@@ -7,6 +7,8 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
import
torch
import
torch
from
vllm.multimodal
import
MultiModalPlaceholderMap
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.worker.model_runner_base
import
(
ModelRunnerBase
,
from
vllm.worker.model_runner_base
import
(
ModelRunnerBase
,
ModelRunnerInputBase
,
ModelRunnerInputBase
,
...
@@ -108,6 +110,15 @@ class AttentionMetadata:
...
@@ -108,6 +110,15 @@ class AttentionMetadata:
# in block 0, and 1st slot in block 1, respectively.
# in block 0, and 1st slot in block 1, respectively.
slot_mapping
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
# The index maps that relate multi-modal embeddings to the corresponding
# placeholders.
#
# N.B. These aren't really related to attention and don't belong on this
# type -- this is just a temporary solution to make them available to
# `model_executable`.
multi_modal_placeholder_index_maps
:
Optional
[
Dict
[
str
,
MultiModalPlaceholderMap
.
IndexMap
]]
@
property
@
property
@
abstractmethod
@
abstractmethod
def
prefill_metadata
(
self
)
->
Optional
[
"AttentionMetadata"
]:
def
prefill_metadata
(
self
)
->
Optional
[
"AttentionMetadata"
]:
...
...
vllm/attention/backends/blocksparse_attn.py
View file @
6c0b7f54
...
@@ -215,6 +215,8 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
...
@@ -215,6 +215,8 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
0
,
num_decode_tokens
=
0
,
slot_mapping
=
self
.
slot_mapping
[:
self
.
num_prefill_tokens
],
slot_mapping
=
self
.
slot_mapping
[:
self
.
num_prefill_tokens
],
multi_modal_placeholder_index_maps
=
self
.
multi_modal_placeholder_index_maps
,
seq_lens
=
self
.
seq_lens
[:
self
.
num_prefills
],
seq_lens
=
self
.
seq_lens
[:
self
.
num_prefills
],
seq_lens_tensor
=
self
.
seq_lens_tensor
[:
self
.
num_prefills
],
seq_lens_tensor
=
self
.
seq_lens_tensor
[:
self
.
num_prefills
],
max_query_len
=
self
.
max_query_len
,
max_query_len
=
self
.
max_query_len
,
...
@@ -243,6 +245,7 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
...
@@ -243,6 +245,7 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
num_prefill_tokens
=
0
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
self
.
num_decode_tokens
,
num_decode_tokens
=
self
.
num_decode_tokens
,
slot_mapping
=
self
.
slot_mapping
[
self
.
num_prefill_tokens
:],
slot_mapping
=
self
.
slot_mapping
[
self
.
num_prefill_tokens
:],
multi_modal_placeholder_index_maps
=
None
,
seq_lens
=
None
,
seq_lens
=
None
,
seq_lens_tensor
=
self
.
seq_lens_tensor
[
self
.
num_prefills
:],
seq_lens_tensor
=
self
.
seq_lens_tensor
[
self
.
num_prefills
:],
max_query_len
=
None
,
max_query_len
=
None
,
...
...
vllm/attention/backends/flash_attn.py
View file @
6c0b7f54
"""Attention layer with FlashAttention."""
"""Attention layer with FlashAttention."""
from
collections
import
defaultdict
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
...
@@ -14,6 +15,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
...
@@ -14,6 +15,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
compute_slot_mapping_start_idx
,
compute_slot_mapping_start_idx
,
is_block_tables_empty
)
is_block_tables_empty
)
from
vllm.forward_context
import
get_forward_context
from
vllm.forward_context
import
get_forward_context
from
vllm.multimodal
import
MultiModalPlaceholderMap
from
vllm.utils
import
(
async_tensor_h2d
,
direct_register_custom_op
,
from
vllm.utils
import
(
async_tensor_h2d
,
direct_register_custom_op
,
make_tensor_with_pad
)
make_tensor_with_pad
)
...
@@ -169,6 +171,8 @@ class FlashAttentionMetadata(AttentionMetadata):
...
@@ -169,6 +171,8 @@ class FlashAttentionMetadata(AttentionMetadata):
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
0
,
num_decode_tokens
=
0
,
slot_mapping
=
self
.
slot_mapping
[:
self
.
num_prefill_tokens
],
slot_mapping
=
self
.
slot_mapping
[:
self
.
num_prefill_tokens
],
multi_modal_placeholder_index_maps
=
self
.
multi_modal_placeholder_index_maps
,
seq_lens
=
self
.
seq_lens
[:
self
.
num_prefills
],
seq_lens
=
self
.
seq_lens
[:
self
.
num_prefills
],
seq_lens_tensor
=
self
.
seq_lens_tensor
[:
self
.
num_prefills
],
seq_lens_tensor
=
self
.
seq_lens_tensor
[:
self
.
num_prefills
],
max_query_len
=
self
.
max_query_len
,
max_query_len
=
self
.
max_query_len
,
...
@@ -198,6 +202,7 @@ class FlashAttentionMetadata(AttentionMetadata):
...
@@ -198,6 +202,7 @@ class FlashAttentionMetadata(AttentionMetadata):
num_prefill_tokens
=
0
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
self
.
num_decode_tokens
,
num_decode_tokens
=
self
.
num_decode_tokens
,
slot_mapping
=
self
.
slot_mapping
[
self
.
num_prefill_tokens
:],
slot_mapping
=
self
.
slot_mapping
[
self
.
num_prefill_tokens
:],
multi_modal_placeholder_index_maps
=
None
,
seq_lens
=
None
,
seq_lens
=
None
,
seq_lens_tensor
=
self
.
seq_lens_tensor
[
self
.
num_prefills
:],
seq_lens_tensor
=
self
.
seq_lens_tensor
[
self
.
num_prefills
:],
max_decode_query_len
=
self
.
max_decode_query_len
,
max_decode_query_len
=
self
.
max_decode_query_len
,
...
@@ -297,6 +302,9 @@ class FlashAttentionMetadataBuilder(
...
@@ -297,6 +302,9 @@ class FlashAttentionMetadataBuilder(
self
.
context_lens
:
List
[
int
]
=
[]
self
.
context_lens
:
List
[
int
]
=
[]
self
.
block_tables
:
List
[
List
[
int
]]
=
[]
self
.
block_tables
:
List
[
List
[
int
]]
=
[]
self
.
curr_seq_lens
:
List
[
int
]
=
[]
self
.
curr_seq_lens
:
List
[
int
]
=
[]
self
.
multimodal_placeholder_maps
:
Dict
[
str
,
MultiModalPlaceholderMap
]
=
defaultdict
(
MultiModalPlaceholderMap
)
self
.
num_prefills
=
0
self
.
num_prefills
=
0
self
.
num_prefill_tokens
=
0
self
.
num_prefill_tokens
=
0
self
.
num_decode_tokens
=
0
self
.
num_decode_tokens
=
0
...
@@ -327,6 +335,12 @@ class FlashAttentionMetadataBuilder(
...
@@ -327,6 +335,12 @@ class FlashAttentionMetadataBuilder(
self
.
context_lens
.
append
(
context_len
)
self
.
context_lens
.
append
(
context_len
)
if
is_prompt
:
if
is_prompt
:
mm_maps
=
inter_data
.
multi_modal_placeholder_maps
if
mm_maps
:
for
modality
,
placeholders
in
mm_maps
.
items
():
self
.
multimodal_placeholder_maps
[
modality
].
extend
(
placeholders
)
self
.
num_prefills
+=
1
self
.
num_prefills
+=
1
self
.
num_prefill_tokens
+=
token_len
self
.
num_prefill_tokens
+=
token_len
self
.
prefill_seq_lens
.
append
(
seq_len
)
self
.
prefill_seq_lens
.
append
(
seq_len
)
...
@@ -449,6 +463,11 @@ class FlashAttentionMetadataBuilder(
...
@@ -449,6 +463,11 @@ class FlashAttentionMetadataBuilder(
seq_start_loc
=
torch
.
zeros
(
seq_lens_tensor
.
shape
[
0
]
+
1
,
seq_start_loc
=
torch
.
zeros
(
seq_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
device
)
device
=
device
)
placeholder_index_maps
=
{
modality
:
placeholder_map
.
index_map
()
for
modality
,
placeholder_map
in
self
.
multimodal_placeholder_maps
.
items
()
}
torch
.
cumsum
(
seq_lens_tensor
,
torch
.
cumsum
(
seq_lens_tensor
,
dim
=
0
,
dim
=
0
,
dtype
=
seq_start_loc
.
dtype
,
dtype
=
seq_start_loc
.
dtype
,
...
@@ -464,6 +483,7 @@ class FlashAttentionMetadataBuilder(
...
@@ -464,6 +483,7 @@ class FlashAttentionMetadataBuilder(
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
num_decode_tokens
=
num_decode_tokens
,
seq_lens
=
seq_lens
,
seq_lens
=
seq_lens
,
multi_modal_placeholder_index_maps
=
placeholder_index_maps
,
seq_lens_tensor
=
seq_lens_tensor
,
seq_lens_tensor
=
seq_lens_tensor
,
max_query_len
=
max_query_len
,
max_query_len
=
max_query_len
,
max_decode_query_len
=
max_decode_query_len
,
max_decode_query_len
=
max_decode_query_len
,
...
...
vllm/attention/backends/flashinfer.py
View file @
6c0b7f54
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Type
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Type
from
vllm.multimodal
import
MultiModalPlaceholderMap
try
:
try
:
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
from
flashinfer.decode
import
CUDAGraphBatchDecodeWithPagedKVCacheWrapper
from
flashinfer.decode
import
CUDAGraphBatchDecodeWithPagedKVCacheWrapper
...
@@ -215,6 +218,7 @@ class FlashInferState(AttentionState):
...
@@ -215,6 +218,7 @@ class FlashInferState(AttentionState):
attn_metadata
=
self
.
runner
.
attn_backend
.
make_metadata
(
attn_metadata
=
self
.
runner
.
attn_backend
.
make_metadata
(
num_prefills
=
0
,
num_prefills
=
0
,
slot_mapping
=
self
.
_graph_slot_mapping
[:
batch_size
],
slot_mapping
=
self
.
_graph_slot_mapping
[:
batch_size
],
multi_modal_placeholder_index_maps
=
None
,
num_prefill_tokens
=
0
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
batch_size
,
num_decode_tokens
=
batch_size
,
max_prefill_seq_len
=
0
,
max_prefill_seq_len
=
0
,
...
@@ -470,6 +474,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -470,6 +474,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self
.
context_lens
:
List
[
int
]
=
[]
self
.
context_lens
:
List
[
int
]
=
[]
self
.
block_tables
:
List
[
List
[
int
]]
=
[]
self
.
block_tables
:
List
[
List
[
int
]]
=
[]
self
.
curr_seq_lens
:
List
[
int
]
=
[]
self
.
curr_seq_lens
:
List
[
int
]
=
[]
self
.
multimodal_placeholder_maps
:
Dict
[
str
,
MultiModalPlaceholderMap
]
=
defaultdict
(
MultiModalPlaceholderMap
)
self
.
num_prefills
=
0
self
.
num_prefills
=
0
self
.
num_prefill_tokens
=
0
self
.
num_prefill_tokens
=
0
self
.
num_decode_tokens
=
0
self
.
num_decode_tokens
=
0
...
@@ -519,6 +526,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -519,6 +526,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
inter_data
.
curr_sliding_window_blocks
):
inter_data
.
curr_sliding_window_blocks
):
self
.
context_lens
.
append
(
context_len
)
self
.
context_lens
.
append
(
context_len
)
if
is_prompt
:
if
is_prompt
:
mm_maps
=
inter_data
.
multi_modal_placeholder_maps
if
mm_maps
:
for
modality
,
placeholders
in
mm_maps
.
items
():
self
.
multimodal_placeholder_maps
[
modality
].
extend
(
placeholders
)
self
.
num_prefills
+=
1
self
.
num_prefills
+=
1
self
.
num_prefill_tokens
+=
token_len
self
.
num_prefill_tokens
+=
token_len
self
.
prefill_seq_lens
.
append
(
seq_len
)
self
.
prefill_seq_lens
.
append
(
seq_len
)
...
@@ -651,6 +663,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -651,6 +663,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
seq_start_loc
=
torch
.
zeros
(
seq_lens_tensor
.
shape
[
0
]
+
1
,
seq_start_loc
=
torch
.
zeros
(
seq_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
device
)
device
=
device
)
placeholder_index_maps
=
{
modality
:
placeholder_map
.
index_map
()
for
modality
,
placeholder_map
in
self
.
multimodal_placeholder_maps
.
items
()
}
torch
.
cumsum
(
seq_lens_tensor
,
torch
.
cumsum
(
seq_lens_tensor
,
dim
=
0
,
dim
=
0
,
dtype
=
seq_start_loc
.
dtype
,
dtype
=
seq_start_loc
.
dtype
,
...
@@ -694,6 +711,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -694,6 +711,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
decode_query_len
=
decode_query_len
,
decode_query_len
=
decode_query_len
,
num_prefills
=
self
.
num_prefills
,
num_prefills
=
self
.
num_prefills
,
slot_mapping
=
slot_mapping_tensor
,
slot_mapping
=
slot_mapping_tensor
,
multi_modal_placeholder_index_maps
=
placeholder_index_maps
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
num_decode_tokens
=
num_decode_tokens
,
max_prefill_seq_len
=
max_prefill_seq_len
,
max_prefill_seq_len
=
max_prefill_seq_len
,
...
...
vllm/attention/backends/placeholder_attn.py
View file @
6c0b7f54
from
collections
import
defaultdict
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
,
Type
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
import
torch
...
@@ -7,6 +8,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
...
@@ -7,6 +8,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata
,
AttentionMetadata
,
AttentionMetadataBuilder
)
AttentionMetadataBuilder
)
from
vllm.attention.backends.utils
import
CommonAttentionState
from
vllm.attention.backends.utils
import
CommonAttentionState
from
vllm.multimodal
import
MultiModalPlaceholderMap
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
ModelInputForGPUBuilder
from
vllm.worker.model_runner
import
ModelInputForGPUBuilder
...
@@ -135,6 +137,8 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
...
@@ -135,6 +137,8 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
0
,
num_decode_tokens
=
0
,
slot_mapping
=
slot_mapping
,
slot_mapping
=
slot_mapping
,
multi_modal_placeholder_index_maps
=
self
.
multi_modal_placeholder_index_maps
,
seq_lens
=
self
.
seq_lens
[:
self
.
num_prefills
],
seq_lens
=
self
.
seq_lens
[:
self
.
num_prefills
],
seq_lens_tensor
=
self
.
seq_lens_tensor
[:
self
.
num_prefills
],
seq_lens_tensor
=
self
.
seq_lens_tensor
[:
self
.
num_prefills
],
max_decode_query_len
=
0
,
max_decode_query_len
=
0
,
...
@@ -167,6 +171,7 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
...
@@ -167,6 +171,7 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
num_prefill_tokens
=
0
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
self
.
num_decode_tokens
,
num_decode_tokens
=
self
.
num_decode_tokens
,
slot_mapping
=
slot_mapping
,
slot_mapping
=
slot_mapping
,
multi_modal_placeholder_index_maps
=
None
,
seq_lens
=
None
,
seq_lens
=
None
,
seq_lens_tensor
=
self
.
seq_lens_tensor
[
self
.
num_prefills
:],
seq_lens_tensor
=
self
.
seq_lens_tensor
[
self
.
num_prefills
:],
max_decode_query_len
=
self
.
max_decode_query_len
,
max_decode_query_len
=
self
.
max_decode_query_len
,
...
@@ -189,6 +194,9 @@ class PlaceholderAttentionMetadataBuilder(
...
@@ -189,6 +194,9 @@ class PlaceholderAttentionMetadataBuilder(
self
.
prefill_seq_lens
:
List
[
int
]
=
[]
self
.
prefill_seq_lens
:
List
[
int
]
=
[]
self
.
context_lens
:
List
[
int
]
=
[]
self
.
context_lens
:
List
[
int
]
=
[]
self
.
curr_seq_lens
:
List
[
int
]
=
[]
self
.
curr_seq_lens
:
List
[
int
]
=
[]
self
.
multimodal_placeholder_maps
:
Dict
[
str
,
MultiModalPlaceholderMap
]
=
defaultdict
(
MultiModalPlaceholderMap
)
self
.
num_prefills
=
0
self
.
num_prefills
=
0
self
.
num_prefill_tokens
=
0
self
.
num_prefill_tokens
=
0
self
.
num_decode_tokens
=
0
self
.
num_decode_tokens
=
0
...
@@ -213,6 +221,12 @@ class PlaceholderAttentionMetadataBuilder(
...
@@ -213,6 +221,12 @@ class PlaceholderAttentionMetadataBuilder(
self
.
context_lens
.
append
(
context_len
)
self
.
context_lens
.
append
(
context_len
)
if
is_prompt
:
if
is_prompt
:
mm_maps
=
inter_data
.
multi_modal_placeholder_maps
if
mm_maps
:
for
modality
,
placeholders
in
mm_maps
.
items
():
self
.
multimodal_placeholder_maps
[
modality
].
extend
(
placeholders
)
self
.
num_prefills
+=
1
self
.
num_prefills
+=
1
self
.
num_prefill_tokens
+=
token_len
self
.
num_prefill_tokens
+=
token_len
self
.
prefill_seq_lens
.
append
(
seq_len
)
self
.
prefill_seq_lens
.
append
(
seq_len
)
...
@@ -280,6 +294,11 @@ class PlaceholderAttentionMetadataBuilder(
...
@@ -280,6 +294,11 @@ class PlaceholderAttentionMetadataBuilder(
seq_start_loc
=
torch
.
zeros
(
seq_lens_tensor
.
shape
[
0
]
+
1
,
seq_start_loc
=
torch
.
zeros
(
seq_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
device
)
device
=
device
)
placeholder_index_maps
=
{
modality
:
placeholder_map
.
index_map
()
for
modality
,
placeholder_map
in
self
.
multimodal_placeholder_maps
.
items
()
}
torch
.
cumsum
(
seq_lens_tensor
,
torch
.
cumsum
(
seq_lens_tensor
,
dim
=
0
,
dim
=
0
,
dtype
=
seq_start_loc
.
dtype
,
dtype
=
seq_start_loc
.
dtype
,
...
@@ -296,6 +315,7 @@ class PlaceholderAttentionMetadataBuilder(
...
@@ -296,6 +315,7 @@ class PlaceholderAttentionMetadataBuilder(
return
PlaceholderAttentionMetadata
(
return
PlaceholderAttentionMetadata
(
num_prefills
=
self
.
num_prefills
,
num_prefills
=
self
.
num_prefills
,
slot_mapping
=
slot_mapping
,
slot_mapping
=
slot_mapping
,
multi_modal_placeholder_index_maps
=
placeholder_index_maps
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
num_decode_tokens
=
num_decode_tokens
,
seq_lens
=
seq_lens
,
seq_lens
=
seq_lens
,
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
6c0b7f54
...
@@ -150,6 +150,8 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
...
@@ -150,6 +150,8 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
0
,
num_decode_tokens
=
0
,
slot_mapping
=
self
.
slot_mapping
[:
self
.
num_prefill_tokens
],
slot_mapping
=
self
.
slot_mapping
[:
self
.
num_prefill_tokens
],
multi_modal_placeholder_index_maps
=
self
.
multi_modal_placeholder_index_maps
,
seq_lens
=
self
.
seq_lens
[:
self
.
num_prefills
],
seq_lens
=
self
.
seq_lens
[:
self
.
num_prefills
],
seq_lens_tensor
=
self
.
seq_lens_tensor
[:
self
.
num_prefills
],
seq_lens_tensor
=
self
.
seq_lens_tensor
[:
self
.
num_prefills
],
max_query_len
=
self
.
max_query_len
,
max_query_len
=
self
.
max_query_len
,
...
@@ -178,6 +180,7 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
...
@@ -178,6 +180,7 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
num_prefill_tokens
=
0
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
self
.
num_decode_tokens
,
num_decode_tokens
=
self
.
num_decode_tokens
,
slot_mapping
=
self
.
slot_mapping
[
self
.
num_prefill_tokens
:],
slot_mapping
=
self
.
slot_mapping
[
self
.
num_prefill_tokens
:],
multi_modal_placeholder_index_maps
=
None
,
seq_lens
=
None
,
seq_lens
=
None
,
seq_lens_tensor
=
self
.
seq_lens_tensor
[
self
.
num_prefills
:],
seq_lens_tensor
=
self
.
seq_lens_tensor
[
self
.
num_prefills
:],
max_query_len
=
None
,
max_query_len
=
None
,
...
...
vllm/attention/backends/utils.py
View file @
6c0b7f54
"""Attention backend utils"""
"""Attention backend utils"""
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Type
,
TypeVar
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Type
,
TypeVar
,
Union
...
@@ -7,6 +8,7 @@ import torch
...
@@ -7,6 +8,7 @@ import torch
from
vllm.attention
import
(
AttentionMetadata
,
AttentionMetadataBuilder
,
from
vllm.attention
import
(
AttentionMetadata
,
AttentionMetadataBuilder
,
AttentionState
)
AttentionState
)
from
vllm.multimodal
import
MultiModalPlaceholderMap
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -123,6 +125,9 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
...
@@ -123,6 +125,9 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
self
.
context_lens
:
List
[
int
]
=
[]
self
.
context_lens
:
List
[
int
]
=
[]
self
.
block_tables
:
List
[
List
[
int
]]
=
[]
self
.
block_tables
:
List
[
List
[
int
]]
=
[]
self
.
curr_seq_lens
:
List
[
int
]
=
[]
self
.
curr_seq_lens
:
List
[
int
]
=
[]
self
.
multimodal_placeholder_maps
:
Dict
[
str
,
MultiModalPlaceholderMap
]
=
defaultdict
(
MultiModalPlaceholderMap
)
self
.
num_prefills
=
0
self
.
num_prefills
=
0
self
.
num_prefill_tokens
=
0
self
.
num_prefill_tokens
=
0
self
.
num_decode_tokens
=
0
self
.
num_decode_tokens
=
0
...
@@ -147,6 +152,12 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
...
@@ -147,6 +152,12 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
inter_data
.
curr_sliding_window_blocks
):
inter_data
.
curr_sliding_window_blocks
):
self
.
context_lens
.
append
(
context_len
)
self
.
context_lens
.
append
(
context_len
)
if
is_prompt
:
if
is_prompt
:
mm_maps
=
inter_data
.
multi_modal_placeholder_maps
if
mm_maps
:
for
modality
,
placeholders
in
mm_maps
.
items
():
self
.
multimodal_placeholder_maps
[
modality
].
extend
(
placeholders
)
self
.
num_prefills
+=
1
self
.
num_prefills
+=
1
self
.
num_prefill_tokens
+=
token_len
self
.
num_prefill_tokens
+=
token_len
self
.
prefill_seq_lens
.
append
(
seq_len
)
self
.
prefill_seq_lens
.
append
(
seq_len
)
...
@@ -242,6 +253,11 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
...
@@ -242,6 +253,11 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
seq_start_loc
=
torch
.
zeros
(
seq_lens_tensor
.
shape
[
0
]
+
1
,
seq_start_loc
=
torch
.
zeros
(
seq_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
device
)
device
=
device
)
placeholder_index_maps
=
{
modality
:
placeholder_map
.
index_map
()
for
modality
,
placeholder_map
in
self
.
multimodal_placeholder_maps
.
items
()
}
torch
.
cumsum
(
seq_lens_tensor
,
torch
.
cumsum
(
seq_lens_tensor
,
dim
=
0
,
dim
=
0
,
dtype
=
seq_start_loc
.
dtype
,
dtype
=
seq_start_loc
.
dtype
,
...
@@ -254,6 +270,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
...
@@ -254,6 +270,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
return
self
.
_metadata_cls
(
# type: ignore
return
self
.
_metadata_cls
(
# type: ignore
num_prefills
=
self
.
num_prefills
,
num_prefills
=
self
.
num_prefills
,
slot_mapping
=
slot_mapping_tensor
,
slot_mapping
=
slot_mapping_tensor
,
multi_modal_placeholder_index_maps
=
placeholder_index_maps
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
num_decode_tokens
=
num_decode_tokens
,
seq_lens
=
seq_lens
,
seq_lens
=
seq_lens
,
...
@@ -305,6 +322,7 @@ class CommonAttentionState(AttentionState):
...
@@ -305,6 +322,7 @@ class CommonAttentionState(AttentionState):
num_prefill_tokens
=
0
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
batch_size
,
num_decode_tokens
=
batch_size
,
slot_mapping
=
self
.
_graph_slot_mapping
[:
batch_size
],
slot_mapping
=
self
.
_graph_slot_mapping
[:
batch_size
],
multi_modal_placeholder_index_maps
=
None
,
seq_lens
=
None
,
seq_lens
=
None
,
seq_lens_tensor
=
self
.
_graph_seq_lens
[:
batch_size
],
seq_lens_tensor
=
self
.
_graph_seq_lens
[:
batch_size
],
max_query_len
=
1
,
max_query_len
=
1
,
...
...
vllm/attention/backends/xformers.py
View file @
6c0b7f54
...
@@ -212,6 +212,8 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
...
@@ -212,6 +212,8 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
0
,
num_decode_tokens
=
0
,
slot_mapping
=
slot_mapping
,
slot_mapping
=
slot_mapping
,
multi_modal_placeholder_index_maps
=
self
.
multi_modal_placeholder_index_maps
,
seq_lens
=
seq_lens
,
seq_lens
=
seq_lens
,
seq_lens_tensor
=
seq_lens_tensor
,
seq_lens_tensor
=
seq_lens_tensor
,
max_query_len
=
self
.
max_query_len
,
max_query_len
=
self
.
max_query_len
,
...
@@ -255,6 +257,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
...
@@ -255,6 +257,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
num_prefill_tokens
=
0
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
self
.
num_decode_tokens
,
num_decode_tokens
=
self
.
num_decode_tokens
,
slot_mapping
=
slot_mapping
,
slot_mapping
=
slot_mapping
,
multi_modal_placeholder_index_maps
=
None
,
seq_lens_tensor
=
seq_lens_tensor
,
seq_lens_tensor
=
seq_lens_tensor
,
max_prefill_seq_len
=
0
,
max_prefill_seq_len
=
0
,
max_decode_seq_len
=
self
.
max_decode_seq_len
,
max_decode_seq_len
=
self
.
max_decode_seq_len
,
...
...
vllm/core/scheduler.py
View file @
6c0b7f54
...
@@ -1308,6 +1308,8 @@ class Scheduler:
...
@@ -1308,6 +1308,8 @@ class Scheduler:
# `multi_modal_data` will be None.
# `multi_modal_data` will be None.
multi_modal_data
=
seq_group
.
multi_modal_data
multi_modal_data
=
seq_group
.
multi_modal_data
if
scheduler_outputs
.
num_prefill_groups
>
0
else
None
,
if
scheduler_outputs
.
num_prefill_groups
>
0
else
None
,
multi_modal_placeholders
=
seq_group
.
multi_modal_placeholders
if
scheduler_outputs
.
num_prefill_groups
>
0
else
None
,
mm_processor_kwargs
=
seq_group
.
mm_processor_kwargs
,
mm_processor_kwargs
=
seq_group
.
mm_processor_kwargs
,
prompt_adapter_request
=
seq_group
.
prompt_adapter_request
,
prompt_adapter_request
=
seq_group
.
prompt_adapter_request
,
)
)
...
...
vllm/inputs/__init__.py
View file @
6c0b7f54
...
@@ -3,7 +3,7 @@ from .data import (DecoderOnlyInputs, EncoderDecoderInputs,
...
@@ -3,7 +3,7 @@ from .data import (DecoderOnlyInputs, EncoderDecoderInputs,
SingletonPrompt
,
TextPrompt
,
TokenInputs
,
TokensPrompt
,
SingletonPrompt
,
TextPrompt
,
TokenInputs
,
TokensPrompt
,
build_explicit_enc_dec_prompt
,
to_enc_dec_tuple_list
,
build_explicit_enc_dec_prompt
,
to_enc_dec_tuple_list
,
token_inputs
,
zip_enc_dec_prompts
)
token_inputs
,
zip_enc_dec_prompts
)
from
.registry
import
InputContext
,
InputRegistry
from
.registry
import
DummyData
,
InputContext
,
InputRegistry
INPUT_REGISTRY
=
InputRegistry
()
INPUT_REGISTRY
=
InputRegistry
()
"""
"""
...
@@ -29,6 +29,7 @@ __all__ = [
...
@@ -29,6 +29,7 @@ __all__ = [
"to_enc_dec_tuple_list"
,
"to_enc_dec_tuple_list"
,
"zip_enc_dec_prompts"
,
"zip_enc_dec_prompts"
,
"INPUT_REGISTRY"
,
"INPUT_REGISTRY"
,
"DummyData"
,
"InputContext"
,
"InputContext"
,
"InputRegistry"
,
"InputRegistry"
,
]
]
...
...
vllm/inputs/data.py
View file @
6c0b7f54
...
@@ -4,7 +4,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List,
...
@@ -4,7 +4,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List,
from
typing_extensions
import
NotRequired
,
TypedDict
,
TypeVar
from
typing_extensions
import
NotRequired
,
TypedDict
,
TypeVar
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.multimodal
import
MultiModalDataDict
from
vllm.multimodal
import
MultiModalDataDict
,
MultiModalPlaceholderDict
class
TextPrompt
(
TypedDict
):
class
TextPrompt
(
TypedDict
):
...
@@ -136,6 +136,12 @@ class TokenInputs(TypedDict):
...
@@ -136,6 +136,12 @@ class TokenInputs(TypedDict):
if the model supports it.
if the model supports it.
"""
"""
multi_modal_placeholders
:
NotRequired
[
Optional
[
"MultiModalPlaceholderDict"
]]
"""
Placeholder ranges for the multi-modal data.
"""
mm_processor_kwargs
:
NotRequired
[
Optional
[
Dict
[
str
,
Any
]]]
mm_processor_kwargs
:
NotRequired
[
Optional
[
Dict
[
str
,
Any
]]]
"""
"""
Optional multi-modal processor kwargs to be forwarded to the
Optional multi-modal processor kwargs to be forwarded to the
...
@@ -149,6 +155,7 @@ def token_inputs(
...
@@ -149,6 +155,7 @@ def token_inputs(
prompt_token_ids
:
List
[
int
],
prompt_token_ids
:
List
[
int
],
prompt
:
Optional
[
str
]
=
None
,
prompt
:
Optional
[
str
]
=
None
,
multi_modal_data
:
Optional
[
"MultiModalDataDict"
]
=
None
,
multi_modal_data
:
Optional
[
"MultiModalDataDict"
]
=
None
,
multi_modal_placeholders
:
Optional
[
"MultiModalPlaceholderDict"
]
=
None
,
mm_processor_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
mm_processor_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
TokenInputs
:
)
->
TokenInputs
:
"""Construct :class:`TokenInputs` from optional values."""
"""Construct :class:`TokenInputs` from optional values."""
...
@@ -158,6 +165,8 @@ def token_inputs(
...
@@ -158,6 +165,8 @@ def token_inputs(
inputs
[
"prompt"
]
=
prompt
inputs
[
"prompt"
]
=
prompt
if
multi_modal_data
is
not
None
:
if
multi_modal_data
is
not
None
:
inputs
[
"multi_modal_data"
]
=
multi_modal_data
inputs
[
"multi_modal_data"
]
=
multi_modal_data
if
multi_modal_placeholders
is
not
None
:
inputs
[
"multi_modal_placeholders"
]
=
multi_modal_placeholders
if
mm_processor_kwargs
is
not
None
:
if
mm_processor_kwargs
is
not
None
:
inputs
[
"mm_processor_kwargs"
]
=
mm_processor_kwargs
inputs
[
"mm_processor_kwargs"
]
=
mm_processor_kwargs
...
...
vllm/inputs/registry.py
View file @
6c0b7f54
import
functools
import
functools
from
collections
import
UserDict
from
collections
import
UserDict
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
Mapping
,
Optional
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
Mapping
,
NamedTuple
,
Protocol
,
Tuple
,
Type
)
Optional
,
Protocol
,
Type
)
from
torch
import
nn
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
...
@@ -16,7 +16,8 @@ from .data import DecoderOnlyInputs
...
@@ -16,7 +16,8 @@ from .data import DecoderOnlyInputs
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.multimodal
import
MultiModalDataDict
,
MultiModalRegistry
from
vllm.multimodal
import
(
MultiModalDataDict
,
MultiModalPlaceholderDict
,
MultiModalRegistry
)
from
vllm.sequence
import
SequenceData
from
vllm.sequence
import
SequenceData
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -63,6 +64,14 @@ class InputContext:
...
@@ -63,6 +64,14 @@ class InputContext:
N
=
TypeVar
(
"N"
,
bound
=
Type
[
nn
.
Module
])
N
=
TypeVar
(
"N"
,
bound
=
Type
[
nn
.
Module
])
class
DummyData
(
NamedTuple
):
"""Dummy data used for profiling."""
seq_data
:
"SequenceData"
multi_modal_data
:
Optional
[
"MultiModalDataDict"
]
=
None
multi_modal_placeholders
:
Optional
[
"MultiModalPlaceholderDict"
]
=
None
class
DummyDataFactory
(
Protocol
):
class
DummyDataFactory
(
Protocol
):
def
__call__
(
def
__call__
(
...
@@ -71,7 +80,7 @@ class DummyDataFactory(Protocol):
...
@@ -71,7 +80,7 @@ class DummyDataFactory(Protocol):
seq_len
:
int
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
mm_counts
:
Mapping
[
str
,
int
],
**
mm_processor_kwargs
:
Any
,
**
mm_processor_kwargs
:
Any
,
)
->
Tuple
[
"SequenceData"
,
Optional
[
"MultiModalDataDict"
]]
:
)
->
DummyData
:
"""
"""
Create dummy data to be inputted into the model.
Create dummy data to be inputted into the model.
...
@@ -123,7 +132,7 @@ class InputRegistry:
...
@@ -123,7 +132,7 @@ class InputRegistry:
ctx
:
InputContext
,
ctx
:
InputContext
,
seq_len
:
int
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
mm_counts
:
Mapping
[
str
,
int
],
)
->
Tuple
[
"SequenceData"
,
Optional
[
"MultiModalDataDict"
]]
:
)
->
DummyData
:
"""
"""
The default dummy data factory represents the longest possible text
The default dummy data factory represents the longest possible text
that can be inputted to the model.
that can be inputted to the model.
...
@@ -134,10 +143,7 @@ class InputRegistry:
...
@@ -134,10 +143,7 @@ class InputRegistry:
# Avoid circular import
# Avoid circular import
from
vllm.sequence
import
SequenceData
from
vllm.sequence
import
SequenceData
dummy_seq_data
=
SequenceData
.
from_prompt_token_counts
((
0
,
seq_len
))
return
DummyData
(
SequenceData
.
from_prompt_token_counts
((
0
,
seq_len
)))
dummy_multi_modal_data
=
None
return
dummy_seq_data
,
dummy_multi_modal_data
def
register_dummy_data
(
self
,
factory
:
DummyDataFactory
):
def
register_dummy_data
(
self
,
factory
:
DummyDataFactory
):
"""
"""
...
@@ -195,7 +201,7 @@ class InputRegistry:
...
@@ -195,7 +201,7 @@ class InputRegistry:
seq_len
:
int
,
seq_len
:
int
,
mm_registry
:
"MultiModalRegistry"
,
mm_registry
:
"MultiModalRegistry"
,
is_encoder_data
:
bool
=
False
,
is_encoder_data
:
bool
=
False
,
)
->
Tuple
[
"SequenceData"
,
Optional
[
"MultiModalDataDict"
]]
:
)
->
DummyData
:
"""
"""
Create dummy data for profiling the memory usage of a model.
Create dummy data for profiling the memory usage of a model.
...
@@ -220,12 +226,12 @@ class InputRegistry:
...
@@ -220,12 +226,12 @@ class InputRegistry:
mm_processor_kwargs
=
get_allowed_kwarg_only_overrides
(
mm_processor_kwargs
=
get_allowed_kwarg_only_overrides
(
dummy_factory
,
overrides
=
model_config
.
mm_processor_kwargs
)
dummy_factory
,
overrides
=
model_config
.
mm_processor_kwargs
)
seq_data
,
mm_data
=
dummy_factory
(
InputContext
(
model_config
),
seq_len
,
du
mm
y
_data
=
dummy_factory
(
InputContext
(
model_config
),
seq_len
,
_MultiModalCounts
(
mm_counts
),
_MultiModalCounts
(
mm_counts
),
**
mm_processor_kwargs
)
**
mm_processor_kwargs
)
# Having more tokens is over-conservative but otherwise fine
# Having more tokens is over-conservative but otherwise fine
num_tokens
=
seq_data
.
prompt_token_ids
num_tokens
=
dummy_data
.
seq_data
.
prompt_token_ids
if
len
(
num_tokens
)
<
seq_len
:
if
len
(
num_tokens
)
<
seq_len
:
if
is_encoder_data
:
if
is_encoder_data
:
print_warning_once
(
print_warning_once
(
...
@@ -235,15 +241,15 @@ class InputRegistry:
...
@@ -235,15 +241,15 @@ class InputRegistry:
raise
AssertionError
(
raise
AssertionError
(
f
"Expected at least
{
seq_len
}
dummy tokens for profiling, "
f
"Expected at least
{
seq_len
}
dummy tokens for profiling, "
f
"but found
{
len
(
num_tokens
)
}
tokens instead."
)
f
"but found
{
len
(
num_tokens
)
}
tokens instead."
)
if
mm
_data
is
not
None
:
if
dummy_data
.
multi_modal
_data
is
not
None
:
for
k
,
v
in
mm
_data
.
items
():
for
k
,
v
in
dummy_data
.
multi_modal
_data
.
items
():
num_items
=
len
(
v
)
if
isinstance
(
v
,
list
)
else
1
num_items
=
len
(
v
)
if
isinstance
(
v
,
list
)
else
1
num_expected
=
mm_counts
[
k
]
num_expected
=
mm_counts
[
k
]
assert
num_items
>=
num_expected
,
(
assert
num_items
>=
num_expected
,
(
f
"Expected at least
{
num_expected
}
dummy '
{
k
}
' instances "
f
"Expected at least
{
num_expected
}
dummy '
{
k
}
' instances "
f
"for profiling, but found
{
num_items
}
instances instead."
)
f
"for profiling, but found
{
num_items
}
instances instead."
)
return
seq_data
,
mm_data
return
du
mm
y
_data
def
_default_input_processor
(
def
_default_input_processor
(
self
,
self
,
...
...
vllm/model_executor/models/blip.py
View file @
6c0b7f54
...
@@ -98,6 +98,11 @@ def input_processor_for_blip(
...
@@ -98,6 +98,11 @@ def input_processor_for_blip(
if
multi_modal_data
is
None
or
"image"
not
in
multi_modal_data
:
if
multi_modal_data
is
None
or
"image"
not
in
multi_modal_data
:
return
inputs
return
inputs
if
"multi_modal_placeholders"
in
inputs
and
"image"
in
inputs
[
"multi_modal_placeholders"
]:
# The inputs already have placeholders.
return
inputs
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
)
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
)
if
image_feature_size_override
is
None
:
if
image_feature_size_override
is
None
:
...
@@ -105,7 +110,7 @@ def input_processor_for_blip(
...
@@ -105,7 +110,7 @@ def input_processor_for_blip(
else
:
else
:
image_feature_size
=
image_feature_size_override
image_feature_size
=
image_feature_size_override
new_prompt
,
new_token_ids
=
repeat_and_pad_placeholder_tokens
(
new_prompt
,
new_token_ids
,
ranges
=
repeat_and_pad_placeholder_tokens
(
tokenizer
,
tokenizer
,
inputs
.
get
(
"prompt"
),
inputs
.
get
(
"prompt"
),
inputs
[
"prompt_token_ids"
],
inputs
[
"prompt_token_ids"
],
...
@@ -116,7 +121,8 @@ def input_processor_for_blip(
...
@@ -116,7 +121,8 @@ def input_processor_for_blip(
# NOTE: Create a defensive copy of the original inputs
# NOTE: Create a defensive copy of the original inputs
return
token_inputs
(
prompt_token_ids
=
new_token_ids
,
return
token_inputs
(
prompt_token_ids
=
new_token_ids
,
prompt
=
new_prompt
,
prompt
=
new_prompt
,
multi_modal_data
=
multi_modal_data
)
multi_modal_data
=
multi_modal_data
,
multi_modal_placeholders
=
{
"image"
:
ranges
})
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa
...
...
vllm/model_executor/models/blip2.py
View file @
6c0b7f54
...
@@ -9,13 +9,14 @@ from transformers import (Blip2Config, Blip2QFormerConfig, Blip2VisionConfig,
...
@@ -9,13 +9,14 @@ from transformers import (Blip2Config, Blip2QFormerConfig, Blip2VisionConfig,
from
vllm.attention
import
AttentionMetadata
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.inputs
import
(
INPUT_REGISTRY
,
DecoderOnlyInputs
,
InputContext
,
from
vllm.inputs
import
(
INPUT_REGISTRY
,
DecoderOnlyInputs
,
DummyData
,
token_inputs
)
InputContext
,
token_inputs
)
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.utils
import
consecutive_placeholder_ranges
from
vllm.sequence
import
IntermediateTensors
,
SequenceData
from
vllm.sequence
import
IntermediateTensors
,
SequenceData
from
.blip
import
(
BlipVisionModel
,
dummy_image_for_blip
,
from
.blip
import
(
BlipVisionModel
,
dummy_image_for_blip
,
...
@@ -425,7 +426,11 @@ def dummy_seq_data_for_blip2(
...
@@ -425,7 +426,11 @@ def dummy_seq_data_for_blip2(
return
SequenceData
.
from_prompt_token_counts
(
return
SequenceData
.
from_prompt_token_counts
(
(
image_token_id
,
image_feature_size
*
num_images
),
(
image_token_id
,
image_feature_size
*
num_images
),
(
0
,
seq_len
-
image_feature_size
*
num_images
),
(
0
,
seq_len
-
image_feature_size
*
num_images
),
)
),
{
"image"
:
consecutive_placeholder_ranges
(
num_items
=
num_images
,
item_size
=
image_feature_size
)
}
def
dummy_data_for_blip2
(
ctx
:
InputContext
,
seq_len
:
int
,
def
dummy_data_for_blip2
(
ctx
:
InputContext
,
seq_len
:
int
,
...
@@ -434,7 +439,7 @@ def dummy_data_for_blip2(ctx: InputContext, seq_len: int,
...
@@ -434,7 +439,7 @@ def dummy_data_for_blip2(ctx: InputContext, seq_len: int,
vision_config
=
hf_config
.
vision_config
vision_config
=
hf_config
.
vision_config
num_images
=
mm_counts
[
"image"
]
num_images
=
mm_counts
[
"image"
]
seq_data
=
dummy_seq_data_for_blip2
(
seq_data
,
ranges
=
dummy_seq_data_for_blip2
(
hf_config
,
hf_config
,
seq_len
,
seq_len
,
num_images
,
num_images
,
...
@@ -444,7 +449,7 @@ def dummy_data_for_blip2(ctx: InputContext, seq_len: int,
...
@@ -444,7 +449,7 @@ def dummy_data_for_blip2(ctx: InputContext, seq_len: int,
if
isinstance
(
vision_config
,
Blip2VisionConfig
):
if
isinstance
(
vision_config
,
Blip2VisionConfig
):
mm_data
=
dummy_image_for_blip
(
vision_config
,
num_images
)
mm_data
=
dummy_image_for_blip
(
vision_config
,
num_images
)
return
seq_data
,
mm_data
return
DummyData
(
seq_data
,
mm_data
,
ranges
)
msg
=
f
"Unsupported vision config:
{
type
(
vision_config
)
}
"
msg
=
f
"Unsupported vision config:
{
type
(
vision_config
)
}
"
raise
NotImplementedError
(
msg
)
raise
NotImplementedError
(
msg
)
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment