Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9101dc75
Unverified
Commit
9101dc75
authored
Jan 12, 2026
by
Cyrus Leung
Committed by
GitHub
Jan 11, 2026
Browse files
[Model] Avoid hardcoding pooling type (#32119)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
025a32f9
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
47 additions
and
22 deletions
+47
-22
vllm/model_executor/models/bert.py
vllm/model_executor/models/bert.py
+10
-4
vllm/model_executor/models/bert_with_rope.py
vllm/model_executor/models/bert_with_rope.py
+9
-1
vllm/model_executor/models/gritlm.py
vllm/model_executor/models/gritlm.py
+9
-4
vllm/model_executor/models/modernbert.py
vllm/model_executor/models/modernbert.py
+12
-4
vllm/model_executor/models/roberta.py
vllm/model_executor/models/roberta.py
+1
-3
vllm/model_executor/models/transformers/pooling.py
vllm/model_executor/models/transformers/pooling.py
+6
-6
No files found.
vllm/model_executor/models/bert.py
View file @
9101dc75
...
@@ -25,11 +25,11 @@ from vllm.model_executor.layers.pooler import (
...
@@ -25,11 +25,11 @@ from vllm.model_executor.layers.pooler import (
PoolingParamsUpdate
,
PoolingParamsUpdate
,
)
)
from
vllm.model_executor.layers.pooler.seqwise
import
(
from
vllm.model_executor.layers.pooler.seqwise
import
(
CLSPool
,
SequencePooler
,
SequencePooler
,
SequencePoolerHeadOutput
,
SequencePoolerHeadOutput
,
SequencePoolerOutput
,
SequencePoolerOutput
,
SequencePoolingMethodOutput
,
SequencePoolingMethodOutput
,
get_seq_pooling_method
,
)
)
from
vllm.model_executor.layers.pooler.tokwise
import
(
from
vllm.model_executor.layers.pooler.tokwise
import
(
pooler_for_token_classify
,
pooler_for_token_classify
,
...
@@ -94,9 +94,9 @@ class BertEmbedding(nn.Module):
...
@@ -94,9 +94,9 @@ class BertEmbedding(nn.Module):
class
BertPooler
(
SequencePooler
):
class
BertPooler
(
SequencePooler
):
def
__init__
(
self
,
config
:
BertConfig
):
def
__init__
(
self
,
config
:
BertConfig
,
pooler_config
:
PoolerConfig
):
super
().
__init__
(
super
().
__init__
(
pooling
=
CLSPool
(
),
pooling
=
get_seq_pooling_method
(
pooler_config
.
seq_pooling_type
),
head
=
self
.
head
,
head
=
self
.
head
,
)
)
...
@@ -450,7 +450,11 @@ class BertPoolingModel(BertModel):
...
@@ -450,7 +450,11 @@ class BertPoolingModel(BertModel):
)
)
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
self
.
pooler
=
BertPooler
(
config
)
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
self
.
pooler
=
BertPooler
(
config
,
pooler_config
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
other_weights
,
loaded_stacked_params
=
self
.
_load_weights
(
weights
)
other_weights
,
loaded_stacked_params
=
self
.
_load_weights
(
weights
)
...
@@ -711,6 +715,8 @@ class BertSpladeSparseEmbeddingModel(BertEmbeddingModel):
...
@@ -711,6 +715,8 @@ class BertSpladeSparseEmbeddingModel(BertEmbeddingModel):
layer_norm_eps
=
getattr
(
cfg
,
"layer_norm_eps"
,
1e-12
),
layer_norm_eps
=
getattr
(
cfg
,
"layer_norm_eps"
,
1e-12
),
)
)
# None of vLLM's built-in sequence pooling types are
# applicable so it is overwritten by SPLADESparsePooler
pooling_mode
=
getattr
(
self
,
"_splade_pooling"
,
"max"
)
pooling_mode
=
getattr
(
self
,
"_splade_pooling"
,
"max"
)
cls_id
=
getattr
(
cfg
,
"cls_token_id"
,
None
)
cls_id
=
getattr
(
cfg
,
"cls_token_id"
,
None
)
...
...
vllm/model_executor/models/bert_with_rope.py
View file @
9101dc75
...
@@ -453,6 +453,7 @@ class BertWithRope(nn.Module, SupportsQuant):
...
@@ -453,6 +453,7 @@ class BertWithRope(nn.Module, SupportsQuant):
add_pooling_layer
:
bool
=
False
,
add_pooling_layer
:
bool
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
vllm_config
=
vllm_config
self
.
vllm_config
=
vllm_config
self
.
add_pooling_layer
=
add_pooling_layer
self
.
add_pooling_layer
=
add_pooling_layer
self
.
config
=
vllm_config
.
model_config
.
hf_config
self
.
config
=
vllm_config
.
model_config
.
hf_config
...
@@ -463,7 +464,14 @@ class BertWithRope(nn.Module, SupportsQuant):
...
@@ -463,7 +464,14 @@ class BertWithRope(nn.Module, SupportsQuant):
rotary_kwargs
=
self
.
config
.
rotary_kwargs
,
rotary_kwargs
=
self
.
config
.
rotary_kwargs
,
prefix
=
f
"
{
prefix
}
.encoder"
,
prefix
=
f
"
{
prefix
}
.encoder"
,
)
)
self
.
pooler
=
BertPooler
(
self
.
config
)
if
add_pooling_layer
else
None
if
add_pooling_layer
:
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
self
.
pooler
=
BertPooler
(
self
.
config
,
pooler_config
)
else
:
self
.
pooler
=
None
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embeddings
(
input_ids
)
return
self
.
embeddings
(
input_ids
)
...
...
vllm/model_executor/models/gritlm.py
View file @
9101dc75
...
@@ -5,7 +5,7 @@ from collections.abc import Set
...
@@ -5,7 +5,7 @@ from collections.abc import Set
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
vllm.config
import
ModelConfig
,
VllmConfig
from
vllm.config
import
ModelConfig
,
PoolerConfig
,
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.pooler
import
(
from
vllm.model_executor.layers.pooler
import
(
DispatchPooler
,
DispatchPooler
,
...
@@ -17,6 +17,7 @@ from vllm.model_executor.layers.pooler.seqwise import (
...
@@ -17,6 +17,7 @@ from vllm.model_executor.layers.pooler.seqwise import (
SequencePoolerHeadOutput
,
SequencePoolerHeadOutput
,
SequencePoolingMethod
,
SequencePoolingMethod
,
SequencePoolingMethodOutput
,
SequencePoolingMethodOutput
,
get_seq_pooling_method
,
)
)
from
vllm.model_executor.layers.pooler.tokwise
import
pooler_for_token_embed
from
vllm.model_executor.layers.pooler.tokwise
import
pooler_for_token_embed
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
...
@@ -177,9 +178,13 @@ class GritLMMeanPool(SequencePoolingMethod):
...
@@ -177,9 +178,13 @@ class GritLMMeanPool(SequencePoolingMethod):
class
GritLMPooler
(
SequencePooler
):
class
GritLMPooler
(
SequencePooler
):
def
__init__
(
self
,
model_config
:
ModelConfig
):
def
__init__
(
self
,
model_config
:
ModelConfig
,
pooler_config
:
PoolerConfig
):
super
().
__init__
(
super
().
__init__
(
pooling
=
GritLMMeanPool
(
model_config
),
pooling
=
(
GritLMMeanPool
(
model_config
)
if
pooler_config
.
seq_pooling_type
==
"MEAN"
else
get_seq_pooling_method
(
pooler_config
.
seq_pooling_type
)
),
head
=
self
.
head
,
head
=
self
.
head
,
)
)
...
@@ -235,6 +240,6 @@ class GritLM(LlamaForCausalLM):
...
@@ -235,6 +240,6 @@ class GritLM(LlamaForCausalLM):
self
.
pooler
=
DispatchPooler
(
self
.
pooler
=
DispatchPooler
(
{
{
"token_embed"
:
pooler_for_token_embed
(
pooler_config
),
"token_embed"
:
pooler_for_token_embed
(
pooler_config
),
"embed"
:
GritLMPooler
(
vllm_config
.
model_config
),
"embed"
:
GritLMPooler
(
vllm_config
.
model_config
,
pooler_config
),
}
}
)
)
vllm/model_executor/models/modernbert.py
View file @
9101dc75
...
@@ -8,7 +8,7 @@ from transformers import ModernBertConfig
...
@@ -8,7 +8,7 @@ from transformers import ModernBertConfig
from
transformers.activations
import
ACT2FN
from
transformers.activations
import
ACT2FN
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
VllmConfig
from
vllm.config
import
PoolerConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.attention.encoder_only_attention
import
(
from
vllm.model_executor.layers.attention.encoder_only_attention
import
(
EncoderOnlyAttention
,
EncoderOnlyAttention
,
...
@@ -282,9 +282,14 @@ class ModernBertModel(nn.Module):
...
@@ -282,9 +282,14 @@ class ModernBertModel(nn.Module):
class
ModernBertPooler
(
SequencePooler
):
class
ModernBertPooler
(
SequencePooler
):
def
__init__
(
self
,
config
:
ModernBertConfig
):
def
__init__
(
self
,
config
:
ModernBertConfig
,
pooler_config
:
PoolerConfig
):
hf_pooling_type
=
config
.
classifier_pooling
.
upper
()
# vllm_pooling_type = pooler_config.seq_pooling_type
# Currently we don't have a way to see if the user set the pooling type
# explicitly or not, so we always use the HF pooling type for now.
super
().
__init__
(
super
().
__init__
(
pooling
=
get_seq_pooling_method
(
config
.
classifier_pooling
.
upper
()
),
pooling
=
get_seq_pooling_method
(
hf_pooling_type
),
head
=
self
.
head
,
head
=
self
.
head
,
)
)
...
@@ -314,7 +319,9 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
...
@@ -314,7 +319,9 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
self
.
config
=
config
self
.
config
=
config
self
.
model
=
ModernBertModel
(
self
.
model
=
ModernBertModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"modernbert"
)
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"modernbert"
)
...
@@ -324,11 +331,12 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
...
@@ -324,11 +331,12 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
config
.
num_labels
,
config
.
num_labels
,
dtype
=
vllm_config
.
model_config
.
head_dtype
,
dtype
=
vllm_config
.
model_config
.
head_dtype
,
)
)
self
.
pooling
=
ModernBertPooler
(
config
)
pooler_config
=
vllm_config
.
model_config
.
pooler_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
assert
pooler_config
is
not
None
self
.
pooling
=
ModernBertPooler
(
config
,
pooler_config
)
self
.
pooler
=
DispatchPooler
.
for_seq_cls
(
self
.
pooler
=
DispatchPooler
.
for_seq_cls
(
pooler_config
,
pooler_config
,
pooling
=
self
.
pooling
,
pooling
=
self
.
pooling
,
...
...
vllm/model_executor/models/roberta.py
View file @
9101dc75
...
@@ -9,7 +9,6 @@ from transformers import RobertaConfig
...
@@ -9,7 +9,6 @@ from transformers import RobertaConfig
from
vllm.config
import
ModelConfig
,
VllmConfig
from
vllm.config
import
ModelConfig
,
VllmConfig
from
vllm.model_executor.layers.pooler
import
DispatchPooler
from
vllm.model_executor.layers.pooler
import
DispatchPooler
from
vllm.model_executor.layers.pooler.seqwise
import
CLSPool
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.models.bert
import
(
from
vllm.model_executor.models.bert
import
(
TOKEN_TYPE_SHIFT
,
TOKEN_TYPE_SHIFT
,
...
@@ -86,7 +85,7 @@ class RobertaClassificationHead(nn.Module):
...
@@ -86,7 +85,7 @@ class RobertaClassificationHead(nn.Module):
)
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
#
CLSPool
has already been applied in `pooling`
#
Token extraction
has already been applied in `
pooler.
pooling`
x
=
self
.
dense
(
x
)
x
=
self
.
dense
(
x
)
x
=
torch
.
tanh
(
x
)
x
=
torch
.
tanh
(
x
)
x
=
self
.
out_proj
(
x
)
x
=
self
.
out_proj
(
x
)
...
@@ -194,7 +193,6 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
...
@@ -194,7 +193,6 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
self
.
pooler
=
DispatchPooler
.
for_seq_cls
(
self
.
pooler
=
DispatchPooler
.
for_seq_cls
(
pooler_config
,
pooler_config
,
pooling
=
CLSPool
(),
classifier
=
self
.
classifier
,
classifier
=
self
.
classifier
,
)
)
...
...
vllm/model_executor/models/transformers/pooling.py
View file @
9101dc75
...
@@ -23,7 +23,6 @@ from transformers import AutoModelForSequenceClassification
...
@@ -23,7 +23,6 @@ from transformers import AutoModelForSequenceClassification
from
vllm.config.utils
import
getattr_iter
from
vllm.config.utils
import
getattr_iter
from
vllm.model_executor.layers.pooler
import
DispatchPooler
from
vllm.model_executor.layers.pooler
import
DispatchPooler
from
vllm.model_executor.layers.pooler.seqwise
import
CLSPool
from
vllm.model_executor.models.interfaces
import
SupportsCrossEncoding
from
vllm.model_executor.models.interfaces
import
SupportsCrossEncoding
from
vllm.model_executor.models.interfaces_base
import
VllmModelForPooling
from
vllm.model_executor.models.interfaces_base
import
VllmModelForPooling
...
@@ -32,7 +31,7 @@ if TYPE_CHECKING:
...
@@ -32,7 +31,7 @@ if TYPE_CHECKING:
class
EmbeddingMixin
(
VllmModelForPooling
):
class
EmbeddingMixin
(
VllmModelForPooling
):
default_pooling_type
=
"CLS"
default_
seq_
pooling_type
=
"CLS"
def
__init__
(
self
,
*
,
vllm_config
:
"VllmConfig"
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
"VllmConfig"
,
prefix
:
str
=
""
):
# Skip VllmModelForPooling.__init__ and call the next class in MRO
# Skip VllmModelForPooling.__init__ and call the next class in MRO
...
@@ -47,7 +46,7 @@ class EmbeddingMixin(VllmModelForPooling):
...
@@ -47,7 +46,7 @@ class EmbeddingMixin(VllmModelForPooling):
class
SequenceClassificationMixin
(
SupportsCrossEncoding
,
VllmModelForPooling
):
class
SequenceClassificationMixin
(
SupportsCrossEncoding
,
VllmModelForPooling
):
default_pooling_type
=
"CLS"
default_
seq_
pooling_type
=
"CLS"
def
__init__
(
self
,
*
,
vllm_config
:
"VllmConfig"
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
"VllmConfig"
,
prefix
:
str
=
""
):
# Skip VllmModelForPooling.__init__ and call the next class in MRO
# Skip VllmModelForPooling.__init__ and call the next class in MRO
...
@@ -85,8 +84,10 @@ class SequenceClassificationMixin(SupportsCrossEncoding, VllmModelForPooling):
...
@@ -85,8 +84,10 @@ class SequenceClassificationMixin(SupportsCrossEncoding, VllmModelForPooling):
self
.
init_parameters
(
self
.
classifier
,
dtype
=
self
.
model_config
.
head_dtype
)
self
.
init_parameters
(
self
.
classifier
,
dtype
=
self
.
model_config
.
head_dtype
)
class
ClassifierWithReshape
(
self
.
classifier
.
__class__
):
class
ClassifierWithReshape
(
self
.
classifier
.
__class__
):
"""CLSPool has already been applied in `pooling`.
"""
Add dim to match expected input shape of `classifier.forward`."""
Token extraction has already been applied in `pooler.pooling`.
Add dim to match expected input shape of `classifier.forward`.
"""
def
forward
(
self
,
*
args
,
**
kwargs
):
def
forward
(
self
,
*
args
,
**
kwargs
):
if
len
(
args
)
>
0
:
if
len
(
args
)
>
0
:
...
@@ -97,6 +98,5 @@ class SequenceClassificationMixin(SupportsCrossEncoding, VllmModelForPooling):
...
@@ -97,6 +98,5 @@ class SequenceClassificationMixin(SupportsCrossEncoding, VllmModelForPooling):
self
.
pooler
=
DispatchPooler
.
for_seq_cls
(
self
.
pooler
=
DispatchPooler
.
for_seq_cls
(
pooler_config
,
pooler_config
,
pooling
=
CLSPool
(),
classifier
=
self
.
classifier
,
classifier
=
self
.
classifier
,
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment