Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c781fbba
Unverified
Commit
c781fbba
authored
Mar 17, 2026
by
Cyrus Leung
Committed by
GitHub
Mar 17, 2026
Browse files
[Bugfix] Standardize custom HF Processor init (#37289)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
979ff44c
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
39 additions
and
33 deletions
+39
-33
vllm/model_executor/models/deepseek_ocr.py
vllm/model_executor/models/deepseek_ocr.py
+3
-1
vllm/model_executor/models/deepseek_ocr2.py
vllm/model_executor/models/deepseek_ocr2.py
+3
-1
vllm/model_executor/models/glm4v.py
vllm/model_executor/models/glm4v.py
+11
-3
vllm/model_executor/models/qwen_vl.py
vllm/model_executor/models/qwen_vl.py
+11
-3
vllm/tokenizers/qwen_vl.py
vllm/tokenizers/qwen_vl.py
+4
-0
vllm/transformers_utils/processors/glm4v.py
vllm/transformers_utils/processors/glm4v.py
+2
-7
vllm/transformers_utils/processors/qwen_vl.py
vllm/transformers_utils/processors/qwen_vl.py
+5
-18
No files found.
vllm/model_executor/models/deepseek_ocr.py
View file @
c781fbba
...
...
@@ -196,8 +196,10 @@ class DeepseekOCRProcessingInfo(BaseProcessingInfo):
crop_mode
=
CROP_MODE
,
strategy
=
"v1"
,
)
return
self
.
ctx
.
get_hf_processor
(
DeepseekOCRProcessor
,
**
{
**
kwargs
,
**
v1_processor_config
}
DeepseekOCRProcessor
,
**
{
**
v1_processor_config
,
**
kwargs
},
)
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
int
|
None
]:
...
...
vllm/model_executor/models/deepseek_ocr2.py
View file @
c781fbba
...
...
@@ -76,8 +76,10 @@ class DeepseekOCR2ProcessingInfo(BaseProcessingInfo):
crop_mode
=
CROP_MODE
,
strategy
=
"v2"
,
)
return
self
.
ctx
.
get_hf_processor
(
DeepseekOCRProcessor
,
**
{
**
kwargs
,
**
v2_processor_config
}
DeepseekOCRProcessor
,
**
{
**
v2_processor_config
,
**
kwargs
},
)
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
int
|
None
]:
...
...
vllm/model_executor/models/glm4v.py
View file @
c781fbba
...
...
@@ -47,7 +47,10 @@ from vllm.multimodal.processing import (
)
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.configs.chatglm
import
ChatGLMConfig
from
vllm.transformers_utils.processors.glm4v
import
GLM4VProcessor
from
vllm.transformers_utils.processors.glm4v
import
(
GLM4VImageProcessorFast
,
GLM4VProcessor
,
)
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
.chatglm
import
ChatGLMBaseModel
,
ChatGLMModel
,
GLMTransformer
...
...
@@ -387,15 +390,20 @@ class GLM4VProcessingInfo(BaseProcessingInfo):
def
get_hf_config
(
self
):
return
self
.
ctx
.
get_hf_config
(
ChatGLMConfig
)
def
get_
hf
_processor
(
self
,
**
kwargs
:
object
)
->
GLM4VProcessor
:
def
get_
image
_processor
(
self
,
**
kwargs
)
:
config
=
self
.
get_hf_config
()
vision_config
=
config
.
vision_config
image_size
=
vision_config
[
"image_size"
]
kwargs
.
setdefault
(
"size"
,
{
"width"
:
image_size
,
"height"
:
image_size
})
return
GLM4VImageProcessorFast
(
**
kwargs
)
def
get_hf_processor
(
self
,
**
kwargs
:
object
)
->
GLM4VProcessor
:
return
self
.
ctx
.
init_processor
(
GLM4VProcessor
,
tokenizer
=
self
.
get_tokenizer
(),
**
{
**
kwargs
,
"image_size"
:
image_size
}
,
image_processor
=
self
.
get_image_processor
(
**
kwargs
)
,
)
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
int
|
None
]:
...
...
vllm/model_executor/models/qwen_vl.py
View file @
c781fbba
...
...
@@ -44,7 +44,10 @@ from vllm.multimodal.processing import (
PromptUpdateDetails
,
)
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.processors.qwen_vl
import
QwenVLProcessor
from
vllm.transformers_utils.processors.qwen_vl
import
(
QwenVLImageProcessorFast
,
QwenVLProcessor
,
)
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
.interfaces
import
(
...
...
@@ -432,15 +435,20 @@ class QwenVLModel(QWenModel):
class
QwenVLProcessingInfo
(
BaseProcessingInfo
):
def
get_
hf
_processor
(
self
,
**
kwargs
:
object
)
->
QwenVLProcessor
:
def
get_
image
_processor
(
self
,
**
kwargs
)
:
config
=
self
.
get_hf_config
()
vision_config
=
config
.
visual
image_size
=
vision_config
[
"image_size"
]
kwargs
.
setdefault
(
"size"
,
{
"width"
:
image_size
,
"height"
:
image_size
})
return
QwenVLImageProcessorFast
(
**
kwargs
)
def
get_hf_processor
(
self
,
**
kwargs
:
object
)
->
QwenVLProcessor
:
return
self
.
ctx
.
init_processor
(
QwenVLProcessor
,
tokenizer
=
self
.
get_tokenizer
(),
**
{
**
kwargs
,
"image_size"
:
image_size
}
,
image_processor
=
self
.
get_image_processor
(
**
kwargs
)
,
)
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
int
|
None
]:
...
...
vllm/tokenizers/qwen_vl.py
View file @
c781fbba
...
...
@@ -61,6 +61,10 @@ def get_qwen_vl_tokenizer(tokenizer: HfTokenizer) -> HfTokenizer:
class
QwenVLTokenizer
(
TokenizerLike
):
image_start_tag
:
str
image_end_tag
:
str
image_pad_tag
:
str
@
classmethod
def
from_pretrained
(
cls
,
*
args
,
**
kwargs
)
->
HfTokenizer
:
tokenizer
=
AutoTokenizer
.
from_pretrained
(
*
args
,
**
kwargs
)
...
...
vllm/transformers_utils/processors/glm4v.py
View file @
c781fbba
...
...
@@ -29,13 +29,8 @@ class GLM4VProcessor(ProcessorMixin):
def
__init__
(
self
,
image_processor
:
GLM4VImageProcessorFast
,
tokenizer
:
PreTrainedTokenizer
,
image_size
:
int
,
image_processor
:
GLM4VImageProcessorFast
|
None
=
None
,
)
->
None
:
self
.
tokenizer
=
tokenizer
if
image_processor
is
None
:
image_processor
=
GLM4VImageProcessorFast
(
size
=
{
"width"
:
image_size
,
"height"
:
image_size
}
)
self
.
image_processor
=
image_processor
self
.
tokenizer
=
tokenizer
vllm/transformers_utils/processors/qwen_vl.py
View file @
c781fbba
...
...
@@ -31,25 +31,12 @@ class QwenVLProcessor(ProcessorMixin):
def
__init__
(
self
,
image_processor
:
QwenVLImageProcessorFast
,
tokenizer
:
QwenVLTokenizer
,
image_size
:
int
,
image_processor
:
QwenVLImageProcessorFast
|
None
=
None
,
)
->
None
:
self
.
tokenizer
=
tokenizer
if
image_processor
is
None
:
image_processor
=
QwenVLImageProcessorFast
(
size
=
{
"width"
:
image_size
,
"height"
:
image_size
}
)
self
.
image_processor
=
image_processor
self
.
tokenizer
=
tokenizer
@
property
def
image_start_tag
(
self
)
->
str
:
return
self
.
tokenizer
.
image_start_tag
# type: ignore[attr-defined]
@
property
def
image_end_tag
(
self
)
->
str
:
return
self
.
tokenizer
.
image_end_tag
# type: ignore[attr-defined]
@
property
def
image_pad_tag
(
self
)
->
str
:
return
self
.
tokenizer
.
image_pad_tag
# type: ignore[attr-defined]
self
.
image_start_tag
=
tokenizer
.
image_start_tag
self
.
image_end_tag
=
tokenizer
.
image_end_tag
self
.
image_pad_tag
=
tokenizer
.
image_pad_tag
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment