Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9c4ecf15
"examples/vscode:/vscode.git/clone" did not exist on "0f2fa9282858d7cc422a0f1bdd08684e5e703d6a"
Commit
9c4ecf15
authored
Apr 14, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.8.4' into v0.8.4-ori
parents
bfc2d6f7
dc1b4a6f
Changes
342
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
717 additions
and
993 deletions
+717
-993
vllm/model_executor/models/idefics3.py
vllm/model_executor/models/idefics3.py
+47
-101
vllm/model_executor/models/interfaces.py
vllm/model_executor/models/interfaces.py
+12
-0
vllm/model_executor/models/internvl.py
vllm/model_executor/models/internvl.py
+17
-67
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+4
-4
vllm/model_executor/models/llama4.py
vllm/model_executor/models/llama4.py
+26
-26
vllm/model_executor/models/llama_eagle.py
vllm/model_executor/models/llama_eagle.py
+151
-0
vllm/model_executor/models/llava.py
vllm/model_executor/models/llava.py
+21
-67
vllm/model_executor/models/llava_next.py
vllm/model_executor/models/llava_next.py
+3
-0
vllm/model_executor/models/llava_next_video.py
vllm/model_executor/models/llava_next_video.py
+17
-29
vllm/model_executor/models/llava_onevision.py
vllm/model_executor/models/llava_onevision.py
+17
-23
vllm/model_executor/models/mamba2.py
vllm/model_executor/models/mamba2.py
+11
-17
vllm/model_executor/models/minicpmo.py
vllm/model_executor/models/minicpmo.py
+28
-93
vllm/model_executor/models/minicpmv.py
vllm/model_executor/models/minicpmv.py
+61
-127
vllm/model_executor/models/mistral3.py
vllm/model_executor/models/mistral3.py
+23
-71
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+90
-86
vllm/model_executor/models/mixtral_quant.py
vllm/model_executor/models/mixtral_quant.py
+7
-1
vllm/model_executor/models/mllama.py
vllm/model_executor/models/mllama.py
+64
-38
vllm/model_executor/models/mllama4.py
vllm/model_executor/models/mllama4.py
+81
-115
vllm/model_executor/models/molmo.py
vllm/model_executor/models/molmo.py
+25
-87
vllm/model_executor/models/nvlm_d.py
vllm/model_executor/models/nvlm_d.py
+12
-41
No files found.
vllm/model_executor/models/idefics3.py
View file @
9c4ecf15
...
@@ -32,18 +32,18 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
...
@@ -32,18 +32,18 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalKwargs
)
from
vllm.multimodal.parse
import
ImageProcessorItems
,
ImageSize
from
vllm.multimodal.parse
import
ImageProcessorItems
,
ImageSize
# yapf conflicts with isort for this block
# yapf conflicts with isort for this block
# yapf: disable
# yapf: disable
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
BaseProcessingInfo
,
MultiModalDataItems
,
MultiModalDataItems
,
PromptReplacement
,
MultiModalFieldConfig
,
PromptUpdate
,
PromptUpdateDetails
)
PromptReplacement
,
PromptUpdate
,
encode_tokens
)
# yapf: enable
# yapf: enable
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
# yapf: disable
# yapf: disable
...
@@ -54,7 +54,6 @@ from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
...
@@ -54,7 +54,6 @@ from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
from
.llama
import
LlamaModel
from
.llama
import
LlamaModel
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
maybe_prefix
,
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
maybe_prefix
,
merge_multimodal_embeddings
)
merge_multimodal_embeddings
)
from
.vision
import
scatter_patch_features
,
select_patch_features
class
Idefics3ImagePixelInputs
(
TypedDict
):
class
Idefics3ImagePixelInputs
(
TypedDict
):
...
@@ -69,14 +68,6 @@ class Idefics3ImagePixelInputs(TypedDict):
...
@@ -69,14 +68,6 @@ class Idefics3ImagePixelInputs(TypedDict):
num_patches
:
torch
.
Tensor
num_patches
:
torch
.
Tensor
"""Shape: `(batch_size * num_images)`"""
"""Shape: `(batch_size * num_images)`"""
embed_is_patch
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
class
Idefics3ImageEmbeddingInputs
(
TypedDict
):
class
Idefics3ImageEmbeddingInputs
(
TypedDict
):
type
:
Literal
[
"image_embeds"
]
type
:
Literal
[
"image_embeds"
]
...
@@ -86,14 +77,6 @@ class Idefics3ImageEmbeddingInputs(TypedDict):
...
@@ -86,14 +77,6 @@ class Idefics3ImageEmbeddingInputs(TypedDict):
`hidden_size` must match the hidden size of language model backbone.
`hidden_size` must match the hidden size of language model backbone.
"""
"""
embed_is_patch
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
ImageInputs
=
Union
[
Idefics3ImagePixelInputs
,
Idefics3ImageEmbeddingInputs
]
ImageInputs
=
Union
[
Idefics3ImagePixelInputs
,
Idefics3ImageEmbeddingInputs
]
...
@@ -114,13 +97,6 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
...
@@ -114,13 +97,6 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
None
}
return
{
"image"
:
None
}
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
int
]:
return
{
"image"
:
self
.
get_max_image_tokens
()}
def
_resize_output_size
(
self
,
def
_resize_output_size
(
self
,
*
,
*
,
height
:
int
,
height
:
int
,
...
@@ -223,6 +199,16 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
...
@@ -223,6 +199,16 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
return
grid_w
*
grid_h
+
1
return
grid_w
*
grid_h
+
1
def
_get_image_token
(
self
,
processor
:
Optional
[
Idefics3Processor
])
->
tuple
[
str
,
str
,
str
]:
if
processor
is
None
:
processor
=
self
.
get_hf_processor
()
image_token
=
processor
.
image_token
.
content
fake_image_token
=
processor
.
fake_image_token
.
content
global_image_token
=
processor
.
global_image_tag
return
image_token
,
fake_image_token
,
global_image_token
def
get_image_repl
(
def
get_image_repl
(
self
,
self
,
*
,
*
,
...
@@ -233,9 +219,8 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
...
@@ -233,9 +219,8 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
if
processor
is
None
:
if
processor
is
None
:
processor
=
self
.
get_hf_processor
()
processor
=
self
.
get_hf_processor
()
image_token
=
processor
.
image_token
.
content
image_token
,
fake_image_token
,
global_img_token
=
self
.
_get_image_token
(
fake_image_token
=
processor
.
fake_image_token
.
content
processor
)
global_img_token
=
processor
.
global_image_tag
image_seq_len
=
processor
.
image_seq_len
image_seq_len
=
processor
.
image_seq_len
grid_placeholder
=
"<row_{n_h}_col_{n_w}>"
grid_placeholder
=
"<row_{n_h}_col_{n_w}>"
...
@@ -275,19 +260,16 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
...
@@ -275,19 +260,16 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
image_height
:
int
,
image_height
:
int
,
processor
:
Optional
[
Idefics3Processor
],
processor
:
Optional
[
Idefics3Processor
],
)
->
int
:
)
->
int
:
tokenizer
=
self
.
get_tokenizer
()
if
processor
is
None
:
image_repl
=
self
.
get_image_repl
(
processor
=
self
.
get_hf_processor
()
num_patches
=
self
.
get_num_patches
(
image_width
=
image_width
,
image_width
=
image_width
,
image_height
=
image_height
,
image_height
=
image_height
,
processor
=
processor
,
processor
=
processor
,
)
)
image_repl_tokens
=
encode_tokens
(
return
num_patches
*
processor
.
image_seq_len
tokenizer
,
image_repl
,
add_special_tokens
=
False
,
)
return
len
(
image_repl_tokens
)
def
get_image_size_with_most_features
(
self
)
->
ImageSize
:
def
get_image_size_with_most_features
(
self
)
->
ImageSize
:
processor
=
self
.
get_hf_processor
()
processor
=
self
.
get_hf_processor
()
...
@@ -298,42 +280,35 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
...
@@ -298,42 +280,35 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
height
=
image_processor
.
size
[
"longest_edge"
],
height
=
image_processor
.
size
[
"longest_edge"
],
)
)
def
get_max_image_tokens
(
self
)
->
int
:
target_width
,
target_height
=
self
.
get_image_size_with_most_features
()
return
self
.
get_num_image_tokens
(
image_width
=
target_width
,
image_height
=
target_height
,
processor
=
None
,
)
class
Idefics3DummyInputsBuilder
(
BaseDummyInputsBuilder
[
Idefics3ProcessingInfo
]
class
Idefics3DummyInputsBuilder
(
BaseDummyInputsBuilder
[
Idefics3ProcessingInfo
]
):
):
def
get_dummy_processor_inputs
(
def
get_dummy_text
(
self
,
mm_counts
:
Mapping
[
str
,
int
])
->
str
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
processor
=
self
.
info
.
get_hf_processor
()
image_token
,
_
,
_
=
self
.
info
.
_get_image_token
(
processor
)
return
image_token
*
num_images
def
get_dummy_mm_data
(
self
,
self
,
seq_len
:
int
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
mm_counts
:
Mapping
[
str
,
int
],
)
->
ProcessorInputs
:
)
->
MultiModalDataDict
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
num_images
=
mm_counts
.
get
(
"image"
,
0
)
hf_processor
=
self
.
info
.
get_hf_processor
()
hf_processor
=
self
.
info
.
get_hf_processor
()
image_processor
:
Idefics3ImageProcessor
=
hf_processor
.
image_processor
image_processor
:
Idefics3ImageProcessor
=
hf_processor
.
image_processor
longest_edge
=
image_processor
.
max_image_size
[
'longest_edge'
]
longest_edge
=
image_processor
.
max_image_size
[
'longest_edge'
]
image_token
=
hf_processor
.
image_token
.
content
mm_data
=
{
return
{
"image"
:
"image"
:
self
.
_get_dummy_images
(
width
=
longest_edge
,
self
.
_get_dummy_images
(
width
=
longest_edge
,
height
=
longest_edge
,
height
=
longest_edge
,
num_images
=
num_images
)
num_images
=
num_images
)
}
}
return
ProcessorInputs
(
prompt_text
=
image_token
*
num_images
,
mm_data
=
mm_data
,
)
class
Idefics3MultiModalProcessor
(
class
Idefics3MultiModalProcessor
(
BaseMultiModalProcessor
[
Idefics3ProcessingInfo
]):
BaseMultiModalProcessor
[
Idefics3ProcessingInfo
]):
...
@@ -364,28 +339,6 @@ class Idefics3MultiModalProcessor(
...
@@ -364,28 +339,6 @@ class Idefics3MultiModalProcessor(
]
]
hf_processor
=
self
.
info
.
get_hf_processor
(
**
mm_kwargs
)
hf_processor
=
self
.
info
.
get_hf_processor
(
**
mm_kwargs
)
image_repl_features
=
[
self
.
info
.
get_image_repl
(
image_width
=
size
.
width
,
image_height
=
size
.
height
,
processor
=
hf_processor
)
for
size
in
image_sizes
]
tokenizer
=
self
.
info
.
get_tokenizer
()
image_repls_feature_tokens
=
[
tokenizer
.
encode
(
image_repl
,
add_special_tokens
=
False
)
for
image_repl
in
image_repl_features
]
vocab
=
tokenizer
.
get_vocab
()
image_token_id
=
vocab
[
hf_processor
.
image_token
.
content
]
embed_is_patch
=
[
torch
.
tensor
(
image_repl_tokens
)
==
image_token_id
for
image_repl_tokens
in
image_repls_feature_tokens
]
processed_outputs
[
"embed_is_patch"
]
=
embed_is_patch
num_patches
=
[
num_patches
=
[
self
.
info
.
get_num_patches
(
self
.
info
.
get_num_patches
(
image_width
=
size
.
width
,
image_width
=
size
.
width
,
...
@@ -415,7 +368,6 @@ class Idefics3MultiModalProcessor(
...
@@ -415,7 +368,6 @@ class Idefics3MultiModalProcessor(
"image"
,
num_patches
),
"image"
,
num_patches
),
image_embeds
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_embeds
=
MultiModalFieldConfig
.
batched
(
"image"
),
num_patches
=
MultiModalFieldConfig
.
batched
(
"image"
),
num_patches
=
MultiModalFieldConfig
.
batched
(
"image"
),
embed_is_patch
=
MultiModalFieldConfig
.
batched
(
"image"
),
)
)
def
_get_prompt_updates
(
def
_get_prompt_updates
(
...
@@ -425,19 +377,24 @@ class Idefics3MultiModalProcessor(
...
@@ -425,19 +377,24 @@ class Idefics3MultiModalProcessor(
out_mm_kwargs
:
MultiModalKwargs
,
out_mm_kwargs
:
MultiModalKwargs
,
)
->
Sequence
[
PromptUpdate
]:
)
->
Sequence
[
PromptUpdate
]:
hf_processor
=
self
.
info
.
get_hf_processor
(
**
hf_processor_mm_kwargs
)
hf_processor
=
self
.
info
.
get_hf_processor
(
**
hf_processor_mm_kwargs
)
image_token
=
hf_processor
.
image_token
.
content
image_token
,
_
,
_
=
self
.
info
.
_get_image_token
(
hf_processor
)
def
get_replacement_idefics3
(
item_idx
:
int
)
->
str
:
def
get_replacement_idefics3
(
item_idx
:
int
)
->
PromptUpdateDetails
:
images
=
mm_items
.
get_items
(
"image"
,
ImageProcessorItems
)
images
=
mm_items
.
get_items
(
"image"
,
ImageProcessorItems
)
image_size
=
images
.
get_image_size
(
item_idx
)
image_size
=
images
.
get_image_size
(
item_idx
)
return
self
.
info
.
get_image_repl
(
image_repl
=
self
.
info
.
get_image_repl
(
image_width
=
image_size
.
width
,
image_width
=
image_size
.
width
,
image_height
=
image_size
.
height
,
image_height
=
image_size
.
height
,
processor
=
hf_processor
,
processor
=
hf_processor
,
)
)
return
PromptUpdateDetails
.
select_text
(
image_repl
,
embed_text
=
image_token
,
)
return
[
return
[
PromptReplacement
(
PromptReplacement
(
modality
=
"image"
,
modality
=
"image"
,
...
@@ -675,13 +632,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -675,13 +632,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
if
pixel_values
is
None
and
image_embeds
is
None
:
if
pixel_values
is
None
and
image_embeds
is
None
:
return
None
return
None
embed_is_patch
=
kwargs
.
pop
(
"embed_is_patch"
)
if
not
isinstance
(
embed_is_patch
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of embed_is_patch. "
f
"Got type:
{
type
(
embed_is_patch
)
}
"
)
embed_is_patch
=
flatten_bn
(
embed_is_patch
)
if
image_embeds
is
not
None
:
if
image_embeds
is
not
None
:
if
not
isinstance
(
image_embeds
,
(
torch
.
Tensor
,
list
)):
if
not
isinstance
(
image_embeds
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of image embeddings. "
raise
ValueError
(
"Incorrect type of image embeddings. "
...
@@ -690,7 +640,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -690,7 +640,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
return
Idefics3ImageEmbeddingInputs
(
return
Idefics3ImageEmbeddingInputs
(
type
=
"image_embeds"
,
type
=
"image_embeds"
,
data
=
flatten_bn
(
image_embeds
,
concat
=
True
),
data
=
flatten_bn
(
image_embeds
,
concat
=
True
),
embed_is_patch
=
embed_is_patch
,
)
)
if
pixel_values
is
not
None
:
if
pixel_values
is
not
None
:
...
@@ -718,7 +667,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -718,7 +667,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
pixel_values
=
self
.
_validate_pixel_values
(
pixel_values
),
pixel_values
=
self
.
_validate_pixel_values
(
pixel_values
),
pixel_attention_mask
=
pixel_attention_mask
,
pixel_attention_mask
=
pixel_attention_mask
,
num_patches
=
num_patches
,
num_patches
=
num_patches
,
embed_is_patch
=
embed_is_patch
,
)
)
raise
AssertionError
(
"This line should be unreachable."
)
raise
AssertionError
(
"This line should be unreachable."
)
...
@@ -748,18 +696,16 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -748,18 +696,16 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
e
.
flatten
(
0
,
1
)
for
e
in
image_features
.
split
(
num_patches
.
tolist
())
e
.
flatten
(
0
,
1
)
for
e
in
image_features
.
split
(
num_patches
.
tolist
())
]
]
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
model
def
get_multimodal_embeddings
(
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
if
image_input
is
None
:
return
None
return
None
image_features
=
self
.
_process_image_input
(
image_input
)
return
self
.
_process_image_input
(
image_input
)
return
scatter_patch_features
(
image_features
,
image_input
[
"embed_is_patch"
],
)
def
get_input_embeddings
(
def
get_input_embeddings
(
self
,
self
,
...
@@ -771,7 +717,7 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -771,7 +717,7 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds
=
merge_multimodal_embeddings
(
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
input_ids
,
inputs_embeds
,
inputs_embeds
,
select_patch_features
(
multimodal_embeddings
)
,
multimodal_embeddings
,
self
.
config
.
image_token_id
,
self
.
config
.
image_token_id
,
)
)
return
inputs_embeds
return
inputs_embeds
...
...
vllm/model_executor/models/interfaces.py
View file @
9c4ecf15
...
@@ -56,6 +56,18 @@ class SupportsMultiModal(Protocol):
...
@@ -56,6 +56,18 @@ class SupportsMultiModal(Protocol):
"""
"""
...
...
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
"""
Returns the underlying language model used for text generation.
This is typically the `torch.nn.Module` instance responsible for
processing the merged multimodal embeddings and producing hidden states
Returns:
torch.nn.Module: The core language model component.
"""
...
# Only for models that support v0 chunked prefill
# Only for models that support v0 chunked prefill
# TODO(ywang96): Remove this overload once v0 is deprecated
# TODO(ywang96): Remove this overload once v0 is deprecated
@
overload
@
overload
...
...
vllm/model_executor/models/internvl.py
View file @
9c4ecf15
...
@@ -25,21 +25,20 @@ from vllm.model_executor.models.intern_vit import (InternVisionModel,
...
@@ -25,21 +25,20 @@ from vllm.model_executor.models.intern_vit import (InternVisionModel,
InternVisionPatchModel
)
InternVisionPatchModel
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
MultiModal
FieldConfig
,
MultiModal
Kwargs
,
from
vllm.multimodal.inputs
import
(
MultiModal
DataDict
,
MultiModal
FieldConfig
,
NestedTensors
)
MultiModalKwargs
,
NestedTensors
)
from
vllm.multimodal.parse
import
(
ImageEmbeddingItems
,
ImageProcessorItems
,
from
vllm.multimodal.parse
import
(
ImageEmbeddingItems
,
ImageProcessorItems
,
ImageSize
,
MultiModalDataItems
)
ImageSize
,
MultiModalDataItems
)
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptReplacement
,
BaseProcessingInfo
,
PromptReplacement
,
PromptUpdate
,
PromptUpdateDetails
)
PromptUpdate
,
PromptUpdateDetails
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
init_vllm_registered_model
,
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
init_vllm_registered_model
,
maybe_prefix
,
merge_multimodal_embeddings
)
maybe_prefix
,
merge_multimodal_embeddings
)
from
.vision
import
scatter_patch_features
,
select_patch_features
IMG_START
=
'<img>'
IMG_START
=
'<img>'
IMG_END
=
'</img>'
IMG_END
=
'</img>'
...
@@ -60,14 +59,6 @@ class InternVLImagePixelInputs(TypedDict):
...
@@ -60,14 +59,6 @@ class InternVLImagePixelInputs(TypedDict):
num_patches
:
torch
.
Tensor
num_patches
:
torch
.
Tensor
"""Shape: `(batch_size * num_images)`"""
"""Shape: `(batch_size * num_images)`"""
embed_is_patch
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
class
InternVLImageEmbeddingInputs
(
TypedDict
):
class
InternVLImageEmbeddingInputs
(
TypedDict
):
type
:
Literal
[
"image_embeds"
]
type
:
Literal
[
"image_embeds"
]
...
@@ -419,24 +410,12 @@ class BaseInternVLProcessor(ABC):
...
@@ -419,24 +410,12 @@ class BaseInternVLProcessor(ABC):
torch
.
tensor
([
len
(
item
)
for
item
in
pixel_values_lst
]),
torch
.
tensor
([
len
(
item
)
for
item
in
pixel_values_lst
]),
}
}
tokenizer
=
self
.
tokenizer
image_token_id
=
self
.
image_token_id
embed_is_patch
=
list
[
torch
.
Tensor
]()
for
pixel_values
in
pixel_values_lst
:
for
pixel_values
in
pixel_values_lst
:
num_patches
=
pixel_values
.
shape
[
0
]
num_patches
=
pixel_values
.
shape
[
0
]
feature_size
=
num_patches
*
self
.
num_image_token
feature_size
=
num_patches
*
self
.
num_image_token
image_repl
=
self
.
get_image_repl
(
feature_size
,
num_patches
)
image_repl
=
self
.
get_image_repl
(
feature_size
,
num_patches
)
feature_tokens
=
tokenizer
.
encode
(
image_repl
.
features
,
add_special_tokens
=
False
)
text
=
[
t
.
replace
(
'<image>'
,
image_repl
.
full
,
1
)
for
t
in
text
]
text
=
[
t
.
replace
(
'<image>'
,
image_repl
.
full
,
1
)
for
t
in
text
]
embed_is_patch
.
append
(
torch
.
tensor
(
feature_tokens
)
==
image_token_id
)
image_inputs
[
"embed_is_patch"
]
=
embed_is_patch
text_inputs
=
self
.
tokenizer
(
text
)
text_inputs
=
self
.
tokenizer
(
text
)
...
@@ -460,7 +439,7 @@ class InternVLProcessor(BaseInternVLProcessor):
...
@@ -460,7 +439,7 @@ class InternVLProcessor(BaseInternVLProcessor):
repl_features
=
IMG_CONTEXT
*
feature_size
repl_features
=
IMG_CONTEXT
*
feature_size
repl_full
=
IMG_START
+
repl_features
+
IMG_END
repl_full
=
IMG_START
+
repl_features
+
IMG_END
return
PromptUpdateDetails
(
full
=
repl_full
,
features
=
repl_features
)
return
PromptUpdateDetails
.
select_text
(
repl_full
,
IMG_CONTEXT
)
class
BaseInternVLProcessingInfo
(
BaseProcessingInfo
):
class
BaseInternVLProcessingInfo
(
BaseProcessingInfo
):
...
@@ -479,13 +458,6 @@ class BaseInternVLProcessingInfo(BaseProcessingInfo):
...
@@ -479,13 +458,6 @@ class BaseInternVLProcessingInfo(BaseProcessingInfo):
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
None
}
return
{
"image"
:
None
}
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
int
]:
return
{
"image"
:
self
.
get_max_image_tokens
()}
def
get_num_image_tokens
(
def
get_num_image_tokens
(
self
,
self
,
*
,
*
,
...
@@ -501,15 +473,6 @@ class BaseInternVLProcessingInfo(BaseProcessingInfo):
...
@@ -501,15 +473,6 @@ class BaseInternVLProcessingInfo(BaseProcessingInfo):
image_height
=
image_height
,
image_height
=
image_height
,
)
)
def
get_max_image_tokens
(
self
)
->
int
:
target_width
,
target_height
=
self
.
get_image_size_with_most_features
()
return
self
.
get_num_image_tokens
(
image_width
=
target_width
,
image_height
=
target_height
,
processor
=
None
,
)
def
get_image_size_with_most_features
(
self
)
->
ImageSize
:
def
get_image_size_with_most_features
(
self
)
->
ImageSize
:
processor
=
self
.
get_hf_processor
()
processor
=
self
.
get_hf_processor
()
...
@@ -541,27 +504,27 @@ _I = TypeVar("_I", bound=BaseInternVLProcessingInfo)
...
@@ -541,27 +504,27 @@ _I = TypeVar("_I", bound=BaseInternVLProcessingInfo)
class
InternVLDummyInputsBuilder
(
BaseDummyInputsBuilder
[
_I
]):
class
InternVLDummyInputsBuilder
(
BaseDummyInputsBuilder
[
_I
]):
def
get_dummy_processor_inputs
(
def
get_dummy_text
(
self
,
mm_counts
:
Mapping
[
str
,
int
])
->
str
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
return
"<image>"
*
num_images
def
get_dummy_mm_data
(
self
,
self
,
seq_len
:
int
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
mm_counts
:
Mapping
[
str
,
int
],
)
->
ProcessorInputs
:
)
->
MultiModalDataDict
:
target_width
,
target_height
=
\
target_width
,
target_height
=
\
self
.
info
.
get_image_size_with_most_features
()
self
.
info
.
get_image_size_with_most_features
()
num_images
=
mm_counts
.
get
(
"image"
,
0
)
num_images
=
mm_counts
.
get
(
"image"
,
0
)
mm_data
=
{
return
{
"image"
:
"image"
:
self
.
_get_dummy_images
(
width
=
target_width
,
self
.
_get_dummy_images
(
width
=
target_width
,
height
=
target_height
,
height
=
target_height
,
num_images
=
num_images
)
num_images
=
num_images
)
}
}
return
ProcessorInputs
(
prompt_text
=
"<image>"
*
num_images
,
mm_data
=
mm_data
,
)
class
InternVLMultiModalProcessor
(
BaseMultiModalProcessor
[
_I
]):
class
InternVLMultiModalProcessor
(
BaseMultiModalProcessor
[
_I
]):
...
@@ -599,7 +562,6 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
...
@@ -599,7 +562,6 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
pixel_values_flat
=
MultiModalFieldConfig
.
flat_from_sizes
(
pixel_values_flat
=
MultiModalFieldConfig
.
flat_from_sizes
(
"image"
,
image_num_patches
),
"image"
,
image_num_patches
),
image_num_patches
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_num_patches
=
MultiModalFieldConfig
.
batched
(
"image"
),
embed_is_patch
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_embeds
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_embeds
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_token_id
=
MultiModalFieldConfig
.
shared
(
"image"
,
num_images
),
image_token_id
=
MultiModalFieldConfig
.
shared
(
"image"
,
num_images
),
)
)
...
@@ -831,7 +793,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -831,7 +793,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
self
,
**
kwargs
:
object
)
->
Optional
[
InternVLImageInputs
]:
self
,
**
kwargs
:
object
)
->
Optional
[
InternVLImageInputs
]:
pixel_values_flat
=
kwargs
.
pop
(
"pixel_values_flat"
,
None
)
pixel_values_flat
=
kwargs
.
pop
(
"pixel_values_flat"
,
None
)
image_num_patches
=
kwargs
.
pop
(
"image_num_patches"
,
None
)
image_num_patches
=
kwargs
.
pop
(
"image_num_patches"
,
None
)
embed_is_patch
=
kwargs
.
pop
(
"embed_is_patch"
,
None
)
image_embeds
=
kwargs
.
pop
(
"image_embeds"
,
None
)
image_embeds
=
kwargs
.
pop
(
"image_embeds"
,
None
)
if
pixel_values_flat
is
None
and
image_embeds
is
None
:
if
pixel_values_flat
is
None
and
image_embeds
is
None
:
...
@@ -860,20 +821,14 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -860,20 +821,14 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
raise
ValueError
(
"Incorrect type of image_num_patches. "
raise
ValueError
(
"Incorrect type of image_num_patches. "
f
"Got type:
{
type
(
image_num_patches
)
}
"
)
f
"Got type:
{
type
(
image_num_patches
)
}
"
)
if
not
isinstance
(
embed_is_patch
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of embed_is_patch. "
f
"Got type:
{
type
(
embed_is_patch
)
}
"
)
pixel_values_flat
=
flatten_bn
(
pixel_values_flat
,
concat
=
True
)
pixel_values_flat
=
flatten_bn
(
pixel_values_flat
,
concat
=
True
)
image_num_patches
=
flatten_bn
(
image_num_patches
,
concat
=
True
)
image_num_patches
=
flatten_bn
(
image_num_patches
,
concat
=
True
)
embed_is_patch
=
flatten_bn
(
embed_is_patch
)
return
InternVLImagePixelInputs
(
return
InternVLImagePixelInputs
(
type
=
"pixel_values"
,
type
=
"pixel_values"
,
pixel_values_flat
=
self
.
_validate_pixel_values
(
pixel_values_flat
=
self
.
_validate_pixel_values
(
pixel_values_flat
),
pixel_values_flat
),
num_patches
=
image_num_patches
,
num_patches
=
image_num_patches
,
embed_is_patch
=
embed_is_patch
,
)
)
raise
AssertionError
(
"This line should be unreachable."
)
raise
AssertionError
(
"This line should be unreachable."
)
...
@@ -913,21 +868,16 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -913,21 +868,16 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
else
:
else
:
self
.
visual_token_mask
=
None
self
.
visual_token_mask
=
None
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
get_multimodal_embeddings
(
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
if
image_input
is
None
:
return
None
return
None
image_features
=
self
.
_process_image_input
(
image_input
)
return
self
.
_process_image_input
(
image_input
)
if
image_input
[
"type"
]
!=
"pixel_values"
:
return
image_features
return
scatter_patch_features
(
image_features
,
image_input
[
"embed_is_patch"
],
)
def
get_input_embeddings
(
def
get_input_embeddings
(
self
,
self
,
...
@@ -941,7 +891,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -941,7 +891,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds
=
merge_multimodal_embeddings
(
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
input_ids
,
inputs_embeds
,
inputs_embeds
,
select_patch_features
(
multimodal_embeddings
)
,
multimodal_embeddings
,
self
.
img_context_token_id
,
self
.
img_context_token_id
,
)
)
return
inputs_embeds
return
inputs_embeds
...
...
vllm/model_executor/models/llama.py
View file @
9c4ecf15
...
@@ -22,7 +22,7 @@
...
@@ -22,7 +22,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Set
,
Tuple
,
Type
,
Union
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -294,7 +294,7 @@ class LlamaModel(nn.Module):
...
@@ -294,7 +294,7 @@ class LlamaModel(nn.Module):
*
,
*
,
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
layer_type
:
T
ype
[
LlamaDecoderLayer
]
=
LlamaDecoderLayer
):
layer_type
:
t
ype
[
nn
.
Module
]
=
LlamaDecoderLayer
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
...
@@ -475,7 +475,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -475,7 +475,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
*
,
*
,
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
layer_type
:
T
ype
[
LlamaDecoderLayer
]
=
LlamaDecoderLayer
):
layer_type
:
t
ype
[
nn
.
Module
]
=
LlamaDecoderLayer
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
...
@@ -523,7 +523,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -523,7 +523,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def
_init_model
(
self
,
def
_init_model
(
self
,
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
layer_type
:
T
ype
[
LlamaDecoderLayer
]
=
LlamaDecoderLayer
):
layer_type
:
t
ype
[
nn
.
Module
]
=
LlamaDecoderLayer
):
return
LlamaModel
(
vllm_config
=
vllm_config
,
return
LlamaModel
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
prefix
=
prefix
,
layer_type
=
layer_type
)
layer_type
=
layer_type
)
...
...
vllm/model_executor/models/llama4.py
View file @
9c4ecf15
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Type
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -36,8 +36,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
...
@@ -36,8 +36,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
.llama
import
LlamaDecoderLayer
,
LlamaForCausalLM
,
LlamaMLP
,
LlamaModel
from
.llama
import
LlamaForCausalLM
,
LlamaMLP
,
LlamaModel
from
.utils
import
(
AutoWeightsLoader
,
extract_layer_index
,
from
.utils
import
(
AutoWeightsLoader
,
extract_layer_index
,
fast_topk
,
is_pp_missing_parameter
)
is_pp_missing_parameter
)
...
@@ -50,7 +50,7 @@ class Llama4MoE(nn.Module):
...
@@ -50,7 +50,7 @@ class Llama4MoE(nn.Module):
topk
:
int
,
topk
:
int
,
renormalize
:
bool
,
renormalize
:
bool
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
router_scores
,
router_indices
=
torch
.
topk
(
gating_output
,
topk
,
dim
=-
1
)
router_scores
,
router_indices
=
fast_
topk
(
gating_output
,
topk
,
dim
=-
1
)
router_scores
=
torch
.
sigmoid
(
router_scores
.
float
()).
to
(
router_scores
=
torch
.
sigmoid
(
router_scores
.
float
()).
to
(
hidden_states
.
dtype
)
hidden_states
.
dtype
)
return
(
router_scores
,
router_indices
.
to
(
torch
.
int32
))
return
(
router_scores
,
router_indices
.
to
(
torch
.
int32
))
...
@@ -155,14 +155,8 @@ class Llama4Attention(nn.Module):
...
@@ -155,14 +155,8 @@ class Llama4Attention(nn.Module):
self
.
rope_theta
=
rope_theta
self
.
rope_theta
=
rope_theta
self
.
max_position_embeddings
=
max_position_embeddings
self
.
max_position_embeddings
=
max_position_embeddings
self
.
n_rep
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
n_rep
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
q_norm
=
RMSNorm
(
self
.
qk_norm
=
RMSNorm
(
hidden_size
=
self
.
q_size
,
hidden_size
=
self
.
head_dim
,
eps
=
config
.
rms_norm_eps
,
has_weight
=
False
,
dtype
=
torch
.
float32
,
)
if
self
.
use_qk_norm
else
None
self
.
k_norm
=
RMSNorm
(
hidden_size
=
self
.
kv_size
,
eps
=
config
.
rms_norm_eps
,
eps
=
config
.
rms_norm_eps
,
has_weight
=
False
,
has_weight
=
False
,
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
...
@@ -226,10 +220,11 @@ class Llama4Attention(nn.Module):
...
@@ -226,10 +220,11 @@ class Llama4Attention(nn.Module):
if
self
.
rotary_emb
is
not
None
:
if
self
.
rotary_emb
is
not
None
:
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
if
self
.
q_norm
is
not
None
:
if
self
.
qk_norm
is
not
None
:
q
=
self
.
q_norm
(
q
.
float
()).
to
(
q
.
dtype
)
q
=
q
.
reshape
(
-
1
,
self
.
num_heads
,
self
.
head_dim
)
if
self
.
k_norm
is
not
None
:
q
=
self
.
qk_norm
(
q
.
float
()).
reshape
(
-
1
,
self
.
q_size
).
to
(
q
.
dtype
)
k
=
self
.
k_norm
(
k
.
float
()).
to
(
k
.
dtype
)
k
=
k
.
reshape
(
-
1
,
self
.
num_kv_heads
,
self
.
head_dim
)
k
=
self
.
qk_norm
(
k
.
float
()).
reshape
(
-
1
,
self
.
kv_size
).
to
(
k
.
dtype
)
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399)
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399)
# to NoPE layers, where the inference-time temperature tuning function
# to NoPE layers, where the inference-time temperature tuning function
...
@@ -247,7 +242,7 @@ class Llama4Attention(nn.Module):
...
@@ -247,7 +242,7 @@ class Llama4Attention(nn.Module):
return
output
return
output
class
Llama4DecoderLayer
(
LlamaDecoderLayer
):
class
Llama4DecoderLayer
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -256,8 +251,9 @@ class Llama4DecoderLayer(LlamaDecoderLayer):
...
@@ -256,8 +251,9 @@ class Llama4DecoderLayer(LlamaDecoderLayer):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
self
.
layer_idx
=
extract_layer_index
(
prefix
)
self
.
layer_idx
=
extract_layer_index
(
prefix
)
nn
.
Module
.
__init__
(
self
)
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
rope_theta
=
config
.
rope_theta
rope_theta
=
config
.
rope_theta
rope_scaling
=
config
.
rope_scaling
rope_scaling
=
config
.
rope_scaling
...
@@ -329,7 +325,7 @@ class Llama4Model(LlamaModel):
...
@@ -329,7 +325,7 @@ class Llama4Model(LlamaModel):
*
,
*
,
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
layer_type
:
T
ype
[
Llama4DecoderLayer
]
=
Llama4DecoderLayer
):
layer_type
:
t
ype
[
Llama4DecoderLayer
]
=
Llama4DecoderLayer
):
self
.
num_experts
=
vllm_config
.
model_config
.
hf_config
.
num_local_experts
self
.
num_experts
=
vllm_config
.
model_config
.
hf_config
.
num_local_experts
super
().
__init__
(
vllm_config
=
vllm_config
,
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
prefix
=
prefix
,
...
@@ -471,20 +467,24 @@ class Llama4ForCausalLM(LlamaForCausalLM):
...
@@ -471,20 +467,24 @@ class Llama4ForCausalLM(LlamaForCausalLM):
}
}
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
#
U
pdate temperature tuning config from generation config
#
u
pdate temperature tuning config from generation config
gen_config
=
vllm_config
.
model_config
.
try_get_generation_config
()
gen_config
=
vllm_config
.
model_config
.
try_get_generation_config
()
gen_config
.
update
(
vllm_config
.
model_config
.
override_generation_config
)
gen_config
.
update
(
vllm_config
.
model_config
.
override_generation_config
)
# enable temperature tuning by default when max_model_len > 32K
default_attn_temperature_tuning
=
\
vllm_config
.
model_config
.
max_model_len
>
32768
vllm_config
.
model_config
.
hf_config
.
attn_temperature_tuning
\
vllm_config
.
model_config
.
hf_config
.
attn_temperature_tuning
\
=
gen_config
.
get
(
"attn_temperature_tuning"
,
False
)
=
gen_config
.
get
(
LlamaForCausalLM
.
__init__
(
self
,
"attn_temperature_tuning"
,
default_attn_temperature_tuning
)
vllm_config
=
vllm_config
,
prefix
=
prefix
,
super
().
__init__
(
vllm_config
=
vllm_config
,
layer_type
=
Llama4DecoderLayer
)
prefix
=
prefix
,
layer_type
=
Llama4DecoderLayer
)
def
_init_model
(
self
,
def
_init_model
(
self
,
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
layer_type
:
T
ype
[
Llama4DecoderLayer
]
=
Llama4DecoderLayer
):
layer_type
:
t
ype
[
Llama4DecoderLayer
]
=
Llama4DecoderLayer
):
return
Llama4Model
(
vllm_config
=
vllm_config
,
return
Llama4Model
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
prefix
=
prefix
,
layer_type
=
layer_type
)
layer_type
=
layer_type
)
...
...
vllm/model_executor/models/llama_eagle.py
0 → 100644
View file @
9c4ecf15
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Iterable
,
Set
,
Tuple
import
torch
import
torch.nn
as
nn
from
transformers
import
LlamaConfig
from
vllm.config
import
ModelConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.llama
import
(
LlamaDecoderLayer
,
LlamaForCausalLM
)
from
.utils
import
AutoWeightsLoader
,
maybe_prefix
logger
=
init_logger
(
__name__
)
class
LlamaDecoderLayer
(
LlamaDecoderLayer
):
def
__init__
(
self
,
config
:
LlamaConfig
,
disable_input_layernorm
:
bool
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
(
config
,
prefix
=
prefix
)
# Skip the input_layernorm
# https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427
if
disable_input_layernorm
:
del
self
.
input_layernorm
self
.
input_layernorm
=
nn
.
Identity
()
class
LlamaModel
(
nn
.
Module
):
def
__init__
(
self
,
*
,
model_config
:
ModelConfig
,
start_layer_id
:
int
=
0
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
config
=
model_config
.
hf_config
self
.
vocab_size
=
self
.
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
config
.
vocab_size
,
self
.
config
.
hidden_size
,
prefix
=
maybe_prefix
(
prefix
,
"embed_tokens"
),
)
self
.
layers
=
nn
.
ModuleList
([
LlamaDecoderLayer
(
self
.
config
,
i
==
0
,
prefix
=
maybe_prefix
(
prefix
,
f
"layers.
{
i
+
start_layer_id
}
"
),
)
for
i
in
range
(
self
.
config
.
num_hidden_layers
)
])
self
.
fc
=
torch
.
nn
.
Linear
(
self
.
config
.
hidden_size
*
2
,
self
.
config
.
hidden_size
,
bias
=
False
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
input_embeds
=
self
.
embed_tokens
(
input_ids
)
hidden_states
=
self
.
fc
(
torch
.
cat
((
input_embeds
,
hidden_states
),
dim
=-
1
))
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
residual
,
)
return
hidden_states
+
residual
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".qkv_proj"
,
".q_proj"
,
"q"
),
(
".qkv_proj"
,
".k_proj"
,
"k"
),
(
".qkv_proj"
,
".v_proj"
,
"v"
),
(
".gate_up_proj"
,
".gate_proj"
,
0
),
(
".gate_up_proj"
,
".up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
class
EagleLlamaForCausalLM
(
LlamaForCausalLM
):
def
__init__
(
self
,
*
,
model_config
:
ModelConfig
,
start_layer_id
:
int
=
0
):
nn
.
Module
.
__init__
(
self
)
self
.
config
=
model_config
.
hf_config
self
.
model
=
LlamaModel
(
model_config
=
model_config
,
start_layer_id
=
start_layer_id
,
prefix
=
"model"
)
logit_scale
=
getattr
(
self
.
config
,
"logit_scale"
,
1.0
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
config
.
vocab_size
,
scale
=
logit_scale
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
return
self
.
model
(
input_ids
,
positions
,
hidden_states
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
loader
=
AutoWeightsLoader
(
self
,
skip_prefixes
=
([
"lm_head."
]
if
self
.
config
.
tie_word_embeddings
else
None
),
)
model_weights
=
{}
for
name
,
loaded_weight
in
weights
:
if
"lm_head"
not
in
name
:
name
=
"model."
+
name
model_weights
[
name
]
=
loaded_weight
loader
.
load_weights
(
model_weights
.
items
())
vllm/model_executor/models/llava.py
View file @
9c4ecf15
...
@@ -32,8 +32,9 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
...
@@ -32,8 +32,9 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize
,
MultiModalDataItems
)
ImageSize
,
MultiModalDataItems
)
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
ProcessingCache
,
BaseProcessingInfo
,
ProcessingCache
,
PromptReplacement
,
PromptUpdate
)
PromptReplacement
,
PromptUpdate
,
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
PromptUpdateDetails
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.clip
import
CLIPVisionModel
from
.clip
import
CLIPVisionModel
...
@@ -42,8 +43,7 @@ from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
...
@@ -42,8 +43,7 @@ from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
from
.siglip
import
SiglipVisionModel
from
.siglip
import
SiglipVisionModel
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
init_vllm_registered_model
,
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
init_vllm_registered_model
,
maybe_prefix
,
merge_multimodal_embeddings
)
maybe_prefix
,
merge_multimodal_embeddings
)
from
.vision
import
(
get_vision_encoder_info
,
scatter_patch_features
,
from
.vision
import
get_vision_encoder_info
select_patch_features
)
class
LlavaImagePixelInputs
(
TypedDict
):
class
LlavaImagePixelInputs
(
TypedDict
):
...
@@ -67,14 +67,6 @@ class PixtralHFImagePixelInputs(TypedDict):
...
@@ -67,14 +67,6 @@ class PixtralHFImagePixelInputs(TypedDict):
in which case the data is passed as a list instead of a batched tensor.
in which case the data is passed as a list instead of a batched tensor.
"""
"""
embed_is_patch
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
class
LlavaImageEmbeddingInputs
(
TypedDict
):
class
LlavaImageEmbeddingInputs
(
TypedDict
):
type
:
Literal
[
"image_embeds"
]
type
:
Literal
[
"image_embeds"
]
...
@@ -145,13 +137,6 @@ class BaseLlavaProcessingInfo(BaseProcessingInfo):
...
@@ -145,13 +137,6 @@ class BaseLlavaProcessingInfo(BaseProcessingInfo):
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
None
}
return
{
"image"
:
None
}
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
int
]:
return
{
"image"
:
self
.
get_max_image_tokens
()}
def
_apply_feature_select_strategy
(
def
_apply_feature_select_strategy
(
self
,
self
,
strategy
:
str
,
strategy
:
str
,
...
@@ -201,30 +186,31 @@ _I = TypeVar("_I", bound=BaseLlavaProcessingInfo)
...
@@ -201,30 +186,31 @@ _I = TypeVar("_I", bound=BaseLlavaProcessingInfo)
class
LlavaDummyInputsBuilder
(
BaseDummyInputsBuilder
[
_I
]):
class
LlavaDummyInputsBuilder
(
BaseDummyInputsBuilder
[
_I
]):
def
get_dummy_processor_inputs
(
def
get_dummy_text
(
self
,
mm_counts
:
Mapping
[
str
,
int
])
->
str
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
processor
=
self
.
info
.
get_hf_processor
()
image_token
=
processor
.
image_token
return
image_token
*
num_images
def
get_dummy_mm_data
(
self
,
self
,
seq_len
:
int
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
mm_counts
:
Mapping
[
str
,
int
],
)
->
ProcessorInputs
:
)
->
MultiModalDataDict
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
num_images
=
mm_counts
.
get
(
"image"
,
0
)
processor
=
self
.
info
.
get_hf_processor
()
image_token
=
processor
.
image_token
target_width
,
target_height
=
\
target_width
,
target_height
=
\
self
.
info
.
get_image_size_with_most_features
()
self
.
info
.
get_image_size_with_most_features
()
mm_data
=
{
return
{
"image"
:
"image"
:
self
.
_get_dummy_images
(
width
=
target_width
,
self
.
_get_dummy_images
(
width
=
target_width
,
height
=
target_height
,
height
=
target_height
,
num_images
=
num_images
)
num_images
=
num_images
)
}
}
return
ProcessorInputs
(
prompt_text
=
image_token
*
num_images
,
mm_data
=
mm_data
,
)
class
LlavaProcessingInfo
(
BaseLlavaProcessingInfo
):
class
LlavaProcessingInfo
(
BaseLlavaProcessingInfo
):
...
@@ -343,23 +329,6 @@ class PixtralHFMultiModalProcessor(
...
@@ -343,23 +329,6 @@ class PixtralHFMultiModalProcessor(
for
p
,
(
h
,
w
)
in
zip
(
pixel_values
,
image_sizes
)
for
p
,
(
h
,
w
)
in
zip
(
pixel_values
,
image_sizes
)
]
]
hf_config
=
self
.
info
.
get_hf_config
()
vision_config
=
hf_config
.
vision_config
assert
isinstance
(
vision_config
,
PixtralVisionConfig
)
encoder_info
=
PixtralHFEncoderInfo
(
vision_config
)
tile_sizes
=
[
encoder_info
.
get_patch_grid_size
(
image_width
=
pixel_value
.
shape
[
-
1
],
image_height
=
pixel_value
.
shape
[
-
2
],
)
for
pixel_value
in
processed_outputs
[
"pixel_values"
]
]
embed_is_patch
=
[
torch
.
tensor
(([
True
]
*
ncols
+
[
False
])
*
nrows
)
for
ncols
,
nrows
in
tile_sizes
]
processed_outputs
[
"embed_is_patch"
]
=
embed_is_patch
return
processed_outputs
return
processed_outputs
def
_get_mm_fields_config
(
def
_get_mm_fields_config
(
...
@@ -369,7 +338,6 @@ class PixtralHFMultiModalProcessor(
...
@@ -369,7 +338,6 @@ class PixtralHFMultiModalProcessor(
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
return
dict
(
return
dict
(
pixel_values
=
MultiModalFieldConfig
.
batched
(
"image"
),
pixel_values
=
MultiModalFieldConfig
.
batched
(
"image"
),
embed_is_patch
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_embeds
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_embeds
=
MultiModalFieldConfig
.
batched
(
"image"
),
)
)
...
@@ -404,7 +372,7 @@ class PixtralHFMultiModalProcessor(
...
@@ -404,7 +372,7 @@ class PixtralHFMultiModalProcessor(
tokens
=
([
image_token_id
]
*
ncols
+
[
image_break_id
])
*
nrows
tokens
=
([
image_token_id
]
*
ncols
+
[
image_break_id
])
*
nrows
tokens
[
-
1
]
=
image_end_id
tokens
[
-
1
]
=
image_end_id
return
tokens
return
PromptUpdateDetails
.
select_token_id
(
tokens
,
image_token_id
)
return
[
return
[
PromptReplacement
(
PromptReplacement
(
...
@@ -612,17 +580,9 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -612,17 +580,9 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
f
"Got type:
{
type
(
pixel_values
)
}
"
)
f
"Got type:
{
type
(
pixel_values
)
}
"
)
if
self
.
config
.
vision_config
.
model_type
==
"pixtral"
:
if
self
.
config
.
vision_config
.
model_type
==
"pixtral"
:
embed_is_patch
=
kwargs
.
pop
(
"embed_is_patch"
)
if
not
isinstance
(
embed_is_patch
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of embed_is_patch. "
f
"Got type:
{
type
(
embed_is_patch
)
}
"
)
embed_is_patch
=
flatten_bn
(
embed_is_patch
)
return
PixtralHFImagePixelInputs
(
return
PixtralHFImagePixelInputs
(
type
=
"pixel_values_pixtral"
,
type
=
"pixel_values_pixtral"
,
pixel_values
=
flatten_bn
(
pixel_values
),
pixel_values
=
flatten_bn
(
pixel_values
),
embed_is_patch
=
embed_is_patch
,
)
)
return
LlavaImagePixelInputs
(
return
LlavaImagePixelInputs
(
...
@@ -708,22 +668,16 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -708,22 +668,16 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
image_embeds
=
torch
.
split
(
image_embeds
,
feature_sizes
)
image_embeds
=
torch
.
split
(
image_embeds
,
feature_sizes
)
return
image_embeds
return
image_embeds
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
get_multimodal_embeddings
(
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
if
image_input
is
None
:
return
None
return
None
image_features
=
self
.
_process_image_input
(
image_input
)
return
self
.
_process_image_input
(
image_input
)
if
image_input
[
"type"
]
!=
"pixel_values_pixtral"
:
# The path is used for pixtral (V0 only) and llava (V0/V1)
return
image_features
return
scatter_patch_features
(
image_features
,
image_input
[
"embed_is_patch"
],
)
def
get_input_embeddings
(
def
get_input_embeddings
(
self
,
self
,
...
@@ -735,7 +689,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -735,7 +689,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds
=
merge_multimodal_embeddings
(
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
input_ids
,
inputs_embeds
,
inputs_embeds
,
select_patch_features
(
multimodal_embeddings
)
,
multimodal_embeddings
,
self
.
config
.
image_token_index
,
self
.
config
.
image_token_index
,
)
)
return
inputs_embeds
return
inputs_embeds
...
...
vllm/model_executor/models/llava_next.py
View file @
9c4ecf15
...
@@ -480,6 +480,9 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -480,6 +480,9 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
for
i
,
patch_features_batch
in
enumerate
(
patch_embeddings
)
for
i
,
patch_features_batch
in
enumerate
(
patch_embeddings
)
]
]
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
get_multimodal_embeddings
(
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
...
...
vllm/model_executor/models/llava_next_video.py
View file @
9c4ecf15
...
@@ -16,13 +16,14 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
...
@@ -16,13 +16,14 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from
vllm.model_executor.models.clip
import
CLIPVisionModel
from
vllm.model_executor.models.clip
import
CLIPVisionModel
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
MultiModalFieldConfig
,
MultiModalKwargs
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalKwargs
)
from
vllm.multimodal.parse
import
(
ImageSize
,
MultiModalDataItems
,
from
vllm.multimodal.parse
import
(
ImageSize
,
MultiModalDataItems
,
VideoEmbeddingItems
,
VideoProcessorItems
)
VideoEmbeddingItems
,
VideoProcessorItems
)
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptReplacement
,
BaseProcessingInfo
,
PromptReplacement
,
PromptUpdate
)
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
is_list_of
from
vllm.utils
import
is_list_of
...
@@ -61,22 +62,6 @@ class LlavaNextVideoProcessingInfo(BaseProcessingInfo):
...
@@ -61,22 +62,6 @@ class LlavaNextVideoProcessingInfo(BaseProcessingInfo):
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"video"
:
1
}
return
{
"video"
:
1
}
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
int
]:
target_width
,
target_height
=
self
.
get_image_size_with_most_features
()
max_video_tokens
=
self
.
get_num_video_tokens
(
image_width
=
target_width
,
image_height
=
target_height
,
num_frames
=
self
.
get_num_frames_with_most_features
(
seq_len
,
mm_counts
),
)
return
{
"video"
:
max_video_tokens
}
def
get_image_size_with_most_features
(
self
)
->
ImageSize
:
def
get_image_size_with_most_features
(
self
)
->
ImageSize
:
vision_encoder_info
=
self
.
get_vision_encoder_info
()
vision_encoder_info
=
self
.
get_vision_encoder_info
()
width
=
height
=
vision_encoder_info
.
get_image_size
()
width
=
height
=
vision_encoder_info
.
get_image_size
()
...
@@ -146,22 +131,27 @@ class LlavaNextVideoProcessingInfo(BaseProcessingInfo):
...
@@ -146,22 +131,27 @@ class LlavaNextVideoProcessingInfo(BaseProcessingInfo):
class
LlavaNextVideoDummyInputsBuilder
(
class
LlavaNextVideoDummyInputsBuilder
(
BaseDummyInputsBuilder
[
LlavaNextVideoProcessingInfo
]):
BaseDummyInputsBuilder
[
LlavaNextVideoProcessingInfo
]):
def
get_dummy_processor_inputs
(
def
get_dummy_text
(
self
,
mm_counts
:
Mapping
[
str
,
int
])
->
str
:
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
ProcessorInputs
:
num_videos
=
mm_counts
.
get
(
"video"
,
0
)
num_videos
=
mm_counts
.
get
(
"video"
,
0
)
processor
=
self
.
info
.
get_hf_processor
()
processor
=
self
.
info
.
get_hf_processor
()
video_token
=
processor
.
video_token
video_token
=
processor
.
video_token
return
video_token
*
num_videos
def
get_dummy_mm_data
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
MultiModalDataDict
:
num_videos
=
mm_counts
.
get
(
"video"
,
0
)
target_width
,
target_height
=
\
target_width
,
target_height
=
\
self
.
info
.
get_image_size_with_most_features
()
self
.
info
.
get_image_size_with_most_features
()
target_num_frames
=
\
target_num_frames
=
\
self
.
info
.
get_num_frames_with_most_features
(
seq_len
,
mm_counts
)
self
.
info
.
get_num_frames_with_most_features
(
seq_len
,
mm_counts
)
mm_data
=
{
return
{
"video"
:
"video"
:
self
.
_get_dummy_videos
(
self
.
_get_dummy_videos
(
width
=
target_width
,
width
=
target_width
,
...
@@ -171,11 +161,6 @@ class LlavaNextVideoDummyInputsBuilder(
...
@@ -171,11 +161,6 @@ class LlavaNextVideoDummyInputsBuilder(
)
)
}
}
return
ProcessorInputs
(
prompt_text
=
video_token
*
num_videos
,
mm_data
=
mm_data
,
)
class
LlavaNextVideoMultiModalProcessor
(
class
LlavaNextVideoMultiModalProcessor
(
BaseMultiModalProcessor
[
LlavaNextVideoProcessingInfo
]):
BaseMultiModalProcessor
[
LlavaNextVideoProcessingInfo
]):
...
@@ -421,6 +406,9 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -421,6 +406,9 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
return
[
e
.
flatten
(
0
,
1
)
for
e
in
embeds
]
return
[
e
.
flatten
(
0
,
1
)
for
e
in
embeds
]
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
get_multimodal_embeddings
(
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
video_input
=
self
.
_parse_and_validate_video_input
(
**
kwargs
)
video_input
=
self
.
_parse_and_validate_video_input
(
**
kwargs
)
...
...
vllm/model_executor/models/llava_onevision.py
View file @
9c4ecf15
...
@@ -19,11 +19,11 @@ from vllm.model_executor.layers.activation import get_act_fn
...
@@ -19,11 +19,11 @@ from vllm.model_executor.layers.activation import get_act_fn
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
MultiModalFieldConfig
,
MultiModalKwargs
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalKwargs
)
from
vllm.multimodal.parse
import
(
ImageSize
,
MultiModalDataItems
,
from
vllm.multimodal.parse
import
(
ImageSize
,
MultiModalDataItems
,
VideoEmbeddingItems
,
VideoProcessorItems
)
VideoEmbeddingItems
,
VideoProcessorItems
)
from
vllm.multimodal.processing
import
PromptReplacement
,
PromptUpdate
from
vllm.multimodal.processing
import
PromptReplacement
,
PromptUpdate
from
vllm.multimodal.profiling
import
ProcessorInputs
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.clip
import
CLIPVisionModel
from
.clip
import
CLIPVisionModel
...
@@ -101,16 +101,6 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo):
...
@@ -101,16 +101,6 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo):
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
None
,
"video"
:
None
}
return
{
"image"
:
None
,
"video"
:
None
}
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
int
]:
return
{
"image"
:
self
.
get_max_image_tokens
(),
"video"
:
self
.
get_max_video_tokens
(
seq_len
,
mm_counts
),
}
# Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86
# Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86
# with additional logic afterwards taken from LlavaOnevisionProcessor
# with additional logic afterwards taken from LlavaOnevisionProcessor
def
_get_num_unpadded_features
(
def
_get_num_unpadded_features
(
...
@@ -236,11 +226,7 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo):
...
@@ -236,11 +226,7 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo):
class
LlavaOnevisionDummyInputsBuilder
(
class
LlavaOnevisionDummyInputsBuilder
(
LlavaDummyInputsBuilder
[
LlavaOnevisionProcessingInfo
]):
LlavaDummyInputsBuilder
[
LlavaOnevisionProcessingInfo
]):
def
get_dummy_processor_inputs
(
def
get_dummy_text
(
self
,
mm_counts
:
Mapping
[
str
,
int
])
->
str
:
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
ProcessorInputs
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
num_images
=
mm_counts
.
get
(
"image"
,
0
)
num_videos
=
mm_counts
.
get
(
"video"
,
0
)
num_videos
=
mm_counts
.
get
(
"video"
,
0
)
...
@@ -248,13 +234,23 @@ class LlavaOnevisionDummyInputsBuilder(
...
@@ -248,13 +234,23 @@ class LlavaOnevisionDummyInputsBuilder(
image_token
=
processor
.
image_token
image_token
=
processor
.
image_token
video_token
=
processor
.
video_token
video_token
=
processor
.
video_token
return
image_token
*
num_images
+
video_token
*
num_videos
def
get_dummy_mm_data
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
MultiModalDataDict
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
num_videos
=
mm_counts
.
get
(
"video"
,
0
)
target_width
,
target_height
=
\
target_width
,
target_height
=
\
self
.
info
.
get_image_size_with_most_features
()
self
.
info
.
get_image_size_with_most_features
()
target_num_frames
=
\
target_num_frames
=
\
self
.
info
.
get_num_frames_with_most_features
(
seq_len
,
self
.
info
.
get_num_frames_with_most_features
(
seq_len
,
mm_counts
)
mm_counts
)
mm_data
=
{
return
{
"image"
:
"image"
:
self
.
_get_dummy_images
(
width
=
target_width
,
self
.
_get_dummy_images
(
width
=
target_width
,
height
=
target_height
,
height
=
target_height
,
...
@@ -268,11 +264,6 @@ class LlavaOnevisionDummyInputsBuilder(
...
@@ -268,11 +264,6 @@ class LlavaOnevisionDummyInputsBuilder(
)
)
}
}
return
ProcessorInputs
(
prompt_text
=
image_token
*
num_images
+
video_token
*
num_videos
,
mm_data
=
mm_data
,
)
class
LlavaOnevisionMultiModalProcessor
(
class
LlavaOnevisionMultiModalProcessor
(
BaseLlavaNextMultiModalProcessor
[
LlavaOnevisionProcessingInfo
]):
BaseLlavaNextMultiModalProcessor
[
LlavaOnevisionProcessingInfo
]):
...
@@ -852,6 +843,9 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -852,6 +843,9 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
image_feature
=
image_feature
.
view
(
batch_frames
,
-
1
,
dim
)
image_feature
=
image_feature
.
view
(
batch_frames
,
-
1
,
dim
)
return
image_feature
return
image_feature
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
get_multimodal_embeddings
(
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
modalities
=
self
.
_parse_and_validate_multimodal_inputs
(
**
kwargs
)
modalities
=
self
.
_parse_and_validate_multimodal_inputs
(
**
kwargs
)
...
...
vllm/model_executor/models/mamba2.py
View file @
9c4ecf15
...
@@ -13,6 +13,8 @@ from vllm.distributed.parallel_state import get_pp_group
...
@@ -13,6 +13,8 @@ from vllm.distributed.parallel_state import get_pp_group
from
vllm.forward_context
import
get_forward_context
from
vllm.forward_context
import
get_forward_context
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.mamba.mamba2_metadata
import
(
Mamba2Metadata
,
prepare_mamba2_metadata
)
from
vllm.model_executor.layers.mamba.mamba_mixer2
import
(
from
vllm.model_executor.layers.mamba.mamba_mixer2
import
(
MambaMixer2
,
extra_groups_for_head_shards
)
MambaMixer2
,
extra_groups_for_head_shards
)
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
...
@@ -57,7 +59,6 @@ class Mamba2DecoderLayer(nn.Module):
...
@@ -57,7 +59,6 @@ class Mamba2DecoderLayer(nn.Module):
head_dim
=
config
.
head_dim
,
head_dim
=
config
.
head_dim
,
rms_norm_eps
=
config
.
layer_norm_epsilon
,
rms_norm_eps
=
config
.
layer_norm_epsilon
,
activation
=
config
.
hidden_act
,
activation
=
config
.
hidden_act
,
chunk_size
=
config
.
chunk_size
,
quant_config
=
quant_config
)
quant_config
=
quant_config
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
...
@@ -67,7 +68,7 @@ class Mamba2DecoderLayer(nn.Module):
...
@@ -67,7 +68,7 @@ class Mamba2DecoderLayer(nn.Module):
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
mamba_cache_params
:
MambaCacheParams
,
mamba_cache_params
:
MambaCacheParams
,
sequence_idx
:
Optional
[
torch
.
Tensor
]
,
mamba2_metadata
:
Mamba2Metadata
,
**
kwargs
,
**
kwargs
,
):
):
if
residual
is
None
:
if
residual
is
None
:
...
@@ -77,7 +78,7 @@ class Mamba2DecoderLayer(nn.Module):
...
@@ -77,7 +78,7 @@ class Mamba2DecoderLayer(nn.Module):
hidden_states
,
residual
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
,
residual
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mixer
(
hidden_states
,
mamba_cache_params
,
hidden_states
=
self
.
mixer
(
hidden_states
,
mamba_cache_params
,
sequence_idx
)
mamba2_metadata
)
return
hidden_states
,
residual
return
hidden_states
,
residual
...
@@ -138,20 +139,13 @@ class Mamba2Model(nn.Module):
...
@@ -138,20 +139,13 @@ class Mamba2Model(nn.Module):
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
residual
=
intermediate_tensors
[
"residual"
]
# pass a sequence index tensor, that is required for
# proper continuous batching computation including
# chunked prefill
seq_idx
=
None
attn_metadata
:
AttentionMetadata
=
get_forward_context
().
attn_metadata
attn_metadata
:
AttentionMetadata
=
get_forward_context
().
attn_metadata
if
attn_metadata
.
num_prefills
>
0
:
seq_idx
=
torch
.
zeros_like
(
input_ids
,
dtype
=
torch
.
int32
)
mamba2_metadata
=
prepare_mamba2_metadata
(
for
i
,
(
srt
,
end
)
in
enumerate
(
chunk_size
=
self
.
config
.
chunk_size
,
zip
(
input_ids
=
input_ids
,
attn_metadata
.
query_start_loc
,
attn_metadata
=
attn_metadata
,
attn_metadata
.
query_start_loc
[
1
:],
)
)):
seq_idx
[
srt
:
end
]
=
i
seq_idx
.
unsqueeze_
(
0
)
for
i
in
range
(
len
(
self
.
layers
)):
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
layer
=
self
.
layers
[
i
]
...
@@ -162,7 +156,7 @@ class Mamba2Model(nn.Module):
...
@@ -162,7 +156,7 @@ class Mamba2Model(nn.Module):
residual
=
residual
,
residual
=
residual
,
mamba_cache_params
=
mamba_cache_params
.
at_layer_idx
(
mamba_cache_params
=
mamba_cache_params
.
at_layer_idx
(
i
-
self
.
start_layer
),
i
-
self
.
start_layer
),
sequence_idx
=
seq_idx
)
mamba2_metadata
=
mamba2_metadata
)
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
return
IntermediateTensors
({
...
...
vllm/model_executor/models/minicpmo.py
View file @
9c4ecf15
...
@@ -35,13 +35,14 @@ from transformers.models.whisper.modeling_whisper import (
...
@@ -35,13 +35,14 @@ from transformers.models.whisper.modeling_whisper import (
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
from
vllm.multimodal.inputs
import
MultiModalFieldConfig
,
NestedTensors
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
NestedTensors
)
from
vllm.multimodal.parse
import
(
AudioItem
,
AudioProcessorItems
,
from
vllm.multimodal.parse
import
(
AudioItem
,
AudioProcessorItems
,
DictEmbeddingItems
,
ModalityData
,
DictEmbeddingItems
,
ModalityData
,
ModalityDataItems
,
MultiModalDataItems
,
ModalityDataItems
,
MultiModalDataItems
,
MultiModalDataParser
)
MultiModalDataParser
)
from
vllm.multimodal.processing
import
PromptReplacement
,
PromptUpdate
from
vllm.multimodal.processing
import
(
PromptReplacement
,
PromptUpdate
,
from
vllm.multimodal.profiling
import
ProcessorInputs
PromptUpdateDetails
)
from
.minicpmv
import
(
_MAX_FRAMES_PER_VIDEO
,
MiniCPMV2_6
,
from
.minicpmv
import
(
_MAX_FRAMES_PER_VIDEO
,
MiniCPMV2_6
,
MiniCPMVDummyInputsBuilder
,
MiniCPMVDummyInputsBuilder
,
...
@@ -50,7 +51,6 @@ from .minicpmv import (_MAX_FRAMES_PER_VIDEO, MiniCPMV2_6,
...
@@ -50,7 +51,6 @@ from .minicpmv import (_MAX_FRAMES_PER_VIDEO, MiniCPMV2_6,
_minicpmv_field_config
)
_minicpmv_field_config
)
from
.utils
import
(
AutoWeightsLoader
,
cast_overflow_tensors
,
flatten_bn
,
from
.utils
import
(
AutoWeightsLoader
,
cast_overflow_tensors
,
flatten_bn
,
maybe_prefix
)
maybe_prefix
)
from
.vision
import
scatter_patch_features
CPU_DEVICE
=
torch
.
device
(
"cpu"
)
CPU_DEVICE
=
torch
.
device
(
"cpu"
)
...
@@ -73,14 +73,6 @@ class MiniCPMOAudioFeatureInputs(TypedDict):
...
@@ -73,14 +73,6 @@ class MiniCPMOAudioFeatureInputs(TypedDict):
which equals to `audio_features.shape[-1]`
which equals to `audio_features.shape[-1]`
"""
"""
embed_is_patch
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
"""
A boolean mask indicating which audio embeddings correspond
to patch tokens.
Shape: `(batch_size * num_audios, num_embeds)`
"""
class
MiniCPMOAudioEmbeddingInputs
(
TypedDict
):
class
MiniCPMOAudioEmbeddingInputs
(
TypedDict
):
type
:
Literal
[
"audio_embeds"
]
type
:
Literal
[
"audio_embeds"
]
...
@@ -93,14 +85,6 @@ class MiniCPMOAudioEmbeddingInputs(TypedDict):
...
@@ -93,14 +85,6 @@ class MiniCPMOAudioEmbeddingInputs(TypedDict):
Length of each slice may vary, so pass it as a list.
Length of each slice may vary, so pass it as a list.
"""
"""
embed_is_patch
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
"""
A boolean mask indicating which audio embeddings correspond
to patch tokens.
Shape: `(batch_size * num_audios, num_embeds)`
"""
MiniCPMOAudioInputs
=
Union
[
MiniCPMOAudioFeatureInputs
,
MiniCPMOAudioInputs
=
Union
[
MiniCPMOAudioFeatureInputs
,
MiniCPMOAudioEmbeddingInputs
]
MiniCPMOAudioEmbeddingInputs
]
...
@@ -115,7 +99,6 @@ def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]):
...
@@ -115,7 +99,6 @@ def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]):
audio_features
=
MultiModalFieldConfig
.
batched
(
"audio"
),
audio_features
=
MultiModalFieldConfig
.
batched
(
"audio"
),
audio_feature_lens
=
MultiModalFieldConfig
.
batched
(
"audio"
),
audio_feature_lens
=
MultiModalFieldConfig
.
batched
(
"audio"
),
audio_embeds
=
MultiModalFieldConfig
.
batched
(
"audio"
),
audio_embeds
=
MultiModalFieldConfig
.
batched
(
"audio"
),
audio_embed_is_patch
=
MultiModalFieldConfig
.
batched
(
"audio"
),
audio_token_id
=
MultiModalFieldConfig
.
shared
(
"audio"
,
num_audios
),
audio_token_id
=
MultiModalFieldConfig
.
shared
(
"audio"
,
num_audios
),
)
)
...
@@ -143,7 +126,7 @@ class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser):
...
@@ -143,7 +126,7 @@ class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser):
def
_parse_audio_data
(
def
_parse_audio_data
(
self
,
self
,
data
:
Union
[
dict
[
str
,
torch
.
Tensor
],
ModalityData
[
AudioItem
]],
data
:
Union
[
dict
[
str
,
torch
.
Tensor
],
ModalityData
[
AudioItem
]],
)
->
ModalityDataItems
[
Any
,
Any
]:
)
->
Optional
[
ModalityDataItems
[
Any
,
Any
]
]
:
if
isinstance
(
data
,
dict
):
if
isinstance
(
data
,
dict
):
return
MiniCPMOAudioEmbeddingItems
(
return
MiniCPMOAudioEmbeddingItems
(
data
,
data
,
...
@@ -159,17 +142,6 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
...
@@ -159,17 +142,6 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
**
super
().
get_supported_mm_limits
(),
"audio"
:
None
}
return
{
**
super
().
get_supported_mm_limits
(),
"audio"
:
None
}
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
int
]:
return
{
**
super
().
get_mm_max_tokens_per_item
(
seq_len
,
mm_counts
),
"audio"
:
self
.
get_max_audio_tokens
(),
}
def
get_audio_placeholder
(
def
get_audio_placeholder
(
self
,
self
,
audio_lens
:
int
,
audio_lens
:
int
,
...
@@ -197,8 +169,7 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
...
@@ -197,8 +169,7 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
pool_step
=
self
.
get_default_audio_pool_step
()
pool_step
=
self
.
get_default_audio_pool_step
()
fbank_feat_in_chunk
=
100
fbank_feat_in_chunk
=
100
cnn_feat_in_chunk
=
(
fbank_feat_in_chunk
-
1
)
//
2
+
1
cnn_feat_in_chunk
=
(
fbank_feat_in_chunk
-
1
)
//
2
+
1
num_audio_tokens
=
(
cnn_feat_in_chunk
-
pool_step
)
//
pool_step
+
1
return
(
cnn_feat_in_chunk
-
pool_step
)
//
pool_step
+
1
return
num_audio_tokens
+
2
# <audio>(<unk>*N)</audio>
def
get_max_audio_chunks_with_most_features
(
self
)
->
int
:
def
get_max_audio_chunks_with_most_features
(
self
)
->
int
:
return
30
return
30
...
@@ -209,8 +180,7 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
...
@@ -209,8 +180,7 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
def
get_audio_len_by_num_chunks
(
self
,
num_chunks
:
int
)
->
int
:
def
get_audio_len_by_num_chunks
(
self
,
num_chunks
:
int
)
->
int
:
sampling_rate
=
self
.
get_default_audio_sampling_rate
()
sampling_rate
=
self
.
get_default_audio_sampling_rate
()
# exclude <audio> </audio>
num_tokens_per_chunk
=
self
.
get_max_audio_tokens_per_chunk
()
num_tokens_per_chunk
=
self
.
get_max_audio_tokens_per_chunk
()
-
2
return
int
(
num_chunks
*
sampling_rate
/
num_tokens_per_chunk
)
+
1
return
int
(
num_chunks
*
sampling_rate
/
num_tokens_per_chunk
)
+
1
def
get_num_frames_with_most_features
(
def
get_num_frames_with_most_features
(
...
@@ -236,29 +206,31 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
...
@@ -236,29 +206,31 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
class
MiniCPMODummyInputsBuilder
(
class
MiniCPMODummyInputsBuilder
(
MiniCPMVDummyInputsBuilder
[
MiniCPMOProcessingInfo
]):
MiniCPMVDummyInputsBuilder
[
MiniCPMOProcessingInfo
]):
def
get_dummy_processor_inputs
(
def
get_dummy_text
(
self
,
mm_counts
:
Mapping
[
str
,
int
])
->
str
:
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
num_audios
=
mm_counts
.
get
(
"audio"
,
0
)
int
])
->
ProcessorInputs
:
audio_prompt_texts
=
self
.
info
.
audio_pattern
*
num_audios
return
super
().
get_dummy_text
(
mm_counts
)
+
audio_prompt_texts
def
get_dummy_mm_data
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
MultiModalDataDict
:
num_audios
=
mm_counts
.
get
(
"audio"
,
0
)
num_audios
=
mm_counts
.
get
(
"audio"
,
0
)
audio_len
=
self
.
info
.
get_max_audio_chunks_with_most_features
()
*
\
audio_len
=
self
.
info
.
get_max_audio_chunks_with_most_features
()
*
\
self
.
info
.
get_default_audio_sampling_rate
()
self
.
info
.
get_default_audio_sampling_rate
()
processor_inputs
=
super
().
get_dummy_processor_inputs
(
seq_len
,
mm_counts
)
audio_prompt_texts
=
self
.
info
.
audio_pattern
*
num_audios
audio_mm_data
=
{
audio_mm_data
=
{
"audio"
:
"audio"
:
self
.
_get_dummy_audios
(
length
=
audio_len
,
num_audios
=
num_audios
)
self
.
_get_dummy_audios
(
length
=
audio_len
,
num_audios
=
num_audios
)
}
}
return
ProcessorInputs
(
return
{
prompt_text
=
processor_inputs
.
prompt_text
+
audio_prompt_texts
,
**
super
().
get_dummy_mm_data
(
seq_len
,
mm_counts
),
mm_data
=
{
**
audio_mm_data
,
**
processor_inputs
.
mm_data
,
}
**
audio_mm_data
,
},
)
class
MiniCPMOMultiModalProcessor
(
class
MiniCPMOMultiModalProcessor
(
...
@@ -295,13 +267,6 @@ class MiniCPMOMultiModalProcessor(
...
@@ -295,13 +267,6 @@ class MiniCPMOMultiModalProcessor(
if
isinstance
(
parsed_audios
,
MiniCPMOAudioEmbeddingItems
):
if
isinstance
(
parsed_audios
,
MiniCPMOAudioEmbeddingItems
):
audio_inputs
=
{}
audio_inputs
=
{}
audio_lens
=
[
self
.
info
.
get_audio_len_by_num_chunks
(
sum
(
map
(
len
,
parsed_audios
.
get
(
i
)[
"audio_embeds"
])))
for
i
in
range
(
len
(
parsed_audios
))
]
else
:
else
:
audio_inputs
=
self
.
_base_call_hf_processor
(
audio_inputs
=
self
.
_base_call_hf_processor
(
prompts
=
[
self
.
info
.
audio_pattern
]
*
len
(
parsed_audios
),
prompts
=
[
self
.
info
.
audio_pattern
]
*
len
(
parsed_audios
),
...
@@ -323,27 +288,7 @@ class MiniCPMOMultiModalProcessor(
...
@@ -323,27 +288,7 @@ class MiniCPMOMultiModalProcessor(
]
]
audio_inputs
[
"audio_features"
]
=
unpadded_audio_features
audio_inputs
[
"audio_features"
]
=
unpadded_audio_features
audio_lens
=
[
parsed_audios
.
get_audio_length
(
i
)
for
i
in
range
(
len
(
parsed_audios
))
]
audio_repl_features
=
[
self
.
get_audio_prompt_texts
(
audio_len
)
for
audio_len
in
audio_lens
]
tokenizer
=
self
.
info
.
get_tokenizer
()
tokenizer
=
self
.
info
.
get_tokenizer
()
audio_repls_feature_tokens
=
[
tokenizer
.
encode
(
audio_repl
,
add_special_tokens
=
False
)
for
audio_repl
in
audio_repl_features
]
embed_is_patch
=
[
self
.
get_embed_is_patch
(
audio_repl_tokens
)
for
audio_repl_tokens
in
audio_repls_feature_tokens
]
audio_inputs
[
"audio_embed_is_patch"
]
=
embed_is_patch
unk_token_id
=
tokenizer
.
get_vocab
()[
"<unk>"
]
unk_token_id
=
tokenizer
.
get_vocab
()[
"<unk>"
]
audio_inputs
[
"audio_token_id"
]
=
torch
.
tensor
(
unk_token_id
)
audio_inputs
[
"audio_token_id"
]
=
torch
.
tensor
(
unk_token_id
)
...
@@ -384,7 +329,10 @@ class MiniCPMOMultiModalProcessor(
...
@@ -384,7 +329,10 @@ class MiniCPMOMultiModalProcessor(
else
:
else
:
audio_len
=
audios
.
get_audio_length
(
item_idx
)
audio_len
=
audios
.
get_audio_length
(
item_idx
)
return
self
.
get_audio_prompt_texts
(
audio_len
)
return
PromptUpdateDetails
.
select_text
(
self
.
get_audio_prompt_texts
(
audio_len
),
"<unk>"
,
)
return
[
return
[
*
base_updates
,
*
base_updates
,
...
@@ -713,13 +661,6 @@ class MiniCPMO(MiniCPMV2_6):
...
@@ -713,13 +661,6 @@ class MiniCPMO(MiniCPMV2_6):
assert
isinstance
(
audio_token_id
,
torch
.
Tensor
)
assert
isinstance
(
audio_token_id
,
torch
.
Tensor
)
self
.
mm_token_ids
.
add
(
audio_token_id
.
flatten
().
unique
().
item
())
self
.
mm_token_ids
.
add
(
audio_token_id
.
flatten
().
unique
().
item
())
audio_embed_is_patch
=
kwargs
.
pop
(
"audio_embed_is_patch"
)
if
not
isinstance
(
audio_embed_is_patch
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of audio_embed_is_patch. "
f
"Got type:
{
type
(
audio_embed_is_patch
)
}
"
)
audio_embed_is_patch
=
flatten_bn
(
audio_embed_is_patch
)
if
audio_embeds
is
not
None
:
if
audio_embeds
is
not
None
:
if
not
isinstance
(
audio_embeds
,
(
torch
.
Tensor
,
list
)):
if
not
isinstance
(
audio_embeds
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of audio_embeds. "
raise
ValueError
(
"Incorrect type of audio_embeds. "
...
@@ -730,7 +671,6 @@ class MiniCPMO(MiniCPMV2_6):
...
@@ -730,7 +671,6 @@ class MiniCPMO(MiniCPMV2_6):
return
MiniCPMOAudioEmbeddingInputs
(
return
MiniCPMOAudioEmbeddingInputs
(
type
=
"audio_embeds"
,
type
=
"audio_embeds"
,
audio_embeds
=
audio_embeds_flat
,
audio_embeds
=
audio_embeds_flat
,
embed_is_patch
=
audio_embed_is_patch
,
)
)
if
not
isinstance
(
audio_features
,
(
torch
.
Tensor
,
list
)):
if
not
isinstance
(
audio_features
,
(
torch
.
Tensor
,
list
)):
...
@@ -749,7 +689,6 @@ class MiniCPMO(MiniCPMV2_6):
...
@@ -749,7 +689,6 @@ class MiniCPMO(MiniCPMV2_6):
type
=
"audio_features"
,
type
=
"audio_features"
,
audio_features
=
audio_features_flat
,
audio_features
=
audio_features_flat
,
audio_feature_lens
=
audio_feature_lens_flat
,
audio_feature_lens
=
audio_feature_lens_flat
,
embed_is_patch
=
audio_embed_is_patch
,
)
)
def
_parse_and_validate_multimodal_inputs
(
self
,
**
kwargs
:
object
)
->
dict
:
def
_parse_and_validate_multimodal_inputs
(
self
,
**
kwargs
:
object
)
->
dict
:
...
@@ -781,10 +720,6 @@ class MiniCPMO(MiniCPMV2_6):
...
@@ -781,10 +720,6 @@ class MiniCPMO(MiniCPMV2_6):
if
modality
==
"audios"
:
if
modality
==
"audios"
:
audio_input
=
modalities
[
"audios"
]
audio_input
=
modalities
[
"audios"
]
audio_features
=
self
.
_process_audio_input
(
audio_input
)
audio_features
=
self
.
_process_audio_input
(
audio_input
)
multimodal_embeddings
+=
tuple
(
multimodal_embeddings
+=
tuple
(
audio_features
)
scatter_patch_features
(
audio_features
,
audio_input
[
"embed_is_patch"
],
))
return
multimodal_embeddings
return
multimodal_embeddings
vllm/model_executor/models/minicpmv.py
View file @
9c4ecf15
...
@@ -48,7 +48,8 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys
...
@@ -48,7 +48,8 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys
from
vllm.model_executor.models.qwen2
import
Qwen2ForCausalLM
from
vllm.model_executor.models.qwen2
import
Qwen2ForCausalLM
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
from
vllm.multimodal.inputs
import
MultiModalFieldConfig
,
NestedTensors
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
NestedTensors
)
from
vllm.multimodal.parse
import
(
DictEmbeddingItems
,
ImageItem
,
from
vllm.multimodal.parse
import
(
DictEmbeddingItems
,
ImageItem
,
ImageProcessorItems
,
ImageSize
,
ImageProcessorItems
,
ImageSize
,
ModalityData
,
ModalityDataItems
,
ModalityData
,
ModalityDataItems
,
...
@@ -56,8 +57,8 @@ from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem,
...
@@ -56,8 +57,8 @@ from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem,
VideoItem
,
VideoProcessorItems
)
VideoItem
,
VideoProcessorItems
)
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptReplacement
,
BaseProcessingInfo
,
PromptReplacement
,
PromptUpdate
)
PromptUpdate
,
PromptUpdateDetails
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
flatten_2d_lists
from
vllm.utils
import
flatten_2d_lists
...
@@ -67,7 +68,6 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
...
@@ -67,7 +68,6 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal
,
SupportsPP
)
SupportsMultiModal
,
SupportsPP
)
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
maybe_prefix
,
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
maybe_prefix
,
merge_multimodal_embeddings
)
merge_multimodal_embeddings
)
from
.vision
import
scatter_patch_features
,
select_patch_features
# For profile run
# For profile run
_MAX_FRAMES_PER_VIDEO
=
16
_MAX_FRAMES_PER_VIDEO
=
16
...
@@ -90,14 +90,6 @@ class MiniCPMVImagePixelInputs(TypedDict):
...
@@ -90,14 +90,6 @@ class MiniCPMVImagePixelInputs(TypedDict):
This should be in `(height, width)` format.
This should be in `(height, width)` format.
"""
"""
embed_is_patch
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
num_slices
:
torch
.
Tensor
num_slices
:
torch
.
Tensor
"""Shape: `(batch_size * num_images)`"""
"""Shape: `(batch_size * num_images)`"""
...
@@ -112,14 +104,6 @@ class MiniCPMVImageEmbeddingInputs(TypedDict):
...
@@ -112,14 +104,6 @@ class MiniCPMVImageEmbeddingInputs(TypedDict):
instead of a batched tensor.
instead of a batched tensor.
"""
"""
embed_is_patch
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
MiniCPMVImageInputs
=
Union
[
MiniCPMVImagePixelInputs
,
MiniCPMVImageInputs
=
Union
[
MiniCPMVImagePixelInputs
,
MiniCPMVImageEmbeddingInputs
]
MiniCPMVImageEmbeddingInputs
]
...
@@ -245,12 +229,10 @@ def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]):
...
@@ -245,12 +229,10 @@ def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]):
image_sizes
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_sizes
=
MultiModalFieldConfig
.
batched
(
"image"
),
tgt_sizes
=
MultiModalFieldConfig
.
batched
(
"image"
),
tgt_sizes
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_embeds
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_embeds
=
MultiModalFieldConfig
.
batched
(
"image"
),
embed_is_patch
=
MultiModalFieldConfig
.
batched
(
"image"
),
video_pixel_values
=
MultiModalFieldConfig
.
batched
(
"video"
),
video_pixel_values
=
MultiModalFieldConfig
.
batched
(
"video"
),
video_image_sizes
=
MultiModalFieldConfig
.
batched
(
"video"
),
video_image_sizes
=
MultiModalFieldConfig
.
batched
(
"video"
),
video_tgt_sizes
=
MultiModalFieldConfig
.
batched
(
"video"
),
video_tgt_sizes
=
MultiModalFieldConfig
.
batched
(
"video"
),
video_embeds
=
MultiModalFieldConfig
.
batched
(
"video"
),
video_embeds
=
MultiModalFieldConfig
.
batched
(
"video"
),
video_embed_is_patch
=
MultiModalFieldConfig
.
batched
(
"video"
),
image_token_id
=
MultiModalFieldConfig
.
shared
(
"image"
,
num_images
),
image_token_id
=
MultiModalFieldConfig
.
shared
(
"image"
,
num_images
),
video_token_id
=
MultiModalFieldConfig
.
shared
(
"video"
,
num_videos
),
video_token_id
=
MultiModalFieldConfig
.
shared
(
"video"
,
num_videos
),
)
)
...
@@ -308,7 +290,7 @@ class MiniCPMVMultiModalDataParser(MultiModalDataParser):
...
@@ -308,7 +290,7 @@ class MiniCPMVMultiModalDataParser(MultiModalDataParser):
def
_parse_image_data
(
def
_parse_image_data
(
self
,
self
,
data
:
Union
[
dict
[
str
,
torch
.
Tensor
],
ModalityData
[
ImageItem
]],
data
:
Union
[
dict
[
str
,
torch
.
Tensor
],
ModalityData
[
ImageItem
]],
)
->
ModalityDataItems
[
Any
,
Any
]:
)
->
Optional
[
ModalityDataItems
[
Any
,
Any
]
]
:
if
isinstance
(
data
,
dict
):
if
isinstance
(
data
,
dict
):
return
MiniCPMVImageEmbeddingItems
(
return
MiniCPMVImageEmbeddingItems
(
data
,
data
,
...
@@ -320,7 +302,7 @@ class MiniCPMVMultiModalDataParser(MultiModalDataParser):
...
@@ -320,7 +302,7 @@ class MiniCPMVMultiModalDataParser(MultiModalDataParser):
def
_parse_video_data
(
def
_parse_video_data
(
self
,
self
,
data
:
Union
[
dict
[
str
,
torch
.
Tensor
],
ModalityData
[
VideoItem
]],
data
:
Union
[
dict
[
str
,
torch
.
Tensor
],
ModalityData
[
VideoItem
]],
)
->
ModalityDataItems
[
Any
,
Any
]:
)
->
Optional
[
ModalityDataItems
[
Any
,
Any
]
]
:
if
isinstance
(
data
,
dict
):
if
isinstance
(
data
,
dict
):
return
MiniCPMVVideoEmbeddingItems
(
return
MiniCPMVVideoEmbeddingItems
(
data
,
data
,
...
@@ -365,18 +347,6 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
...
@@ -365,18 +347,6 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
return
mm_limits
return
mm_limits
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
int
]:
mm_max_tokens
=
{
"image"
:
self
.
get_max_image_tokens
()}
if
self
.
get_model_version
()
==
(
2
,
6
):
mm_max_tokens
[
"video"
]
=
self
.
get_max_video_tokens
(
seq_len
,
mm_counts
)
return
mm_max_tokens
def
get_slice_image_placeholder
(
def
get_slice_image_placeholder
(
self
,
self
,
image_size
:
ImageSize
,
image_size
:
ImageSize
,
...
@@ -398,22 +368,43 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
...
@@ -398,22 +368,43 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
use_image_id
=
use_image_id
,
use_image_id
=
use_image_id
,
)
)
def
get_sliced_grid
(
self
,
image_size
:
ImageSize
,
# For MiniCPM V/O 2.6
max_slice_nums
:
Optional
[
int
]
=
None
,
)
->
Optional
[
tuple
[
int
,
int
]]:
image_processor
=
self
.
get_image_processor
()
version
=
self
.
get_model_version
()
if
version
==
(
2
,
0
)
or
version
==
(
2
,
5
):
return
image_processor
.
get_sliced_grid
(
image_size
)
if
max_slice_nums
is
None
:
max_slice_nums
=
image_processor
.
max_slice_nums
return
image_processor
.
get_sliced_grid
(
image_size
,
max_slice_nums
=
max_slice_nums
,
)
def
get_num_image_tokens
(
def
get_num_image_tokens
(
self
,
self
,
image_size
:
ImageSize
,
image_size
:
ImageSize
,
max_slice_nums
:
Optional
[
int
]
=
None
,
max_slice_nums
:
Optional
[
int
]
=
None
,
use_image_id
:
bool
=
True
,
)
->
int
:
)
->
int
:
tokenizer
=
self
.
get_tokenizer
()
image_processor
=
self
.
get_image_processor
()
image_placeholders
=
self
.
get_slice_image_placeholder
(
grid
=
self
.
get_sliced_grid
(
image_size
,
image_size
,
max_slice_nums
=
max_slice_nums
,
max_slice_nums
=
max_slice_nums
,
use_image_id
=
use_image_id
,
)
)
image_token_ids
=
tokenizer
.
encode
(
image_placeholders
,
if
grid
is
None
:
add_special_tokens
=
False
)
ncols
=
nrows
=
0
else
:
ncols
,
nrows
=
grid
return
len
(
image_token_ids
)
return
(
ncols
*
nrows
+
1
)
*
image_processor
.
image_feature_size
def
get_max_image_tokens
(
self
)
->
int
:
def
get_max_image_tokens
(
self
)
->
int
:
image_size
=
self
.
get_image_size_with_most_features
()
image_size
=
self
.
get_image_size_with_most_features
()
...
@@ -433,7 +424,6 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
...
@@ -433,7 +424,6 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
return
self
.
get_num_image_tokens
(
return
self
.
get_num_image_tokens
(
frame_size
,
frame_size
,
max_slice_nums
=
self
.
get_video_max_slice_num
(),
max_slice_nums
=
self
.
get_video_max_slice_num
(),
use_image_id
=
False
,
)
)
def
get_max_video_tokens
(
def
get_max_video_tokens
(
...
@@ -482,11 +472,20 @@ _I = TypeVar("_I",
...
@@ -482,11 +472,20 @@ _I = TypeVar("_I",
class
MiniCPMVDummyInputsBuilder
(
BaseDummyInputsBuilder
[
_I
]):
class
MiniCPMVDummyInputsBuilder
(
BaseDummyInputsBuilder
[
_I
]):
def
get_dummy_processor_inputs
(
def
get_dummy_text
(
self
,
mm_counts
:
Mapping
[
str
,
int
])
->
str
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
num_videos
=
mm_counts
.
get
(
"video"
,
0
)
image_prompt_texts
=
self
.
info
.
image_pattern
*
num_images
video_prompt_texts
=
self
.
info
.
video_pattern
*
num_videos
return
image_prompt_texts
+
video_prompt_texts
def
get_dummy_mm_data
(
self
,
self
,
seq_len
:
int
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
mm_counts
:
Mapping
[
str
,
int
],
)
->
ProcessorInputs
:
)
->
MultiModalDataDict
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
num_images
=
mm_counts
.
get
(
"image"
,
0
)
num_videos
=
mm_counts
.
get
(
"video"
,
0
)
num_videos
=
mm_counts
.
get
(
"video"
,
0
)
...
@@ -497,7 +496,7 @@ class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
...
@@ -497,7 +496,7 @@ class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
num_video_frames
=
\
num_video_frames
=
\
self
.
info
.
get_num_frames_with_most_features
(
seq_len
,
mm_counts
)
self
.
info
.
get_num_frames_with_most_features
(
seq_len
,
mm_counts
)
mm_data
=
{
return
{
"image"
:
"image"
:
self
.
_get_dummy_images
(
width
=
image_width
,
self
.
_get_dummy_images
(
width
=
image_width
,
height
=
image_height
,
height
=
image_height
,
...
@@ -509,13 +508,6 @@ class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
...
@@ -509,13 +508,6 @@ class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
]
*
num_videos
,
]
*
num_videos
,
}
}
image_prompt_texts
=
self
.
info
.
image_pattern
*
num_images
video_prompt_texts
=
self
.
info
.
video_pattern
*
num_videos
return
ProcessorInputs
(
prompt_text
=
image_prompt_texts
+
video_prompt_texts
,
mm_data
=
mm_data
)
class
MiniCPMVMultiModalProcessor
(
BaseMultiModalProcessor
[
_I
]):
class
MiniCPMVMultiModalProcessor
(
BaseMultiModalProcessor
[
_I
]):
...
@@ -539,14 +531,6 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
...
@@ -539,14 +531,6 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
use_image_id
=
False
,
use_image_id
=
False
,
)
*
num_frames
)
*
num_frames
def
get_embed_is_patch
(
self
,
input_ids
:
list
[
int
],
)
->
torch
.
Tensor
:
tokenizer
=
self
.
info
.
get_tokenizer
()
unk_token_id
=
tokenizer
.
get_vocab
()[
"<unk>"
]
return
torch
.
tensor
(
input_ids
)
==
unk_token_id
def
process_images
(
def
process_images
(
self
,
self
,
mm_data
:
Mapping
[
str
,
object
],
mm_data
:
Mapping
[
str
,
object
],
...
@@ -570,26 +554,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
...
@@ -570,26 +554,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
out_keys
=
{
"pixel_values"
,
"image_sizes"
,
"tgt_sizes"
},
out_keys
=
{
"pixel_values"
,
"image_sizes"
,
"tgt_sizes"
},
)
)
image_sizes
=
[
parsed_images
.
get_image_size
(
i
)
for
i
in
range
(
len
(
parsed_images
))
]
image_repl_features
=
[
self
.
get_image_prompt_texts
(
size
,
idx
)
for
idx
,
size
in
enumerate
(
image_sizes
)
]
tokenizer
=
self
.
info
.
get_tokenizer
()
tokenizer
=
self
.
info
.
get_tokenizer
()
image_repls_feature_tokens
=
[
tokenizer
.
encode
(
image_repl
,
add_special_tokens
=
False
)
for
image_repl
in
image_repl_features
]
embed_is_patch
=
[
self
.
get_embed_is_patch
(
image_repl_tokens
)
for
image_repl_tokens
in
image_repls_feature_tokens
]
image_inputs
[
"embed_is_patch"
]
=
embed_is_patch
unk_token_id
=
tokenizer
.
get_vocab
()[
"<unk>"
]
unk_token_id
=
tokenizer
.
get_vocab
()[
"<unk>"
]
image_inputs
[
"image_token_id"
]
=
torch
.
tensor
(
unk_token_id
)
image_inputs
[
"image_token_id"
]
=
torch
.
tensor
(
unk_token_id
)
...
@@ -625,31 +590,9 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
...
@@ -625,31 +590,9 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
out_keys
=
{
"pixel_values"
,
"image_sizes"
,
"tgt_sizes"
},
out_keys
=
{
"pixel_values"
,
"image_sizes"
,
"tgt_sizes"
},
)
)
frame_sizes
=
[
parsed_videos
.
get_frame_size
(
i
)
for
i
in
range
(
len
(
parsed_videos
))
]
num_frames
=
[
parsed_videos
.
get_num_frames
(
i
)
for
i
in
range
(
len
(
parsed_videos
))
]
video_repl_features
=
[
self
.
get_video_prompt_texts
(
size
,
nframes
)
for
size
,
nframes
in
zip
(
frame_sizes
,
num_frames
)
]
tokenizer
=
self
.
info
.
get_tokenizer
()
video_repls_feature_tokens
=
[
tokenizer
.
encode
(
video_repl
,
add_special_tokens
=
False
)
for
video_repl
in
video_repl_features
]
embed_is_patch
=
[
self
.
get_embed_is_patch
(
video_repl_tokens
)
for
video_repl_tokens
in
video_repls_feature_tokens
]
video_inputs
[
"embed_is_patch"
]
=
embed_is_patch
video_inputs
=
{
f
"video_
{
k
}
"
:
v
for
k
,
v
in
video_inputs
.
items
()}
video_inputs
=
{
f
"video_
{
k
}
"
:
v
for
k
,
v
in
video_inputs
.
items
()}
tokenizer
=
self
.
info
.
get_tokenizer
()
unk_token_id
=
tokenizer
.
get_vocab
()[
"<unk>"
]
unk_token_id
=
tokenizer
.
get_vocab
()[
"<unk>"
]
video_inputs
[
"video_token_id"
]
=
torch
.
tensor
(
unk_token_id
)
video_inputs
[
"video_token_id"
]
=
torch
.
tensor
(
unk_token_id
)
...
@@ -740,7 +683,10 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
...
@@ -740,7 +683,10 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
image_size
=
images
.
get_image_size
(
item_idx
)
image_size
=
images
.
get_image_size
(
item_idx
)
return
self
.
get_image_prompt_texts
(
image_size
,
item_idx
)
return
PromptUpdateDetails
.
select_text
(
self
.
get_image_prompt_texts
(
image_size
,
item_idx
),
"<unk>"
,
)
def
get_video_replacement
(
item_idx
:
int
):
def
get_video_replacement
(
item_idx
:
int
):
videos
=
mm_items
.
get_items
(
videos
=
mm_items
.
get_items
(
...
@@ -749,7 +695,10 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
...
@@ -749,7 +695,10 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
frame_size
=
videos
.
get_frame_size
(
item_idx
)
frame_size
=
videos
.
get_frame_size
(
item_idx
)
num_frames
=
videos
.
get_num_frames
(
item_idx
)
num_frames
=
videos
.
get_num_frames
(
item_idx
)
return
self
.
get_video_prompt_texts
(
frame_size
,
num_frames
)
return
PromptUpdateDetails
.
select_text
(
self
.
get_video_prompt_texts
(
frame_size
,
num_frames
),
"<unk>"
,
)
get_replacement
=
{
get_replacement
=
{
"image"
:
get_image_replacement
,
"image"
:
get_image_replacement
,
...
@@ -832,14 +781,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -832,14 +781,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
assert
isinstance
(
image_token_id
,
torch
.
Tensor
)
assert
isinstance
(
image_token_id
,
torch
.
Tensor
)
self
.
mm_token_ids
.
add
(
image_token_id
.
flatten
().
unique
().
item
())
self
.
mm_token_ids
.
add
(
image_token_id
.
flatten
().
unique
().
item
())
embed_is_patch
=
kwargs
.
pop
(
"embed_is_patch"
)
if
not
isinstance
(
embed_is_patch
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
f
"Incorrect type of embed_is_patch for
{
modality
=
}
. "
f
"Got type:
{
type
(
embed_is_patch
)
}
"
)
embed_is_patch
=
flatten_bn
(
embed_is_patch
)
if
image_embeds
is
not
None
:
if
image_embeds
is
not
None
:
if
not
isinstance
(
image_embeds
,
(
torch
.
Tensor
,
list
)):
if
not
isinstance
(
image_embeds
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
raise
ValueError
(
...
@@ -851,7 +792,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -851,7 +792,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
return
MiniCPMVImageEmbeddingInputs
(
return
MiniCPMVImageEmbeddingInputs
(
type
=
"image_embeds"
,
type
=
"image_embeds"
,
image_embeds
=
image_embeds_flat
,
image_embeds
=
image_embeds_flat
,
embed_is_patch
=
embed_is_patch
,
)
)
if
not
isinstance
(
pixel_values
,
(
torch
.
Tensor
,
list
)):
if
not
isinstance
(
pixel_values
,
(
torch
.
Tensor
,
list
)):
...
@@ -879,7 +819,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -879,7 +819,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
type
=
"pixel_values"
,
type
=
"pixel_values"
,
pixel_values
=
pixel_values_flat
,
pixel_values
=
pixel_values_flat
,
tgt_sizes
=
tgt_sizes_flat
,
tgt_sizes
=
tgt_sizes_flat
,
embed_is_patch
=
embed_is_patch
,
num_slices
=
num_slices_flat
,
num_slices
=
num_slices_flat
,
)
)
...
@@ -936,22 +875,17 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -936,22 +875,17 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
if
modality
==
"images"
:
if
modality
==
"images"
:
image_input
=
modalities
[
"images"
]
image_input
=
modalities
[
"images"
]
image_features
=
self
.
_process_vision_input
(
image_input
)
image_features
=
self
.
_process_vision_input
(
image_input
)
multimodal_embeddings
+=
tuple
(
multimodal_embeddings
+=
tuple
(
image_features
)
scatter_patch_features
(
image_features
,
image_input
[
"embed_is_patch"
],
))
if
modality
==
"videos"
:
if
modality
==
"videos"
:
video_input
=
modalities
[
"videos"
]
video_input
=
modalities
[
"videos"
]
video_features
=
self
.
_process_vision_input
(
video_input
)
video_features
=
self
.
_process_vision_input
(
video_input
)
multimodal_embeddings
+=
tuple
(
multimodal_embeddings
+=
tuple
(
video_features
)
scatter_patch_features
(
video_features
,
video_input
[
"embed_is_patch"
],
))
return
multimodal_embeddings
return
multimodal_embeddings
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
llm
def
get_multimodal_embeddings
(
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
modalities
=
self
.
_parse_and_validate_multimodal_inputs
(
**
kwargs
)
modalities
=
self
.
_parse_and_validate_multimodal_inputs
(
**
kwargs
)
...
@@ -971,7 +905,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -971,7 +905,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds
=
merge_multimodal_embeddings
(
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
input_ids
,
inputs_embeds
,
inputs_embeds
,
select_patch_features
(
multimodal_embeddings
)
,
multimodal_embeddings
,
list
(
self
.
mm_token_ids
),
list
(
self
.
mm_token_ids
),
)
)
return
inputs_embeds
return
inputs_embeds
...
...
vllm/model_executor/models/mistral3.py
View file @
9c4ecf15
...
@@ -22,21 +22,22 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
...
@@ -22,21 +22,22 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
MultiModalFieldConfig
,
MultiModalKwargs
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalKwargs
)
from
vllm.multimodal.parse
import
(
ImageProcessorItems
,
ImageSize
,
from
vllm.multimodal.parse
import
(
ImageProcessorItems
,
ImageSize
,
MultiModalDataItems
)
MultiModalDataItems
)
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
ProcessingCache
,
BaseProcessingInfo
,
ProcessingCache
,
PromptReplacement
,
PromptUpdate
)
PromptReplacement
,
PromptUpdate
,
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
PromptUpdateDetails
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
from
.pixtral
import
PixtralHFEncoderInfo
,
PixtralHFVisionModel
from
.pixtral
import
PixtralHFEncoderInfo
,
PixtralHFVisionModel
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
init_vllm_registered_model
,
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
init_vllm_registered_model
,
maybe_prefix
,
merge_multimodal_embeddings
)
maybe_prefix
,
merge_multimodal_embeddings
)
from
.vision
import
(
get_vision_encoder_info
,
scatter_patch_features
,
from
.vision
import
get_vision_encoder_info
select_patch_features
)
class
Mistral3ImagePixelInputs
(
TypedDict
):
class
Mistral3ImagePixelInputs
(
TypedDict
):
...
@@ -49,14 +50,6 @@ class Mistral3ImagePixelInputs(TypedDict):
...
@@ -49,14 +50,6 @@ class Mistral3ImagePixelInputs(TypedDict):
in which case the data is passed as a list instead of a batched tensor.
in which case the data is passed as a list instead of a batched tensor.
"""
"""
embed_is_patch
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size, num_images, num_embeds)`
"""
class
Mistral3PatchMerger
(
nn
.
Module
):
class
Mistral3PatchMerger
(
nn
.
Module
):
"""
"""
...
@@ -170,13 +163,6 @@ class BaseLlavaProcessingInfo(BaseProcessingInfo):
...
@@ -170,13 +163,6 @@ class BaseLlavaProcessingInfo(BaseProcessingInfo):
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
None
}
return
{
"image"
:
None
}
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
int
]:
return
{
"image"
:
self
.
get_max_image_tokens
()}
def
get_num_image_tokens
(
def
get_num_image_tokens
(
self
,
self
,
*
,
*
,
...
@@ -194,44 +180,37 @@ class BaseLlavaProcessingInfo(BaseProcessingInfo):
...
@@ -194,44 +180,37 @@ class BaseLlavaProcessingInfo(BaseProcessingInfo):
width
=
height
=
vision_encoder_info
.
get_image_size
()
width
=
height
=
vision_encoder_info
.
get_image_size
()
return
ImageSize
(
width
=
width
,
height
=
height
)
return
ImageSize
(
width
=
width
,
height
=
height
)
def
get_max_image_tokens
(
self
)
->
int
:
target_width
,
target_height
=
self
.
get_image_size_with_most_features
()
return
self
.
get_num_image_tokens
(
image_width
=
target_width
,
image_height
=
target_height
,
)
_I
=
TypeVar
(
"_I"
,
bound
=
BaseLlavaProcessingInfo
)
_I
=
TypeVar
(
"_I"
,
bound
=
BaseLlavaProcessingInfo
)
class
Mistral3DummyInputsBuilder
(
BaseDummyInputsBuilder
[
_I
]):
class
Mistral3DummyInputsBuilder
(
BaseDummyInputsBuilder
[
_I
]):
def
get_dummy_processor_inputs
(
def
get_dummy_text
(
self
,
mm_counts
:
Mapping
[
str
,
int
])
->
str
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
processor
=
self
.
info
.
get_hf_processor
()
image_token
=
processor
.
image_token
return
image_token
*
num_images
def
get_dummy_mm_data
(
self
,
self
,
seq_len
:
int
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
mm_counts
:
Mapping
[
str
,
int
],
)
->
ProcessorInputs
:
)
->
MultiModalDataDict
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
num_images
=
mm_counts
.
get
(
"image"
,
0
)
processor
=
self
.
info
.
get_hf_processor
()
image_token
=
processor
.
image_token
target_width
,
target_height
=
\
target_width
,
target_height
=
\
self
.
info
.
get_image_size_with_most_features
()
self
.
info
.
get_image_size_with_most_features
()
mm_data
=
{
return
{
"image"
:
"image"
:
self
.
_get_dummy_images
(
width
=
target_width
,
self
.
_get_dummy_images
(
width
=
target_width
,
height
=
target_height
,
height
=
target_height
,
num_images
=
num_images
)
num_images
=
num_images
)
}
}
return
ProcessorInputs
(
prompt_text
=
image_token
*
num_images
,
mm_data
=
mm_data
,
)
class
Mistral3ProcessingInfo
(
BaseLlavaProcessingInfo
):
class
Mistral3ProcessingInfo
(
BaseLlavaProcessingInfo
):
...
@@ -266,23 +245,6 @@ class Mistral3MultiModalProcessor(
...
@@ -266,23 +245,6 @@ class Mistral3MultiModalProcessor(
p
[:,
:
h
,
:
w
]
for
p
,
(
h
,
w
)
in
zip
(
pixel_values
,
image_sizes
)
p
[:,
:
h
,
:
w
]
for
p
,
(
h
,
w
)
in
zip
(
pixel_values
,
image_sizes
)
]
]
hf_config
=
self
.
info
.
get_hf_config
()
vision_config
=
hf_config
.
vision_config
assert
isinstance
(
vision_config
,
PixtralVisionConfig
)
encoder_info
=
PixtralHFEncoderInfo
(
vision_config
)
tile_sizes
=
[
encoder_info
.
get_patch_grid_size
(
image_width
=
pixel_value
.
shape
[
-
1
],
image_height
=
pixel_value
.
shape
[
-
2
],
)
for
pixel_value
in
processed_outputs
[
"pixel_values"
]
]
embed_is_patch
=
[
torch
.
tensor
(([
True
]
*
ncols
+
[
False
])
*
nrows
)
for
ncols
,
nrows
in
tile_sizes
]
processed_outputs
[
"embed_is_patch"
]
=
embed_is_patch
return
processed_outputs
return
processed_outputs
def
_get_mm_fields_config
(
def
_get_mm_fields_config
(
...
@@ -292,7 +254,6 @@ class Mistral3MultiModalProcessor(
...
@@ -292,7 +254,6 @@ class Mistral3MultiModalProcessor(
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
return
dict
(
return
dict
(
pixel_values
=
MultiModalFieldConfig
.
batched
(
"image"
),
pixel_values
=
MultiModalFieldConfig
.
batched
(
"image"
),
embed_is_patch
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_embeds
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_embeds
=
MultiModalFieldConfig
.
batched
(
"image"
),
)
)
...
@@ -327,7 +288,7 @@ class Mistral3MultiModalProcessor(
...
@@ -327,7 +288,7 @@ class Mistral3MultiModalProcessor(
tokens
=
([
image_token_id
]
*
ncols
+
[
image_break_id
])
*
nrows
tokens
=
([
image_token_id
]
*
ncols
+
[
image_break_id
])
*
nrows
tokens
[
-
1
]
=
image_end_id
tokens
[
-
1
]
=
image_end_id
return
tokens
return
PromptUpdateDetails
.
select_token_id
(
tokens
,
image_token_id
)
return
[
return
[
PromptReplacement
(
PromptReplacement
(
...
@@ -418,8 +379,6 @@ def init_vision_tower_for_llava(
...
@@ -418,8 +379,6 @@ def init_vision_tower_for_llava(
)
)
# TODO(mgoin): Support V1, there are issues with image batching/chunking
# that need to be resolved first.
@
MULTIMODAL_REGISTRY
.
register_processor
(
@
MULTIMODAL_REGISTRY
.
register_processor
(
_build_mistral3_processor
,
_build_mistral3_processor
,
info
=
_build_mistral3_info
,
info
=
_build_mistral3_info
,
...
@@ -509,16 +468,9 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -509,16 +468,9 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
raise
ValueError
(
"Incorrect type of pixel values. "
raise
ValueError
(
"Incorrect type of pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
f
"Got type:
{
type
(
pixel_values
)
}
"
)
assert
self
.
config
.
vision_config
.
model_type
==
"pixtral"
embed_is_patch
=
kwargs
.
pop
(
"embed_is_patch"
)
if
not
isinstance
(
embed_is_patch
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of embed_is_patch. "
f
"Got type:
{
type
(
embed_is_patch
)
}
"
)
return
Mistral3ImagePixelInputs
(
return
Mistral3ImagePixelInputs
(
type
=
"pixel_values_pixtral"
,
type
=
"pixel_values_pixtral"
,
pixel_values
=
flatten_bn
(
pixel_values
),
pixel_values
=
flatten_bn
(
pixel_values
),
embed_is_patch
=
flatten_bn
(
embed_is_patch
),
)
)
def
_process_image_input
(
def
_process_image_input
(
...
@@ -549,6 +501,9 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -549,6 +501,9 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
image_embeds
=
(
image_embeds
,
)
image_embeds
=
(
image_embeds
,
)
return
image_embeds
return
image_embeds
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
get_multimodal_embeddings
(
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
...
@@ -557,10 +512,7 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -557,10 +512,7 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
return
scatter_patch_features
(
return
vision_embeddings
vision_embeddings
,
image_input
[
"embed_is_patch"
],
)
def
get_input_embeddings
(
def
get_input_embeddings
(
self
,
self
,
...
@@ -572,7 +524,7 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -572,7 +524,7 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds
=
merge_multimodal_embeddings
(
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
input_ids
,
inputs_embeds
,
inputs_embeds
,
select_patch_features
(
multimodal_embeddings
)
,
multimodal_embeddings
,
self
.
config
.
image_token_index
,
self
.
config
.
image_token_index
,
)
)
return
inputs_embeds
return
inputs_embeds
...
...
vllm/model_executor/models/mixtral.py
View file @
9c4ecf15
...
@@ -49,7 +49,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
...
@@ -49,7 +49,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
from
.utils
import
(
AutoWeightsLoader
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
maybe_prefix
)
...
@@ -260,6 +260,8 @@ class MixtralModel(nn.Module):
...
@@ -260,6 +260,8 @@ class MixtralModel(nn.Module):
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
self
.
quant_config
=
quant_config
lora_vocab
=
(
lora_config
.
lora_extra_vocab_size
*
lora_vocab
=
(
lora_config
.
lora_extra_vocab_size
*
(
lora_config
.
max_loras
or
1
))
if
lora_config
else
0
(
lora_config
.
max_loras
or
1
))
if
lora_config
else
0
self
.
vocab_size
=
config
.
vocab_size
+
lora_vocab
self
.
vocab_size
=
config
.
vocab_size
+
lora_vocab
...
@@ -313,88 +315,6 @@ class MixtralModel(nn.Module):
...
@@ -313,88 +315,6 @@ class MixtralModel(nn.Module):
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
return
hidden_states
class
MixtralForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
fall_back_to_pt_during_load
=
False
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
}
# LoRA specific attributes
embedding_modules
=
{
"embed_tokens"
:
"input_embeddings"
,
"lm_head"
:
"output_embeddings"
,
}
embedding_padding_modules
=
[
"lm_head"
]
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
quant_config
=
quant_config
self
.
model
=
MixtralModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
unpadded_vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
quant_config
=
quant_config
,
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
)
self
.
sampler
=
get_sampler
()
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
Optional
[
torch
.
Tensor
],
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
@@ -415,9 +335,6 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -415,9 +335,6 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
if
(
self
.
quant_config
is
not
None
and
if
(
self
.
quant_config
is
not
None
and
(
scale_name
:
=
self
.
quant_config
.
get_cache_scale
(
name
))):
(
scale_name
:
=
self
.
quant_config
.
get_cache_scale
(
name
))):
# Loading kv cache quantization scales
# Loading kv cache quantization scales
...
@@ -489,3 +406,90 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -489,3 +406,90 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
loaded_params
.
add
(
name
)
return
loaded_params
return
loaded_params
class
MixtralForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
fall_back_to_pt_during_load
=
False
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
}
# LoRA specific attributes
embedding_modules
=
{
"embed_tokens"
:
"input_embeddings"
,
"lm_head"
:
"output_embeddings"
,
}
embedding_padding_modules
=
[
"lm_head"
]
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
quant_config
=
quant_config
self
.
model
=
MixtralModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
unpadded_vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
quant_config
=
quant_config
,
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
)
self
.
sampler
=
get_sampler
()
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
Optional
[
torch
.
Tensor
],
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
loader
=
AutoWeightsLoader
(
self
,
skip_prefixes
=
[
"rotary_emb.inv_freq"
])
return
loader
.
load_weights
(
weights
)
vllm/model_executor/models/mixtral_quant.py
View file @
9c4ecf15
...
@@ -45,7 +45,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
...
@@ -45,7 +45,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
...
@@ -420,6 +421,11 @@ class MixtralForCausalLM(nn.Module, SupportsPP):
...
@@ -420,6 +421,11 @@ class MixtralForCausalLM(nn.Module, SupportsPP):
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
if
name
.
endswith
(
"scale"
):
# Remapping the name of FP8 kv-scale.
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
continue
continue
...
...
vllm/model_executor/models/mllama.py
View file @
9c4ecf15
...
@@ -52,16 +52,17 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -52,16 +52,17 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
(
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
)
default_weight_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
MultiModalEncDecInputs
,
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalEncDecInputs
,
MultiModalFieldConfig
,
MultiModalKwargs
)
MultiModalFieldConfig
,
MultiModalKwargs
)
from
vllm.multimodal.parse
import
(
ImageProcessorItems
,
ImageSize
,
from
vllm.multimodal.parse
import
(
ImageProcessorItems
,
ImageSize
,
MultiModalDataDict
,
MultiModalDataItems
)
MultiModalDataItems
)
from
vllm.multimodal.processing
import
(
BaseProcessingInfo
,
from
vllm.multimodal.processing
import
(
BaseProcessingInfo
,
EncDecMultiModalProcessor
,
EncDecMultiModalProcessor
,
PromptReplacement
,
PromptUpdate
)
PromptReplacement
,
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
.clip
import
CLIPMLP
from
.clip
import
CLIPMLP
from
.interfaces
import
SupportsMultiModal
,
SupportsV0Only
from
.interfaces
import
SupportsMultiModal
,
SupportsV0Only
...
@@ -106,16 +107,6 @@ class MllamaProcessingInfo(BaseProcessingInfo):
...
@@ -106,16 +107,6 @@ class MllamaProcessingInfo(BaseProcessingInfo):
image_size
=
self
.
get_hf_config
().
vision_config
.
image_size
image_size
=
self
.
get_hf_config
().
vision_config
.
image_size
return
calc_token_per_chunk
(
image_size
)
return
calc_token_per_chunk
(
image_size
)
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
int
]:
vision_config
=
self
.
get_hf_config
().
vision_config
token_per_chunk
=
self
.
get_token_per_chunk_from_config
()
mm_max_tokens
=
vision_config
.
max_num_tiles
*
token_per_chunk
return
{
"image"
:
mm_max_tokens
}
def
get_num_tiles_per_image
(
self
,
image_height
:
int
,
def
get_num_tiles_per_image
(
self
,
image_height
:
int
,
image_width
:
int
)
->
int
:
image_width
:
int
)
->
int
:
vision_config
=
self
.
get_hf_config
().
vision_config
vision_config
=
self
.
get_hf_config
().
vision_config
...
@@ -141,31 +132,31 @@ class MllamaProcessingInfo(BaseProcessingInfo):
...
@@ -141,31 +132,31 @@ class MllamaProcessingInfo(BaseProcessingInfo):
class
MllamaDummyInputsBuilder
(
BaseDummyInputsBuilder
[
MllamaProcessingInfo
]):
class
MllamaDummyInputsBuilder
(
BaseDummyInputsBuilder
[
MllamaProcessingInfo
]):
def
get_dummy_processor_inputs
(
def
get_dummy_text
(
self
,
mm_counts
:
Mapping
[
str
,
int
])
->
str
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
processor
=
self
.
info
.
get_hf_processor
()
image_token
=
processor
.
image_token
return
image_token
*
num_images
def
get_dummy_mm_data
(
self
,
self
,
seq_len
:
int
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
mm_counts
:
Mapping
[
str
,
int
],
)
->
ProcessorInputs
:
)
->
MultiModalDataDict
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
num_images
=
mm_counts
.
get
(
"image"
,
0
)
target_width
,
target_height
=
\
target_width
,
target_height
=
\
self
.
info
.
get_image_size_with_most_features
()
self
.
info
.
get_image_size_with_most_features
()
mm_data
=
{
return
{
"image"
:
"image"
:
self
.
_get_dummy_images
(
width
=
target_width
,
self
.
_get_dummy_images
(
width
=
target_width
,
height
=
target_height
,
height
=
target_height
,
num_images
=
num_images
)
num_images
=
num_images
)
}
}
hf_processor
=
self
.
info
.
get_hf_processor
()
image_token
:
str
=
hf_processor
.
image_token
return
ProcessorInputs
(
prompt_text
=
image_token
*
num_images
,
mm_data
=
mm_data
,
)
class
MllamaMultiModalProcessor
(
EncDecMultiModalProcessor
[
MllamaProcessingInfo
]
class
MllamaMultiModalProcessor
(
EncDecMultiModalProcessor
[
MllamaProcessingInfo
]
):
):
...
@@ -211,6 +202,9 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo]
...
@@ -211,6 +202,9 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo]
# }
# }
if
mm_data
:
if
mm_data
:
hf_processor
=
self
.
info
.
get_hf_processor
()
image_token
:
str
=
hf_processor
.
image_token
# Since only the last group of consecutive images
# Since only the last group of consecutive images
# are attended by the decoded tokens, we only need to
# are attended by the decoded tokens, we only need to
# get the number of tokens for those images.
# get the number of tokens for those images.
...
@@ -227,7 +221,7 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo]
...
@@ -227,7 +221,7 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo]
num_tokens
=
decode_tiles
*
token_per_chunk
num_tokens
=
decode_tiles
*
token_per_chunk
mm_inputs
[
"encoder_prompt_token_ids"
]
=
[
image_token_id
mm_inputs
[
"encoder_prompt_token_ids"
]
=
[
image_token_id
]
*
num_tokens
]
*
num_tokens
mm_inputs
[
"encoder_prompt"
]
=
"<|
image
|>"
*
num_tokens
mm_inputs
[
"encoder_prompt"
]
=
image
_token
*
num_tokens
return
mm_inputs
return
mm_inputs
...
@@ -1188,6 +1182,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -1188,6 +1182,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
super
().
__init__
()
super
().
__init__
()
config
:
MllamaConfig
=
vllm_config
.
model_config
.
hf_config
config
:
MllamaConfig
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
vocab_size
=
config
.
text_config
.
vocab_size
self
.
vocab_size
=
config
.
text_config
.
vocab_size
self
.
hidden_size
=
config
.
text_config
.
hidden_size
self
.
hidden_size
=
config
.
text_config
.
hidden_size
...
@@ -1306,6 +1301,31 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -1306,6 +1301,31 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
raise
AssertionError
(
"This line should be unreachable."
)
raise
AssertionError
(
"This line should be unreachable."
)
def
_get_and_validate_encoder_lens
(
self
,
encoder_seq_lens
:
List
[
int
],
num_tiles
:
List
[
List
[
int
]],
num_tokens_per_tile
:
int
,
)
->
List
[
int
]:
# Get the actual number of encoder tokens for each sample.
# Because attn_metadata.encoder_seq_lens only counts the last
# group of images for each sample, which is used to cheat the
# block manager to allocate blocks for those images only.
# See MllamaMultiModalProcessor for more details.
actual_encoder_seq_lens
=
[
sum
(
num_tile
)
*
num_tokens_per_tile
for
num_tile
in
num_tiles
]
# remove 0 encoder len entries for text-only requests for these
# assertions
attn_metadata_lens
=
[
x
for
x
in
encoder_seq_lens
if
x
>
0
]
assert
len
(
actual_encoder_seq_lens
)
==
len
(
attn_metadata_lens
)
for
actual_len
,
last_group_len
in
zip
(
actual_encoder_seq_lens
,
attn_metadata_lens
):
assert
actual_len
>=
last_group_len
return
actual_encoder_seq_lens
def
flat_encoder_result
(
self
,
cross_attention_states
:
torch
.
Tensor
,
def
flat_encoder_result
(
self
,
cross_attention_states
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
actual_encoder_seq_lens
:
List
[
int
]):
actual_encoder_seq_lens
:
List
[
int
]):
...
@@ -1325,6 +1345,9 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -1325,6 +1345,9 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
cross_attention_states
=
cross_attention_states_flat
cross_attention_states
=
cross_attention_states_flat
return
cross_attention_states
return
cross_attention_states
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
get_cross_attention_states
(
def
get_cross_attention_states
(
self
,
self
,
image_inputs
:
MllamaImagePixelInputs
,
image_inputs
:
MllamaImagePixelInputs
,
...
@@ -1430,20 +1453,14 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -1430,20 +1453,14 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
else
:
else
:
skip_cross_attention
=
False
skip_cross_attention
=
False
# Get the actual number of encoder tokens for each sample.
num_tiles
=
[
t
.
tolist
()
for
t
in
kwargs
.
pop
(
"num_tiles"
)]
# Because attn_metadata.encoder_seq_lens only counts the last
# group of images for each sample, which is used to cheat the
# block manager to allocate blocks for those images only.
# See MllamaMultiModalProcessor for more details.
num_tiles_tensor
=
kwargs
.
pop
(
"num_tiles"
)
num_tiles
=
[
t
.
tolist
()
for
t
in
num_tiles_tensor
]
num_tokens_per_tile
=
calc_token_per_chunk
(
self
.
image_size
)
num_tokens_per_tile
=
calc_token_per_chunk
(
self
.
image_size
)
actual_encoder_seq_lens
=
[
sum
(
num_tile
)
*
num_tokens_per_tile
for
num_tile
in
num_tiles
actual_encoder_seq_lens
=
self
.
_get_and_validate_encoder_lens
(
]
attn_metadata
.
encoder_seq_lens
,
for
actual_len
,
last_group_len
in
zip
(
num_tiles
,
actual_encoder_seq_lens
,
attn_metadata
.
encoder_seq_lens
):
num_tokens_per_tile
,
assert
actual_len
>=
last_group_len
)
cross_attention_states
=
self
.
get_cross_attention_states
(
cross_attention_states
=
self
.
get_cross_attention_states
(
image_inputs
,
attn_metadata
,
actual_encoder_seq_lens
)
image_inputs
,
attn_metadata
,
actual_encoder_seq_lens
)
...
@@ -1521,6 +1538,15 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -1521,6 +1538,15 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
updated_params
.
add
(
name
)
updated_params
.
add
(
name
)
return
updated_params
return
updated_params
def
get_mm_mapping
(
self
)
->
MultiModelKeys
:
"""
Get the module prefix in multimodal models
"""
return
MultiModelKeys
.
from_string_field
(
language_model
=
"language_model"
,
connector
=
"multi_modal_projector"
,
tower_model
=
"vision_model"
)
def
skip_attention_mask
(
sparse_mask
:
List
[
List
[
int
]])
->
bool
:
def
skip_attention_mask
(
sparse_mask
:
List
[
List
[
int
]])
->
bool
:
for
mask
in
sparse_mask
:
for
mask
in
sparse_mask
:
...
...
vllm/model_executor/models/mllama4.py
View file @
9c4ecf15
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
# limitations under the License.
# limitations under the License.
import
math
import
math
from
collections.abc
import
Iterable
,
Mapping
from
collections.abc
import
Iterable
,
Mapping
from
functools
import
cached_property
from
itertools
import
tee
from
itertools
import
tee
from
typing
import
List
,
Literal
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
from
typing
import
List
,
Literal
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
...
@@ -24,7 +25,6 @@ import torch
...
@@ -24,7 +25,6 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
BatchFeature
,
Llama4Config
,
Llama4VisionConfig
from
transformers
import
BatchFeature
,
Llama4Config
,
Llama4VisionConfig
from
transformers.image_utils
import
SizeDict
from
transformers.image_utils
import
SizeDict
from
transformers.modeling_outputs
import
BaseModelOutput
from
transformers.models.llama4
import
Llama4Processor
from
transformers.models.llama4
import
Llama4Processor
from
transformers.models.llama4.image_processing_llama4_fast
import
(
from
transformers.models.llama4.image_processing_llama4_fast
import
(
find_supported_resolutions
,
get_best_fit
)
find_supported_resolutions
,
get_best_fit
)
...
@@ -33,33 +33,30 @@ from vllm.attention.layer import MultiHeadAttention
...
@@ -33,33 +33,30 @@ from vllm.attention.layer import MultiHeadAttention
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.inputs
import
InputProcessingContext
from
vllm.inputs
import
InputProcessingContext
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.model_loader.loader
import
_initialize_model
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
MultiModal
FieldConfig
,
MultiModal
Kwargs
,
from
vllm.multimodal.inputs
import
(
MultiModal
DataDict
,
MultiModal
FieldConfig
,
NestedTensors
)
MultiModalKwargs
,
NestedTensors
)
from
vllm.multimodal.parse
import
(
ImageProcessorItems
,
ImageSize
,
from
vllm.multimodal.parse
import
(
ImageProcessorItems
,
ImageSize
,
MultiModalDataItems
)
MultiModalDataItems
)
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptReplacement
,
BaseProcessingInfo
,
PromptReplacement
,
PromptUpdate
)
PromptUpdate
,
PromptUpdateDetails
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.tokenizer
import
cached_tokenizer_from_config
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
init_vllm_registered_model
,
from
.llama4
import
Llama4ForCausalLM
maybe_prefix
,
merge_multimodal_embeddings
)
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
maybe_prefix
,
from
.vision
import
scatter_patch_features
,
select_patch_features
merge_multimodal_embeddings
)
logger
=
init_logger
(
__name__
)
class
Llama4ImagePatchInputs
(
TypedDict
):
class
Llama4ImagePatchInputs
(
TypedDict
):
...
@@ -76,11 +73,7 @@ class Llama4ImagePatchInputs(TypedDict):
...
@@ -76,11 +73,7 @@ class Llama4ImagePatchInputs(TypedDict):
This is used to split the embeddings which has the first two dimensions
This is used to split the embeddings which has the first two dimensions
flattened just like `flat_data`.
flattened just like `flat_data`.
"""
"""
embed_is_patch
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
"""
aspect_ratios
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
aspect_ratios
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
"""
"""
A list of aspect ratios corresponding to the number of tiles
A list of aspect ratios corresponding to the number of tiles
...
@@ -345,7 +338,7 @@ class Llama4VisionEncoder(nn.Module):
...
@@ -345,7 +338,7 @@ class Llama4VisionEncoder(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
)
->
BaseModelOutput
:
)
->
torch
.
Tensor
:
r
"""
r
"""
Args:
Args:
inputs_embeds (`torch.FloatTensor` of shape
inputs_embeds (`torch.FloatTensor` of shape
...
@@ -361,7 +354,7 @@ class Llama4VisionEncoder(nn.Module):
...
@@ -361,7 +354,7 @@ class Llama4VisionEncoder(nn.Module):
layer_outputs
=
encoder_layer
(
hidden_states
)
layer_outputs
=
encoder_layer
(
hidden_states
)
hidden_states
=
layer_outputs
[
0
]
hidden_states
=
layer_outputs
[
0
]
return
BaseModelOutput
(
last_hidden_state
=
hidden_states
,
)
return
hidden_states
class
Llama4UnfoldConvolution
(
nn
.
Module
):
class
Llama4UnfoldConvolution
(
nn
.
Module
):
...
@@ -433,7 +426,7 @@ class Llama4VisionModel(nn.Module):
...
@@ -433,7 +426,7 @@ class Llama4VisionModel(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
images_flattened
:
torch
.
Tensor
,
images_flattened
:
torch
.
Tensor
,
)
->
BaseModelOutput
:
)
->
torch
.
Tensor
:
# Patch embedding
# Patch embedding
hidden_state
=
self
.
patch_embedding
(
images_flattened
)
hidden_state
=
self
.
patch_embedding
(
images_flattened
)
num_tiles
,
num_patches
,
hidden_dim
=
hidden_state
.
shape
num_tiles
,
num_patches
,
hidden_dim
=
hidden_state
.
shape
...
@@ -458,8 +451,7 @@ class Llama4VisionModel(nn.Module):
...
@@ -458,8 +451,7 @@ class Llama4VisionModel(nn.Module):
hidden_state
=
hidden_state
.
view
(
num_tiles
,
-
1
,
hidden_dim
)
hidden_state
=
hidden_state
.
view
(
num_tiles
,
-
1
,
hidden_dim
)
# Apply encoder
# Apply encoder
output
=
self
.
model
(
hidden_state
)
hidden_state
=
self
.
model
(
hidden_state
)
hidden_state
=
output
.
last_hidden_state
hidden_state
=
self
.
layernorm_post
(
hidden_state
)
hidden_state
=
self
.
layernorm_post
(
hidden_state
)
# Remove CLS token output
# Remove CLS token output
...
@@ -468,10 +460,7 @@ class Llama4VisionModel(nn.Module):
...
@@ -468,10 +460,7 @@ class Llama4VisionModel(nn.Module):
# now, we use Llama4VisionPixelShuffle + mlp to project embeddings
# now, we use Llama4VisionPixelShuffle + mlp to project embeddings
hidden_state
=
self
.
vision_adapter
(
hidden_state
)
hidden_state
=
self
.
vision_adapter
(
hidden_state
)
return
BaseModelOutput
(
return
hidden_state
last_hidden_state
=
hidden_state
,
attentions
=
None
,
)
class
Mllama4ProcessingInfo
(
BaseProcessingInfo
):
class
Mllama4ProcessingInfo
(
BaseProcessingInfo
):
...
@@ -488,7 +477,9 @@ class Mllama4ProcessingInfo(BaseProcessingInfo):
...
@@ -488,7 +477,9 @@ class Mllama4ProcessingInfo(BaseProcessingInfo):
**
kwargs
)
**
kwargs
)
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
10
}
# Although vLLM can support more images from an infra capability
# perspective, we do not recommend using >10 images in practice.
return
{
"image"
:
None
}
@
staticmethod
@
staticmethod
def
get_patch_per_chunk
(
vision_config
:
Llama4VisionConfig
)
->
int
:
def
get_patch_per_chunk
(
vision_config
:
Llama4VisionConfig
)
->
int
:
...
@@ -507,17 +498,6 @@ class Mllama4ProcessingInfo(BaseProcessingInfo):
...
@@ -507,17 +498,6 @@ class Mllama4ProcessingInfo(BaseProcessingInfo):
image_processor
=
self
.
get_hf_processor
().
image_processor
image_processor
=
self
.
get_hf_processor
().
image_processor
return
image_processor
.
max_patches
return
image_processor
.
max_patches
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
int
]:
vision_config
=
self
.
get_hf_config
().
vision_config
# image_start + local tiles * (patches + 1 x separator) +
# 1 global tile * (image x 1 + patches) + image_end
token_per_chunk
=
self
.
get_patch_per_chunk
(
vision_config
)
+
1
mm_max_tokens
=
(
self
.
get_max_num_tiles
()
+
1
)
*
token_per_chunk
+
2
return
{
"image"
:
mm_max_tokens
}
def
get_image_size_with_most_features
(
self
)
->
ImageSize
:
def
get_image_size_with_most_features
(
self
)
->
ImageSize
:
vision_config
=
self
.
get_hf_config
().
vision_config
vision_config
=
self
.
get_hf_config
().
vision_config
...
@@ -581,33 +561,9 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
...
@@ -581,33 +561,9 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
for
(
r_h
,
r_w
)
in
aspect_ratios
for
(
r_h
,
r_w
)
in
aspect_ratios
]
]
# embed_is_patch should have one feature per image-related token:
# <|image_start|>, <|tile_*_separator|>, <|image|>, <|image_end|>
# -> False
# <|patch|> -> True
# embed_is_patch has no entries corresponding to non-image-related
# tokens.
patch_id
=
tokenizer
.
get_vocab
()[
processor
.
img_patch_token
]
num_patches_per_chunk
=
self
.
info
.
get_patch_per_chunk
(
vision_config
)
expanded_image_tokens_list
=
[
processor
.
_prompt_split_image
(
aspect_ratio
,
num_patches_per_chunk
)
for
aspect_ratio
in
aspect_ratios
]
expanded_image_token_ids
=
[
tokenizer
.
encode
(
image_tokens
,
add_special_tokens
=
False
)
for
image_tokens
in
expanded_image_tokens_list
]
embed_is_patch
=
[
torch
.
tensor
(
tokens
)
==
patch_id
for
tokens
in
expanded_image_token_ids
]
processed_outputs
[
"aspect_ratios"
]
=
aspect_ratios
processed_outputs
[
"aspect_ratios"
]
=
aspect_ratios
processed_outputs
[
"patches_per_image"
]
=
torch
.
tensor
(
processed_outputs
[
"patches_per_image"
]
=
torch
.
tensor
(
patches_per_image
)
patches_per_image
)
processed_outputs
[
"embed_is_patch"
]
=
embed_is_patch
return
processed_outputs
return
processed_outputs
...
@@ -622,7 +578,6 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
...
@@ -622,7 +578,6 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
"image"
,
patches_per_image
),
"image"
,
patches_per_image
),
patches_per_image
=
MultiModalFieldConfig
.
batched
(
"image"
),
patches_per_image
=
MultiModalFieldConfig
.
batched
(
"image"
),
aspect_ratios
=
MultiModalFieldConfig
.
batched
(
"image"
),
aspect_ratios
=
MultiModalFieldConfig
.
batched
(
"image"
),
embed_is_patch
=
MultiModalFieldConfig
.
batched
(
"image"
),
)
)
def
_get_prompt_updates
(
def
_get_prompt_updates
(
...
@@ -642,12 +597,17 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
...
@@ -642,12 +597,17 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
num_patches_per_chunk
=
self
.
info
.
get_patch_per_chunk
(
vision_config
)
num_patches_per_chunk
=
self
.
info
.
get_patch_per_chunk
(
vision_config
)
hf_processor
=
self
.
info
.
get_hf_processor
(
**
hf_processor_mm_kwargs
)
hf_processor
=
self
.
info
.
get_hf_processor
(
**
hf_processor_mm_kwargs
)
image_token
=
hf_processor
.
image_token
image_token
=
hf_processor
.
image_token
img_patch_token
=
hf_processor
.
img_patch_token
def
get_replacement
(
item_idx
:
int
):
def
get_replacement
(
item_idx
:
int
):
aspect_ratio
=
out_mm_kwargs
[
"aspect_ratios"
][
item_idx
]
aspect_ratio
=
out_mm_kwargs
[
"aspect_ratios"
][
item_idx
]
return
hf_processor
.
_prompt_split_image
(
repl
=
hf_processor
.
_prompt_split_image
(
aspect_ratio
=
aspect_ratio
,
aspect_ratio
=
aspect_ratio
,
num_patches_per_chunk
=
num_patches_per_chunk
)
num_patches_per_chunk
=
num_patches_per_chunk
,
)
return
PromptUpdateDetails
.
select_text
(
repl
,
img_patch_token
)
return
[
return
[
PromptReplacement
(
PromptReplacement
(
...
@@ -660,36 +620,39 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
...
@@ -660,36 +620,39 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
class
Mllama4DummyInputsBuilder
(
BaseDummyInputsBuilder
[
Mllama4ProcessingInfo
]):
class
Mllama4DummyInputsBuilder
(
BaseDummyInputsBuilder
[
Mllama4ProcessingInfo
]):
def
get_dummy_processor_inputs
(
def
get_dummy_text
(
self
,
mm_counts
:
Mapping
[
str
,
int
])
->
str
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
processor
=
self
.
info
.
get_hf_processor
()
image_token
=
processor
.
fake_image_token
return
image_token
*
num_images
def
get_dummy_mm_data
(
self
,
self
,
seq_len
:
int
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
mm_counts
:
Mapping
[
str
,
int
],
)
->
ProcessorInputs
:
)
->
MultiModalDataDict
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
num_images
=
mm_counts
.
get
(
"image"
,
0
)
(
target_width
,
(
target_width
,
target_height
)
=
self
.
info
.
get_image_size_with_most_features
()
target_height
)
=
self
.
info
.
get_image_size_with_most_features
()
mm_data
=
{
return
{
"image"
:
"image"
:
self
.
_get_dummy_images
(
width
=
target_width
,
self
.
_get_dummy_images
(
width
=
target_width
,
height
=
target_height
,
height
=
target_height
,
num_images
=
num_images
)
num_images
=
num_images
)
}
}
image_token
=
self
.
info
.
get_hf_processor
().
fake_image_token
return
ProcessorInputs
(
prompt_text
=
image_token
*
num_images
,
mm_data
=
mm_data
,
)
@
MULTIMODAL_REGISTRY
.
register_processor
(
@
MULTIMODAL_REGISTRY
.
register_processor
(
Mllama4MultiModalProcessor
,
Mllama4MultiModalProcessor
,
info
=
Mllama4ProcessingInfo
,
info
=
Mllama4ProcessingInfo
,
dummy_inputs
=
Mllama4DummyInputsBuilder
,
dummy_inputs
=
Mllama4DummyInputsBuilder
,
)
)
class
Llama4ForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
):
class
Llama4ForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
packed_modules_mapping
=
{
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
}
}
...
@@ -710,13 +673,22 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -710,13 +673,22 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal):
self
.
config
,
self
.
config
,
None
,
None
,
prefix
=
maybe_prefix
(
prefix
,
"multi_modal_projector"
))
prefix
=
maybe_prefix
(
prefix
,
"multi_modal_projector"
))
self
.
language_model
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
hf_config
=
config
.
text_config
,
architectures
=
[
"Llama4ForCausalLM"
],
prefix
=
maybe_prefix
(
prefix
,
"language_model"
))
self
.
tokenizer
=
cached_tokenizer_from_config
(
vllm_config
.
model_config
)
self
.
language_model
=
_initialize_model
(
vllm_config
=
vllm_config
.
with_hf_config
(
config
.
text_config
),
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
model_class
=
Llama4ForCausalLM
,
)
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
@
cached_property
def
sampler
(
self
):
if
hasattr
(
self
.
language_model
,
"sampler"
):
return
self
.
language_model
.
sampler
return
get_sampler
()
def
_parse_and_validate_image_input
(
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
Llama4ImagePatchInputs
]:
self
,
**
kwargs
:
object
)
->
Optional
[
Llama4ImagePatchInputs
]:
...
@@ -730,11 +702,6 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -730,11 +702,6 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal):
flat_pixel_values
=
flatten_bn
(
pixel_values
,
concat
=
True
)
flat_pixel_values
=
flatten_bn
(
pixel_values
,
concat
=
True
)
patches_per_image
=
flatten_bn
(
kwargs
.
pop
(
"patches_per_image"
))
patches_per_image
=
flatten_bn
(
kwargs
.
pop
(
"patches_per_image"
))
embed_is_patch
=
kwargs
.
pop
(
"embed_is_patch"
,
None
)
if
not
isinstance
(
embed_is_patch
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of embed_is_patch. "
f
"Got type:
{
type
(
embed_is_patch
)
}
"
)
aspect_ratios
=
kwargs
.
pop
(
"aspect_ratios"
,
None
)
aspect_ratios
=
kwargs
.
pop
(
"aspect_ratios"
,
None
)
if
not
isinstance
(
aspect_ratios
,
(
torch
.
Tensor
,
list
)):
if
not
isinstance
(
aspect_ratios
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of aspect_ratios. "
raise
ValueError
(
"Incorrect type of aspect_ratios. "
...
@@ -744,7 +711,6 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -744,7 +711,6 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal):
type
=
"pixel_values"
,
type
=
"pixel_values"
,
flat_data
=
flat_pixel_values
,
flat_data
=
flat_pixel_values
,
patches_per_image
=
patches_per_image
,
patches_per_image
=
patches_per_image
,
embed_is_patch
=
embed_is_patch
,
aspect_ratios
=
aspect_ratios
,
aspect_ratios
=
aspect_ratios
,
)
)
...
@@ -752,8 +718,18 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -752,8 +718,18 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal):
self
,
image_input
:
Llama4ImagePatchInputs
)
->
MultiModalEmbeddings
:
self
,
image_input
:
Llama4ImagePatchInputs
)
->
MultiModalEmbeddings
:
flat_data
=
image_input
[
"flat_data"
]
flat_data
=
image_input
[
"flat_data"
]
patches_per_image
=
image_input
[
"patches_per_image"
].
tolist
()
patches_per_image
=
image_input
[
"patches_per_image"
].
tolist
()
vision_embeddings_flat
=
self
.
vision_model
(
flat_data
).
last_hidden_state
return
vision_embeddings_flat
.
split
(
patches_per_image
,
dim
=
0
)
vision_embeddings_flat
=
self
.
vision_model
(
flat_data
)
vision_embeddings_flat
=
self
.
multi_modal_projector
(
vision_embeddings_flat
)
return
[
img
.
flatten
(
0
,
1
)
for
img
in
vision_embeddings_flat
.
split
(
patches_per_image
,
dim
=
0
)
]
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
get_multimodal_embeddings
(
self
,
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Optional
[
MultiModalEmbeddings
]:
**
kwargs
)
->
Optional
[
MultiModalEmbeddings
]:
...
@@ -761,20 +737,7 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -761,20 +737,7 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal):
if
image_input
is
None
:
if
image_input
is
None
:
return
None
return
None
# num_images x [num_chunks, num_patches, hidden_dim]
return
self
.
_process_image_input
(
image_input
)
image_features
=
self
.
_process_image_input
(
image_input
)
# num_images x [num_chunks x num_patches, hidden_dim]
image_features_flat
=
[
img
.
flatten
(
0
,
1
)
for
img
in
image_features
]
# num_images x [1, input_len] -> num_images x [input_len]
embed_is_patch_flat
=
[
is_patch
.
flatten
(
0
,
1
)
for
is_patch
in
image_input
[
"embed_is_patch"
]
]
return
scatter_patch_features
(
image_features_flat
,
embed_is_patch_flat
,
)
def
get_input_embeddings
(
def
get_input_embeddings
(
self
,
self
,
...
@@ -784,11 +747,12 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -784,11 +747,12 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal):
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
if
multimodal_embeddings
is
not
None
:
if
multimodal_embeddings
is
not
None
:
multimodal_embeddings
=
torch
.
cat
(
multimodal_embeddings
)
mm_embeddings
=
self
.
multi_modal_projector
(
multimodal_embeddings
)
inputs_embeds
=
merge_multimodal_embeddings
(
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
select_patch_features
(
mm_embeddings
),
input_ids
,
self
.
config
.
image_token_index
)
inputs_embeds
,
multimodal_embeddings
,
self
.
config
.
image_token_index
,
)
return
inputs_embeds
return
inputs_embeds
...
@@ -800,9 +764,12 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -800,9 +764,12 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal):
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
:
object
,
**
kwargs
:
object
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
# NOTE: In v1, inputs_embeds is always generated at model runner, this
if
intermediate_tensors
is
not
None
:
# condition is for v0 compatibility.
inputs_embeds
=
None
if
"pixel_values"
in
kwargs
:
# NOTE: In v1, inputs_embeds is always generated at model runner,
# this condition is for v0 compatibility.
elif
inputs_embeds
is
None
:
vision_embeddings
=
self
.
get_multimodal_embeddings
(
**
kwargs
)
vision_embeddings
=
self
.
get_multimodal_embeddings
(
**
kwargs
)
inputs_embeds
=
self
.
get_input_embeddings
(
input_ids
,
inputs_embeds
=
self
.
get_input_embeddings
(
input_ids
,
vision_embeddings
)
vision_embeddings
)
...
@@ -857,9 +824,8 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -857,9 +824,8 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal):
# language_model is an Llama4ForCausalLM instance. We load it's
# language_model is an Llama4ForCausalLM instance. We load it's
# using llama4's load_weights routine.
# using llama4's load_weights routine.
language_model_prefix
=
"language_model.model."
language_model_weights
,
other_weights
=
self
.
separate_weights
(
language_model_weights
,
other_weights
=
self
.
separate_weights
(
weights
,
prefix
=
language_model
_prefix
)
weights
,
prefix
=
"
language_model
.model."
)
loader
=
AutoWeightsLoader
(
self
)
loader
=
AutoWeightsLoader
(
self
)
loaded_language_model_params
=
loader
.
load_weights
(
loaded_language_model_params
=
loader
.
load_weights
(
language_model_weights
)
language_model_weights
)
...
...
vllm/model_executor/models/molmo.py
View file @
9c4ecf15
...
@@ -41,13 +41,15 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -41,13 +41,15 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
MultiModalFieldConfig
,
MultiModalKwargs
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalKwargs
)
from
vllm.multimodal.parse
import
(
ImageProcessorItems
,
ImageSize
,
from
vllm.multimodal.parse
import
(
ImageProcessorItems
,
ImageSize
,
MultiModalDataItems
)
MultiModalDataItems
)
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptIndexTargets
,
BaseProcessingInfo
,
PromptIndexTargets
,
PromptInsertion
,
PromptUpdate
)
PromptInsertion
,
PromptUpdate
,
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
PromptUpdateDetails
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsLoRA
,
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsLoRA
,
...
@@ -56,7 +58,6 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
...
@@ -56,7 +58,6 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
is_pp_missing_parameter
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
,
merge_multimodal_embeddings
)
maybe_prefix
,
merge_multimodal_embeddings
)
from
.vision
import
scatter_patch_features
,
select_patch_features
# TODO: hard-coded for now. Consider making it configurable.
# TODO: hard-coded for now. Consider making it configurable.
VIT_LAYERS
=
[
-
2
,
-
9
]
VIT_LAYERS
=
[
-
2
,
-
9
]
...
@@ -84,14 +85,6 @@ class MolmoImageInputs(TypedDict):
...
@@ -84,14 +85,6 @@ class MolmoImageInputs(TypedDict):
Shape: `(batch_size * num_images, num_crops, num_patch)`
Shape: `(batch_size * num_images, num_crops, num_patch)`
"""
"""
embed_is_patch
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
num_crops
:
torch
.
Tensor
num_crops
:
torch
.
Tensor
"""Shape: `(batch_size * num_images)`"""
"""Shape: `(batch_size * num_images)`"""
...
@@ -1146,30 +1139,6 @@ class MolmoProcessorWrapper:
...
@@ -1146,30 +1139,6 @@ class MolmoProcessorWrapper:
if
image_input_idx
is
not
None
:
if
image_input_idx
is
not
None
:
feat_is_patch
=
image_input_idx
>=
0
feat_is_patch
=
image_input_idx
>=
0
input_is_embed
=
torch
.
isin
(
input_ids
,
torch
.
tensor
([
self
.
image_patch_id
,
self
.
im_col_id
,
self
.
im_start_id
,
self
.
im_end_id
,
]),
)
embed_ids
=
input_ids
[
input_is_embed
]
embed_is_patch
=
embed_ids
==
self
.
image_patch_id
assert
embed_is_patch
.
sum
()
==
feat_is_patch
.
sum
()
# image_tokens = extra_joint + joint
# Both `extra_joint` and `joint` have `im_start_id` and `im_end_id`
embed_start
=
torch
.
nonzero
(
embed_ids
==
self
.
im_start_id
)[::
2
,
0
]
embed_end
=
torch
.
nonzero
(
embed_ids
==
self
.
im_end_id
)[
1
::
2
,
0
]
assert
len
(
embed_start
)
==
len
(
embed_end
)
==
len
(
images
)
embed_is_patch
=
[
embed_is_patch
[
start
:
end
+
1
]
for
start
,
end
in
zip
(
embed_start
,
embed_end
)
]
tilings
=
[
tilings
=
[
self
.
select_tiling
(
self
.
select_tiling
(
image_width
=
image
.
size
[
0
],
image_width
=
image
.
size
[
0
],
...
@@ -1181,7 +1150,6 @@ class MolmoProcessorWrapper:
...
@@ -1181,7 +1150,6 @@ class MolmoProcessorWrapper:
assert
num_crops
.
sum
()
==
len
(
feat_is_patch
)
assert
num_crops
.
sum
()
==
len
(
feat_is_patch
)
outputs
[
"feat_is_patch"
]
=
feat_is_patch
outputs
[
"feat_is_patch"
]
=
feat_is_patch
outputs
[
"embed_is_patch"
]
=
embed_is_patch
outputs
[
"num_crops"
]
=
num_crops
outputs
[
"num_crops"
]
=
num_crops
outputs
[
"img_patch_id"
]
=
self
.
image_patch_id
outputs
[
"img_patch_id"
]
=
self
.
image_patch_id
...
@@ -1197,13 +1165,6 @@ class MolmoProcessingInfo(BaseProcessingInfo):
...
@@ -1197,13 +1165,6 @@ class MolmoProcessingInfo(BaseProcessingInfo):
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
None
}
return
{
"image"
:
None
}
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
int
]:
return
{
"image"
:
self
.
get_max_image_tokens
()}
def
get_num_image_tokens
(
def
get_num_image_tokens
(
self
,
self
,
*
,
*
,
...
@@ -1220,26 +1181,13 @@ class MolmoProcessingInfo(BaseProcessingInfo):
...
@@ -1220,26 +1181,13 @@ class MolmoProcessingInfo(BaseProcessingInfo):
)
)
pooling_size
=
processor
.
pooling_size
pooling_size
=
processor
.
pooling_size
base_image_input_size
=
processor
.
base_image_input_size
image_token_length_w
=
processor
.
image_token_length_w
base_image_input_d
=
processor
.
image_patch_size
image_token_length_h
=
processor
.
image_token_length_h
crop_patches
=
base_image_input_size
[
0
]
//
base_image_input_d
per_row
=
ncols
//
pooling_size
+
1
joint
=
per_row
*
(
nrows
//
pooling_size
)
+
2
image_token_length
=
(
crop_patches
+
pooling_size
-
1
)
//
pooling_size
resize
=
(
image_token_length
+
1
)
*
image_token_length
+
2
return
resize
+
joint
extra
=
image_token_length_w
*
image_token_length_h
joint
=
((
ncols
+
1
)
//
pooling_size
)
*
((
nrows
+
1
)
//
pooling_size
)
def
get_max_image_tokens
(
self
)
->
int
:
return
extra
+
joint
target_width
,
target_height
=
self
.
get_image_size_with_most_features
()
return
self
.
get_num_image_tokens
(
image_width
=
target_width
,
image_height
=
target_height
,
processor
=
None
,
)
def
get_image_size_with_most_features
(
self
)
->
ImageSize
:
def
get_image_size_with_most_features
(
self
)
->
ImageSize
:
processor
=
self
.
get_hf_processor
()
processor
=
self
.
get_hf_processor
()
...
@@ -1269,27 +1217,25 @@ class MolmoProcessingInfo(BaseProcessingInfo):
...
@@ -1269,27 +1217,25 @@ class MolmoProcessingInfo(BaseProcessingInfo):
class
MolmoDummyInputsBuilder
(
BaseDummyInputsBuilder
[
MolmoProcessingInfo
]):
class
MolmoDummyInputsBuilder
(
BaseDummyInputsBuilder
[
MolmoProcessingInfo
]):
def
get_dummy_processor_inputs
(
def
get_dummy_text
(
self
,
mm_counts
:
Mapping
[
str
,
int
])
->
str
:
return
""
def
get_dummy_mm_data
(
self
,
self
,
seq_len
:
int
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
mm_counts
:
Mapping
[
str
,
int
],
)
->
ProcessorInputs
:
)
->
MultiModalDataDict
:
target_width
,
target_height
=
\
target_width
,
target_height
=
\
self
.
info
.
get_image_size_with_most_features
()
self
.
info
.
get_image_size_with_most_features
()
num_images
=
mm_counts
.
get
(
"image"
,
0
)
num_images
=
mm_counts
.
get
(
"image"
,
0
)
mm_data
=
{
return
{
"image"
:
"image"
:
self
.
_get_dummy_images
(
width
=
target_width
,
self
.
_get_dummy_images
(
width
=
target_width
,
height
=
target_height
,
height
=
target_height
,
num_images
=
num_images
)
num_images
=
num_images
)
}
}
return
ProcessorInputs
(
prompt_text
=
""
,
mm_data
=
mm_data
,
)
class
MolmoMultiModalProcessor
(
BaseMultiModalProcessor
[
MolmoProcessingInfo
]):
class
MolmoMultiModalProcessor
(
BaseMultiModalProcessor
[
MolmoProcessingInfo
]):
...
@@ -1328,7 +1274,6 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
...
@@ -1328,7 +1274,6 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
"image"
,
num_crops
),
"image"
,
num_crops
),
feat_is_patch
=
MultiModalFieldConfig
.
flat_from_sizes
(
feat_is_patch
=
MultiModalFieldConfig
.
flat_from_sizes
(
"image"
,
num_crops
),
"image"
,
num_crops
),
embed_is_patch
=
MultiModalFieldConfig
.
batched
(
"image"
),
num_crops
=
MultiModalFieldConfig
.
batched
(
"image"
),
num_crops
=
MultiModalFieldConfig
.
batched
(
"image"
),
img_patch_id
=
MultiModalFieldConfig
.
shared
(
"image"
,
num_images
),
img_patch_id
=
MultiModalFieldConfig
.
shared
(
"image"
,
num_images
),
)
)
...
@@ -1368,8 +1313,10 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
...
@@ -1368,8 +1313,10 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
joint
=
([
img_start_id
]
+
joint_row
*
joint
=
([
img_start_id
]
+
joint_row
*
((
nrows
+
1
)
//
pooling_size
)
+
[
img_end_id
])
((
nrows
+
1
)
//
pooling_size
)
+
[
img_end_id
])
image_tokens
=
extra_joint
+
joint
return
PromptUpdateDetails
.
select_token_id
(
return
image_tokens
extra_joint
+
joint
,
embed_token_id
=
img_patch_id
,
)
return
[
return
[
PromptInsertion
(
PromptInsertion
(
...
@@ -1475,11 +1422,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
...
@@ -1475,11 +1422,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
raise
ValueError
(
"Incorrect type of feat_is_patch. "
raise
ValueError
(
"Incorrect type of feat_is_patch. "
f
"Got type:
{
type
(
feat_is_patch
)
}
"
)
f
"Got type:
{
type
(
feat_is_patch
)
}
"
)
embed_is_patch
=
kwargs
.
pop
(
"embed_is_patch"
,
None
)
if
not
isinstance
(
embed_is_patch
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of embed_is_patch. "
f
"Got type:
{
type
(
embed_is_patch
)
}
"
)
num_crops
=
kwargs
.
pop
(
"num_crops"
,
None
)
num_crops
=
kwargs
.
pop
(
"num_crops"
,
None
)
if
not
isinstance
(
num_crops
,
(
torch
.
Tensor
,
list
)):
if
not
isinstance
(
num_crops
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of num_crops. "
raise
ValueError
(
"Incorrect type of num_crops. "
...
@@ -1491,14 +1433,12 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
...
@@ -1491,14 +1433,12 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
f
"Got type:
{
type
(
img_patch_id
)
}
"
)
f
"Got type:
{
type
(
img_patch_id
)
}
"
)
self
.
img_patch_id
=
img_patch_id
.
flatten
().
unique
().
item
()
self
.
img_patch_id
=
img_patch_id
.
flatten
().
unique
().
item
()
embed_is_patch
=
flatten_bn
(
embed_is_patch
)
num_crops
=
flatten_bn
(
num_crops
,
concat
=
True
)
num_crops
=
flatten_bn
(
num_crops
,
concat
=
True
)
return
MolmoImageInputs
(
return
MolmoImageInputs
(
images
=
images
,
images
=
images
,
image_masks
=
image_masks
,
image_masks
=
image_masks
,
feat_is_patch
=
feat_is_patch
,
feat_is_patch
=
feat_is_patch
,
embed_is_patch
=
embed_is_patch
,
num_crops
=
num_crops
,
num_crops
=
num_crops
,
)
)
...
@@ -1531,18 +1471,16 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
...
@@ -1531,18 +1471,16 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
)
)
]
]
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
model
def
get_multimodal_embeddings
(
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
if
image_input
is
None
:
return
None
return
None
image_features
=
self
.
_process_image_input
(
image_input
)
return
self
.
_process_image_input
(
image_input
)
return
scatter_patch_features
(
image_features
,
image_input
[
"embed_is_patch"
],
)
def
get_input_embeddings
(
def
get_input_embeddings
(
self
,
self
,
...
@@ -1556,7 +1494,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
...
@@ -1556,7 +1494,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
inputs_embeds
=
merge_multimodal_embeddings
(
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
input_ids
,
inputs_embeds
,
inputs_embeds
,
select_patch_features
(
multimodal_embeddings
)
,
multimodal_embeddings
,
self
.
img_patch_id
,
self
.
img_patch_id
,
)
)
return
inputs_embeds
return
inputs_embeds
...
...
vllm/model_executor/models/nvlm_d.py
View file @
9c4ecf15
...
@@ -15,12 +15,11 @@ from transformers import PretrainedConfig
...
@@ -15,12 +15,11 @@ from transformers import PretrainedConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
MultiModalKwargs
from
vllm.multimodal.inputs
import
MultiModalDataDict
,
MultiModalKwargs
from
vllm.multimodal.parse
import
(
ImageEmbeddingItems
,
ImageProcessorItems
,
from
vllm.multimodal.parse
import
(
ImageEmbeddingItems
,
ImageProcessorItems
,
MultiModalDataItems
)
MultiModalDataItems
)
from
vllm.multimodal.processing
import
(
PromptReplacement
,
PromptUpdate
,
from
vllm.multimodal.processing
import
(
PromptReplacement
,
PromptUpdate
,
PromptUpdateDetails
)
PromptUpdateDetails
)
from
vllm.multimodal.profiling
import
ProcessorInputs
from
.intern_vit
import
InternVisionModel
from
.intern_vit
import
InternVisionModel
from
.internvl
import
(
BaseInternVLProcessingInfo
,
BaseInternVLProcessor
,
from
.internvl
import
(
BaseInternVLProcessingInfo
,
BaseInternVLProcessor
,
...
@@ -57,7 +56,7 @@ class NVLMProcessor(BaseInternVLProcessor):
...
@@ -57,7 +56,7 @@ class NVLMProcessor(BaseInternVLProcessor):
# when trying to find "<tile" as a subsequence of "<Image><tile"
# when trying to find "<tile" as a subsequence of "<Image><tile"
repl
=
"<Image>"
+
features
+
"</Image>"
repl
=
"<Image>"
+
features
+
"</Image>"
return
PromptUpdateDetails
(
full
=
repl
,
features
=
repl
)
return
PromptUpdateDetails
.
select_text
(
repl
,
IMG_PAD
)
class
NVLMProcessingInfo
(
BaseInternVLProcessingInfo
):
class
NVLMProcessingInfo
(
BaseInternVLProcessingInfo
):
...
@@ -84,57 +83,32 @@ class NVLMProcessingInfo(BaseInternVLProcessingInfo):
...
@@ -84,57 +83,32 @@ class NVLMProcessingInfo(BaseInternVLProcessingInfo):
**
kwargs
,
**
kwargs
,
)
)
def
get_max_image_tokens
(
self
)
->
int
:
hf_processor
=
self
.
get_hf_processor
()
tokenizer
=
hf_processor
.
tokenizer
max_num_patches
=
hf_processor
.
max_dynamic_patch
class
NVLMDummyInputsBuilder
(
InternVLDummyInputsBuilder
[
NVLMProcessingInfo
]):
# we need +1 here because max_dynamic_patch in config doesn't
# include the thumbnail patch
tile_pos_identifiers
=
[
f
"<tile_
{
i
+
1
}
>"
for
i
in
range
(
max_num_patches
)
]
if
hf_processor
.
use_thumbnail
and
max_num_patches
!=
1
:
tile_pos_identifiers
+=
[
"<tile_global_thumbnail>"
]
# "<Image><tile" is tokenized as ["<Image", "><", "tile"]
def
get_dummy_text
(
self
,
mm_counts
:
Mapping
[
str
,
int
])
->
str
:
# so we include <tile_1> in the start_str
num_images
=
mm_counts
.
get
(
"image"
,
0
)
start_str
=
"<Image>"
+
tile_pos_identifiers
.
pop
(
0
)
end_str
=
"</Image>"
start_token_len
=
len
(
tokenizer
.
encode
(
start_str
))
end_token_len
=
len
(
tokenizer
.
encode
(
end_str
))
tile_token_len
=
sum
(
len
(
tokenizer
.
encode
(
identifier
))
for
identifier
in
tile_pos_identifiers
)
non_image_tokens_num
=
start_token_len
+
end_token_len
+
tile_token_len
return
super
().
get_max_image_tokens
()
+
non_image_tokens_num
# The newline is necessary to separate ">" of the current item
# and "<" of the next item
return
"<image>
\n
"
*
num_images
class
NVLMDummyInputsBuilder
(
InternVLDummyInputsBuilder
[
NVLMProcessingInfo
]):
def
get_dummy_mm_data
(
def
get_dummy_processor_inputs
(
self
,
self
,
seq_len
:
int
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
mm_counts
:
Mapping
[
str
,
int
],
)
->
ProcessorInputs
:
)
->
MultiModalDataDict
:
target_width
,
target_height
=
\
target_width
,
target_height
=
\
self
.
info
.
get_image_size_with_most_features
()
self
.
info
.
get_image_size_with_most_features
()
num_images
=
mm_counts
.
get
(
"image"
,
0
)
num_images
=
mm_counts
.
get
(
"image"
,
0
)
mm_data
=
{
return
{
"image"
:
"image"
:
self
.
_get_dummy_images
(
width
=
target_width
,
self
.
_get_dummy_images
(
width
=
target_width
,
height
=
target_height
,
height
=
target_height
,
num_images
=
num_images
)
num_images
=
num_images
)
}
}
return
ProcessorInputs
(
# The newline is necessary to separate ">" of the current item
# and "<" of the next item
prompt_text
=
"<image>
\n
"
*
num_images
,
mm_data
=
mm_data
,
)
class
NVLMMultiModalProcessor
(
InternVLMultiModalProcessor
[
NVLMProcessingInfo
]):
class
NVLMMultiModalProcessor
(
InternVLMultiModalProcessor
[
NVLMProcessingInfo
]):
...
@@ -177,10 +151,7 @@ class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]):
...
@@ -177,10 +151,7 @@ class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]):
repl
=
hf_processor
.
get_image_repl
(
feature_size
,
num_patches
)
repl
=
hf_processor
.
get_image_repl
(
feature_size
,
num_patches
)
return
PromptUpdateDetails
(
return
PromptUpdateDetails
.
select_text
(
repl
.
full
+
"
\n
"
,
IMG_PAD
)
full
=
repl
.
full
+
"
\n
"
,
features
=
repl
.
features
+
"
\n
"
,
)
# See note in dummy data regarding why we have the extra newline
# See note in dummy data regarding why we have the extra newline
return
[
return
[
...
...
Prev
1
…
9
10
11
12
13
14
15
16
17
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment