Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8c38ee70
Unverified
Commit
8c38ee70
authored
Jan 03, 2025
by
Cyrus Leung
Committed by
GitHub
Jan 02, 2025
Browse files
[VLM] Merged multi-modal processor for LLaVA-NeXT (#11682)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
b6087a6b
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
605 additions
and
551 deletions
+605
-551
tests/models/decoder_only/vision_language/mm_processor_kwargs/test_llava_next.py
...ly/vision_language/mm_processor_kwargs/test_llava_next.py
+0
-70
tests/multimodal/test_mapper.py
tests/multimodal/test_mapper.py
+0
-118
tests/multimodal/test_processing.py
tests/multimodal/test_processing.py
+97
-0
tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py
...ins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py
+1
-3
vllm/model_executor/models/clip.py
vllm/model_executor/models/clip.py
+25
-0
vllm/model_executor/models/fuyu.py
vllm/model_executor/models/fuyu.py
+3
-3
vllm/model_executor/models/llava.py
vllm/model_executor/models/llava.py
+215
-119
vllm/model_executor/models/llava_next.py
vllm/model_executor/models/llava_next.py
+112
-209
vllm/model_executor/models/phi3v.py
vllm/model_executor/models/phi3v.py
+14
-10
vllm/model_executor/models/pixtral.py
vllm/model_executor/models/pixtral.py
+52
-14
vllm/model_executor/models/siglip.py
vllm/model_executor/models/siglip.py
+25
-0
vllm/model_executor/models/utils.py
vllm/model_executor/models/utils.py
+1
-1
vllm/model_executor/models/vision.py
vllm/model_executor/models/vision.py
+52
-0
vllm/multimodal/parse.py
vllm/multimodal/parse.py
+8
-4
No files found.
tests/models/decoder_only/vision_language/mm_processor_kwargs/test_llava_next.py
deleted
100644 → 0
View file @
b6087a6b
import
pytest
from
vllm.inputs
import
InputContext
from
....utils
import
build_model_context
@
pytest
.
fixture
()
def
get_max_llava_next_image_tokens
():
from
vllm.model_executor.models.llava_next
import
(
get_max_llava_next_image_tokens
)
return
get_max_llava_next_image_tokens
@
pytest
.
fixture
()
def
dummy_data_for_llava_next
():
from
vllm.model_executor.models.llava_next
import
dummy_data_for_llava_next
return
dummy_data_for_llava_next
@
pytest
.
mark
.
parametrize
(
"gridpoints,expected_max_tokens"
,
[
([[
336
,
336
]],
1176
),
([[
336
,
672
],
[
672
,
336
],
[
672
,
672
],
[
1008
,
336
],
[
336
,
1008
]],
2928
),
])
def
test_get_max_llava_next_image_tokens
(
gridpoints
,
expected_max_tokens
,
get_max_llava_next_image_tokens
):
ctx
=
build_model_context
(
model_name
=
"llava-hf/llava-v1.6-mistral-7b-hf"
)
# Update the config image_grid_pinpoints
# and calculate the resulting max tokens
ctx
.
model_config
.
hf_config
.
image_grid_pinpoints
=
gridpoints
actual_max_tokens
=
get_max_llava_next_image_tokens
(
InputContext
(
ctx
.
model_config
))
assert
expected_max_tokens
==
actual_max_tokens
@
pytest
.
mark
.
parametrize
(
"gridpoints,expected_size"
,
[
# One point; it has to be the largest
([[
336
,
336
]],
(
336
,
336
)),
# Default for most llava next models; the 2x2 tile is the largest
([[
336
,
672
],
[
672
,
336
],
[
672
,
672
],
[
1008
,
336
],
[
336
,
1008
]],
(
672
,
672
)),
# If two rectangular gridpoints are the same, the more vertical
# one has the higher feature count due to newline features
([[
336
,
672
],
[
672
,
336
]],
(
672
,
336
))
])
def
test_dummy_data_for_llava_next_feature_size
(
dummy_data_for_llava_next
,
gridpoints
,
expected_size
):
ctx
=
build_model_context
(
model_name
=
"llava-hf/llava-v1.6-mistral-7b-hf"
)
# Update the config image_grid_pinpoints
ctx
.
model_config
.
hf_config
.
image_grid_pinpoints
=
gridpoints
seq_len
=
5000
# bigger than the max feature size for any image
dummy_data
=
dummy_data_for_llava_next
(
ctx
,
seq_len
=
seq_len
,
mm_counts
=
{
"image"
:
1
},
)
seq_data
=
dummy_data
.
seq_data
mm_data
=
dummy_data
.
multi_modal_data
# The dummy data dims should match the gridpoint with the biggest feat size
assert
mm_data
[
"image"
].
height
==
expected_size
[
0
]
assert
mm_data
[
"image"
].
width
==
expected_size
[
1
]
assert
len
(
seq_data
.
get_token_ids
())
>=
seq_len
tests/multimodal/test_mapper.py
deleted
100644 → 0
View file @
b6087a6b
from
contextlib
import
nullcontext
import
numpy
as
np
import
pytest
from
transformers
import
LlavaNextImageProcessor
from
vllm.config
import
ModelConfig
from
vllm.multimodal
import
MultiModalRegistry
from
vllm.multimodal.image
import
rescale_image_size
@
pytest
.
fixture
def
mm_registry
():
return
MultiModalRegistry
()
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
,
"float"
])
@
pytest
.
mark
.
parametrize
(
"size_factor"
,
[
0.25
,
0.5
,
1.0
])
def
test_llava_next_image_processor
(
image_assets
,
mm_registry
,
dtype
,
size_factor
):
MODEL_NAME
=
"llava-hf/llava-v1.6-vicuna-7b-hf"
hf_processor
=
LlavaNextImageProcessor
.
from_pretrained
(
MODEL_NAME
)
assert
isinstance
(
hf_processor
,
LlavaNextImageProcessor
)
model_config
=
ModelConfig
(
model
=
MODEL_NAME
,
task
=
"auto"
,
tokenizer
=
MODEL_NAME
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
seed
=
0
,
dtype
=
dtype
,
revision
=
None
,
limit_mm_per_prompt
=
{
"image"
:
1
},
)
mm_registry
.
init_mm_limits_per_prompt
(
model_config
)
for
asset
in
image_assets
:
image
=
rescale_image_size
(
asset
.
pil_image
,
size_factor
)
hf_result
=
hf_processor
.
preprocess
(
image
,
return_tensors
=
"pt"
,
)
vllm_result
=
mm_registry
.
map_input
(
model_config
,
{
"image"
:
image
},
)
assert
hf_result
.
keys
()
==
vllm_result
.
keys
()
for
key
,
hf_tensor
in
hf_result
.
items
():
hf_arr
:
np
.
ndarray
=
hf_tensor
.
numpy
()
vllm_arr
:
np
.
ndarray
=
vllm_result
[
key
].
numpy
()
assert
hf_arr
.
shape
==
vllm_arr
.
shape
,
f
"Failed for key=
{
key
}
"
assert
np
.
allclose
(
hf_arr
,
vllm_arr
),
f
"Failed for key=
{
key
}
"
@
pytest
.
mark
.
parametrize
(
(
"num_images"
,
"limit"
,
"is_valid"
),
[(
0
,
0
,
True
),
(
0
,
1
,
True
),
(
1
,
0
,
False
),
(
1
,
1
,
True
),
(
1
,
2
,
True
),
(
2
,
1
,
False
),
(
2
,
2
,
True
)],
)
def
test_mm_limits
(
image_assets
,
mm_registry
,
num_images
,
limit
,
is_valid
):
MODEL_NAME
=
"llava-hf/llava-v1.6-mistral-7b-hf"
model_config
=
ModelConfig
(
model
=
MODEL_NAME
,
task
=
"auto"
,
tokenizer
=
MODEL_NAME
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
seed
=
0
,
dtype
=
"half"
,
revision
=
None
,
limit_mm_per_prompt
=
{
"image"
:
limit
},
)
mm_registry
.
init_mm_limits_per_prompt
(
model_config
)
image
=
image_assets
[
0
].
pil_image
if
num_images
==
0
:
mm_inputs
=
{}
elif
num_images
==
1
:
mm_inputs
=
{
"image"
:
image
}
else
:
mm_inputs
=
{
"image"
:
[
image
]
*
num_images
}
with
nullcontext
()
if
is_valid
else
pytest
.
raises
(
ValueError
):
mm_registry
.
map_input
(
model_config
,
mm_inputs
)
# NOTE: We don't test zero images since the HF processor doesn't support it
@
pytest
.
mark
.
parametrize
(
"num_images"
,
[
1
,
2
])
def
test_image_mapper_multi
(
image_assets
,
mm_registry
,
num_images
):
MODEL_NAME
=
"llava-hf/llava-v1.6-mistral-7b-hf"
model_config
=
ModelConfig
(
model
=
MODEL_NAME
,
task
=
"auto"
,
tokenizer
=
MODEL_NAME
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
seed
=
0
,
dtype
=
"half"
,
revision
=
None
,
limit_mm_per_prompt
=
{
"image"
:
num_images
},
)
mm_registry
.
init_mm_limits_per_prompt
(
model_config
)
image
=
image_assets
[
0
].
pil_image
mm_inputs
=
{
"image"
:
[
image
]
*
num_images
}
mapped_inputs
=
mm_registry
.
map_input
(
model_config
,
mm_inputs
)
assert
len
(
mapped_inputs
[
"pixel_values"
])
==
num_images
tests/multimodal/test_processing.py
View file @
8c38ee70
from
contextlib
import
nullcontext
from
functools
import
partial
from
typing
import
cast
from
unittest.mock
import
MagicMock
import
numpy
as
np
import
pytest
...
...
@@ -526,6 +528,100 @@ def _rand_audio(
return
rng
.
rand
(
audio_len
),
sr
@
pytest
.
mark
.
parametrize
(
"model_id"
,
[
"llava-hf/llava-v1.6-mistral-7b-hf"
])
@
pytest
.
mark
.
parametrize
(
(
"limit"
,
"num_supported"
,
"is_valid"
),
[(
0
,
0
,
True
),
(
0
,
1
,
True
),
(
1
,
0
,
False
),
(
1
,
1
,
True
),
(
1
,
2
,
True
),
(
2
,
1
,
False
),
(
2
,
2
,
True
)],
)
def
test_limit_mm_per_prompt_dummy
(
model_id
,
limit
,
num_supported
,
is_valid
):
limit_mm_per_prompt
=
{
"image"
:
limit
}
model_config
=
ModelConfig
(
model
=
model_id
,
task
=
"auto"
,
tokenizer
=
model_id
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
seed
=
0
,
dtype
=
"half"
,
revision
=
None
,
limit_mm_per_prompt
=
limit_mm_per_prompt
,
)
model_cls
=
MULTIMODAL_REGISTRY
.
_get_model_cls
(
model_config
)
processor_factory
=
MULTIMODAL_REGISTRY
.
_processor_factories
[
model_cls
]
ctx
=
InputProcessingContext
(
model_config
,
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
),
)
processor
=
processor_factory
(
ctx
,
cache
=
None
)
mock_supported_mm_limits
=
MagicMock
(
return_value
=
{
"image"
:
num_supported
})
processor
.
get_supported_mm_limits
=
mock_supported_mm_limits
if
is_valid
:
exc_ctx
=
nullcontext
()
else
:
exc_ctx
=
pytest
.
raises
(
ValueError
,
match
=
"this model only supports"
)
with
exc_ctx
:
processor
.
_get_and_validate_dummy_mm_counts
()
@
pytest
.
mark
.
parametrize
(
"model_id"
,
[
"llava-hf/llava-v1.6-mistral-7b-hf"
])
@
pytest
.
mark
.
parametrize
(
(
"num_images"
,
"limit"
,
"is_valid"
),
[(
0
,
0
,
True
),
(
0
,
1
,
True
),
(
1
,
0
,
False
),
(
1
,
1
,
True
),
(
1
,
2
,
True
),
(
2
,
1
,
False
),
(
2
,
2
,
True
)],
)
def
test_limit_mm_per_prompt_apply
(
model_id
,
num_images
,
limit
,
is_valid
):
limit_mm_per_prompt
=
{
"image"
:
limit
}
model_config
=
ModelConfig
(
model
=
model_id
,
task
=
"auto"
,
tokenizer
=
model_id
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
seed
=
0
,
dtype
=
"half"
,
revision
=
None
,
limit_mm_per_prompt
=
limit_mm_per_prompt
,
)
model_cls
=
MULTIMODAL_REGISTRY
.
_get_model_cls
(
model_config
)
processor_factory
=
MULTIMODAL_REGISTRY
.
_processor_factories
[
model_cls
]
ctx
=
InputProcessingContext
(
model_config
,
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
),
)
processor
=
processor_factory
(
ctx
,
cache
=
None
)
rng
=
np
.
random
.
RandomState
(
0
)
image
=
_rand_img
(
rng
,
min_wh
=
128
,
max_wh
=
256
)
if
num_images
==
0
:
mm_data
=
{}
elif
num_images
==
1
:
mm_data
=
{
"image"
:
image
}
else
:
mm_data
=
{
"image"
:
[
image
]
*
num_images
}
if
is_valid
:
exc_ctx
=
nullcontext
()
else
:
exc_ctx
=
pytest
.
raises
(
ValueError
,
match
=
f
"passed
{
num_images
}
image"
)
with
exc_ctx
:
processor
.
apply
(
"<image>"
*
num_images
,
mm_data
=
mm_data
,
hf_processor_mm_kwargs
=
{},
)
def
_test_processing_cache_correctness
(
model_id
:
str
,
modalities
:
dict
[
str
,
bool
],
...
...
@@ -631,6 +727,7 @@ def _test_processing_cache_correctness(
(
"facebook/chameleon-7b"
,
{
"image"
:
False
}),
(
"adept/fuyu-8b"
,
{
"image"
:
False
}),
(
"llava-hf/llava-1.5-7b-hf"
,
{
"image"
:
True
}),
(
"llava-hf/llava-v1.6-mistral-7b-hf"
,
{
"image"
:
True
}),
(
"TIGER-Lab/Mantis-8B-siglip-llama3"
,
{
"image"
:
True
}),
(
"mistral-community/pixtral-12b"
,
{
"image"
:
True
}),
(
"Qwen/Qwen2-VL-2B-Instruct"
,
{
"image"
:
True
,
"video"
:
True
}),
...
...
tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py
View file @
8c38ee70
...
...
@@ -3,13 +3,11 @@ from typing import Optional
import
torch
from
vllm.model_executor.models.llava
import
(
LlavaForConditionalGeneration
,
LlavaMultiModalProcessor
,
get_max_llava_image_tokens
)
LlavaMultiModalProcessor
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
@
MULTIMODAL_REGISTRY
.
register_max_image_tokens
(
get_max_llava_image_tokens
)
@
MULTIMODAL_REGISTRY
.
register_processor
(
LlavaMultiModalProcessor
)
class
MyLlava
(
LlavaForConditionalGeneration
):
...
...
vllm/model_executor/models/clip.py
View file @
8c38ee70
...
...
@@ -24,6 +24,8 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
resolve_visual_encoder_outputs
)
from
vllm.sequence
import
SequenceData
from
.vision
import
VisionEncoderInfo
def
get_clip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
assert
image_size
%
patch_size
==
0
...
...
@@ -149,6 +151,29 @@ def input_processor_for_clip(
multi_modal_placeholders
=
{
"image"
:
ranges
})
class
CLIPEncoderInfo
(
VisionEncoderInfo
[
CLIPVisionConfig
]):
def
get_num_image_tokens
(
self
,
*
,
image_width
:
int
,
image_height
:
int
,
)
->
int
:
return
get_clip_image_feature_size
(
self
.
vision_config
)
def
get_max_image_tokens
(
self
)
->
int
:
return
get_max_clip_image_tokens
(
self
.
vision_config
)
def
get_num_patches
(
self
)
->
int
:
return
get_clip_patch_grid_length
(
image_size
=
self
.
vision_config
.
image_size
,
patch_size
=
self
.
vision_config
.
patch_size
,
)
def
get_image_size
(
self
)
->
int
:
return
self
.
vision_config
.
image_size
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
class
CLIPVisionEmbeddings
(
nn
.
Module
):
...
...
vllm/model_executor/models/fuyu.py
View file @
8c38ee70
...
...
@@ -76,7 +76,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor):
return
ImageSize
(
width
=
target_size
[
"width"
],
height
=
target_size
[
"height"
])
def
_get_image_grid_size
(
def
_get_image_
feature_
grid_size
(
self
,
*
,
image_width
:
int
,
...
...
@@ -99,7 +99,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor):
def
get_mm_max_tokens_per_item
(
self
)
->
Mapping
[
str
,
int
]:
target_width
,
target_height
=
self
.
_get_image_target_size
()
max_ncols
,
max_nrows
=
self
.
_get_image_grid_size
(
max_ncols
,
max_nrows
=
self
.
_get_image_
feature_
grid_size
(
image_width
=
target_width
,
image_height
=
target_height
,
)
...
...
@@ -172,7 +172,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor):
images
=
mm_items
.
get_items
(
"image"
,
ImageProcessorItems
)
image_size
=
images
.
get_image_size
(
item_idx
)
ncols
,
nrows
=
self
.
_get_image_grid_size
(
ncols
,
nrows
=
self
.
_get_image_
feature_
grid_size
(
image_width
=
image_size
.
width
,
image_height
=
image_size
.
height
,
)
...
...
vllm/model_executor/models/llava.py
View file @
8c38ee70
from
abc
import
abstractmethod
from
functools
import
cached_property
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Protocol
,
Set
,
Tuple
,
TypedDict
,
Union
)
from
typing
import
(
Final
,
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Protocol
,
Set
,
Tuple
,
TypedDict
,
Union
)
import
torch
import
torch.nn
as
nn
...
...
@@ -12,7 +13,6 @@ from transformers.models.pixtral import PixtralProcessor
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.inputs
import
InputContext
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
RowParallelLinear
)
...
...
@@ -23,23 +23,23 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalInputsV2
,
MultiModalKwargs
,
NestedTensors
)
from
vllm.multimodal.parse
import
ImageProcessorItems
from
vllm.multimodal.parse
import
(
ImageEmbeddingItems
,
ImageProcessorItems
,
ImageSize
)
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
MultiModalDataItems
,
ProcessorInputs
,
PromptReplacement
,
InputProcessingContext
,
MultiModalDataItems
,
ProcessingCache
,
ProcessorInputs
,
PromptReplacement
,
full_groupby_modality
)
from
vllm.sequence
import
IntermediateTensors
from
.clip
import
(
CLIPVisionModel
,
dummy_image_for_clip
,
get_max_clip_image_tokens
)
from
.clip
import
CLIPVisionModel
from
.interfaces
import
SupportsMultiModal
,
SupportsPP
from
.pixtral
import
(
PixtralHFVisionModel
,
dummy_image_for_pixtral_hf
,
get_max_pixtral_hf_image_tokens
,
get_pixtral_hf_image_feature_size
)
from
.siglip
import
(
SiglipVisionModel
,
dummy_image_for_siglip
,
get_max_siglip_image_tokens
)
from
.pixtral
import
(
PixtralHFVisionModel
,
get_pixtral_hf_image_feature_grid_size
)
from
.siglip
import
SiglipVisionModel
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
init_vllm_registered_model
,
maybe_prefix
,
merge_multimodal_embeddings
)
from
.vision
import
vision_encoder_info
class
LlavaImagePixelInputs
(
TypedDict
):
...
...
@@ -94,39 +94,167 @@ class LlavaMultiModalProjector(nn.Module):
return
hidden_states
def
get_max_llava_image_tokens
(
ctx
:
InputContext
):
hf_config
=
ctx
.
get_hf_config
(
LlavaConfig
)
vision_config
=
hf_config
.
vision_config
class
LlavaLikeConfig
(
Protocol
):
vision_config
:
Final
[
PretrainedConfig
]
vision_feature_select_strategy
:
Final
[
str
]
vision_feature_layer
:
Final
[
Union
[
int
,
List
[
int
]]]
if
isinstance
(
vision_config
,
CLIPVisionConfig
):
num_image_tokens
=
get_max_clip_image_tokens
(
vision_config
)
elif
isinstance
(
vision_config
,
SiglipVisionConfig
):
num_image_tokens
=
get_max_siglip_image_tokens
(
vision_config
)
elif
isinstance
(
vision_config
,
PixtralVisionConfig
):
num_image_tokens
=
get_max_pixtral_hf_image_tokens
(
vision_config
)
else
:
msg
=
f
"Unsupported vision config:
{
type
(
vision_config
)
}
"
raise
NotImplementedError
(
msg
)
strategy
=
hf_config
.
vision_feature_select_strategy
if
strategy
==
"default"
:
return
num_image_tokens
-
1
elif
strategy
==
"full"
:
return
num_image_tokens
else
:
raise
ValueError
(
f
"Unexpected select feature strategy:
{
strategy
}
"
)
class
BaseLlavaMultiModalProcessor
(
BaseMultiModalProcessor
):
def
__init__
(
self
,
ctx
:
InputProcessingContext
,
*
,
cache
:
Optional
[
ProcessingCache
]
=
None
,
enable_sanity_checks
:
bool
=
True
)
->
None
:
super
().
__init__
(
ctx
,
cache
=
cache
,
enable_sanity_checks
=
enable_sanity_checks
)
vision_config
=
self
.
_get_hf_config
().
vision_config
self
.
_vision_encoder_info
=
vision_encoder_info
(
vision_config
)
class
LlavaMultiModalProcessor
(
BaseMultiModalProcessor
):
@
abstractmethod
def
_get_hf_config
(
self
)
->
LlavaLikeConfig
:
raise
NotImplementedError
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
None
}
def
_apply_feature_select_strategy
(
self
,
strategy
:
str
,
encoder_num_image_tokens
:
int
,
)
->
int
:
if
strategy
==
"default"
:
return
encoder_num_image_tokens
-
1
if
strategy
==
"full"
:
return
encoder_num_image_tokens
msg
=
f
"Unexpected feature select strategy:
{
strategy
!
r
}
"
raise
NotImplementedError
(
msg
)
def
_get_max_image_tokens
(
self
)
->
int
:
hf_config
=
self
.
_get_hf_config
()
return
self
.
_apply_feature_select_strategy
(
hf_config
.
vision_feature_select_strategy
,
self
.
_vision_encoder_info
.
get_max_image_tokens
(),
)
def
get_mm_max_tokens_per_item
(
self
)
->
Mapping
[
str
,
int
]:
return
{
"image"
:
get_max_llava_image_tokens
(
self
.
ctx
)}
return
{
"image"
:
self
.
_get_max_image_tokens
()}
def
_get_mm_fields_config
(
self
,
hf_inputs
:
BatchFeature
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
return
dict
(
pixel_values
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_embeds
=
MultiModalFieldConfig
.
batched
(
"image"
),
)
def
_get_dummy_image_size
(
self
)
->
ImageSize
:
image_size
=
self
.
_vision_encoder_info
.
get_image_size
()
return
ImageSize
(
image_size
,
image_size
)
@
abstractmethod
def
_get_image_token
(
self
)
->
str
:
raise
NotImplementedError
def
_get_dummy_mm_inputs
(
self
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
ProcessorInputs
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
image_token
=
self
.
_get_image_token
()
target_width
,
target_height
=
self
.
_get_dummy_image_size
()
mm_data
=
{
"image"
:
self
.
_get_dummy_images
(
width
=
target_width
,
height
=
target_height
,
num_images
=
num_images
)
}
return
ProcessorInputs
(
prompt_text
=
image_token
*
num_images
,
mm_data
=
mm_data
,
)
class
LlavaMultiModalProcessor
(
BaseLlavaMultiModalProcessor
):
def
_get_hf_config
(
self
)
->
LlavaConfig
:
return
self
.
ctx
.
get_hf_config
(
LlavaConfig
)
def
_get_hf_processor
(
self
)
->
LlavaProcessor
:
return
self
.
ctx
.
get_hf_processor
(
LlavaProcessor
)
def
_get_image_token
(
self
)
->
str
:
return
self
.
_get_hf_processor
().
image_token
def
_get_num_image_tokens
(
self
,
*
,
image_width
:
int
,
image_height
:
int
,
)
->
int
:
hf_config
=
self
.
_get_hf_config
()
return
self
.
_apply_feature_select_strategy
(
hf_config
.
vision_feature_select_strategy
,
self
.
_vision_encoder_info
.
get_num_image_tokens
(
image_width
=
image_width
,
image_height
=
image_height
,
),
)
def
_get_prompt_replacements
(
self
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
out_mm_kwargs
:
MultiModalKwargs
,
)
->
list
[
PromptReplacement
]:
hf_config
=
self
.
_get_hf_config
()
image_token_id
=
hf_config
.
image_token_index
def
_get_hf_processor
(
self
)
->
Union
[
LlavaProcessor
,
PixtralProcessor
]:
return
self
.
ctx
.
get_hf_processor
((
LlavaProcessor
,
PixtralProcessor
))
def
get_replacement
(
item_idx
:
int
):
images
=
mm_items
.
get_items
(
"image"
,
(
ImageEmbeddingItems
,
ImageProcessorItems
))
if
isinstance
(
images
,
ImageEmbeddingItems
):
num_image_tokens
=
images
.
get_feature_size
(
item_idx
)
else
:
image_size
=
images
.
get_image_size
(
item_idx
)
num_image_tokens
=
self
.
_get_num_image_tokens
(
image_width
=
image_size
.
width
,
image_height
=
image_size
.
height
,
)
return
[
image_token_id
]
*
num_image_tokens
return
[
PromptReplacement
(
modality
=
"image"
,
target
=
[
image_token_id
],
replacement
=
get_replacement
,
),
]
class
PixtralHFMultiModalProcessor
(
BaseLlavaMultiModalProcessor
):
def
_get_hf_config
(
self
)
->
LlavaConfig
:
return
self
.
ctx
.
get_hf_config
(
LlavaConfig
)
def
_get_hf_processor
(
self
)
->
PixtralProcessor
:
return
self
.
ctx
.
get_hf_processor
(
PixtralProcessor
)
def
_get_image_token
(
self
)
->
str
:
return
self
.
_get_hf_processor
().
image_token
def
_call_hf_processor
(
self
,
...
...
@@ -140,119 +268,82 @@ class LlavaMultiModalProcessor(BaseMultiModalProcessor):
mm_kwargs
=
mm_kwargs
,
)
# NOTE: pixel_values=None for MLlavaProcessor
pixel_values
=
processed_outputs
.
get
(
"pixel_values"
)
if
pixel_values
is
not
None
:
images
=
mm_data
[
"images"
]
assert
isinstance
(
images
,
list
)
if
isinstance
(
self
.
_get_hf_processor
(),
PixtralProcessor
):
# Original output: (1, num_images, C, H, W)
# New output: (num_images, C, H, W)
assert
(
isinstance
(
pixel_values
,
list
)
and
len
(
pixel_values
)
==
1
)
assert
(
isinstance
(
pixel_values
[
0
],
list
)
and
len
(
pixel_values
[
0
])
==
len
(
images
))
# Original output: (1, num_images, C, H, W)
# New output: (num_images, C, H, W)
assert
(
isinstance
(
pixel_values
,
list
)
and
len
(
pixel_values
)
==
1
)
assert
(
isinstance
(
pixel_values
[
0
],
list
)
and
len
(
pixel_values
[
0
])
==
len
(
images
))
processed_outputs
[
"pixel_values"
]
=
pixel_values
[
0
]
processed_outputs
[
"pixel_values"
]
=
pixel_values
[
0
]
return
processed_outputs
def
_get_mm_fields_config
(
self
,
hf_inputs
:
BatchFeature
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
return
dict
(
pixel_values
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_embeds
=
MultiModalFieldConfig
.
batched
(
"image"
),
)
def
_get_prompt_replacements
(
self
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
out_mm_kwargs
:
MultiModalKwargs
,
)
->
list
[
PromptReplacement
]:
hf_config
=
self
.
ctx
.
get_hf_config
(
LlavaConfig
)
hf_config
=
self
.
_
get_hf_config
()
image_token_id
=
hf_config
.
image_token_index
processor
=
self
.
_get_hf_processor
()
if
isinstance
(
processor
,
PixtralProcessor
):
image_token
=
processor
.
image_token
image_break_token
=
processor
.
image_break_token
image_end_token
=
processor
.
image_end_token
vision_config
=
hf_config
.
vision_config
assert
isinstance
(
vision_config
,
PixtralVisionConfig
)
image_token
=
processor
.
image_token
image_break_token
=
processor
.
image_break_token
image_end_token
=
processor
.
image_end_token
def
get_replacement_pixtral
(
item_idx
:
int
):
images
=
mm_items
.
get_items
(
"image"
,
ImageProcessorItems
)
image_size
=
images
.
get_image_size
(
item_idx
)
(
num_width_tokens
,
num_height_tokens
,
)
=
get_pixtral_hf_image_feature_size
(
vision_config
,
image_width
=
image_size
.
width
,
image_height
=
image_size
.
height
,
)
vision_config
=
hf_config
.
vision_config
assert
isinstance
(
vision_config
,
PixtralVisionConfig
)
tokens
=
([
image_token
]
*
nu
m_
w
id
th_tokens
+
[
image_break_token
])
*
num_height_tokens
tokens
[
-
1
]
=
image_end_token
def
get_replacement
(
ite
m_id
x
:
int
):
images
=
mm_items
.
get_items
(
"image"
,
ImageProcessorItems
)
image_size
=
images
.
get_image_size
(
item_idx
)
return
""
.
join
(
tokens
)
ncols
,
nrows
=
get_pixtral_hf_image_feature_grid_size
(
vision_config
,
image_width
=
image_size
.
width
,
image_height
=
image_size
.
height
,
)
return
[
PromptReplacement
(
modality
=
"image"
,
target
=
[
image_token_id
],
replacement
=
get_replacement_pixtral
,
),
]
tokens
=
([
image_token
]
*
ncols
+
[
image_break_token
])
*
nrows
tokens
[
-
1
]
=
image_end_token
max_image_tokens
=
get_max_llava_image_tokens
(
self
.
ctx
)
return
""
.
join
(
tokens
)
return
[
PromptReplacement
(
modality
=
"image"
,
target
=
[
image_token_id
],
replacement
=
[
image_token_id
]
*
max_image_tok
en
s
,
)
replacement
=
get_replacem
en
t
,
)
,
]
def
_get_dummy_mm_inputs
(
self
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
ProcessorInputs
:
hf_config
=
self
.
ctx
.
get_hf_config
(
LlavaConfig
)
vision_config
=
hf_config
.
vision_config
num_images
=
mm_counts
.
get
(
"image"
,
0
)
if
isinstance
(
vision_config
,
CLIPVisionConfig
):
data
=
dummy_image_for_clip
(
vision_config
,
num_images
)
elif
isinstance
(
vision_config
,
SiglipVisionConfig
):
data
=
dummy_image_for_siglip
(
vision_config
,
num_images
)
elif
isinstance
(
vision_config
,
PixtralVisionConfig
):
data
=
dummy_image_for_pixtral_hf
(
vision_config
,
num_images
)
else
:
msg
=
f
"Unsupported vision config:
{
type
(
vision_config
)
}
"
raise
NotImplementedError
(
msg
)
hf_processor
=
self
.
_get_hf_processor
()
image_token
=
hf_processor
.
image_token
def
_build_llava_or_pixtral_hf_processor
(
ctx
:
InputProcessingContext
,
*
,
cache
:
Optional
[
ProcessingCache
]
=
None
,
enable_sanity_checks
:
bool
=
True
,
)
->
BaseLlavaMultiModalProcessor
:
hf_config
=
ctx
.
get_hf_config
(
LlavaConfig
)
return
ProcessorInputs
(
prompt_text
=
image_token
*
num_images
,
mm_data
=
data
,
if
isinstance
(
hf_config
.
vision_config
,
PixtralVisionConfig
):
return
PixtralHFMultiModalProcessor
(
ctx
,
cache
=
cache
,
enable_sanity_checks
=
enable_sanity_checks
,
)
class
LlavaLikeConfig
(
Protocol
):
vision_config
:
PretrainedConfig
vision_feature_layer
:
Union
[
int
,
List
[
int
]]
return
LlavaMultiModalProcessor
(
ctx
,
cache
=
cache
,
enable_sanity_checks
=
enable_sanity_checks
,
)
def
_get_num_hidden_layers
(
hf_config
:
LlavaLikeConfig
)
->
int
:
...
...
@@ -330,7 +421,7 @@ def init_vision_tower_for_llava(
raise
NotImplementedError
(
msg
)
@
MULTIMODAL_REGISTRY
.
register_processor
(
LlavaMultiModalP
rocessor
)
@
MULTIMODAL_REGISTRY
.
register_processor
(
_build_llava_or_pixtral_hf_p
rocessor
)
class
LlavaForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping
=
{
...
...
@@ -596,7 +687,12 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
)
->
MultiModalInputsV2
:
hf_config
=
self
.
ctx
.
get_hf_config
(
LlavaConfig
)
image_token_id
=
hf_config
.
image_token_index
max_image_tokens
=
get_max_llava_image_tokens
(
self
.
ctx
)
# Assume that it doesn't depend on the image size
num_image_tokens
=
self
.
_get_num_image_tokens
(
image_width
=-
1
,
image_height
=-
1
,
)
result
=
super
().
apply
(
prompt_text
,
mm_data
,
hf_processor_mm_kwargs
)
...
...
@@ -609,14 +705,14 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
def
get_replacement_mantis
(
item_idx
:
int
):
return
""
.
join
([
f
"(image
{
item_idx
+
1
}
: <Image>"
,
# 7 tokens
"<image>"
*
max
_image_tokens
,
"<image>"
*
num
_image_tokens
,
"</Image>)"
,
# 3 tokens
])
mantis_repls
=
self
.
_bind_prompt_replacements
([
PromptReplacement
(
modality
=
"image"
,
target
=
[
image_token_id
]
*
max
_image_tokens
,
target
=
[
image_token_id
]
*
num
_image_tokens
,
replacement
=
get_replacement_mantis
,
)
])
...
...
vllm/model_executor/models/llava_next.py
View file @
8c38ee70
...
...
@@ -4,31 +4,25 @@ from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
import
torch
import
torch.nn
as
nn
from
PIL
import
Image
from
transformers
import
CLIPVisionConfig
,
LlavaNextConfig
,
SiglipVisionConfig
from
transformers
import
BatchFeature
,
LlavaNextConfig
,
LlavaNextProcessor
from
transformers.models.llava_next.modeling_llava_next
import
(
get_anyres_image_grid_shape
,
unpad_image
)
from
typing_extensions
import
NotRequired
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.inputs
import
(
INPUT_REGISTRY
,
DecoderOnlyInputs
,
DummyData
,
InputContext
)
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
NestedTensors
from
vllm.multimodal.inputs
import
MultiModalFieldConfig
,
NestedTensors
from
vllm.multimodal.parse
import
ImageSize
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
is_list_of
from
.clip
import
(
CLIPVisionModel
,
dummy_image_for_clip
,
dummy_seq_data_for_clip
,
get_clip_image_feature_size
,
get_clip_patch_grid_length
,
input_processor_for_clip
)
from
.clip
import
CLIPVisionModel
from
.interfaces
import
SupportsMultiModal
,
SupportsPP
from
.llava
import
LlavaMultiModalProjector
,
init_vision_tower_for_llava
from
.siglip
import
(
SiglipVisionModel
,
dummy_image_for_siglip
,
dummy_seq_data_for_siglip
,
get_siglip_image_feature_size
,
get_siglip_patch_grid_length
,
input_processor_for_siglip
)
from
.llava
import
(
LlavaMultiModalProcessor
,
LlavaMultiModalProjector
,
init_vision_tower_for_llava
)
from
.siglip
import
SiglipVisionModel
from
.utils
import
(
AutoWeightsLoader
,
embed_multimodal
,
flatten_bn
,
init_vllm_registered_model
,
maybe_prefix
)
...
...
@@ -65,218 +59,127 @@ LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
LlavaNextImageEmbeddingInputs
]
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79
def
_get_llava_next_num_unpadded_features
(
original_height
:
int
,
original_width
:
int
,
npatches
:
int
,
num_patch_height
:
int
,
num_patch_width
:
int
,
)
->
Tuple
[
int
,
int
]:
current_height
=
npatches
*
num_patch_height
current_width
=
npatches
*
num_patch_width
original_aspect_ratio
=
original_width
/
original_height
current_aspect_ratio
=
current_width
/
current_height
if
original_aspect_ratio
>
current_aspect_ratio
:
scale_factor
=
current_width
/
original_width
new_height
=
int
(
original_height
*
scale_factor
)
padding
=
(
current_height
-
new_height
)
//
2
current_height
-=
2
*
padding
else
:
scale_factor
=
current_height
/
original_height
new_width
=
int
(
original_width
*
scale_factor
)
padding
=
(
current_width
-
new_width
)
//
2
current_width
-=
2
*
padding
unpadded_features
=
current_height
*
current_width
newline_features
=
current_height
return
(
unpadded_features
,
newline_features
)
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106
def
get_llava_next_image_feature_size
(
hf_config
:
LlavaNextConfig
,
*
,
input_height
:
int
,
input_width
:
int
,
)
->
int
:
vision_config
=
hf_config
.
vision_config
if
isinstance
(
vision_config
,
CLIPVisionConfig
):
num_patches
=
get_clip_patch_grid_length
(
image_size
=
vision_config
.
image_size
,
patch_size
=
vision_config
.
patch_size
,
)
base_feature_size
=
get_clip_image_feature_size
(
vision_config
)
elif
isinstance
(
vision_config
,
SiglipVisionConfig
):
num_patches
=
get_siglip_patch_grid_length
(
image_size
=
vision_config
.
image_size
,
patch_size
=
vision_config
.
patch_size
,
)
base_feature_size
=
get_siglip_image_feature_size
(
vision_config
)
else
:
msg
=
f
"Unsupported vision config:
{
type
(
vision_config
)
}
"
raise
NotImplementedError
(
msg
)
strategy
=
hf_config
.
vision_feature_select_strategy
if
strategy
==
"default"
:
base_feature_size
-=
1
elif
strategy
==
"full"
:
pass
else
:
raise
ValueError
(
f
"Unexpected select feature strategy:
{
strategy
}
"
)
class
LlavaNextMultiModalProcessor
(
LlavaMultiModalProcessor
):
num_patch_height
,
num_patch_width
=
get_anyres_image_grid_shape
(
image_size
=
(
input_height
,
input_width
),
grid_pinpoints
=
hf_config
.
image_grid_pinpoints
,
patch_size
=
vision_config
.
image_size
,
)
(
unpadded_feature_size
,
newline_feature_size
,
)
=
_get_llava_next_num_unpadded_features
(
input_height
,
input_width
,
num_patches
,
num_patch_height
,
num_patch_width
)
return
unpadded_feature_size
+
newline_feature_size
+
base_feature_size
def
get_max_llava_next_image_tokens
(
ctx
:
InputContext
):
"""Compute the max feature size for all possible image grid pinpoints."""
return
_get_pinpoint_with_largest_features
(
ctx
)[
0
]
def
_get_pinpoint_with_largest_features
(
ctx
:
InputContext
)
->
Tuple
[
int
,
Tuple
[
int
,
int
]]:
"""Get the grid pinpoint with the largest features & its feature size."""
hf_config
=
ctx
.
get_hf_config
(
LlavaNextConfig
)
largest_feature_size
=
0
largest_feature_pinpoint
=
None
for
(
height
,
width
)
in
hf_config
.
image_grid_pinpoints
:
feat_size
=
get_llava_next_image_feature_size
(
hf_config
,
input_height
=
height
,
input_width
=
width
,
)
if
feat_size
>
largest_feature_size
:
largest_feature_size
=
feat_size
largest_feature_pinpoint
=
(
height
,
width
)
if
not
largest_feature_size
or
largest_feature_pinpoint
is
None
:
raise
ValueError
(
"Cannot have a largest feature size of 0!"
)
return
largest_feature_size
,
largest_feature_pinpoint
def
dummy_data_for_llava_next
(
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
]):
hf_config
=
ctx
.
get_hf_config
(
LlavaNextConfig
)
vision_config
=
hf_config
.
vision_config
num_images
=
mm_counts
[
"image"
]
image_feature_size
,
pinpoint
=
_get_pinpoint_with_largest_features
(
ctx
)
max_feat_height
,
max_feat_width
=
pinpoint
if
isinstance
(
vision_config
,
CLIPVisionConfig
):
seq_data
,
ranges
=
dummy_seq_data_for_clip
(
vision_config
,
seq_len
,
num_images
,
image_token_id
=
hf_config
.
image_token_index
,
image_feature_size_override
=
image_feature_size
,
)
def
_get_hf_config
(
self
)
->
LlavaNextConfig
:
return
self
.
ctx
.
get_hf_config
(
LlavaNextConfig
)
def
_get_hf_processor
(
self
)
->
LlavaNextProcessor
:
return
self
.
ctx
.
get_hf_processor
(
LlavaNextProcessor
)
mm_data
=
dummy_image_for_clip
(
vision_config
,
num_images
,
image_width_override
=
max_feat_width
,
image_height_override
=
max_feat_height
,
def
_get_image_token
(
self
)
->
str
:
return
self
.
_get_hf_processor
().
image_token
def
_get_mm_fields_config
(
self
,
hf_inputs
:
BatchFeature
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
return
dict
(
pixel_values
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_sizes
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_embeds
=
MultiModalFieldConfig
.
batched
(
"image"
),
)
return
DummyData
(
seq_data
,
mm_data
,
ranges
)
elif
isinstance
(
vision_config
,
SiglipVisionConfig
):
seq_data
,
ranges
=
dummy_seq_data_for_siglip
(
vision_config
,
seq_len
,
num_images
,
image_token_id
=
hf_config
.
image_token_index
,
image_feature_size_override
=
image_feature_size
,
def
_get_max_image_tokens
(
self
)
->
int
:
largest_feature_size
,
_
=
self
.
_get_pinpoint_with_most_features
()
return
largest_feature_size
def
_get_dummy_image_size
(
self
)
->
ImageSize
:
_
,
pinpoint
=
self
.
_get_pinpoint_with_most_features
()
return
pinpoint
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106
def
_get_num_image_tokens
(
self
,
*
,
image_width
:
int
,
image_height
:
int
,
)
->
int
:
hf_config
=
self
.
_get_hf_config
()
base_feature_size
=
self
.
_apply_feature_select_strategy
(
hf_config
.
vision_feature_select_strategy
,
self
.
_vision_encoder_info
.
get_num_image_tokens
(
image_width
=
image_width
,
image_height
=
image_height
,
),
)
num_patches
=
self
.
_vision_encoder_info
.
get_num_patches
()
mm_data
=
dummy_image_for_siglip
(
vision_config
,
num_images
,
image_width_override
=
max_feat_width
,
image_height_override
=
max_feat_height
,
num_patch_height
,
num_patch_width
=
get_anyres_image_grid_shape
(
image_size
=
(
image_height
,
image_width
),
grid_pinpoints
=
hf_config
.
image_grid_pinpoints
,
patch_size
=
self
.
_vision_encoder_info
.
get_image_size
(),
)
return
DummyData
(
seq_data
,
mm_data
,
ranges
)
(
unpadded_feature_size
,
newline_feature_size
,
)
=
self
.
_get_num_unpadded_features
(
original_height
=
image_height
,
original_width
=
image_width
,
npatches
=
num_patches
,
num_patch_height
=
num_patch_height
,
num_patch_width
=
num_patch_width
,
)
msg
=
f
"Unsupported vision config:
{
type
(
vision_config
)
}
"
raise
NotImplementedError
(
msg
)
return
unpadded_feature_size
+
newline_feature_size
+
base_feature_size
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79
def
_get_num_unpadded_features
(
self
,
*
,
original_height
:
int
,
original_width
:
int
,
npatches
:
int
,
num_patch_height
:
int
,
num_patch_width
:
int
,
)
->
tuple
[
int
,
int
]:
current_height
=
npatches
*
num_patch_height
current_width
=
npatches
*
num_patch_width
original_aspect_ratio
=
original_width
/
original_height
current_aspect_ratio
=
current_width
/
current_height
if
original_aspect_ratio
>
current_aspect_ratio
:
scale_factor
=
current_width
/
original_width
new_height
=
int
(
original_height
*
scale_factor
)
padding
=
(
current_height
-
new_height
)
//
2
current_height
-=
2
*
padding
else
:
scale_factor
=
current_height
/
original_height
new_width
=
int
(
original_width
*
scale_factor
)
padding
=
(
current_width
-
new_width
)
//
2
current_width
-=
2
*
padding
def
input_processor_for_llava_next
(
ctx
:
InputContext
,
inputs
:
DecoderOnlyInputs
):
multi_modal_data
=
inputs
.
get
(
"multi_modal_data"
)
if
multi_modal_data
is
None
or
"image"
not
in
multi_modal_data
:
return
inputs
unpadded_features
=
current_height
*
current_width
newline_features
=
current_height
return
(
unpadded_features
,
newline_features
)
model_config
=
ctx
.
model_config
hf_config
=
ctx
.
get_hf_config
(
LlavaNextConfig
)
vision_config
=
hf_config
.
vision_config
def
_get_pinpoint_with_most_features
(
self
)
->
tuple
[
int
,
ImageSize
]:
"""
Get the grid pinpoint with the most features and
the corresponding feature size.
"""
hf_config
=
self
.
_get_hf_config
()
image_data
=
multi_modal_data
[
"image"
]
if
isinstance
(
image_data
,
Image
.
Image
):
width
,
height
=
image_data
.
size
largest_feature_size
,
largest_feature_pinpoint
=
0
,
None
for
(
height
,
width
)
in
hf_config
.
image_grid_pinpoints
:
feat_size
=
self
.
_get_num_image_tokens
(
image_width
=
width
,
image_height
=
height
)
if
feat_size
>
largest_feature_size
:
largest_feature_size
=
feat_size
largest_feature_pinpoint
=
ImageSize
(
width
=
width
,
height
=
height
)
image_feature_size
=
get_llava_next_image_feature_size
(
hf_config
,
input_height
=
height
,
input_width
=
width
,
)
elif
is_list_of
(
image_data
,
Image
.
Image
):
image_feature_size
=
[
get_llava_next_image_feature_size
(
hf_config
,
input_height
=
img
.
height
,
input_width
=
img
.
width
)
for
img
in
image_data
]
elif
isinstance
(
image_data
,
torch
.
Tensor
):
num_images
,
image_feature_size
,
hidden_size
=
image_data
.
shape
elif
is_list_of
(
image_data
,
torch
.
Tensor
):
image_feature_size
=
[
item
.
shape
[
1
]
for
item
in
image_data
]
else
:
raise
TypeError
(
f
"Invalid image type:
{
type
(
image_data
)
}
"
)
vision_config
=
hf_config
.
vision_config
if
isinstance
(
vision_config
,
CLIPVisionConfig
):
return
input_processor_for_clip
(
model_config
,
vision_config
,
inputs
,
image_token_id
=
hf_config
.
image_token_index
,
image_feature_size_override
=
image_feature_size
,
)
elif
isinstance
(
vision_config
,
SiglipVisionConfig
):
return
input_processor_for_siglip
(
model_config
,
vision_config
,
inputs
,
image_token_id
=
hf_config
.
image_token_index
,
image_feature_size_override
=
image_feature_size
,
)
if
largest_feature_size
==
0
or
largest_feature_pinpoint
is
None
:
raise
ValueError
(
"Cannot have a largest feature size of 0!"
)
msg
=
f
"Unsupported vision config:
{
type
(
vision_config
)
}
"
raise
NotImplementedError
(
msg
)
return
largest_feature_size
,
largest_feature_pinpoint
@
MULTIMODAL_REGISTRY
.
register_image_input_mapper
()
@
MULTIMODAL_REGISTRY
.
register_max_image_tokens
(
get_max_llava_next_image_tokens
)
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_data_for_llava_next
)
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_llava_next
)
@
MULTIMODAL_REGISTRY
.
register_processor
(
LlavaNextMultiModalProcessor
)
class
LlavaNextForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
...
...
@@ -507,7 +410,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
def
_process_image_pixels
(
self
,
inputs
:
LlavaNextImagePixelInputs
,
)
->
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]:
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
...
]]:
assert
self
.
vision_tower
is
not
None
pixel_values
=
inputs
[
"data"
]
...
...
vllm/model_executor/models/phi3v.py
View file @
8c38ee70
...
...
@@ -34,7 +34,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalInputsV2
,
MultiModalKwargs
,
NestedTensors
,
PlaceholderRange
)
from
vllm.multimodal.parse
import
ImageProcessorItems
from
vllm.multimodal.parse
import
ImageEmbeddingItems
,
ImageProcessorItems
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
MultiModalDataItems
,
ProcessorInputs
,
PromptReplacement
,
...
...
@@ -388,15 +388,19 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
assert
isinstance
(
bos_token_id
,
int
)
def
get_replacement_phi3v
(
item_idx
:
int
):
images
=
mm_items
.
get_items
(
"image"
,
ImageProcessorItems
)
image_size
=
images
.
get_image_size
(
item_idx
)
num_tokens
=
self
.
_get_num_image_tokens
(
image_width
=
image_size
.
width
,
image_height
=
image_size
.
height
,
)
return
[
_IMAGE_TOKEN_ID
]
*
num_tokens
+
[
bos_token_id
]
images
=
mm_items
.
get_items
(
"image"
,
(
ImageEmbeddingItems
,
ImageProcessorItems
))
if
isinstance
(
images
,
ImageEmbeddingItems
):
num_image_tokens
=
images
.
get_feature_size
(
item_idx
)
else
:
image_size
=
images
.
get_image_size
(
item_idx
)
num_image_tokens
=
self
.
_get_num_image_tokens
(
image_width
=
image_size
.
width
,
image_height
=
image_size
.
height
,
)
return
[
_IMAGE_TOKEN_ID
]
*
num_image_tokens
+
[
bos_token_id
]
num_images
=
mm_items
.
get_count
(
"image"
,
strict
=
False
)
...
...
vllm/model_executor/models/pixtral.py
View file @
8c38ee70
...
...
@@ -38,6 +38,7 @@ from vllm.sequence import IntermediateTensors, SequenceData
from
.interfaces
import
SupportsMultiModal
,
SupportsPP
from
.utils
import
(
init_vllm_registered_model
,
maybe_prefix
,
merge_multimodal_embeddings
)
from
.vision
import
VisionEncoderInfo
try
:
from
xformers
import
ops
as
xops
...
...
@@ -697,10 +698,18 @@ def get_pixtral_hf_patch_grid_length(*, image_size: int,
return
image_size
//
patch_size
def
get_pixtral_hf_num_patches
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
grid_length
=
get_pixtral_hf_patch_grid_length
(
image_size
=
image_size
,
patch_size
=
patch_size
)
return
grid_length
*
grid_length
def
get_pixtral_hf_image_feature_size
(
*
,
image_size
:
int
,
patch_size
:
int
,
)
->
int
:
grid_length
=
get_pixtral_hf_patch_grid_length
(
image_size
=
image_size
,
patch_size
=
patch_size
,
)
# Consider the image_break_token
return
(
grid_length
+
1
)
*
grid_length
def
get_max_pixtral_hf_image_tokens
(
hf_config
:
PixtralVisionConfig
)
->
int
:
...
...
@@ -730,13 +739,16 @@ def dummy_image_for_pixtral_hf(
return
{
"image"
:
image
if
num_images
==
1
else
[
image
]
*
num_images
}
def
get_pixtral_hf_image_feature_size
(
hf_config
:
PixtralVisionConfig
,
image_width
:
int
,
image_height
:
int
)
->
Tuple
[
int
,
int
]:
# Adapted from transformers.models.pixtral.image_processing_pixtral.get_resize_output_image_size # noqa: E501
# https://github.com/huggingface/transformers/blob/2bd4d5897dc73e8b172832070a6f9e567a0df017/src/transformers/models/pixtral/image_processing_pixtral.py#L180 # noqa: E501
max_width
,
max_height
=
hf_config
.
image_size
,
hf_config
.
image_size
patch_width
,
patch_height
=
hf_config
.
patch_size
,
hf_config
.
patch_size
# Adapted from transformers.models.pixtral.image_processing_pixtral.get_resize_output_image_size # noqa: E501
# https://github.com/huggingface/transformers/blob/2bd4d5897dc73e8b172832070a6f9e567a0df017/src/transformers/models/pixtral/image_processing_pixtral.py#L180
def
get_pixtral_hf_image_feature_grid_size
(
hf_config
:
PixtralVisionConfig
,
*
,
image_width
:
int
,
image_height
:
int
,
)
->
tuple
[
int
,
int
]:
max_width
=
max_height
=
hf_config
.
image_size
patch_width
=
patch_height
=
hf_config
.
patch_size
ratio
=
max
(
image_width
/
max_width
,
image_height
/
max_height
)
...
...
@@ -744,12 +756,38 @@ def get_pixtral_hf_image_feature_size(hf_config: PixtralVisionConfig,
image_width
=
int
(
math
.
ceil
(
image_width
/
ratio
))
image_height
=
int
(
math
.
ceil
(
image_height
/
ratio
))
n
um_height_tokens
,
num_width_token
s
=
_get_pixtral_hf_num_image_tokens
(
n
rows
,
ncol
s
=
_get_pixtral_hf_num_image_tokens
(
(
image_height
,
image_width
),
(
patch_height
,
patch_width
),
)
)
# type: ignore
return
ncols
,
nrows
class
PixtralHFEncoderInfo
(
VisionEncoderInfo
[
PixtralVisionConfig
]):
def
get_num_image_tokens
(
self
,
*
,
image_width
:
int
,
image_height
:
int
,
)
->
int
:
return
get_pixtral_hf_image_feature_size
(
image_size
=
self
.
vision_config
.
image_size
,
patch_size
=
self
.
get_image_size
(),
)
def
get_max_image_tokens
(
self
)
->
int
:
return
get_max_pixtral_hf_image_tokens
(
self
.
vision_config
)
def
get_num_patches
(
self
)
->
int
:
return
get_pixtral_hf_patch_grid_length
(
image_size
=
self
.
vision_config
.
image_size
,
patch_size
=
self
.
vision_config
.
patch_size
,
)
return
num_width_tokens
,
num_height_tokens
def
get_image_size
(
self
)
->
int
:
return
self
.
vision_config
.
image_size
class
PixtralHFMLP
(
nn
.
Module
):
...
...
vllm/model_executor/models/siglip.py
View file @
8c38ee70
...
...
@@ -28,6 +28,8 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
resolve_visual_encoder_outputs
)
from
vllm.sequence
import
SequenceData
from
.vision
import
VisionEncoderInfo
def
get_siglip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
# Since interpolation is applied, the image size need not be divisible
...
...
@@ -156,6 +158,29 @@ def input_processor_for_siglip(
multi_modal_placeholders
=
{
"image"
:
ranges
})
class
SiglipEncoderInfo
(
VisionEncoderInfo
[
SiglipVisionConfig
]):
def
get_num_image_tokens
(
self
,
*
,
image_width
:
int
,
image_height
:
int
,
)
->
int
:
return
get_siglip_image_feature_size
(
self
.
vision_config
)
def
get_max_image_tokens
(
self
)
->
int
:
return
get_max_siglip_image_tokens
(
self
.
vision_config
)
def
get_num_patches
(
self
)
->
int
:
return
get_siglip_patch_grid_length
(
image_size
=
self
.
vision_config
.
image_size
,
patch_size
=
self
.
vision_config
.
patch_size
,
)
def
get_image_size
(
self
)
->
int
:
return
self
.
vision_config
.
image_size
# Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa
class
SiglipVisionEmbeddings
(
nn
.
Module
):
...
...
vllm/model_executor/models/utils.py
View file @
8c38ee70
...
...
@@ -373,7 +373,7 @@ def embed_multimodal(
input_ids
:
torch
.
Tensor
,
multimodal_token_id
:
int
,
get_text_embeds
:
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
],
multimodal_embeds
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]
,
multimodal_embeds
:
Nested
Tensor
s
,
)
->
torch
.
Tensor
:
"""
Embed token IDs and multimodal inputs and combine their embeddings.
...
...
vllm/model_executor/models/vision.py
0 → 100644
View file @
8c38ee70
from
abc
import
ABC
,
abstractmethod
from
typing
import
Generic
,
TypeVar
from
transformers
import
PretrainedConfig
_C
=
TypeVar
(
"_C"
,
bound
=
PretrainedConfig
)
class
VisionEncoderInfo
(
ABC
,
Generic
[
_C
]):
def
__init__
(
self
,
vision_config
:
_C
)
->
None
:
super
().
__init__
()
self
.
vision_config
=
vision_config
@
abstractmethod
def
get_num_image_tokens
(
self
,
*
,
image_width
:
int
,
image_height
:
int
,
)
->
int
:
raise
NotImplementedError
@
abstractmethod
def
get_max_image_tokens
(
self
)
->
int
:
raise
NotImplementedError
@
abstractmethod
def
get_num_patches
(
self
)
->
int
:
raise
NotImplementedError
@
abstractmethod
def
get_image_size
(
self
)
->
int
:
raise
NotImplementedError
def
vision_encoder_info
(
vision_config
:
PretrainedConfig
)
->
VisionEncoderInfo
:
# Avoid circular imports
from
.clip
import
CLIPEncoderInfo
,
CLIPVisionConfig
from
.pixtral
import
PixtralHFEncoderInfo
,
PixtralVisionConfig
from
.siglip
import
SiglipEncoderInfo
,
SiglipVisionConfig
if
isinstance
(
vision_config
,
CLIPVisionConfig
):
return
CLIPEncoderInfo
(
vision_config
)
if
isinstance
(
vision_config
,
PixtralVisionConfig
):
return
PixtralHFEncoderInfo
(
vision_config
)
if
isinstance
(
vision_config
,
SiglipVisionConfig
):
return
SiglipEncoderInfo
(
vision_config
)
msg
=
f
"Unsupported vision config:
{
type
(
vision_config
)
}
"
raise
NotImplementedError
(
msg
)
vllm/multimodal/parse.py
View file @
8c38ee70
from
abc
import
ABC
,
abstractmethod
from
collections
import
UserDict
from
collections.abc
import
Callable
,
Iterator
,
Mapping
,
Sequence
from
typing
import
TYPE_CHECKING
,
Any
,
Generic
,
NamedTuple
,
Optional
,
TypeVar
from
typing
import
(
TYPE_CHECKING
,
Any
,
Generic
,
NamedTuple
,
Optional
,
TypeVar
,
Union
)
import
numpy
as
np
import
torch
...
...
@@ -87,7 +88,7 @@ class EmbeddingItems(ModalityDataItems[NestedTensors, torch.Tensor]):
def
get_count
(
self
)
->
int
:
return
len
(
self
.
data
)
def
get
(
self
,
index
:
int
)
->
object
:
def
get
(
self
,
index
:
int
)
->
torch
.
Tensor
:
return
self
.
data
[
index
]
def
get_processor_data
(
self
)
->
Mapping
[
str
,
object
]:
...
...
@@ -96,6 +97,9 @@ class EmbeddingItems(ModalityDataItems[NestedTensors, torch.Tensor]):
def
get_passthrough_data
(
self
)
->
Mapping
[
str
,
object
]:
return
{
f
"
{
self
.
modality
}
_embeds"
:
self
.
data
}
def
get_feature_size
(
self
,
item_idx
:
int
)
->
int
:
return
len
(
self
.
get
(
item_idx
))
class
AudioProcessorItems
(
ProcessorBatchItems
[
HfAudioItem
]):
...
...
@@ -182,7 +186,7 @@ class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]):
def
get_items
(
self
,
modality
:
str
,
typ
:
type
[
_D
],
typ
:
Union
[
type
[
_D
],
tuple
[
type
[
_D
],
...]],
)
->
_D
:
"""
Get the data items belonging to a modality,
...
...
@@ -199,7 +203,7 @@ class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]):
f
"Expected type:
{
typ
}
, but "
f
"found type:
{
type
(
items
)
}
"
)
return
items
return
items
# type: ignore[return-value]
ModalityDataParser
:
TypeAlias
=
Callable
[[
ModalityData
[
Any
]],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment