Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a8eb1182
Unverified
Commit
a8eb1182
authored
Jan 23, 2026
by
Andreas Karatzas
Committed by
GitHub
Jan 23, 2026
Browse files
[CI][Models] Add VLM Support for Sequence Classification Conversion (#32885)
Signed-off-by:
Andreas Karatzas
<
akaratza@amd.com
>
parent
fa6e599a
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
155 additions
and
39 deletions
+155
-39
vllm/model_executor/layers/layernorm.py
vllm/model_executor/layers/layernorm.py
+39
-15
vllm/model_executor/models/adapters.py
vllm/model_executor/models/adapters.py
+113
-23
vllm/v1/attention/backends/triton_attn.py
vllm/v1/attention/backends/triton_attn.py
+3
-1
No files found.
vllm/model_executor/layers/layernorm.py
View file @
a8eb1182
...
@@ -278,15 +278,29 @@ class GemmaRMSNorm(CustomOp):
...
@@ -278,15 +278,29 @@ class GemmaRMSNorm(CustomOp):
self
.
variance_epsilon
=
eps
self
.
variance_epsilon
=
eps
@
staticmethod
@
staticmethod
def
forward_static
(
def
_
forward_static
_no_residual
(
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
variance_epsilon
:
float
,
variance_epsilon
:
float
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""PyTorch-native implementation equivalent to forward() without residual."""
"""PyTorch-native implementation equivalent to forward()."""
orig_dtype
=
x
.
dtype
x
=
x
.
float
()
variance
=
x
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
x
=
x
*
torch
.
rsqrt
(
variance
+
variance_epsilon
)
x
=
x
*
(
1.0
+
weight
.
float
())
x
=
x
.
to
(
orig_dtype
)
return
x
@
staticmethod
def
_forward_static_with_residual
(
weight
:
torch
.
Tensor
,
variance_epsilon
:
float
,
x
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""PyTorch-native implementation equivalent to forward() with residual."""
orig_dtype
=
x
.
dtype
orig_dtype
=
x
.
dtype
if
residual
is
not
None
:
x
=
(
x
=
(
x
.
float
()
+
residual
.
float
()
x
.
float
()
+
residual
.
float
()
if
orig_dtype
==
torch
.
float16
if
orig_dtype
==
torch
.
float16
...
@@ -301,7 +315,7 @@ class GemmaRMSNorm(CustomOp):
...
@@ -301,7 +315,7 @@ class GemmaRMSNorm(CustomOp):
# See https://github.com/huggingface/transformers/pull/29402
# See https://github.com/huggingface/transformers/pull/29402
x
=
x
*
(
1.0
+
weight
.
float
())
x
=
x
*
(
1.0
+
weight
.
float
())
x
=
x
.
to
(
orig_dtype
)
x
=
x
.
to
(
orig_dtype
)
return
x
if
residual
is
None
else
(
x
,
residual
)
return
x
,
residual
def
forward_native
(
def
forward_native
(
self
,
self
,
...
@@ -309,7 +323,14 @@ class GemmaRMSNorm(CustomOp):
...
@@ -309,7 +323,14 @@ class GemmaRMSNorm(CustomOp):
residual
:
torch
.
Tensor
|
None
=
None
,
residual
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""PyTorch-native implementation equivalent to forward()."""
"""PyTorch-native implementation equivalent to forward()."""
return
self
.
forward_static
(
self
.
weight
.
data
,
self
.
variance_epsilon
,
x
,
residual
)
if
residual
is
None
:
return
self
.
_forward_static_no_residual
(
self
.
weight
.
data
,
self
.
variance_epsilon
,
x
)
else
:
return
self
.
_forward_static_with_residual
(
self
.
weight
.
data
,
self
.
variance_epsilon
,
x
,
residual
)
def
forward_cuda
(
def
forward_cuda
(
self
,
self
,
...
@@ -320,8 +341,11 @@ class GemmaRMSNorm(CustomOp):
...
@@ -320,8 +341,11 @@ class GemmaRMSNorm(CustomOp):
return
self
.
forward_native
(
x
,
residual
)
return
self
.
forward_native
(
x
,
residual
)
if
not
getattr
(
self
,
"_is_compiled"
,
False
):
if
not
getattr
(
self
,
"_is_compiled"
,
False
):
self
.
forward_static
=
torch
.
compile
(
# type: ignore
self
.
_forward_static_no_residual
=
torch
.
compile
(
# type: ignore
self
.
forward_static
self
.
_forward_static_no_residual
)
self
.
_forward_static_with_residual
=
torch
.
compile
(
# type: ignore
self
.
_forward_static_with_residual
)
)
self
.
_is_compiled
=
True
self
.
_is_compiled
=
True
return
self
.
forward_native
(
x
,
residual
)
return
self
.
forward_native
(
x
,
residual
)
...
...
vllm/model_executor/models/adapters.py
View file @
a8eb1182
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Iterable
from
collections.abc
import
Iterable
from
contextlib
import
contextmanager
from
typing
import
TYPE_CHECKING
,
Any
,
TypeVar
,
cast
from
typing
import
TYPE_CHECKING
,
Any
,
TypeVar
,
cast
import
torch
import
torch
...
@@ -373,6 +374,76 @@ class SequenceClassificationConfig(VerifyAndUpdateConfig):
...
@@ -373,6 +374,76 @@ class SequenceClassificationConfig(VerifyAndUpdateConfig):
text_config
.
use_sep_token
=
use_sep_token
text_config
.
use_sep_token
=
use_sep_token
def
_get_language_model_for_seq_cls
(
model
)
->
nn
.
Module
:
"""
Get the language model component for sequence classification conversion.
For VLMs, returns the inner language model. For standard LLMs, returns model itself.
"""
if
supports_multimodal
(
model
):
try
:
lm
=
model
.
get_language_model
()
if
lm
is
not
model
:
return
lm
except
Exception
:
pass
for
attr_name
in
(
"language_model"
,
"lm"
,
"text_model"
):
if
hasattr
(
model
,
attr_name
):
candidate
=
getattr
(
model
,
attr_name
)
if
(
isinstance
(
candidate
,
nn
.
Module
)
and
candidate
is
not
model
and
hasattr
(
candidate
,
"model"
)
):
return
candidate
for
name
,
child
in
model
.
named_children
():
child_name
=
type
(
child
).
__name__
if
(
"ForCausalLM"
in
child_name
or
"LMHead"
in
child_name
)
and
hasattr
(
child
,
"model"
):
return
child
return
model
@
contextmanager
def
_disable_seq_cls_loading_on_inner_model
(
language_model
,
is_vlm
:
bool
):
"""
Context manager to temporarily disable sequence classification loading
on inner VLM models to prevent recursive seq_cls_model_loader calls.
"""
if
not
is_vlm
:
yield
return
inner_hf_config
=
getattr
(
language_model
,
"config"
,
None
)
if
inner_hf_config
is
None
:
yield
return
inner_text_config
=
inner_hf_config
.
get_text_config
()
original_method
=
getattr
(
inner_text_config
,
"method"
,
None
)
original_tokens
=
getattr
(
inner_text_config
,
"classifier_from_token"
,
None
)
original_hf_tokens
=
getattr
(
inner_hf_config
,
"classifier_from_token"
,
None
)
try
:
if
original_method
is
not
None
:
inner_text_config
.
method
=
None
if
original_tokens
is
not
None
:
inner_text_config
.
classifier_from_token
=
None
if
original_hf_tokens
is
not
None
:
inner_hf_config
.
classifier_from_token
=
None
yield
finally
:
if
original_method
is
not
None
:
inner_text_config
.
method
=
original_method
if
original_tokens
is
not
None
:
inner_text_config
.
classifier_from_token
=
original_tokens
if
original_hf_tokens
is
not
None
:
inner_hf_config
.
classifier_from_token
=
original_hf_tokens
def
load_weights_using_from_2_way_softmax
(
def
load_weights_using_from_2_way_softmax
(
model
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]
model
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]
):
):
...
@@ -393,9 +464,9 @@ def load_weights_using_from_2_way_softmax(
...
@@ -393,9 +464,9 @@ def load_weights_using_from_2_way_softmax(
tokens
=
cast
(
list
[
int
],
tokens
)
tokens
=
cast
(
list
[
int
],
tokens
)
assert
len
(
tokens
)
==
2
assert
len
(
tokens
)
==
2
language_model
=
(
language_model
=
_get_language_model_for_seq_cls
(
model
)
model
.
get_language_model
()
if
hasattr
(
model
,
"get_language_model"
)
else
model
is_vlm
=
language_model
is
not
model
)
language_model
.
lm_head
=
ParallelLMHead
(
language_model
.
lm_head
=
ParallelLMHead
(
text_config
.
vocab_size
,
text_config
.
hidden_size
,
quant_config
=
quant_config
text_config
.
vocab_size
,
text_config
.
hidden_size
,
quant_config
=
quant_config
)
)
...
@@ -411,6 +482,7 @@ def load_weights_using_from_2_way_softmax(
...
@@ -411,6 +482,7 @@ def load_weights_using_from_2_way_softmax(
)
)
language_model
.
lm_head
=
language_model
.
lm_head
.
tie_weights
(
embed_tokens
)
language_model
.
lm_head
=
language_model
.
lm_head
.
tie_weights
(
embed_tokens
)
with
_disable_seq_cls_loading_on_inner_model
(
language_model
,
is_vlm
):
# ModelForPooling is dynamically defined inside the _create_pooling_model_cls
# ModelForPooling is dynamically defined inside the _create_pooling_model_cls
# function, so we need use this hacky method to obtain it.
# function, so we need use this hacky method to obtain it.
pooling_model_cls
=
next
(
pooling_model_cls
=
next
(
...
@@ -434,12 +506,15 @@ def load_weights_using_from_2_way_softmax(
...
@@ -434,12 +506,15 @@ def load_weights_using_from_2_way_softmax(
torch
.
float32
torch
.
float32
)
-
lm_head_weight
.
data
[[
false_id
]].
to
(
torch
.
float32
)
)
-
lm_head_weight
.
data
[[
false_id
]].
to
(
torch
.
float32
)
param
=
model
.
score
.
weight
score_layer
=
language_model
.
score
if
is_vlm
else
model
.
score
param
=
score_layer
.
weight
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
score_weight
)
weight_loader
(
param
,
score_weight
)
del
language_model
.
lm_head
del
language_model
.
lm_head
loaded_weights
.
add
(
"score.weight"
)
score_weight_name
=
"language_model.score.weight"
if
is_vlm
else
"score.weight"
loaded_weights
.
add
(
score_weight_name
)
lm_head_name
=
"lm_head.weight"
lm_head_name
=
"lm_head.weight"
if
hf_to_vllm_mapper
:
=
getattr
(
model
,
"hf_to_vllm_mapper"
,
None
):
if
hf_to_vllm_mapper
:
=
getattr
(
model
,
"hf_to_vllm_mapper"
,
None
):
...
@@ -460,22 +535,30 @@ def load_weights_no_post_processing(model, weights: Iterable[tuple[str, torch.Te
...
@@ -460,22 +535,30 @@ def load_weights_no_post_processing(model, weights: Iterable[tuple[str, torch.Te
tokens
=
cast
(
list
[
int
],
tokens
)
tokens
=
cast
(
list
[
int
],
tokens
)
assert
len
(
tokens
)
>
0
assert
len
(
tokens
)
>
0
model
.
lm_head
=
ParallelLMHead
(
language_model
=
_get_language_model_for_seq_cls
(
model
)
is_vlm
=
language_model
is
not
model
language_model
.
lm_head
=
ParallelLMHead
(
text_config
.
vocab_size
,
text_config
.
hidden_size
,
quant_config
=
quant_config
text_config
.
vocab_size
,
text_config
.
hidden_size
,
quant_config
=
quant_config
)
)
if
text_config
.
tie_word_embeddings
:
if
text_config
.
tie_word_embeddings
:
# embed_tokens is the assumed name for input embeddings. If the model does not
# embed_tokens is the assumed name for input embeddings. If the model does not
# have this attribute, we fall back to get_input_embeddings(), which is used by
# have this attribute, we fall back to get_input_embeddings(), which is used by
# the Transformers modeling backend.
# the Transformers modeling backend.
text_backbone
=
language_model
.
model
embed_tokens
=
(
embed_tokens
=
(
model
.
model
.
embed_tokens
text_backbone
.
embed_tokens
if
hasattr
(
model
.
model
,
"embed_tokens"
)
if
hasattr
(
text_backbone
,
"embed_tokens"
)
else
model
.
model
.
get_input_embeddings
()
else
text_backbone
.
get_input_embeddings
()
)
)
model
.
lm_head
=
model
.
lm_head
.
tie_weights
(
embed_tokens
)
language_
model
.
lm_head
=
language_
model
.
lm_head
.
tie_weights
(
embed_tokens
)
with
_disable_seq_cls_loading_on_inner_model
(
language_model
,
is_vlm
):
pooling_model_cls
=
next
(
x
for
x
in
type
(
model
).
__mro__
if
x
.
__name__
==
"ModelForPooling"
)
# Skip ModelForSequenceClassification in MRO to avoid infinite recursion
# Skip ModelForSequenceClassification in MRO to avoid infinite recursion
loaded_weights
=
type
(
model
).
__mro__
[
1
]
.
load_weights
(
model
,
weights
)
loaded_weights
=
pooling_model_cls
.
load_weights
(
model
,
weights
)
from
vllm.tokenizers
import
get_tokenizer
from
vllm.tokenizers
import
get_tokenizer
...
@@ -487,15 +570,22 @@ def load_weights_no_post_processing(model, weights: Iterable[tuple[str, torch.Te
...
@@ -487,15 +570,22 @@ def load_weights_no_post_processing(model, weights: Iterable[tuple[str, torch.Te
)
)
token_ids
=
[
tokenizer
.
convert_tokens_to_ids
(
t
)
for
t
in
tokens
]
token_ids
=
[
tokenizer
.
convert_tokens_to_ids
(
t
)
for
t
in
tokens
]
score_weight
=
model
.
lm_head
.
weight
.
data
[
token_ids
]
score_weight
=
language_
model
.
lm_head
.
weight
.
data
[
token_ids
]
param
=
model
.
score
.
weight
score_layer
=
language_model
.
score
if
is_vlm
else
model
.
score
param
=
score_layer
.
weight
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
score_weight
)
weight_loader
(
param
,
score_weight
)
del
model
.
lm_head
del
language_model
.
lm_head
loaded_weights
.
add
(
"score.weight"
)
loaded_weights
.
discard
(
"lm_head.weight"
)
score_weight_name
=
"language_model.score.weight"
if
is_vlm
else
"score.weight"
loaded_weights
.
add
(
score_weight_name
)
lm_head_name
=
"lm_head.weight"
if
hf_to_vllm_mapper
:
=
getattr
(
model
,
"hf_to_vllm_mapper"
,
None
):
lm_head_name
=
hf_to_vllm_mapper
.
_map_name
(
lm_head_name
)
loaded_weights
.
discard
(
lm_head_name
)
return
loaded_weights
return
loaded_weights
...
...
vllm/v1/attention/backends/triton_attn.py
View file @
a8eb1182
...
@@ -107,7 +107,9 @@ class TritonAttentionMetadata:
...
@@ -107,7 +107,9 @@ class TritonAttentionMetadata:
for
r
in
range_lists
for
r
in
range_lists
]
]
return
torch
.
nested
.
nested_tensor
(
range_tensors
).
to_padded_tensor
(
0
)
return
torch
.
nested
.
nested_tensor
(
range_tensors
,
layout
=
torch
.
jagged
).
to_padded_tensor
(
0
)
class
TritonAttentionMetadataBuilder
(
AttentionMetadataBuilder
[
TritonAttentionMetadata
]):
class
TritonAttentionMetadataBuilder
(
AttentionMetadataBuilder
[
TritonAttentionMetadata
]):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment