Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
541d1df4
Unverified
Commit
541d1df4
authored
Mar 28, 2025
by
Cyrus Leung
Committed by
GitHub
Mar 28, 2025
Browse files
[Bugfix] `embed_is_patch` for Idefics3 (#15696)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
3b00ff91
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
320 additions
and
188 deletions
+320
-188
vllm/model_executor/models/commandr.py
vllm/model_executor/models/commandr.py
+0
-1
vllm/model_executor/models/idefics3.py
vllm/model_executor/models/idefics3.py
+318
-183
vllm/model_executor/models/mllama.py
vllm/model_executor/models/mllama.py
+0
-1
vllm/model_executor/models/qwen2_audio.py
vllm/model_executor/models/qwen2_audio.py
+1
-1
vllm/model_executor/models/ultravox.py
vllm/model_executor/models/ultravox.py
+1
-2
No files found.
vllm/model_executor/models/commandr.py
View file @
541d1df4
...
@@ -24,7 +24,6 @@
...
@@ -24,7 +24,6 @@
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
import
torch.utils.checkpoint
from
torch
import
nn
from
torch
import
nn
from
transformers
import
CohereConfig
from
transformers
import
CohereConfig
...
...
vllm/model_executor/models/idefics3.py
View file @
541d1df4
...
@@ -17,16 +17,14 @@
...
@@ -17,16 +17,14 @@
import
math
import
math
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
typing
import
Dict
,
List
,
Literal
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
from
typing
import
Dict
,
Literal
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
import
torch
import
torch
import
torch.utils.checkpoint
from
torch
import
nn
from
torch
import
nn
from
transformers
import
(
BatchFeature
,
Idefics3Config
,
Idefics3ImageProcessor
,
from
transformers
import
(
BatchFeature
,
Idefics3Config
,
Idefics3ImageProcessor
,
Idefics3Processor
)
Idefics3Processor
)
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
ReplicatedLinear
from
vllm.model_executor.layers.linear
import
ReplicatedLinear
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
...
@@ -35,13 +33,16 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
...
@@ -35,13 +33,16 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
from
vllm.multimodal.inputs
import
NestedTensors
from
vllm.multimodal.parse
import
ImageProcessorItems
,
ImageSize
from
vllm.multimodal.parse
import
ImageProcessorItems
# yapf conflicts with isort for this block
# yapf: disable
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
BaseProcessingInfo
,
MultiModalDataItems
,
MultiModalDataItems
,
MultiModalFieldConfig
,
MultiModalFieldConfig
,
PromptReplacement
,
PromptUpdate
)
PromptReplacement
,
PromptUpdate
,
encode_tokens
)
# yapf: enable
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
...
@@ -53,18 +54,28 @@ from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
...
@@ -53,18 +54,28 @@ from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
from
.llama
import
LlamaModel
from
.llama
import
LlamaModel
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
maybe_prefix
,
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
maybe_prefix
,
merge_multimodal_embeddings
)
merge_multimodal_embeddings
)
from
.vision
import
scatter_patch_features
,
select_patch_features
logger
=
init_logger
(
__name__
)
class
Idefics3ImagePixelInputs
(
TypedDict
):
class
Idefics3ImagePixelInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values"
]
type
:
Literal
[
"pixel_values"
]
data
:
torch
.
Tensor
pixel_values
:
torch
.
Tensor
"""
"""
Shape: `(batch_size * num_images * num_patches,
Shape: `(batch_size * num_images * num_patches,
num_channels, height, width)`
num_channels, height, width)`
"""
"""
pixel_attention_mask
:
Optional
[
torch
.
BoolTensor
]
pixel_attention_mask
:
torch
.
Tensor
num_patches
:
torch
.
Tensor
"""Shape: `(batch_size * num_images)`"""
embed_is_patch
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
class
Idefics3ImageEmbeddingInputs
(
TypedDict
):
class
Idefics3ImageEmbeddingInputs
(
TypedDict
):
...
@@ -75,6 +86,14 @@ class Idefics3ImageEmbeddingInputs(TypedDict):
...
@@ -75,6 +86,14 @@ class Idefics3ImageEmbeddingInputs(TypedDict):
`hidden_size` must match the hidden size of language model backbone.
`hidden_size` must match the hidden size of language model backbone.
"""
"""
embed_is_patch
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
ImageInputs
=
Union
[
Idefics3ImagePixelInputs
,
Idefics3ImageEmbeddingInputs
]
ImageInputs
=
Union
[
Idefics3ImagePixelInputs
,
Idefics3ImageEmbeddingInputs
]
...
@@ -100,32 +119,14 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
...
@@ -100,32 +119,14 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
seq_len
:
int
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
mm_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
int
]:
)
->
Mapping
[
str
,
int
]:
hf_processor
=
self
.
get_hf_processor
()
return
{
"image"
:
self
.
get_max_image_tokens
()}
image_processor
:
Idefics3ImageProcessor
=
hf_processor
.
image_processor
grid_w
,
grid_h
=
self
.
_get_image_feature_grid_size
(
image_width
=
image_processor
.
size
[
'longest_edge'
],
image_height
=
image_processor
.
size
[
'longest_edge'
],
)
num_image_token
=
(
grid_w
*
grid_h
+
1
)
*
hf_processor
.
image_seq_len
# Calculate Non-image-token length
# NOTE: <row_1_col_1> and <global-img> are special token for SmolVLM
# but not for Idefic3, so we need to tokenize them to get actual length.
tokenizer
=
self
.
get_tokenizer
()
tile_token_len
=
len
(
tokenizer
.
tokenize
(
"<row_1_col_1>"
))
glob_token_len
=
len
(
tokenizer
.
tokenize
(
hf_processor
.
global_image_tag
))
# linebreak and <fake_token_around_image> always cost 1 token
fake_token_len
=
lb_len
=
1
non_image_token
=
(
grid_w
*
grid_h
)
*
(
tile_token_len
+
fake_token_len
)
+
glob_token_len
+
(
grid_h
+
1
)
*
lb_len
+
fake_token_len
return
{
"image"
:
num_image_token
+
non_image_token
}
def
_resize_output_size
(
self
,
def
_resize_output_size
(
self
,
*
,
*
,
height
:
int
,
height
:
int
,
width
:
int
,
width
:
int
,
max_len
:
Optional
[
int
]
=
None
,
max_len
:
Optional
[
int
]
=
None
,
min_len
:
Optional
[
int
]
=
1
,
min_len
:
int
=
1
,
max_size
:
Optional
[
int
]
=
None
)
->
tuple
[
int
,
int
]:
max_size
:
Optional
[
int
]
=
None
)
->
tuple
[
int
,
int
]:
# Set default value for max_len if not provided
# Set default value for max_len if not provided
max_len
=
max
(
height
,
width
)
if
max_len
is
None
else
max_len
max_len
=
max
(
height
,
width
)
if
max_len
is
None
else
max_len
...
@@ -181,10 +182,13 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
...
@@ -181,10 +182,13 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
*
,
*
,
image_width
:
int
,
image_width
:
int
,
image_height
:
int
,
image_height
:
int
,
size
:
Optional
[
dict
[
str
,
object
]]
=
None
,
processor
:
Optional
[
Idefics3Processor
]
,
)
->
tuple
[
int
,
int
]:
)
->
tuple
[
int
,
int
]:
hf_processor
=
self
.
get_hf_processor
(
size
=
size
)
if
processor
is
None
:
image_processor
:
Idefics3ImageProcessor
=
hf_processor
.
image_processor
processor
=
self
.
get_hf_processor
()
image_processor
:
Idefics3ImageProcessor
=
processor
.
image_processor
max_image_size
=
image_processor
.
max_image_size
[
'longest_edge'
]
max_image_size
=
image_processor
.
max_image_size
[
'longest_edge'
]
size
=
image_processor
.
size
[
'longest_edge'
]
size
=
image_processor
.
size
[
'longest_edge'
]
assert
size
%
max_image_size
==
0
,
(
assert
size
%
max_image_size
==
0
,
(
...
@@ -204,6 +208,105 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
...
@@ -204,6 +208,105 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
grid_h
=
grid_w
=
0
grid_h
=
grid_w
=
0
return
grid_w
,
grid_h
return
grid_w
,
grid_h
def
get_num_patches
(
self
,
*
,
image_width
:
int
,
image_height
:
int
,
processor
:
Optional
[
Idefics3Processor
],
)
->
int
:
grid_w
,
grid_h
=
self
.
_get_image_feature_grid_size
(
image_width
=
image_width
,
image_height
=
image_height
,
processor
=
processor
,
)
return
grid_w
*
grid_h
+
1
def
get_image_repl
(
self
,
*
,
image_width
:
int
,
image_height
:
int
,
processor
:
Optional
[
Idefics3Processor
],
)
->
str
:
if
processor
is
None
:
processor
=
self
.
get_hf_processor
()
image_token
=
processor
.
image_token
.
content
fake_image_token
=
processor
.
fake_image_token
.
content
global_img_token
=
processor
.
global_image_tag
image_seq_len
=
processor
.
image_seq_len
grid_placeholder
=
"<row_{n_h}_col_{n_w}>"
p_img
=
image_token
*
image_seq_len
global_img_placeholder
=
fake_image_token
+
global_img_token
+
p_img
tile_img_placeholder
=
fake_image_token
+
grid_placeholder
+
p_img
grid_w
,
grid_h
=
self
.
_get_image_feature_grid_size
(
image_width
=
image_width
,
image_height
=
image_height
,
processor
=
processor
,
)
if
grid_w
==
0
and
grid_h
==
0
:
return
global_img_placeholder
+
fake_image_token
tiles_placeholder
=
list
[
str
]()
for
i
in
range
(
grid_h
):
for
j
in
range
(
grid_w
):
placeholder_per_tile
=
tile_img_placeholder
.
format
(
n_h
=
i
+
1
,
n_w
=
j
+
1
)
tiles_placeholder
.
append
(
placeholder_per_tile
)
# Add line break if it is the last tile in the row
if
j
==
grid_w
-
1
:
tiles_placeholder
.
append
(
"
\n
"
)
return
""
.
join
([
*
tiles_placeholder
,
"
\n
"
,
global_img_placeholder
,
fake_image_token
,
])
def
get_num_image_tokens
(
self
,
*
,
image_width
:
int
,
image_height
:
int
,
processor
:
Optional
[
Idefics3Processor
],
)
->
int
:
tokenizer
=
self
.
get_tokenizer
()
image_repl
=
self
.
get_image_repl
(
image_width
=
image_width
,
image_height
=
image_height
,
processor
=
processor
,
)
image_repl_tokens
=
encode_tokens
(
tokenizer
,
image_repl
,
add_special_tokens
=
False
,
)
return
len
(
image_repl_tokens
)
def
get_image_size_with_most_features
(
self
)
->
ImageSize
:
processor
=
self
.
get_hf_processor
()
image_processor
:
Idefics3ImageProcessor
=
processor
.
image_processor
return
ImageSize
(
width
=
image_processor
.
size
[
"longest_edge"
],
height
=
image_processor
.
size
[
"longest_edge"
],
)
def
get_max_image_tokens
(
self
)
->
int
:
target_width
,
target_height
=
self
.
get_image_size_with_most_features
()
return
self
.
get_num_image_tokens
(
image_width
=
target_width
,
image_height
=
target_height
,
processor
=
None
,
)
class
Idefics3DummyInputsBuilder
(
BaseDummyInputsBuilder
[
Idefics3ProcessingInfo
]
class
Idefics3DummyInputsBuilder
(
BaseDummyInputsBuilder
[
Idefics3ProcessingInfo
]
):
):
...
@@ -217,7 +320,7 @@ class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo]
...
@@ -217,7 +320,7 @@ class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo]
hf_processor
=
self
.
info
.
get_hf_processor
()
hf_processor
=
self
.
info
.
get_hf_processor
()
image_processor
:
Idefics3ImageProcessor
=
hf_processor
.
image_processor
image_processor
:
Idefics3ImageProcessor
=
hf_processor
.
image_processor
longest_edge
=
image_processor
.
max_image_size
[
'longest_edge'
]
longest_edge
=
image_processor
.
max_image_size
[
'longest_edge'
]
image_token
:
str
=
hf_processor
.
image_token
.
content
image_token
=
hf_processor
.
image_token
.
content
mm_data
=
{
mm_data
=
{
"image"
:
"image"
:
...
@@ -241,26 +344,61 @@ class Idefics3MultiModalProcessor(
...
@@ -241,26 +344,61 @@ class Idefics3MultiModalProcessor(
mm_data
:
Mapping
[
str
,
object
],
mm_data
:
Mapping
[
str
,
object
],
mm_kwargs
:
Mapping
[
str
,
object
],
mm_kwargs
:
Mapping
[
str
,
object
],
)
->
BatchFeature
:
)
->
BatchFeature
:
if
mm_data
:
# Text-only input not supported in composite processor
processed_outputs
=
super
().
_call_hf_processor
(
if
not
(
images
:
=
mm_data
.
get
(
"images"
,
[])):
prompt
,
mm_data
,
mm_kwargs
)
prompt_ids
=
self
.
info
.
get_tokenizer
().
encode
(
prompt
)
image_grids
=
[
prompt_ids
=
self
.
_apply_hf_processor_tokens_only
(
prompt_ids
)
self
.
info
.
_get_image_feature_grid_size
(
return
BatchFeature
(
dict
(
input_ids
=
[
prompt_ids
]),
tensor_type
=
"pt"
)
image_width
=
img
.
width
,
image_height
=
img
.
height
,
processed_outputs
=
super
().
_call_hf_processor
(
**
mm_kwargs
,
prompt
,
)
for
img
in
mm_data
[
"images"
]
mm_data
,
]
mm_kwargs
,
image_patches
=
list
(
map
(
lambda
x
:
math
.
prod
(
x
)
+
1
,
image_grids
))
)
for
key
in
(
"pixel_values"
,
"pixel_attention_mask"
):
data
=
processed_outputs
.
pop
(
key
)
parsed_images
=
(
self
.
_get_data_parser
().
parse_mm_data
({
data
=
data
.
flatten
(
0
,
1
).
split
(
image_patches
)
"image"
:
images
processed_outputs
[
key
]
=
data
}).
get_items
(
"image"
,
ImageProcessorItems
))
else
:
image_sizes
=
[
tokenizer
=
self
.
info
.
get_tokenizer
()
parsed_images
.
get_image_size
(
i
)
for
i
in
range
(
len
(
parsed_images
))
processed_outputs
=
tokenizer
(
prompt
,
]
add_special_tokens
=
True
,
hf_processor
=
self
.
info
.
get_hf_processor
(
**
mm_kwargs
)
return_tensors
=
"pt"
)
image_repl_features
=
[
self
.
info
.
get_image_repl
(
image_width
=
size
.
width
,
image_height
=
size
.
height
,
processor
=
hf_processor
)
for
size
in
image_sizes
]
tokenizer
=
self
.
info
.
get_tokenizer
()
image_repls_feature_tokens
=
[
tokenizer
.
encode
(
image_repl
,
add_special_tokens
=
False
)
for
image_repl
in
image_repl_features
]
vocab
=
tokenizer
.
get_vocab
()
image_token_id
=
vocab
[
hf_processor
.
image_token
.
content
]
embed_is_patch
=
[
torch
.
tensor
(
image_repl_tokens
)
==
image_token_id
for
image_repl_tokens
in
image_repls_feature_tokens
]
processed_outputs
[
"embed_is_patch"
]
=
embed_is_patch
num_patches
=
[
self
.
info
.
get_num_patches
(
image_width
=
size
.
width
,
image_height
=
size
.
height
,
processor
=
hf_processor
,
)
for
size
in
image_sizes
]
processed_outputs
[
"num_patches"
]
=
torch
.
tensor
(
num_patches
)
# Remove the extra batch dimension
processed_outputs
[
"pixel_values"
].
squeeze_
(
0
)
processed_outputs
[
"pixel_attention_mask"
].
squeeze_
(
0
)
return
processed_outputs
return
processed_outputs
def
_get_mm_fields_config
(
def
_get_mm_fields_config
(
...
@@ -268,10 +406,16 @@ class Idefics3MultiModalProcessor(
...
@@ -268,10 +406,16 @@ class Idefics3MultiModalProcessor(
hf_inputs
:
BatchFeature
,
hf_inputs
:
BatchFeature
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
num_patches
=
hf_inputs
.
get
(
"num_patches"
,
torch
.
empty
(
0
))
return
dict
(
return
dict
(
pixel_values
=
MultiModalFieldConfig
.
batched
(
"image"
),
pixel_values
=
MultiModalFieldConfig
.
flat_from_sizes
(
pixel_attention_mask
=
MultiModalFieldConfig
.
batched
(
"image"
),
"image"
,
num_patches
),
pixel_attention_mask
=
MultiModalFieldConfig
.
flat_from_sizes
(
"image"
,
num_patches
),
image_embeds
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_embeds
=
MultiModalFieldConfig
.
batched
(
"image"
),
num_patches
=
MultiModalFieldConfig
.
batched
(
"image"
),
embed_is_patch
=
MultiModalFieldConfig
.
batched
(
"image"
),
)
)
def
_get_prompt_updates
(
def
_get_prompt_updates
(
...
@@ -281,42 +425,18 @@ class Idefics3MultiModalProcessor(
...
@@ -281,42 +425,18 @@ class Idefics3MultiModalProcessor(
out_mm_kwargs
:
MultiModalKwargs
,
out_mm_kwargs
:
MultiModalKwargs
,
)
->
Sequence
[
PromptUpdate
]:
)
->
Sequence
[
PromptUpdate
]:
hf_processor
=
self
.
info
.
get_hf_processor
(
**
hf_processor_mm_kwargs
)
hf_processor
=
self
.
info
.
get_hf_processor
(
**
hf_processor_mm_kwargs
)
image_token
=
hf_processor
.
image_token
.
content
image_token
=
hf_processor
.
image_token
.
content
fake_image_token
=
hf_processor
.
fake_image_token
.
content
global_img_token
=
hf_processor
.
global_image_tag
image_seq_len
=
hf_processor
.
image_seq_len
grid_placeholder
=
"<row_{n_h}_col_{n_w}>"
p_img
=
image_token
*
image_seq_len
global_img_placeholder
=
fake_image_token
+
global_img_token
+
p_img
tile_img_placeholder
=
fake_image_token
+
grid_placeholder
+
p_img
def
get_replacement_idefics3
(
item_idx
:
int
)
->
str
:
def
get_replacement_idefics3
(
item_idx
:
int
)
->
str
:
images
=
mm_items
.
get_items
(
"image"
,
ImageProcessorItems
)
images
=
mm_items
.
get_items
(
"image"
,
ImageProcessorItems
)
image_size
=
images
.
get_image_size
(
item_idx
)
image_size
=
images
.
get_image_size
(
item_idx
)
grid_w
,
grid_h
=
self
.
info
.
_get_image_feature_grid_size
(
return
self
.
info
.
get_image_repl
(
image_width
=
image_size
.
width
,
image_width
=
image_size
.
width
,
image_height
=
image_size
.
height
,
image_height
=
image_size
.
height
,
**
hf_processor
_mm_kwargs
,
processor
=
hf_processor
,
)
)
if
grid_w
==
0
and
grid_h
==
0
:
image_placeholder
=
global_img_placeholder
else
:
tiles_placeholder
=
list
[
str
]()
for
i
in
range
(
grid_h
):
for
j
in
range
(
grid_w
):
placeholder_per_tile
=
tile_img_placeholder
.
format
(
n_h
=
i
+
1
,
n_w
=
j
+
1
)
tiles_placeholder
.
append
(
placeholder_per_tile
)
# Add line break if it is the last tile in the row
if
j
==
grid_w
-
1
:
tiles_placeholder
.
append
(
"
\n
"
)
image_placeholder
=
""
.
join
(
[
*
tiles_placeholder
,
"
\n
"
,
global_img_placeholder
])
return
image_placeholder
+
fake_image_token
return
[
return
[
PromptReplacement
(
PromptReplacement
(
...
@@ -424,73 +544,13 @@ class Idefics3Model(nn.Module):
...
@@ -424,73 +544,13 @@ class Idefics3Model(nn.Module):
config
.
vision_config
.
patch_size
)
**
2
)
/
(
config
.
scale_factor
**
2
))
config
.
vision_config
.
patch_size
)
**
2
)
/
(
config
.
scale_factor
**
2
))
self
.
image_token_id
=
self
.
config
.
image_token_id
self
.
image_token_id
=
self
.
config
.
image_token_id
def
_validate_pixel_values
(
def
image_pixels_to_features
(
self
,
data
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]
)
->
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]:
h
=
w
=
self
.
config
.
vision_config
.
image_size
expected_dims
=
(
3
,
h
,
w
)
def
_validate_shape
(
d
:
torch
.
Tensor
):
actual_dims
=
tuple
(
d
.
shape
[
1
:])
if
actual_dims
!=
expected_dims
:
expected_expr
=
(
"num_patches"
,
*
map
(
str
,
expected_dims
))
raise
ValueError
(
"The expected shape of pixel values per image per batch "
f
"is
{
expected_expr
}
. You supplied
{
tuple
(
d
.
shape
)
}
."
)
for
d
in
data
:
_validate_shape
(
d
)
return
data
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
ImageInputs
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
image_embeds
=
kwargs
.
pop
(
"image_embeds"
,
None
)
pixel_attention_mask
=
kwargs
.
pop
(
"pixel_attention_mask"
,
None
)
if
pixel_values
is
None
and
image_embeds
is
None
:
return
None
if
image_embeds
is
not
None
:
if
not
isinstance
(
image_embeds
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of image embeddings. "
f
"Got type:
{
type
(
image_embeds
)
}
"
)
return
Idefics3ImageEmbeddingInputs
(
type
=
"image_embeds"
,
data
=
flatten_bn
(
image_embeds
,
concat
=
True
),
)
if
pixel_values
is
not
None
:
if
not
isinstance
(
pixel_values
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
if
isinstance
(
pixel_values
,
list
):
pixel_values
=
torch
.
cat
(
pixel_values
,
dim
=
1
)
pixel_attention_mask
=
torch
.
cat
(
pixel_attention_mask
,
dim
=
1
)
else
:
pixel_values
=
flatten_bn
(
pixel_values
)
pixel_attention_mask
=
flatten_bn
(
pixel_attention_mask
)
return
Idefics3ImagePixelInputs
(
type
=
"pixel_values"
,
data
=
self
.
_validate_pixel_values
(
pixel_values
),
pixel_attention_mask
=
pixel_attention_mask
)
raise
AssertionError
(
"This line should be unreachable."
)
def
_image_pixels_to_features
(
self
,
self
,
pixel_values
:
torch
.
Tensor
,
pixel_values
:
torch
.
Tensor
,
pixel_attention_mask
:
Optional
[
torch
.
Bool
Tensor
]
=
None
,
pixel_attention_mask
:
torch
.
Tensor
,
)
->
Nested
Tensor
s
:
)
->
torch
.
Tensor
:
# NOTE: we skip the step to select the vision feature layer since
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
# this is already done inside the vision tower
num_patches
=
[
x
.
size
(
0
)
for
x
in
pixel_values
]
pixel_values
=
pixel_values
.
to
(
pixel_values
=
pixel_values
.
to
(
dtype
=
self
.
vision_model
.
embeddings
.
patch_embedding
.
weight
.
dtype
dtype
=
self
.
vision_model
.
embeddings
.
patch_embedding
.
weight
.
dtype
)
# fp16 compatibility
)
# fp16 compatibility
...
@@ -502,17 +562,9 @@ class Idefics3Model(nn.Module):
...
@@ -502,17 +562,9 @@ class Idefics3Model(nn.Module):
pixel_values
=
pixel_values
[
real_images_inds
].
contiguous
()
pixel_values
=
pixel_values
[
real_images_inds
].
contiguous
()
# Handle the vision attention mask
# Handle the vision attention mask
if
pixel_attention_mask
is
None
:
# Remove padding images from the mask
pixel_attention_mask
=
torch
.
ones
(
pixel_attention_mask
=
pixel_attention_mask
[
size
=
(
pixel_values
.
size
(
0
),
pixel_values
.
size
(
2
),
real_images_inds
].
contiguous
()
pixel_values
.
size
(
3
)),
dtype
=
torch
.
bool
,
device
=
pixel_values
.
device
,
)
else
:
# Remove padding images from the mask
pixel_attention_mask
=
pixel_attention_mask
[
real_images_inds
].
contiguous
()
patch_size
=
self
.
config
.
vision_config
.
patch_size
patch_size
=
self
.
config
.
vision_config
.
patch_size
patches_subgrid
=
pixel_attention_mask
.
unfold
(
dimension
=
1
,
patches_subgrid
=
pixel_attention_mask
.
unfold
(
dimension
=
1
,
...
@@ -529,27 +581,7 @@ class Idefics3Model(nn.Module):
...
@@ -529,27 +581,7 @@ class Idefics3Model(nn.Module):
patch_attention_mask
=
patch_attention_mask
,
patch_attention_mask
=
patch_attention_mask
,
)
)
return
image_hidden_states
.
split
(
num_patches
)
return
image_hidden_states
def
_process_image_pixels
(
self
,
inputs
:
Idefics3ImagePixelInputs
)
->
NestedTensors
:
assert
self
.
vision_model
is
not
None
pixel_values
=
inputs
[
"data"
]
pixel_attention_mask
=
inputs
[
"pixel_attention_mask"
]
return
self
.
_image_pixels_to_features
(
pixel_values
,
pixel_attention_mask
)
def
_process_image_input
(
self
,
image_input
:
ImageInputs
)
->
torch
.
Tensor
:
if
image_input
[
"type"
]
==
"image_embeds"
:
return
image_input
[
"data"
]
assert
self
.
vision_model
is
not
None
image_features
=
self
.
_process_image_pixels
(
image_input
)
num_patches
=
[
x
.
size
(
0
)
for
x
in
image_features
]
image_features
=
torch
.
cat
(
image_features
)
return
self
.
connector
(
image_features
).
split
(
num_patches
)
def
get_input_embeddings
(
def
get_input_embeddings
(
self
,
self
,
...
@@ -616,13 +648,113 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -616,13 +648,113 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
self
.
logits_processor
=
LogitsProcessor
(
config
.
text_config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
text_config
.
vocab_size
)
self
.
sampler
=
get_sampler
()
self
.
sampler
=
get_sampler
()
def
_validate_pixel_values
(
self
,
data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
h
=
w
=
self
.
config
.
vision_config
.
image_size
expected_dims
=
(
3
,
h
,
w
)
def
_validate_shape
(
d
:
torch
.
Tensor
):
actual_dims
=
tuple
(
d
.
shape
)
if
actual_dims
!=
expected_dims
:
expected_expr
=
str
(
expected_dims
)
raise
ValueError
(
"The expected shape of pixel values per image per batch "
f
" per patch is
{
expected_expr
}
. "
f
"You supplied
{
tuple
(
d
.
shape
)
}
."
)
for
d
in
data
:
_validate_shape
(
d
)
return
data
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
ImageInputs
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
image_embeds
=
kwargs
.
pop
(
"image_embeds"
,
None
)
if
pixel_values
is
None
and
image_embeds
is
None
:
return
None
embed_is_patch
=
kwargs
.
pop
(
"embed_is_patch"
)
if
not
isinstance
(
embed_is_patch
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of embed_is_patch. "
f
"Got type:
{
type
(
embed_is_patch
)
}
"
)
embed_is_patch
=
flatten_bn
(
embed_is_patch
)
if
image_embeds
is
not
None
:
if
not
isinstance
(
image_embeds
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of image embeddings. "
f
"Got type:
{
type
(
image_embeds
)
}
"
)
return
Idefics3ImageEmbeddingInputs
(
type
=
"image_embeds"
,
data
=
flatten_bn
(
image_embeds
,
concat
=
True
),
embed_is_patch
=
embed_is_patch
,
)
if
pixel_values
is
not
None
:
if
not
isinstance
(
pixel_values
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
pixel_attention_mask
=
kwargs
.
pop
(
"pixel_attention_mask"
)
if
not
isinstance
(
pixel_attention_mask
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of pixel_attention_mask. "
f
"Got type:
{
type
(
pixel_attention_mask
)
}
"
)
num_patches
=
kwargs
.
pop
(
"num_patches"
)
if
not
isinstance
(
num_patches
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of num_patches. "
f
"Got type:
{
type
(
num_patches
)
}
"
)
pixel_values
=
flatten_bn
(
pixel_values
,
concat
=
True
)
pixel_attention_mask
=
flatten_bn
(
pixel_attention_mask
,
concat
=
True
)
num_patches
=
flatten_bn
(
num_patches
,
concat
=
True
)
return
Idefics3ImagePixelInputs
(
type
=
"pixel_values"
,
pixel_values
=
self
.
_validate_pixel_values
(
pixel_values
),
pixel_attention_mask
=
pixel_attention_mask
,
num_patches
=
num_patches
,
embed_is_patch
=
embed_is_patch
,
)
raise
AssertionError
(
"This line should be unreachable."
)
def
_process_image_pixels
(
self
,
inputs
:
Idefics3ImagePixelInputs
)
->
torch
.
Tensor
:
pixel_values
=
inputs
[
"pixel_values"
]
pixel_attention_mask
=
inputs
[
"pixel_attention_mask"
]
return
self
.
model
.
image_pixels_to_features
(
pixel_values
,
pixel_attention_mask
=
pixel_attention_mask
,
)
def
_process_image_input
(
self
,
image_input
:
ImageInputs
)
->
torch
.
Tensor
:
if
image_input
[
"type"
]
==
"image_embeds"
:
return
image_input
[
"data"
]
image_features
=
self
.
_process_image_pixels
(
image_input
)
image_features
=
self
.
model
.
connector
(
image_features
)
num_patches
=
image_input
[
"num_patches"
]
return
image_features
.
split
(
num_patches
.
tolist
())
def
get_multimodal_embeddings
(
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
image_input
=
self
.
model
.
_parse_and_validate_image_input
(
**
kwargs
)
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
if
image_input
is
None
:
return
None
return
None
vision_embeddings
=
self
.
model
.
_process_image_input
(
image_input
)
return
vision_embeddings
image_features
=
self
.
_process_image_input
(
image_input
)
return
scatter_patch_features
(
image_features
,
image_input
[
"embed_is_patch"
],
)
def
get_input_embeddings
(
def
get_input_embeddings
(
self
,
self
,
...
@@ -632,8 +764,11 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -632,8 +764,11 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds
=
self
.
model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
self
.
model
.
get_input_embeddings
(
input_ids
)
if
multimodal_embeddings
is
not
None
:
if
multimodal_embeddings
is
not
None
:
inputs_embeds
=
merge_multimodal_embeddings
(
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
multimodal_embeddings
,
input_ids
,
self
.
config
.
image_token_id
)
inputs_embeds
,
select_patch_features
(
multimodal_embeddings
),
self
.
config
.
image_token_id
,
)
return
inputs_embeds
return
inputs_embeds
def
forward
(
def
forward
(
...
...
vllm/model_executor/models/mllama.py
View file @
541d1df4
...
@@ -21,7 +21,6 @@ from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
...
@@ -21,7 +21,6 @@ from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
torch.utils.checkpoint
import
transformers.models.mllama.configuration_mllama
as
config_mllama
import
transformers.models.mllama.configuration_mllama
as
config_mllama
from
PIL.Image
import
Image
from
PIL.Image
import
Image
from
torch
import
nn
from
torch
import
nn
...
...
vllm/model_executor/models/qwen2_audio.py
View file @
541d1df4
...
@@ -160,7 +160,7 @@ class Qwen2AudioMultiModalProcessor(
...
@@ -160,7 +160,7 @@ class Qwen2AudioMultiModalProcessor(
mm_kwargs
:
Mapping
[
str
,
Any
],
mm_kwargs
:
Mapping
[
str
,
Any
],
)
->
BatchFeature
:
)
->
BatchFeature
:
# Text-only input not supported in composite processor
# Text-only input not supported in composite processor
if
not
mm_data
or
not
mm_data
.
get
(
"audios"
,
[]):
if
not
mm_data
.
get
(
"audios"
,
[]):
prompt_ids
=
self
.
info
.
get_tokenizer
().
encode
(
prompt
)
prompt_ids
=
self
.
info
.
get_tokenizer
().
encode
(
prompt
)
prompt_ids
=
self
.
_apply_hf_processor_tokens_only
(
prompt_ids
)
prompt_ids
=
self
.
_apply_hf_processor_tokens_only
(
prompt_ids
)
return
BatchFeature
(
dict
(
input_ids
=
[
prompt_ids
]),
tensor_type
=
"pt"
)
return
BatchFeature
(
dict
(
input_ids
=
[
prompt_ids
]),
tensor_type
=
"pt"
)
...
...
vllm/model_executor/models/ultravox.py
View file @
541d1df4
...
@@ -8,7 +8,6 @@ from functools import cached_property
...
@@ -8,7 +8,6 @@ from functools import cached_property
from
typing
import
Any
,
Literal
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
from
typing
import
Any
,
Literal
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
import
torch
import
torch
import
torch.utils.checkpoint
from
torch
import
nn
from
torch
import
nn
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
from
transformers
import
BatchFeature
,
ProcessorMixin
from
transformers
import
BatchFeature
,
ProcessorMixin
...
@@ -160,7 +159,7 @@ class UltravoxMultiModalProcessor(
...
@@ -160,7 +159,7 @@ class UltravoxMultiModalProcessor(
mm_kwargs
:
Mapping
[
str
,
object
],
mm_kwargs
:
Mapping
[
str
,
object
],
)
->
BatchFeature
:
)
->
BatchFeature
:
# Text-only input not supported in composite processor
# Text-only input not supported in composite processor
if
not
mm_data
or
not
mm_data
.
get
(
"audios"
,
[]):
if
not
mm_data
.
get
(
"audios"
,
[]):
prompt_ids
=
self
.
info
.
get_tokenizer
().
encode
(
prompt_ids
=
self
.
info
.
get_tokenizer
().
encode
(
prompt
,
add_special_tokens
=
False
)
prompt
,
add_special_tokens
=
False
)
prompt_ids
=
self
.
_apply_hf_processor_tokens_only
(
prompt_ids
)
prompt_ids
=
self
.
_apply_hf_processor_tokens_only
(
prompt_ids
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment