Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6c0b7f54
Unverified
Commit
6c0b7f54
authored
Nov 01, 2024
by
Peter Salas
Committed by
GitHub
Nov 01, 2024
Browse files
[Core][VLM] Add precise multi-modal placeholder tracking (#8346)
Signed-off-by:
Peter Salas
<
peter@fixie.ai
>
parent
d151fde8
Changes
53
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
284 additions
and
68 deletions
+284
-68
examples/offline_inference_audio_language.py
examples/offline_inference_audio_language.py
+1
-5
tests/kernels/utils.py
tests/kernels/utils.py
+2
-0
tests/models/decoder_only/audio_language/test_ultravox.py
tests/models/decoder_only/audio_language/test_ultravox.py
+74
-17
tests/multimodal/test_processor_kwargs.py
tests/multimodal/test_processor_kwargs.py
+7
-7
tests/multimodal/test_utils.py
tests/multimodal/test_utils.py
+45
-12
tests/worker/test_model_input.py
tests/worker/test_model_input.py
+3
-0
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+11
-0
vllm/attention/backends/blocksparse_attn.py
vllm/attention/backends/blocksparse_attn.py
+3
-0
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+20
-0
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+18
-0
vllm/attention/backends/placeholder_attn.py
vllm/attention/backends/placeholder_attn.py
+21
-1
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+3
-0
vllm/attention/backends/utils.py
vllm/attention/backends/utils.py
+18
-0
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+3
-0
vllm/core/scheduler.py
vllm/core/scheduler.py
+2
-0
vllm/inputs/__init__.py
vllm/inputs/__init__.py
+2
-1
vllm/inputs/data.py
vllm/inputs/data.py
+10
-1
vllm/inputs/registry.py
vllm/inputs/registry.py
+23
-17
vllm/model_executor/models/blip.py
vllm/model_executor/models/blip.py
+8
-2
vllm/model_executor/models/blip2.py
vllm/model_executor/models/blip2.py
+10
-5
No files found.
examples/offline_inference_audio_language.py
View file @
6c0b7f54
...
...
@@ -34,11 +34,7 @@ def run_ultravox(question: str, audio_count: int):
tokenize
=
False
,
add_generation_prompt
=
True
)
llm
=
LLM
(
model
=
model_name
,
enforce_eager
=
True
,
enable_chunked_prefill
=
False
,
max_model_len
=
8192
,
limit_mm_per_prompt
=
{
"audio"
:
audio_count
})
llm
=
LLM
(
model
=
model_name
,
limit_mm_per_prompt
=
{
"audio"
:
audio_count
})
stop_token_ids
=
None
return
llm
,
prompt
,
stop_token_ids
...
...
tests/kernels/utils.py
View file @
6c0b7f54
...
...
@@ -869,6 +869,7 @@ def make_test_metadata(
return
attn_backend
.
make_metadata
(
num_prefills
=
num_prefills
,
slot_mapping
=
(
None
if
kv_mmap
is
None
else
kv_mmap
.
slot_mapping
),
multi_modal_placeholder_index_maps
=
None
,
num_prefill_tokens
=
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
seq_lens
=
seq_lens
,
...
...
@@ -914,6 +915,7 @@ def make_test_metadata(
return
attn_backend
.
make_metadata
(
num_prefills
=
num_prefills
,
slot_mapping
=
kv_mmap
.
slot_mapping
,
multi_modal_placeholder_index_maps
=
None
,
num_prefill_tokens
=
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
seq_lens
=
seq_lens
,
...
...
tests/models/decoder_only/audio_language/test_ultravox.py
View file @
6c0b7f54
...
...
@@ -2,8 +2,10 @@ from typing import List, Optional, Tuple, Type
import
numpy
as
np
import
pytest
import
pytest_asyncio
from
transformers
import
AutoModel
,
AutoTokenizer
,
BatchEncoding
from
tests.utils
import
RemoteOpenAIServer
from
vllm.sequence
import
SampleLogprobs
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
...
...
@@ -17,6 +19,13 @@ AudioTuple = Tuple[np.ndarray, int]
VLLM_PLACEHOLDER
=
"<|reserved_special_token_0|>"
HF_PLACEHOLDER
=
"<|audio|>"
CHUNKED_PREFILL_KWARGS
=
{
"enable_chunked_prefill"
:
True
,
"max_num_seqs"
:
2
,
# Use a very small limit to exercise chunked prefill.
"max_num_batched_tokens"
:
16
}
@
pytest
.
fixture
(
scope
=
"session"
)
def
audio_assets
():
...
...
@@ -30,6 +39,26 @@ def audio(request):
return
AudioAsset
(
request
.
param
)
@
pytest
.
fixture
(
params
=
({},
CHUNKED_PREFILL_KWARGS
))
def
server
(
request
,
audio_assets
):
args
=
[
"--dtype=bfloat16"
,
"--max-model-len=4096"
,
"--enforce-eager"
,
f
"--limit-mm-per-prompt=audio=
{
len
(
audio_assets
)
}
"
]
+
[
f
"--
{
key
.
replace
(
'_'
,
'-'
)
}
=
{
value
}
"
for
key
,
value
in
request
.
param
.
items
()
]
with
RemoteOpenAIServer
(
MODEL_NAME
,
args
)
as
remote_server
:
yield
remote_server
@
pytest_asyncio
.
fixture
async
def
client
(
server
):
async
with
server
.
get_async_client
()
as
async_client
:
yield
async_client
def
_get_prompt
(
audio_count
,
question
,
placeholder
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
MODEL_NAME
)
placeholder
=
f
"
{
placeholder
}
\n
"
*
audio_count
...
...
@@ -68,8 +97,7 @@ def run_test(
dtype
:
str
,
max_tokens
:
int
,
num_logprobs
:
int
,
tensor_parallel_size
:
int
,
distributed_executor_backend
:
Optional
[
str
]
=
None
,
**
kwargs
,
):
"""Inference result should be the same between hf and vllm."""
torch_dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
dtype
]
...
...
@@ -79,11 +107,8 @@ def run_test(
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
with
vllm_runner
(
model
,
dtype
=
dtype
,
tensor_parallel_size
=
tensor_parallel_size
,
distributed_executor_backend
=
distributed_executor_backend
,
enforce_eager
=
True
)
as
vllm_model
:
with
vllm_runner
(
model
,
dtype
=
dtype
,
enforce_eager
=
True
,
**
kwargs
)
as
vllm_model
:
vllm_outputs_per_audio
=
[
vllm_model
.
generate_greedy_logprobs
([
vllm_prompt
],
max_tokens
,
...
...
@@ -135,18 +160,16 @@ def run_multi_audio_test(
dtype
:
str
,
max_tokens
:
int
,
num_logprobs
:
int
,
tensor_parallel_size
:
int
,
distributed_executor_backend
:
Optional
[
str
]
=
None
,
**
kwargs
,
):
with
vllm_runner
(
model
,
dtype
=
dtype
,
tensor_parallel_size
=
tensor_parallel_size
,
distributed_executor_backend
=
distributed_executor_backend
,
enforce_eager
=
True
,
limit_mm_per_prompt
=
{
"audio"
:
max
((
len
(
audio
)
for
_
,
audio
in
prompts_and_audios
))
})
as
vllm_model
:
},
**
kwargs
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_greedy_logprobs
(
[
prompt
for
prompt
,
_
in
prompts_and_audios
],
max_tokens
,
...
...
@@ -162,8 +185,9 @@ def run_multi_audio_test(
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"vllm_kwargs"
,
[{},
CHUNKED_PREFILL_KWARGS
])
def
test_models
(
hf_runner
,
vllm_runner
,
audio
,
dtype
:
str
,
max_tokens
:
int
,
num_logprobs
:
int
)
->
None
:
num_logprobs
:
int
,
vllm_kwargs
:
dict
)
->
None
:
vllm_prompt
=
_get_prompt
(
1
,
"Describe the audio above."
,
VLLM_PLACEHOLDER
)
hf_prompt
=
_get_prompt
(
1
,
"Describe the audio above."
,
HF_PLACEHOLDER
)
...
...
@@ -175,7 +199,7 @@ def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int,
dtype
=
dtype
,
max_tokens
=
max_tokens
,
num_logprobs
=
num_logprobs
,
tensor_parallel_size
=
1
,
**
vllm_kwargs
,
)
...
...
@@ -183,9 +207,10 @@ def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int,
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"vllm_kwargs"
,
[{},
CHUNKED_PREFILL_KWARGS
])
def
test_models_with_multiple_audios
(
vllm_runner
,
audio_assets
,
dtype
:
str
,
max_tokens
:
int
,
num_logprobs
:
in
t
)
->
None
:
max_tokens
:
int
,
num_logprobs
:
int
,
vllm_kwargs
:
dic
t
)
->
None
:
vllm_prompt
=
_get_prompt
(
len
(
audio_assets
),
"Describe each of the audios above."
,
...
...
@@ -198,5 +223,37 @@ def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str,
dtype
=
dtype
,
max_tokens
=
max_tokens
,
num_logprobs
=
num_logprobs
,
tensor_parallel_size
=
1
,
**
vllm_kwargs
,
)
@
pytest
.
mark
.
asyncio
async
def
test_online_inference
(
client
,
audio_assets
):
"""Exercises online inference with/without chunked prefill enabled."""
messages
=
[{
"role"
:
"user"
,
"content"
:
[
*
[{
"type"
:
"audio_url"
,
"audio_url"
:
{
"url"
:
audio
.
url
}
}
for
audio
in
audio_assets
],
{
"type"
:
"text"
,
"text"
:
f
"What's happening in these
{
len
(
audio_assets
)
}
audio clips?"
},
],
}]
chat_completion
=
await
client
.
chat
.
completions
.
create
(
model
=
MODEL_NAME
,
messages
=
messages
,
max_tokens
=
10
)
assert
len
(
chat_completion
.
choices
)
==
1
choice
=
chat_completion
.
choices
[
0
]
assert
choice
.
finish_reason
==
"length"
tests/multimodal/test_processor_kwargs.py
View file @
6c0b7f54
...
...
@@ -5,8 +5,8 @@ from unittest.mock import patch
import
pytest
import
torch
from
vllm.inputs
import
DecoderOnlyInputs
,
InputContext
,
token_inputs
from
vllm.i
nput
s.r
egistry
import
InputRegistry
from
vllm.inputs
import
(
DecoderOnlyInputs
,
DummyData
,
InputContext
,
I
nput
R
egistry
,
token_inputs
)
from
vllm.multimodal
import
MultiModalRegistry
from
vllm.sequence
import
VLLM_TOKEN_ID_ARRAY_TYPE
,
SequenceData
...
...
@@ -56,7 +56,7 @@ def use_dummy_data_mock():
num_crops
=
DEFAULT_NUM_CROPS
):
seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
]
*
num_crops
))
return
seq_data
,
None
return
DummyData
(
seq_data
,
None
)
with
patch
(
"vllm.inputs.registry.InputRegistry._default_dummy_data_factory"
,
...
...
@@ -177,9 +177,9 @@ def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops):
# NOTE: seq_len is thrown away here since this will leverage the
# default dummy data factory that we have patched in, whose seq
# len is solely dependent on the value of the mm_processor_kwargs.
seq
_data
,
_
=
dummy_registry
.
dummy_data_for_profiling
(
dummy
_data
=
dummy_registry
.
dummy_data_for_profiling
(
ctx
.
model_config
,
seq_len
=-
1
,
mm_registry
=
mm_registry
)
assert
len
(
seq_data
.
prompt_token_ids
)
==
expected_seq_count
assert
len
(
dummy_data
.
seq_data
.
prompt_token_ids
)
==
expected_seq_count
@
pytest
.
mark
.
parametrize
(
...
...
@@ -206,9 +206,9 @@ def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock,
# NOTE: seq_len is thrown away here since this will leverage the
# default dummy data factory that we have patched in, whose seq
# len is solely dependent on the value of the mm_processor_kwargs.
seq
_data
,
_
=
dummy_registry
.
dummy_data_for_profiling
(
dummy
_data
=
dummy_registry
.
dummy_data_for_profiling
(
ctx
.
model_config
,
seq_len
=-
1
,
mm_registry
=
mm_registry
)
assert
len
(
seq_data
.
prompt_token_ids
)
==
DEFAULT_NUM_CROPS
assert
len
(
dummy_data
.
seq_data
.
prompt_token_ids
)
==
DEFAULT_NUM_CROPS
### Test overrides for the max token count per multimodal instance
...
...
tests/multimodal/test_utils.py
View file @
6c0b7f54
...
...
@@ -92,18 +92,50 @@ def test_repeat_and_pad_placeholder_tokens(model):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model
)
test_cases
=
[
(
"<image>"
,
2
,
"<image><image>"
,
[
32000
,
32000
]),
(
"<image><image>"
,
2
,
"<image><image><image>"
,
[
32000
,
32000
,
32000
]),
(
"<image><image>"
,
[
3
,
2
],
"<image><image><image><image><image>"
,
[
32000
,
32000
,
32000
,
32000
,
32000
]),
(
"Image:<image>Image:<image>!"
,
[
3
,
2
],
(
"<image>"
,
2
,
"<image><image>"
,
[
32000
,
32000
],
[{
"offset"
:
0
,
"length"
:
2
}],
),
(
"<image><image>"
,
2
,
"<image><image><image>"
,
[
32000
,
32000
,
32000
],
[{
"offset"
:
0
,
"length"
:
2
}]),
(
"<image><image>"
,
[
3
,
2
],
"<image><image><image><image><image>"
,
[
32000
,
32000
,
32000
,
32000
,
32000
],
[{
"offset"
:
0
,
"length"
:
3
},
{
"offset"
:
3
,
"length"
:
2
}],
),
(
"Image:<image>Image:<image>!"
,
[
3
,
2
],
"Image:<image><image><image>Image:<image><image>!"
,
[
9833
,
28747
,
32000
,
32000
,
32000
,
9833
,
28747
,
32000
,
32000
,
918
]),
(
"<image>"
,
[
3
,
2
],
"<image><image><image>"
,
[
32000
,
32000
,
32000
]),
]
for
prompt
,
repeat_count
,
expected_prompt
,
expected_token_ids
in
test_cases
:
new_prompt
,
new_token_ids
=
repeat_and_pad_placeholder_tokens
(
[
9833
,
28747
,
32000
,
32000
,
32000
,
9833
,
28747
,
32000
,
32000
,
918
],
[{
"offset"
:
2
,
"length"
:
3
},
{
"offset"
:
7
,
"length"
:
2
}],
),
(
"<image>"
,
[
3
,
2
],
"<image><image><image>"
,
[
32000
,
32000
,
32000
],
[{
"offset"
:
0
,
"length"
:
3
}],
),
]
# yapf: disable
for
(
prompt
,
repeat_count
,
expected_prompt
,
expected_token_ids
,
expected_ranges
,
)
in
test_cases
:
new_prompt
,
new_token_ids
,
ranges
=
repeat_and_pad_placeholder_tokens
(
tokenizer
=
tokenizer
,
prompt
=
prompt
,
prompt_token_ids
=
tokenizer
.
encode
(
prompt
,
...
...
@@ -113,3 +145,4 @@ def test_repeat_and_pad_placeholder_tokens(model):
)
assert
new_prompt
==
expected_prompt
assert
new_token_ids
==
expected_token_ids
assert
ranges
==
expected_ranges
tests/worker/test_model_input.py
View file @
6c0b7f54
...
...
@@ -73,6 +73,7 @@ def test_model_runner_input():
num_prefill_tokens
=
2
,
num_decode_tokens
=
3
,
slot_mapping
=
torch
.
zeros
(
1
),
multi_modal_placeholder_index_maps
=
None
,
)
model_input
=
ModelInputForGPUWithSamplingMetadata
(
input_tokens
=
torch
.
ones
(
10
),
...
...
@@ -124,6 +125,7 @@ def test_embedding_model_runner_input():
num_prefill_tokens
=
2
,
num_decode_tokens
=
3
,
slot_mapping
=
torch
.
zeros
(
1
),
multi_modal_placeholder_index_maps
=
None
,
)
model_input
=
ModelInputForGPUWithPoolingMetadata
(
input_tokens
=
torch
.
ones
(
10
),
...
...
@@ -174,6 +176,7 @@ def test_multi_step_model_runner_input():
num_prefill_tokens
=
2
,
num_decode_tokens
=
3
,
slot_mapping
=
torch
.
zeros
(
1
),
multi_modal_placeholder_index_maps
=
None
,
)
frozen_model_input
=
ModelInputForGPUWithSamplingMetadata
(
input_tokens
=
torch
.
ones
(
10
),
...
...
vllm/attention/backends/abstract.py
View file @
6c0b7f54
...
...
@@ -7,6 +7,8 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
import
torch
from
vllm.multimodal
import
MultiModalPlaceholderMap
if
TYPE_CHECKING
:
from
vllm.worker.model_runner_base
import
(
ModelRunnerBase
,
ModelRunnerInputBase
,
...
...
@@ -108,6 +110,15 @@ class AttentionMetadata:
# in block 0, and 1st slot in block 1, respectively.
slot_mapping
:
torch
.
Tensor
# The index maps that relate multi-modal embeddings to the corresponding
# placeholders.
#
# N.B. These aren't really related to attention and don't belong on this
# type -- this is just a temporary solution to make them available to
# `model_executable`.
multi_modal_placeholder_index_maps
:
Optional
[
Dict
[
str
,
MultiModalPlaceholderMap
.
IndexMap
]]
@
property
@
abstractmethod
def
prefill_metadata
(
self
)
->
Optional
[
"AttentionMetadata"
]:
...
...
vllm/attention/backends/blocksparse_attn.py
View file @
6c0b7f54
...
...
@@ -215,6 +215,8 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
0
,
slot_mapping
=
self
.
slot_mapping
[:
self
.
num_prefill_tokens
],
multi_modal_placeholder_index_maps
=
self
.
multi_modal_placeholder_index_maps
,
seq_lens
=
self
.
seq_lens
[:
self
.
num_prefills
],
seq_lens_tensor
=
self
.
seq_lens_tensor
[:
self
.
num_prefills
],
max_query_len
=
self
.
max_query_len
,
...
...
@@ -243,6 +245,7 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
num_prefill_tokens
=
0
,
num_decode_tokens
=
self
.
num_decode_tokens
,
slot_mapping
=
self
.
slot_mapping
[
self
.
num_prefill_tokens
:],
multi_modal_placeholder_index_maps
=
None
,
seq_lens
=
None
,
seq_lens_tensor
=
self
.
seq_lens_tensor
[
self
.
num_prefills
:],
max_query_len
=
None
,
...
...
vllm/attention/backends/flash_attn.py
View file @
6c0b7f54
"""Attention layer with FlashAttention."""
from
collections
import
defaultdict
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
...
...
@@ -14,6 +15,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
compute_slot_mapping_start_idx
,
is_block_tables_empty
)
from
vllm.forward_context
import
get_forward_context
from
vllm.multimodal
import
MultiModalPlaceholderMap
from
vllm.utils
import
(
async_tensor_h2d
,
direct_register_custom_op
,
make_tensor_with_pad
)
...
...
@@ -169,6 +171,8 @@ class FlashAttentionMetadata(AttentionMetadata):
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
0
,
slot_mapping
=
self
.
slot_mapping
[:
self
.
num_prefill_tokens
],
multi_modal_placeholder_index_maps
=
self
.
multi_modal_placeholder_index_maps
,
seq_lens
=
self
.
seq_lens
[:
self
.
num_prefills
],
seq_lens_tensor
=
self
.
seq_lens_tensor
[:
self
.
num_prefills
],
max_query_len
=
self
.
max_query_len
,
...
...
@@ -198,6 +202,7 @@ class FlashAttentionMetadata(AttentionMetadata):
num_prefill_tokens
=
0
,
num_decode_tokens
=
self
.
num_decode_tokens
,
slot_mapping
=
self
.
slot_mapping
[
self
.
num_prefill_tokens
:],
multi_modal_placeholder_index_maps
=
None
,
seq_lens
=
None
,
seq_lens_tensor
=
self
.
seq_lens_tensor
[
self
.
num_prefills
:],
max_decode_query_len
=
self
.
max_decode_query_len
,
...
...
@@ -297,6 +302,9 @@ class FlashAttentionMetadataBuilder(
self
.
context_lens
:
List
[
int
]
=
[]
self
.
block_tables
:
List
[
List
[
int
]]
=
[]
self
.
curr_seq_lens
:
List
[
int
]
=
[]
self
.
multimodal_placeholder_maps
:
Dict
[
str
,
MultiModalPlaceholderMap
]
=
defaultdict
(
MultiModalPlaceholderMap
)
self
.
num_prefills
=
0
self
.
num_prefill_tokens
=
0
self
.
num_decode_tokens
=
0
...
...
@@ -327,6 +335,12 @@ class FlashAttentionMetadataBuilder(
self
.
context_lens
.
append
(
context_len
)
if
is_prompt
:
mm_maps
=
inter_data
.
multi_modal_placeholder_maps
if
mm_maps
:
for
modality
,
placeholders
in
mm_maps
.
items
():
self
.
multimodal_placeholder_maps
[
modality
].
extend
(
placeholders
)
self
.
num_prefills
+=
1
self
.
num_prefill_tokens
+=
token_len
self
.
prefill_seq_lens
.
append
(
seq_len
)
...
...
@@ -449,6 +463,11 @@ class FlashAttentionMetadataBuilder(
seq_start_loc
=
torch
.
zeros
(
seq_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
placeholder_index_maps
=
{
modality
:
placeholder_map
.
index_map
()
for
modality
,
placeholder_map
in
self
.
multimodal_placeholder_maps
.
items
()
}
torch
.
cumsum
(
seq_lens_tensor
,
dim
=
0
,
dtype
=
seq_start_loc
.
dtype
,
...
...
@@ -464,6 +483,7 @@ class FlashAttentionMetadataBuilder(
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
seq_lens
=
seq_lens
,
multi_modal_placeholder_index_maps
=
placeholder_index_maps
,
seq_lens_tensor
=
seq_lens_tensor
,
max_query_len
=
max_query_len
,
max_decode_query_len
=
max_decode_query_len
,
...
...
vllm/attention/backends/flashinfer.py
View file @
6c0b7f54
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Type
from
vllm.multimodal
import
MultiModalPlaceholderMap
try
:
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
from
flashinfer.decode
import
CUDAGraphBatchDecodeWithPagedKVCacheWrapper
...
...
@@ -215,6 +218,7 @@ class FlashInferState(AttentionState):
attn_metadata
=
self
.
runner
.
attn_backend
.
make_metadata
(
num_prefills
=
0
,
slot_mapping
=
self
.
_graph_slot_mapping
[:
batch_size
],
multi_modal_placeholder_index_maps
=
None
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
batch_size
,
max_prefill_seq_len
=
0
,
...
...
@@ -470,6 +474,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self
.
context_lens
:
List
[
int
]
=
[]
self
.
block_tables
:
List
[
List
[
int
]]
=
[]
self
.
curr_seq_lens
:
List
[
int
]
=
[]
self
.
multimodal_placeholder_maps
:
Dict
[
str
,
MultiModalPlaceholderMap
]
=
defaultdict
(
MultiModalPlaceholderMap
)
self
.
num_prefills
=
0
self
.
num_prefill_tokens
=
0
self
.
num_decode_tokens
=
0
...
...
@@ -519,6 +526,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
inter_data
.
curr_sliding_window_blocks
):
self
.
context_lens
.
append
(
context_len
)
if
is_prompt
:
mm_maps
=
inter_data
.
multi_modal_placeholder_maps
if
mm_maps
:
for
modality
,
placeholders
in
mm_maps
.
items
():
self
.
multimodal_placeholder_maps
[
modality
].
extend
(
placeholders
)
self
.
num_prefills
+=
1
self
.
num_prefill_tokens
+=
token_len
self
.
prefill_seq_lens
.
append
(
seq_len
)
...
...
@@ -651,6 +663,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
seq_start_loc
=
torch
.
zeros
(
seq_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
placeholder_index_maps
=
{
modality
:
placeholder_map
.
index_map
()
for
modality
,
placeholder_map
in
self
.
multimodal_placeholder_maps
.
items
()
}
torch
.
cumsum
(
seq_lens_tensor
,
dim
=
0
,
dtype
=
seq_start_loc
.
dtype
,
...
...
@@ -694,6 +711,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
decode_query_len
=
decode_query_len
,
num_prefills
=
self
.
num_prefills
,
slot_mapping
=
slot_mapping_tensor
,
multi_modal_placeholder_index_maps
=
placeholder_index_maps
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
max_prefill_seq_len
=
max_prefill_seq_len
,
...
...
vllm/attention/backends/placeholder_attn.py
View file @
6c0b7f54
from
collections
import
defaultdict
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
,
Type
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
...
...
@@ -7,6 +8,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata
,
AttentionMetadataBuilder
)
from
vllm.attention.backends.utils
import
CommonAttentionState
from
vllm.multimodal
import
MultiModalPlaceholderMap
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
ModelInputForGPUBuilder
...
...
@@ -135,6 +137,8 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
0
,
slot_mapping
=
slot_mapping
,
multi_modal_placeholder_index_maps
=
self
.
multi_modal_placeholder_index_maps
,
seq_lens
=
self
.
seq_lens
[:
self
.
num_prefills
],
seq_lens_tensor
=
self
.
seq_lens_tensor
[:
self
.
num_prefills
],
max_decode_query_len
=
0
,
...
...
@@ -167,6 +171,7 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
num_prefill_tokens
=
0
,
num_decode_tokens
=
self
.
num_decode_tokens
,
slot_mapping
=
slot_mapping
,
multi_modal_placeholder_index_maps
=
None
,
seq_lens
=
None
,
seq_lens_tensor
=
self
.
seq_lens_tensor
[
self
.
num_prefills
:],
max_decode_query_len
=
self
.
max_decode_query_len
,
...
...
@@ -189,6 +194,9 @@ class PlaceholderAttentionMetadataBuilder(
self
.
prefill_seq_lens
:
List
[
int
]
=
[]
self
.
context_lens
:
List
[
int
]
=
[]
self
.
curr_seq_lens
:
List
[
int
]
=
[]
self
.
multimodal_placeholder_maps
:
Dict
[
str
,
MultiModalPlaceholderMap
]
=
defaultdict
(
MultiModalPlaceholderMap
)
self
.
num_prefills
=
0
self
.
num_prefill_tokens
=
0
self
.
num_decode_tokens
=
0
...
...
@@ -213,6 +221,12 @@ class PlaceholderAttentionMetadataBuilder(
self
.
context_lens
.
append
(
context_len
)
if
is_prompt
:
mm_maps
=
inter_data
.
multi_modal_placeholder_maps
if
mm_maps
:
for
modality
,
placeholders
in
mm_maps
.
items
():
self
.
multimodal_placeholder_maps
[
modality
].
extend
(
placeholders
)
self
.
num_prefills
+=
1
self
.
num_prefill_tokens
+=
token_len
self
.
prefill_seq_lens
.
append
(
seq_len
)
...
...
@@ -280,6 +294,11 @@ class PlaceholderAttentionMetadataBuilder(
seq_start_loc
=
torch
.
zeros
(
seq_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
placeholder_index_maps
=
{
modality
:
placeholder_map
.
index_map
()
for
modality
,
placeholder_map
in
self
.
multimodal_placeholder_maps
.
items
()
}
torch
.
cumsum
(
seq_lens_tensor
,
dim
=
0
,
dtype
=
seq_start_loc
.
dtype
,
...
...
@@ -296,6 +315,7 @@ class PlaceholderAttentionMetadataBuilder(
return
PlaceholderAttentionMetadata
(
num_prefills
=
self
.
num_prefills
,
slot_mapping
=
slot_mapping
,
multi_modal_placeholder_index_maps
=
placeholder_index_maps
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
seq_lens
=
seq_lens
,
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
6c0b7f54
...
...
@@ -150,6 +150,8 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
0
,
slot_mapping
=
self
.
slot_mapping
[:
self
.
num_prefill_tokens
],
multi_modal_placeholder_index_maps
=
self
.
multi_modal_placeholder_index_maps
,
seq_lens
=
self
.
seq_lens
[:
self
.
num_prefills
],
seq_lens_tensor
=
self
.
seq_lens_tensor
[:
self
.
num_prefills
],
max_query_len
=
self
.
max_query_len
,
...
...
@@ -178,6 +180,7 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
num_prefill_tokens
=
0
,
num_decode_tokens
=
self
.
num_decode_tokens
,
slot_mapping
=
self
.
slot_mapping
[
self
.
num_prefill_tokens
:],
multi_modal_placeholder_index_maps
=
None
,
seq_lens
=
None
,
seq_lens_tensor
=
self
.
seq_lens_tensor
[
self
.
num_prefills
:],
max_query_len
=
None
,
...
...
vllm/attention/backends/utils.py
View file @
6c0b7f54
"""Attention backend utils"""
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Type
,
TypeVar
,
Union
...
...
@@ -7,6 +8,7 @@ import torch
from
vllm.attention
import
(
AttentionMetadata
,
AttentionMetadataBuilder
,
AttentionState
)
from
vllm.multimodal
import
MultiModalPlaceholderMap
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
if
TYPE_CHECKING
:
...
...
@@ -123,6 +125,9 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
self
.
context_lens
:
List
[
int
]
=
[]
self
.
block_tables
:
List
[
List
[
int
]]
=
[]
self
.
curr_seq_lens
:
List
[
int
]
=
[]
self
.
multimodal_placeholder_maps
:
Dict
[
str
,
MultiModalPlaceholderMap
]
=
defaultdict
(
MultiModalPlaceholderMap
)
self
.
num_prefills
=
0
self
.
num_prefill_tokens
=
0
self
.
num_decode_tokens
=
0
...
...
@@ -147,6 +152,12 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
inter_data
.
curr_sliding_window_blocks
):
self
.
context_lens
.
append
(
context_len
)
if
is_prompt
:
mm_maps
=
inter_data
.
multi_modal_placeholder_maps
if
mm_maps
:
for
modality
,
placeholders
in
mm_maps
.
items
():
self
.
multimodal_placeholder_maps
[
modality
].
extend
(
placeholders
)
self
.
num_prefills
+=
1
self
.
num_prefill_tokens
+=
token_len
self
.
prefill_seq_lens
.
append
(
seq_len
)
...
...
@@ -242,6 +253,11 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
seq_start_loc
=
torch
.
zeros
(
seq_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
placeholder_index_maps
=
{
modality
:
placeholder_map
.
index_map
()
for
modality
,
placeholder_map
in
self
.
multimodal_placeholder_maps
.
items
()
}
torch
.
cumsum
(
seq_lens_tensor
,
dim
=
0
,
dtype
=
seq_start_loc
.
dtype
,
...
...
@@ -254,6 +270,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
return
self
.
_metadata_cls
(
# type: ignore
num_prefills
=
self
.
num_prefills
,
slot_mapping
=
slot_mapping_tensor
,
multi_modal_placeholder_index_maps
=
placeholder_index_maps
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
seq_lens
=
seq_lens
,
...
...
@@ -305,6 +322,7 @@ class CommonAttentionState(AttentionState):
num_prefill_tokens
=
0
,
num_decode_tokens
=
batch_size
,
slot_mapping
=
self
.
_graph_slot_mapping
[:
batch_size
],
multi_modal_placeholder_index_maps
=
None
,
seq_lens
=
None
,
seq_lens_tensor
=
self
.
_graph_seq_lens
[:
batch_size
],
max_query_len
=
1
,
...
...
vllm/attention/backends/xformers.py
View file @
6c0b7f54
...
...
@@ -212,6 +212,8 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
0
,
slot_mapping
=
slot_mapping
,
multi_modal_placeholder_index_maps
=
self
.
multi_modal_placeholder_index_maps
,
seq_lens
=
seq_lens
,
seq_lens_tensor
=
seq_lens_tensor
,
max_query_len
=
self
.
max_query_len
,
...
...
@@ -255,6 +257,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
num_prefill_tokens
=
0
,
num_decode_tokens
=
self
.
num_decode_tokens
,
slot_mapping
=
slot_mapping
,
multi_modal_placeholder_index_maps
=
None
,
seq_lens_tensor
=
seq_lens_tensor
,
max_prefill_seq_len
=
0
,
max_decode_seq_len
=
self
.
max_decode_seq_len
,
...
...
vllm/core/scheduler.py
View file @
6c0b7f54
...
...
@@ -1308,6 +1308,8 @@ class Scheduler:
# `multi_modal_data` will be None.
multi_modal_data
=
seq_group
.
multi_modal_data
if
scheduler_outputs
.
num_prefill_groups
>
0
else
None
,
multi_modal_placeholders
=
seq_group
.
multi_modal_placeholders
if
scheduler_outputs
.
num_prefill_groups
>
0
else
None
,
mm_processor_kwargs
=
seq_group
.
mm_processor_kwargs
,
prompt_adapter_request
=
seq_group
.
prompt_adapter_request
,
)
...
...
vllm/inputs/__init__.py
View file @
6c0b7f54
...
...
@@ -3,7 +3,7 @@ from .data import (DecoderOnlyInputs, EncoderDecoderInputs,
SingletonPrompt
,
TextPrompt
,
TokenInputs
,
TokensPrompt
,
build_explicit_enc_dec_prompt
,
to_enc_dec_tuple_list
,
token_inputs
,
zip_enc_dec_prompts
)
from
.registry
import
InputContext
,
InputRegistry
from
.registry
import
DummyData
,
InputContext
,
InputRegistry
INPUT_REGISTRY
=
InputRegistry
()
"""
...
...
@@ -29,6 +29,7 @@ __all__ = [
"to_enc_dec_tuple_list"
,
"zip_enc_dec_prompts"
,
"INPUT_REGISTRY"
,
"DummyData"
,
"InputContext"
,
"InputRegistry"
,
]
...
...
vllm/inputs/data.py
View file @
6c0b7f54
...
...
@@ -4,7 +4,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List,
from
typing_extensions
import
NotRequired
,
TypedDict
,
TypeVar
if
TYPE_CHECKING
:
from
vllm.multimodal
import
MultiModalDataDict
from
vllm.multimodal
import
MultiModalDataDict
,
MultiModalPlaceholderDict
class
TextPrompt
(
TypedDict
):
...
...
@@ -136,6 +136,12 @@ class TokenInputs(TypedDict):
if the model supports it.
"""
multi_modal_placeholders
:
NotRequired
[
Optional
[
"MultiModalPlaceholderDict"
]]
"""
Placeholder ranges for the multi-modal data.
"""
mm_processor_kwargs
:
NotRequired
[
Optional
[
Dict
[
str
,
Any
]]]
"""
Optional multi-modal processor kwargs to be forwarded to the
...
...
@@ -149,6 +155,7 @@ def token_inputs(
prompt_token_ids
:
List
[
int
],
prompt
:
Optional
[
str
]
=
None
,
multi_modal_data
:
Optional
[
"MultiModalDataDict"
]
=
None
,
multi_modal_placeholders
:
Optional
[
"MultiModalPlaceholderDict"
]
=
None
,
mm_processor_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
TokenInputs
:
"""Construct :class:`TokenInputs` from optional values."""
...
...
@@ -158,6 +165,8 @@ def token_inputs(
inputs
[
"prompt"
]
=
prompt
if
multi_modal_data
is
not
None
:
inputs
[
"multi_modal_data"
]
=
multi_modal_data
if
multi_modal_placeholders
is
not
None
:
inputs
[
"multi_modal_placeholders"
]
=
multi_modal_placeholders
if
mm_processor_kwargs
is
not
None
:
inputs
[
"mm_processor_kwargs"
]
=
mm_processor_kwargs
...
...
vllm/inputs/registry.py
View file @
6c0b7f54
import
functools
from
collections
import
UserDict
from
dataclasses
import
dataclass
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
Mapping
,
Optional
,
Protocol
,
Tuple
,
Type
)
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
Mapping
,
NamedTuple
,
Optional
,
Protocol
,
Type
)
from
torch
import
nn
from
transformers
import
PretrainedConfig
...
...
@@ -16,7 +16,8 @@ from .data import DecoderOnlyInputs
if
TYPE_CHECKING
:
from
vllm.config
import
ModelConfig
from
vllm.multimodal
import
MultiModalDataDict
,
MultiModalRegistry
from
vllm.multimodal
import
(
MultiModalDataDict
,
MultiModalPlaceholderDict
,
MultiModalRegistry
)
from
vllm.sequence
import
SequenceData
logger
=
init_logger
(
__name__
)
...
...
@@ -63,6 +64,14 @@ class InputContext:
N
=
TypeVar
(
"N"
,
bound
=
Type
[
nn
.
Module
])
class
DummyData
(
NamedTuple
):
"""Dummy data used for profiling."""
seq_data
:
"SequenceData"
multi_modal_data
:
Optional
[
"MultiModalDataDict"
]
=
None
multi_modal_placeholders
:
Optional
[
"MultiModalPlaceholderDict"
]
=
None
class
DummyDataFactory
(
Protocol
):
def
__call__
(
...
...
@@ -71,7 +80,7 @@ class DummyDataFactory(Protocol):
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
**
mm_processor_kwargs
:
Any
,
)
->
Tuple
[
"SequenceData"
,
Optional
[
"MultiModalDataDict"
]]
:
)
->
DummyData
:
"""
Create dummy data to be inputted into the model.
...
...
@@ -123,7 +132,7 @@ class InputRegistry:
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
Tuple
[
"SequenceData"
,
Optional
[
"MultiModalDataDict"
]]
:
)
->
DummyData
:
"""
The default dummy data factory represents the longest possible text
that can be inputted to the model.
...
...
@@ -134,10 +143,7 @@ class InputRegistry:
# Avoid circular import
from
vllm.sequence
import
SequenceData
dummy_seq_data
=
SequenceData
.
from_prompt_token_counts
((
0
,
seq_len
))
dummy_multi_modal_data
=
None
return
dummy_seq_data
,
dummy_multi_modal_data
return
DummyData
(
SequenceData
.
from_prompt_token_counts
((
0
,
seq_len
)))
def
register_dummy_data
(
self
,
factory
:
DummyDataFactory
):
"""
...
...
@@ -195,7 +201,7 @@ class InputRegistry:
seq_len
:
int
,
mm_registry
:
"MultiModalRegistry"
,
is_encoder_data
:
bool
=
False
,
)
->
Tuple
[
"SequenceData"
,
Optional
[
"MultiModalDataDict"
]]
:
)
->
DummyData
:
"""
Create dummy data for profiling the memory usage of a model.
...
...
@@ -220,12 +226,12 @@ class InputRegistry:
mm_processor_kwargs
=
get_allowed_kwarg_only_overrides
(
dummy_factory
,
overrides
=
model_config
.
mm_processor_kwargs
)
seq_data
,
mm_data
=
dummy_factory
(
InputContext
(
model_config
),
seq_len
,
du
mm
y
_data
=
dummy_factory
(
InputContext
(
model_config
),
seq_len
,
_MultiModalCounts
(
mm_counts
),
**
mm_processor_kwargs
)
# Having more tokens is over-conservative but otherwise fine
num_tokens
=
seq_data
.
prompt_token_ids
num_tokens
=
dummy_data
.
seq_data
.
prompt_token_ids
if
len
(
num_tokens
)
<
seq_len
:
if
is_encoder_data
:
print_warning_once
(
...
...
@@ -235,15 +241,15 @@ class InputRegistry:
raise
AssertionError
(
f
"Expected at least
{
seq_len
}
dummy tokens for profiling, "
f
"but found
{
len
(
num_tokens
)
}
tokens instead."
)
if
mm
_data
is
not
None
:
for
k
,
v
in
mm
_data
.
items
():
if
dummy_data
.
multi_modal
_data
is
not
None
:
for
k
,
v
in
dummy_data
.
multi_modal
_data
.
items
():
num_items
=
len
(
v
)
if
isinstance
(
v
,
list
)
else
1
num_expected
=
mm_counts
[
k
]
assert
num_items
>=
num_expected
,
(
f
"Expected at least
{
num_expected
}
dummy '
{
k
}
' instances "
f
"for profiling, but found
{
num_items
}
instances instead."
)
return
seq_data
,
mm_data
return
du
mm
y
_data
def
_default_input_processor
(
self
,
...
...
vllm/model_executor/models/blip.py
View file @
6c0b7f54
...
...
@@ -98,6 +98,11 @@ def input_processor_for_blip(
if
multi_modal_data
is
None
or
"image"
not
in
multi_modal_data
:
return
inputs
if
"multi_modal_placeholders"
in
inputs
and
"image"
in
inputs
[
"multi_modal_placeholders"
]:
# The inputs already have placeholders.
return
inputs
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
)
if
image_feature_size_override
is
None
:
...
...
@@ -105,7 +110,7 @@ def input_processor_for_blip(
else
:
image_feature_size
=
image_feature_size_override
new_prompt
,
new_token_ids
=
repeat_and_pad_placeholder_tokens
(
new_prompt
,
new_token_ids
,
ranges
=
repeat_and_pad_placeholder_tokens
(
tokenizer
,
inputs
.
get
(
"prompt"
),
inputs
[
"prompt_token_ids"
],
...
...
@@ -116,7 +121,8 @@ def input_processor_for_blip(
# NOTE: Create a defensive copy of the original inputs
return
token_inputs
(
prompt_token_ids
=
new_token_ids
,
prompt
=
new_prompt
,
multi_modal_data
=
multi_modal_data
)
multi_modal_data
=
multi_modal_data
,
multi_modal_placeholders
=
{
"image"
:
ranges
})
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa
...
...
vllm/model_executor/models/blip2.py
View file @
6c0b7f54
...
...
@@ -9,13 +9,14 @@ from transformers import (Blip2Config, Blip2QFormerConfig, Blip2VisionConfig,
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.inputs
import
(
INPUT_REGISTRY
,
DecoderOnlyInputs
,
InputContext
,
token_inputs
)
from
vllm.inputs
import
(
INPUT_REGISTRY
,
DecoderOnlyInputs
,
DummyData
,
InputContext
,
token_inputs
)
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.utils
import
consecutive_placeholder_ranges
from
vllm.sequence
import
IntermediateTensors
,
SequenceData
from
.blip
import
(
BlipVisionModel
,
dummy_image_for_blip
,
...
...
@@ -425,7 +426,11 @@ def dummy_seq_data_for_blip2(
return
SequenceData
.
from_prompt_token_counts
(
(
image_token_id
,
image_feature_size
*
num_images
),
(
0
,
seq_len
-
image_feature_size
*
num_images
),
)
),
{
"image"
:
consecutive_placeholder_ranges
(
num_items
=
num_images
,
item_size
=
image_feature_size
)
}
def
dummy_data_for_blip2
(
ctx
:
InputContext
,
seq_len
:
int
,
...
...
@@ -434,7 +439,7 @@ def dummy_data_for_blip2(ctx: InputContext, seq_len: int,
vision_config
=
hf_config
.
vision_config
num_images
=
mm_counts
[
"image"
]
seq_data
=
dummy_seq_data_for_blip2
(
seq_data
,
ranges
=
dummy_seq_data_for_blip2
(
hf_config
,
seq_len
,
num_images
,
...
...
@@ -444,7 +449,7 @@ def dummy_data_for_blip2(ctx: InputContext, seq_len: int,
if
isinstance
(
vision_config
,
Blip2VisionConfig
):
mm_data
=
dummy_image_for_blip
(
vision_config
,
num_images
)
return
seq_data
,
mm_data
return
DummyData
(
seq_data
,
mm_data
,
ranges
)
msg
=
f
"Unsupported vision config:
{
type
(
vision_config
)
}
"
raise
NotImplementedError
(
msg
)
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment