Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
82de9b9d
Unverified
Commit
82de9b9d
authored
Aug 01, 2025
by
Cyrus Leung
Committed by
GitHub
Jul 31, 2025
Browse files
[Misc] Automatically resolve HF processor init kwargs (#22005)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
ad57f23f
Changes
40
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
194 additions
and
300 deletions
+194
-300
examples/offline_inference/vision_language.py
examples/offline_inference/vision_language.py
+19
-19
tests/lora/test_qwen2vl.py
tests/lora/test_qwen2vl.py
+0
-6
tests/models/multimodal/generation/test_common.py
tests/models/multimodal/generation/test_common.py
+26
-1
tests/models/multimodal/generation/vlm_utils/model_utils.py
tests/models/multimodal/generation/vlm_utils/model_utils.py
+12
-0
tests/models/multimodal/processing/test_transformers.py
tests/models/multimodal/processing/test_transformers.py
+1
-1
tests/models/registry.py
tests/models/registry.py
+1
-2
tests/multimodal/test_processing.py
tests/multimodal/test_processing.py
+70
-37
vllm/config.py
vllm/config.py
+11
-1
vllm/inputs/registry.py
vllm/inputs/registry.py
+8
-9
vllm/model_executor/models/aya_vision.py
vllm/model_executor/models/aya_vision.py
+3
-9
vllm/model_executor/models/deepseek_vl2.py
vllm/model_executor/models/deepseek_vl2.py
+18
-18
vllm/model_executor/models/florence2.py
vllm/model_executor/models/florence2.py
+0
-6
vllm/model_executor/models/fuyu.py
vllm/model_executor/models/fuyu.py
+2
-2
vllm/model_executor/models/glm4_1v.py
vllm/model_executor/models/glm4_1v.py
+4
-4
vllm/model_executor/models/h2ovl.py
vllm/model_executor/models/h2ovl.py
+1
-15
vllm/model_executor/models/hyperclovax_vision.py
vllm/model_executor/models/hyperclovax_vision.py
+1
-19
vllm/model_executor/models/idefics3.py
vllm/model_executor/models/idefics3.py
+1
-9
vllm/model_executor/models/internvl.py
vllm/model_executor/models/internvl.py
+3
-25
vllm/model_executor/models/keye.py
vllm/model_executor/models/keye.py
+2
-82
vllm/model_executor/models/llava.py
vllm/model_executor/models/llava.py
+11
-35
No files found.
examples/offline_inference/vision_language.py
View file @
82de9b9d
...
@@ -449,25 +449,6 @@ def run_smolvlm(questions: list[str], modality: str) -> ModelRequestData:
...
@@ -449,25 +449,6 @@ def run_smolvlm(questions: list[str], modality: str) -> ModelRequestData:
)
)
# omni-research/Tarsier-7b
def
run_tarsier
(
questions
:
list
[
str
],
modality
:
str
)
->
ModelRequestData
:
assert
modality
==
"image"
model_name
=
"omni-research/Tarsier-7b"
engine_args
=
EngineArgs
(
model
=
model_name
,
trust_remote_code
=
True
,
max_model_len
=
4096
,
limit_mm_per_prompt
=
{
modality
:
1
},
)
prompts
=
[(
f
"USER: <image>
\n
{
question
}
ASSISTANT:"
)
for
question
in
questions
]
return
ModelRequestData
(
engine_args
=
engine_args
,
prompts
=
prompts
,
)
# Intern-S1
# Intern-S1
def
run_interns1
(
questions
:
list
[
str
],
modality
:
str
)
->
ModelRequestData
:
def
run_interns1
(
questions
:
list
[
str
],
modality
:
str
)
->
ModelRequestData
:
model_name
=
"internlm/Intern-S1"
model_name
=
"internlm/Intern-S1"
...
@@ -1293,6 +1274,25 @@ def run_qwen2_5_omni(questions: list[str], modality: str):
...
@@ -1293,6 +1274,25 @@ def run_qwen2_5_omni(questions: list[str], modality: str):
)
)
# omni-research/Tarsier-7b
def
run_tarsier
(
questions
:
list
[
str
],
modality
:
str
)
->
ModelRequestData
:
assert
modality
==
"image"
model_name
=
"omni-research/Tarsier-7b"
engine_args
=
EngineArgs
(
model
=
model_name
,
trust_remote_code
=
True
,
max_model_len
=
4096
,
limit_mm_per_prompt
=
{
modality
:
1
},
)
prompts
=
[(
f
"USER: <image>
\n
{
question
}
ASSISTANT:"
)
for
question
in
questions
]
return
ModelRequestData
(
engine_args
=
engine_args
,
prompts
=
prompts
,
)
def
run_tarsier2
(
questions
:
list
[
str
],
modality
:
str
)
->
ModelRequestData
:
def
run_tarsier2
(
questions
:
list
[
str
],
modality
:
str
)
->
ModelRequestData
:
model_name
=
"omni-research/Tarsier2-Recap-7b"
model_name
=
"omni-research/Tarsier2-Recap-7b"
...
...
tests/lora/test_qwen2vl.py
View file @
82de9b9d
...
@@ -4,8 +4,6 @@ from dataclasses import dataclass
...
@@ -4,8 +4,6 @@ from dataclasses import dataclass
from
typing
import
Optional
from
typing
import
Optional
import
pytest
import
pytest
from
packaging.version
import
Version
from
transformers
import
__version__
as
TRANSFORMERS_VERSION
import
vllm
import
vllm
from
vllm.assets.image
import
ImageAsset
from
vllm.assets.image
import
ImageAsset
...
@@ -185,10 +183,6 @@ def test_qwen2vl_lora_beam_search(qwen2vl_lora_files):
...
@@ -185,10 +183,6 @@ def test_qwen2vl_lora_beam_search(qwen2vl_lora_files):
current_platform
.
is_rocm
(),
current_platform
.
is_rocm
(),
reason
=
"Qwen2.5-VL dependency xformers incompatible with ROCm"
,
reason
=
"Qwen2.5-VL dependency xformers incompatible with ROCm"
,
)
)
@
pytest
.
mark
.
skipif
(
Version
(
TRANSFORMERS_VERSION
)
<
Version
(
"4.49.0"
),
reason
=
"Qwen2.5-VL require transformers version no lower than 4.49.0"
,
)
def
test_qwen25vl_lora
(
qwen25vl_lora_files
):
def
test_qwen25vl_lora
(
qwen25vl_lora_files
):
"""Test Qwen 2.5 VL model with LoRA"""
"""Test Qwen 2.5 VL model with LoRA"""
config
=
TestConfig
(
model_path
=
QWEN25VL_MODEL_PATH
,
config
=
TestConfig
(
model_path
=
QWEN25VL_MODEL_PATH
,
...
...
tests/models/multimodal/generation/test_common.py
View file @
82de9b9d
...
@@ -702,13 +702,38 @@ VLM_TEST_SETTINGS = {
...
@@ -702,13 +702,38 @@ VLM_TEST_SETTINGS = {
"smolvlm"
:
VLMTestInfo
(
"smolvlm"
:
VLMTestInfo
(
models
=
[
"HuggingFaceTB/SmolVLM2-2.2B-Instruct"
],
models
=
[
"HuggingFaceTB/SmolVLM2-2.2B-Instruct"
],
test_type
=
(
VLMTestType
.
IMAGE
,
VLMTestType
.
MULTI_IMAGE
),
test_type
=
(
VLMTestType
.
IMAGE
,
VLMTestType
.
MULTI_IMAGE
),
prompt_formatter
=
lambda
img_prompt
:
f
"<|im_start|>User:
{
img_prompt
}
<end_of_utterance>
\n
Assistant:"
,
# noqa: E501
prompt_formatter
=
lambda
img_prompt
:
f
"<|im_start|>User:
{
img_prompt
}
<end_of_utterance>
\n
Assistant:"
,
# noqa: E501
img_idx_to_prompt
=
lambda
idx
:
"<image>"
,
img_idx_to_prompt
=
lambda
idx
:
"<image>"
,
max_model_len
=
8192
,
max_model_len
=
8192
,
max_num_seqs
=
2
,
max_num_seqs
=
2
,
auto_cls
=
AutoModelForImageTextToText
,
auto_cls
=
AutoModelForImageTextToText
,
hf_output_post_proc
=
model_utils
.
smolvlm_trunc_hf_output
,
hf_output_post_proc
=
model_utils
.
smolvlm_trunc_hf_output
,
),
),
"tarsier"
:
VLMTestInfo
(
models
=
[
"omni-research/Tarsier-7b"
],
test_type
=
(
VLMTestType
.
IMAGE
,
VLMTestType
.
MULTI_IMAGE
),
prompt_formatter
=
lambda
img_prompt
:
f
"USER:
{
img_prompt
}
ASSISTANT:"
,
max_model_len
=
4096
,
max_num_seqs
=
2
,
auto_cls
=
AutoModelForImageTextToText
,
patch_hf_runner
=
model_utils
.
tarsier_patch_hf_runner
,
),
"tarsier2"
:
VLMTestInfo
(
models
=
[
"omni-research/Tarsier2-Recap-7b"
],
test_type
=
(
VLMTestType
.
IMAGE
,
VLMTestType
.
MULTI_IMAGE
,
VLMTestType
.
VIDEO
,
),
prompt_formatter
=
lambda
img_prompt
:
f
"<|im_start|>system
\n
You are a helpful assistant.<|im_end|>
\n
<|im_start|>user
\n
{
img_prompt
}
<|im_end|>
\n
<|im_start|>assistant
\n
"
,
# noqa: E501
img_idx_to_prompt
=
lambda
idx
:
"<|vision_start|><|image_pad|><|vision_end|>"
,
# noqa: E501
video_idx_to_prompt
=
lambda
idx
:
"<|vision_start|><|video_pad|><|vision_end|>"
,
# noqa: E501
max_model_len
=
4096
,
max_num_seqs
=
2
,
auto_cls
=
AutoModelForImageTextToText
,
image_size_factors
=
[(),
(
0.25
,),
(
0.25
,
0.25
,
0.25
),
(
0.25
,
0.2
,
0.15
)],
marks
=
[
pytest
.
mark
.
skip
(
"Model initialization hangs"
)],
),
### Tensor parallel / multi-gpu broadcast tests
### Tensor parallel / multi-gpu broadcast tests
"chameleon-broadcast"
:
VLMTestInfo
(
"chameleon-broadcast"
:
VLMTestInfo
(
models
=
[
"facebook/chameleon-7b"
],
models
=
[
"facebook/chameleon-7b"
],
...
...
tests/models/multimodal/generation/vlm_utils/model_utils.py
View file @
82de9b9d
...
@@ -818,3 +818,15 @@ def qwen2_5_omni_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
...
@@ -818,3 +818,15 @@ def qwen2_5_omni_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
thinker
.
get_output_embeddings
=
lambda
:
thinker
.
lm_head
thinker
.
get_output_embeddings
=
lambda
:
thinker
.
lm_head
hf_model
.
model
=
thinker
hf_model
.
model
=
thinker
return
hf_model
return
hf_model
def
tarsier_patch_hf_runner
(
hf_model
:
HfRunner
)
->
HfRunner
:
from
vllm.model_executor.models.tarsier
import
get_vision_encoder_info
vision_encoder_info
=
get_vision_encoder_info
(
hf_model
.
config
)
hf_processor
=
hf_model
.
processor
if
hf_processor
.
patch_size
is
None
:
hf_processor
.
patch_size
=
vision_encoder_info
.
get_patch_size
()
return
hf_model
tests/models/multimodal/processing/test_transformers.py
View file @
82de9b9d
...
@@ -16,7 +16,7 @@ def test_multimodal_processor(model_id):
...
@@ -16,7 +16,7 @@ def test_multimodal_processor(model_id):
model_impl
=
"transformers"
,
model_impl
=
"transformers"
,
)
)
mm_processor
=
MULTIMODAL_REGISTRY
.
create_processor
(
model_config
,
)
mm_processor
=
MULTIMODAL_REGISTRY
.
create_processor
(
model_config
)
image_pil
=
ImageAsset
(
'cherry_blossom'
).
pil_image
image_pil
=
ImageAsset
(
'cherry_blossom'
).
pil_image
mm_data
=
{
"image"
:
image_pil
}
mm_data
=
{
"image"
:
image_pil
}
...
...
tests/models/registry.py
View file @
82de9b9d
...
@@ -465,8 +465,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
...
@@ -465,8 +465,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
is_available_online
=
False
),
is_available_online
=
False
),
"UltravoxModel"
:
_HfExamplesInfo
(
"fixie-ai/ultravox-v0_5-llama-3_2-1b"
,
# noqa: E501
"UltravoxModel"
:
_HfExamplesInfo
(
"fixie-ai/ultravox-v0_5-llama-3_2-1b"
,
# noqa: E501
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"TarsierForConditionalGeneration"
:
_HfExamplesInfo
(
"omni-research/Tarsier-7b"
,
# noqa: E501
"TarsierForConditionalGeneration"
:
_HfExamplesInfo
(
"omni-research/Tarsier-7b"
),
# noqa: E501
hf_overrides
=
{
"architectures"
:
[
"TarsierForConditionalGeneration"
]}),
# noqa: E501
"Tarsier2ForConditionalGeneration"
:
_HfExamplesInfo
(
"omni-research/Tarsier2-Recap-7b"
,
# noqa: E501
"Tarsier2ForConditionalGeneration"
:
_HfExamplesInfo
(
"omni-research/Tarsier2-Recap-7b"
,
# noqa: E501
hf_overrides
=
{
"architectures"
:
[
"Tarsier2ForConditionalGeneration"
]}),
# noqa: E501
hf_overrides
=
{
"architectures"
:
[
"Tarsier2ForConditionalGeneration"
]}),
# noqa: E501
"VoxtralForConditionalGeneration"
:
_HfExamplesInfo
(
"VoxtralForConditionalGeneration"
:
_HfExamplesInfo
(
...
...
tests/multimodal/test_processing.py
View file @
82de9b9d
...
@@ -2,16 +2,15 @@
...
@@ -2,16 +2,15 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
contextlib
import
nullcontext
from
contextlib
import
nullcontext
from
types
import
MethodType
from
typing
import
Optional
,
cast
from
typing
import
cast
from
unittest.mock
import
MagicMock
from
unittest.mock
import
MagicMock
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
import
torch
import
torch
from
transformers
import
ProcessorMixin
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.inputs
import
InputProcessingContext
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
MultiModalFieldElem
,
MultiModalKwargs
,
from
vllm.multimodal.inputs
import
(
MultiModalFieldElem
,
MultiModalKwargs
,
MultiModalKwargsItem
,
MultiModalKwargsItem
,
...
@@ -1013,57 +1012,91 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
...
@@ -1013,57 +1012,91 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
)
)
class
_
Processor
Proxy
:
class
Dummy
Processor
:
def
__init__
(
self
,
processor
:
ProcessorMixin
)
->
None
:
def
__init__
(
self
,
a
:
int
=
0
,
b
:
int
=
0
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
__processor
=
processor
self
.
a
=
a
self
.
b
=
b
def
__getattr__
(
self
,
key
:
str
):
return
getattr
(
self
.
__processor
,
key
)
def
__call__
(
def
__call__
(
self
,
self
,
text
=
None
,
a
:
int
=
0
,
images
=
None
,
c
:
int
=
0
,
videos
=
None
,
return_tensors
:
Optional
[
str
]
=
None
,
exists
=
None
,
)
->
dict
[
str
,
int
]:
return_tensors
=
None
,
return
dict
(
a
=
a
,
c
=
c
)
):
return
dict
(
exists
=
exists
)
@
pytest
.
mark
.
parametrize
(
"model_id"
,
[
"Qwen/Qwen2-VL-2B-Instruct"
])
# Dummy
# yapf: disable
# yapf: disable
@
pytest
.
mark
.
parametrize
(
"model_id"
,
[
"Qwen/Qwen2-VL-2B-Instruct"
])
# Dummy
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"c
all
_kwargs"
,
"expected_kwargs"
),
(
"c
onfig_kwargs"
,
"inference
_kwargs"
,
"expected_kwargs"
),
[
[
# Should ignore invalid kwargs
({
"a"
:
1
},
{},
{
"a"
:
1
,
"b"
:
0
}),
({
"does_not_exist"
:
100
},
{
"exists"
:
None
}),
({},
{
"a"
:
1
},
{
"a"
:
1
,
"b"
:
0
}),
({
"exists"
:
1
},
{
"exists"
:
1
}),
# inference_kwargs should take precedence
({
"does_not_exist"
:
100
,
"exists"
:
1
},
{
"exists"
:
1
}),
({
"a"
:
1
},
{
"a"
:
2
},
{
"a"
:
2
,
"b"
:
0
}),
# Should ignore extra kwargs
({
"a"
:
1
,
"c"
:
1
},
{},
{
"a"
:
1
,
"b"
:
0
}),
({
"b"
:
1
,
"c"
:
1
},
{},
{
"a"
:
0
,
"b"
:
1
}),
],
],
)
)
# yapf: enable
# yapf: enable
def
test_hf_processor_kwargs
(
model_id
,
call_kwargs
,
expected_kwargs
):
def
test_hf_processor_init_kwargs
(
model_config
=
ModelConfig
(
model_id
)
model_id
,
config_kwargs
,
inference_kwargs
,
expected_kwargs
,
):
# Should not be used since there is nothing to convert to tokens
mock_tokenizer
=
cast
(
AnyTokenizer
,
object
())
processor
=
MULTIMODAL_REGISTRY
.
create_processor
(
model_config
)
ctx
=
InputProcessingContext
(
orig_get_hf_processor
=
processor
.
info
.
get_hf_processor
model_config
=
ModelConfig
(
model_id
,
mm_processor_kwargs
=
config_kwargs
),
tokenizer
=
mock_tokenizer
,
)
processor
=
ctx
.
get_hf_processor
(
DummyProcessor
,
# type: ignore[arg-type]
**
inference_kwargs
,
)
for
k
,
v
in
expected_kwargs
.
items
():
assert
getattr
(
processor
,
k
)
==
v
def
get_hf_processor
(
self
,
**
kwargs
):
assert
kwargs
==
call_kwargs
return
_ProcessorProxy
(
orig_get_hf_processor
())
processor
.
info
.
get_hf_processor
=
MethodType
(
get_hf_processor
,
# yapf: disable
processor
.
info
)
@
pytest
.
mark
.
parametrize
(
"model_id"
,
[
"Qwen/Qwen2-VL-2B-Instruct"
])
# Dummy
@
pytest
.
mark
.
parametrize
(
(
"config_kwargs"
,
"inference_kwargs"
,
"expected_kwargs"
),
[
({
"a"
:
1
},
{},
{
"a"
:
1
,
"c"
:
0
}),
({},
{
"a"
:
1
},
{
"a"
:
1
,
"c"
:
0
}),
# inference_kwargs should take precedence
({
"a"
:
1
},
{
"a"
:
2
},
{
"a"
:
2
,
"c"
:
0
}),
# Should ignore extra kwargs
({
"a"
:
1
,
"c"
:
1
},
{},
{
"a"
:
1
,
"c"
:
1
}),
({
"b"
:
1
,
"c"
:
1
},
{},
{
"a"
:
0
,
"c"
:
1
}),
],
)
# yapf: enable
def
test_hf_processor_call_kwargs
(
model_id
,
config_kwargs
,
inference_kwargs
,
expected_kwargs
,
):
# Should not be used since there is nothing to convert to tokens
mock_tokenizer
=
cast
(
AnyTokenizer
,
object
())
out_kwargs
=
processor
.
_call_hf_processor
(
ctx
=
InputProcessingContext
(
prompt
=
""
,
model_config
=
ModelConfig
(
model_id
,
mm_processor_kwargs
=
config_kwargs
),
mm_data
=
{},
tokenizer
=
mock_tokenizer
,
mm_kwargs
=
call_kwargs
,
tok_kwargs
=
{},
)
)
assert
out_kwargs
==
expected_kwargs
processor
=
ctx
.
get_hf_processor
(
DummyProcessor
)
# type: ignore[arg-type]
result
=
ctx
.
call_hf_processor
(
processor
,
{},
inference_kwargs
)
assert
result
==
expected_kwargs
vllm/config.py
View file @
82de9b9d
...
@@ -11,6 +11,7 @@ import textwrap
...
@@ -11,6 +11,7 @@ import textwrap
import
uuid
import
uuid
import
warnings
import
warnings
from
collections
import
Counter
from
collections
import
Counter
from
collections.abc
import
Mapping
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
dataclasses
import
(
MISSING
,
Field
,
asdict
,
field
,
fields
,
is_dataclass
,
from
dataclasses
import
(
MISSING
,
Field
,
asdict
,
field
,
fields
,
is_dataclass
,
replace
)
replace
)
...
@@ -3332,7 +3333,16 @@ class MultiModalConfig:
...
@@ -3332,7 +3333,16 @@ class MultiModalConfig:
999
if
envs
.
VLLM_USE_V1
else
1
,
999
if
envs
.
VLLM_USE_V1
else
1
,
)
)
# TODO: Add configs to init vision tower or not.
def
merge_mm_processor_kwargs
(
self
,
inference_kwargs
:
Mapping
[
str
,
object
],
)
->
dict
[
str
,
object
]:
"""
Get the keyword arguments to pass to the multi-modal processor
according to the extra arguments passed during inference.
"""
kwargs
=
self
.
mm_processor_kwargs
or
{}
return
kwargs
|
dict
(
inference_kwargs
)
@
config
@
config
...
...
vllm/inputs/registry.py
View file @
82de9b9d
...
@@ -11,7 +11,7 @@ from typing_extensions import TypeVar
...
@@ -11,7 +11,7 @@ from typing_extensions import TypeVar
from
vllm.jsontree
import
JSONTree
,
json_map_leaves
from
vllm.jsontree
import
JSONTree
,
json_map_leaves
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.processor
import
cached_processor_from_config
from
vllm.transformers_utils.processor
import
cached_processor_from_config
from
vllm.utils
import
resolve_mm_processor_kwarg
s
from
vllm.utils
import
get_allowed_kwarg_only_override
s
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
...
@@ -154,14 +154,11 @@ class InputProcessingContext(InputContext):
...
@@ -154,14 +154,11 @@ class InputProcessingContext(InputContext):
assert
callable
(
hf_processor
)
assert
callable
(
hf_processor
)
mm_config
=
self
.
model_config
.
get_multimodal_config
()
mm_config
=
self
.
model_config
.
get_multimodal_config
()
base_kwargs
=
mm_config
.
mm_processor_kwargs
merged_kwargs
=
mm_config
.
merge_mm_processor_kwargs
(
kwargs
)
if
base_kwargs
is
None
:
base_kwargs
=
{}
merged_kwargs
=
resolve_mm_processor_kwargs
(
allowed_kwargs
=
get_allowed_kwarg_only_overrides
(
base_kwargs
,
kwargs
,
hf_processor
,
hf_processor
,
merged_kwargs
,
requires_kw_only
=
False
,
requires_kw_only
=
False
,
allow_var_kwargs
=
True
,
allow_var_kwargs
=
True
,
)
)
...
@@ -173,7 +170,9 @@ class InputProcessingContext(InputContext):
...
@@ -173,7 +170,9 @@ class InputProcessingContext(InputContext):
return
x
return
x
try
:
try
:
output
=
hf_processor
(
**
data
,
**
merged_kwargs
,
return_tensors
=
"pt"
)
output
=
hf_processor
(
**
data
,
**
allowed_kwargs
,
return_tensors
=
"pt"
)
# this emulates output.to(dtype=self.model_config.dtype)
# this emulates output.to(dtype=self.model_config.dtype)
if
isinstance
(
output
,
BatchFeature
):
if
isinstance
(
output
,
BatchFeature
):
cast_output
=
json_map_leaves
(
maybe_cast_dtype
,
output
.
data
)
cast_output
=
json_map_leaves
(
maybe_cast_dtype
,
output
.
data
)
...
@@ -189,7 +188,7 @@ class InputProcessingContext(InputContext):
...
@@ -189,7 +188,7 @@ class InputProcessingContext(InputContext):
except
Exception
as
exc
:
except
Exception
as
exc
:
msg
=
(
f
"Failed to apply
{
type
(
hf_processor
).
__name__
}
"
msg
=
(
f
"Failed to apply
{
type
(
hf_processor
).
__name__
}
"
f
"on data=
{
data
}
with kwargs=
{
merg
ed_kwargs
}
"
)
f
"on data=
{
data
}
with kwargs=
{
allow
ed_kwargs
}
"
)
raise
ValueError
(
msg
)
from
exc
raise
ValueError
(
msg
)
from
exc
...
...
vllm/model_executor/models/aya_vision.py
View file @
82de9b9d
...
@@ -123,16 +123,10 @@ class AyaVisionProcessingInfo(BaseProcessingInfo):
...
@@ -123,16 +123,10 @@ class AyaVisionProcessingInfo(BaseProcessingInfo):
return
self
.
ctx
.
get_hf_config
(
AyaVisionConfig
)
return
self
.
ctx
.
get_hf_config
(
AyaVisionConfig
)
def
get_hf_processor
(
self
,
**
kwargs
:
object
)
->
AyaVisionProcessor
:
def
get_hf_processor
(
self
,
**
kwargs
:
object
)
->
AyaVisionProcessor
:
processor
=
self
.
ctx
.
get_hf_processor
(
AyaVisionProcessor
,
**
kwargs
)
return
self
.
ctx
.
get_hf_processor
(
AyaVisionProcessor
,
**
kwargs
)
# Temporary workaround since this processor has multiple image tokens
def
get_image_processor
(
self
,
**
kwargs
:
object
)
->
GotOcr2ImageProcessor
:
# See https://github.com/huggingface/transformers/issues/38350
return
self
.
get_hf_processor
(
**
kwargs
).
image_processor
processor
.
_check_special_mm_tokens
=
lambda
*
args
,
**
kwargs
:
None
return
processor
def
get_image_processor
(
self
)
->
GotOcr2ImageProcessor
:
return
self
.
get_hf_processor
().
image_processor
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
None
}
return
{
"image"
:
None
}
...
...
vllm/model_executor/models/deepseek_vl2.py
View file @
82de9b9d
...
@@ -214,25 +214,25 @@ class DeepseekVL2MultiModalProcessor(
...
@@ -214,25 +214,25 @@ class DeepseekVL2MultiModalProcessor(
mm_kwargs
:
Mapping
[
str
,
object
],
mm_kwargs
:
Mapping
[
str
,
object
],
tok_kwargs
:
Mapping
[
str
,
object
],
tok_kwargs
:
Mapping
[
str
,
object
],
)
->
BatchFeature
:
)
->
BatchFeature
:
if
mm_data
:
if
not
mm_data
:
processed_outputs
=
self
.
info
.
ctx
.
call_hf_processor
(
self
.
info
.
get_hf_processor
(
**
mm_kwargs
),
dict
(
prompt
=
prompt
,
**
mm_data
),
dict
(
**
mm_kwargs
,
**
tok_kwargs
),
)
pixel_values
=
processed_outputs
[
"pixel_values"
]
# split pixel values into patches corresponding to each image
images_spatial_crop
=
processed_outputs
[
"images_spatial_crop"
]
patches_per_image
=
[
x
.
prod
().
item
()
+
1
for
x
in
images_spatial_crop
]
pixel_values
=
pixel_values
.
split
(
patches_per_image
)
processed_outputs
[
"pixel_values"
]
=
pixel_values
else
:
tokenizer
=
self
.
info
.
get_tokenizer
()
tokenizer
=
self
.
info
.
get_tokenizer
()
processed_outputs
=
tokenizer
(
prompt
,
return
tokenizer
(
prompt
,
add_special_tokens
=
True
,
add_special_tokens
=
True
,
return_tensors
=
"pt"
)
return_tensors
=
"pt"
)
processed_outputs
=
super
().
_call_hf_processor
(
prompt
=
prompt
,
mm_data
=
mm_data
,
mm_kwargs
=
mm_kwargs
,
tok_kwargs
=
tok_kwargs
,
)
pixel_values
=
processed_outputs
[
"pixel_values"
]
# split pixel values into patches corresponding to each image
images_spatial_crop
=
processed_outputs
[
"images_spatial_crop"
]
patches_per_image
=
[
x
.
prod
().
item
()
+
1
for
x
in
images_spatial_crop
]
pixel_values
=
pixel_values
.
split
(
patches_per_image
)
processed_outputs
[
"pixel_values"
]
=
pixel_values
return
processed_outputs
return
processed_outputs
...
...
vllm/model_executor/models/florence2.py
View file @
82de9b9d
...
@@ -761,12 +761,6 @@ class Florence2LanguageForConditionalGeneration(nn.Module, SupportsV0Only):
...
@@ -761,12 +761,6 @@ class Florence2LanguageForConditionalGeneration(nn.Module, SupportsV0Only):
class
Florence2ProcessingInfo
(
BaseProcessingInfo
):
class
Florence2ProcessingInfo
(
BaseProcessingInfo
):
def
get_hf_config
(
self
):
return
self
.
ctx
.
get_hf_config
()
def
get_hf_processor
(
self
):
return
self
.
ctx
.
get_hf_processor
()
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
1
}
return
{
"image"
:
1
}
...
...
vllm/model_executor/models/fuyu.py
View file @
82de9b9d
...
@@ -83,8 +83,8 @@ class FuyuProcessingInfo(BaseProcessingInfo):
...
@@ -83,8 +83,8 @@ class FuyuProcessingInfo(BaseProcessingInfo):
def
get_hf_processor
(
self
,
**
kwargs
:
object
):
def
get_hf_processor
(
self
,
**
kwargs
:
object
):
return
self
.
ctx
.
get_hf_processor
(
FuyuProcessor
,
**
kwargs
)
return
self
.
ctx
.
get_hf_processor
(
FuyuProcessor
,
**
kwargs
)
def
get_image_processor
(
self
)
->
FuyuImageProcessor
:
def
get_image_processor
(
self
,
**
kwargs
:
object
)
->
FuyuImageProcessor
:
return
self
.
get_hf_processor
().
image_processor
return
self
.
get_hf_processor
(
**
kwargs
).
image_processor
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
1
}
return
{
"image"
:
1
}
...
...
vllm/model_executor/models/glm4_1v.py
View file @
82de9b9d
...
@@ -809,11 +809,11 @@ class Glm4vProcessingInfo(BaseProcessingInfo):
...
@@ -809,11 +809,11 @@ class Glm4vProcessingInfo(BaseProcessingInfo):
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
None
,
"video"
:
1
}
return
{
"image"
:
None
,
"video"
:
1
}
def
get_image_processor
(
self
)
->
Glm4vImageProcessor
:
def
get_image_processor
(
self
,
**
kwargs
:
object
)
->
Glm4vImageProcessor
:
return
self
.
get_hf_processor
().
image_processor
return
self
.
get_hf_processor
(
**
kwargs
).
image_processor
def
get_video_processor
(
self
)
->
Glm4vVideoProcessor
:
def
get_video_processor
(
self
,
**
kwargs
:
object
)
->
Glm4vVideoProcessor
:
return
self
.
get_hf_processor
().
video_processor
return
self
.
get_hf_processor
(
**
kwargs
).
video_processor
def
_get_vision_info
(
def
_get_vision_info
(
self
,
self
,
...
...
vllm/model_executor/models/h2ovl.py
View file @
82de9b9d
...
@@ -392,21 +392,7 @@ class H2OVLProcessor(BaseInternVLProcessor):
...
@@ -392,21 +392,7 @@ class H2OVLProcessor(BaseInternVLProcessor):
class
H2OVLProcessingInfo
(
BaseInternVLProcessingInfo
):
class
H2OVLProcessingInfo
(
BaseInternVLProcessingInfo
):
def
get_hf_processor
(
def
get_hf_processor
(
self
,
**
kwargs
:
object
)
->
H2OVLProcessor
:
self
,
*
,
min_dynamic_patch
:
Optional
[
int
]
=
None
,
max_dynamic_patch
:
Optional
[
int
]
=
None
,
dynamic_image_size
:
Optional
[
bool
]
=
None
,
**
kwargs
:
object
,
)
->
H2OVLProcessor
:
if
min_dynamic_patch
is
not
None
:
kwargs
[
"min_dynamic_patch"
]
=
min_dynamic_patch
if
max_dynamic_patch
is
not
None
:
kwargs
[
"max_dynamic_patch"
]
=
max_dynamic_patch
if
dynamic_image_size
is
not
None
:
kwargs
[
"dynamic_image_size"
]
=
dynamic_image_size
return
self
.
ctx
.
init_processor
(
return
self
.
ctx
.
init_processor
(
H2OVLProcessor
,
H2OVLProcessor
,
config
=
self
.
get_hf_config
(),
config
=
self
.
get_hf_config
(),
...
...
vllm/model_executor/models/hyperclovax_vision.py
View file @
82de9b9d
...
@@ -25,8 +25,7 @@ import torch
...
@@ -25,8 +25,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
timm.layers
import
LayerNorm
,
LayerNorm2d
from
timm.layers
import
LayerNorm
,
LayerNorm2d
from
timm.models.regnet
import
RegStage
from
timm.models.regnet
import
RegStage
from
transformers
import
(
AutoProcessor
,
BatchFeature
,
CLIPVisionConfig
,
from
transformers
import
BatchFeature
,
CLIPVisionConfig
,
SiglipVisionConfig
SiglipVisionConfig
)
from
transformers.modeling_utils
import
no_init_weights
from
transformers.modeling_utils
import
no_init_weights
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
...
@@ -80,26 +79,9 @@ HCXVisionMultimodalInputs = Union[HCXVisionMultimodalPixelInputs]
...
@@ -80,26 +79,9 @@ HCXVisionMultimodalInputs = Union[HCXVisionMultimodalPixelInputs]
class
HCXVisionProcessingInfo
(
BaseProcessingInfo
):
class
HCXVisionProcessingInfo
(
BaseProcessingInfo
):
def
get_hf_config
(
self
):
return
self
.
ctx
.
get_hf_config
()
def
get_vision_encoder_info
(
self
):
def
get_vision_encoder_info
(
self
):
return
get_vision_encoder_info
(
self
.
get_hf_config
())
return
get_vision_encoder_info
(
self
.
get_hf_config
())
def
get_hf_processor
(
self
,
**
kwargs
:
object
,
):
processor_cls
=
type
(
AutoProcessor
.
from_pretrained
(
self
.
ctx
.
model_config
.
model
,
trust_remote_code
=
self
.
ctx
.
model_config
.
trust_remote_code
,
))
return
self
.
ctx
.
get_hf_processor
(
processor_cls
,
**
kwargs
,
)
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
None
,
"video"
:
None
}
return
{
"image"
:
None
,
"video"
:
None
}
...
...
vllm/model_executor/models/idefics3.py
View file @
82de9b9d
...
@@ -88,15 +88,7 @@ ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs]
...
@@ -88,15 +88,7 @@ ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs]
class
Idefics3ProcessingInfo
(
BaseProcessingInfo
):
class
Idefics3ProcessingInfo
(
BaseProcessingInfo
):
def
get_hf_processor
(
def
get_hf_processor
(
self
,
**
kwargs
:
object
)
->
Idefics3Processor
:
self
,
*
,
size
:
Optional
[
dict
[
str
,
int
]]
=
None
,
**
kwargs
:
object
,
)
->
Idefics3Processor
:
if
size
is
not
None
:
kwargs
[
"size"
]
=
size
return
self
.
ctx
.
get_hf_processor
(
Idefics3Processor
,
**
kwargs
)
return
self
.
ctx
.
get_hf_processor
(
Idefics3Processor
,
**
kwargs
)
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
...
...
vllm/model_executor/models/internvl.py
View file @
82de9b9d
...
@@ -665,14 +665,7 @@ class BaseInternVLProcessingInfo(BaseProcessingInfo):
...
@@ -665,14 +665,7 @@ class BaseInternVLProcessingInfo(BaseProcessingInfo):
"""Basic image-only ProcessingInfo for InternVL-style models."""
"""Basic image-only ProcessingInfo for InternVL-style models."""
@
abstractmethod
@
abstractmethod
def
get_hf_processor
(
def
get_hf_processor
(
self
,
**
kwargs
:
object
)
->
BaseInternVLProcessor
:
self
,
*
,
min_dynamic_patch
:
Optional
[
int
]
=
None
,
max_dynamic_patch
:
Optional
[
int
]
=
None
,
dynamic_image_size
:
Optional
[
bool
]
=
None
,
**
kwargs
:
object
,
)
->
BaseInternVLProcessor
:
raise
NotImplementedError
raise
NotImplementedError
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
...
@@ -882,27 +875,12 @@ class InternVLProcessingInfo(BaseInternVLProcessingInfo):
...
@@ -882,27 +875,12 @@ class InternVLProcessingInfo(BaseInternVLProcessingInfo):
return
max
(
max_frames_per_video
,
1
)
return
max
(
max_frames_per_video
,
1
)
def
get_hf_processor
(
def
get_hf_processor
(
self
,
**
kwargs
:
object
)
->
InternVLProcessor
:
self
,
*
,
min_dynamic_patch
:
Optional
[
int
]
=
None
,
max_dynamic_patch
:
Optional
[
int
]
=
None
,
dynamic_image_size
:
Optional
[
bool
]
=
None
,
**
kwargs
:
object
,
)
->
InternVLProcessor
:
if
min_dynamic_patch
is
not
None
:
kwargs
[
"min_dynamic_patch"
]
=
min_dynamic_patch
if
max_dynamic_patch
is
not
None
:
kwargs
[
"max_dynamic_patch"
]
=
max_dynamic_patch
if
dynamic_image_size
is
not
None
:
kwargs
[
"dynamic_image_size"
]
=
dynamic_image_size
kwargs
[
"video_token"
]
=
self
.
get_video_token
()
return
self
.
ctx
.
init_processor
(
return
self
.
ctx
.
init_processor
(
InternVLProcessor
,
InternVLProcessor
,
config
=
self
.
get_hf_config
(),
config
=
self
.
get_hf_config
(),
tokenizer
=
self
.
get_tokenizer
(),
tokenizer
=
self
.
get_tokenizer
(),
video_token
=
self
.
get_video_token
(),
**
kwargs
,
**
kwargs
,
)
)
...
...
vllm/model_executor/models/keye.py
View file @
82de9b9d
...
@@ -44,8 +44,6 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
...
@@ -44,8 +44,6 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
from
vllm.platforms
import
_Backend
from
vllm.platforms
import
_Backend
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.config
import
uses_mrope
from
vllm.transformers_utils.config
import
uses_mrope
from
vllm.transformers_utils.processor
import
(
cached_image_processor_from_config
)
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsLoRA
,
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsLoRA
,
...
@@ -980,72 +978,8 @@ class KeyeMultiModalDataParser(MultiModalDataParser):
...
@@ -980,72 +978,8 @@ class KeyeMultiModalDataParser(MultiModalDataParser):
class
KeyeProcessingInfo
(
BaseProcessingInfo
):
class
KeyeProcessingInfo
(
BaseProcessingInfo
):
def
get_hf_processor
(
def
get_image_processor
(
self
,
**
kwargs
:
object
):
self
,
return
self
.
get_hf_processor
(
**
kwargs
).
image_processor
*
,
min_pixels
:
Optional
[
int
]
=
None
,
max_pixels
:
Optional
[
int
]
=
None
,
size
:
Optional
[
dict
[
str
,
int
]]
=
None
,
**
kwargs
:
object
,
):
return
self
.
ctx
.
get_hf_processor
(
image_processor
=
self
.
get_image_processor
(
min_pixels
=
min_pixels
,
max_pixels
=
max_pixels
,
size
=
size
,
),
**
kwargs
,
)
def
_get_image_processor_kwargs
(
self
,
*
,
min_pixels
:
Optional
[
int
]
=
None
,
max_pixels
:
Optional
[
int
]
=
None
,
size
:
Optional
[
dict
[
str
,
int
]]
=
None
,
**
kwargs
:
object
,
):
if
self
.
ctx
.
model_config
.
mm_processor_kwargs
:
kwargs
.
update
(
self
.
ctx
.
model_config
.
mm_processor_kwargs
)
if
min_pixels
is
not
None
:
kwargs
[
"min_pixels"
]
=
min_pixels
if
size
is
None
:
size
=
{
"shortest_edge"
:
min_pixels
}
else
:
size
[
"shortest_edge"
]
=
min_pixels
if
max_pixels
is
not
None
:
kwargs
[
"max_pixels"
]
=
max_pixels
if
size
is
None
:
size
=
{
"longest_edge"
:
max_pixels
}
else
:
size
[
"longest_edge"
]
=
max_pixels
if
size
is
not
None
:
kwargs
[
"size"
]
=
size
return
kwargs
def
get_image_processor
(
self
,
*
,
min_pixels
:
Optional
[
int
]
=
None
,
max_pixels
:
Optional
[
int
]
=
None
,
size
:
Optional
[
dict
[
str
,
int
]]
=
None
,
**
kwargs
:
object
,
):
return
cached_image_processor_from_config
(
self
.
ctx
.
model_config
,
**
self
.
_get_image_processor_kwargs
(
min_pixels
=
min_pixels
,
max_pixels
=
max_pixels
,
size
=
size
,
**
kwargs
,
),
)
def
get_supported_mm_limits
(
self
,
)
->
Mapping
[
str
,
Optional
[
int
]]:
def
get_supported_mm_limits
(
self
,
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
None
,
"video"
:
None
}
return
{
"image"
:
None
,
"video"
:
None
}
...
@@ -1246,20 +1180,6 @@ class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]):
...
@@ -1246,20 +1180,6 @@ class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]):
def
_get_data_parser
(
self
)
->
MultiModalDataParser
:
def
_get_data_parser
(
self
)
->
MultiModalDataParser
:
return
KeyeMultiModalDataParser
()
return
KeyeMultiModalDataParser
()
def
_call_hf_processor
(
self
,
prompt
:
str
,
mm_data
:
Mapping
[
str
,
object
],
mm_kwargs
:
Mapping
[
str
,
object
],
tok_kwargs
:
Mapping
[
str
,
object
],
)
->
BatchFeature
:
mm_kwargs
=
self
.
info
.
_get_image_processor_kwargs
(
**
mm_kwargs
)
return
self
.
info
.
ctx
.
call_hf_processor
(
self
.
info
.
get_hf_processor
(
**
mm_kwargs
),
dict
(
text
=
prompt
,
**
mm_data
),
dict
(
**
mm_kwargs
,
**
tok_kwargs
),
)
def
_get_prompt_updates
(
def
_get_prompt_updates
(
self
,
self
,
mm_items
:
MultiModalDataItems
,
mm_items
:
MultiModalDataItems
,
...
...
vllm/model_executor/models/llava.py
View file @
82de9b9d
...
@@ -8,11 +8,9 @@ from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar,
...
@@ -8,11 +8,9 @@ from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar,
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
packaging.version
import
Version
from
transformers
import
(
BatchFeature
,
CLIPVisionConfig
,
LlavaConfig
,
from
transformers
import
(
BatchFeature
,
CLIPVisionConfig
,
LlavaConfig
,
PixtralVisionConfig
,
PretrainedConfig
,
PixtralVisionConfig
,
PretrainedConfig
,
SiglipVisionConfig
)
SiglipVisionConfig
)
from
transformers
import
__version__
as
TRANSFORMERS_VERSION
from
transformers.models.llava
import
LlavaProcessor
from
transformers.models.llava
import
LlavaProcessor
from
transformers.models.pixtral
import
PixtralProcessor
from
transformers.models.pixtral
import
PixtralProcessor
...
@@ -307,29 +305,14 @@ class PixtralHFMultiModalProcessor(
...
@@ -307,29 +305,14 @@ class PixtralHFMultiModalProcessor(
pixel_values
=
processed_outputs
.
get
(
"pixel_values"
)
pixel_values
=
processed_outputs
.
get
(
"pixel_values"
)
if
pixel_values
is
not
None
:
if
pixel_values
is
not
None
:
# Before/after https://github.com/huggingface/transformers/pull/35122
# Avoid padding since we need the output for each image to be
if
Version
(
TRANSFORMERS_VERSION
)
<=
Version
(
"4.48.3"
):
# independent of other images for the cache to work correctly
images
=
mm_data
[
"images"
]
image_sizes
=
processed_outputs
[
"image_sizes"
]
assert
isinstance
(
images
,
list
)
assert
len
(
pixel_values
)
==
len
(
image_sizes
)
# Original output: (1, num_images, C, H, W)
# New output: (num_images, C, H, W)
assert
(
isinstance
(
pixel_values
,
list
)
and
len
(
pixel_values
)
==
1
)
assert
(
isinstance
(
pixel_values
[
0
],
list
)
and
len
(
pixel_values
[
0
])
==
len
(
images
))
processed_outputs
[
"pixel_values"
]
=
pixel_values
[
0
]
else
:
# Avoid padding since we need the output for each image to be
# independent of other images for the cache to work correctly
image_sizes
=
processed_outputs
[
"image_sizes"
]
assert
len
(
pixel_values
)
==
len
(
image_sizes
)
processed_outputs
[
"pixel_values"
]
=
[
processed_outputs
[
"pixel_values"
]
=
[
p
[:,
:
h
,
:
w
]
p
[:,
:
h
,
:
w
]
for
p
,
(
h
,
w
)
in
zip
(
pixel_values
,
image_sizes
)
for
p
,
(
h
,
w
)
in
zip
(
pixel_values
,
image_sizes
)
]
]
return
processed_outputs
return
processed_outputs
...
@@ -784,17 +767,10 @@ class MantisProcessingInfo(LlavaProcessingInfo):
...
@@ -784,17 +767,10 @@ class MantisProcessingInfo(LlavaProcessingInfo):
vision_info
=
self
.
get_vision_encoder_info
()
vision_info
=
self
.
get_vision_encoder_info
()
kwargs
.
setdefault
(
"patch_size"
,
vision_info
.
get_patch_size
())
kwargs
.
setdefault
(
"patch_size"
,
vision_info
.
get_patch_size
())
kwargs
.
setdefault
(
if
Version
(
TRANSFORMERS_VERSION
)
<
Version
(
"4.48"
):
"vision_feature_select_strategy"
,
# BUG: num_additional_image_tokens = 0 but treated as 1,
hf_config
.
vision_feature_select_strategy
,
# so we set vision_feature_select_strategy to None to offset this
)
kwargs
.
setdefault
(
"vision_feature_select_strategy"
,
None
)
else
:
# FIXED: https://github.com/huggingface/transformers/pull/33424/files#diff-6a37acc21efcadaae622b079b2712a131131448ff64262bd219aa346aeec38faL150
kwargs
.
setdefault
(
"vision_feature_select_strategy"
,
hf_config
.
vision_feature_select_strategy
,
)
return
self
.
ctx
.
get_hf_processor
(
LlavaProcessor
,
**
kwargs
)
return
self
.
ctx
.
get_hf_processor
(
LlavaProcessor
,
**
kwargs
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment