Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8863c2b2
Unverified
Commit
8863c2b2
authored
Jan 13, 2026
by
Cyrus Leung
Committed by
GitHub
Jan 12, 2026
Browse files
[Model] Standardize pooling heads (#32148)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
3f72639d
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
182 additions
and
149 deletions
+182
-149
vllm/model_executor/layers/pooler/common.py
vllm/model_executor/layers/pooler/common.py
+6
-1
vllm/model_executor/layers/pooler/seqwise/heads.py
vllm/model_executor/layers/pooler/seqwise/heads.py
+41
-47
vllm/model_executor/layers/pooler/seqwise/poolers.py
vllm/model_executor/layers/pooler/seqwise/poolers.py
+25
-4
vllm/model_executor/layers/pooler/tokwise/heads.py
vllm/model_executor/layers/pooler/tokwise/heads.py
+23
-32
vllm/model_executor/layers/pooler/tokwise/poolers.py
vllm/model_executor/layers/pooler/tokwise/poolers.py
+25
-4
vllm/model_executor/models/bert.py
vllm/model_executor/models/bert.py
+25
-24
vllm/model_executor/models/bert_with_rope.py
vllm/model_executor/models/bert_with_rope.py
+1
-4
vllm/model_executor/models/gritlm.py
vllm/model_executor/models/gritlm.py
+11
-14
vllm/model_executor/models/modernbert.py
vllm/model_executor/models/modernbert.py
+25
-19
No files found.
vllm/model_executor/layers/pooler/common.py
View file @
8863c2b2
...
@@ -2,12 +2,17 @@
...
@@ -2,12 +2,17 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Callable
from
collections.abc
import
Callable
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TypeVar
import
torch
import
torch
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
_T
=
TypeVar
(
"_T"
,
bound
=
torch
.
Tensor
|
list
[
torch
.
Tensor
])
ProjectorFn
=
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
]
ClassifierFn
=
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
]
ClassifierFn
=
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
]
ActivationFn
=
Callable
[[
_T
],
_T
]
@
dataclass
(
frozen
=
True
)
@
dataclass
(
frozen
=
True
)
...
@@ -24,4 +29,4 @@ class PoolingParamsUpdate:
...
@@ -24,4 +29,4 @@ class PoolingParamsUpdate:
params
.
requires_token_ids
=
self
.
requires_token_ids
params
.
requires_token_ids
=
self
.
requires_token_ids
__all__
=
[
"
Classifie
rFn"
,
"PoolingParamsUpdate"
]
__all__
=
[
"
ActivationFn"
,
"ClassifierFn"
,
"Projecto
rFn"
,
"PoolingParamsUpdate"
]
vllm/model_executor/layers/pooler/seqwise/heads.py
View file @
8863c2b2
...
@@ -7,14 +7,7 @@ from typing import TypeAlias
...
@@ -7,14 +7,7 @@ from typing import TypeAlias
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.config
import
get_current_vllm_config
from
vllm.model_executor.layers.pooler
import
ActivationFn
,
ClassifierFn
,
ProjectorFn
from
vllm.model_executor.layers.pooler
import
ClassifierFn
from
vllm.model_executor.layers.pooler.activations
import
(
PoolerActivation
,
PoolerNormalize
,
resolve_classifier_act_fn
,
)
from
vllm.model_executor.models.adapters
import
_load_st_projector
from
vllm.tasks
import
PoolingTask
from
vllm.tasks
import
PoolingTask
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
vllm.v1.pool.metadata
import
PoolingMetadata
...
@@ -38,17 +31,17 @@ class SequencePoolerHead(nn.Module, ABC):
...
@@ -38,17 +31,17 @@ class SequencePoolerHead(nn.Module, ABC):
class
EmbeddingPoolerHead
(
SequencePoolerHead
):
class
EmbeddingPoolerHead
(
SequencePoolerHead
):
def
__init__
(
self
)
->
None
:
def
__init__
(
self
,
projector
:
ProjectorFn
|
None
=
None
,
head_dtype
:
torch
.
dtype
|
str
|
None
=
None
,
activation
:
ActivationFn
|
None
=
None
,
)
->
None
:
super
().
__init__
()
super
().
__init__
()
# Load ST projector if available
self
.
projector
=
projector
vllm_config
=
get_current_vllm_config
()
self
.
head_dtype
=
head_dtype
model_config
=
vllm_config
.
model_config
self
.
activation
=
activation
self
.
projector
=
_load_st_projector
(
model_config
)
self
.
head_dtype
=
model_config
.
head_dtype
self
.
activation
=
PoolerNormalize
()
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"embed"
}
return
{
"embed"
}
...
@@ -65,7 +58,8 @@ class EmbeddingPoolerHead(SequencePoolerHead):
...
@@ -65,7 +58,8 @@ class EmbeddingPoolerHead(SequencePoolerHead):
pooled_data
=
torch
.
stack
(
pooled_data
)
pooled_data
=
torch
.
stack
(
pooled_data
)
# pooled_data shape: [batchsize, hidden_dimension]
# pooled_data shape: [batchsize, hidden_dimension]
pooled_data
=
pooled_data
.
to
(
self
.
head_dtype
)
if
self
.
head_dtype
is
not
None
:
pooled_data
=
pooled_data
.
to
(
self
.
head_dtype
)
# Apply ST projector
# Apply ST projector
if
self
.
projector
is
not
None
:
if
self
.
projector
is
not
None
:
...
@@ -88,15 +82,16 @@ class EmbeddingPoolerHead(SequencePoolerHead):
...
@@ -88,15 +82,16 @@ class EmbeddingPoolerHead(SequencePoolerHead):
]
]
# for normalize
# for normalize
flags
=
[
p
.
normalize
for
p
in
pooling_params
]
if
self
.
activation
is
not
None
:
if
len
(
set
(
flags
))
==
1
:
flags
=
[
p
.
normalize
for
p
in
pooling_params
]
if
flags
[
0
]:
if
len
(
set
(
flags
))
==
1
:
pooled_data
=
self
.
activation
(
pooled_data
)
if
flags
[
0
]:
else
:
pooled_data
=
self
.
activation
(
pooled_data
)
pooled_data
=
[
else
:
self
.
activation
(
vecs
)
if
f
else
vecs
pooled_data
=
[
for
vecs
,
f
in
zip
(
pooled_data
,
flags
)
self
.
activation
(
vecs
)
if
f
else
vecs
]
for
vecs
,
f
in
zip
(
pooled_data
,
flags
)
]
# pooled_data shape: [batchsize, embedding_dimension]
# pooled_data shape: [batchsize, embedding_dimension]
return
pooled_data
return
pooled_data
...
@@ -106,20 +101,16 @@ class ClassifierPoolerHead(SequencePoolerHead):
...
@@ -106,20 +101,16 @@ class ClassifierPoolerHead(SequencePoolerHead):
def
__init__
(
def
__init__
(
self
,
self
,
classifier
:
ClassifierFn
|
None
=
None
,
classifier
:
ClassifierFn
|
None
=
None
,
act_fn
:
PoolerActivation
|
str
|
None
=
None
,
logit_bias
:
float
|
None
=
None
,
head_dtype
:
torch
.
dtype
|
str
|
None
=
None
,
activation
:
ActivationFn
|
None
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
vllm_config
=
get_current_vllm_config
()
model_config
=
vllm_config
.
model_config
self
.
classifier
=
classifier
self
.
classifier
=
classifier
self
.
logit_bias
:
float
|
None
=
model_config
.
pooler_config
.
logit_bias
self
.
logit_bias
=
logit_bias
self
.
head_dtype
=
model_config
.
head_dtype
self
.
head_dtype
=
head_dtype
self
.
activation
=
activation
self
.
act_fn
=
resolve_classifier_act_fn
(
model_config
,
static_num_labels
=
True
,
act_fn
=
act_fn
)
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"classify"
,
"score"
}
return
{
"classify"
,
"score"
}
...
@@ -136,7 +127,8 @@ class ClassifierPoolerHead(SequencePoolerHead):
...
@@ -136,7 +127,8 @@ class ClassifierPoolerHead(SequencePoolerHead):
pooled_data
=
torch
.
stack
(
pooled_data
)
pooled_data
=
torch
.
stack
(
pooled_data
)
# pooled_data shape: [batchsize, hidden_size]
# pooled_data shape: [batchsize, hidden_size]
pooled_data
=
pooled_data
.
to
(
self
.
head_dtype
)
if
self
.
head_dtype
is
not
None
:
pooled_data
=
pooled_data
.
to
(
self
.
head_dtype
)
if
self
.
classifier
is
not
None
:
if
self
.
classifier
is
not
None
:
pooled_data
=
self
.
classifier
(
pooled_data
)
pooled_data
=
self
.
classifier
(
pooled_data
)
...
@@ -145,13 +137,15 @@ class ClassifierPoolerHead(SequencePoolerHead):
...
@@ -145,13 +137,15 @@ class ClassifierPoolerHead(SequencePoolerHead):
if
self
.
logit_bias
is
not
None
:
if
self
.
logit_bias
is
not
None
:
pooled_data
-=
self
.
logit_bias
pooled_data
-=
self
.
logit_bias
flags
=
[
p
.
use_activation
for
p
in
pooling_params
]
if
self
.
activation
is
not
None
:
if
len
(
set
(
flags
))
==
1
:
flags
=
[
p
.
use_activation
for
p
in
pooling_params
]
scores
=
self
.
act_fn
(
pooled_data
)
if
flags
[
0
]
else
pooled_data
if
len
(
set
(
flags
))
==
1
:
else
:
pooled_data
=
self
.
activation
(
pooled_data
)
if
flags
[
0
]
else
pooled_data
scores
=
[
else
:
self
.
act_fn
(
vecs
)
if
f
else
vecs
for
vecs
,
f
in
zip
(
pooled_data
,
flags
)
pooled_data
=
[
]
self
.
activation
(
vecs
)
if
f
else
vecs
for
vecs
,
f
in
zip
(
pooled_data
,
flags
)
]
#
scores
shape: [batchsize, num_labels]
#
pooled_data
shape: [batchsize, num_labels]
return
scores
return
pooled_data
vllm/model_executor/layers/pooler/seqwise/poolers.py
View file @
8863c2b2
...
@@ -5,10 +5,15 @@ from typing import TypeAlias
...
@@ -5,10 +5,15 @@ from typing import TypeAlias
import
torch
import
torch
from
vllm.config
import
PoolerConfig
from
vllm.config
import
PoolerConfig
,
get_current_vllm_config
from
vllm.model_executor.layers.pooler
import
ClassifierFn
,
PoolingParamsUpdate
from
vllm.model_executor.layers.pooler
import
ClassifierFn
,
PoolingParamsUpdate
from
vllm.model_executor.layers.pooler.abstract
import
Pooler
from
vllm.model_executor.layers.pooler.abstract
import
Pooler
from
vllm.model_executor.layers.pooler.activations
import
PoolerActivation
from
vllm.model_executor.layers.pooler.activations
import
(
PoolerActivation
,
PoolerNormalize
,
resolve_classifier_act_fn
,
)
from
vllm.model_executor.models.adapters
import
_load_st_projector
from
vllm.tasks
import
POOLING_TASKS
,
PoolingTask
from
vllm.tasks
import
POOLING_TASKS
,
PoolingTask
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
vllm.v1.pool.metadata
import
PoolingMetadata
...
@@ -86,7 +91,14 @@ class SequencePooler(Pooler):
...
@@ -86,7 +91,14 @@ class SequencePooler(Pooler):
def
pooler_for_embed
(
pooler_config
:
PoolerConfig
):
def
pooler_for_embed
(
pooler_config
:
PoolerConfig
):
pooling
=
get_seq_pooling_method
(
pooler_config
.
get_seq_pooling_type
())
pooling
=
get_seq_pooling_method
(
pooler_config
.
get_seq_pooling_type
())
head
=
EmbeddingPoolerHead
()
vllm_config
=
get_current_vllm_config
()
model_config
=
vllm_config
.
model_config
head
=
EmbeddingPoolerHead
(
projector
=
_load_st_projector
(
model_config
),
head_dtype
=
model_config
.
head_dtype
,
activation
=
PoolerNormalize
(),
)
return
SequencePooler
(
pooling
=
pooling
,
head
=
head
)
return
SequencePooler
(
pooling
=
pooling
,
head
=
head
)
...
@@ -101,6 +113,15 @@ def pooler_for_classify(
...
@@ -101,6 +113,15 @@ def pooler_for_classify(
if
pooling
is
None
:
if
pooling
is
None
:
pooling
=
get_seq_pooling_method
(
pooler_config
.
get_seq_pooling_type
())
pooling
=
get_seq_pooling_method
(
pooler_config
.
get_seq_pooling_type
())
head
=
ClassifierPoolerHead
(
classifier
=
classifier
,
act_fn
=
act_fn
)
vllm_config
=
get_current_vllm_config
()
model_config
=
vllm_config
.
model_config
head
=
ClassifierPoolerHead
(
classifier
=
classifier
,
logit_bias
=
model_config
.
pooler_config
.
logit_bias
,
head_dtype
=
model_config
.
head_dtype
,
activation
=
resolve_classifier_act_fn
(
model_config
,
static_num_labels
=
True
,
act_fn
=
act_fn
),
)
return
SequencePooler
(
pooling
=
pooling
,
head
=
head
)
return
SequencePooler
(
pooling
=
pooling
,
head
=
head
)
vllm/model_executor/layers/pooler/tokwise/heads.py
View file @
8863c2b2
...
@@ -7,14 +7,7 @@ from typing import TypeAlias
...
@@ -7,14 +7,7 @@ from typing import TypeAlias
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.config
import
get_current_vllm_config
from
vllm.model_executor.layers.pooler
import
ActivationFn
,
ClassifierFn
,
ProjectorFn
from
vllm.model_executor.layers.pooler
import
ClassifierFn
from
vllm.model_executor.layers.pooler.activations
import
(
PoolerActivation
,
PoolerNormalize
,
resolve_classifier_act_fn
,
)
from
vllm.model_executor.models.adapters
import
_load_st_projector
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.tasks
import
PoolingTask
from
vllm.tasks
import
PoolingTask
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
vllm.v1.pool.metadata
import
PoolingMetadata
...
@@ -49,17 +42,17 @@ class TokenPoolerHead(nn.Module, ABC):
...
@@ -49,17 +42,17 @@ class TokenPoolerHead(nn.Module, ABC):
class
TokenEmbeddingPoolerHead
(
TokenPoolerHead
):
class
TokenEmbeddingPoolerHead
(
TokenPoolerHead
):
def
__init__
(
self
)
->
None
:
def
__init__
(
self
,
projector
:
ProjectorFn
|
None
=
None
,
head_dtype
:
torch
.
dtype
|
str
|
None
=
None
,
activation
:
ActivationFn
|
None
=
None
,
)
->
None
:
super
().
__init__
()
super
().
__init__
()
# Load ST projector if available
self
.
projector
=
projector
vllm_config
=
get_current_vllm_config
()
self
.
head_dtype
=
head_dtype
model_config
=
vllm_config
.
model_config
self
.
activation
=
activation
self
.
projector
=
_load_st_projector
(
model_config
)
self
.
head_dtype
=
model_config
.
head_dtype
self
.
activation
=
PoolerNormalize
()
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"token_embed"
}
return
{
"token_embed"
}
...
@@ -73,7 +66,8 @@ class TokenEmbeddingPoolerHead(TokenPoolerHead):
...
@@ -73,7 +66,8 @@ class TokenEmbeddingPoolerHead(TokenPoolerHead):
if
pooled_data
is
None
:
if
pooled_data
is
None
:
return
None
return
None
pooled_data
=
pooled_data
.
to
(
self
.
head_dtype
)
if
self
.
head_dtype
is
not
None
:
pooled_data
=
pooled_data
.
to
(
self
.
head_dtype
)
# pooled_data shape: [n_tokens, hidden_dimension]
# pooled_data shape: [n_tokens, hidden_dimension]
# Apply ST projector
# Apply ST projector
...
@@ -85,7 +79,7 @@ class TokenEmbeddingPoolerHead(TokenPoolerHead):
...
@@ -85,7 +79,7 @@ class TokenEmbeddingPoolerHead(TokenPoolerHead):
pooled_data
=
pooled_data
[...,
:
pooling_param
.
dimensions
]
pooled_data
=
pooled_data
[...,
:
pooling_param
.
dimensions
]
# for normalize
# for normalize
if
pooling_param
.
normalize
:
if
self
.
activation
is
not
None
and
pooling_param
.
normalize
:
pooled_data
=
self
.
activation
(
pooled_data
)
pooled_data
=
self
.
activation
(
pooled_data
)
# pooled_data shape: [n_tokens, embedding_dimension]
# pooled_data shape: [n_tokens, embedding_dimension]
...
@@ -96,20 +90,16 @@ class TokenClassifierPoolerHead(TokenPoolerHead):
...
@@ -96,20 +90,16 @@ class TokenClassifierPoolerHead(TokenPoolerHead):
def
__init__
(
def
__init__
(
self
,
self
,
classifier
:
ClassifierFn
|
None
=
None
,
classifier
:
ClassifierFn
|
None
=
None
,
act_fn
:
PoolerActivation
|
str
|
None
=
None
,
logit_bias
:
float
|
None
=
None
,
head_dtype
:
torch
.
dtype
|
str
|
None
=
None
,
activation
:
ActivationFn
|
None
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
vllm_config
=
get_current_vllm_config
()
model_config
=
vllm_config
.
model_config
self
.
classifier
=
classifier
self
.
classifier
=
classifier
self
.
logit_bias
:
float
|
None
=
model_config
.
pooler_config
.
logit_bias
self
.
logit_bias
=
logit_bias
self
.
head_dtype
=
model_config
.
head_dtype
self
.
head_dtype
=
head_dtype
self
.
activation
=
activation
self
.
act_fn
=
resolve_classifier_act_fn
(
model_config
,
static_num_labels
=
False
,
act_fn
=
act_fn
)
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"token_classify"
}
return
{
"token_classify"
}
...
@@ -123,7 +113,8 @@ class TokenClassifierPoolerHead(TokenPoolerHead):
...
@@ -123,7 +113,8 @@ class TokenClassifierPoolerHead(TokenPoolerHead):
if
pooled_data
is
None
:
if
pooled_data
is
None
:
return
None
return
None
pooled_data
=
pooled_data
.
to
(
self
.
head_dtype
)
if
self
.
head_dtype
is
not
None
:
pooled_data
=
pooled_data
.
to
(
self
.
head_dtype
)
# hidden_states shape: [n_token, hidden_size]
# hidden_states shape: [n_token, hidden_size]
if
self
.
classifier
is
not
None
:
if
self
.
classifier
is
not
None
:
...
@@ -135,8 +126,8 @@ class TokenClassifierPoolerHead(TokenPoolerHead):
...
@@ -135,8 +126,8 @@ class TokenClassifierPoolerHead(TokenPoolerHead):
if
self
.
logit_bias
is
not
None
:
if
self
.
logit_bias
is
not
None
:
scores
-=
self
.
logit_bias
scores
-=
self
.
logit_bias
if
pooling_param
.
use_activation
:
if
self
.
activation
is
not
None
and
pooling_param
.
use_activation
:
scores
=
self
.
act
_f
n
(
scores
)
scores
=
self
.
act
ivatio
n
(
scores
)
# scores shape: [n_token, num_labels]
# scores shape: [n_token, num_labels]
return
scores
return
scores
vllm/model_executor/layers/pooler/tokwise/poolers.py
View file @
8863c2b2
...
@@ -5,10 +5,15 @@ from typing import TypeAlias
...
@@ -5,10 +5,15 @@ from typing import TypeAlias
import
torch
import
torch
from
vllm.config
import
PoolerConfig
from
vllm.config
import
PoolerConfig
,
get_current_vllm_config
from
vllm.model_executor.layers.pooler
import
ClassifierFn
,
PoolingParamsUpdate
from
vllm.model_executor.layers.pooler
import
ClassifierFn
,
PoolingParamsUpdate
from
vllm.model_executor.layers.pooler.abstract
import
Pooler
from
vllm.model_executor.layers.pooler.abstract
import
Pooler
from
vllm.model_executor.layers.pooler.activations
import
PoolerActivation
from
vllm.model_executor.layers.pooler.activations
import
(
PoolerActivation
,
PoolerNormalize
,
resolve_classifier_act_fn
,
)
from
vllm.model_executor.models.adapters
import
_load_st_projector
from
vllm.tasks
import
POOLING_TASKS
,
PoolingTask
from
vllm.tasks
import
POOLING_TASKS
,
PoolingTask
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
vllm.v1.pool.metadata
import
PoolingMetadata
...
@@ -86,7 +91,14 @@ class TokenPooler(Pooler):
...
@@ -86,7 +91,14 @@ class TokenPooler(Pooler):
def
pooler_for_token_embed
(
pooler_config
:
PoolerConfig
):
def
pooler_for_token_embed
(
pooler_config
:
PoolerConfig
):
pooling
=
get_tok_pooling_method
(
pooler_config
.
get_tok_pooling_type
())
pooling
=
get_tok_pooling_method
(
pooler_config
.
get_tok_pooling_type
())
head
=
TokenEmbeddingPoolerHead
()
vllm_config
=
get_current_vllm_config
()
model_config
=
vllm_config
.
model_config
head
=
TokenEmbeddingPoolerHead
(
projector
=
_load_st_projector
(
model_config
),
head_dtype
=
model_config
.
head_dtype
,
activation
=
PoolerNormalize
(),
)
return
TokenPooler
(
pooling
=
pooling
,
head
=
head
)
return
TokenPooler
(
pooling
=
pooling
,
head
=
head
)
...
@@ -101,6 +113,15 @@ def pooler_for_token_classify(
...
@@ -101,6 +113,15 @@ def pooler_for_token_classify(
if
pooling
is
None
:
if
pooling
is
None
:
pooling
=
get_tok_pooling_method
(
pooler_config
.
get_tok_pooling_type
())
pooling
=
get_tok_pooling_method
(
pooler_config
.
get_tok_pooling_type
())
head
=
TokenClassifierPoolerHead
(
classifier
=
classifier
,
act_fn
=
act_fn
)
vllm_config
=
get_current_vllm_config
()
model_config
=
vllm_config
.
model_config
head
=
TokenClassifierPoolerHead
(
classifier
=
classifier
,
logit_bias
=
model_config
.
pooler_config
.
logit_bias
,
head_dtype
=
model_config
.
head_dtype
,
activation
=
resolve_classifier_act_fn
(
model_config
,
static_num_labels
=
False
,
act_fn
=
act_fn
),
)
return
TokenPooler
(
pooling
=
pooling
,
head
=
head
)
return
TokenPooler
(
pooling
=
pooling
,
head
=
head
)
vllm/model_executor/models/bert.py
View file @
8863c2b2
...
@@ -8,7 +8,7 @@ from torch import nn
...
@@ -8,7 +8,7 @@ from torch import nn
from
transformers
import
BertConfig
from
transformers
import
BertConfig
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
PoolerConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
ModelConfig
,
PoolerConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention.encoder_only_attention
import
(
from
vllm.model_executor.layers.attention.encoder_only_attention
import
(
...
@@ -24,11 +24,11 @@ from vllm.model_executor.layers.pooler import (
...
@@ -24,11 +24,11 @@ from vllm.model_executor.layers.pooler import (
Pooler
,
Pooler
,
PoolingParamsUpdate
,
PoolingParamsUpdate
,
)
)
from
vllm.model_executor.layers.pooler.activations
import
LambdaPoolerActivation
from
vllm.model_executor.layers.pooler.seqwise
import
(
from
vllm.model_executor.layers.pooler.seqwise
import
(
EmbeddingPoolerHead
,
SequencePooler
,
SequencePooler
,
SequencePoolerHeadOutput
,
SequencePoolerOutput
,
SequencePoolerOutput
,
SequencePoolingMethodOutput
,
get_seq_pooling_method
,
get_seq_pooling_method
,
)
)
from
vllm.model_executor.layers.pooler.tokwise
import
(
from
vllm.model_executor.layers.pooler.tokwise
import
(
...
@@ -94,26 +94,32 @@ class BertEmbedding(nn.Module):
...
@@ -94,26 +94,32 @@ class BertEmbedding(nn.Module):
class
BertPooler
(
SequencePooler
):
class
BertPooler
(
SequencePooler
):
def
__init__
(
self
,
config
:
BertConfig
,
pooler_config
:
PoolerConfig
):
def
__init__
(
self
,
model_config
:
ModelConfig
):
pooler_config
=
model_config
.
pooler_config
assert
pooler_config
is
not
None
config
:
BertConfig
=
model_config
.
hf_config
super
().
__init__
(
super
().
__init__
(
pooling
=
get_seq_pooling_method
(
pooler_config
.
seq_pooling_type
),
pooling
=
get_seq_pooling_method
(
pooler_config
.
seq_pooling_type
),
head
=
self
.
head
,
# We set this dummy to avoid adding parameters to nn.Module too early
head
=
nn
.
Identity
(),
)
)
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
)
head_dtype
=
model_config
.
head_dtype
self
.
activation
=
nn
.
Tanh
()
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
def
head
(
config
.
hidden_size
,
self
,
dtype
=
head_dtype
,
pooled_data
:
SequencePoolingMethodOutput
,
)
pooling_metadata
:
PoolingMetadata
,
self
.
act_fn
=
nn
.
Tanh
()
)
->
SequencePoolerHeadOutput
:
if
isinstance
(
pooled_data
,
list
):
pooled_data
=
torch
.
stack
(
pooled_data
)
pooled_data
=
self
.
dense
(
pooled_data
)
# Use lambdas so that weights are not registered under `self.head`
pooled_data
=
self
.
activation
(
pooled_data
)
self
.
head
=
EmbeddingPoolerHead
(
return
pooled_data
projector
=
lambda
x
:
self
.
dense
(
x
),
head_dtype
=
head_dtype
,
activation
=
LambdaPoolerActivation
(
self
.
act_fn
),
)
class
BertEncoder
(
nn
.
Module
):
class
BertEncoder
(
nn
.
Module
):
...
@@ -449,12 +455,7 @@ class BertPoolingModel(BertModel):
...
@@ -449,12 +455,7 @@ class BertPoolingModel(BertModel):
embedding_class
=
embedding_class
,
embedding_class
=
embedding_class
,
)
)
config
=
vllm_config
.
model_config
.
hf_config
self
.
pooler
=
BertPooler
(
vllm_config
.
model_config
)
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
self
.
pooler
=
BertPooler
(
config
,
pooler_config
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
other_weights
,
loaded_stacked_params
=
self
.
_load_weights
(
weights
)
other_weights
,
loaded_stacked_params
=
self
.
_load_weights
(
weights
)
...
...
vllm/model_executor/models/bert_with_rope.py
View file @
8863c2b2
...
@@ -466,10 +466,7 @@ class BertWithRope(nn.Module, SupportsQuant):
...
@@ -466,10 +466,7 @@ class BertWithRope(nn.Module, SupportsQuant):
)
)
if
add_pooling_layer
:
if
add_pooling_layer
:
pooler_config
=
vllm_config
.
model_config
.
pooler_config
self
.
pooler
=
BertPooler
(
vllm_config
.
model_config
)
assert
pooler_config
is
not
None
self
.
pooler
=
BertPooler
(
self
.
config
,
pooler_config
)
else
:
else
:
self
.
pooler
=
None
self
.
pooler
=
None
...
...
vllm/model_executor/models/gritlm.py
View file @
8863c2b2
...
@@ -5,7 +5,7 @@ from collections.abc import Set
...
@@ -5,7 +5,7 @@ from collections.abc import Set
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
vllm.config
import
ModelConfig
,
PoolerConfig
,
VllmConfig
from
vllm.config
import
ModelConfig
,
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.pooler
import
(
from
vllm.model_executor.layers.pooler
import
(
DispatchPooler
,
DispatchPooler
,
...
@@ -13,8 +13,8 @@ from vllm.model_executor.layers.pooler import (
...
@@ -13,8 +13,8 @@ from vllm.model_executor.layers.pooler import (
)
)
from
vllm.model_executor.layers.pooler.activations
import
PoolerNormalize
from
vllm.model_executor.layers.pooler.activations
import
PoolerNormalize
from
vllm.model_executor.layers.pooler.seqwise
import
(
from
vllm.model_executor.layers.pooler.seqwise
import
(
EmbeddingPoolerHead
,
SequencePooler
,
SequencePooler
,
SequencePoolerHeadOutput
,
SequencePoolingMethod
,
SequencePoolingMethod
,
SequencePoolingMethodOutput
,
SequencePoolingMethodOutput
,
get_seq_pooling_method
,
get_seq_pooling_method
,
...
@@ -178,25 +178,22 @@ class GritLMMeanPool(SequencePoolingMethod):
...
@@ -178,25 +178,22 @@ class GritLMMeanPool(SequencePoolingMethod):
class
GritLMPooler
(
SequencePooler
):
class
GritLMPooler
(
SequencePooler
):
def
__init__
(
self
,
model_config
:
ModelConfig
,
pooler_config
:
PoolerConfig
):
def
__init__
(
self
,
model_config
:
ModelConfig
):
pooler_config
=
model_config
.
pooler_config
assert
pooler_config
is
not
None
super
().
__init__
(
super
().
__init__
(
pooling
=
(
pooling
=
(
GritLMMeanPool
(
model_config
)
GritLMMeanPool
(
model_config
)
if
pooler_config
.
seq_pooling_type
==
"MEAN"
if
pooler_config
.
seq_pooling_type
==
"MEAN"
else
get_seq_pooling_method
(
pooler_config
.
seq_pooling_type
)
else
get_seq_pooling_method
(
pooler_config
.
seq_pooling_type
)
),
),
head
=
self
.
head
,
head
=
EmbeddingPoolerHead
(
head_dtype
=
model_config
.
head_dtype
,
activation
=
PoolerNormalize
(),
),
)
)
self
.
activation
=
PoolerNormalize
()
def
head
(
self
,
pooled_data
:
SequencePoolingMethodOutput
,
pooling_metadata
:
PoolingMetadata
,
)
->
SequencePoolerHeadOutput
:
return
self
.
activation
(
pooled_data
)
@
default_pooling_type
(
seq_pooling_type
=
"MEAN"
)
@
default_pooling_type
(
seq_pooling_type
=
"MEAN"
)
class
GritLM
(
LlamaForCausalLM
):
class
GritLM
(
LlamaForCausalLM
):
...
@@ -240,6 +237,6 @@ class GritLM(LlamaForCausalLM):
...
@@ -240,6 +237,6 @@ class GritLM(LlamaForCausalLM):
self
.
pooler
=
DispatchPooler
(
self
.
pooler
=
DispatchPooler
(
{
{
"token_embed"
:
pooler_for_token_embed
(
pooler_config
),
"token_embed"
:
pooler_for_token_embed
(
pooler_config
),
"embed"
:
GritLMPooler
(
vllm_config
.
model_config
,
pooler_config
),
"embed"
:
GritLMPooler
(
vllm_config
.
model_config
),
}
}
)
)
vllm/model_executor/models/modernbert.py
View file @
8863c2b2
...
@@ -8,17 +8,17 @@ from transformers import ModernBertConfig
...
@@ -8,17 +8,17 @@ from transformers import ModernBertConfig
from
transformers.activations
import
ACT2FN
from
transformers.activations
import
ACT2FN
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
Pooler
Config
,
VllmConfig
from
vllm.config
import
Model
Config
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.attention.encoder_only_attention
import
(
from
vllm.model_executor.layers.attention.encoder_only_attention
import
(
EncoderOnlyAttention
,
EncoderOnlyAttention
,
)
)
from
vllm.model_executor.layers.linear
import
QKVParallelLinear
,
RowParallelLinear
from
vllm.model_executor.layers.linear
import
QKVParallelLinear
,
RowParallelLinear
from
vllm.model_executor.layers.pooler
import
DispatchPooler
from
vllm.model_executor.layers.pooler
import
DispatchPooler
from
vllm.model_executor.layers.pooler.activations
import
LambdaPoolerActivation
from
vllm.model_executor.layers.pooler.seqwise
import
(
from
vllm.model_executor.layers.pooler.seqwise
import
(
EmbeddingPoolerHead
,
SequencePooler
,
SequencePooler
,
SequencePoolerHeadOutput
,
SequencePoolingMethodOutput
,
get_seq_pooling_method
,
get_seq_pooling_method
,
)
)
from
vllm.model_executor.layers.pooler.tokwise
import
pooler_for_token_classify
from
vllm.model_executor.layers.pooler.tokwise
import
pooler_for_token_classify
...
@@ -26,7 +26,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
...
@@ -26,7 +26,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
.interfaces
import
SupportsCrossEncoding
from
.interfaces
import
SupportsCrossEncoding
from
.interfaces_base
import
attn_type
,
default_pooling_type
from
.interfaces_base
import
attn_type
,
default_pooling_type
...
@@ -282,7 +281,11 @@ class ModernBertModel(nn.Module):
...
@@ -282,7 +281,11 @@ class ModernBertModel(nn.Module):
class
ModernBertPooler
(
SequencePooler
):
class
ModernBertPooler
(
SequencePooler
):
def
__init__
(
self
,
config
:
ModernBertConfig
,
pooler_config
:
PoolerConfig
):
def
__init__
(
self
,
model_config
:
ModelConfig
):
pooler_config
=
model_config
.
pooler_config
assert
pooler_config
is
not
None
config
:
ModernBertConfig
=
model_config
.
hf_config
hf_pooling_type
=
config
.
classifier_pooling
.
upper
()
hf_pooling_type
=
config
.
classifier_pooling
.
upper
()
# vllm_pooling_type = pooler_config.seq_pooling_type
# vllm_pooling_type = pooler_config.seq_pooling_type
# Currently we don't have a way to see if the user set the pooling type
# Currently we don't have a way to see if the user set the pooling type
...
@@ -290,27 +293,30 @@ class ModernBertPooler(SequencePooler):
...
@@ -290,27 +293,30 @@ class ModernBertPooler(SequencePooler):
super
().
__init__
(
super
().
__init__
(
pooling
=
get_seq_pooling_method
(
hf_pooling_type
),
pooling
=
get_seq_pooling_method
(
hf_pooling_type
),
head
=
self
.
head
,
# We set this dummy to avoid adding parameters to nn.Module too early
head
=
nn
.
Identity
(),
)
)
head_dtype
=
model_config
.
head_dtype
self
.
dense
=
nn
.
Linear
(
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
,
config
.
classifier_bias
config
.
hidden_size
,
config
.
hidden_size
,
config
.
classifier_bias
,
dtype
=
head_dtype
,
)
)
self
.
act
=
nn
.
GELU
()
self
.
act
=
nn
.
GELU
()
self
.
norm
=
nn
.
LayerNorm
(
self
.
norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
norm_eps
,
bias
=
config
.
norm_bias
config
.
hidden_size
,
eps
=
config
.
norm_eps
,
bias
=
config
.
norm_bias
,
)
)
def
head
(
# Use lambdas so that weights are not registered under `self.head`
self
,
self
.
head
=
EmbeddingPoolerHead
(
pooled_data
:
SequencePoolingMethodOutput
,
projector
=
lambda
x
:
self
.
dense
(
x
),
pooling_metadata
:
PoolingMetadata
,
head_dtype
=
head_dtype
,
)
->
SequencePoolerHeadOutput
:
activation
=
LambdaPoolerActivation
(
lambda
x
:
self
.
norm
(
self
.
act
(
x
))),
if
isinstance
(
pooled_data
,
list
):
)
pooled_data
=
torch
.
stack
(
pooled_data
)
pooled_data
=
pooled_data
.
to
(
self
.
dense
.
weight
.
dtype
)
return
self
.
norm
(
self
.
act
(
self
.
dense
(
pooled_data
)))
@
default_pooling_type
(
seq_pooling_type
=
"CLS"
)
@
default_pooling_type
(
seq_pooling_type
=
"CLS"
)
...
@@ -335,7 +341,7 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
...
@@ -335,7 +341,7 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
pooler_config
=
vllm_config
.
model_config
.
pooler_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
assert
pooler_config
is
not
None
self
.
pooling
=
ModernBertPooler
(
config
,
pooler
_config
)
self
.
pooling
=
ModernBertPooler
(
vllm_
config
.
model
_config
)
self
.
pooler
=
DispatchPooler
.
for_seq_cls
(
self
.
pooler
=
DispatchPooler
.
for_seq_cls
(
pooler_config
,
pooler_config
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment