Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a11f3265
Unverified
Commit
a11f3265
authored
Dec 08, 2024
by
Roger Wang
Committed by
GitHub
Dec 08, 2024
Browse files
[V1] Initial support of multimodal models for V1 re-arch (#10699)
Signed-off-by:
Roger Wang
<
ywang@roblox.com
>
parent
fd57d2b5
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
284 additions
and
69 deletions
+284
-69
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+8
-8
vllm/model_executor/models/interfaces.py
vllm/model_executor/models/interfaces.py
+5
-0
vllm/model_executor/models/internvl.py
vllm/model_executor/models/internvl.py
+57
-11
vllm/model_executor/models/molmo.py
vllm/model_executor/models/molmo.py
+63
-9
vllm/model_executor/models/pixtral.py
vllm/model_executor/models/pixtral.py
+92
-29
vllm/model_executor/models/utils.py
vllm/model_executor/models/utils.py
+27
-1
vllm/multimodal/inputs.py
vllm/multimodal/inputs.py
+2
-1
vllm/multimodal/utils.py
vllm/multimodal/utils.py
+6
-4
vllm/v1/core/scheduler.py
vllm/v1/core/scheduler.py
+2
-2
vllm/v1/engine/llm_engine.py
vllm/v1/engine/llm_engine.py
+21
-3
vllm/v1/engine/mm_input_mapper.py
vllm/v1/engine/mm_input_mapper.py
+1
-1
No files found.
vllm/engine/arg_utils.py
View file @
a11f3265
...
@@ -1050,9 +1050,12 @@ class EngineArgs:
...
@@ -1050,9 +1050,12 @@ class EngineArgs:
# long context (> 32K) models. This is to avoid OOM errors in the
# long context (> 32K) models. This is to avoid OOM errors in the
# initial memory profiling phase.
# initial memory profiling phase.
# Chunked prefill is currently disabled for multimodal models by
# For multimodal models, chunked prefill is disabled by default in
# default.
# V0, but enabled by design in V1
if
use_long_context
and
not
model_config
.
is_multimodal_model
:
if
model_config
.
is_multimodal_model
:
self
.
enable_chunked_prefill
=
bool
(
envs
.
VLLM_USE_V1
)
elif
use_long_context
:
is_gpu
=
device_config
.
device_type
==
"cuda"
is_gpu
=
device_config
.
device_type
==
"cuda"
use_sliding_window
=
(
model_config
.
get_sliding_window
()
use_sliding_window
=
(
model_config
.
get_sliding_window
()
is
not
None
)
is
not
None
)
...
@@ -1241,12 +1244,9 @@ class EngineArgs:
...
@@ -1241,12 +1244,9 @@ class EngineArgs:
Override the EngineConfig's configs based on the usage context for V1.
Override the EngineConfig's configs based on the usage context for V1.
"""
"""
assert
envs
.
VLLM_USE_V1
,
"V1 is not enabled"
assert
envs
.
VLLM_USE_V1
,
"V1 is not enabled"
# TODO (ywang96): Enable APC by default when VLM supports it.
if
engine_config
.
model_config
.
is_multimodal_model
:
if
engine_config
.
model_config
.
is_multimodal_model
:
logger
.
warning
(
# TODO (ywang96): Enable APC by default when VLM supports it.
"Prefix caching is currently not supported for multimodal "
assert
not
engine_config
.
cache_config
.
enable_prefix_caching
"models and has been disabled."
)
engine_config
.
cache_config
.
enable_prefix_caching
=
False
@
dataclass
@
dataclass
...
...
vllm/model_executor/models/interfaces.py
View file @
a11f3265
...
@@ -36,6 +36,11 @@ class SupportsMultiModal(Protocol):
...
@@ -36,6 +36,11 @@ class SupportsMultiModal(Protocol):
"""
"""
Returns multimodal embeddings generated from multimodal kwargs
Returns multimodal embeddings generated from multimodal kwargs
to be merged with text embeddings.
to be merged with text embeddings.
The output embeddings must be one of the following formats:
- A list or tuple of 2D tensors, where each tensor corresponds to
each input image.
- A single 3D tensor, with the batch dimension grouping the 2D tensors.
"""
"""
...
...
...
...
vllm/model_executor/models/internvl.py
View file @
a11f3265
...
@@ -26,7 +26,7 @@ from vllm.model_executor.models.intern_vit import (InternVisionModel,
...
@@ -26,7 +26,7 @@ from vllm.model_executor.models.intern_vit import (InternVisionModel,
InternVisionPatchModel
)
InternVisionPatchModel
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
from
vllm.multimodal.inputs
import
NestedTensors
from
vllm.multimodal.inputs
import
NestedTensors
,
PlaceholderRange
from
vllm.multimodal.utils
import
cached_get_tokenizer
from
vllm.multimodal.utils
import
cached_get_tokenizer
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
is_list_of
from
vllm.utils
import
is_list_of
...
@@ -52,12 +52,18 @@ class InternVLImagePixelInputs(TypedDict):
...
@@ -52,12 +52,18 @@ class InternVLImagePixelInputs(TypedDict):
Shape:
Shape:
`(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
`(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
"""
"""
patches_per_image
:
List
[
int
]
"""
List of number of total patches for each image in the batch.
"""
class
InternVLImageEmbeddingInputs
(
TypedDict
):
class
InternVLImageEmbeddingInputs
(
TypedDict
):
type
:
Literal
[
"image_embeds"
]
type
:
Literal
[
"image_embeds"
]
data
:
torch
.
Tensor
data
:
NestedTensors
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
"""
A tensor of shape `(num_images, total_image_feature_size, hidden_size)`
or a list of tensors of shape `(total_image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
`hidden_size` must match the hidden size of language model backbone.
"""
"""
...
@@ -349,10 +355,32 @@ class InternVLInputPipeline:
...
@@ -349,10 +355,32 @@ class InternVLInputPipeline:
new_prompt
=
self
.
_expand_image_prompt
(
prompt
,
image_feature_sizes
,
new_prompt
=
self
.
_expand_image_prompt
(
prompt
,
image_feature_sizes
,
num_patches
)
num_patches
)
new_prompt_token_ids
=
tokenizer
.
encode
(
new_prompt
)
new_prompt_token_ids
=
tokenizer
.
encode
(
new_prompt
)
img_context_token_id
=
tokenizer
.
encode
(
self
.
img_context_token
,
add_special_tokens
=
False
)
assert
len
(
img_context_token_id
)
==
1
,
\
(
f
"Invalid image token '
{
self
.
img_context_token
}
': A valid image "
f
"token encodes to a single token ID, got
{
img_context_token_id
}
."
)
img_context_token_id
=
img_context_token_id
[
0
]
# Get precise tracking of placeholder positions
token_idx
=
image_idx
=
0
placeholder_ranges
=
[]
while
token_idx
<
len
(
new_prompt_token_ids
):
if
new_prompt_token_ids
[
token_idx
]
==
img_context_token_id
:
curr_image_featue_size
=
image_feature_sizes
[
image_idx
]
placeholder_ranges
.
append
(
PlaceholderRange
(
offset
=
token_idx
,
length
=
curr_image_featue_size
))
image_idx
+=
1
token_idx
+=
curr_image_featue_size
else
:
token_idx
+=
1
return
token_inputs
(
prompt
=
prompt
,
return
token_inputs
(
prompt_token_ids
=
new_prompt_token_ids
,
prompt
=
prompt
,
multi_modal_data
=
multi_modal_data
)
prompt_token_ids
=
new_prompt_token_ids
,
multi_modal_data
=
multi_modal_data
,
multi_modal_placeholders
=
{
"image"
:
placeholder_ranges
})
def
input_mapper
(
def
input_mapper
(
self
,
self
,
...
@@ -614,26 +642,46 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -614,26 +642,46 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
if
not
isinstance
(
pixel_values
,
(
torch
.
Tensor
,
list
)):
if
not
isinstance
(
pixel_values
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of pixel values. "
raise
ValueError
(
"Incorrect type of pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
f
"Got type:
{
type
(
pixel_values
)
}
"
)
patches_per_image
=
[]
for
request_pixel_values
in
pixel_values
:
for
image_pixel_values
in
request_pixel_values
:
patches_per_image
.
append
(
image_pixel_values
.
shape
[
0
])
# We need to flatten (B, N, P) to (B*N*P),
# We need to flatten (B, N, P) to (B*N*P),
# so we call flatten_bn twice.
# so we call flatten_bn twice.
return
InternVLImagePixelInputs
(
return
InternVLImagePixelInputs
(
type
=
"pixel_values"
,
type
=
"pixel_values"
,
data
=
self
.
_validate_pixel_values
(
data
=
self
.
_validate_pixel_values
(
flatten_bn
(
flatten_bn
(
pixel_values
),
concat
=
True
)),
flatten_bn
(
flatten_bn
(
pixel_values
),
concat
=
True
)),
)
patches_per_image
=
patches_per_image
)
raise
AssertionError
(
"This line should be unreachable."
)
raise
AssertionError
(
"This line should be unreachable."
)
def
_process_image_input
(
def
_process_image_input
(
self
,
self
,
image_input
:
InternVLImageInputs
,
image_input
:
InternVLImageInputs
,
)
->
torch
.
Tensor
:
)
->
Tuple
[
torch
.
Tensor
]
:
if
image_input
[
"type"
]
==
"image_embeds"
:
if
image_input
[
"type"
]
==
"image_embeds"
:
return
image_input
[
"data"
]
return
image_input
[
"data"
]
assert
self
.
vision_model
is
not
None
assert
self
.
vision_model
is
not
None
image_embeds
=
self
.
extract_feature
(
image_input
[
"data"
])
image_embeds
=
self
.
extract_feature
(
image_input
[
"data"
])
patches_per_image
=
image_input
[
"patches_per_image"
]
if
len
(
patches_per_image
)
==
1
:
image_embeds
=
image_embeds
.
unsqueeze
(
0
)
return
image_embeds
# NOTE: Image embeddings are split into separate tensors for each image
# by the size of each embedding.
feature_size
=
image_embeds
.
shape
[
1
]
image_embeds
=
image_embeds
.
view
(
-
1
,
self
.
config
.
text_config
.
hidden_size
)
image_feature_sizes
=
[
num_patches
*
feature_size
for
num_patches
in
patches_per_image
]
image_embeds
=
image_embeds
.
split
(
image_feature_sizes
)
return
image_embeds
return
image_embeds
def
_set_visual_token_mask
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_set_visual_token_mask
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -696,13 +744,11 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -696,13 +744,11 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
"inputs_embeds"
:
inputs_embeds
,
"inputs_embeds"
:
inputs_embeds
,
}
}
# Only required if the model is mono-architecture
if
self
.
visual_token_mask
is
not
None
:
if
self
.
visual_token_mask
is
not
None
:
# overwrite visual_token_mask and img_context_token_id back to None,
# so that this doesn't need to depend on encoder output
forward_kwargs
.
update
(
forward_kwargs
.
update
(
{
"visual_token_mask"
:
self
.
visual_token_mask
})
{
"visual_token_mask"
:
self
.
visual_token_mask
})
self
.
visual_token_mask
=
None
self
.
visual_token_mask
=
None
self
.
img_context_token_id
=
None
hidden_states
=
self
.
language_model
.
model
(
**
forward_kwargs
)
hidden_states
=
self
.
language_model
.
model
(
**
forward_kwargs
)
return
hidden_states
return
hidden_states
...
...
vllm/model_executor/models/molmo.py
View file @
a11f3265
...
@@ -37,7 +37,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -37,7 +37,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead
,
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
from
vllm.multimodal.inputs
import
NestedTensors
from
vllm.multimodal.inputs
import
NestedTensors
,
PlaceholderRange
from
vllm.multimodal.utils
import
cached_get_tokenizer
from
vllm.multimodal.utils
import
cached_get_tokenizer
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
IntermediateTensors
,
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
IntermediateTensors
,
SequenceData
)
SequenceData
)
...
@@ -46,12 +46,16 @@ from vllm.transformers_utils.processor import get_processor
...
@@ -46,12 +46,16 @@ from vllm.transformers_utils.processor import get_processor
from
.interfaces
import
SupportsMultiModal
,
SupportsPP
from
.interfaces
import
SupportsMultiModal
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
is_pp_missing_parameter
,
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
maybe_prefix
,
merge_multimodal_embeddings
)
# TODO: hard-coded for now. Consider making it configurable.
# TODO: hard-coded for now. Consider making it configurable.
VIT_LAYERS
=
[
-
2
,
-
9
]
VIT_LAYERS
=
[
-
2
,
-
9
]
NUM_PREFIX_TOKENS
=
1
NUM_PREFIX_TOKENS
=
1
ADDITIONAL_VOCAB_SIZE
=
128
ADDITIONAL_VOCAB_SIZE
=
128
DEFAULT_IMAGE_PATCH_TOKEN_ID
=
152066
DEFAULT_IM_START_TOKEN_ID
=
152067
DEFAULT_IM_END_TOKEN_ID
=
152064
DEFAULT_IM_COL_TOKEN_ID
=
152065
class
MolmoImageInputs
(
TypedDict
):
class
MolmoImageInputs
(
TypedDict
):
...
@@ -75,6 +79,11 @@ class MolmoImageInputs(TypedDict):
...
@@ -75,6 +79,11 @@ class MolmoImageInputs(TypedDict):
`(batch_size, num_crops, num_patch)`
`(batch_size, num_crops, num_patch)`
"""
"""
image_start_end
:
Tuple
[
int
,
int
]
"""Starting and ending index of placeholder
tokens
"""
@
dataclass
@
dataclass
class
VisionBackboneConfig
:
class
VisionBackboneConfig
:
...
@@ -918,6 +927,8 @@ def image_input_mapper_for_molmo(
...
@@ -918,6 +927,8 @@ def image_input_mapper_for_molmo(
ctx
:
InputContext
,
ctx
:
InputContext
,
data
:
object
,
data
:
object
,
):
):
if
isinstance
(
data
,
list
):
data
=
data
[
0
]
return
MultiModalKwargs
(
data
)
return
MultiModalKwargs
(
data
)
...
@@ -967,7 +978,22 @@ def dummy_data_for_molmo(ctx: InputContext, seq_len: int,
...
@@ -967,7 +978,22 @@ def dummy_data_for_molmo(ctx: InputContext, seq_len: int,
if
"image_masks"
in
out
:
if
"image_masks"
in
out
:
dummy_imgdata
[
"image_masks"
]
=
out
[
"image_masks"
]
dummy_imgdata
[
"image_masks"
]
=
out
[
"image_masks"
]
dummy_imgdata
[
"seq_len"
]
=
torch
.
tensor
(
seq_len
,
dtype
=
torch
.
long
)
dummy_imgdata
[
"seq_len"
]
=
torch
.
tensor
(
seq_len
,
dtype
=
torch
.
long
)
return
DummyData
(
dummy_seqdata
,
{
"image"
:
dummy_imgdata
})
size
=
0
offset
=
-
1
for
i
in
range
(
len
(
token_ids
)):
if
token_ids
[
i
]
in
(
DEFAULT_IMAGE_PATCH_TOKEN_ID
,
DEFAULT_IM_START_TOKEN_ID
,
DEFAULT_IM_END_TOKEN_ID
,
DEFAULT_IM_COL_TOKEN_ID
):
if
offset
<
0
:
offset
=
i
size
+=
1
dummy_imgdata
[
"image_start_end"
]
=
(
offset
,
offset
+
size
)
return
DummyData
(
seq_data
=
dummy_seqdata
,
multi_modal_data
=
{
"image"
:
dummy_imgdata
},
multi_modal_placeholders
=
{
"image"
:
[
PlaceholderRange
(
offset
=
offset
,
length
=
size
)]
})
def
pad_images
(
def
pad_images
(
...
@@ -1055,19 +1081,34 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs):
...
@@ -1055,19 +1081,34 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs):
if
image_masks
is
not
None
:
if
image_masks
is
not
None
:
image_data
[
"image_masks"
]
=
image_masks
image_data
[
"image_masks"
]
=
image_masks
image_data
[
"seq_len"
]
=
torch
.
tensor
(
len
(
out
[
"input_ids"
]),
new_prompt_token_ids
=
out
[
"input_ids"
].
tolist
()
image_data
[
"seq_len"
]
=
torch
.
tensor
(
len
(
new_prompt_token_ids
),
dtype
=
torch
.
long
)
dtype
=
torch
.
long
)
multi_modal_data
=
dict
(
image
=
image_data
)
multi_modal_data
=
dict
(
image
=
image_data
)
size
=
0
offset
=
-
1
for
i
in
range
(
len
(
new_prompt_token_ids
)):
if
new_prompt_token_ids
[
i
]
in
(
DEFAULT_IMAGE_PATCH_TOKEN_ID
,
DEFAULT_IM_START_TOKEN_ID
,
DEFAULT_IM_END_TOKEN_ID
,
DEFAULT_IM_COL_TOKEN_ID
):
if
offset
<
0
:
offset
=
i
size
+=
1
image_data
[
"image_start_end"
]
=
(
offset
,
offset
+
size
)
prompt
=
inputs
.
get
(
"prompt"
)
prompt
=
inputs
.
get
(
"prompt"
)
if
prompt
is
None
:
if
prompt
is
None
:
prompt
=
tokenizer
.
decode
(
out
[
"input
_ids
"
]
)
prompt
=
tokenizer
.
decode
(
new_prompt_token
_ids
)
return
token_inputs
(
return
token_inputs
(
prompt_token_ids
=
out
[
"input
_ids
"
]
,
prompt_token_ids
=
new_prompt_token
_ids
,
prompt
=
prompt
,
prompt
=
prompt
,
multi_modal_data
=
multi_modal_data
,
multi_modal_data
=
multi_modal_data
,
multi_modal_placeholders
=
{
"image"
:
[
PlaceholderRange
(
offset
=
offset
,
length
=
size
)]
},
)
)
...
@@ -1113,6 +1154,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -1113,6 +1154,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
)
->
Optional
[
MolmoImageInputs
]:
)
->
Optional
[
MolmoImageInputs
]:
images
=
kwargs
.
pop
(
"images"
,
None
)
images
=
kwargs
.
pop
(
"images"
,
None
)
image_masks
=
kwargs
.
pop
(
"image_masks"
,
None
)
image_masks
=
kwargs
.
pop
(
"image_masks"
,
None
)
image_start_end
=
kwargs
.
pop
(
"image_start_end"
,
None
)
if
images
is
None
:
if
images
is
None
:
return
None
return
None
...
@@ -1130,6 +1172,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -1130,6 +1172,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
image_input_idx
=
image_input_idx
,
image_input_idx
=
image_input_idx
,
seq_len
=
seq_len
,
seq_len
=
seq_len
,
image_masks
=
image_masks
,
image_masks
=
image_masks
,
image_start_end
=
image_start_end
,
)
)
def
_process_image_input
(
def
_process_image_input
(
...
@@ -1178,9 +1221,16 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -1178,9 +1221,16 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
# Note: In this original implementation from AI2, the final
# Note: In this original implementation from AI2, the final
# vision_embeddings will be always be the same length
# vision_embeddings will be always be the same length
# of input embedddings, which is not very efficient.
# of input embeddings.
# TODO(ywang96): see if this can be optimized.
vision_embeddings
=
torch
.
einsum
(
'nd,nm->md'
,
image_features
,
mat
)
vision_embeddings
=
torch
.
einsum
(
'nd,nm->md'
,
image_features
,
mat
)
# Split by the sizes of the input sequences. For each full embedding,
# extract the actual vision embeddings to be merged.
vision_embeddings
=
list
(
vision_embeddings
.
split
(
seq_len
.
tolist
()))
for
i
in
range
(
len
(
vision_embeddings
)):
start
,
end
=
image_input
[
'image_start_end'
][
i
]
vision_embeddings
[
i
]
=
vision_embeddings
[
i
][
start
:
end
]
return
vision_embeddings
return
vision_embeddings
def
get_input_embeddings
(
def
get_input_embeddings
(
...
@@ -1190,7 +1240,11 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -1190,7 +1240,11 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
self
.
model
.
get_input_embeddings
(
input_ids
)
if
multimodal_embeddings
is
not
None
:
if
multimodal_embeddings
is
not
None
:
inputs_embeds
=
inputs_embeds
+
multimodal_embeddings
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
multimodal_embeddings
,
[
DEFAULT_IMAGE_PATCH_TOKEN_ID
,
DEFAULT_IM_START_TOKEN_ID
,
DEFAULT_IM_END_TOKEN_ID
,
DEFAULT_IM_COL_TOKEN_ID
])
return
inputs_embeds
return
inputs_embeds
def
forward
(
def
forward
(
...
...
vllm/model_executor/models/pixtral.py
View file @
a11f3265
...
@@ -48,6 +48,9 @@ try:
...
@@ -48,6 +48,9 @@ try:
except
ImportError
:
except
ImportError
:
USE_XFORMERS_OPS
=
False
USE_XFORMERS_OPS
=
False
PIXTRAL_IMAGE_BREAK_ID
=
12
PIXTRAL_IMAGE_END_ID
=
13
def
get_max_pixtral_image_tokens
(
ctx
:
InputContext
):
def
get_max_pixtral_image_tokens
(
ctx
:
InputContext
):
tokenizer
=
cached_get_tokenizer
(
tokenizer
=
cached_get_tokenizer
(
...
@@ -68,7 +71,6 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
...
@@ -68,7 +71,6 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
tokenizer_mode
=
ctx
.
model_config
.
tokenizer_mode
)
tokenizer_mode
=
ctx
.
model_config
.
tokenizer_mode
)
mm_encoder
=
tokenizer
.
mistral
.
instruct_tokenizer
.
mm_encoder
mm_encoder
=
tokenizer
.
mistral
.
instruct_tokenizer
.
mm_encoder
patch_size
=
mm_encoder
.
mm_config
.
image_patch_size
image_token_id
=
mm_encoder
.
special_ids
.
img
image_token_id
=
mm_encoder
.
special_ids
.
img
mm_config
=
ctx
.
model_config
.
multimodal_config
mm_config
=
ctx
.
model_config
.
multimodal_config
...
@@ -78,8 +80,8 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
...
@@ -78,8 +80,8 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
size
=
256
size
=
256
image
=
Image
.
new
(
"RGB"
,
(
size
,
size
),
color
=
0
)
image
=
Image
.
new
(
"RGB"
,
(
size
,
size
),
color
=
0
)
image_feature_size
=
(
size
**
2
)
//
(
patch_size
**
2
)
encoding
=
tokenizer
.
instruct
.
mm_encoder
(
ImageChunk
(
image
=
image
)
)
image_feature_size
=
len
(
encoding
.
tokens
)
num_image_tokens
=
image_feature_size
*
num_images
num_image_tokens
=
image_feature_size
*
num_images
seq_data
=
SequenceData
.
from_prompt_token_counts
(
seq_data
=
SequenceData
.
from_prompt_token_counts
(
(
image_token_id
,
num_image_tokens
),
(
image_token_id
,
num_image_tokens
),
...
@@ -101,14 +103,13 @@ def input_mapper_for_pixtral(ctx: InputContext,
...
@@ -101,14 +103,13 @@ def input_mapper_for_pixtral(ctx: InputContext,
Args:
Args:
ctx: Context of the loaded model.
ctx: Context of the loaded model.
data: data potentially containing image
/image embedding
s to be
mapp
ed
data: data potentially containing
PIL
images to be
process
ed
to pixel_values in .forward() for a visual QWenLMHeadModel model
.
and mapped to `images`
.
Returns:
Returns:
MultiModalKwargs containing the stacked normalized images tensor or
MultiModalKwargs containing the stacked normalized images tensor or
image embeddings.
image embeddings.
"""
"""
# Early exit if we have provided an image to a language only Qwen model
model_config
=
ctx
.
model_config
model_config
=
ctx
.
model_config
tokenizer
=
cached_get_tokenizer
(
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
tokenizer_mode
=
model_config
.
tokenizer_mode
)
model_config
.
tokenizer
,
tokenizer_mode
=
model_config
.
tokenizer_mode
)
...
@@ -116,35 +117,67 @@ def input_mapper_for_pixtral(ctx: InputContext,
...
@@ -116,35 +117,67 @@ def input_mapper_for_pixtral(ctx: InputContext,
data_list
=
data
if
isinstance
(
data
,
list
)
else
[
data
]
data_list
=
data
if
isinstance
(
data
,
list
)
else
[
data
]
images
=
[]
images
=
[]
image_tokens_list
=
[]
for
image_data
in
data_list
:
for
image_data
in
data_list
:
image
=
ImageChunk
(
image
=
image_data
)
image
=
ImageChunk
(
image
=
image_data
)
encoding
=
tokenizer
.
instruct
.
mm_encoder
(
image
)
encoding
=
tokenizer
.
instruct
.
mm_encoder
(
image
)
image
=
torch
.
from_numpy
(
encoding
.
image
).
to
(
device
=
"cuda"
,
image
=
torch
.
from_numpy
(
encoding
.
image
).
to
(
device
=
"cuda"
,
dtype
=
torch
.
float16
)
dtype
=
torch
.
float16
)
images
.
append
(
image
)
images
.
append
(
image
)
image_tokens_list
.
append
(
encoding
.
tokens
)
return
MultiModalKwargs
({
"images"
:
images
})
image_tokens
=
torch
.
tensor
([
token_id
for
image_tokens
in
image_tokens_list
for
token_id
in
image_tokens
])
return
MultiModalKwargs
({
"images"
:
images
,
"image_tokens"
:
image_tokens
})
def
input_processor_for_pixtral
(
ctx
:
InputContext
,
inputs
:
DecoderOnlyInputs
):
def
input_processor_for_pixtral
(
ctx
:
InputContext
,
inputs
:
DecoderOnlyInputs
):
multi_modal_data
=
inputs
.
get
(
"multi_modal_data"
)
multi_modal_data
=
inputs
.
get
(
"multi_modal_data"
)
if
multi_modal_data
is
not
None
and
"image"
in
multi_modal_data
:
if
multi_modal_data
is
None
or
"image"
not
in
multi_modal_data
:
tokenizer
=
cached_get_tokenizer
(
return
inputs
ctx
.
model_config
.
tokenizer
,
tokenizer_mode
=
ctx
.
model_config
.
tokenizer_mode
)
mm_encoder
=
tokenizer
.
mistral
.
instruct_tokenizer
.
mm_encoder
image_token_id
=
mm_encoder
.
special_ids
.
img
if
image_token_id
not
in
inputs
[
'prompt_token_ids'
]:
prompt_token_ids
=
inputs
.
get
(
"prompt_token_ids"
)
raise
ValueError
(
prompt
=
inputs
.
get
(
"prompt"
)
f
"You've passed
{
inputs
=
}
without
{
image_token_id
=
}
"
tokenizer
=
cached_get_tokenizer
(
" Make sure to process your input via mistral_common's"
ctx
.
model_config
.
tokenizer
,
" tokenizer or pass a chat completion request. For more"
tokenizer_mode
=
ctx
.
model_config
.
tokenizer_mode
)
" For more info, see: "
"https://github.com/vllm-project/vllm/issues/8411."
)
return
inputs
mm_encoder
=
tokenizer
.
mistral
.
instruct_tokenizer
.
mm_encoder
image_token_id
=
mm_encoder
.
special_ids
.
img
image_break_id
=
mm_encoder
.
special_ids
.
img_break
image_end_id
=
mm_encoder
.
special_ids
.
img_end
if
image_token_id
not
in
inputs
[
'prompt_token_ids'
]:
raise
ValueError
(
f
"You've passed
{
inputs
=
}
without
{
image_token_id
=
}
"
" Make sure to process your input via mistral_common's"
" tokenizer or pass a chat completion request. For more"
" For more info, see: "
"https://github.com/vllm-project/vllm/issues/8411."
)
# Get precise tracking of placeholder positions
placeholder_ranges
=
[]
curr_offset
=
-
1
curr_length
=
0
for
i
in
range
(
len
(
prompt_token_ids
)):
if
prompt_token_ids
[
i
]
in
(
image_token_id
,
image_break_id
):
if
curr_offset
<
0
:
curr_offset
=
i
curr_length
+=
1
elif
prompt_token_ids
[
i
]
==
image_end_id
:
curr_length
+=
1
placeholder_ranges
.
append
(
PlaceholderRange
(
offset
=
curr_offset
,
length
=
curr_length
))
curr_offset
=
-
1
curr_length
=
0
else
:
pass
return
token_inputs
(
prompt
=
prompt
,
prompt_token_ids
=
prompt_token_ids
,
multi_modal_data
=
multi_modal_data
,
multi_modal_placeholders
=
{
"image"
:
placeholder_ranges
})
@
MULTIMODAL_REGISTRY
.
register_image_input_mapper
(
input_mapper_for_pixtral
)
@
MULTIMODAL_REGISTRY
.
register_image_input_mapper
(
input_mapper_for_pixtral
)
...
@@ -192,11 +225,29 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -192,11 +225,29 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
return
get_sampler
()
return
get_sampler
()
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Optional
[
NestedTensors
]:
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Optional
[
NestedTensors
]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
image_input
,
image_tokens
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
if
image_input
is
None
:
return
None
return
None
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
return
vision_embeddings
# NOTE: We patch the outputs of the vision encoder with embeddings
# from `[IMG_BREAK]` and `[IMG_END]` tokens.
image_embeds
=
self
.
language_model
.
get_input_embeddings
(
image_tokens
)
image_token_mask
=
image_tokens
==
self
.
vision_args
.
image_token_id
image_embeds
[
image_token_mask
]
=
vision_embeddings
# NOTE: Image embeddings are split into separate tensors for each image
# by the indices of `[IMG_END]` token.
split_indices
=
torch
.
where
(
image_tokens
==
PIXTRAL_IMAGE_END_ID
)[
0
]
+
1
if
len
(
split_indices
)
<=
1
:
# Do not split, return as tensor of shape [1, fs, hs]
return
image_embeds
.
unsqueeze
(
0
)
image_embeds
=
image_embeds
.
tensor_split
(
split_indices
.
cpu
())
return
image_embeds
def
get_input_embeddings
(
def
get_input_embeddings
(
self
,
self
,
...
@@ -206,8 +257,10 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -206,8 +257,10 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
if
multimodal_embeddings
is
not
None
:
if
multimodal_embeddings
is
not
None
:
inputs_embeds
=
merge_multimodal_embeddings
(
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
multimodal_embeddings
,
input_ids
,
inputs_embeds
,
multimodal_embeddings
,
[
self
.
vision_args
.
image_token_id
)
self
.
vision_args
.
image_token_id
,
PIXTRAL_IMAGE_END_ID
,
PIXTRAL_IMAGE_BREAK_ID
])
return
inputs_embeds
return
inputs_embeds
def
forward
(
def
forward
(
...
@@ -245,10 +298,11 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -245,10 +298,11 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
def
_parse_and_validate_image_input
(
def
_parse_and_validate_image_input
(
self
,
self
,
images
:
Optional
[
Union
[
List
[
List
[
torch
.
Tensor
]],
List
[
torch
.
Tensor
],
images
:
Optional
[
Union
[
List
[
List
[
torch
.
Tensor
]],
List
[
torch
.
Tensor
],
torch
.
Tensor
]]
=
None
torch
.
Tensor
]]
=
None
,
image_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Optional
[
List
[
torch
.
Tensor
]]:
)
->
Optional
[
List
[
torch
.
Tensor
]]:
if
images
is
None
:
if
images
is
None
:
return
None
return
None
,
None
if
isinstance
(
images
,
torch
.
Tensor
):
if
isinstance
(
images
,
torch
.
Tensor
):
# if passed as batch take all images
# if passed as batch take all images
...
@@ -267,7 +321,16 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -267,7 +321,16 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
images
=
flatten_images
images
=
flatten_images
return
images
if
isinstance
(
image_tokens
,
torch
.
Tensor
):
# image_tokens are batched
image_tokens
=
image_tokens
.
flatten
()
elif
isinstance
(
image_tokens
,
list
):
# image_tokens are of different lengths thus passed as a list
image_tokens
=
torch
.
cat
(
image_tokens
)
assert
image_tokens
.
dim
()
==
1
return
images
,
image_tokens
def
_process_image_input
(
self
,
def
_process_image_input
(
self
,
image_input
:
List
[
torch
.
Tensor
])
->
torch
.
Tensor
:
image_input
:
List
[
torch
.
Tensor
])
->
torch
.
Tensor
:
...
...
vllm/model_executor/models/utils.py
View file @
a11f3265
...
@@ -409,16 +409,42 @@ def merge_multimodal_embeddings(
...
@@ -409,16 +409,42 @@ def merge_multimodal_embeddings(
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
inputs_embeds
:
torch
.
Tensor
,
inputs_embeds
:
torch
.
Tensor
,
multimodal_embeddings
:
NestedTensors
,
multimodal_embeddings
:
NestedTensors
,
placeholder_token_id
:
int
,
placeholder_token_id
:
Union
[
int
,
List
[
int
]]
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
positions in ``inputs_embeds`` corresponding to placeholder tokens in
positions in ``inputs_embeds`` corresponding to placeholder tokens in
``input_ids``.
``input_ids``.
``placeholder_token_id`` can be a list of token ids (e.g, token ids
of img_start, img_break, and img_end tokens) when needed: This means
the order of these tokens in the ``input_ids`` MUST MATCH the order of
their embeddings in ``multimodal_embeddings`` since we need to
slice-merge instead of individually scattering.
For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where
- T is text token
- S is image start token
- I is image embedding token
- B is image break token
- E is image end token.
Then the image embeddings (that correspond to I's) from vision encoder
must be padded with embeddings of S, B, and E in the same order of
input_ids for a correct embedding merge.
Note:
Note:
This updates ``inputs_embeds`` in place.
This updates ``inputs_embeds`` in place.
"""
"""
if
isinstance
(
placeholder_token_id
,
list
):
placeholder_token_id
=
torch
.
tensor
(
placeholder_token_id
,
device
=
input_ids
.
device
)
return
_merge_multimodal_embeddings
(
inputs_embeds
,
torch
.
isin
(
input_ids
,
placeholder_token_id
),
multimodal_embeddings
,
)
return
_merge_multimodal_embeddings
(
return
_merge_multimodal_embeddings
(
inputs_embeds
,
inputs_embeds
,
(
input_ids
==
placeholder_token_id
),
(
input_ids
==
placeholder_token_id
),
...
...
vllm/multimodal/inputs.py
View file @
a11f3265
...
@@ -96,7 +96,8 @@ class PlaceholderRange(TypedDict):
...
@@ -96,7 +96,8 @@ class PlaceholderRange(TypedDict):
"""The length of the placeholder."""
"""The length of the placeholder."""
NestedTensors
=
Union
[
List
[
"NestedTensors"
],
List
[
torch
.
Tensor
],
torch
.
Tensor
]
NestedTensors
=
Union
[
List
[
"NestedTensors"
],
List
[
torch
.
Tensor
],
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
...]]
"""
"""
Uses a list instead of a tensor if the dimensions of each element do not match.
Uses a list instead of a tensor if the dimensions of each element do not match.
"""
"""
...
...
vllm/multimodal/utils.py
View file @
a11f3265
...
@@ -535,11 +535,13 @@ def repeat_and_pad_placeholder_tokens(
...
@@ -535,11 +535,13 @@ def repeat_and_pad_placeholder_tokens(
return
new_prompt
,
new_token_ids
,
placeholder_ranges
return
new_prompt
,
new_token_ids
,
placeholder_ranges
def
consecutive_placeholder_ranges
(
num_items
:
int
,
def
consecutive_placeholder_ranges
(
item_size
:
int
)
->
List
[
PlaceholderRange
]:
num_items
:
int
,
item_size
:
int
,
initial_offset
:
int
=
0
)
->
List
[
PlaceholderRange
]:
"""Returns a list of consecutive PlaceholderRanges of a fixed size"""
"""Returns a list of consecutive PlaceholderRanges of a fixed size"""
return
[
return
[
PlaceholderRange
(
offset
=
i
*
item_size
,
length
=
item_size
)
PlaceholderRange
(
offset
=
i
nitial_offset
+
i
*
item_size
,
for
i
in
range
(
num_items
)
length
=
item_size
)
for
i
in
range
(
num_items
)
]
]
vllm/v1/core/scheduler.py
View file @
a11f3265
...
@@ -73,12 +73,12 @@ class Scheduler:
...
@@ -73,12 +73,12 @@ class Scheduler:
# has the Transformer architecture (e.g., ViT).
# has the Transformer architecture (e.g., ViT).
# FIXME(woosuk): Below are placeholder values. We need to calculate the
# FIXME(woosuk): Below are placeholder values. We need to calculate the
# actual values from the configurations.
# actual values from the configurations.
self
.
max_num_encoder_input_tokens
=
2048
self
.
max_num_encoder_input_tokens
=
16384
# NOTE(woosuk): For the models without encoder (e.g., text-only models),
# NOTE(woosuk): For the models without encoder (e.g., text-only models),
# the encoder cache will not be initialized and used, regardless of
# the encoder cache will not be initialized and used, regardless of
# the cache size. This is because the memory space for the encoder cache
# the cache size. This is because the memory space for the encoder cache
# is preallocated in the profiling run.
# is preallocated in the profiling run.
self
.
encoder_cache_manager
=
EncoderCacheManager
(
cache_size
=
2048
)
self
.
encoder_cache_manager
=
EncoderCacheManager
(
cache_size
=
16384
)
def
schedule
(
self
)
->
"SchedulerOutput"
:
def
schedule
(
self
)
->
"SchedulerOutput"
:
# NOTE(woosuk) on the scheduling algorithm:
# NOTE(woosuk) on the scheduling algorithm:
...
...
vllm/v1/engine/llm_engine.py
View file @
a11f3265
from
typing
import
Dict
,
List
,
Mapping
,
Optional
,
Type
,
Union
from
typing
import
Dict
,
List
,
Mapping
,
Optional
,
Type
,
Union
from
typing_extensions
import
TypeVar
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.metrics_types
import
StatLoggerBase
from
vllm.engine.metrics_types
import
StatLoggerBase
...
@@ -12,7 +14,8 @@ from vllm.outputs import RequestOutput
...
@@ -12,7 +14,8 @@ from vllm.outputs import RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
from
vllm.transformers_utils.tokenizer_group
import
(
BaseTokenizerGroup
,
init_tokenizer_from_configs
)
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.v1.engine.core_client
import
EngineCoreClient
from
vllm.v1.engine.core_client
import
EngineCoreClient
from
vllm.v1.engine.detokenizer
import
Detokenizer
from
vllm.v1.engine.detokenizer
import
Detokenizer
...
@@ -21,6 +24,8 @@ from vllm.v1.executor.gpu_executor import GPUExecutor
...
@@ -21,6 +24,8 @@ from vllm.v1.executor.gpu_executor import GPUExecutor
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_G
=
TypeVar
(
"_G"
,
bound
=
BaseTokenizerGroup
,
default
=
BaseTokenizerGroup
)
class
LLMEngine
:
class
LLMEngine
:
"""Legacy LLMEngine for backwards compatibility."""
"""Legacy LLMEngine for backwards compatibility."""
...
@@ -169,5 +174,18 @@ class LLMEngine:
...
@@ -169,5 +174,18 @@ class LLMEngine:
def
stop_profile
(
self
):
def
stop_profile
(
self
):
self
.
engine_core
.
profile
(
False
)
self
.
engine_core
.
profile
(
False
)
def
get_tokenizer_group
(
self
,
group_type
):
def
get_tokenizer_group
(
pass
self
,
group_type
:
Type
[
_G
]
=
BaseTokenizerGroup
,
)
->
_G
:
tokenizer_group
=
self
.
tokenizer
if
tokenizer_group
is
None
:
raise
ValueError
(
"Unable to get tokenizer because "
"skip_tokenizer_init is True"
)
if
not
isinstance
(
tokenizer_group
,
group_type
):
raise
TypeError
(
"Invalid type of tokenizer group. "
f
"Expected type:
{
group_type
}
, but "
f
"found type:
{
type
(
tokenizer_group
)
}
"
)
return
tokenizer_group
vllm/v1/engine/mm_input_mapper.py
View file @
a11f3265
...
@@ -33,7 +33,7 @@ class MMInputMapper:
...
@@ -33,7 +33,7 @@ class MMInputMapper:
num_images
=
len
(
image_inputs
)
num_images
=
len
(
image_inputs
)
for
i
in
range
(
num_images
):
for
i
in
range
(
num_images
):
mm_input
=
self
.
multi_modal_input_mapper
(
mm_input
=
self
.
multi_modal_input_mapper
(
{
"image"
:
[
image_inputs
[
i
]
]
},
{
"image"
:
image_inputs
[
i
]},
mm_processor_kwargs
=
mm_processor_kwargs
,
mm_processor_kwargs
=
mm_processor_kwargs
,
)
)
mm_inputs
.
append
(
mm_input
)
mm_inputs
.
append
(
mm_input
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment