Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8c38ee70
Unverified
Commit
8c38ee70
authored
Jan 03, 2025
by
Cyrus Leung
Committed by
GitHub
Jan 02, 2025
Browse files
[VLM] Merged multi-modal processor for LLaVA-NeXT (#11682)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
b6087a6b
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
605 additions
and
551 deletions
+605
-551
tests/models/decoder_only/vision_language/mm_processor_kwargs/test_llava_next.py
...ly/vision_language/mm_processor_kwargs/test_llava_next.py
+0
-70
tests/multimodal/test_mapper.py
tests/multimodal/test_mapper.py
+0
-118
tests/multimodal/test_processing.py
tests/multimodal/test_processing.py
+97
-0
tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py
...ins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py
+1
-3
vllm/model_executor/models/clip.py
vllm/model_executor/models/clip.py
+25
-0
vllm/model_executor/models/fuyu.py
vllm/model_executor/models/fuyu.py
+3
-3
vllm/model_executor/models/llava.py
vllm/model_executor/models/llava.py
+215
-119
vllm/model_executor/models/llava_next.py
vllm/model_executor/models/llava_next.py
+112
-209
vllm/model_executor/models/phi3v.py
vllm/model_executor/models/phi3v.py
+14
-10
vllm/model_executor/models/pixtral.py
vllm/model_executor/models/pixtral.py
+52
-14
vllm/model_executor/models/siglip.py
vllm/model_executor/models/siglip.py
+25
-0
vllm/model_executor/models/utils.py
vllm/model_executor/models/utils.py
+1
-1
vllm/model_executor/models/vision.py
vllm/model_executor/models/vision.py
+52
-0
vllm/multimodal/parse.py
vllm/multimodal/parse.py
+8
-4
No files found.
tests/models/decoder_only/vision_language/mm_processor_kwargs/test_llava_next.py
deleted
100644 → 0
View file @
b6087a6b
import
pytest
from
vllm.inputs
import
InputContext
from
....utils
import
build_model_context
@
pytest
.
fixture
()
def
get_max_llava_next_image_tokens
():
from
vllm.model_executor.models.llava_next
import
(
get_max_llava_next_image_tokens
)
return
get_max_llava_next_image_tokens
@
pytest
.
fixture
()
def
dummy_data_for_llava_next
():
from
vllm.model_executor.models.llava_next
import
dummy_data_for_llava_next
return
dummy_data_for_llava_next
@
pytest
.
mark
.
parametrize
(
"gridpoints,expected_max_tokens"
,
[
([[
336
,
336
]],
1176
),
([[
336
,
672
],
[
672
,
336
],
[
672
,
672
],
[
1008
,
336
],
[
336
,
1008
]],
2928
),
])
def
test_get_max_llava_next_image_tokens
(
gridpoints
,
expected_max_tokens
,
get_max_llava_next_image_tokens
):
ctx
=
build_model_context
(
model_name
=
"llava-hf/llava-v1.6-mistral-7b-hf"
)
# Update the config image_grid_pinpoints
# and calculate the resulting max tokens
ctx
.
model_config
.
hf_config
.
image_grid_pinpoints
=
gridpoints
actual_max_tokens
=
get_max_llava_next_image_tokens
(
InputContext
(
ctx
.
model_config
))
assert
expected_max_tokens
==
actual_max_tokens
@
pytest
.
mark
.
parametrize
(
"gridpoints,expected_size"
,
[
# One point; it has to be the largest
([[
336
,
336
]],
(
336
,
336
)),
# Default for most llava next models; the 2x2 tile is the largest
([[
336
,
672
],
[
672
,
336
],
[
672
,
672
],
[
1008
,
336
],
[
336
,
1008
]],
(
672
,
672
)),
# If two rectangular gridpoints are the same, the more vertical
# one has the higher feature count due to newline features
([[
336
,
672
],
[
672
,
336
]],
(
672
,
336
))
])
def
test_dummy_data_for_llava_next_feature_size
(
dummy_data_for_llava_next
,
gridpoints
,
expected_size
):
ctx
=
build_model_context
(
model_name
=
"llava-hf/llava-v1.6-mistral-7b-hf"
)
# Update the config image_grid_pinpoints
ctx
.
model_config
.
hf_config
.
image_grid_pinpoints
=
gridpoints
seq_len
=
5000
# bigger than the max feature size for any image
dummy_data
=
dummy_data_for_llava_next
(
ctx
,
seq_len
=
seq_len
,
mm_counts
=
{
"image"
:
1
},
)
seq_data
=
dummy_data
.
seq_data
mm_data
=
dummy_data
.
multi_modal_data
# The dummy data dims should match the gridpoint with the biggest feat size
assert
mm_data
[
"image"
].
height
==
expected_size
[
0
]
assert
mm_data
[
"image"
].
width
==
expected_size
[
1
]
assert
len
(
seq_data
.
get_token_ids
())
>=
seq_len
tests/multimodal/test_mapper.py
deleted
100644 → 0
View file @
b6087a6b
from
contextlib
import
nullcontext
import
numpy
as
np
import
pytest
from
transformers
import
LlavaNextImageProcessor
from
vllm.config
import
ModelConfig
from
vllm.multimodal
import
MultiModalRegistry
from
vllm.multimodal.image
import
rescale_image_size
@
pytest
.
fixture
def
mm_registry
():
return
MultiModalRegistry
()
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
,
"float"
])
@
pytest
.
mark
.
parametrize
(
"size_factor"
,
[
0.25
,
0.5
,
1.0
])
def
test_llava_next_image_processor
(
image_assets
,
mm_registry
,
dtype
,
size_factor
):
MODEL_NAME
=
"llava-hf/llava-v1.6-vicuna-7b-hf"
hf_processor
=
LlavaNextImageProcessor
.
from_pretrained
(
MODEL_NAME
)
assert
isinstance
(
hf_processor
,
LlavaNextImageProcessor
)
model_config
=
ModelConfig
(
model
=
MODEL_NAME
,
task
=
"auto"
,
tokenizer
=
MODEL_NAME
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
seed
=
0
,
dtype
=
dtype
,
revision
=
None
,
limit_mm_per_prompt
=
{
"image"
:
1
},
)
mm_registry
.
init_mm_limits_per_prompt
(
model_config
)
for
asset
in
image_assets
:
image
=
rescale_image_size
(
asset
.
pil_image
,
size_factor
)
hf_result
=
hf_processor
.
preprocess
(
image
,
return_tensors
=
"pt"
,
)
vllm_result
=
mm_registry
.
map_input
(
model_config
,
{
"image"
:
image
},
)
assert
hf_result
.
keys
()
==
vllm_result
.
keys
()
for
key
,
hf_tensor
in
hf_result
.
items
():
hf_arr
:
np
.
ndarray
=
hf_tensor
.
numpy
()
vllm_arr
:
np
.
ndarray
=
vllm_result
[
key
].
numpy
()
assert
hf_arr
.
shape
==
vllm_arr
.
shape
,
f
"Failed for key=
{
key
}
"
assert
np
.
allclose
(
hf_arr
,
vllm_arr
),
f
"Failed for key=
{
key
}
"
@
pytest
.
mark
.
parametrize
(
(
"num_images"
,
"limit"
,
"is_valid"
),
[(
0
,
0
,
True
),
(
0
,
1
,
True
),
(
1
,
0
,
False
),
(
1
,
1
,
True
),
(
1
,
2
,
True
),
(
2
,
1
,
False
),
(
2
,
2
,
True
)],
)
def
test_mm_limits
(
image_assets
,
mm_registry
,
num_images
,
limit
,
is_valid
):
MODEL_NAME
=
"llava-hf/llava-v1.6-mistral-7b-hf"
model_config
=
ModelConfig
(
model
=
MODEL_NAME
,
task
=
"auto"
,
tokenizer
=
MODEL_NAME
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
seed
=
0
,
dtype
=
"half"
,
revision
=
None
,
limit_mm_per_prompt
=
{
"image"
:
limit
},
)
mm_registry
.
init_mm_limits_per_prompt
(
model_config
)
image
=
image_assets
[
0
].
pil_image
if
num_images
==
0
:
mm_inputs
=
{}
elif
num_images
==
1
:
mm_inputs
=
{
"image"
:
image
}
else
:
mm_inputs
=
{
"image"
:
[
image
]
*
num_images
}
with
nullcontext
()
if
is_valid
else
pytest
.
raises
(
ValueError
):
mm_registry
.
map_input
(
model_config
,
mm_inputs
)
# NOTE: We don't test zero images since the HF processor doesn't support it
@
pytest
.
mark
.
parametrize
(
"num_images"
,
[
1
,
2
])
def
test_image_mapper_multi
(
image_assets
,
mm_registry
,
num_images
):
MODEL_NAME
=
"llava-hf/llava-v1.6-mistral-7b-hf"
model_config
=
ModelConfig
(
model
=
MODEL_NAME
,
task
=
"auto"
,
tokenizer
=
MODEL_NAME
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
seed
=
0
,
dtype
=
"half"
,
revision
=
None
,
limit_mm_per_prompt
=
{
"image"
:
num_images
},
)
mm_registry
.
init_mm_limits_per_prompt
(
model_config
)
image
=
image_assets
[
0
].
pil_image
mm_inputs
=
{
"image"
:
[
image
]
*
num_images
}
mapped_inputs
=
mm_registry
.
map_input
(
model_config
,
mm_inputs
)
assert
len
(
mapped_inputs
[
"pixel_values"
])
==
num_images
tests/multimodal/test_processing.py
View file @
8c38ee70
from
contextlib
import
nullcontext
from
functools
import
partial
from
functools
import
partial
from
typing
import
cast
from
typing
import
cast
from
unittest.mock
import
MagicMock
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
...
@@ -526,6 +528,100 @@ def _rand_audio(
...
@@ -526,6 +528,100 @@ def _rand_audio(
return
rng
.
rand
(
audio_len
),
sr
return
rng
.
rand
(
audio_len
),
sr
@
pytest
.
mark
.
parametrize
(
"model_id"
,
[
"llava-hf/llava-v1.6-mistral-7b-hf"
])
@
pytest
.
mark
.
parametrize
(
(
"limit"
,
"num_supported"
,
"is_valid"
),
[(
0
,
0
,
True
),
(
0
,
1
,
True
),
(
1
,
0
,
False
),
(
1
,
1
,
True
),
(
1
,
2
,
True
),
(
2
,
1
,
False
),
(
2
,
2
,
True
)],
)
def
test_limit_mm_per_prompt_dummy
(
model_id
,
limit
,
num_supported
,
is_valid
):
limit_mm_per_prompt
=
{
"image"
:
limit
}
model_config
=
ModelConfig
(
model
=
model_id
,
task
=
"auto"
,
tokenizer
=
model_id
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
seed
=
0
,
dtype
=
"half"
,
revision
=
None
,
limit_mm_per_prompt
=
limit_mm_per_prompt
,
)
model_cls
=
MULTIMODAL_REGISTRY
.
_get_model_cls
(
model_config
)
processor_factory
=
MULTIMODAL_REGISTRY
.
_processor_factories
[
model_cls
]
ctx
=
InputProcessingContext
(
model_config
,
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
),
)
processor
=
processor_factory
(
ctx
,
cache
=
None
)
mock_supported_mm_limits
=
MagicMock
(
return_value
=
{
"image"
:
num_supported
})
processor
.
get_supported_mm_limits
=
mock_supported_mm_limits
if
is_valid
:
exc_ctx
=
nullcontext
()
else
:
exc_ctx
=
pytest
.
raises
(
ValueError
,
match
=
"this model only supports"
)
with
exc_ctx
:
processor
.
_get_and_validate_dummy_mm_counts
()
@
pytest
.
mark
.
parametrize
(
"model_id"
,
[
"llava-hf/llava-v1.6-mistral-7b-hf"
])
@
pytest
.
mark
.
parametrize
(
(
"num_images"
,
"limit"
,
"is_valid"
),
[(
0
,
0
,
True
),
(
0
,
1
,
True
),
(
1
,
0
,
False
),
(
1
,
1
,
True
),
(
1
,
2
,
True
),
(
2
,
1
,
False
),
(
2
,
2
,
True
)],
)
def
test_limit_mm_per_prompt_apply
(
model_id
,
num_images
,
limit
,
is_valid
):
limit_mm_per_prompt
=
{
"image"
:
limit
}
model_config
=
ModelConfig
(
model
=
model_id
,
task
=
"auto"
,
tokenizer
=
model_id
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
False
,
seed
=
0
,
dtype
=
"half"
,
revision
=
None
,
limit_mm_per_prompt
=
limit_mm_per_prompt
,
)
model_cls
=
MULTIMODAL_REGISTRY
.
_get_model_cls
(
model_config
)
processor_factory
=
MULTIMODAL_REGISTRY
.
_processor_factories
[
model_cls
]
ctx
=
InputProcessingContext
(
model_config
,
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
),
)
processor
=
processor_factory
(
ctx
,
cache
=
None
)
rng
=
np
.
random
.
RandomState
(
0
)
image
=
_rand_img
(
rng
,
min_wh
=
128
,
max_wh
=
256
)
if
num_images
==
0
:
mm_data
=
{}
elif
num_images
==
1
:
mm_data
=
{
"image"
:
image
}
else
:
mm_data
=
{
"image"
:
[
image
]
*
num_images
}
if
is_valid
:
exc_ctx
=
nullcontext
()
else
:
exc_ctx
=
pytest
.
raises
(
ValueError
,
match
=
f
"passed
{
num_images
}
image"
)
with
exc_ctx
:
processor
.
apply
(
"<image>"
*
num_images
,
mm_data
=
mm_data
,
hf_processor_mm_kwargs
=
{},
)
def
_test_processing_cache_correctness
(
def
_test_processing_cache_correctness
(
model_id
:
str
,
model_id
:
str
,
modalities
:
dict
[
str
,
bool
],
modalities
:
dict
[
str
,
bool
],
...
@@ -631,6 +727,7 @@ def _test_processing_cache_correctness(
...
@@ -631,6 +727,7 @@ def _test_processing_cache_correctness(
(
"facebook/chameleon-7b"
,
{
"image"
:
False
}),
(
"facebook/chameleon-7b"
,
{
"image"
:
False
}),
(
"adept/fuyu-8b"
,
{
"image"
:
False
}),
(
"adept/fuyu-8b"
,
{
"image"
:
False
}),
(
"llava-hf/llava-1.5-7b-hf"
,
{
"image"
:
True
}),
(
"llava-hf/llava-1.5-7b-hf"
,
{
"image"
:
True
}),
(
"llava-hf/llava-v1.6-mistral-7b-hf"
,
{
"image"
:
True
}),
(
"TIGER-Lab/Mantis-8B-siglip-llama3"
,
{
"image"
:
True
}),
(
"TIGER-Lab/Mantis-8B-siglip-llama3"
,
{
"image"
:
True
}),
(
"mistral-community/pixtral-12b"
,
{
"image"
:
True
}),
(
"mistral-community/pixtral-12b"
,
{
"image"
:
True
}),
(
"Qwen/Qwen2-VL-2B-Instruct"
,
{
"image"
:
True
,
"video"
:
True
}),
(
"Qwen/Qwen2-VL-2B-Instruct"
,
{
"image"
:
True
,
"video"
:
True
}),
...
...
tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py
View file @
8c38ee70
...
@@ -3,13 +3,11 @@ from typing import Optional
...
@@ -3,13 +3,11 @@ from typing import Optional
import
torch
import
torch
from
vllm.model_executor.models.llava
import
(
LlavaForConditionalGeneration
,
from
vllm.model_executor.models.llava
import
(
LlavaForConditionalGeneration
,
LlavaMultiModalProcessor
,
LlavaMultiModalProcessor
)
get_max_llava_image_tokens
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
@
MULTIMODAL_REGISTRY
.
register_max_image_tokens
(
get_max_llava_image_tokens
)
@
MULTIMODAL_REGISTRY
.
register_processor
(
LlavaMultiModalProcessor
)
@
MULTIMODAL_REGISTRY
.
register_processor
(
LlavaMultiModalProcessor
)
class
MyLlava
(
LlavaForConditionalGeneration
):
class
MyLlava
(
LlavaForConditionalGeneration
):
...
...
vllm/model_executor/models/clip.py
View file @
8c38ee70
...
@@ -24,6 +24,8 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
...
@@ -24,6 +24,8 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
resolve_visual_encoder_outputs
)
resolve_visual_encoder_outputs
)
from
vllm.sequence
import
SequenceData
from
vllm.sequence
import
SequenceData
from
.vision
import
VisionEncoderInfo
def
get_clip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
def
get_clip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
assert
image_size
%
patch_size
==
0
assert
image_size
%
patch_size
==
0
...
@@ -149,6 +151,29 @@ def input_processor_for_clip(
...
@@ -149,6 +151,29 @@ def input_processor_for_clip(
multi_modal_placeholders
=
{
"image"
:
ranges
})
multi_modal_placeholders
=
{
"image"
:
ranges
})
class
CLIPEncoderInfo
(
VisionEncoderInfo
[
CLIPVisionConfig
]):
def
get_num_image_tokens
(
self
,
*
,
image_width
:
int
,
image_height
:
int
,
)
->
int
:
return
get_clip_image_feature_size
(
self
.
vision_config
)
def
get_max_image_tokens
(
self
)
->
int
:
return
get_max_clip_image_tokens
(
self
.
vision_config
)
def
get_num_patches
(
self
)
->
int
:
return
get_clip_patch_grid_length
(
image_size
=
self
.
vision_config
.
image_size
,
patch_size
=
self
.
vision_config
.
patch_size
,
)
def
get_image_size
(
self
)
->
int
:
return
self
.
vision_config
.
image_size
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
class
CLIPVisionEmbeddings
(
nn
.
Module
):
class
CLIPVisionEmbeddings
(
nn
.
Module
):
...
...
vllm/model_executor/models/fuyu.py
View file @
8c38ee70
...
@@ -76,7 +76,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor):
...
@@ -76,7 +76,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor):
return
ImageSize
(
width
=
target_size
[
"width"
],
return
ImageSize
(
width
=
target_size
[
"width"
],
height
=
target_size
[
"height"
])
height
=
target_size
[
"height"
])
def
_get_image_grid_size
(
def
_get_image_
feature_
grid_size
(
self
,
self
,
*
,
*
,
image_width
:
int
,
image_width
:
int
,
...
@@ -99,7 +99,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor):
...
@@ -99,7 +99,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor):
def
get_mm_max_tokens_per_item
(
self
)
->
Mapping
[
str
,
int
]:
def
get_mm_max_tokens_per_item
(
self
)
->
Mapping
[
str
,
int
]:
target_width
,
target_height
=
self
.
_get_image_target_size
()
target_width
,
target_height
=
self
.
_get_image_target_size
()
max_ncols
,
max_nrows
=
self
.
_get_image_grid_size
(
max_ncols
,
max_nrows
=
self
.
_get_image_
feature_
grid_size
(
image_width
=
target_width
,
image_width
=
target_width
,
image_height
=
target_height
,
image_height
=
target_height
,
)
)
...
@@ -172,7 +172,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor):
...
@@ -172,7 +172,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor):
images
=
mm_items
.
get_items
(
"image"
,
ImageProcessorItems
)
images
=
mm_items
.
get_items
(
"image"
,
ImageProcessorItems
)
image_size
=
images
.
get_image_size
(
item_idx
)
image_size
=
images
.
get_image_size
(
item_idx
)
ncols
,
nrows
=
self
.
_get_image_grid_size
(
ncols
,
nrows
=
self
.
_get_image_
feature_
grid_size
(
image_width
=
image_size
.
width
,
image_width
=
image_size
.
width
,
image_height
=
image_size
.
height
,
image_height
=
image_size
.
height
,
)
)
...
...
vllm/model_executor/models/llava.py
View file @
8c38ee70
from
abc
import
abstractmethod
from
functools
import
cached_property
from
functools
import
cached_property
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Protocol
,
Set
,
from
typing
import
(
Final
,
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
)
Protocol
,
Set
,
Tuple
,
TypedDict
,
Union
)
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -12,7 +13,6 @@ from transformers.models.pixtral import PixtralProcessor
...
@@ -12,7 +13,6 @@ from transformers.models.pixtral import PixtralProcessor
from
vllm.attention
import
AttentionMetadata
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.inputs
import
InputContext
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
...
@@ -23,23 +23,23 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
...
@@ -23,23 +23,23 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalInputsV2
,
MultiModalKwargs
,
MultiModalInputsV2
,
MultiModalKwargs
,
NestedTensors
)
NestedTensors
)
from
vllm.multimodal.parse
import
ImageProcessorItems
from
vllm.multimodal.parse
import
(
ImageEmbeddingItems
,
ImageProcessorItems
,
ImageSize
)
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
MultiModalDataItems
,
ProcessorInputs
,
InputProcessingContext
,
PromptReplacement
,
MultiModalDataItems
,
ProcessingCache
,
ProcessorInputs
,
PromptReplacement
,
full_groupby_modality
)
full_groupby_modality
)
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.clip
import
(
CLIPVisionModel
,
dummy_image_for_clip
,
from
.clip
import
CLIPVisionModel
get_max_clip_image_tokens
)
from
.interfaces
import
SupportsMultiModal
,
SupportsPP
from
.interfaces
import
SupportsMultiModal
,
SupportsPP
from
.pixtral
import
(
PixtralHFVisionModel
,
dummy_image_for_pixtral_hf
,
from
.pixtral
import
(
PixtralHFVisionModel
,
get_max_pixtral_hf_image_tokens
,
get_pixtral_hf_image_feature_grid_size
)
get_pixtral_hf_image_feature_size
)
from
.siglip
import
SiglipVisionModel
from
.siglip
import
(
SiglipVisionModel
,
dummy_image_for_siglip
,
get_max_siglip_image_tokens
)
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
init_vllm_registered_model
,
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
init_vllm_registered_model
,
maybe_prefix
,
merge_multimodal_embeddings
)
maybe_prefix
,
merge_multimodal_embeddings
)
from
.vision
import
vision_encoder_info
class
LlavaImagePixelInputs
(
TypedDict
):
class
LlavaImagePixelInputs
(
TypedDict
):
...
@@ -94,39 +94,167 @@ class LlavaMultiModalProjector(nn.Module):
...
@@ -94,39 +94,167 @@ class LlavaMultiModalProjector(nn.Module):
return
hidden_states
return
hidden_states
def
get_max_llava_image_tokens
(
ctx
:
InputContext
):
class
LlavaLikeConfig
(
Protocol
):
hf_config
=
ctx
.
get_hf_config
(
LlavaConfig
)
vision_config
:
Final
[
PretrainedConfig
]
vision_config
=
hf_config
.
vision_config
vision_feature_select_strategy
:
Final
[
str
]
vision_feature_layer
:
Final
[
Union
[
int
,
List
[
int
]]]
if
isinstance
(
vision_config
,
CLIPVisionConfig
):
num_image_tokens
=
get_max_clip_image_tokens
(
vision_config
)
elif
isinstance
(
vision_config
,
SiglipVisionConfig
):
num_image_tokens
=
get_max_siglip_image_tokens
(
vision_config
)
elif
isinstance
(
vision_config
,
PixtralVisionConfig
):
num_image_tokens
=
get_max_pixtral_hf_image_tokens
(
vision_config
)
else
:
msg
=
f
"Unsupported vision config:
{
type
(
vision_config
)
}
"
raise
NotImplementedError
(
msg
)
strategy
=
hf_config
.
vision_feature_select_strategy
class
BaseLlavaMultiModalProcessor
(
BaseMultiModalProcessor
):
if
strategy
==
"default"
:
return
num_image_tokens
-
1
elif
strategy
==
"full"
:
return
num_image_tokens
else
:
raise
ValueError
(
f
"Unexpected select feature strategy:
{
strategy
}
"
)
def
__init__
(
self
,
ctx
:
InputProcessingContext
,
*
,
cache
:
Optional
[
ProcessingCache
]
=
None
,
enable_sanity_checks
:
bool
=
True
)
->
None
:
super
().
__init__
(
ctx
,
cache
=
cache
,
enable_sanity_checks
=
enable_sanity_checks
)
vision_config
=
self
.
_get_hf_config
().
vision_config
self
.
_vision_encoder_info
=
vision_encoder_info
(
vision_config
)
class
LlavaMultiModalProcessor
(
BaseMultiModalProcessor
):
@
abstractmethod
def
_get_hf_config
(
self
)
->
LlavaLikeConfig
:
raise
NotImplementedError
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
None
}
return
{
"image"
:
None
}
def
_apply_feature_select_strategy
(
self
,
strategy
:
str
,
encoder_num_image_tokens
:
int
,
)
->
int
:
if
strategy
==
"default"
:
return
encoder_num_image_tokens
-
1
if
strategy
==
"full"
:
return
encoder_num_image_tokens
msg
=
f
"Unexpected feature select strategy:
{
strategy
!
r
}
"
raise
NotImplementedError
(
msg
)
def
_get_max_image_tokens
(
self
)
->
int
:
hf_config
=
self
.
_get_hf_config
()
return
self
.
_apply_feature_select_strategy
(
hf_config
.
vision_feature_select_strategy
,
self
.
_vision_encoder_info
.
get_max_image_tokens
(),
)
def
get_mm_max_tokens_per_item
(
self
)
->
Mapping
[
str
,
int
]:
def
get_mm_max_tokens_per_item
(
self
)
->
Mapping
[
str
,
int
]:
return
{
"image"
:
get_max_llava_image_tokens
(
self
.
ctx
)}
return
{
"image"
:
self
.
_get_max_image_tokens
()}
def
_get_mm_fields_config
(
self
,
hf_inputs
:
BatchFeature
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
return
dict
(
pixel_values
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_embeds
=
MultiModalFieldConfig
.
batched
(
"image"
),
)
def
_get_dummy_image_size
(
self
)
->
ImageSize
:
image_size
=
self
.
_vision_encoder_info
.
get_image_size
()
return
ImageSize
(
image_size
,
image_size
)
@
abstractmethod
def
_get_image_token
(
self
)
->
str
:
raise
NotImplementedError
def
_get_dummy_mm_inputs
(
self
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
ProcessorInputs
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
image_token
=
self
.
_get_image_token
()
target_width
,
target_height
=
self
.
_get_dummy_image_size
()
mm_data
=
{
"image"
:
self
.
_get_dummy_images
(
width
=
target_width
,
height
=
target_height
,
num_images
=
num_images
)
}
return
ProcessorInputs
(
prompt_text
=
image_token
*
num_images
,
mm_data
=
mm_data
,
)
class
LlavaMultiModalProcessor
(
BaseLlavaMultiModalProcessor
):
def
_get_hf_config
(
self
)
->
LlavaConfig
:
return
self
.
ctx
.
get_hf_config
(
LlavaConfig
)
def
_get_hf_processor
(
self
)
->
LlavaProcessor
:
return
self
.
ctx
.
get_hf_processor
(
LlavaProcessor
)
def
_get_image_token
(
self
)
->
str
:
return
self
.
_get_hf_processor
().
image_token
def
_get_num_image_tokens
(
self
,
*
,
image_width
:
int
,
image_height
:
int
,
)
->
int
:
hf_config
=
self
.
_get_hf_config
()
return
self
.
_apply_feature_select_strategy
(
hf_config
.
vision_feature_select_strategy
,
self
.
_vision_encoder_info
.
get_num_image_tokens
(
image_width
=
image_width
,
image_height
=
image_height
,
),
)
def
_get_prompt_replacements
(
self
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
out_mm_kwargs
:
MultiModalKwargs
,
)
->
list
[
PromptReplacement
]:
hf_config
=
self
.
_get_hf_config
()
image_token_id
=
hf_config
.
image_token_index
def
_get_hf_processor
(
self
)
->
Union
[
LlavaProcessor
,
PixtralProcessor
]:
def
get_replacement
(
item_idx
:
int
):
return
self
.
ctx
.
get_hf_processor
((
LlavaProcessor
,
PixtralProcessor
))
images
=
mm_items
.
get_items
(
"image"
,
(
ImageEmbeddingItems
,
ImageProcessorItems
))
if
isinstance
(
images
,
ImageEmbeddingItems
):
num_image_tokens
=
images
.
get_feature_size
(
item_idx
)
else
:
image_size
=
images
.
get_image_size
(
item_idx
)
num_image_tokens
=
self
.
_get_num_image_tokens
(
image_width
=
image_size
.
width
,
image_height
=
image_size
.
height
,
)
return
[
image_token_id
]
*
num_image_tokens
return
[
PromptReplacement
(
modality
=
"image"
,
target
=
[
image_token_id
],
replacement
=
get_replacement
,
),
]
class
PixtralHFMultiModalProcessor
(
BaseLlavaMultiModalProcessor
):
def
_get_hf_config
(
self
)
->
LlavaConfig
:
return
self
.
ctx
.
get_hf_config
(
LlavaConfig
)
def
_get_hf_processor
(
self
)
->
PixtralProcessor
:
return
self
.
ctx
.
get_hf_processor
(
PixtralProcessor
)
def
_get_image_token
(
self
)
->
str
:
return
self
.
_get_hf_processor
().
image_token
def
_call_hf_processor
(
def
_call_hf_processor
(
self
,
self
,
...
@@ -140,119 +268,82 @@ class LlavaMultiModalProcessor(BaseMultiModalProcessor):
...
@@ -140,119 +268,82 @@ class LlavaMultiModalProcessor(BaseMultiModalProcessor):
mm_kwargs
=
mm_kwargs
,
mm_kwargs
=
mm_kwargs
,
)
)
# NOTE: pixel_values=None for MLlavaProcessor
pixel_values
=
processed_outputs
.
get
(
"pixel_values"
)
pixel_values
=
processed_outputs
.
get
(
"pixel_values"
)
if
pixel_values
is
not
None
:
if
pixel_values
is
not
None
:
images
=
mm_data
[
"images"
]
images
=
mm_data
[
"images"
]
assert
isinstance
(
images
,
list
)
assert
isinstance
(
images
,
list
)
if
isinstance
(
self
.
_get_hf_processor
(),
PixtralProcessor
):
# Original output: (1, num_images, C, H, W)
# Original output: (1, num_images, C, H, W)
# New output: (num_images, C, H, W)
# New output: (num_images, C, H, W)
assert
(
isinstance
(
pixel_values
,
list
)
and
len
(
pixel_values
)
==
1
)
assert
(
isinstance
(
pixel_values
,
list
)
assert
(
isinstance
(
pixel_values
[
0
],
list
)
and
len
(
pixel_values
)
==
1
)
and
len
(
pixel_values
[
0
])
==
len
(
images
))
assert
(
isinstance
(
pixel_values
[
0
],
list
)
and
len
(
pixel_values
[
0
])
==
len
(
images
))
processed_outputs
[
"pixel_values"
]
=
pixel_values
[
0
]
processed_outputs
[
"pixel_values"
]
=
pixel_values
[
0
]
return
processed_outputs
return
processed_outputs
def
_get_mm_fields_config
(
self
,
hf_inputs
:
BatchFeature
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
return
dict
(
pixel_values
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_embeds
=
MultiModalFieldConfig
.
batched
(
"image"
),
)
def
_get_prompt_replacements
(
def
_get_prompt_replacements
(
self
,
self
,
mm_items
:
MultiModalDataItems
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
out_mm_kwargs
:
MultiModalKwargs
,
out_mm_kwargs
:
MultiModalKwargs
,
)
->
list
[
PromptReplacement
]:
)
->
list
[
PromptReplacement
]:
hf_config
=
self
.
ctx
.
get_hf_config
(
LlavaConfig
)
hf_config
=
self
.
_
get_hf_config
()
image_token_id
=
hf_config
.
image_token_index
image_token_id
=
hf_config
.
image_token_index
processor
=
self
.
_get_hf_processor
()
processor
=
self
.
_get_hf_processor
()
if
isinstance
(
processor
,
PixtralProcessor
):
image_token
=
processor
.
image_token
image_token
=
processor
.
image_token
image_break_token
=
processor
.
image_break_token
image_break_token
=
processor
.
image_break_token
image_end_token
=
processor
.
image_end_token
image_end_token
=
processor
.
image_end_token
vision_config
=
hf_config
.
vision_config
assert
isinstance
(
vision_config
,
PixtralVisionConfig
)
def
get_replacement_pixtral
(
item_idx
:
int
):
vision_config
=
hf_config
.
vision_config
images
=
mm_items
.
get_items
(
"image"
,
ImageProcessorItems
)
assert
isinstance
(
vision_config
,
PixtralVisionConfig
)
image_size
=
images
.
get_image_size
(
item_idx
)
(
num_width_tokens
,
num_height_tokens
,
)
=
get_pixtral_hf_image_feature_size
(
vision_config
,
image_width
=
image_size
.
width
,
image_height
=
image_size
.
height
,
)
tokens
=
([
image_token
]
*
nu
m_
w
id
th_tokens
+
def
get_replacement
(
ite
m_id
x
:
int
):
[
image_break_token
])
*
num_height_tokens
images
=
mm_items
.
get_items
(
"image"
,
ImageProcessorItems
)
tokens
[
-
1
]
=
image_end_token
image_size
=
images
.
get_image_size
(
item_idx
)
return
""
.
join
(
tokens
)
ncols
,
nrows
=
get_pixtral_hf_image_feature_grid_size
(
vision_config
,
image_width
=
image_size
.
width
,
image_height
=
image_size
.
height
,
)
return
[
tokens
=
([
image_token
]
*
ncols
+
[
image_break_token
])
*
nrows
PromptReplacement
(
tokens
[
-
1
]
=
image_end_token
modality
=
"image"
,
target
=
[
image_token_id
],
replacement
=
get_replacement_pixtral
,
),
]
max_image_tokens
=
get_max_llava_image_tokens
(
self
.
ctx
)
return
""
.
join
(
tokens
)
return
[
return
[
PromptReplacement
(
PromptReplacement
(
modality
=
"image"
,
modality
=
"image"
,
target
=
[
image_token_id
],
target
=
[
image_token_id
],
replacement
=
[
image_token_id
]
*
max_image_tok
en
s
,
replacement
=
get_replacem
en
t
,
)
)
,
]
]
def
_get_dummy_mm_inputs
(
self
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
ProcessorInputs
:
hf_config
=
self
.
ctx
.
get_hf_config
(
LlavaConfig
)
vision_config
=
hf_config
.
vision_config
num_images
=
mm_counts
.
get
(
"image"
,
0
)
if
isinstance
(
vision_config
,
CLIPVisionConfig
):
data
=
dummy_image_for_clip
(
vision_config
,
num_images
)
elif
isinstance
(
vision_config
,
SiglipVisionConfig
):
data
=
dummy_image_for_siglip
(
vision_config
,
num_images
)
elif
isinstance
(
vision_config
,
PixtralVisionConfig
):
data
=
dummy_image_for_pixtral_hf
(
vision_config
,
num_images
)
else
:
msg
=
f
"Unsupported vision config:
{
type
(
vision_config
)
}
"
raise
NotImplementedError
(
msg
)
hf_processor
=
self
.
_get_hf_processor
()
def
_build_llava_or_pixtral_hf_processor
(
image_token
=
hf_processor
.
image_token
ctx
:
InputProcessingContext
,
*
,
cache
:
Optional
[
ProcessingCache
]
=
None
,
enable_sanity_checks
:
bool
=
True
,
)
->
BaseLlavaMultiModalProcessor
:
hf_config
=
ctx
.
get_hf_config
(
LlavaConfig
)
return
ProcessorInputs
(
if
isinstance
(
hf_config
.
vision_config
,
PixtralVisionConfig
):
prompt_text
=
image_token
*
num_images
,
return
PixtralHFMultiModalProcessor
(
mm_data
=
data
,
ctx
,
cache
=
cache
,
enable_sanity_checks
=
enable_sanity_checks
,
)
)
return
LlavaMultiModalProcessor
(
class
LlavaLikeConfig
(
Protocol
):
ctx
,
vision_config
:
PretrainedConfig
cache
=
cache
,
vision_feature_layer
:
Union
[
int
,
List
[
int
]]
enable_sanity_checks
=
enable_sanity_checks
,
)
def
_get_num_hidden_layers
(
hf_config
:
LlavaLikeConfig
)
->
int
:
def
_get_num_hidden_layers
(
hf_config
:
LlavaLikeConfig
)
->
int
:
...
@@ -330,7 +421,7 @@ def init_vision_tower_for_llava(
...
@@ -330,7 +421,7 @@ def init_vision_tower_for_llava(
raise
NotImplementedError
(
msg
)
raise
NotImplementedError
(
msg
)
@
MULTIMODAL_REGISTRY
.
register_processor
(
LlavaMultiModalP
rocessor
)
@
MULTIMODAL_REGISTRY
.
register_processor
(
_build_llava_or_pixtral_hf_p
rocessor
)
class
LlavaForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
class
LlavaForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
# BitandBytes specific attributes
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping
=
{
bitsandbytes_stacked_params_mapping
=
{
...
@@ -596,7 +687,12 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
...
@@ -596,7 +687,12 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
)
->
MultiModalInputsV2
:
)
->
MultiModalInputsV2
:
hf_config
=
self
.
ctx
.
get_hf_config
(
LlavaConfig
)
hf_config
=
self
.
ctx
.
get_hf_config
(
LlavaConfig
)
image_token_id
=
hf_config
.
image_token_index
image_token_id
=
hf_config
.
image_token_index
max_image_tokens
=
get_max_llava_image_tokens
(
self
.
ctx
)
# Assume that it doesn't depend on the image size
num_image_tokens
=
self
.
_get_num_image_tokens
(
image_width
=-
1
,
image_height
=-
1
,
)
result
=
super
().
apply
(
prompt_text
,
mm_data
,
hf_processor_mm_kwargs
)
result
=
super
().
apply
(
prompt_text
,
mm_data
,
hf_processor_mm_kwargs
)
...
@@ -609,14 +705,14 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
...
@@ -609,14 +705,14 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
def
get_replacement_mantis
(
item_idx
:
int
):
def
get_replacement_mantis
(
item_idx
:
int
):
return
""
.
join
([
return
""
.
join
([
f
"(image
{
item_idx
+
1
}
: <Image>"
,
# 7 tokens
f
"(image
{
item_idx
+
1
}
: <Image>"
,
# 7 tokens
"<image>"
*
max
_image_tokens
,
"<image>"
*
num
_image_tokens
,
"</Image>)"
,
# 3 tokens
"</Image>)"
,
# 3 tokens
])
])
mantis_repls
=
self
.
_bind_prompt_replacements
([
mantis_repls
=
self
.
_bind_prompt_replacements
([
PromptReplacement
(
PromptReplacement
(
modality
=
"image"
,
modality
=
"image"
,
target
=
[
image_token_id
]
*
max
_image_tokens
,
target
=
[
image_token_id
]
*
num
_image_tokens
,
replacement
=
get_replacement_mantis
,
replacement
=
get_replacement_mantis
,
)
)
])
])
...
...
vllm/model_executor/models/llava_next.py
View file @
8c38ee70
...
@@ -4,31 +4,25 @@ from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
...
@@ -4,31 +4,25 @@ from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
PIL
import
Image
from
transformers
import
BatchFeature
,
LlavaNextConfig
,
LlavaNextProcessor
from
transformers
import
CLIPVisionConfig
,
LlavaNextConfig
,
SiglipVisionConfig
from
transformers.models.llava_next.modeling_llava_next
import
(
from
transformers.models.llava_next.modeling_llava_next
import
(
get_anyres_image_grid_shape
,
unpad_image
)
get_anyres_image_grid_shape
,
unpad_image
)
from
typing_extensions
import
NotRequired
from
typing_extensions
import
NotRequired
from
vllm.attention
import
AttentionMetadata
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.inputs
import
(
INPUT_REGISTRY
,
DecoderOnlyInputs
,
DummyData
,
InputContext
)
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
NestedTensors
from
vllm.multimodal.inputs
import
MultiModalFieldConfig
,
NestedTensors
from
vllm.multimodal.parse
import
ImageSize
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
is_list_of
from
.clip
import
(
CLIPVisionModel
,
dummy_image_for_clip
,
from
.clip
import
CLIPVisionModel
dummy_seq_data_for_clip
,
get_clip_image_feature_size
,
get_clip_patch_grid_length
,
input_processor_for_clip
)
from
.interfaces
import
SupportsMultiModal
,
SupportsPP
from
.interfaces
import
SupportsMultiModal
,
SupportsPP
from
.llava
import
LlavaMultiModalProjector
,
init_vision_tower_for_llava
from
.llava
import
(
LlavaMultiModalProcessor
,
LlavaMultiModalProjector
,
from
.siglip
import
(
SiglipVisionModel
,
dummy_image_for_siglip
,
init_vision_tower_for_llava
)
dummy_seq_data_for_siglip
,
get_siglip_image_feature_size
,
from
.siglip
import
SiglipVisionModel
get_siglip_patch_grid_length
,
input_processor_for_siglip
)
from
.utils
import
(
AutoWeightsLoader
,
embed_multimodal
,
flatten_bn
,
from
.utils
import
(
AutoWeightsLoader
,
embed_multimodal
,
flatten_bn
,
init_vllm_registered_model
,
maybe_prefix
)
init_vllm_registered_model
,
maybe_prefix
)
...
@@ -65,218 +59,127 @@ LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
...
@@ -65,218 +59,127 @@ LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
LlavaNextImageEmbeddingInputs
]
LlavaNextImageEmbeddingInputs
]
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79
class
LlavaNextMultiModalProcessor
(
LlavaMultiModalProcessor
):
def
_get_llava_next_num_unpadded_features
(
original_height
:
int
,
original_width
:
int
,
npatches
:
int
,
num_patch_height
:
int
,
num_patch_width
:
int
,
)
->
Tuple
[
int
,
int
]:
current_height
=
npatches
*
num_patch_height
current_width
=
npatches
*
num_patch_width
original_aspect_ratio
=
original_width
/
original_height
current_aspect_ratio
=
current_width
/
current_height
if
original_aspect_ratio
>
current_aspect_ratio
:
scale_factor
=
current_width
/
original_width
new_height
=
int
(
original_height
*
scale_factor
)
padding
=
(
current_height
-
new_height
)
//
2
current_height
-=
2
*
padding
else
:
scale_factor
=
current_height
/
original_height
new_width
=
int
(
original_width
*
scale_factor
)
padding
=
(
current_width
-
new_width
)
//
2
current_width
-=
2
*
padding
unpadded_features
=
current_height
*
current_width
newline_features
=
current_height
return
(
unpadded_features
,
newline_features
)
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106
def
get_llava_next_image_feature_size
(
hf_config
:
LlavaNextConfig
,
*
,
input_height
:
int
,
input_width
:
int
,
)
->
int
:
vision_config
=
hf_config
.
vision_config
if
isinstance
(
vision_config
,
CLIPVisionConfig
):
num_patches
=
get_clip_patch_grid_length
(
image_size
=
vision_config
.
image_size
,
patch_size
=
vision_config
.
patch_size
,
)
base_feature_size
=
get_clip_image_feature_size
(
vision_config
)
elif
isinstance
(
vision_config
,
SiglipVisionConfig
):
num_patches
=
get_siglip_patch_grid_length
(
image_size
=
vision_config
.
image_size
,
patch_size
=
vision_config
.
patch_size
,
)
base_feature_size
=
get_siglip_image_feature_size
(
vision_config
)
else
:
msg
=
f
"Unsupported vision config:
{
type
(
vision_config
)
}
"
raise
NotImplementedError
(
msg
)
strategy
=
hf_config
.
vision_feature_select_strategy
if
strategy
==
"default"
:
base_feature_size
-=
1
elif
strategy
==
"full"
:
pass
else
:
raise
ValueError
(
f
"Unexpected select feature strategy:
{
strategy
}
"
)
num_patch_height
,
num_patch_width
=
get_anyres_image_grid_shape
(
def
_get_hf_config
(
self
)
->
LlavaNextConfig
:
image_size
=
(
input_height
,
input_width
),
return
self
.
ctx
.
get_hf_config
(
LlavaNextConfig
)
grid_pinpoints
=
hf_config
.
image_grid_pinpoints
,
patch_size
=
vision_config
.
image_size
,
def
_get_hf_processor
(
self
)
->
LlavaNextProcessor
:
)
return
self
.
ctx
.
get_hf_processor
(
LlavaNextProcessor
)
(
unpadded_feature_size
,
newline_feature_size
,
)
=
_get_llava_next_num_unpadded_features
(
input_height
,
input_width
,
num_patches
,
num_patch_height
,
num_patch_width
)
return
unpadded_feature_size
+
newline_feature_size
+
base_feature_size
def
get_max_llava_next_image_tokens
(
ctx
:
InputContext
):
"""Compute the max feature size for all possible image grid pinpoints."""
return
_get_pinpoint_with_largest_features
(
ctx
)[
0
]
def
_get_pinpoint_with_largest_features
(
ctx
:
InputContext
)
->
Tuple
[
int
,
Tuple
[
int
,
int
]]:
"""Get the grid pinpoint with the largest features & its feature size."""
hf_config
=
ctx
.
get_hf_config
(
LlavaNextConfig
)
largest_feature_size
=
0
largest_feature_pinpoint
=
None
for
(
height
,
width
)
in
hf_config
.
image_grid_pinpoints
:
feat_size
=
get_llava_next_image_feature_size
(
hf_config
,
input_height
=
height
,
input_width
=
width
,
)
if
feat_size
>
largest_feature_size
:
largest_feature_size
=
feat_size
largest_feature_pinpoint
=
(
height
,
width
)
if
not
largest_feature_size
or
largest_feature_pinpoint
is
None
:
raise
ValueError
(
"Cannot have a largest feature size of 0!"
)
return
largest_feature_size
,
largest_feature_pinpoint
def
dummy_data_for_llava_next
(
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
]):
hf_config
=
ctx
.
get_hf_config
(
LlavaNextConfig
)
vision_config
=
hf_config
.
vision_config
num_images
=
mm_counts
[
"image"
]
image_feature_size
,
pinpoint
=
_get_pinpoint_with_largest_features
(
ctx
)
max_feat_height
,
max_feat_width
=
pinpoint
if
isinstance
(
vision_config
,
CLIPVisionConfig
):
seq_data
,
ranges
=
dummy_seq_data_for_clip
(
vision_config
,
seq_len
,
num_images
,
image_token_id
=
hf_config
.
image_token_index
,
image_feature_size_override
=
image_feature_size
,
)
mm_data
=
dummy_image_for_clip
(
def
_get_image_token
(
self
)
->
str
:
vision_config
,
return
self
.
_get_hf_processor
().
image_token
num_images
,
image_width_override
=
max_feat_width
,
def
_get_mm_fields_config
(
image_height_override
=
max_feat_height
,
self
,
hf_inputs
:
BatchFeature
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
return
dict
(
pixel_values
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_sizes
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_embeds
=
MultiModalFieldConfig
.
batched
(
"image"
),
)
)
return
DummyData
(
seq_data
,
mm_data
,
ranges
)
def
_get_max_image_tokens
(
self
)
->
int
:
elif
isinstance
(
vision_config
,
SiglipVisionConfig
):
largest_feature_size
,
_
=
self
.
_get_pinpoint_with_most_features
()
seq_data
,
ranges
=
dummy_seq_data_for_siglip
(
return
largest_feature_size
vision_config
,
seq_len
,
def
_get_dummy_image_size
(
self
)
->
ImageSize
:
num_images
,
_
,
pinpoint
=
self
.
_get_pinpoint_with_most_features
()
image_token_id
=
hf_config
.
image_token_index
,
return
pinpoint
image_feature_size_override
=
image_feature_size
,
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106
def
_get_num_image_tokens
(
self
,
*
,
image_width
:
int
,
image_height
:
int
,
)
->
int
:
hf_config
=
self
.
_get_hf_config
()
base_feature_size
=
self
.
_apply_feature_select_strategy
(
hf_config
.
vision_feature_select_strategy
,
self
.
_vision_encoder_info
.
get_num_image_tokens
(
image_width
=
image_width
,
image_height
=
image_height
,
),
)
)
num_patches
=
self
.
_vision_encoder_info
.
get_num_patches
()
mm_data
=
dummy_image_for_siglip
(
num_patch_height
,
num_patch_width
=
get_anyres_image_grid_shape
(
vision_config
,
image_size
=
(
image_height
,
image_width
),
num_images
,
grid_pinpoints
=
hf_config
.
image_grid_pinpoints
,
image_width_override
=
max_feat_width
,
patch_size
=
self
.
_vision_encoder_info
.
get_image_size
(),
image_height_override
=
max_feat_height
,
)
)
return
DummyData
(
seq_data
,
mm_data
,
ranges
)
(
unpadded_feature_size
,
newline_feature_size
,
)
=
self
.
_get_num_unpadded_features
(
original_height
=
image_height
,
original_width
=
image_width
,
npatches
=
num_patches
,
num_patch_height
=
num_patch_height
,
num_patch_width
=
num_patch_width
,
)
msg
=
f
"Unsupported vision config:
{
type
(
vision_config
)
}
"
return
unpadded_feature_size
+
newline_feature_size
+
base_feature_size
raise
NotImplementedError
(
msg
)
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79
def
_get_num_unpadded_features
(
self
,
*
,
original_height
:
int
,
original_width
:
int
,
npatches
:
int
,
num_patch_height
:
int
,
num_patch_width
:
int
,
)
->
tuple
[
int
,
int
]:
current_height
=
npatches
*
num_patch_height
current_width
=
npatches
*
num_patch_width
original_aspect_ratio
=
original_width
/
original_height
current_aspect_ratio
=
current_width
/
current_height
if
original_aspect_ratio
>
current_aspect_ratio
:
scale_factor
=
current_width
/
original_width
new_height
=
int
(
original_height
*
scale_factor
)
padding
=
(
current_height
-
new_height
)
//
2
current_height
-=
2
*
padding
else
:
scale_factor
=
current_height
/
original_height
new_width
=
int
(
original_width
*
scale_factor
)
padding
=
(
current_width
-
new_width
)
//
2
current_width
-=
2
*
padding
def
input_processor_for_llava_next
(
ctx
:
InputContext
,
unpadded_features
=
current_height
*
current_width
inputs
:
DecoderOnlyInputs
):
newline_features
=
current_height
multi_modal_data
=
inputs
.
get
(
"multi_modal_data"
)
return
(
unpadded_features
,
newline_features
)
if
multi_modal_data
is
None
or
"image"
not
in
multi_modal_data
:
return
inputs
model_config
=
ctx
.
model_config
def
_get_pinpoint_with_most_features
(
self
)
->
tuple
[
int
,
ImageSize
]:
hf_config
=
ctx
.
get_hf_config
(
LlavaNextConfig
)
"""
vision_config
=
hf_config
.
vision_config
Get the grid pinpoint with the most features and
the corresponding feature size.
"""
hf_config
=
self
.
_get_hf_config
()
image_data
=
multi_modal_data
[
"image"
]
largest_feature_size
,
largest_feature_pinpoint
=
0
,
None
if
isinstance
(
image_data
,
Image
.
Image
):
for
(
height
,
width
)
in
hf_config
.
image_grid_pinpoints
:
width
,
height
=
image_data
.
size
feat_size
=
self
.
_get_num_image_tokens
(
image_width
=
width
,
image_height
=
height
)
if
feat_size
>
largest_feature_size
:
largest_feature_size
=
feat_size
largest_feature_pinpoint
=
ImageSize
(
width
=
width
,
height
=
height
)
image_feature_size
=
get_llava_next_image_feature_size
(
if
largest_feature_size
==
0
or
largest_feature_pinpoint
is
None
:
hf_config
,
raise
ValueError
(
"Cannot have a largest feature size of 0!"
)
input_height
=
height
,
input_width
=
width
,
)
elif
is_list_of
(
image_data
,
Image
.
Image
):
image_feature_size
=
[
get_llava_next_image_feature_size
(
hf_config
,
input_height
=
img
.
height
,
input_width
=
img
.
width
)
for
img
in
image_data
]
elif
isinstance
(
image_data
,
torch
.
Tensor
):
num_images
,
image_feature_size
,
hidden_size
=
image_data
.
shape
elif
is_list_of
(
image_data
,
torch
.
Tensor
):
image_feature_size
=
[
item
.
shape
[
1
]
for
item
in
image_data
]
else
:
raise
TypeError
(
f
"Invalid image type:
{
type
(
image_data
)
}
"
)
vision_config
=
hf_config
.
vision_config
if
isinstance
(
vision_config
,
CLIPVisionConfig
):
return
input_processor_for_clip
(
model_config
,
vision_config
,
inputs
,
image_token_id
=
hf_config
.
image_token_index
,
image_feature_size_override
=
image_feature_size
,
)
elif
isinstance
(
vision_config
,
SiglipVisionConfig
):
return
input_processor_for_siglip
(
model_config
,
vision_config
,
inputs
,
image_token_id
=
hf_config
.
image_token_index
,
image_feature_size_override
=
image_feature_size
,
)
msg
=
f
"Unsupported vision config:
{
type
(
vision_config
)
}
"
return
largest_feature_size
,
largest_feature_pinpoint
raise
NotImplementedError
(
msg
)
@
MULTIMODAL_REGISTRY
.
register_image_input_mapper
()
@
MULTIMODAL_REGISTRY
.
register_processor
(
LlavaNextMultiModalProcessor
)
@
MULTIMODAL_REGISTRY
.
register_max_image_tokens
(
get_max_llava_next_image_tokens
)
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_data_for_llava_next
)
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_llava_next
)
class
LlavaNextForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
class
LlavaNextForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
SupportsPP
):
...
@@ -507,7 +410,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -507,7 +410,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
def
_process_image_pixels
(
def
_process_image_pixels
(
self
,
self
,
inputs
:
LlavaNextImagePixelInputs
,
inputs
:
LlavaNextImagePixelInputs
,
)
->
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]:
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
...
]]:
assert
self
.
vision_tower
is
not
None
assert
self
.
vision_tower
is
not
None
pixel_values
=
inputs
[
"data"
]
pixel_values
=
inputs
[
"data"
]
...
...
vllm/model_executor/models/phi3v.py
View file @
8c38ee70
...
@@ -34,7 +34,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
...
@@ -34,7 +34,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalInputsV2
,
MultiModalKwargs
,
MultiModalInputsV2
,
MultiModalKwargs
,
NestedTensors
,
PlaceholderRange
)
NestedTensors
,
PlaceholderRange
)
from
vllm.multimodal.parse
import
ImageProcessorItems
from
vllm.multimodal.parse
import
ImageEmbeddingItems
,
ImageProcessorItems
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
MultiModalDataItems
,
ProcessorInputs
,
MultiModalDataItems
,
ProcessorInputs
,
PromptReplacement
,
PromptReplacement
,
...
@@ -388,15 +388,19 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
...
@@ -388,15 +388,19 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
assert
isinstance
(
bos_token_id
,
int
)
assert
isinstance
(
bos_token_id
,
int
)
def
get_replacement_phi3v
(
item_idx
:
int
):
def
get_replacement_phi3v
(
item_idx
:
int
):
images
=
mm_items
.
get_items
(
"image"
,
ImageProcessorItems
)
images
=
mm_items
.
get_items
(
image_size
=
images
.
get_image_size
(
item_idx
)
"image"
,
(
ImageEmbeddingItems
,
ImageProcessorItems
))
num_tokens
=
self
.
_get_num_image_tokens
(
if
isinstance
(
images
,
ImageEmbeddingItems
):
image_width
=
image_size
.
width
,
num_image_tokens
=
images
.
get_feature_size
(
item_idx
)
image_height
=
image_size
.
height
,
else
:
)
image_size
=
images
.
get_image_size
(
item_idx
)
num_image_tokens
=
self
.
_get_num_image_tokens
(
return
[
_IMAGE_TOKEN_ID
]
*
num_tokens
+
[
bos_token_id
]
image_width
=
image_size
.
width
,
image_height
=
image_size
.
height
,
)
return
[
_IMAGE_TOKEN_ID
]
*
num_image_tokens
+
[
bos_token_id
]
num_images
=
mm_items
.
get_count
(
"image"
,
strict
=
False
)
num_images
=
mm_items
.
get_count
(
"image"
,
strict
=
False
)
...
...
vllm/model_executor/models/pixtral.py
View file @
8c38ee70
...
@@ -38,6 +38,7 @@ from vllm.sequence import IntermediateTensors, SequenceData
...
@@ -38,6 +38,7 @@ from vllm.sequence import IntermediateTensors, SequenceData
from
.interfaces
import
SupportsMultiModal
,
SupportsPP
from
.interfaces
import
SupportsMultiModal
,
SupportsPP
from
.utils
import
(
init_vllm_registered_model
,
maybe_prefix
,
from
.utils
import
(
init_vllm_registered_model
,
maybe_prefix
,
merge_multimodal_embeddings
)
merge_multimodal_embeddings
)
from
.vision
import
VisionEncoderInfo
try
:
try
:
from
xformers
import
ops
as
xops
from
xformers
import
ops
as
xops
...
@@ -697,10 +698,18 @@ def get_pixtral_hf_patch_grid_length(*, image_size: int,
...
@@ -697,10 +698,18 @@ def get_pixtral_hf_patch_grid_length(*, image_size: int,
return
image_size
//
patch_size
return
image_size
//
patch_size
def
get_pixtral_hf_num_patches
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
def
get_pixtral_hf_image_feature_size
(
grid_length
=
get_pixtral_hf_patch_grid_length
(
image_size
=
image_size
,
*
,
patch_size
=
patch_size
)
image_size
:
int
,
return
grid_length
*
grid_length
patch_size
:
int
,
)
->
int
:
grid_length
=
get_pixtral_hf_patch_grid_length
(
image_size
=
image_size
,
patch_size
=
patch_size
,
)
# Consider the image_break_token
return
(
grid_length
+
1
)
*
grid_length
def
get_max_pixtral_hf_image_tokens
(
hf_config
:
PixtralVisionConfig
)
->
int
:
def
get_max_pixtral_hf_image_tokens
(
hf_config
:
PixtralVisionConfig
)
->
int
:
...
@@ -730,13 +739,16 @@ def dummy_image_for_pixtral_hf(
...
@@ -730,13 +739,16 @@ def dummy_image_for_pixtral_hf(
return
{
"image"
:
image
if
num_images
==
1
else
[
image
]
*
num_images
}
return
{
"image"
:
image
if
num_images
==
1
else
[
image
]
*
num_images
}
def
get_pixtral_hf_image_feature_size
(
hf_config
:
PixtralVisionConfig
,
# Adapted from transformers.models.pixtral.image_processing_pixtral.get_resize_output_image_size # noqa: E501
image_width
:
int
,
# https://github.com/huggingface/transformers/blob/2bd4d5897dc73e8b172832070a6f9e567a0df017/src/transformers/models/pixtral/image_processing_pixtral.py#L180
image_height
:
int
)
->
Tuple
[
int
,
int
]:
def
get_pixtral_hf_image_feature_grid_size
(
# Adapted from transformers.models.pixtral.image_processing_pixtral.get_resize_output_image_size # noqa: E501
hf_config
:
PixtralVisionConfig
,
# https://github.com/huggingface/transformers/blob/2bd4d5897dc73e8b172832070a6f9e567a0df017/src/transformers/models/pixtral/image_processing_pixtral.py#L180 # noqa: E501
*
,
max_width
,
max_height
=
hf_config
.
image_size
,
hf_config
.
image_size
image_width
:
int
,
patch_width
,
patch_height
=
hf_config
.
patch_size
,
hf_config
.
patch_size
image_height
:
int
,
)
->
tuple
[
int
,
int
]:
max_width
=
max_height
=
hf_config
.
image_size
patch_width
=
patch_height
=
hf_config
.
patch_size
ratio
=
max
(
image_width
/
max_width
,
image_height
/
max_height
)
ratio
=
max
(
image_width
/
max_width
,
image_height
/
max_height
)
...
@@ -744,12 +756,38 @@ def get_pixtral_hf_image_feature_size(hf_config: PixtralVisionConfig,
...
@@ -744,12 +756,38 @@ def get_pixtral_hf_image_feature_size(hf_config: PixtralVisionConfig,
image_width
=
int
(
math
.
ceil
(
image_width
/
ratio
))
image_width
=
int
(
math
.
ceil
(
image_width
/
ratio
))
image_height
=
int
(
math
.
ceil
(
image_height
/
ratio
))
image_height
=
int
(
math
.
ceil
(
image_height
/
ratio
))
n
um_height_tokens
,
num_width_token
s
=
_get_pixtral_hf_num_image_tokens
(
n
rows
,
ncol
s
=
_get_pixtral_hf_num_image_tokens
(
(
image_height
,
image_width
),
(
image_height
,
image_width
),
(
patch_height
,
patch_width
),
(
patch_height
,
patch_width
),
)
)
# type: ignore
return
ncols
,
nrows
class
PixtralHFEncoderInfo
(
VisionEncoderInfo
[
PixtralVisionConfig
]):
def
get_num_image_tokens
(
self
,
*
,
image_width
:
int
,
image_height
:
int
,
)
->
int
:
return
get_pixtral_hf_image_feature_size
(
image_size
=
self
.
vision_config
.
image_size
,
patch_size
=
self
.
get_image_size
(),
)
def
get_max_image_tokens
(
self
)
->
int
:
return
get_max_pixtral_hf_image_tokens
(
self
.
vision_config
)
def
get_num_patches
(
self
)
->
int
:
return
get_pixtral_hf_patch_grid_length
(
image_size
=
self
.
vision_config
.
image_size
,
patch_size
=
self
.
vision_config
.
patch_size
,
)
return
num_width_tokens
,
num_height_tokens
def
get_image_size
(
self
)
->
int
:
return
self
.
vision_config
.
image_size
class
PixtralHFMLP
(
nn
.
Module
):
class
PixtralHFMLP
(
nn
.
Module
):
...
...
vllm/model_executor/models/siglip.py
View file @
8c38ee70
...
@@ -28,6 +28,8 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
...
@@ -28,6 +28,8 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
resolve_visual_encoder_outputs
)
resolve_visual_encoder_outputs
)
from
vllm.sequence
import
SequenceData
from
vllm.sequence
import
SequenceData
from
.vision
import
VisionEncoderInfo
def
get_siglip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
def
get_siglip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
# Since interpolation is applied, the image size need not be divisible
# Since interpolation is applied, the image size need not be divisible
...
@@ -156,6 +158,29 @@ def input_processor_for_siglip(
...
@@ -156,6 +158,29 @@ def input_processor_for_siglip(
multi_modal_placeholders
=
{
"image"
:
ranges
})
multi_modal_placeholders
=
{
"image"
:
ranges
})
class
SiglipEncoderInfo
(
VisionEncoderInfo
[
SiglipVisionConfig
]):
def
get_num_image_tokens
(
self
,
*
,
image_width
:
int
,
image_height
:
int
,
)
->
int
:
return
get_siglip_image_feature_size
(
self
.
vision_config
)
def
get_max_image_tokens
(
self
)
->
int
:
return
get_max_siglip_image_tokens
(
self
.
vision_config
)
def
get_num_patches
(
self
)
->
int
:
return
get_siglip_patch_grid_length
(
image_size
=
self
.
vision_config
.
image_size
,
patch_size
=
self
.
vision_config
.
patch_size
,
)
def
get_image_size
(
self
)
->
int
:
return
self
.
vision_config
.
image_size
# Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa
# Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa
class
SiglipVisionEmbeddings
(
nn
.
Module
):
class
SiglipVisionEmbeddings
(
nn
.
Module
):
...
...
vllm/model_executor/models/utils.py
View file @
8c38ee70
...
@@ -373,7 +373,7 @@ def embed_multimodal(
...
@@ -373,7 +373,7 @@ def embed_multimodal(
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
multimodal_token_id
:
int
,
multimodal_token_id
:
int
,
get_text_embeds
:
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
],
get_text_embeds
:
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
],
multimodal_embeds
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]
,
multimodal_embeds
:
Nested
Tensor
s
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Embed token IDs and multimodal inputs and combine their embeddings.
Embed token IDs and multimodal inputs and combine their embeddings.
...
...
vllm/model_executor/models/vision.py
0 → 100644
View file @
8c38ee70
from
abc
import
ABC
,
abstractmethod
from
typing
import
Generic
,
TypeVar
from
transformers
import
PretrainedConfig
_C
=
TypeVar
(
"_C"
,
bound
=
PretrainedConfig
)
class
VisionEncoderInfo
(
ABC
,
Generic
[
_C
]):
def
__init__
(
self
,
vision_config
:
_C
)
->
None
:
super
().
__init__
()
self
.
vision_config
=
vision_config
@
abstractmethod
def
get_num_image_tokens
(
self
,
*
,
image_width
:
int
,
image_height
:
int
,
)
->
int
:
raise
NotImplementedError
@
abstractmethod
def
get_max_image_tokens
(
self
)
->
int
:
raise
NotImplementedError
@
abstractmethod
def
get_num_patches
(
self
)
->
int
:
raise
NotImplementedError
@
abstractmethod
def
get_image_size
(
self
)
->
int
:
raise
NotImplementedError
def
vision_encoder_info
(
vision_config
:
PretrainedConfig
)
->
VisionEncoderInfo
:
# Avoid circular imports
from
.clip
import
CLIPEncoderInfo
,
CLIPVisionConfig
from
.pixtral
import
PixtralHFEncoderInfo
,
PixtralVisionConfig
from
.siglip
import
SiglipEncoderInfo
,
SiglipVisionConfig
if
isinstance
(
vision_config
,
CLIPVisionConfig
):
return
CLIPEncoderInfo
(
vision_config
)
if
isinstance
(
vision_config
,
PixtralVisionConfig
):
return
PixtralHFEncoderInfo
(
vision_config
)
if
isinstance
(
vision_config
,
SiglipVisionConfig
):
return
SiglipEncoderInfo
(
vision_config
)
msg
=
f
"Unsupported vision config:
{
type
(
vision_config
)
}
"
raise
NotImplementedError
(
msg
)
vllm/multimodal/parse.py
View file @
8c38ee70
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
collections
import
UserDict
from
collections
import
UserDict
from
collections.abc
import
Callable
,
Iterator
,
Mapping
,
Sequence
from
collections.abc
import
Callable
,
Iterator
,
Mapping
,
Sequence
from
typing
import
TYPE_CHECKING
,
Any
,
Generic
,
NamedTuple
,
Optional
,
TypeVar
from
typing
import
(
TYPE_CHECKING
,
Any
,
Generic
,
NamedTuple
,
Optional
,
TypeVar
,
Union
)
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -87,7 +88,7 @@ class EmbeddingItems(ModalityDataItems[NestedTensors, torch.Tensor]):
...
@@ -87,7 +88,7 @@ class EmbeddingItems(ModalityDataItems[NestedTensors, torch.Tensor]):
def
get_count
(
self
)
->
int
:
def
get_count
(
self
)
->
int
:
return
len
(
self
.
data
)
return
len
(
self
.
data
)
def
get
(
self
,
index
:
int
)
->
object
:
def
get
(
self
,
index
:
int
)
->
torch
.
Tensor
:
return
self
.
data
[
index
]
return
self
.
data
[
index
]
def
get_processor_data
(
self
)
->
Mapping
[
str
,
object
]:
def
get_processor_data
(
self
)
->
Mapping
[
str
,
object
]:
...
@@ -96,6 +97,9 @@ class EmbeddingItems(ModalityDataItems[NestedTensors, torch.Tensor]):
...
@@ -96,6 +97,9 @@ class EmbeddingItems(ModalityDataItems[NestedTensors, torch.Tensor]):
def
get_passthrough_data
(
self
)
->
Mapping
[
str
,
object
]:
def
get_passthrough_data
(
self
)
->
Mapping
[
str
,
object
]:
return
{
f
"
{
self
.
modality
}
_embeds"
:
self
.
data
}
return
{
f
"
{
self
.
modality
}
_embeds"
:
self
.
data
}
def
get_feature_size
(
self
,
item_idx
:
int
)
->
int
:
return
len
(
self
.
get
(
item_idx
))
class
AudioProcessorItems
(
ProcessorBatchItems
[
HfAudioItem
]):
class
AudioProcessorItems
(
ProcessorBatchItems
[
HfAudioItem
]):
...
@@ -182,7 +186,7 @@ class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]):
...
@@ -182,7 +186,7 @@ class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]):
def
get_items
(
def
get_items
(
self
,
self
,
modality
:
str
,
modality
:
str
,
typ
:
type
[
_D
],
typ
:
Union
[
type
[
_D
],
tuple
[
type
[
_D
],
...]],
)
->
_D
:
)
->
_D
:
"""
"""
Get the data items belonging to a modality,
Get the data items belonging to a modality,
...
@@ -199,7 +203,7 @@ class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]):
...
@@ -199,7 +203,7 @@ class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]):
f
"Expected type:
{
typ
}
, but "
f
"Expected type:
{
typ
}
, but "
f
"found type:
{
type
(
items
)
}
"
)
f
"found type:
{
type
(
items
)
}
"
)
return
items
return
items
# type: ignore[return-value]
ModalityDataParser
:
TypeAlias
=
Callable
[[
ModalityData
[
Any
]],
ModalityDataParser
:
TypeAlias
=
Callable
[[
ModalityData
[
Any
]],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment