Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8863c2b2
Unverified
Commit
8863c2b2
authored
Jan 13, 2026
by
Cyrus Leung
Committed by
GitHub
Jan 12, 2026
Browse files
[Model] Standardize pooling heads (#32148)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
3f72639d
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
182 additions
and
149 deletions
+182
-149
vllm/model_executor/layers/pooler/common.py
vllm/model_executor/layers/pooler/common.py
+6
-1
vllm/model_executor/layers/pooler/seqwise/heads.py
vllm/model_executor/layers/pooler/seqwise/heads.py
+41
-47
vllm/model_executor/layers/pooler/seqwise/poolers.py
vllm/model_executor/layers/pooler/seqwise/poolers.py
+25
-4
vllm/model_executor/layers/pooler/tokwise/heads.py
vllm/model_executor/layers/pooler/tokwise/heads.py
+23
-32
vllm/model_executor/layers/pooler/tokwise/poolers.py
vllm/model_executor/layers/pooler/tokwise/poolers.py
+25
-4
vllm/model_executor/models/bert.py
vllm/model_executor/models/bert.py
+25
-24
vllm/model_executor/models/bert_with_rope.py
vllm/model_executor/models/bert_with_rope.py
+1
-4
vllm/model_executor/models/gritlm.py
vllm/model_executor/models/gritlm.py
+11
-14
vllm/model_executor/models/modernbert.py
vllm/model_executor/models/modernbert.py
+25
-19
No files found.
vllm/model_executor/layers/pooler/common.py
View file @
8863c2b2
...
...
@@ -2,12 +2,17 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Callable
from
dataclasses
import
dataclass
from
typing
import
TypeVar
import
torch
from
vllm.pooling_params
import
PoolingParams
_T
=
TypeVar
(
"_T"
,
bound
=
torch
.
Tensor
|
list
[
torch
.
Tensor
])
ProjectorFn
=
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
]
ClassifierFn
=
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
]
ActivationFn
=
Callable
[[
_T
],
_T
]
@
dataclass
(
frozen
=
True
)
...
...
@@ -24,4 +29,4 @@ class PoolingParamsUpdate:
params
.
requires_token_ids
=
self
.
requires_token_ids
__all__
=
[
"
Classifie
rFn"
,
"PoolingParamsUpdate"
]
__all__
=
[
"
ActivationFn"
,
"ClassifierFn"
,
"Projecto
rFn"
,
"PoolingParamsUpdate"
]
vllm/model_executor/layers/pooler/seqwise/heads.py
View file @
8863c2b2
...
...
@@ -7,14 +7,7 @@ from typing import TypeAlias
import
torch
import
torch.nn
as
nn
from
vllm.config
import
get_current_vllm_config
from
vllm.model_executor.layers.pooler
import
ClassifierFn
from
vllm.model_executor.layers.pooler.activations
import
(
PoolerActivation
,
PoolerNormalize
,
resolve_classifier_act_fn
,
)
from
vllm.model_executor.models.adapters
import
_load_st_projector
from
vllm.model_executor.layers.pooler
import
ActivationFn
,
ClassifierFn
,
ProjectorFn
from
vllm.tasks
import
PoolingTask
from
vllm.v1.pool.metadata
import
PoolingMetadata
...
...
@@ -38,17 +31,17 @@ class SequencePoolerHead(nn.Module, ABC):
class
EmbeddingPoolerHead
(
SequencePoolerHead
):
def
__init__
(
self
)
->
None
:
def
__init__
(
self
,
projector
:
ProjectorFn
|
None
=
None
,
head_dtype
:
torch
.
dtype
|
str
|
None
=
None
,
activation
:
ActivationFn
|
None
=
None
,
)
->
None
:
super
().
__init__
()
# Load ST projector if available
vllm_config
=
get_current_vllm_config
()
model_config
=
vllm_config
.
model_config
self
.
projector
=
_load_st_projector
(
model_config
)
self
.
head_dtype
=
model_config
.
head_dtype
self
.
activation
=
PoolerNormalize
()
self
.
projector
=
projector
self
.
head_dtype
=
head_dtype
self
.
activation
=
activation
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"embed"
}
...
...
@@ -65,6 +58,7 @@ class EmbeddingPoolerHead(SequencePoolerHead):
pooled_data
=
torch
.
stack
(
pooled_data
)
# pooled_data shape: [batchsize, hidden_dimension]
if
self
.
head_dtype
is
not
None
:
pooled_data
=
pooled_data
.
to
(
self
.
head_dtype
)
# Apply ST projector
...
...
@@ -88,6 +82,7 @@ class EmbeddingPoolerHead(SequencePoolerHead):
]
# for normalize
if
self
.
activation
is
not
None
:
flags
=
[
p
.
normalize
for
p
in
pooling_params
]
if
len
(
set
(
flags
))
==
1
:
if
flags
[
0
]:
...
...
@@ -106,20 +101,16 @@ class ClassifierPoolerHead(SequencePoolerHead):
def
__init__
(
self
,
classifier
:
ClassifierFn
|
None
=
None
,
act_fn
:
PoolerActivation
|
str
|
None
=
None
,
logit_bias
:
float
|
None
=
None
,
head_dtype
:
torch
.
dtype
|
str
|
None
=
None
,
activation
:
ActivationFn
|
None
=
None
,
)
->
None
:
super
().
__init__
()
vllm_config
=
get_current_vllm_config
()
model_config
=
vllm_config
.
model_config
self
.
classifier
=
classifier
self
.
logit_bias
:
float
|
None
=
model_config
.
pooler_config
.
logit_bias
self
.
head_dtype
=
model_config
.
head_dtype
self
.
act_fn
=
resolve_classifier_act_fn
(
model_config
,
static_num_labels
=
True
,
act_fn
=
act_fn
)
self
.
logit_bias
=
logit_bias
self
.
head_dtype
=
head_dtype
self
.
activation
=
activation
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"classify"
,
"score"
}
...
...
@@ -136,6 +127,7 @@ class ClassifierPoolerHead(SequencePoolerHead):
pooled_data
=
torch
.
stack
(
pooled_data
)
# pooled_data shape: [batchsize, hidden_size]
if
self
.
head_dtype
is
not
None
:
pooled_data
=
pooled_data
.
to
(
self
.
head_dtype
)
if
self
.
classifier
is
not
None
:
...
...
@@ -145,13 +137,15 @@ class ClassifierPoolerHead(SequencePoolerHead):
if
self
.
logit_bias
is
not
None
:
pooled_data
-=
self
.
logit_bias
if
self
.
activation
is
not
None
:
flags
=
[
p
.
use_activation
for
p
in
pooling_params
]
if
len
(
set
(
flags
))
==
1
:
scores
=
self
.
act
_f
n
(
pooled_data
)
if
flags
[
0
]
else
pooled_data
pooled_data
=
self
.
act
ivatio
n
(
pooled_data
)
if
flags
[
0
]
else
pooled_data
else
:
scores
=
[
self
.
act_fn
(
vecs
)
if
f
else
vecs
for
vecs
,
f
in
zip
(
pooled_data
,
flags
)
pooled_data
=
[
self
.
activation
(
vecs
)
if
f
else
vecs
for
vecs
,
f
in
zip
(
pooled_data
,
flags
)
]
#
scores
shape: [batchsize, num_labels]
return
scores
#
pooled_data
shape: [batchsize, num_labels]
return
pooled_data
vllm/model_executor/layers/pooler/seqwise/poolers.py
View file @
8863c2b2
...
...
@@ -5,10 +5,15 @@ from typing import TypeAlias
import
torch
from
vllm.config
import
PoolerConfig
from
vllm.config
import
PoolerConfig
,
get_current_vllm_config
from
vllm.model_executor.layers.pooler
import
ClassifierFn
,
PoolingParamsUpdate
from
vllm.model_executor.layers.pooler.abstract
import
Pooler
from
vllm.model_executor.layers.pooler.activations
import
PoolerActivation
from
vllm.model_executor.layers.pooler.activations
import
(
PoolerActivation
,
PoolerNormalize
,
resolve_classifier_act_fn
,
)
from
vllm.model_executor.models.adapters
import
_load_st_projector
from
vllm.tasks
import
POOLING_TASKS
,
PoolingTask
from
vllm.v1.pool.metadata
import
PoolingMetadata
...
...
@@ -86,7 +91,14 @@ class SequencePooler(Pooler):
def
pooler_for_embed
(
pooler_config
:
PoolerConfig
):
pooling
=
get_seq_pooling_method
(
pooler_config
.
get_seq_pooling_type
())
head
=
EmbeddingPoolerHead
()
vllm_config
=
get_current_vllm_config
()
model_config
=
vllm_config
.
model_config
head
=
EmbeddingPoolerHead
(
projector
=
_load_st_projector
(
model_config
),
head_dtype
=
model_config
.
head_dtype
,
activation
=
PoolerNormalize
(),
)
return
SequencePooler
(
pooling
=
pooling
,
head
=
head
)
...
...
@@ -101,6 +113,15 @@ def pooler_for_classify(
if
pooling
is
None
:
pooling
=
get_seq_pooling_method
(
pooler_config
.
get_seq_pooling_type
())
head
=
ClassifierPoolerHead
(
classifier
=
classifier
,
act_fn
=
act_fn
)
vllm_config
=
get_current_vllm_config
()
model_config
=
vllm_config
.
model_config
head
=
ClassifierPoolerHead
(
classifier
=
classifier
,
logit_bias
=
model_config
.
pooler_config
.
logit_bias
,
head_dtype
=
model_config
.
head_dtype
,
activation
=
resolve_classifier_act_fn
(
model_config
,
static_num_labels
=
True
,
act_fn
=
act_fn
),
)
return
SequencePooler
(
pooling
=
pooling
,
head
=
head
)
vllm/model_executor/layers/pooler/tokwise/heads.py
View file @
8863c2b2
...
...
@@ -7,14 +7,7 @@ from typing import TypeAlias
import
torch
import
torch.nn
as
nn
from
vllm.config
import
get_current_vllm_config
from
vllm.model_executor.layers.pooler
import
ClassifierFn
from
vllm.model_executor.layers.pooler.activations
import
(
PoolerActivation
,
PoolerNormalize
,
resolve_classifier_act_fn
,
)
from
vllm.model_executor.models.adapters
import
_load_st_projector
from
vllm.model_executor.layers.pooler
import
ActivationFn
,
ClassifierFn
,
ProjectorFn
from
vllm.pooling_params
import
PoolingParams
from
vllm.tasks
import
PoolingTask
from
vllm.v1.pool.metadata
import
PoolingMetadata
...
...
@@ -49,17 +42,17 @@ class TokenPoolerHead(nn.Module, ABC):
class
TokenEmbeddingPoolerHead
(
TokenPoolerHead
):
def
__init__
(
self
)
->
None
:
def
__init__
(
self
,
projector
:
ProjectorFn
|
None
=
None
,
head_dtype
:
torch
.
dtype
|
str
|
None
=
None
,
activation
:
ActivationFn
|
None
=
None
,
)
->
None
:
super
().
__init__
()
# Load ST projector if available
vllm_config
=
get_current_vllm_config
()
model_config
=
vllm_config
.
model_config
self
.
projector
=
_load_st_projector
(
model_config
)
self
.
head_dtype
=
model_config
.
head_dtype
self
.
activation
=
PoolerNormalize
()
self
.
projector
=
projector
self
.
head_dtype
=
head_dtype
self
.
activation
=
activation
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"token_embed"
}
...
...
@@ -73,6 +66,7 @@ class TokenEmbeddingPoolerHead(TokenPoolerHead):
if
pooled_data
is
None
:
return
None
if
self
.
head_dtype
is
not
None
:
pooled_data
=
pooled_data
.
to
(
self
.
head_dtype
)
# pooled_data shape: [n_tokens, hidden_dimension]
...
...
@@ -85,7 +79,7 @@ class TokenEmbeddingPoolerHead(TokenPoolerHead):
pooled_data
=
pooled_data
[...,
:
pooling_param
.
dimensions
]
# for normalize
if
pooling_param
.
normalize
:
if
self
.
activation
is
not
None
and
pooling_param
.
normalize
:
pooled_data
=
self
.
activation
(
pooled_data
)
# pooled_data shape: [n_tokens, embedding_dimension]
...
...
@@ -96,20 +90,16 @@ class TokenClassifierPoolerHead(TokenPoolerHead):
def
__init__
(
self
,
classifier
:
ClassifierFn
|
None
=
None
,
act_fn
:
PoolerActivation
|
str
|
None
=
None
,
logit_bias
:
float
|
None
=
None
,
head_dtype
:
torch
.
dtype
|
str
|
None
=
None
,
activation
:
ActivationFn
|
None
=
None
,
)
->
None
:
super
().
__init__
()
vllm_config
=
get_current_vllm_config
()
model_config
=
vllm_config
.
model_config
self
.
classifier
=
classifier
self
.
logit_bias
:
float
|
None
=
model_config
.
pooler_config
.
logit_bias
self
.
head_dtype
=
model_config
.
head_dtype
self
.
act_fn
=
resolve_classifier_act_fn
(
model_config
,
static_num_labels
=
False
,
act_fn
=
act_fn
)
self
.
logit_bias
=
logit_bias
self
.
head_dtype
=
head_dtype
self
.
activation
=
activation
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"token_classify"
}
...
...
@@ -123,6 +113,7 @@ class TokenClassifierPoolerHead(TokenPoolerHead):
if
pooled_data
is
None
:
return
None
if
self
.
head_dtype
is
not
None
:
pooled_data
=
pooled_data
.
to
(
self
.
head_dtype
)
# hidden_states shape: [n_token, hidden_size]
...
...
@@ -135,8 +126,8 @@ class TokenClassifierPoolerHead(TokenPoolerHead):
if
self
.
logit_bias
is
not
None
:
scores
-=
self
.
logit_bias
if
pooling_param
.
use_activation
:
scores
=
self
.
act
_f
n
(
scores
)
if
self
.
activation
is
not
None
and
pooling_param
.
use_activation
:
scores
=
self
.
act
ivatio
n
(
scores
)
# scores shape: [n_token, num_labels]
return
scores
vllm/model_executor/layers/pooler/tokwise/poolers.py
View file @
8863c2b2
...
...
@@ -5,10 +5,15 @@ from typing import TypeAlias
import
torch
from
vllm.config
import
PoolerConfig
from
vllm.config
import
PoolerConfig
,
get_current_vllm_config
from
vllm.model_executor.layers.pooler
import
ClassifierFn
,
PoolingParamsUpdate
from
vllm.model_executor.layers.pooler.abstract
import
Pooler
from
vllm.model_executor.layers.pooler.activations
import
PoolerActivation
from
vllm.model_executor.layers.pooler.activations
import
(
PoolerActivation
,
PoolerNormalize
,
resolve_classifier_act_fn
,
)
from
vllm.model_executor.models.adapters
import
_load_st_projector
from
vllm.tasks
import
POOLING_TASKS
,
PoolingTask
from
vllm.v1.pool.metadata
import
PoolingMetadata
...
...
@@ -86,7 +91,14 @@ class TokenPooler(Pooler):
def
pooler_for_token_embed
(
pooler_config
:
PoolerConfig
):
pooling
=
get_tok_pooling_method
(
pooler_config
.
get_tok_pooling_type
())
head
=
TokenEmbeddingPoolerHead
()
vllm_config
=
get_current_vllm_config
()
model_config
=
vllm_config
.
model_config
head
=
TokenEmbeddingPoolerHead
(
projector
=
_load_st_projector
(
model_config
),
head_dtype
=
model_config
.
head_dtype
,
activation
=
PoolerNormalize
(),
)
return
TokenPooler
(
pooling
=
pooling
,
head
=
head
)
...
...
@@ -101,6 +113,15 @@ def pooler_for_token_classify(
if
pooling
is
None
:
pooling
=
get_tok_pooling_method
(
pooler_config
.
get_tok_pooling_type
())
head
=
TokenClassifierPoolerHead
(
classifier
=
classifier
,
act_fn
=
act_fn
)
vllm_config
=
get_current_vllm_config
()
model_config
=
vllm_config
.
model_config
head
=
TokenClassifierPoolerHead
(
classifier
=
classifier
,
logit_bias
=
model_config
.
pooler_config
.
logit_bias
,
head_dtype
=
model_config
.
head_dtype
,
activation
=
resolve_classifier_act_fn
(
model_config
,
static_num_labels
=
False
,
act_fn
=
act_fn
),
)
return
TokenPooler
(
pooling
=
pooling
,
head
=
head
)
vllm/model_executor/models/bert.py
View file @
8863c2b2
...
...
@@ -8,7 +8,7 @@ from torch import nn
from
transformers
import
BertConfig
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
PoolerConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
ModelConfig
,
PoolerConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention.encoder_only_attention
import
(
...
...
@@ -24,11 +24,11 @@ from vllm.model_executor.layers.pooler import (
Pooler
,
PoolingParamsUpdate
,
)
from
vllm.model_executor.layers.pooler.activations
import
LambdaPoolerActivation
from
vllm.model_executor.layers.pooler.seqwise
import
(
EmbeddingPoolerHead
,
SequencePooler
,
SequencePoolerHeadOutput
,
SequencePoolerOutput
,
SequencePoolingMethodOutput
,
get_seq_pooling_method
,
)
from
vllm.model_executor.layers.pooler.tokwise
import
(
...
...
@@ -94,26 +94,32 @@ class BertEmbedding(nn.Module):
class
BertPooler
(
SequencePooler
):
def
__init__
(
self
,
config
:
BertConfig
,
pooler_config
:
PoolerConfig
):
def
__init__
(
self
,
model_config
:
ModelConfig
):
pooler_config
=
model_config
.
pooler_config
assert
pooler_config
is
not
None
config
:
BertConfig
=
model_config
.
hf_config
super
().
__init__
(
pooling
=
get_seq_pooling_method
(
pooler_config
.
seq_pooling_type
),
head
=
self
.
head
,
# We set this dummy to avoid adding parameters to nn.Module too early
head
=
nn
.
Identity
(),
)
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
activation
=
nn
.
Tanh
()
def
head
(
self
,
pooled_data
:
SequencePoolingMethodOutput
,
pooling_metadata
:
PoolingMetadata
,
)
->
SequencePoolerHeadOutput
:
if
isinstance
(
pooled_data
,
list
):
pooled_data
=
torch
.
stack
(
pooled_data
)
head_dtype
=
model_config
.
head_dtype
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
,
dtype
=
head_dtype
,
)
self
.
act_fn
=
nn
.
Tanh
()
pooled_data
=
self
.
dense
(
pooled_data
)
pooled_data
=
self
.
activation
(
pooled_data
)
return
pooled_data
# Use lambdas so that weights are not registered under `self.head`
self
.
head
=
EmbeddingPoolerHead
(
projector
=
lambda
x
:
self
.
dense
(
x
),
head_dtype
=
head_dtype
,
activation
=
LambdaPoolerActivation
(
self
.
act_fn
),
)
class
BertEncoder
(
nn
.
Module
):
...
...
@@ -449,12 +455,7 @@ class BertPoolingModel(BertModel):
embedding_class
=
embedding_class
,
)
config
=
vllm_config
.
model_config
.
hf_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
self
.
pooler
=
BertPooler
(
config
,
pooler_config
)
self
.
pooler
=
BertPooler
(
vllm_config
.
model_config
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
other_weights
,
loaded_stacked_params
=
self
.
_load_weights
(
weights
)
...
...
vllm/model_executor/models/bert_with_rope.py
View file @
8863c2b2
...
...
@@ -466,10 +466,7 @@ class BertWithRope(nn.Module, SupportsQuant):
)
if
add_pooling_layer
:
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
self
.
pooler
=
BertPooler
(
self
.
config
,
pooler_config
)
self
.
pooler
=
BertPooler
(
vllm_config
.
model_config
)
else
:
self
.
pooler
=
None
...
...
vllm/model_executor/models/gritlm.py
View file @
8863c2b2
...
...
@@ -5,7 +5,7 @@ from collections.abc import Set
import
numpy
as
np
import
torch
from
vllm.config
import
ModelConfig
,
PoolerConfig
,
VllmConfig
from
vllm.config
import
ModelConfig
,
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.pooler
import
(
DispatchPooler
,
...
...
@@ -13,8 +13,8 @@ from vllm.model_executor.layers.pooler import (
)
from
vllm.model_executor.layers.pooler.activations
import
PoolerNormalize
from
vllm.model_executor.layers.pooler.seqwise
import
(
EmbeddingPoolerHead
,
SequencePooler
,
SequencePoolerHeadOutput
,
SequencePoolingMethod
,
SequencePoolingMethodOutput
,
get_seq_pooling_method
,
...
...
@@ -178,25 +178,22 @@ class GritLMMeanPool(SequencePoolingMethod):
class
GritLMPooler
(
SequencePooler
):
def
__init__
(
self
,
model_config
:
ModelConfig
,
pooler_config
:
PoolerConfig
):
def
__init__
(
self
,
model_config
:
ModelConfig
):
pooler_config
=
model_config
.
pooler_config
assert
pooler_config
is
not
None
super
().
__init__
(
pooling
=
(
GritLMMeanPool
(
model_config
)
if
pooler_config
.
seq_pooling_type
==
"MEAN"
else
get_seq_pooling_method
(
pooler_config
.
seq_pooling_type
)
),
head
=
self
.
head
,
head
=
EmbeddingPoolerHead
(
head_dtype
=
model_config
.
head_dtype
,
activation
=
PoolerNormalize
(),
),
)
self
.
activation
=
PoolerNormalize
()
def
head
(
self
,
pooled_data
:
SequencePoolingMethodOutput
,
pooling_metadata
:
PoolingMetadata
,
)
->
SequencePoolerHeadOutput
:
return
self
.
activation
(
pooled_data
)
@
default_pooling_type
(
seq_pooling_type
=
"MEAN"
)
class
GritLM
(
LlamaForCausalLM
):
...
...
@@ -240,6 +237,6 @@ class GritLM(LlamaForCausalLM):
self
.
pooler
=
DispatchPooler
(
{
"token_embed"
:
pooler_for_token_embed
(
pooler_config
),
"embed"
:
GritLMPooler
(
vllm_config
.
model_config
,
pooler_config
),
"embed"
:
GritLMPooler
(
vllm_config
.
model_config
),
}
)
vllm/model_executor/models/modernbert.py
View file @
8863c2b2
...
...
@@ -8,17 +8,17 @@ from transformers import ModernBertConfig
from
transformers.activations
import
ACT2FN
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
Pooler
Config
,
VllmConfig
from
vllm.config
import
Model
Config
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.attention.encoder_only_attention
import
(
EncoderOnlyAttention
,
)
from
vllm.model_executor.layers.linear
import
QKVParallelLinear
,
RowParallelLinear
from
vllm.model_executor.layers.pooler
import
DispatchPooler
from
vllm.model_executor.layers.pooler.activations
import
LambdaPoolerActivation
from
vllm.model_executor.layers.pooler.seqwise
import
(
EmbeddingPoolerHead
,
SequencePooler
,
SequencePoolerHeadOutput
,
SequencePoolingMethodOutput
,
get_seq_pooling_method
,
)
from
vllm.model_executor.layers.pooler.tokwise
import
pooler_for_token_classify
...
...
@@ -26,7 +26,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.sequence
import
IntermediateTensors
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
.interfaces
import
SupportsCrossEncoding
from
.interfaces_base
import
attn_type
,
default_pooling_type
...
...
@@ -282,7 +281,11 @@ class ModernBertModel(nn.Module):
class
ModernBertPooler
(
SequencePooler
):
def
__init__
(
self
,
config
:
ModernBertConfig
,
pooler_config
:
PoolerConfig
):
def
__init__
(
self
,
model_config
:
ModelConfig
):
pooler_config
=
model_config
.
pooler_config
assert
pooler_config
is
not
None
config
:
ModernBertConfig
=
model_config
.
hf_config
hf_pooling_type
=
config
.
classifier_pooling
.
upper
()
# vllm_pooling_type = pooler_config.seq_pooling_type
# Currently we don't have a way to see if the user set the pooling type
...
...
@@ -290,27 +293,30 @@ class ModernBertPooler(SequencePooler):
super
().
__init__
(
pooling
=
get_seq_pooling_method
(
hf_pooling_type
),
head
=
self
.
head
,
# We set this dummy to avoid adding parameters to nn.Module too early
head
=
nn
.
Identity
(),
)
head_dtype
=
model_config
.
head_dtype
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
,
config
.
classifier_bias
config
.
hidden_size
,
config
.
hidden_size
,
config
.
classifier_bias
,
dtype
=
head_dtype
,
)
self
.
act
=
nn
.
GELU
()
self
.
norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
norm_eps
,
bias
=
config
.
norm_bias
config
.
hidden_size
,
eps
=
config
.
norm_eps
,
bias
=
config
.
norm_bias
,
)
def
head
(
self
,
pooled_data
:
SequencePoolingMethodOutput
,
pooling_metadata
:
PoolingMetadata
,
)
->
SequencePoolerHeadOutput
:
if
isinstance
(
pooled_data
,
list
):
pooled_data
=
torch
.
stack
(
pooled_data
)
pooled_data
=
pooled_data
.
to
(
self
.
dense
.
weight
.
dtype
)
return
self
.
norm
(
self
.
act
(
self
.
dense
(
pooled_data
)))
# Use lambdas so that weights are not registered under `self.head`
self
.
head
=
EmbeddingPoolerHead
(
projector
=
lambda
x
:
self
.
dense
(
x
),
head_dtype
=
head_dtype
,
activation
=
LambdaPoolerActivation
(
lambda
x
:
self
.
norm
(
self
.
act
(
x
))),
)
@
default_pooling_type
(
seq_pooling_type
=
"CLS"
)
...
...
@@ -335,7 +341,7 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
self
.
pooling
=
ModernBertPooler
(
config
,
pooler
_config
)
self
.
pooling
=
ModernBertPooler
(
vllm_
config
.
model
_config
)
self
.
pooler
=
DispatchPooler
.
for_seq_cls
(
pooler_config
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment