Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3c23ce2d
"vllm/vscode:/vscode.git/clone" did not exist on "3de2ed767f64be006586b4c97e1f6524a75b4748"
Commit
3c23ce2d
authored
Dec 25, 2024
by
zhuwenwen
Browse files
[Model] Add Qwen2-Audio model support
parent
b4cf96af
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
902 additions
and
26 deletions
+902
-26
docs/source/models/supported_models.rst
docs/source/models/supported_models.rst
+5
-0
examples/offline_inference_audio_language.py
examples/offline_inference_audio_language.py
+39
-16
vllm/entrypoints/chat_utils.py
vllm/entrypoints/chat_utils.py
+3
-0
vllm/inputs/__init__.py
vllm/inputs/__init__.py
+6
-3
vllm/inputs/data.py
vllm/inputs/data.py
+116
-6
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+1
-0
vllm/model_executor/models/interfaces.py
vllm/model_executor/models/interfaces.py
+41
-1
vllm/model_executor/models/qwen2_audio.py
vllm/model_executor/models/qwen2_audio.py
+461
-0
vllm/model_executor/models/ultravox.py
vllm/model_executor/models/ultravox.py
+3
-0
vllm/multimodal/__init__.py
vllm/multimodal/__init__.py
+2
-0
vllm/multimodal/inputs.py
vllm/multimodal/inputs.py
+225
-0
No files found.
docs/source/models/supported_models.rst
View file @
3c23ce2d
...
@@ -284,6 +284,11 @@ Multimodal Language Models
...
@@ -284,6 +284,11 @@ Multimodal Language Models
- Image\ :sup:`E+`
- Image\ :sup:`E+`
- :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc.
- :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc.
-
-
* - :code:`Qwen2AudioForConditionalGeneration`
- Qwen2-Audio
- T + A\ :sup:`+`
- :code:`Qwen/Qwen2-Audio-7B-Instruct`
-
* - :code:`Qwen2VLForConditionalGeneration`
* - :code:`Qwen2VLForConditionalGeneration`
- Qwen2-VL
- Qwen2-VL
- Image\ :sup:`+` / Video\ :sup:`+`
- Image\ :sup:`+` / Video\ :sup:`+`
...
...
examples/offline_inference_audio_language.py
View file @
3c23ce2d
...
@@ -12,14 +12,15 @@ from vllm.assets.audio import AudioAsset
...
@@ -12,14 +12,15 @@ from vllm.assets.audio import AudioAsset
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
import
FlexibleArgumentParser
audio_assets
=
[
AudioAsset
(
"mary_had_lamb"
),
AudioAsset
(
"winning_call"
)]
audio_assets
=
[
AudioAsset
(
"mary_had_lamb"
),
AudioAsset
(
"winning_call"
)]
question_per_audio_count
=
[
question_per_audio_count
=
{
"What is recited in the audio?"
,
0
:
"What is 1+1?"
,
"What sport and what nursery rhyme are referenced?"
1
:
"What is recited in the audio?"
,
]
2
:
"What sport and what nursery rhyme are referenced?"
}
# Ultravox 0.3
# Ultravox 0.3
def
run_ultravox
(
question
,
audio_count
):
def
run_ultravox
(
question
:
str
,
audio_count
:
int
):
model_name
=
"fixie-ai/ultravox-v0_3"
model_name
=
"fixie-ai/ultravox-v0_3"
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
)
...
@@ -42,9 +43,29 @@ def run_ultravox(question, audio_count):
...
@@ -42,9 +43,29 @@ def run_ultravox(question, audio_count):
return
llm
,
prompt
,
stop_token_ids
return
llm
,
prompt
,
stop_token_ids
model_example_map
=
{
# Qwen2-Audio
"ultravox"
:
run_ultravox
,
def
run_qwen2_audio
(
question
:
str
,
audio_count
:
int
):
}
model_name
=
"Qwen/Qwen2-Audio-7B-Instruct"
llm
=
LLM
(
model
=
model_name
,
max_model_len
=
4096
,
max_num_seqs
=
5
,
limit_mm_per_prompt
=
{
"audio"
:
audio_count
})
audio_in_prompt
=
""
.
join
([
f
"Audio
{
idx
+
1
}
: "
f
"<|audio_bos|><|AUDIO|><|audio_eos|>
\n
"
for
idx
in
range
(
audio_count
)
])
prompt
=
(
"<|im_start|>system
\n
You are a helpful assistant.<|im_end|>
\n
"
"<|im_start|>user
\n
"
f
"
{
audio_in_prompt
}{
question
}
<|im_end|>
\n
"
"<|im_start|>assistant
\n
"
)
stop_token_ids
=
None
return
llm
,
prompt
,
stop_token_ids
model_example_map
=
{
"ultravox"
:
run_ultravox
,
"qwen2_audio"
:
run_qwen2_audio
}
def
main
(
args
):
def
main
(
args
):
...
@@ -54,7 +75,7 @@ def main(args):
...
@@ -54,7 +75,7 @@ def main(args):
audio_count
=
args
.
num_audios
audio_count
=
args
.
num_audios
llm
,
prompt
,
stop_token_ids
=
model_example_map
[
model
](
llm
,
prompt
,
stop_token_ids
=
model_example_map
[
model
](
question_per_audio_count
[
audio_count
-
1
],
audio_count
)
question_per_audio_count
[
audio_count
],
audio_count
)
# We set temperature to 0.2 so that outputs can be different
# We set temperature to 0.2 so that outputs can be different
# even when all prompts are identical when running batch inference.
# even when all prompts are identical when running batch inference.
...
@@ -62,16 +83,18 @@ def main(args):
...
@@ -62,16 +83,18 @@ def main(args):
max_tokens
=
64
,
max_tokens
=
64
,
stop_token_ids
=
stop_token_ids
)
stop_token_ids
=
stop_token_ids
)
assert
args
.
num_prompts
>
0
mm_data
=
{}
inputs
=
{
if
audio_count
>
0
:
"prompt"
:
prompt
,
mm_data
=
{
"multi_modal_data"
:
{
"audio"
:
[
"audio"
:
[
asset
.
audio_and_sample_rate
asset
.
audio_and_sample_rate
for
asset
in
audio_assets
[:
audio_count
]
for
asset
in
audio_assets
[:
audio_count
]
]
]
},
}
}
assert
args
.
num_prompts
>
0
inputs
=
{
"prompt"
:
prompt
,
"multi_modal_data"
:
mm_data
}
if
args
.
num_prompts
>
1
:
if
args
.
num_prompts
>
1
:
# Batch inference
# Batch inference
inputs
=
[
inputs
]
*
args
.
num_prompts
inputs
=
[
inputs
]
*
args
.
num_prompts
...
@@ -100,7 +123,7 @@ if __name__ == "__main__":
...
@@ -100,7 +123,7 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--num-audios"
,
parser
.
add_argument
(
"--num-audios"
,
type
=
int
,
type
=
int
,
default
=
1
,
default
=
1
,
choices
=
[
1
,
2
],
choices
=
[
0
,
1
,
2
],
help
=
"Number of audio items per prompt."
)
help
=
"Number of audio items per prompt."
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
...
vllm/entrypoints/chat_utils.py
View file @
3c23ce2d
...
@@ -168,6 +168,9 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
...
@@ -168,6 +168,9 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
elif
modality
==
"audio"
:
elif
modality
==
"audio"
:
if
model_type
==
"ultravox"
:
if
model_type
==
"ultravox"
:
return
"<|reserved_special_token_0|>"
return
"<|reserved_special_token_0|>"
if
model_type
==
"qwen2_audio"
:
return
(
f
"Audio
{
current_count
}
: "
f
"<|audio_bos|><|AUDIO|><|audio_eos|>"
)
raise
TypeError
(
f
"Unknown model type:
{
model_type
}
"
)
raise
TypeError
(
f
"Unknown model type:
{
model_type
}
"
)
elif
modality
==
"video"
:
elif
modality
==
"video"
:
if
model_type
==
"qwen2_vl"
:
if
model_type
==
"qwen2_vl"
:
...
...
vllm/inputs/__init__.py
View file @
3c23ce2d
from
.data
import
(
EncoderDecoderLLMInputs
,
ExplicitEncoderDecoderPrompt
,
from
.data
import
(
DecoderOnlyInputs
,
EncoderDecoderLLMInputs
,
ExplicitEncoderDecoderPrompt
,
LLMInputs
,
PromptInputs
,
SingletonPromptInputs
,
TextPrompt
,
LLMInputs
,
PromptInputs
,
SingletonPromptInputs
,
TextPrompt
,
TokensPrompt
,
build_explicit_enc_dec_prompt
,
TokensPrompt
,
to_enc_dec_tuple_list
,
zip_enc_dec_prompts
)
build_explicit_enc_dec_prompt
,
to_enc_dec_tuple_list
,
token_inputs
,
zip_enc_dec_prompts
)
from
.registry
import
InputContext
,
InputRegistry
from
.registry
import
InputContext
,
InputRegistry
INPUT_REGISTRY
=
InputRegistry
()
INPUT_REGISTRY
=
InputRegistry
()
...
@@ -20,6 +21,8 @@ __all__ = [
...
@@ -20,6 +21,8 @@ __all__ = [
"SingletonPromptInputs"
,
"SingletonPromptInputs"
,
"ExplicitEncoderDecoderPrompt"
,
"ExplicitEncoderDecoderPrompt"
,
"LLMInputs"
,
"LLMInputs"
,
"token_inputs"
,
"DecoderOnlyInputs"
,
"EncoderDecoderLLMInputs"
,
"EncoderDecoderLLMInputs"
,
"build_explicit_enc_dec_prompt"
,
"build_explicit_enc_dec_prompt"
,
"to_enc_dec_tuple_list"
,
"to_enc_dec_tuple_list"
,
...
...
vllm/inputs/data.py
View file @
3c23ce2d
from
typing
import
(
TYPE_CHECKING
,
Generic
,
Iterable
,
List
,
Optional
,
Tuple
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
Generic
,
Iterable
,
List
,
Literal
,
Union
)
Optional
,
Tuple
,
Union
)
from
typing_extensions
import
NotRequired
,
TypedDict
,
TypeVar
from
typing_extensions
import
NotRequired
,
TypedDict
,
TypeVar
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.multimodal
import
MultiModalDataDict
from
vllm.multimodal
import
MultiModalDataDict
,
MultiModalPlaceholderDict
from
vllm.multimodal.inputs
import
MultiModalInputsV2
class
TextPrompt
(
TypedDict
):
class
TextPrompt
(
TypedDict
):
...
@@ -19,6 +20,14 @@ class TextPrompt(TypedDict):
...
@@ -19,6 +20,14 @@ class TextPrompt(TypedDict):
if the model supports it.
if the model supports it.
"""
"""
mm_processor_kwargs
:
NotRequired
[
Dict
[
str
,
Any
]]
"""
Optional multi-modal processor kwargs to be forwarded to the
multimodal input mapper & processor. Note that if multiple modalities
have registered mappers etc for the model being considered, we attempt
to pass the mm_processor_kwargs to each of them.
"""
class
TokensPrompt
(
TypedDict
):
class
TokensPrompt
(
TypedDict
):
"""Schema for a tokenized prompt."""
"""Schema for a tokenized prompt."""
...
@@ -28,10 +37,18 @@ class TokensPrompt(TypedDict):
...
@@ -28,10 +37,18 @@ class TokensPrompt(TypedDict):
multi_modal_data
:
NotRequired
[
"MultiModalDataDict"
]
multi_modal_data
:
NotRequired
[
"MultiModalDataDict"
]
"""
"""
Optional multi-modal data to pass to the model,
DEPRECATED:
Optional multi-modal data to pass to the model,
if the model supports it.
if the model supports it.
"""
"""
mm_processor_kwargs
:
NotRequired
[
Dict
[
str
,
Any
]]
"""
DEPRECATED: Optional multi-modal processor kwargs to be forwarded to the
multimodal input mapper & processor. Note that if multiple modalities
have registered mappers etc for the model being considered, we attempt
to pass the mm_processor_kwargs to each of them.
"""
SingletonPromptInputs
=
Union
[
str
,
TextPrompt
,
TokensPrompt
]
SingletonPromptInputs
=
Union
[
str
,
TextPrompt
,
TokensPrompt
]
"""
"""
...
@@ -88,6 +105,8 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
...
@@ -88,6 +105,8 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
decoder_prompt
:
Optional
[
_T2_co
]
decoder_prompt
:
Optional
[
_T2_co
]
mm_processor_kwargs
:
NotRequired
[
Dict
[
str
,
Any
]]
PromptInputs
=
Union
[
SingletonPromptInputs
,
ExplicitEncoderDecoderPrompt
]
PromptInputs
=
Union
[
SingletonPromptInputs
,
ExplicitEncoderDecoderPrompt
]
"""
"""
...
@@ -101,6 +120,71 @@ both decoder-only and encoder/decoder input types:
...
@@ -101,6 +120,71 @@ both decoder-only and encoder/decoder input types:
"""
"""
class
TokenInputs
(
TypedDict
):
"""Represents token-based inputs."""
type
:
Literal
[
"token"
]
"""The type of inputs."""
prompt_token_ids
:
List
[
int
]
"""The token IDs of the prompt."""
prompt
:
NotRequired
[
str
]
"""
The original prompt text corresponding to the token IDs, if available.
"""
multi_modal_data
:
NotRequired
[
"MultiModalDataDict"
]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
multi_modal_placeholders
:
NotRequired
[
"MultiModalPlaceholderDict"
]
"""
Placeholder ranges for the multi-modal data.
"""
mm_processor_kwargs
:
NotRequired
[
Dict
[
str
,
Any
]]
"""
Optional multi-modal processor kwargs to be forwarded to the
multimodal input mapper & processor. Note that if multiple modalities
have registered mappers etc for the model being considered, we attempt
to pass the mm_processor_kwargs to each of them.
"""
def
token_inputs
(
prompt_token_ids
:
List
[
int
],
prompt
:
Optional
[
str
]
=
None
,
multi_modal_data
:
Optional
[
"MultiModalDataDict"
]
=
None
,
multi_modal_placeholders
:
Optional
[
"MultiModalPlaceholderDict"
]
=
None
,
mm_processor_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
TokenInputs
:
"""Construct :class:`TokenInputs` from optional values."""
inputs
=
TokenInputs
(
type
=
"token"
,
prompt_token_ids
=
prompt_token_ids
)
if
prompt
is
not
None
:
inputs
[
"prompt"
]
=
prompt
if
multi_modal_data
is
not
None
:
inputs
[
"multi_modal_data"
]
=
multi_modal_data
if
multi_modal_placeholders
is
not
None
:
inputs
[
"multi_modal_placeholders"
]
=
multi_modal_placeholders
if
mm_processor_kwargs
is
not
None
:
inputs
[
"mm_processor_kwargs"
]
=
mm_processor_kwargs
return
inputs
DecoderOnlyInputs
=
Union
[
TokenInputs
,
"MultiModalInputsV2"
]
"""
The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor.
This specifies the data required for decoder-only models.
"""
class
LLMInputs
(
TypedDict
):
class
LLMInputs
(
TypedDict
):
"""
"""
The inputs in :class:`~vllm.LLMEngine` before they are
The inputs in :class:`~vllm.LLMEngine` before they are
...
@@ -146,6 +230,32 @@ class EncoderDecoderLLMInputs(LLMInputs):
...
@@ -146,6 +230,32 @@ class EncoderDecoderLLMInputs(LLMInputs):
"""
"""
class
EncoderDecoderInputs
(
TypedDict
):
"""
The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor.
This specifies the required data for encoder-decoder models.
"""
encoder
:
Union
[
TokenInputs
,
"MultiModalInputsV2"
]
"""The inputs for the encoder portion."""
decoder
:
Union
[
TokenInputs
,
"MultiModalInputsV2"
]
"""The inputs for the decoder portion."""
SingletonInputs
=
Union
[
TokenInputs
,
"MultiModalInputsV2"
]
"""
A processed :class:`SingletonPrompt` which can be passed to
:class:`vllm.sequence.Sequence`.
"""
ProcessorInputs
=
Union
[
DecoderOnlyInputs
,
EncoderDecoderInputs
]
"""
The inputs to :data:`vllm.inputs.InputProcessor`.
"""
_T1
=
TypeVar
(
"_T1"
,
_T1
=
TypeVar
(
"_T1"
,
bound
=
SingletonPromptInputs
,
bound
=
SingletonPromptInputs
,
default
=
SingletonPromptInputs
)
default
=
SingletonPromptInputs
)
...
...
vllm/model_executor/models/__init__.py
View file @
3c23ce2d
...
@@ -103,6 +103,7 @@ _MULTIMODAL_MODELS = {
...
@@ -103,6 +103,7 @@ _MULTIMODAL_MODELS = {
"QWenLMHeadModel"
:
(
"qwen"
,
"QWenLMHeadModel"
),
"QWenLMHeadModel"
:
(
"qwen"
,
"QWenLMHeadModel"
),
"Qwen2VLForConditionalGeneration"
:
(
"qwen2_vl"
,
"Qwen2VLForConditionalGeneration"
:
(
"qwen2_vl"
,
"Qwen2VLForConditionalGeneration"
),
"Qwen2VLForConditionalGeneration"
),
"Qwen2AudioForConditionalGeneration"
:
(
"qwen2_audio"
,
"Qwen2AudioForConditionalGeneration"
),
# noqa: E501
"UltravoxModel"
:
(
"ultravox"
,
"UltravoxModel"
),
"UltravoxModel"
:
(
"ultravox"
,
"UltravoxModel"
),
"MllamaForConditionalGeneration"
:
(
"mllama"
,
"MllamaForConditionalGeneration"
:
(
"mllama"
,
"MllamaForConditionalGeneration"
),
"MllamaForConditionalGeneration"
),
...
...
vllm/model_executor/models/interfaces.py
View file @
3c23ce2d
from
typing
import
(
ClassVar
,
Dict
,
List
,
Literal
,
Optional
,
Protocol
,
Type
,
from
typing
import
(
TYPE_CHECKING
,
ClassVar
,
Dict
,
List
,
Literal
,
Optional
,
Protocol
,
Type
,
Union
,
overload
,
runtime_checkable
)
Union
,
overload
,
runtime_checkable
)
import
torch
from
typing_extensions
import
TypeIs
from
typing_extensions
import
TypeIs
from
vllm.config
import
LoRAConfig
,
MultiModalConfig
,
SchedulerConfig
from
vllm.config
import
LoRAConfig
,
MultiModalConfig
,
SchedulerConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
if
TYPE_CHECKING
:
from
vllm.config
import
LoRAConfig
,
MultiModalConfig
,
SchedulerConfig
from
vllm.sequence
import
IntermediateTensors
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -145,6 +150,41 @@ def _supports_lora(
...
@@ -145,6 +150,41 @@ def _supports_lora(
return
isinstance
(
model
,
SupportsLoRA
)
return
isinstance
(
model
,
SupportsLoRA
)
@
runtime_checkable
class
SupportsPP
(
Protocol
):
"""The interface required for all models that support pipeline parallel."""
supports_pp
:
ClassVar
[
Literal
[
True
]]
=
True
"""
A flag that indicates this model supports pipeline parallel.
Note:
There is no need to redefine this flag if this class is in the
MRO of your model class.
"""
def
make_empty_intermediate_tensors
(
self
,
batch_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
)
->
"IntermediateTensors"
:
"""Called when PP rank > 0 for profiling purposes."""
...
def
forward
(
self
,
*
,
intermediate_tensors
:
Optional
[
"IntermediateTensors"
],
)
->
Union
[
torch
.
Tensor
,
"IntermediateTensors"
]:
"""
Accept :class:`IntermediateTensors` when PP rank > 0.
Return :class:`IntermediateTensors` only for the last PP rank.
"""
...
@
runtime_checkable
@
runtime_checkable
class
HasInnerState
(
Protocol
):
class
HasInnerState
(
Protocol
):
"""The interface required for all models that has inner state."""
"""The interface required for all models that has inner state."""
...
...
vllm/model_executor/models/qwen2_audio.py
0 → 100644
View file @
3c23ce2d
# coding=utf-8
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
from
functools
import
lru_cache
from
typing
import
Iterable
,
List
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
import
librosa
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
transformers
import
Qwen2AudioConfig
,
Qwen2AudioEncoder
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.inputs
import
(
INPUT_REGISTRY
,
DecoderOnlyInputs
,
InputContext
,
token_inputs
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.models.qwen2
import
Qwen2Model
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalInputs
from
vllm.sequence
import
IntermediateTensors
,
SequenceData
from
.interfaces
import
SupportsMultiModal
,
SupportsPP
logger
=
init_logger
(
__name__
)
_KEYS_TO_MODIFY_MAPPING
=
{
"language_model.lm_head"
:
"lm_head"
,
"language_model.model"
:
"language_model"
,
}
# # === Audio Inputs === #
class
Qwen2AudioInputs
(
TypedDict
):
input_features
:
torch
.
Tensor
"""Shape:
`(num_audios, num_mel_bins, 3000)`
"""
feature_attention_mask
:
torch
.
Tensor
"""Shape: `(num_audios, 3000)`
"""
# === Audio Encoder === #
class
Qwen2AudioMultiModalProjector
(
nn
.
Module
):
def
__init__
(
self
,
audio_hidden_size
:
int
,
text_hidden_size
:
int
):
super
().
__init__
()
self
.
linear
=
nn
.
Linear
(
audio_hidden_size
,
text_hidden_size
,
bias
=
True
)
def
forward
(
self
,
audio_features
):
hidden_states
=
self
.
linear
(
audio_features
)
return
hidden_states
def
dummy_data_for_qwen2_audio
(
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
]):
num_audios
=
mm_counts
[
"audio"
]
max_llm_audio_tokens
=
get_max_qwen2_audio_audio_tokens
(
ctx
)
*
num_audios
if
seq_len
-
max_llm_audio_tokens
-
2
<
0
:
raise
RuntimeError
(
f
"Qwen2-Audio cannot process
{
num_audios
}
audios in a prompt, "
"please increase max_model_len or reduce audio limit by "
"--limit-mm-per-prompt."
)
audio_token_index
=
ctx
.
model_config
.
hf_config
.
audio_token_index
dummy_seqdata
=
SequenceData
.
from_prompt_token_counts
(
(
audio_token_index
,
max_llm_audio_tokens
),
(
0
,
seq_len
-
max_llm_audio_tokens
),
)
dummy_audio
=
np
.
full
((
max_llm_audio_tokens
*
2
*
2
*
160
,
),
0.
)
return
dummy_seqdata
,
{
"audio"
:
[(
dummy_audio
,
16000
)]
*
num_audios
}
def
get_processor
(
processor_name
:
str
,
*
args
,
trust_remote_code
:
bool
=
False
,
**
kwargs
,
):
"""Gets a processor for the given model name via HuggingFace.
Derived from `vllm.transformers_utils.image_processor.get_image_processor`.
"""
# don't put this import at the top level
# it will call torch.cuda.device_count()
from
transformers
import
AutoProcessor
try
:
processor
=
AutoProcessor
.
from_pretrained
(
processor_name
,
*
args
,
trust_remote_code
=
trust_remote_code
,
**
kwargs
)
except
ValueError
as
e
:
# If the error pertains to the processor class not existing or not
# currently being imported, suggest using the --trust-remote-code flag.
# Unlike AutoTokenizer, AutoProcessor does not separate such errors
if
not
trust_remote_code
:
err_msg
=
(
"Failed to load the processor. If the processor is "
"a custom processor not yet available in the HuggingFace "
"transformers library, consider setting "
"`trust_remote_code=True` in LLM or using the "
"`--trust-remote-code` flag in the CLI."
)
raise
RuntimeError
(
err_msg
)
from
e
else
:
raise
e
return
processor
cached_get_processor
=
lru_cache
(
get_processor
)
def
_get_feat_extract_output_lengths
(
input_lengths
:
torch
.
LongTensor
):
"""
Computes the output length of the convolutional layers
and the output length of the audio encoder
"""
input_lengths
=
(
input_lengths
-
1
)
//
2
+
1
output_lengths
=
(
input_lengths
-
2
)
//
2
+
1
return
input_lengths
,
output_lengths
def
get_max_qwen2_audio_audio_tokens
(
ctx
:
InputContext
)
->
int
:
max_source_position
=
(
ctx
.
model_config
.
hf_config
.
audio_config
.
max_source_positions
)
output_lengths
=
(
max_source_position
-
2
)
//
2
+
1
return
output_lengths
def
input_processor_for_qwen2_audio
(
ctx
:
InputContext
,
inputs
:
DecoderOnlyInputs
)
->
DecoderOnlyInputs
:
multi_modal_data
=
inputs
.
get
(
"multi_modal_data"
)
if
multi_modal_data
is
None
or
"audio"
not
in
multi_modal_data
:
return
inputs
audios
=
multi_modal_data
[
"audio"
]
if
not
isinstance
(
audios
,
list
):
audios
=
[
audios
]
if
len
(
audios
)
==
0
:
return
inputs
processor
=
cached_get_processor
(
ctx
.
model_config
.
model
)
resampled_audios
=
[
librosa
.
resample
(
audio
,
orig_sr
=
sampling_rate
,
target_sr
=
processor
.
feature_extractor
.
sampling_rate
)
for
audio
,
sampling_rate
in
audios
]
audio_input_lengths
=
np
.
array
(
[
min
(
3000
,
_
.
shape
[
0
]
//
160
+
1
)
for
_
in
resampled_audios
])
audio_feat_lengths
,
audio_output_lengths
=
_get_feat_extract_output_lengths
(
audio_input_lengths
)
audio_token_index
=
ctx
.
model_config
.
hf_config
.
audio_token_index
input_ids
=
inputs
[
'prompt_token_ids'
]
new_input_ids
=
[]
audio_num
=
input_ids
.
count
(
audio_token_index
)
assert
len
(
audio_input_lengths
)
==
audio_num
,
\
(
f
'The text input contains
{
audio_num
}
audio tokens, '
f
'but
{
len
(
audio_input_lengths
)
}
audios provided'
)
start
=
0
for
audio_idx
in
range
(
audio_num
):
end
=
input_ids
.
index
(
audio_token_index
,
start
)
new_input_ids
.
extend
(
input_ids
[
start
:
end
])
# text part
new_input_ids
.
extend
([
audio_token_index
]
*
audio_output_lengths
[
audio_idx
])
start
=
end
+
1
new_input_ids
.
extend
(
input_ids
[
start
:])
return
token_inputs
(
prompt_token_ids
=
new_input_ids
,
prompt
=
inputs
[
'prompt'
],
multi_modal_data
=
multi_modal_data
,
)
def
input_mapper_for_qwen2_audio
(
ctx
:
InputContext
,
multi_modal_data
:
Union
[
np
.
ndarray
,
List
[
np
.
ndarray
]],
)
->
MultiModalInputs
:
"""Input mapper for Qwen2-Audio."""
if
not
isinstance
(
multi_modal_data
,
list
):
multi_modal_data
=
[
multi_modal_data
]
if
len
(
multi_modal_data
)
==
0
:
return
MultiModalInputs
()
processor
=
cached_get_processor
(
ctx
.
model_config
.
model
)
audio_feature_extractor
=
processor
.
feature_extractor
if
audio_feature_extractor
is
None
:
raise
RuntimeError
(
"No HuggingFace audio_feature_extractor is available "
"to process the audio object"
)
try
:
resampled_audios
=
[
librosa
.
resample
(
audio
,
orig_sr
=
sampling_rate
,
target_sr
=
processor
.
feature_extractor
.
sampling_rate
)
for
audio
,
sampling_rate
in
multi_modal_data
]
batch_data
=
audio_feature_extractor
(
resampled_audios
,
sampling_rate
=
16000
,
return_attention_mask
=
True
,
padding
=
"max_length"
,
return_tensors
=
"pt"
).
data
batch_data
[
"feature_attention_mask"
]
=
batch_data
.
pop
(
"attention_mask"
)
except
Exception
:
logger
.
error
(
"Failed to process audio (%s)"
,
multi_modal_data
)
raise
return
MultiModalInputs
(
batch_data
)
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_data_for_qwen2_audio
)
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_qwen2_audio
)
@
MULTIMODAL_REGISTRY
.
register_input_mapper
(
"audio"
,
input_mapper_for_qwen2_audio
)
@
MULTIMODAL_REGISTRY
.
register_max_multimodal_tokens
(
"audio"
,
get_max_qwen2_audio_audio_tokens
)
class
Qwen2AudioForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
def
__init__
(
self
,
config
:
Qwen2AudioConfig
,
multimodal_config
:
MultiModalConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
self
.
audio_tower
=
Qwen2AudioEncoder
(
config
.
audio_config
)
self
.
multi_modal_projector
=
Qwen2AudioMultiModalProjector
(
config
.
audio_config
.
d_model
,
config
.
text_config
.
hidden_size
)
self
.
quant_config
=
quant_config
self
.
language_model
=
Qwen2Model
(
config
.
text_config
,
cache_config
,
quant_config
)
self
.
unpadded_vocab_size
=
config
.
text_config
.
vocab_size
if
config
.
text_config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
language_model
.
embed_tokens
else
:
self
.
lm_head
=
ParallelLMHead
(
config
.
text_config
.
vocab_size
,
config
.
text_config
.
hidden_size
,
quant_config
=
quant_config
)
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
text_config
.
vocab_size
,
logit_scale
)
self
.
sampler
=
Sampler
()
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
def
_validate_and_reshape_mm_tensor
(
self
,
mm_input
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]],
name
:
str
)
->
torch
.
Tensor
:
if
not
isinstance
(
mm_input
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
f
"Incorrect type of
{
name
}
. "
f
"Got type:
{
type
(
mm_input
)
}
"
)
if
isinstance
(
mm_input
,
torch
.
Tensor
):
return
torch
.
concat
(
list
(
mm_input
))
else
:
return
torch
.
concat
(
mm_input
)
def
_parse_and_validate_audio_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
Qwen2AudioInputs
]:
input_features
=
kwargs
.
pop
(
'input_features'
,
None
)
feature_attention_mask
=
kwargs
.
pop
(
'feature_attention_mask'
,
None
)
if
input_features
is
None
:
return
None
input_features
=
self
.
_validate_and_reshape_mm_tensor
(
input_features
,
'input_features'
)
feature_attention_mask
=
self
.
_validate_and_reshape_mm_tensor
(
feature_attention_mask
,
'feature_attention_mask'
)
if
not
isinstance
(
input_features
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of audio input features. "
f
"Got type:
{
type
(
input_features
)
}
"
)
return
Qwen2AudioInputs
(
input_features
=
input_features
,
feature_attention_mask
=
feature_attention_mask
)
def
_process_audio_input
(
self
,
audio_input
:
Qwen2AudioInputs
)
->
torch
.
Tensor
:
input_features
=
audio_input
[
"input_features"
]
feature_attention_mask
=
audio_input
[
"feature_attention_mask"
]
audio_feat_lengths
,
audio_output_lengths
=
(
self
.
audio_tower
.
_get_feat_extract_output_lengths
(
feature_attention_mask
.
sum
(
-
1
)))
batch_size
,
_
,
max_mel_seq_len
=
input_features
.
shape
max_seq_len
=
(
max_mel_seq_len
-
2
)
//
2
+
1
# Create a sequence tensor of shape (batch_size, max_seq_len)
seq_range
=
(
torch
.
arange
(
0
,
max_seq_len
,
dtype
=
audio_feat_lengths
.
dtype
,
device
=
audio_feat_lengths
.
device
).
unsqueeze
(
0
).
expand
(
batch_size
,
max_seq_len
))
lengths_expand
=
audio_feat_lengths
.
unsqueeze
(
-
1
).
expand
(
batch_size
,
max_seq_len
)
# Create mask
padding_mask
=
seq_range
>=
lengths_expand
audio_attention_mask_
=
padding_mask
.
view
(
batch_size
,
1
,
1
,
max_seq_len
).
expand
(
batch_size
,
1
,
max_seq_len
,
max_seq_len
)
audio_attention_mask
=
audio_attention_mask_
.
to
(
dtype
=
self
.
audio_tower
.
conv1
.
weight
.
dtype
,
device
=
self
.
audio_tower
.
conv1
.
weight
.
device
)
audio_attention_mask
[
audio_attention_mask_
]
=
float
(
"-inf"
)
audio_outputs
=
self
.
audio_tower
(
input_features
,
attention_mask
=
audio_attention_mask
)
selected_audio_feature
=
audio_outputs
.
last_hidden_state
audio_features
=
self
.
multi_modal_projector
(
selected_audio_feature
)
num_audios
,
max_audio_tokens
,
embed_dim
=
audio_features
.
shape
audio_features_mask
=
torch
.
arange
(
max_audio_tokens
).
expand
(
num_audios
,
max_audio_tokens
).
to
(
audio_output_lengths
.
device
)
<
audio_output_lengths
.
unsqueeze
(
1
)
masked_audio_features
=
audio_features
[
audio_features_mask
].
view
(
-
1
,
embed_dim
)
return
masked_audio_features
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
**
kwargs
:
object
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
intermediate_tensors
is
not
None
:
input_ids
=
None
inputs_embeds
=
None
else
:
audio_input
=
self
.
_parse_and_validate_audio_input
(
**
kwargs
)
if
audio_input
is
None
:
inputs_embeds
=
None
else
:
inputs_embeds
=
self
.
language_model
.
embed_tokens
(
input_ids
)
masked_audio_features
=
self
.
_process_audio_input
(
audio_input
)
# merge llm embeddings and audio features
mask
=
(
input_ids
==
self
.
config
.
audio_token_index
)
inputs_embeds
[
mask
,
:]
=
masked_audio_features
input_ids
=
None
hidden_states
=
self
.
language_model
(
input_ids
=
input_ids
,
positions
=
positions
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
if
(
self
.
config
.
text_config
.
tie_word_embeddings
and
"lm_head.weight"
in
name
):
continue
for
key_to_modify
,
new_key
in
_KEYS_TO_MODIFY_MAPPING
.
items
():
if
key_to_modify
in
name
:
name
=
name
.
replace
(
key_to_modify
,
new_key
)
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
or
'audio'
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Remapping the name of FP8 kv-scale.
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
\ No newline at end of file
vllm/model_executor/models/ultravox.py
View file @
3c23ce2d
...
@@ -118,6 +118,9 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object):
...
@@ -118,6 +118,9 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object):
if
not
isinstance
(
data
,
list
):
if
not
isinstance
(
data
,
list
):
data
=
[
data
]
data
=
[
data
]
if
len
(
data
)
==
0
:
return
MultiModalInputs
()
audio_features
=
[]
audio_features
=
[]
for
audio_input
in
data
:
for
audio_input
in
data
:
if
not
isinstance
(
audio_input
,
tuple
):
if
not
isinstance
(
audio_input
,
tuple
):
...
...
vllm/multimodal/__init__.py
View file @
3c23ce2d
from
.base
import
(
BatchedTensorInputs
,
MultiModalDataBuiltins
,
from
.base
import
(
BatchedTensorInputs
,
MultiModalDataBuiltins
,
MultiModalDataDict
,
MultiModalInputs
,
MultiModalPlugin
,
MultiModalDataDict
,
MultiModalInputs
,
MultiModalPlugin
,
NestedTensors
)
NestedTensors
)
from
.inputs
import
MultiModalPlaceholderDict
from
.registry
import
MultiModalRegistry
from
.registry
import
MultiModalRegistry
MULTIMODAL_REGISTRY
=
MultiModalRegistry
()
MULTIMODAL_REGISTRY
=
MultiModalRegistry
()
...
@@ -17,6 +18,7 @@ __all__ = [
...
@@ -17,6 +18,7 @@ __all__ = [
"MultiModalDataBuiltins"
,
"MultiModalDataBuiltins"
,
"MultiModalDataDict"
,
"MultiModalDataDict"
,
"MultiModalInputs"
,
"MultiModalInputs"
,
"MultiModalPlaceholderDict"
,
"MultiModalPlugin"
,
"MultiModalPlugin"
,
"NestedTensors"
,
"NestedTensors"
,
"MULTIMODAL_REGISTRY"
,
"MULTIMODAL_REGISTRY"
,
...
...
vllm/multimodal/inputs.py
0 → 100644
View file @
3c23ce2d
from
collections
import
UserDict
,
defaultdict
from
typing
import
(
Any
,
Dict
,
List
,
Literal
,
Mapping
,
Sequence
,
Tuple
,
TypedDict
,
TypeVar
,
Union
,
cast
,
final
)
import
numpy
as
np
import
torch
import
torch.types
from
PIL.Image
import
Image
from
typing_extensions
import
TypeAlias
from
vllm.utils
import
JSONTree
,
is_list_of
,
json_map_leaves
_T
=
TypeVar
(
"_T"
)
# yapf: disable
ImageItem
:
TypeAlias
=
Union
[
Image
,
np
.
ndarray
,
torch
.
Tensor
]
"""
A :class:`transformers.image_utils.ImageInput` representing a single image,
which can be passed to a HuggingFace :code:`ImageProcessor`.
"""
VideoItem
:
TypeAlias
=
Union
[
List
[
Image
],
np
.
ndarray
,
torch
.
Tensor
,
List
[
np
.
ndarray
],
List
[
torch
.
Tensor
],
]
"""
A :class:`transformers.image_utils.VideoInput` representing a single video,
which can be passed to a HuggingFace :code:`VideoProcessor`.
"""
AudioItem
:
TypeAlias
=
Union
[
np
.
ndarray
,
List
[
float
],
Tuple
[
np
.
ndarray
,
float
],
# DEPRECATED: Use mm_processor_kwargs instead
]
"""
Represents a single audio that can be inputted to a HuggingFace
:code:`AudioProcessor`.
"""
# yapf: enable
MultiModalData
:
TypeAlias
=
Union
[
_T
,
List
[
_T
]]
"""
Either a single data item, or a list of data items.
The number of data items allowed per modality is restricted by
:code:`--limit-mm-per-prompt`.
"""
@
final
class
MultiModalDataBuiltins
(
TypedDict
,
total
=
False
):
"""Type annotations for modality types predefined by vLLM."""
image
:
MultiModalData
[
ImageItem
]
"""The input image(s)."""
video
:
MultiModalData
[
VideoItem
]
"""The input video(s)."""
audio
:
MultiModalData
[
AudioItem
]
"""The input audio(s)."""
MultiModalDataDict
:
TypeAlias
=
Mapping
[
str
,
MultiModalData
[
Any
]]
"""
A dictionary containing an entry for each modality type to input.
Note:
This dictionary also accepts modality keys defined outside
:class:`MultiModalDataBuiltins` as long as a customized plugin
is registered through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
Read more on that :ref:`here <adding_multimodal_plugin>`.
"""
class
PlaceholderRange
(
TypedDict
):
"""
Placeholder location information for multi-modal data.
For example:
Prompt: AAAA BBBB What is in these images?
Images A and B will have:
A: { "offset": 0, "length": 4 }
B: { "offset": 5, "length": 4 }
"""
offset
:
int
"""The start index of the placeholder in the prompt."""
length
:
int
"""The length of the placeholder."""
NestedTensors
=
Union
[
List
[
"NestedTensors"
],
List
[
torch
.
Tensor
],
torch
.
Tensor
]
"""
Uses a list instead of a tensor if the dimensions of each element do not match.
"""
BatchedTensorInputs
:
TypeAlias
=
Dict
[
str
,
NestedTensors
]
"""
A dictionary containing nested tensors which have been batched via
:meth:`MultiModalKwargs.batch`.
"""
class
MultiModalKwargs
(
UserDict
[
str
,
NestedTensors
]):
"""
A dictionary that represents the keyword arguments to
:meth:`~torch.nn.Module.forward`.
"""
@
staticmethod
def
_try_stack
(
nested_tensors
:
NestedTensors
)
->
NestedTensors
:
"""
Stack the inner dimensions that have the same shape in
a nested list of tensors.
Thus, a dimension represented by a list means that the inner
dimensions are different for each element along that dimension.
"""
if
isinstance
(
nested_tensors
,
torch
.
Tensor
):
return
nested_tensors
# TODO: Remove these once all models have been migrated
if
isinstance
(
nested_tensors
,
np
.
ndarray
):
return
torch
.
from_numpy
(
nested_tensors
)
if
isinstance
(
nested_tensors
,
(
int
,
float
)):
return
torch
.
tensor
(
nested_tensors
)
stacked
=
[
MultiModalKwargs
.
_try_stack
(
t
)
for
t
in
nested_tensors
]
if
not
is_list_of
(
stacked
,
torch
.
Tensor
,
check
=
"all"
):
# Only tensors (not lists) can be stacked.
return
stacked
tensors_
=
cast
(
List
[
torch
.
Tensor
],
stacked
)
if
any
(
t
.
shape
!=
tensors_
[
0
].
shape
for
t
in
tensors_
):
# The tensors have incompatible shapes and can't be stacked.
return
tensors_
return
torch
.
stack
(
tensors_
)
@
staticmethod
def
batch
(
inputs_list
:
List
[
"MultiModalKwargs"
])
->
BatchedTensorInputs
:
"""
Batch multiple inputs together into a dictionary.
The resulting dictionary has the same keys as the inputs.
If the corresponding value from each input is a tensor and they all
share the same shape, the output value is a single batched tensor;
otherwise, the output value is a list containing the original value
from each input.
"""
if
len
(
inputs_list
)
==
0
:
return
{}
# We need to consider the case where each item in the batch
# contains different modalities (i.e. different keys).
item_lists
:
Dict
[
str
,
List
[
NestedTensors
]]
=
defaultdict
(
list
)
for
inputs
in
inputs_list
:
for
k
,
v
in
inputs
.
items
():
item_lists
[
k
].
append
(
v
)
return
{
k
:
MultiModalKwargs
.
_try_stack
(
item_list
)
for
k
,
item_list
in
item_lists
.
items
()
}
@
staticmethod
def
as_kwargs
(
batched_inputs
:
BatchedTensorInputs
,
*
,
device
:
torch
.
types
.
Device
,
)
->
BatchedTensorInputs
:
json_inputs
=
cast
(
JSONTree
[
torch
.
Tensor
],
batched_inputs
)
json_mapped
=
json_map_leaves
(
lambda
x
:
x
.
to
(
device
,
non_blocking
=
True
),
json_inputs
,
)
return
cast
(
BatchedTensorInputs
,
json_mapped
)
MultiModalPlaceholderDict
=
Mapping
[
str
,
Sequence
[
PlaceholderRange
]]
"""
A dictionary containing placeholder ranges.
"""
class
MultiModalInputsV2
(
TypedDict
):
"""
Represents the outputs of :class:`vllm.multimodal.MultiModalProcessor`,
ready to be passed to vLLM internals.
"""
type
:
Literal
[
"multimodal"
]
"""The type of inputs."""
prompt
:
str
"""
The original, unprocessed prompt text.
Note:
Since prompt text is not required by vLLM internals, we leave this
unprocessed to save CPU computation. You can still call
:code:`tokenizer.decode(prompt_token_ids)` to get the processed text.
"""
prompt_token_ids
:
List
[
int
]
"""The processed token IDs which includes placeholder tokens."""
mm_kwargs
:
MultiModalKwargs
"""Keyword arguments to be directly passed to the model after batching."""
mm_placeholders
:
MultiModalPlaceholderDict
"""
For each modality, information about the placeholder tokens in
:code:`prompt_token_ids`.
"""
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment