Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
af51d80f
"vllm/vscode:/vscode.git/clone" did not exist on "3112271f6e5d50b3d94a2efa88a5a8e77826b897"
Unverified
Commit
af51d80f
authored
Apr 04, 2025
by
Roger Wang
Committed by
GitHub
Apr 04, 2025
Browse files
Revert "[V1] Scatter and gather placeholders in the model runner" (#16075)
parent
f5722a50
Changes
42
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
469 additions
and
219 deletions
+469
-219
vllm/model_executor/models/minicpmv.py
vllm/model_executor/models/minicpmv.py
+101
-41
vllm/model_executor/models/mistral3.py
vllm/model_executor/models/mistral3.py
+44
-6
vllm/model_executor/models/molmo.py
vllm/model_executor/models/molmo.py
+61
-13
vllm/model_executor/models/nvlm_d.py
vllm/model_executor/models/nvlm_d.py
+30
-2
vllm/model_executor/models/paligemma.py
vllm/model_executor/models/paligemma.py
+3
-3
vllm/model_executor/models/phi3v.py
vllm/model_executor/models/phi3v.py
+8
-3
vllm/model_executor/models/pixtral.py
vllm/model_executor/models/pixtral.py
+40
-8
vllm/model_executor/models/qwen2_audio.py
vllm/model_executor/models/qwen2_audio.py
+3
-3
vllm/model_executor/models/qwen_vl.py
vllm/model_executor/models/qwen_vl.py
+3
-3
vllm/model_executor/models/skyworkr1v.py
vllm/model_executor/models/skyworkr1v.py
+39
-3
vllm/model_executor/models/vision.py
vllm/model_executor/models/vision.py
+76
-1
vllm/multimodal/base.py
vllm/multimodal/base.py
+2
-2
vllm/multimodal/inputs.py
vllm/multimodal/inputs.py
+3
-29
vllm/multimodal/processing.py
vllm/multimodal/processing.py
+25
-52
vllm/multimodal/profiling.py
vllm/multimodal/profiling.py
+1
-1
vllm/multimodal/utils.py
vllm/multimodal/utils.py
+1
-1
vllm/v1/core/kv_cache_utils.py
vllm/v1/core/kv_cache_utils.py
+4
-3
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+4
-4
vllm/v1/request.py
vllm/v1/request.py
+1
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+20
-40
No files found.
vllm/model_executor/models/minicpmv.py
View file @
af51d80f
...
@@ -56,7 +56,7 @@ from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem,
...
@@ -56,7 +56,7 @@ from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem,
VideoItem
,
VideoProcessorItems
)
VideoItem
,
VideoProcessorItems
)
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptReplacement
,
BaseProcessingInfo
,
PromptReplacement
,
PromptUpdate
,
PromptUpdateDetails
)
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
...
@@ -67,6 +67,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
...
@@ -67,6 +67,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal
,
SupportsPP
)
SupportsMultiModal
,
SupportsPP
)
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
maybe_prefix
,
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
maybe_prefix
,
merge_multimodal_embeddings
)
merge_multimodal_embeddings
)
from
.vision
import
scatter_patch_features
,
select_patch_features
# For profile run
# For profile run
_MAX_FRAMES_PER_VIDEO
=
16
_MAX_FRAMES_PER_VIDEO
=
16
...
@@ -89,6 +90,14 @@ class MiniCPMVImagePixelInputs(TypedDict):
...
@@ -89,6 +90,14 @@ class MiniCPMVImagePixelInputs(TypedDict):
This should be in `(height, width)` format.
This should be in `(height, width)` format.
"""
"""
embed_is_patch
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
num_slices
:
torch
.
Tensor
num_slices
:
torch
.
Tensor
"""Shape: `(batch_size * num_images)`"""
"""Shape: `(batch_size * num_images)`"""
...
@@ -103,6 +112,14 @@ class MiniCPMVImageEmbeddingInputs(TypedDict):
...
@@ -103,6 +112,14 @@ class MiniCPMVImageEmbeddingInputs(TypedDict):
instead of a batched tensor.
instead of a batched tensor.
"""
"""
embed_is_patch
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
MiniCPMVImageInputs
=
Union
[
MiniCPMVImagePixelInputs
,
MiniCPMVImageInputs
=
Union
[
MiniCPMVImagePixelInputs
,
MiniCPMVImageEmbeddingInputs
]
MiniCPMVImageEmbeddingInputs
]
...
@@ -228,10 +245,12 @@ def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]):
...
@@ -228,10 +245,12 @@ def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]):
image_sizes
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_sizes
=
MultiModalFieldConfig
.
batched
(
"image"
),
tgt_sizes
=
MultiModalFieldConfig
.
batched
(
"image"
),
tgt_sizes
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_embeds
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_embeds
=
MultiModalFieldConfig
.
batched
(
"image"
),
embed_is_patch
=
MultiModalFieldConfig
.
batched
(
"image"
),
video_pixel_values
=
MultiModalFieldConfig
.
batched
(
"video"
),
video_pixel_values
=
MultiModalFieldConfig
.
batched
(
"video"
),
video_image_sizes
=
MultiModalFieldConfig
.
batched
(
"video"
),
video_image_sizes
=
MultiModalFieldConfig
.
batched
(
"video"
),
video_tgt_sizes
=
MultiModalFieldConfig
.
batched
(
"video"
),
video_tgt_sizes
=
MultiModalFieldConfig
.
batched
(
"video"
),
video_embeds
=
MultiModalFieldConfig
.
batched
(
"video"
),
video_embeds
=
MultiModalFieldConfig
.
batched
(
"video"
),
video_embed_is_patch
=
MultiModalFieldConfig
.
batched
(
"video"
),
image_token_id
=
MultiModalFieldConfig
.
shared
(
"image"
,
num_images
),
image_token_id
=
MultiModalFieldConfig
.
shared
(
"image"
,
num_images
),
video_token_id
=
MultiModalFieldConfig
.
shared
(
"video"
,
num_videos
),
video_token_id
=
MultiModalFieldConfig
.
shared
(
"video"
,
num_videos
),
)
)
...
@@ -379,43 +398,22 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
...
@@ -379,43 +398,22 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
use_image_id
=
use_image_id
,
use_image_id
=
use_image_id
,
)
)
def
get_sliced_grid
(
self
,
image_size
:
ImageSize
,
# For MiniCPM V/O 2.6
max_slice_nums
:
Optional
[
int
]
=
None
,
)
->
Optional
[
tuple
[
int
,
int
]]:
image_processor
=
self
.
get_image_processor
()
version
=
self
.
get_model_version
()
if
version
==
(
2
,
0
)
or
version
==
(
2
,
5
):
return
image_processor
.
get_sliced_grid
(
image_size
)
if
max_slice_nums
is
None
:
max_slice_nums
=
image_processor
.
max_slice_nums
return
image_processor
.
get_sliced_grid
(
image_size
,
max_slice_nums
=
max_slice_nums
,
)
def
get_num_image_tokens
(
def
get_num_image_tokens
(
self
,
self
,
image_size
:
ImageSize
,
image_size
:
ImageSize
,
max_slice_nums
:
Optional
[
int
]
=
None
,
max_slice_nums
:
Optional
[
int
]
=
None
,
use_image_id
:
bool
=
True
,
)
->
int
:
)
->
int
:
image_processor
=
self
.
get_image_processor
()
tokenizer
=
self
.
get_tokenizer
()
image_placeholders
=
self
.
get_slice_image_placeholder
(
grid
=
self
.
get_sliced_grid
(
image_size
,
image_size
,
max_slice_nums
=
max_slice_nums
,
max_slice_nums
=
max_slice_nums
,
use_image_id
=
use_image_id
,
)
)
if
grid
is
None
:
image_token_ids
=
tokenizer
.
encode
(
image_placeholders
,
ncols
=
nrows
=
0
add_special_tokens
=
False
)
else
:
ncols
,
nrows
=
grid
return
(
ncols
*
nrows
+
1
)
*
image_processor
.
image_feature_size
return
len
(
image_token_ids
)
def
get_max_image_tokens
(
self
)
->
int
:
def
get_max_image_tokens
(
self
)
->
int
:
image_size
=
self
.
get_image_size_with_most_features
()
image_size
=
self
.
get_image_size_with_most_features
()
...
@@ -435,6 +433,7 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
...
@@ -435,6 +433,7 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
return
self
.
get_num_image_tokens
(
return
self
.
get_num_image_tokens
(
frame_size
,
frame_size
,
max_slice_nums
=
self
.
get_video_max_slice_num
(),
max_slice_nums
=
self
.
get_video_max_slice_num
(),
use_image_id
=
False
,
)
)
def
get_max_video_tokens
(
def
get_max_video_tokens
(
...
@@ -540,6 +539,14 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
...
@@ -540,6 +539,14 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
use_image_id
=
False
,
use_image_id
=
False
,
)
*
num_frames
)
*
num_frames
def
get_embed_is_patch
(
self
,
input_ids
:
list
[
int
],
)
->
torch
.
Tensor
:
tokenizer
=
self
.
info
.
get_tokenizer
()
unk_token_id
=
tokenizer
.
get_vocab
()[
"<unk>"
]
return
torch
.
tensor
(
input_ids
)
==
unk_token_id
def
process_images
(
def
process_images
(
self
,
self
,
mm_data
:
Mapping
[
str
,
object
],
mm_data
:
Mapping
[
str
,
object
],
...
@@ -563,7 +570,26 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
...
@@ -563,7 +570,26 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
out_keys
=
{
"pixel_values"
,
"image_sizes"
,
"tgt_sizes"
},
out_keys
=
{
"pixel_values"
,
"image_sizes"
,
"tgt_sizes"
},
)
)
image_sizes
=
[
parsed_images
.
get_image_size
(
i
)
for
i
in
range
(
len
(
parsed_images
))
]
image_repl_features
=
[
self
.
get_image_prompt_texts
(
size
,
idx
)
for
idx
,
size
in
enumerate
(
image_sizes
)
]
tokenizer
=
self
.
info
.
get_tokenizer
()
tokenizer
=
self
.
info
.
get_tokenizer
()
image_repls_feature_tokens
=
[
tokenizer
.
encode
(
image_repl
,
add_special_tokens
=
False
)
for
image_repl
in
image_repl_features
]
embed_is_patch
=
[
self
.
get_embed_is_patch
(
image_repl_tokens
)
for
image_repl_tokens
in
image_repls_feature_tokens
]
image_inputs
[
"embed_is_patch"
]
=
embed_is_patch
unk_token_id
=
tokenizer
.
get_vocab
()[
"<unk>"
]
unk_token_id
=
tokenizer
.
get_vocab
()[
"<unk>"
]
image_inputs
[
"image_token_id"
]
=
torch
.
tensor
(
unk_token_id
)
image_inputs
[
"image_token_id"
]
=
torch
.
tensor
(
unk_token_id
)
...
@@ -599,9 +625,31 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
...
@@ -599,9 +625,31 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
out_keys
=
{
"pixel_values"
,
"image_sizes"
,
"tgt_sizes"
},
out_keys
=
{
"pixel_values"
,
"image_sizes"
,
"tgt_sizes"
},
)
)
video_inputs
=
{
f
"video_
{
k
}
"
:
v
for
k
,
v
in
video_inputs
.
items
()}
frame_sizes
=
[
parsed_videos
.
get_frame_size
(
i
)
for
i
in
range
(
len
(
parsed_videos
))
]
num_frames
=
[
parsed_videos
.
get_num_frames
(
i
)
for
i
in
range
(
len
(
parsed_videos
))
]
video_repl_features
=
[
self
.
get_video_prompt_texts
(
size
,
nframes
)
for
size
,
nframes
in
zip
(
frame_sizes
,
num_frames
)
]
tokenizer
=
self
.
info
.
get_tokenizer
()
tokenizer
=
self
.
info
.
get_tokenizer
()
video_repls_feature_tokens
=
[
tokenizer
.
encode
(
video_repl
,
add_special_tokens
=
False
)
for
video_repl
in
video_repl_features
]
embed_is_patch
=
[
self
.
get_embed_is_patch
(
video_repl_tokens
)
for
video_repl_tokens
in
video_repls_feature_tokens
]
video_inputs
[
"embed_is_patch"
]
=
embed_is_patch
video_inputs
=
{
f
"video_
{
k
}
"
:
v
for
k
,
v
in
video_inputs
.
items
()}
unk_token_id
=
tokenizer
.
get_vocab
()[
"<unk>"
]
unk_token_id
=
tokenizer
.
get_vocab
()[
"<unk>"
]
video_inputs
[
"video_token_id"
]
=
torch
.
tensor
(
unk_token_id
)
video_inputs
[
"video_token_id"
]
=
torch
.
tensor
(
unk_token_id
)
...
@@ -692,10 +740,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
...
@@ -692,10 +740,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
image_size
=
images
.
get_image_size
(
item_idx
)
image_size
=
images
.
get_image_size
(
item_idx
)
return
PromptUpdateDetails
.
select_text
(
return
self
.
get_image_prompt_texts
(
image_size
,
item_idx
)
self
.
get_image_prompt_texts
(
image_size
,
item_idx
),
"<unk>"
,
)
def
get_video_replacement
(
item_idx
:
int
):
def
get_video_replacement
(
item_idx
:
int
):
videos
=
mm_items
.
get_items
(
videos
=
mm_items
.
get_items
(
...
@@ -704,10 +749,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
...
@@ -704,10 +749,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
frame_size
=
videos
.
get_frame_size
(
item_idx
)
frame_size
=
videos
.
get_frame_size
(
item_idx
)
num_frames
=
videos
.
get_num_frames
(
item_idx
)
num_frames
=
videos
.
get_num_frames
(
item_idx
)
return
PromptUpdateDetails
.
select_text
(
return
self
.
get_video_prompt_texts
(
frame_size
,
num_frames
)
self
.
get_video_prompt_texts
(
frame_size
,
num_frames
),
"<unk>"
,
)
get_replacement
=
{
get_replacement
=
{
"image"
:
get_image_replacement
,
"image"
:
get_image_replacement
,
...
@@ -790,6 +832,14 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -790,6 +832,14 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
assert
isinstance
(
image_token_id
,
torch
.
Tensor
)
assert
isinstance
(
image_token_id
,
torch
.
Tensor
)
self
.
mm_token_ids
.
add
(
image_token_id
.
flatten
().
unique
().
item
())
self
.
mm_token_ids
.
add
(
image_token_id
.
flatten
().
unique
().
item
())
embed_is_patch
=
kwargs
.
pop
(
"embed_is_patch"
)
if
not
isinstance
(
embed_is_patch
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
f
"Incorrect type of embed_is_patch for
{
modality
=
}
. "
f
"Got type:
{
type
(
embed_is_patch
)
}
"
)
embed_is_patch
=
flatten_bn
(
embed_is_patch
)
if
image_embeds
is
not
None
:
if
image_embeds
is
not
None
:
if
not
isinstance
(
image_embeds
,
(
torch
.
Tensor
,
list
)):
if
not
isinstance
(
image_embeds
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
raise
ValueError
(
...
@@ -801,6 +851,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -801,6 +851,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
return
MiniCPMVImageEmbeddingInputs
(
return
MiniCPMVImageEmbeddingInputs
(
type
=
"image_embeds"
,
type
=
"image_embeds"
,
image_embeds
=
image_embeds_flat
,
image_embeds
=
image_embeds_flat
,
embed_is_patch
=
embed_is_patch
,
)
)
if
not
isinstance
(
pixel_values
,
(
torch
.
Tensor
,
list
)):
if
not
isinstance
(
pixel_values
,
(
torch
.
Tensor
,
list
)):
...
@@ -828,6 +879,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -828,6 +879,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
type
=
"pixel_values"
,
type
=
"pixel_values"
,
pixel_values
=
pixel_values_flat
,
pixel_values
=
pixel_values_flat
,
tgt_sizes
=
tgt_sizes_flat
,
tgt_sizes
=
tgt_sizes_flat
,
embed_is_patch
=
embed_is_patch
,
num_slices
=
num_slices_flat
,
num_slices
=
num_slices_flat
,
)
)
...
@@ -884,11 +936,19 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -884,11 +936,19 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
if
modality
==
"images"
:
if
modality
==
"images"
:
image_input
=
modalities
[
"images"
]
image_input
=
modalities
[
"images"
]
image_features
=
self
.
_process_vision_input
(
image_input
)
image_features
=
self
.
_process_vision_input
(
image_input
)
multimodal_embeddings
+=
tuple
(
image_features
)
multimodal_embeddings
+=
tuple
(
scatter_patch_features
(
image_features
,
image_input
[
"embed_is_patch"
],
))
if
modality
==
"videos"
:
if
modality
==
"videos"
:
video_input
=
modalities
[
"videos"
]
video_input
=
modalities
[
"videos"
]
video_features
=
self
.
_process_vision_input
(
video_input
)
video_features
=
self
.
_process_vision_input
(
video_input
)
multimodal_embeddings
+=
tuple
(
video_features
)
multimodal_embeddings
+=
tuple
(
scatter_patch_features
(
video_features
,
video_input
[
"embed_is_patch"
],
))
return
multimodal_embeddings
return
multimodal_embeddings
...
@@ -911,7 +971,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -911,7 +971,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds
=
merge_multimodal_embeddings
(
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
input_ids
,
inputs_embeds
,
inputs_embeds
,
multimodal_embeddings
,
select_patch_features
(
multimodal_embeddings
)
,
list
(
self
.
mm_token_ids
),
list
(
self
.
mm_token_ids
),
)
)
return
inputs_embeds
return
inputs_embeds
...
...
vllm/model_executor/models/mistral3.py
View file @
af51d80f
...
@@ -27,8 +27,7 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
...
@@ -27,8 +27,7 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems
)
MultiModalDataItems
)
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
ProcessingCache
,
BaseProcessingInfo
,
ProcessingCache
,
PromptReplacement
,
PromptUpdate
,
PromptReplacement
,
PromptUpdate
)
PromptUpdateDetails
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
...
@@ -36,7 +35,8 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
...
@@ -36,7 +35,8 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from
.pixtral
import
PixtralHFEncoderInfo
,
PixtralHFVisionModel
from
.pixtral
import
PixtralHFEncoderInfo
,
PixtralHFVisionModel
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
init_vllm_registered_model
,
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
init_vllm_registered_model
,
maybe_prefix
,
merge_multimodal_embeddings
)
maybe_prefix
,
merge_multimodal_embeddings
)
from
.vision
import
get_vision_encoder_info
from
.vision
import
(
get_vision_encoder_info
,
scatter_patch_features
,
select_patch_features
)
class
Mistral3ImagePixelInputs
(
TypedDict
):
class
Mistral3ImagePixelInputs
(
TypedDict
):
...
@@ -49,6 +49,14 @@ class Mistral3ImagePixelInputs(TypedDict):
...
@@ -49,6 +49,14 @@ class Mistral3ImagePixelInputs(TypedDict):
in which case the data is passed as a list instead of a batched tensor.
in which case the data is passed as a list instead of a batched tensor.
"""
"""
embed_is_patch
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size, num_images, num_embeds)`
"""
class
Mistral3PatchMerger
(
nn
.
Module
):
class
Mistral3PatchMerger
(
nn
.
Module
):
"""
"""
...
@@ -258,6 +266,23 @@ class Mistral3MultiModalProcessor(
...
@@ -258,6 +266,23 @@ class Mistral3MultiModalProcessor(
p
[:,
:
h
,
:
w
]
for
p
,
(
h
,
w
)
in
zip
(
pixel_values
,
image_sizes
)
p
[:,
:
h
,
:
w
]
for
p
,
(
h
,
w
)
in
zip
(
pixel_values
,
image_sizes
)
]
]
hf_config
=
self
.
info
.
get_hf_config
()
vision_config
=
hf_config
.
vision_config
assert
isinstance
(
vision_config
,
PixtralVisionConfig
)
encoder_info
=
PixtralHFEncoderInfo
(
vision_config
)
tile_sizes
=
[
encoder_info
.
get_patch_grid_size
(
image_width
=
pixel_value
.
shape
[
-
1
],
image_height
=
pixel_value
.
shape
[
-
2
],
)
for
pixel_value
in
processed_outputs
[
"pixel_values"
]
]
embed_is_patch
=
[
torch
.
tensor
(([
True
]
*
ncols
+
[
False
])
*
nrows
)
for
ncols
,
nrows
in
tile_sizes
]
processed_outputs
[
"embed_is_patch"
]
=
embed_is_patch
return
processed_outputs
return
processed_outputs
def
_get_mm_fields_config
(
def
_get_mm_fields_config
(
...
@@ -267,6 +292,7 @@ class Mistral3MultiModalProcessor(
...
@@ -267,6 +292,7 @@ class Mistral3MultiModalProcessor(
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
return
dict
(
return
dict
(
pixel_values
=
MultiModalFieldConfig
.
batched
(
"image"
),
pixel_values
=
MultiModalFieldConfig
.
batched
(
"image"
),
embed_is_patch
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_embeds
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_embeds
=
MultiModalFieldConfig
.
batched
(
"image"
),
)
)
...
@@ -301,7 +327,7 @@ class Mistral3MultiModalProcessor(
...
@@ -301,7 +327,7 @@ class Mistral3MultiModalProcessor(
tokens
=
([
image_token_id
]
*
ncols
+
[
image_break_id
])
*
nrows
tokens
=
([
image_token_id
]
*
ncols
+
[
image_break_id
])
*
nrows
tokens
[
-
1
]
=
image_end_id
tokens
[
-
1
]
=
image_end_id
return
PromptUpdateDetails
.
select_token_id
(
tokens
,
image_token_id
)
return
tokens
return
[
return
[
PromptReplacement
(
PromptReplacement
(
...
@@ -392,6 +418,8 @@ def init_vision_tower_for_llava(
...
@@ -392,6 +418,8 @@ def init_vision_tower_for_llava(
)
)
# TODO(mgoin): Support V1, there are issues with image batching/chunking
# that need to be resolved first.
@
MULTIMODAL_REGISTRY
.
register_processor
(
@
MULTIMODAL_REGISTRY
.
register_processor
(
_build_mistral3_processor
,
_build_mistral3_processor
,
info
=
_build_mistral3_info
,
info
=
_build_mistral3_info
,
...
@@ -481,9 +509,16 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -481,9 +509,16 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
raise
ValueError
(
"Incorrect type of pixel values. "
raise
ValueError
(
"Incorrect type of pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
f
"Got type:
{
type
(
pixel_values
)
}
"
)
assert
self
.
config
.
vision_config
.
model_type
==
"pixtral"
embed_is_patch
=
kwargs
.
pop
(
"embed_is_patch"
)
if
not
isinstance
(
embed_is_patch
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of embed_is_patch. "
f
"Got type:
{
type
(
embed_is_patch
)
}
"
)
return
Mistral3ImagePixelInputs
(
return
Mistral3ImagePixelInputs
(
type
=
"pixel_values_pixtral"
,
type
=
"pixel_values_pixtral"
,
pixel_values
=
flatten_bn
(
pixel_values
),
pixel_values
=
flatten_bn
(
pixel_values
),
embed_is_patch
=
flatten_bn
(
embed_is_patch
),
)
)
def
_process_image_input
(
def
_process_image_input
(
...
@@ -522,7 +557,10 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -522,7 +557,10 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
return
vision_embeddings
return
scatter_patch_features
(
vision_embeddings
,
image_input
[
"embed_is_patch"
],
)
def
get_input_embeddings
(
def
get_input_embeddings
(
self
,
self
,
...
@@ -534,7 +572,7 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -534,7 +572,7 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds
=
merge_multimodal_embeddings
(
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
input_ids
,
inputs_embeds
,
inputs_embeds
,
multimodal_embeddings
,
select_patch_features
(
multimodal_embeddings
)
,
self
.
config
.
image_token_index
,
self
.
config
.
image_token_index
,
)
)
return
inputs_embeds
return
inputs_embeds
...
...
vllm/model_executor/models/molmo.py
View file @
af51d80f
...
@@ -46,8 +46,7 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
...
@@ -46,8 +46,7 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems
)
MultiModalDataItems
)
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptIndexTargets
,
BaseProcessingInfo
,
PromptIndexTargets
,
PromptInsertion
,
PromptUpdate
,
PromptInsertion
,
PromptUpdate
)
PromptUpdateDetails
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
...
@@ -57,6 +56,7 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
...
@@ -57,6 +56,7 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
is_pp_missing_parameter
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
,
merge_multimodal_embeddings
)
maybe_prefix
,
merge_multimodal_embeddings
)
from
.vision
import
scatter_patch_features
,
select_patch_features
# TODO: hard-coded for now. Consider making it configurable.
# TODO: hard-coded for now. Consider making it configurable.
VIT_LAYERS
=
[
-
2
,
-
9
]
VIT_LAYERS
=
[
-
2
,
-
9
]
...
@@ -84,6 +84,14 @@ class MolmoImageInputs(TypedDict):
...
@@ -84,6 +84,14 @@ class MolmoImageInputs(TypedDict):
Shape: `(batch_size * num_images, num_crops, num_patch)`
Shape: `(batch_size * num_images, num_crops, num_patch)`
"""
"""
embed_is_patch
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
num_crops
:
torch
.
Tensor
num_crops
:
torch
.
Tensor
"""Shape: `(batch_size * num_images)`"""
"""Shape: `(batch_size * num_images)`"""
...
@@ -1138,6 +1146,30 @@ class MolmoProcessorWrapper:
...
@@ -1138,6 +1146,30 @@ class MolmoProcessorWrapper:
if
image_input_idx
is
not
None
:
if
image_input_idx
is
not
None
:
feat_is_patch
=
image_input_idx
>=
0
feat_is_patch
=
image_input_idx
>=
0
input_is_embed
=
torch
.
isin
(
input_ids
,
torch
.
tensor
([
self
.
image_patch_id
,
self
.
im_col_id
,
self
.
im_start_id
,
self
.
im_end_id
,
]),
)
embed_ids
=
input_ids
[
input_is_embed
]
embed_is_patch
=
embed_ids
==
self
.
image_patch_id
assert
embed_is_patch
.
sum
()
==
feat_is_patch
.
sum
()
# image_tokens = extra_joint + joint
# Both `extra_joint` and `joint` have `im_start_id` and `im_end_id`
embed_start
=
torch
.
nonzero
(
embed_ids
==
self
.
im_start_id
)[::
2
,
0
]
embed_end
=
torch
.
nonzero
(
embed_ids
==
self
.
im_end_id
)[
1
::
2
,
0
]
assert
len
(
embed_start
)
==
len
(
embed_end
)
==
len
(
images
)
embed_is_patch
=
[
embed_is_patch
[
start
:
end
+
1
]
for
start
,
end
in
zip
(
embed_start
,
embed_end
)
]
tilings
=
[
tilings
=
[
self
.
select_tiling
(
self
.
select_tiling
(
image_width
=
image
.
size
[
0
],
image_width
=
image
.
size
[
0
],
...
@@ -1149,6 +1181,7 @@ class MolmoProcessorWrapper:
...
@@ -1149,6 +1181,7 @@ class MolmoProcessorWrapper:
assert
num_crops
.
sum
()
==
len
(
feat_is_patch
)
assert
num_crops
.
sum
()
==
len
(
feat_is_patch
)
outputs
[
"feat_is_patch"
]
=
feat_is_patch
outputs
[
"feat_is_patch"
]
=
feat_is_patch
outputs
[
"embed_is_patch"
]
=
embed_is_patch
outputs
[
"num_crops"
]
=
num_crops
outputs
[
"num_crops"
]
=
num_crops
outputs
[
"img_patch_id"
]
=
self
.
image_patch_id
outputs
[
"img_patch_id"
]
=
self
.
image_patch_id
...
@@ -1187,13 +1220,17 @@ class MolmoProcessingInfo(BaseProcessingInfo):
...
@@ -1187,13 +1220,17 @@ class MolmoProcessingInfo(BaseProcessingInfo):
)
)
pooling_size
=
processor
.
pooling_size
pooling_size
=
processor
.
pooling_size
image_
token_length_w
=
processor
.
image_
token_length_w
base_
image_
input_size
=
processor
.
base_
image_
input_size
image_
token_length_h
=
processor
.
image_
token_length_h
base_
image_
input_d
=
processor
.
image_
patch_size
extra
=
image_token_length_w
*
image_token_length_h
crop_patches
=
base_image_input_size
[
0
]
//
base_image_input_d
joint
=
((
ncols
+
1
)
//
pooling_size
)
*
((
nrows
+
1
)
//
pooling_size
)
return
extra
+
joint
per_row
=
ncols
//
pooling_size
+
1
joint
=
per_row
*
(
nrows
//
pooling_size
)
+
2
image_token_length
=
(
crop_patches
+
pooling_size
-
1
)
//
pooling_size
resize
=
(
image_token_length
+
1
)
*
image_token_length
+
2
return
resize
+
joint
def
get_max_image_tokens
(
self
)
->
int
:
def
get_max_image_tokens
(
self
)
->
int
:
target_width
,
target_height
=
self
.
get_image_size_with_most_features
()
target_width
,
target_height
=
self
.
get_image_size_with_most_features
()
...
@@ -1291,6 +1328,7 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
...
@@ -1291,6 +1328,7 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
"image"
,
num_crops
),
"image"
,
num_crops
),
feat_is_patch
=
MultiModalFieldConfig
.
flat_from_sizes
(
feat_is_patch
=
MultiModalFieldConfig
.
flat_from_sizes
(
"image"
,
num_crops
),
"image"
,
num_crops
),
embed_is_patch
=
MultiModalFieldConfig
.
batched
(
"image"
),
num_crops
=
MultiModalFieldConfig
.
batched
(
"image"
),
num_crops
=
MultiModalFieldConfig
.
batched
(
"image"
),
img_patch_id
=
MultiModalFieldConfig
.
shared
(
"image"
,
num_images
),
img_patch_id
=
MultiModalFieldConfig
.
shared
(
"image"
,
num_images
),
)
)
...
@@ -1330,10 +1368,8 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
...
@@ -1330,10 +1368,8 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
joint
=
([
img_start_id
]
+
joint_row
*
joint
=
([
img_start_id
]
+
joint_row
*
((
nrows
+
1
)
//
pooling_size
)
+
[
img_end_id
])
((
nrows
+
1
)
//
pooling_size
)
+
[
img_end_id
])
return
PromptUpdateDetails
.
select_token_id
(
image_tokens
=
extra_joint
+
joint
extra_joint
+
joint
,
return
image_tokens
embed_token_id
=
img_patch_id
,
)
return
[
return
[
PromptInsertion
(
PromptInsertion
(
...
@@ -1439,6 +1475,11 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
...
@@ -1439,6 +1475,11 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
raise
ValueError
(
"Incorrect type of feat_is_patch. "
raise
ValueError
(
"Incorrect type of feat_is_patch. "
f
"Got type:
{
type
(
feat_is_patch
)
}
"
)
f
"Got type:
{
type
(
feat_is_patch
)
}
"
)
embed_is_patch
=
kwargs
.
pop
(
"embed_is_patch"
,
None
)
if
not
isinstance
(
embed_is_patch
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of embed_is_patch. "
f
"Got type:
{
type
(
embed_is_patch
)
}
"
)
num_crops
=
kwargs
.
pop
(
"num_crops"
,
None
)
num_crops
=
kwargs
.
pop
(
"num_crops"
,
None
)
if
not
isinstance
(
num_crops
,
(
torch
.
Tensor
,
list
)):
if
not
isinstance
(
num_crops
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of num_crops. "
raise
ValueError
(
"Incorrect type of num_crops. "
...
@@ -1450,12 +1491,14 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
...
@@ -1450,12 +1491,14 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
f
"Got type:
{
type
(
img_patch_id
)
}
"
)
f
"Got type:
{
type
(
img_patch_id
)
}
"
)
self
.
img_patch_id
=
img_patch_id
.
flatten
().
unique
().
item
()
self
.
img_patch_id
=
img_patch_id
.
flatten
().
unique
().
item
()
embed_is_patch
=
flatten_bn
(
embed_is_patch
)
num_crops
=
flatten_bn
(
num_crops
,
concat
=
True
)
num_crops
=
flatten_bn
(
num_crops
,
concat
=
True
)
return
MolmoImageInputs
(
return
MolmoImageInputs
(
images
=
images
,
images
=
images
,
image_masks
=
image_masks
,
image_masks
=
image_masks
,
feat_is_patch
=
feat_is_patch
,
feat_is_patch
=
feat_is_patch
,
embed_is_patch
=
embed_is_patch
,
num_crops
=
num_crops
,
num_crops
=
num_crops
,
)
)
...
@@ -1494,7 +1537,12 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
...
@@ -1494,7 +1537,12 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
if
image_input
is
None
:
if
image_input
is
None
:
return
None
return
None
return
self
.
_process_image_input
(
image_input
)
image_features
=
self
.
_process_image_input
(
image_input
)
return
scatter_patch_features
(
image_features
,
image_input
[
"embed_is_patch"
],
)
def
get_input_embeddings
(
def
get_input_embeddings
(
self
,
self
,
...
@@ -1508,7 +1556,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
...
@@ -1508,7 +1556,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
inputs_embeds
=
merge_multimodal_embeddings
(
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
input_ids
,
inputs_embeds
,
inputs_embeds
,
multimodal_embeddings
,
select_patch_features
(
multimodal_embeddings
)
,
self
.
img_patch_id
,
self
.
img_patch_id
,
)
)
return
inputs_embeds
return
inputs_embeds
...
...
vllm/model_executor/models/nvlm_d.py
View file @
af51d80f
...
@@ -57,7 +57,7 @@ class NVLMProcessor(BaseInternVLProcessor):
...
@@ -57,7 +57,7 @@ class NVLMProcessor(BaseInternVLProcessor):
# when trying to find "<tile" as a subsequence of "<Image><tile"
# when trying to find "<tile" as a subsequence of "<Image><tile"
repl
=
"<Image>"
+
features
+
"</Image>"
repl
=
"<Image>"
+
features
+
"</Image>"
return
PromptUpdateDetails
.
select_text
(
repl
,
IMG_PAD
)
return
PromptUpdateDetails
(
full
=
repl
,
features
=
repl
)
class
NVLMProcessingInfo
(
BaseInternVLProcessingInfo
):
class
NVLMProcessingInfo
(
BaseInternVLProcessingInfo
):
...
@@ -84,6 +84,31 @@ class NVLMProcessingInfo(BaseInternVLProcessingInfo):
...
@@ -84,6 +84,31 @@ class NVLMProcessingInfo(BaseInternVLProcessingInfo):
**
kwargs
,
**
kwargs
,
)
)
def
get_max_image_tokens
(
self
)
->
int
:
hf_processor
=
self
.
get_hf_processor
()
tokenizer
=
hf_processor
.
tokenizer
max_num_patches
=
hf_processor
.
max_dynamic_patch
# we need +1 here because max_dynamic_patch in config doesn't
# include the thumbnail patch
tile_pos_identifiers
=
[
f
"<tile_
{
i
+
1
}
>"
for
i
in
range
(
max_num_patches
)
]
if
hf_processor
.
use_thumbnail
and
max_num_patches
!=
1
:
tile_pos_identifiers
+=
[
"<tile_global_thumbnail>"
]
# "<Image><tile" is tokenized as ["<Image", "><", "tile"]
# so we include <tile_1> in the start_str
start_str
=
"<Image>"
+
tile_pos_identifiers
.
pop
(
0
)
end_str
=
"</Image>"
start_token_len
=
len
(
tokenizer
.
encode
(
start_str
))
end_token_len
=
len
(
tokenizer
.
encode
(
end_str
))
tile_token_len
=
sum
(
len
(
tokenizer
.
encode
(
identifier
))
for
identifier
in
tile_pos_identifiers
)
non_image_tokens_num
=
start_token_len
+
end_token_len
+
tile_token_len
return
super
().
get_max_image_tokens
()
+
non_image_tokens_num
class
NVLMDummyInputsBuilder
(
InternVLDummyInputsBuilder
[
NVLMProcessingInfo
]):
class
NVLMDummyInputsBuilder
(
InternVLDummyInputsBuilder
[
NVLMProcessingInfo
]):
...
@@ -152,7 +177,10 @@ class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]):
...
@@ -152,7 +177,10 @@ class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]):
repl
=
hf_processor
.
get_image_repl
(
feature_size
,
num_patches
)
repl
=
hf_processor
.
get_image_repl
(
feature_size
,
num_patches
)
return
PromptUpdateDetails
.
select_text
(
repl
.
full
+
"
\n
"
,
IMG_PAD
)
return
PromptUpdateDetails
(
full
=
repl
.
full
+
"
\n
"
,
features
=
repl
.
features
+
"
\n
"
,
)
# See note in dummy data regarding why we have the extra newline
# See note in dummy data regarding why we have the extra newline
return
[
return
[
...
...
vllm/model_executor/models/paligemma.py
View file @
af51d80f
...
@@ -162,9 +162,9 @@ class PaliGemmaMultiModalProcessor(
...
@@ -162,9 +162,9 @@ class PaliGemmaMultiModalProcessor(
modality
=
"image"
,
modality
=
"image"
,
target
=
PromptIndexTargets
.
prefix
(
target
=
PromptIndexTargets
.
prefix
(
[
bos_token_id
]
if
tokenizer
.
add_bos_token
else
[]),
[
bos_token_id
]
if
tokenizer
.
add_bos_token
else
[]),
insertion
=
PromptUpdateDetails
.
select_token_id
(
insertion
=
PromptUpdateDetails
(
image_tokens
+
[
bos_token_id
],
full
=
image_tokens
+
[
bos_token_id
],
embed_token_id
=
image_token
_id
,
features
=
image_token
s
,
),
),
)
)
]
]
...
...
vllm/model_executor/models/phi3v.py
View file @
af51d80f
...
@@ -40,7 +40,8 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
...
@@ -40,7 +40,8 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
BoundPromptUpdate
,
BaseProcessingInfo
,
BoundPromptUpdate
,
PlaceholderFeaturesInfo
,
PlaceholderFeaturesInfo
,
PromptReplacement
,
PromptUpdate
)
PromptReplacement
,
PromptUpdate
,
PromptUpdateDetails
)
# yapf: enable
# yapf: enable
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
...
@@ -442,7 +443,12 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
...
@@ -442,7 +443,12 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
processor
=
hf_processor
,
processor
=
hf_processor
,
)
)
return
[
_IMAGE_TOKEN_ID
]
*
num_image_tokens
image_tokens
=
[
_IMAGE_TOKEN_ID
]
*
num_image_tokens
return
PromptUpdateDetails
(
full
=
image_tokens
,
features
=
image_tokens
,
)
num_images
=
mm_items
.
get_count
(
"image"
,
strict
=
False
)
num_images
=
mm_items
.
get_count
(
"image"
,
strict
=
False
)
...
@@ -511,7 +517,6 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
...
@@ -511,7 +517,6 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
item_idx
=
p
.
item_idx
,
item_idx
=
p
.
item_idx
,
start_idx
=
p
.
start_idx
-
1
,
start_idx
=
p
.
start_idx
-
1
,
tokens
=
p
.
tokens
,
tokens
=
p
.
tokens
,
is_embed
=
p
.
is_embed
,
)
for
p
in
ps
)
for
p
in
ps
]
]
for
modality
,
ps
in
placeholders
.
items
()
for
modality
,
ps
in
placeholders
.
items
()
...
...
vllm/model_executor/models/pixtral.py
View file @
af51d80f
...
@@ -37,7 +37,7 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
...
@@ -37,7 +37,7 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems
)
MultiModalDataItems
)
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptReplacement
,
BaseProcessingInfo
,
PromptReplacement
,
PromptUpdate
,
PromptUpdateDetails
)
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.tokenizer
import
(
MistralTokenizer
,
from
vllm.transformers_utils.tokenizer
import
(
MistralTokenizer
,
...
@@ -46,7 +46,8 @@ from vllm.transformers_utils.tokenizer import (MistralTokenizer,
...
@@ -46,7 +46,8 @@ from vllm.transformers_utils.tokenizer import (MistralTokenizer,
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
from
.utils
import
(
flatten_bn
,
init_vllm_registered_model
,
maybe_prefix
,
from
.utils
import
(
flatten_bn
,
init_vllm_registered_model
,
maybe_prefix
,
merge_multimodal_embeddings
)
merge_multimodal_embeddings
)
from
.vision
import
VisionEncoderInfo
,
resolve_visual_encoder_outputs
from
.vision
import
(
VisionEncoderInfo
,
resolve_visual_encoder_outputs
,
scatter_patch_features
,
select_patch_features
)
try
:
try
:
from
xformers
import
ops
as
xops
from
xformers
import
ops
as
xops
...
@@ -67,6 +68,14 @@ class PixtralImagePixelInputs(TypedDict):
...
@@ -67,6 +68,14 @@ class PixtralImagePixelInputs(TypedDict):
The result of stacking :attr:`ImageEncoding.tokens` from each prompt.
The result of stacking :attr:`ImageEncoding.tokens` from each prompt.
"""
"""
embed_is_patch
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
class
PixtralProcessorAdapter
:
class
PixtralProcessorAdapter
:
"""
"""
...
@@ -135,8 +144,11 @@ class PixtralProcessorAdapter:
...
@@ -135,8 +144,11 @@ class PixtralProcessorAdapter:
"For more info, see: "
"For more info, see: "
"https://github.com/vllm-project/vllm/issues/8411."
)
"https://github.com/vllm-project/vllm/issues/8411."
)
image_token_id
=
self
.
image_token_id
images_processed
=
list
[
torch
.
Tensor
]()
images_processed
=
list
[
torch
.
Tensor
]()
images_tokens
=
list
[
torch
.
Tensor
]()
images_tokens
=
list
[
torch
.
Tensor
]()
images_embed_is_patch
=
list
[
torch
.
Tensor
]()
for
image
in
images
:
for
image
in
images
:
image_inputs
=
self
.
image_processor
(
ImageChunk
(
image
=
image
))
image_inputs
=
self
.
image_processor
(
ImageChunk
(
image
=
image
))
...
@@ -145,10 +157,12 @@ class PixtralProcessorAdapter:
...
@@ -145,10 +157,12 @@ class PixtralProcessorAdapter:
images_processed
.
append
(
image_processed
)
images_processed
.
append
(
image_processed
)
images_tokens
.
append
(
image_tokens
)
images_tokens
.
append
(
image_tokens
)
images_embed_is_patch
.
append
(
image_tokens
==
image_token_id
)
return
{
return
{
"input_ids"
:
torch
.
cat
(
images_tokens
)[
None
].
expand
(
len
(
text
),
-
1
),
"input_ids"
:
torch
.
cat
(
images_tokens
)[
None
].
expand
(
len
(
text
),
-
1
),
"images"
:
images_processed
,
"images"
:
images_processed
,
"embed_is_patch"
:
images_embed_is_patch
,
}
}
...
@@ -199,7 +213,7 @@ class PixtralProcessingInfo(BaseProcessingInfo):
...
@@ -199,7 +213,7 @@ class PixtralProcessingInfo(BaseProcessingInfo):
ncols
,
nrows
=
processor
.
image_processor
.
_image_to_num_tokens
(
ncols
,
nrows
=
processor
.
image_processor
.
_image_to_num_tokens
(
Image
.
new
(
"RGB"
,
(
image_width
,
image_height
)))
Image
.
new
(
"RGB"
,
(
image_width
,
image_height
)))
return
ncols
*
nrows
return
(
ncols
+
1
)
*
nrows
def
get_image_size_with_most_features
(
self
)
->
ImageSize
:
def
get_image_size_with_most_features
(
self
)
->
ImageSize
:
image_processor
=
self
.
get_hf_processor
().
image_processor
image_processor
=
self
.
get_hf_processor
().
image_processor
...
@@ -249,7 +263,10 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
...
@@ -249,7 +263,10 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
hf_inputs
:
Mapping
[
str
,
NestedTensors
],
hf_inputs
:
Mapping
[
str
,
NestedTensors
],
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
return
dict
(
images
=
MultiModalFieldConfig
.
batched
(
"image"
))
return
dict
(
images
=
MultiModalFieldConfig
.
batched
(
"image"
),
embed_is_patch
=
MultiModalFieldConfig
.
batched
(
"image"
),
)
def
_get_prompt_updates
(
def
_get_prompt_updates
(
self
,
self
,
...
@@ -273,7 +290,7 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
...
@@ -273,7 +290,7 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
tokens
=
([
image_token_id
]
*
ncols
+
[
image_break_id
])
*
nrows
tokens
=
([
image_token_id
]
*
ncols
+
[
image_break_id
])
*
nrows
tokens
[
-
1
]
=
image_end_id
tokens
[
-
1
]
=
image_end_id
return
PromptUpdateDetails
.
select_token_id
(
tokens
,
image_token_id
)
return
tokens
return
[
return
[
PromptReplacement
(
PromptReplacement
(
...
@@ -364,9 +381,17 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -364,9 +381,17 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
raise
ValueError
(
"Incorrect type of images. "
raise
ValueError
(
"Incorrect type of images. "
f
"Got type:
{
type
(
images
)
}
"
)
f
"Got type:
{
type
(
images
)
}
"
)
embed_is_patch
=
kwargs
.
pop
(
"embed_is_patch"
)
if
not
isinstance
(
embed_is_patch
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of embed_is_patch. "
f
"Got type:
{
type
(
embed_is_patch
)
}
"
)
embed_is_patch
=
flatten_bn
(
embed_is_patch
)
return
PixtralImagePixelInputs
(
return
PixtralImagePixelInputs
(
type
=
"pixel_values"
,
type
=
"pixel_values"
,
images
=
flatten_bn
(
images
),
images
=
flatten_bn
(
images
),
embed_is_patch
=
embed_is_patch
,
)
)
def
_process_image_input
(
def
_process_image_input
(
...
@@ -402,7 +427,12 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -402,7 +427,12 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
if
image_input
is
None
:
if
image_input
is
None
:
return
None
return
None
return
self
.
_process_image_input
(
image_input
)
image_features
=
self
.
_process_image_input
(
image_input
)
return
scatter_patch_features
(
image_features
,
image_input
[
"embed_is_patch"
],
)
def
get_input_embeddings
(
def
get_input_embeddings
(
self
,
self
,
...
@@ -414,7 +444,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -414,7 +444,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds
=
merge_multimodal_embeddings
(
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
input_ids
,
inputs_embeds
,
inputs_embeds
,
multimodal_embeddings
,
select_patch_features
(
multimodal_embeddings
)
,
self
.
vision_args
.
image_token_id
,
self
.
vision_args
.
image_token_id
,
)
)
return
inputs_embeds
return
inputs_embeds
...
@@ -933,7 +963,9 @@ class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]):
...
@@ -933,7 +963,9 @@ class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]):
image_width
=
image_width
,
image_width
=
image_width
,
image_height
=
image_height
,
image_height
=
image_height
,
)
)
return
ncols
*
nrows
# Consider the image_break_token
return
(
ncols
+
1
)
*
nrows
def
get_max_image_tokens
(
self
)
->
int
:
def
get_max_image_tokens
(
self
)
->
int
:
image_size
=
self
.
get_image_size
()
image_size
=
self
.
get_image_size
()
...
...
vllm/model_executor/models/qwen2_audio.py
View file @
af51d80f
...
@@ -229,9 +229,9 @@ class Qwen2AudioMultiModalProcessor(
...
@@ -229,9 +229,9 @@ class Qwen2AudioMultiModalProcessor(
audio_tokens
=
[
audio_token_id
]
*
num_features
audio_tokens
=
[
audio_token_id
]
*
num_features
return
PromptUpdateDetails
.
select_token_id
(
return
PromptUpdateDetails
(
[
audio_bos_id
]
+
audio_tokens
+
[
audio_eos_id
],
full
=
[
audio_bos_id
]
+
audio_tokens
+
[
audio_eos_id
],
embed_token_id
=
audio_token
_id
,
features
=
audio_token
s
,
)
)
return
[
return
[
...
...
vllm/model_executor/models/qwen_vl.py
View file @
af51d80f
...
@@ -647,9 +647,9 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
...
@@ -647,9 +647,9 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
PromptReplacement
(
PromptReplacement
(
modality
=
"image"
,
modality
=
"image"
,
target
=
[
img_start_id
,
img_end_id
],
target
=
[
img_start_id
,
img_end_id
],
replacement
=
PromptUpdateDetails
.
select_token_id
(
replacement
=
PromptUpdateDetails
(
[
img_start_id
]
+
image_tokens
+
[
img_end_id
],
full
=
[
img_start_id
]
+
image_tokens
+
[
img_end_id
],
embed_token_id
=
img_pad_id
,
features
=
image_tokens
,
),
),
)
)
]
]
...
...
vllm/model_executor/models/skyworkr1v.py
View file @
af51d80f
...
@@ -40,6 +40,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
...
@@ -40,6 +40,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
init_vllm_registered_model
,
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
init_vllm_registered_model
,
maybe_prefix
,
merge_multimodal_embeddings
)
maybe_prefix
,
merge_multimodal_embeddings
)
from
.vision
import
scatter_patch_features
,
select_patch_features
IMG_START
=
'<img>'
IMG_START
=
'<img>'
IMG_END
=
'</img>'
IMG_END
=
'</img>'
...
@@ -60,6 +61,14 @@ class SkyworkR1VImagePixelInputs(TypedDict):
...
@@ -60,6 +61,14 @@ class SkyworkR1VImagePixelInputs(TypedDict):
num_patches
:
torch
.
Tensor
num_patches
:
torch
.
Tensor
"""Shape: `(batch_size * num_images)`"""
"""Shape: `(batch_size * num_images)`"""
embed_is_patch
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
class
SkyworkR1VImageEmbeddingInputs
(
TypedDict
):
class
SkyworkR1VImageEmbeddingInputs
(
TypedDict
):
type
:
Literal
[
"image_embeds"
]
type
:
Literal
[
"image_embeds"
]
...
@@ -410,13 +419,24 @@ class BaseSkyworkR1VProcessor(ABC):
...
@@ -410,13 +419,24 @@ class BaseSkyworkR1VProcessor(ABC):
torch
.
tensor
([
len
(
item
)
for
item
in
pixel_values_lst
]),
torch
.
tensor
([
len
(
item
)
for
item
in
pixel_values_lst
]),
}
}
tokenizer
=
self
.
tokenizer
image_token_id
=
self
.
image_token_id
embed_is_patch
=
list
[
torch
.
Tensor
]()
for
pixel_values
in
pixel_values_lst
:
for
pixel_values
in
pixel_values_lst
:
num_patches
=
pixel_values
.
shape
[
0
]
num_patches
=
pixel_values
.
shape
[
0
]
feature_size
=
num_patches
*
self
.
num_image_token
feature_size
=
num_patches
*
self
.
num_image_token
image_repl
=
self
.
get_image_repl
(
feature_size
,
num_patches
)
image_repl
=
self
.
get_image_repl
(
feature_size
,
num_patches
)
feature_tokens
=
tokenizer
.
encode
(
image_repl
.
features
,
add_special_tokens
=
False
)
text
=
[
t
.
replace
(
'<image>'
,
image_repl
.
full
,
1
)
for
t
in
text
]
text
=
[
t
.
replace
(
'<image>'
,
image_repl
.
full
,
1
)
for
t
in
text
]
embed_is_patch
.
append
(
torch
.
tensor
(
feature_tokens
)
==
image_token_id
)
image_inputs
[
"embed_is_patch"
]
=
embed_is_patch
text_inputs
=
self
.
tokenizer
(
text
)
text_inputs
=
self
.
tokenizer
(
text
)
...
@@ -440,7 +460,7 @@ class SkyworkR1VProcessor(BaseSkyworkR1VProcessor):
...
@@ -440,7 +460,7 @@ class SkyworkR1VProcessor(BaseSkyworkR1VProcessor):
repl_features
=
IMG_CONTEXT
*
feature_size
repl_features
=
IMG_CONTEXT
*
feature_size
repl_full
=
IMG_START
+
repl_features
+
IMG_END
repl_full
=
IMG_START
+
repl_features
+
IMG_END
return
PromptUpdateDetails
.
select_text
(
repl_full
,
IMG_CONTEXT
)
return
PromptUpdateDetails
(
full
=
repl_full
,
features
=
repl_features
)
class
BaseSkyworkR1VProcessingInfo
(
BaseProcessingInfo
):
class
BaseSkyworkR1VProcessingInfo
(
BaseProcessingInfo
):
...
@@ -579,6 +599,7 @@ class SkyworkR1VMultiModalProcessor(BaseMultiModalProcessor[_I]):
...
@@ -579,6 +599,7 @@ class SkyworkR1VMultiModalProcessor(BaseMultiModalProcessor[_I]):
pixel_values_flat
=
MultiModalFieldConfig
.
flat_from_sizes
(
pixel_values_flat
=
MultiModalFieldConfig
.
flat_from_sizes
(
"image"
,
image_num_patches
),
"image"
,
image_num_patches
),
image_num_patches
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_num_patches
=
MultiModalFieldConfig
.
batched
(
"image"
),
embed_is_patch
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_embeds
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_embeds
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_token_id
=
MultiModalFieldConfig
.
shared
(
"image"
,
num_images
),
image_token_id
=
MultiModalFieldConfig
.
shared
(
"image"
,
num_images
),
)
)
...
@@ -814,6 +835,7 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -814,6 +835,7 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
self
,
**
kwargs
:
object
)
->
Optional
[
SkyworkR1VImageInputs
]:
self
,
**
kwargs
:
object
)
->
Optional
[
SkyworkR1VImageInputs
]:
pixel_values_flat
=
kwargs
.
pop
(
"pixel_values_flat"
,
None
)
pixel_values_flat
=
kwargs
.
pop
(
"pixel_values_flat"
,
None
)
image_num_patches
=
kwargs
.
pop
(
"image_num_patches"
,
None
)
image_num_patches
=
kwargs
.
pop
(
"image_num_patches"
,
None
)
embed_is_patch
=
kwargs
.
pop
(
"embed_is_patch"
,
None
)
image_embeds
=
kwargs
.
pop
(
"image_embeds"
,
None
)
image_embeds
=
kwargs
.
pop
(
"image_embeds"
,
None
)
if
pixel_values_flat
is
None
and
image_embeds
is
None
:
if
pixel_values_flat
is
None
and
image_embeds
is
None
:
...
@@ -842,14 +864,20 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -842,14 +864,20 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
raise
ValueError
(
"Incorrect type of image_num_patches. "
raise
ValueError
(
"Incorrect type of image_num_patches. "
f
"Got type:
{
type
(
image_num_patches
)
}
"
)
f
"Got type:
{
type
(
image_num_patches
)
}
"
)
if
not
isinstance
(
embed_is_patch
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of embed_is_patch. "
f
"Got type:
{
type
(
embed_is_patch
)
}
"
)
pixel_values_flat
=
flatten_bn
(
pixel_values_flat
,
concat
=
True
)
pixel_values_flat
=
flatten_bn
(
pixel_values_flat
,
concat
=
True
)
image_num_patches
=
flatten_bn
(
image_num_patches
,
concat
=
True
)
image_num_patches
=
flatten_bn
(
image_num_patches
,
concat
=
True
)
embed_is_patch
=
flatten_bn
(
embed_is_patch
)
return
SkyworkR1VImagePixelInputs
(
return
SkyworkR1VImagePixelInputs
(
type
=
"pixel_values"
,
type
=
"pixel_values"
,
pixel_values_flat
=
self
.
_validate_pixel_values
(
pixel_values_flat
=
self
.
_validate_pixel_values
(
pixel_values_flat
),
pixel_values_flat
),
num_patches
=
image_num_patches
,
num_patches
=
image_num_patches
,
embed_is_patch
=
embed_is_patch
,
)
)
raise
AssertionError
(
"This line should be unreachable."
)
raise
AssertionError
(
"This line should be unreachable."
)
...
@@ -895,7 +923,15 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -895,7 +923,15 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
if
image_input
is
None
:
if
image_input
is
None
:
return
None
return
None
return
self
.
_process_image_input
(
image_input
)
image_features
=
self
.
_process_image_input
(
image_input
)
if
image_input
[
"type"
]
!=
"pixel_values"
:
return
image_features
return
scatter_patch_features
(
image_features
,
image_input
[
"embed_is_patch"
],
)
def
get_input_embeddings
(
def
get_input_embeddings
(
self
,
self
,
...
@@ -909,7 +945,7 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -909,7 +945,7 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds
=
merge_multimodal_embeddings
(
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
input_ids
,
inputs_embeds
,
inputs_embeds
,
multimodal_embeddings
,
select_patch_features
(
multimodal_embeddings
)
,
self
.
img_context_token_id
,
self
.
img_context_token_id
,
)
)
return
inputs_embeds
return
inputs_embeds
...
...
vllm/model_executor/models/vision.py
View file @
af51d80f
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
Final
,
Generic
,
Optional
,
Protocol
,
TypeVar
,
Union
from
collections.abc
import
Sequence
from
typing
import
Final
,
Generic
,
Optional
,
Protocol
,
TypeVar
,
Union
,
cast
import
torch
import
torch
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
...
@@ -9,9 +10,12 @@ from transformers import PretrainedConfig
...
@@ -9,9 +10,12 @@ from transformers import PretrainedConfig
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.attention.selector
import
(
backend_name_to_enum
,
from
vllm.attention.selector
import
(
backend_name_to_enum
,
get_global_forced_attn_backend
)
get_global_forced_attn_backend
)
from
vllm.jsontree
import
JSONTree
,
json_map_leaves
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
_Backend
,
current_platform
from
vllm.platforms
import
_Backend
,
current_platform
from
.interfaces
import
MultiModalEmbeddings
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_C
=
TypeVar
(
"_C"
,
bound
=
PretrainedConfig
)
_C
=
TypeVar
(
"_C"
,
bound
=
PretrainedConfig
)
...
@@ -151,3 +155,74 @@ def resolve_visual_encoder_outputs(
...
@@ -151,3 +155,74 @@ def resolve_visual_encoder_outputs(
if
post_layer_norm
is
not
None
and
uses_last_layer
:
if
post_layer_norm
is
not
None
and
uses_last_layer
:
hs_pool
[
-
1
]
=
post_layer_norm
(
encoder_outputs
)
hs_pool
[
-
1
]
=
post_layer_norm
(
encoder_outputs
)
return
torch
.
cat
(
hs_pool
,
dim
=-
1
)
return
torch
.
cat
(
hs_pool
,
dim
=-
1
)
def
scatter_patch_features
(
patches
:
Union
[
torch
.
Tensor
,
Sequence
[
torch
.
Tensor
]],
embed_is_patch
:
Union
[
torch
.
Tensor
,
Sequence
[
torch
.
Tensor
]],
)
->
tuple
[
torch
.
Tensor
,
...]:
"""
Scatter the patch features into a contiguous tensor that corresponds
to the embedding tokens defined by the multimodal processor.
The rest of the values in the tensor are set to NaN so that they
can be filtered out by :func`select_patch_features`.
Args:
patches: The patch features for each image.
Shape: `(num_images, <patch_dims>, feature_depth)`
embed_is_patch: A boolean mask indicating which image embeddings
correspond to patch tokens for each image.
Shape: `(num_images, num_embeds)`
Note:
The original code only considers patch tokens as feature
tokens, but our processor considers all image-related tokens
as feature tokens because the feature tokens need to be
consecutive in `input_ids`.
Example:
A simplified example for one image:
.. code-block::
Embedding tokens (from HF processor):
[<start> <patch> <patch> <col> <patch> <patch> <col> <end> ]
embed_is_patch (from HF processor):
[ False True True False True True False False ]
Encoder outputs (from model):
[ p1 p2 p3 p4 ]
The resulting embedding tensor is:
[ nan p1 p2 nan p3 p4 nan nan ]
"""
if
len
(
patches
)
!=
len
(
embed_is_patch
):
raise
ValueError
(
f
"Inconsistent num_images:
{
len
(
patches
)
=
}
vs. "
f
"
{
len
(
embed_is_patch
)
=
}
"
)
def
get_embed_one
(
patches_one
:
torch
.
Tensor
,
e_is_patch
:
torch
.
Tensor
):
embed_one
=
patches_one
.
new_full
(
(
e_is_patch
.
shape
[
0
],
patches_one
.
shape
[
-
1
]),
fill_value
=
torch
.
nan
,
)
embed_one
[
e_is_patch
]
=
patches_one
return
embed_one
return
tuple
(
get_embed_one
(
patches_one
,
e_is_patch
)
for
patches_one
,
e_is_patch
in
zip
(
patches
,
embed_is_patch
))
def
select_patch_features
(
multimodal_embeddings
:
MultiModalEmbeddings
)
->
MultiModalEmbeddings
:
"""
Given the outputs of :func:`scatter_patch_features`, return only
the values that correspond to patch features.
"""
selected_features
=
json_map_leaves
(
lambda
x
:
x
[
~
x
.
isnan
()].
view
(
-
1
,
*
x
.
shape
[
1
:]),
cast
(
JSONTree
[
torch
.
Tensor
],
multimodal_embeddings
),
)
return
cast
(
MultiModalEmbeddings
,
selected_features
)
vllm/multimodal/base.py
View file @
af51d80f
...
@@ -385,8 +385,8 @@ class MultiModalPlaceholderMap:
...
@@ -385,8 +385,8 @@ class MultiModalPlaceholderMap:
for
placeholder_dict
,
mm_item
in
zip
(
multi_modal_placeholders
,
for
placeholder_dict
,
mm_item
in
zip
(
multi_modal_placeholders
,
multi_modal_items
):
multi_modal_items
):
placeholder
=
range
(
placeholder
=
range
(
placeholder_dict
.
offset
,
placeholder_dict
[
"
offset
"
]
,
placeholder_dict
.
offset
+
placeholder_dict
.
length
,
placeholder_dict
[
"
offset
"
]
+
placeholder_dict
[
"
length
"
]
,
)
)
intersection
=
range
(
intersection
=
range
(
max
(
positions
.
start
,
placeholder
.
start
),
max
(
positions
.
start
,
placeholder
.
start
),
...
...
vllm/multimodal/inputs.py
View file @
af51d80f
...
@@ -109,8 +109,7 @@ The built-in modalities are defined by :class:`MultiModalDataBuiltins`.
...
@@ -109,8 +109,7 @@ The built-in modalities are defined by :class:`MultiModalDataBuiltins`.
"""
"""
@
dataclass
(
frozen
=
True
)
class
PlaceholderRange
(
TypedDict
):
class
PlaceholderRange
:
"""
"""
Placeholder location information for multi-modal data.
Placeholder location information for multi-modal data.
...
@@ -122,8 +121,8 @@ class PlaceholderRange:
...
@@ -122,8 +121,8 @@ class PlaceholderRange:
.. code-block::
.. code-block::
A:
PlaceholderRange(
offset
=
0, length
=4)
A:
{ "
offset
":
0,
"
length
": 4 }
B:
PlaceholderRange(
offset
=
5, length
=4)
B:
{ "
offset
":
5,
"
length
": 4 }
"""
"""
offset
:
int
offset
:
int
...
@@ -132,31 +131,6 @@ class PlaceholderRange:
...
@@ -132,31 +131,6 @@ class PlaceholderRange:
length
:
int
length
:
int
"""The length of the placeholder."""
"""The length of the placeholder."""
is_embed
:
Optional
[
torch
.
Tensor
]
=
None
"""
A boolean mask of shape `(length,)` indicating which positions
between `offset` and `offset + length` to assign embeddings to.
"""
def
get_num_embeds
(
self
)
->
int
:
if
self
.
is_embed
is
None
:
return
self
.
length
return
int
(
self
.
is_embed
.
sum
().
item
())
def
__eq__
(
self
,
other
:
object
)
->
bool
:
if
not
isinstance
(
other
,
self
.
__class__
):
return
False
if
not
(
self
.
offset
,
self
.
length
)
==
(
other
.
offset
,
other
.
length
):
return
False
if
self
.
is_embed
is
None
:
return
other
.
is_embed
is
None
if
other
.
is_embed
is
None
:
return
self
.
is_embed
is
None
return
nested_tensors_equal
(
self
.
is_embed
,
other
.
is_embed
)
NestedTensors
=
Union
[
list
[
"NestedTensors"
],
list
[
torch
.
Tensor
],
torch
.
Tensor
,
NestedTensors
=
Union
[
list
[
"NestedTensors"
],
list
[
torch
.
Tensor
],
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
...]]
tuple
[
torch
.
Tensor
,
...]]
...
...
vllm/multimodal/processing.py
View file @
af51d80f
...
@@ -108,46 +108,16 @@ class PromptUpdateDetails(Generic[_S]):
...
@@ -108,46 +108,16 @@ class PromptUpdateDetails(Generic[_S]):
full
:
_S
full
:
_S
"""The full content."""
"""The full content."""
is_embed
:
Optional
[
Callable
[[
"_BoundPromptSequence"
],
torch
.
Tensor
]]
=
None
features
:
_S
"""
"""
Given :attr:`full`, return a boolean mask of shape `(len(full),)`
The part of the content that corresponds to feature placeholders;
indicating which positions of `full` to assign embeddings to.
this will be replaced by the output of the vision encoder during model
inference.
`None` (default) means to assign embeddings to all positions of `full`.
The embeddings are obtained by calling
:class:`SupportsMultiModal.get_multimodal_embeddings`.
"""
"""
@
staticmethod
@
staticmethod
def
from_seq
(
seq
:
_S
)
->
"PromptUpdateDetails[_S]"
:
def
from_seq
(
seq
:
_S
)
->
"PromptUpdateDetails[_S]"
:
return
PromptUpdateDetails
(
full
=
seq
)
return
PromptUpdateDetails
(
full
=
seq
,
features
=
seq
)
@
staticmethod
def
select_text
(
seq
:
_S
,
embed_text
:
str
,
)
->
"PromptUpdateDetails[_S]"
:
def
is_embed
(
full
:
"_BoundPromptSequence"
)
->
torch
.
Tensor
:
embed_token_ids
=
encode_tokens
(
full
.
tokenizer
,
embed_text
)
return
torch
.
isin
(
torch
.
tensor
(
full
.
token_ids
),
torch
.
tensor
(
embed_token_ids
),
)
return
PromptUpdateDetails
(
full
=
seq
,
is_embed
=
is_embed
)
@
staticmethod
def
select_token_id
(
seq
:
_S
,
embed_token_id
:
int
,
)
->
"PromptUpdateDetails[_S]"
:
return
PromptUpdateDetails
(
full
=
seq
,
is_embed
=
lambda
f
:
torch
.
tensor
(
f
.
token_ids
)
==
embed_token_id
,
)
PromptUpdateInfo
=
Union
[
PromptSeq
,
PromptUpdateDetails
]
PromptUpdateInfo
=
Union
[
PromptSeq
,
PromptUpdateDetails
]
...
@@ -436,7 +406,7 @@ class _BoundPromptSequence:
...
@@ -436,7 +406,7 @@ class _BoundPromptSequence:
@
dataclass
@
dataclass
class
_BoundPromptContent
:
class
_BoundPromptContent
:
full
:
_BoundPromptSequence
full
:
_BoundPromptSequence
is_embed
:
Optional
[
Callable
[[
"
_BoundPromptSequence
"
],
torch
.
Tensor
]]
features
:
_BoundPromptSequence
@
dataclass
@
dataclass
...
@@ -496,8 +466,10 @@ class BoundPromptUpdate:
...
@@ -496,8 +466,10 @@ class BoundPromptUpdate:
bound_full
=
_BoundPromptSequence
.
from_seq
(
self
.
tokenizer
,
bound_full
=
_BoundPromptSequence
.
from_seq
(
self
.
tokenizer
,
content
.
full
)
content
.
full
)
bound_features
=
_BoundPromptSequence
.
from_seq
(
self
.
tokenizer
,
content
.
features
)
bound_content
=
_BoundPromptContent
(
full
=
bound_full
,
bound_content
=
_BoundPromptContent
(
full
=
bound_full
,
is_embed
=
content
.
is_embed
)
features
=
bound_features
)
if
cache_key
is
not
None
:
if
cache_key
is
not
None
:
self
.
_content_cache
[
cache_key
]
=
bound_content
self
.
_content_cache
[
cache_key
]
=
bound_content
...
@@ -633,19 +605,15 @@ class PlaceholderFeaturesInfo:
...
@@ -633,19 +605,15 @@ class PlaceholderFeaturesInfo:
item_idx
:
int
item_idx
:
int
start_idx
:
int
start_idx
:
int
tokens
:
list
[
int
]
tokens
:
list
[
int
]
is_embed
:
Optional
[
torch
.
Tensor
]
@
property
@
property
def
length
(
self
)
->
int
:
def
length
(
self
)
->
int
:
return
len
(
self
.
tokens
)
return
len
(
self
.
tokens
)
def
to_range
(
self
)
->
PlaceholderRange
:
def
to_range
(
self
)
->
PlaceholderRange
:
# TODO: Is it worth it to optimize this by stripping the
# leading and ending positions where `is_embed=False`?
return
PlaceholderRange
(
return
PlaceholderRange
(
offset
=
self
.
start_idx
,
offset
=
self
.
start_idx
,
length
=
self
.
length
,
length
=
self
.
length
,
is_embed
=
self
.
is_embed
,
)
)
...
@@ -838,17 +806,22 @@ def _iter_placeholders(
...
@@ -838,17 +806,22 @@ def _iter_placeholders(
continue
continue
if
prompt
[
start_idx
:
end_idx_full
]
==
content_tokens_full
:
if
prompt
[
start_idx
:
end_idx_full
]
==
content_tokens_full
:
content_is_embed
=
content
.
is_embed
content_tokens_feat
=
content
.
features
.
token_ids
if
content_is_embed
is
not
None
:
content_is_embed
=
content_is_embed
(
content
.
full
)
try
:
match
=
next
(
iter_token_matches
(
content_tokens_full
,
content_tokens_feat
))
yield
PlaceholderFeaturesInfo
(
yield
PlaceholderFeaturesInfo
(
modality
=
modality
,
modality
=
modality
,
item_idx
=
item_idx
,
item_idx
=
item_idx
,
start_idx
=
start_idx
,
start_idx
=
start_idx
+
match
.
start_idx
,
tokens
=
content_tokens_full
,
tokens
=
content_tokens_feat
,
is_embed
=
content_is_embed
,
)
)
except
StopIteration
:
raise
AssertionError
(
f
"
{
content_tokens_feat
=
}
should be a "
f
"subsequence of
{
content_tokens_full
=
}
"
)
from
None
# Exclude overlapping matches
# Exclude overlapping matches
start_idx
=
end_idx_full
start_idx
=
end_idx_full
...
...
vllm/multimodal/profiling.py
View file @
af51d80f
...
@@ -180,7 +180,7 @@ class MultiModalProfiler(Generic[_I]):
...
@@ -180,7 +180,7 @@ class MultiModalProfiler(Generic[_I]):
placeholders_by_modality
=
mm_inputs
[
"mm_placeholders"
]
placeholders_by_modality
=
mm_inputs
[
"mm_placeholders"
]
total_placeholders_by_modality
=
{
total_placeholders_by_modality
=
{
modality
:
sum
(
item
.
get_num_embeds
()
for
item
in
placeholders
)
modality
:
sum
(
item
[
"length"
]
for
item
in
placeholders
)
for
modality
,
placeholders
in
placeholders_by_modality
.
items
()
for
modality
,
placeholders
in
placeholders_by_modality
.
items
()
}
}
expected_placeholders_by_modality
=
{
expected_placeholders_by_modality
=
{
...
...
vllm/multimodal/utils.py
View file @
af51d80f
...
@@ -340,7 +340,7 @@ def merge_and_sort_multimodal_metadata(
...
@@ -340,7 +340,7 @@ def merge_and_sort_multimodal_metadata(
all_items
.
append
((
modality
,
placeholder
,
hash_value
))
all_items
.
append
((
modality
,
placeholder
,
hash_value
))
# Sort all items by offset
# Sort all items by offset
all_items
.
sort
(
key
=
lambda
x
:
x
[
1
]
.
offset
)
all_items
.
sort
(
key
=
lambda
x
:
x
[
1
]
[
'
offset
'
]
)
# Split into separate lists
# Split into separate lists
sorted_modalities
=
[
item
[
0
]
for
item
in
all_items
]
sorted_modalities
=
[
item
[
0
]
for
item
in
all_items
]
...
...
vllm/v1/core/kv_cache_utils.py
View file @
af51d80f
...
@@ -310,7 +310,8 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
...
@@ -310,7 +310,8 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
# Note that we assume mm_positions is sorted by offset.
# Note that we assume mm_positions is sorted by offset.
# We do not need to check all mm inputs if the start token index is out of
# We do not need to check all mm inputs if the start token index is out of
# range. This usually happens in the late prefill phase and decoding phase.
# range. This usually happens in the late prefill phase and decoding phase.
if
mm_positions
[
-
1
].
offset
+
mm_positions
[
-
1
].
length
<
start_token_idx
:
if
mm_positions
[
-
1
][
"offset"
]
+
mm_positions
[
-
1
][
"length"
]
<
start_token_idx
:
return
extra_keys
,
start_mm_idx
return
extra_keys
,
start_mm_idx
# Support start_mm_idx == -1 to indicate the last mm input.
# Support start_mm_idx == -1 to indicate the last mm input.
...
@@ -321,8 +322,8 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
...
@@ -321,8 +322,8 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
curr_mm_idx
=
start_mm_idx
curr_mm_idx
=
start_mm_idx
while
mm_positions
and
curr_mm_idx
<
len
(
mm_positions
):
while
mm_positions
and
curr_mm_idx
<
len
(
mm_positions
):
assert
mm_hashes
[
curr_mm_idx
]
is
not
None
assert
mm_hashes
[
curr_mm_idx
]
is
not
None
offset
=
mm_positions
[
curr_mm_idx
]
.
offset
offset
=
mm_positions
[
curr_mm_idx
]
[
"
offset
"
]
length
=
mm_positions
[
curr_mm_idx
]
.
length
length
=
mm_positions
[
curr_mm_idx
]
[
"
length
"
]
if
end_token_idx
>
offset
:
if
end_token_idx
>
offset
:
if
start_token_idx
>
offset
+
length
:
if
start_token_idx
>
offset
+
length
:
# This block has passed the current mm input.
# This block has passed the current mm input.
...
...
vllm/v1/core/sched/scheduler.py
View file @
af51d80f
...
@@ -505,8 +505,8 @@ class Scheduler(SchedulerInterface):
...
@@ -505,8 +505,8 @@ class Scheduler(SchedulerInterface):
assert
mm_positions
is
not
None
assert
mm_positions
is
not
None
assert
len
(
mm_positions
)
>
0
assert
len
(
mm_positions
)
>
0
for
i
,
pos_info
in
enumerate
(
mm_positions
):
for
i
,
pos_info
in
enumerate
(
mm_positions
):
start_pos
=
pos_info
.
offset
start_pos
=
pos_info
[
"
offset
"
]
num_encoder_tokens
=
pos_info
.
length
num_encoder_tokens
=
pos_info
[
"
length
"
]
# The encoder output is needed if the two ranges overlap:
# The encoder output is needed if the two ranges overlap:
# [num_computed_tokens, num_computed_tokens + num_new_tokens) and
# [num_computed_tokens, num_computed_tokens + num_new_tokens) and
...
@@ -596,8 +596,8 @@ class Scheduler(SchedulerInterface):
...
@@ -596,8 +596,8 @@ class Scheduler(SchedulerInterface):
if
cached_encoder_input_ids
:
if
cached_encoder_input_ids
:
for
input_id
in
list
(
cached_encoder_input_ids
):
for
input_id
in
list
(
cached_encoder_input_ids
):
mm_positions
=
request
.
mm_positions
[
input_id
]
mm_positions
=
request
.
mm_positions
[
input_id
]
start_pos
=
mm_positions
.
offset
start_pos
=
mm_positions
[
"
offset
"
]
num_tokens
=
mm_positions
.
length
num_tokens
=
mm_positions
[
"
length
"
]
if
start_pos
+
num_tokens
<=
request
.
num_computed_tokens
:
if
start_pos
+
num_tokens
<=
request
.
num_computed_tokens
:
# The encoder output is already processed and stored
# The encoder output is already processed and stored
# in the decoder's KV cache.
# in the decoder's KV cache.
...
...
vllm/v1/request.py
View file @
af51d80f
...
@@ -121,7 +121,7 @@ class Request:
...
@@ -121,7 +121,7 @@ class Request:
def
get_num_encoder_tokens
(
self
,
input_id
:
int
)
->
int
:
def
get_num_encoder_tokens
(
self
,
input_id
:
int
)
->
int
:
assert
input_id
<
len
(
self
.
mm_positions
)
assert
input_id
<
len
(
self
.
mm_positions
)
num_tokens
=
self
.
mm_positions
[
input_id
]
.
length
num_tokens
=
self
.
mm_positions
[
input_id
]
[
"
length
"
]
return
num_tokens
return
num_tokens
@
property
@
property
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
af51d80f
...
@@ -19,8 +19,7 @@ from vllm.logger import init_logger
...
@@ -19,8 +19,7 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.rotary_embedding
import
MRotaryEmbedding
from
vllm.model_executor.layers.rotary_embedding
import
MRotaryEmbedding
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
from
vllm.multimodal.utils
import
group_mm_inputs_by_modality
from
vllm.multimodal.utils
import
group_mm_inputs_by_modality
from
vllm.sampling_params
import
SamplingType
from
vllm.sampling_params
import
SamplingType
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
...
@@ -44,8 +43,7 @@ from vllm.v1.utils import bind_kv_cache
...
@@ -44,8 +43,7 @@ from vllm.v1.utils import bind_kv_cache
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
from
.utils
import
(
gather_mm_placeholders
,
sanity_check_mm_encoder_outputs
,
from
.utils
import
sanity_check_mm_encoder_outputs
scatter_mm_placeholders
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
import
xgrammar
as
xgr
import
xgrammar
as
xgr
...
@@ -831,22 +829,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -831,22 +829,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
)
return
metadata
return
metadata
def
_execute_
mm_
encoder
(
self
,
scheduler_output
:
"SchedulerOutput"
):
def
_execute_encoder
(
self
,
scheduler_output
:
"SchedulerOutput"
):
scheduled_encoder_inputs
=
scheduler_output
.
scheduled_encoder_inputs
scheduled_encoder_inputs
=
scheduler_output
.
scheduled_encoder_inputs
if
not
scheduled_encoder_inputs
:
if
not
scheduled_encoder_inputs
:
return
return
# Batch the multi-modal inputs.
# Batch the multi-modal inputs.
mm_inputs
=
list
[
MultiModalKwargs
]
()
mm_inputs
:
list
[
MultiModalKwargs
]
=
[]
req_i
ds_pos
=
list
[
tuple
[
str
,
int
,
PlaceholderRange
]]()
req_i
nput_ids
:
list
[
tuple
[
str
,
int
]]
=
[]
for
req_id
,
encoder_input_ids
in
scheduled_encoder_inputs
.
items
():
for
req_id
,
encoder_input_ids
in
scheduled_encoder_inputs
.
items
():
req_state
=
self
.
requests
[
req_id
]
req_state
=
self
.
requests
[
req_id
]
for
input_id
,
pos_info
in
zip
(
for
input_id
in
encoder_input_ids
:
encoder_input_ids
,
req_state
.
mm_positions
,
):
mm_inputs
.
append
(
req_state
.
mm_inputs
[
input_id
])
mm_inputs
.
append
(
req_state
.
mm_inputs
[
input_id
])
req_i
ds_po
s
.
append
((
req_id
,
input_id
,
pos_info
))
req_i
nput_id
s
.
append
((
req_id
,
input_id
))
# Batch mm inputs as much as we can: if a request in the batch has
# Batch mm inputs as much as we can: if a request in the batch has
# multiple modalities or a different modality than the previous one,
# multiple modalities or a different modality than the previous one,
...
@@ -882,23 +877,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -882,23 +877,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
encoder_outputs
.
append
(
output
)
encoder_outputs
.
append
(
output
)
# Cache the encoder outputs.
# Cache the encoder outputs.
for
(
req_id
,
input_id
,
pos_info
),
output
in
zip
(
for
(
req_id
,
input_id
),
output
in
zip
(
req_input_ids
,
encoder_outputs
):
req_ids_pos
,
encoder_outputs
,
):
if
req_id
not
in
self
.
encoder_cache
:
if
req_id
not
in
self
.
encoder_cache
:
self
.
encoder_cache
[
req_id
]
=
{}
self
.
encoder_cache
[
req_id
]
=
{}
self
.
encoder_cache
[
req_id
][
input_id
]
=
output
self
.
encoder_cache
[
req_id
][
input_id
]
=
scatter_mm_placeholders
(
def
_gather_encoder_outputs
(
output
,
is_embed
=
pos_info
.
is_embed
,
)
def
_gather_mm_embeddings
(
self
,
self
,
scheduler_output
:
"SchedulerOutput"
,
scheduler_output
:
"SchedulerOutput"
,
)
->
list
[
torch
.
Tensor
]:
)
->
list
[
torch
.
Tensor
]:
mm_embed
s
:
list
[
torch
.
Tensor
]
=
[]
encoder_output
s
:
list
[
torch
.
Tensor
]
=
[]
for
req_id
in
self
.
input_batch
.
req_ids
:
for
req_id
in
self
.
input_batch
.
req_ids
:
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
[
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
req_id
]
...
@@ -906,8 +894,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -906,8 +894,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_computed_tokens
=
req_state
.
num_computed_tokens
num_computed_tokens
=
req_state
.
num_computed_tokens
mm_positions
=
req_state
.
mm_positions
mm_positions
=
req_state
.
mm_positions
for
i
,
pos_info
in
enumerate
(
mm_positions
):
for
i
,
pos_info
in
enumerate
(
mm_positions
):
start_pos
=
pos_info
.
offset
start_pos
=
pos_info
[
"
offset
"
]
num_encoder_tokens
=
pos_info
.
length
num_encoder_tokens
=
pos_info
[
"
length
"
]
# The encoder output is needed if the two ranges overlap:
# The encoder output is needed if the two ranges overlap:
# [num_computed_tokens,
# [num_computed_tokens,
...
@@ -929,16 +917,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -929,16 +917,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
assert
req_id
in
self
.
encoder_cache
assert
req_id
in
self
.
encoder_cache
assert
i
in
self
.
encoder_cache
[
req_id
]
assert
i
in
self
.
encoder_cache
[
req_id
]
encoder_output
=
self
.
encoder_cache
[
req_id
][
i
]
encoder_output
=
self
.
encoder_cache
[
req_id
][
i
]
encoder_outputs
.
append
(
encoder_output
[
start_idx
:
end_idx
])
if
(
is_embed
:
=
pos_info
.
is_embed
)
is
not
None
:
return
encoder_outputs
is_embed
=
is_embed
[
start_idx
:
end_idx
]
mm_embeds_item
=
gather_mm_placeholders
(
encoder_output
[
start_idx
:
end_idx
],
is_embed
=
is_embed
,
)
mm_embeds
.
append
(
mm_embeds_item
)
return
mm_embeds
def
get_model
(
self
)
->
nn
.
Module
:
def
get_model
(
self
)
->
nn
.
Module
:
return
self
.
model
return
self
.
model
...
@@ -1003,10 +983,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1003,10 +983,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
self
.
is_multimodal_model
:
if
self
.
is_multimodal_model
:
# Run the multimodal encoder if any.
# Run the multimodal encoder if any.
self
.
_execute_
mm_
encoder
(
scheduler_output
)
self
.
_execute_encoder
(
scheduler_output
)
mm_embed
s
=
self
.
_gather_
mm_embedding
s
(
scheduler_output
)
encoder_output
s
=
self
.
_gather_
encoder_output
s
(
scheduler_output
)
else
:
else
:
mm_embed
s
=
[]
encoder_output
s
=
[]
# Prepare the decoder inputs.
# Prepare the decoder inputs.
attn_metadata
,
logits_indices
,
spec_decode_metadata
=
(
attn_metadata
,
logits_indices
,
spec_decode_metadata
=
(
...
@@ -1028,9 +1008,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1028,9 +1008,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# embeddings), we always use embeddings (rather than token ids)
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
# as input to the multimodal model, even when the input is text.
input_ids
=
self
.
input_ids
[:
num_scheduled_tokens
]
input_ids
=
self
.
input_ids
[:
num_scheduled_tokens
]
if
mm_embed
s
:
if
encoder_output
s
:
inputs_embeds
=
self
.
model
.
get_input_embeddings
(
inputs_embeds
=
self
.
model
.
get_input_embeddings
(
input_ids
,
mm_embed
s
)
input_ids
,
encoder_output
s
)
else
:
else
:
inputs_embeds
=
self
.
model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
self
.
model
.
get_input_embeddings
(
input_ids
)
# TODO(woosuk): Avoid the copy. Optimize.
# TODO(woosuk): Avoid the copy. Optimize.
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment