Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c8ed39b9
Unverified
Commit
c8ed39b9
authored
Jan 09, 2026
by
Cyrus Leung
Committed by
GitHub
Jan 09, 2026
Browse files
[Model] Reorganize pooling layers (#31973)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
02073280
Changes
34
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
75 additions
and
209 deletions
+75
-209
vllm/model_executor/models/bert_with_rope.py
vllm/model_executor/models/bert_with_rope.py
+5
-15
vllm/model_executor/models/clip.py
vllm/model_executor/models/clip.py
+2
-7
vllm/model_executor/models/gpt2.py
vllm/model_executor/models/gpt2.py
+2
-14
vllm/model_executor/models/gritlm.py
vllm/model_executor/models/gritlm.py
+18
-29
vllm/model_executor/models/internlm2.py
vllm/model_executor/models/internlm2.py
+2
-4
vllm/model_executor/models/jamba.py
vllm/model_executor/models/jamba.py
+2
-14
vllm/model_executor/models/jina_vl.py
vllm/model_executor/models/jina_vl.py
+2
-14
vllm/model_executor/models/modernbert.py
vllm/model_executor/models/modernbert.py
+20
-51
vllm/model_executor/models/qwen2_rm.py
vllm/model_executor/models/qwen2_rm.py
+4
-7
vllm/model_executor/models/roberta.py
vllm/model_executor/models/roberta.py
+6
-18
vllm/model_executor/models/siglip.py
vllm/model_executor/models/siglip.py
+2
-7
vllm/model_executor/models/terratorch.py
vllm/model_executor/models/terratorch.py
+2
-2
vllm/model_executor/models/transformers/pooling.py
vllm/model_executor/models/transformers/pooling.py
+7
-24
vllm/v1/outputs.py
vllm/v1/outputs.py
+1
-3
No files found.
vllm/model_executor/models/bert_with_rope.py
View file @
c8ed39b9
...
@@ -24,6 +24,7 @@ from vllm.model_executor.layers.linear import (
...
@@ -24,6 +24,7 @@ from vllm.model_executor.layers.linear import (
ReplicatedLinear
,
ReplicatedLinear
,
RowParallelLinear
,
RowParallelLinear
,
)
)
from
vllm.model_executor.layers.pooler
import
DispatchPooler
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
...
@@ -37,7 +38,6 @@ from vllm.model_executor.utils import set_weight_attrs
...
@@ -37,7 +38,6 @@ from vllm.model_executor.utils import set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
..layers.pooler
import
ClassifierPooler
,
DispatchPooler
,
Pooler
from
.bert
import
BertPooler
from
.bert
import
BertPooler
from
.interfaces
import
SupportsCrossEncoding
,
SupportsQuant
from
.interfaces
import
SupportsCrossEncoding
,
SupportsQuant
from
.interfaces_base
import
default_pooling_type
from
.interfaces_base
import
default_pooling_type
...
@@ -693,20 +693,10 @@ class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding):
...
@@ -693,20 +693,10 @@ class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding):
pooler_config
=
vllm_config
.
model_config
.
pooler_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
assert
pooler_config
is
not
None
self
.
pooler
=
DispatchPooler
(
self
.
pooler
=
DispatchPooler
.
for_seq_cls
(
{
pooler_config
,
"token_classify"
:
Pooler
.
for_token_classify
(
pooler_config
,
classifier
=
self
.
classifier
),
"classify"
:
ClassifierPooler
(
pooling
=
self
.
new
.
pooler
,
pooling
=
self
.
new
.
pooler
,
classifier
=
self
.
classifier
,
classifier
=
self
.
classifier
,
act_fn
=
"classify"
,
),
"score"
:
ClassifierPooler
(
pooling
=
self
.
new
.
pooler
,
classifier
=
self
.
classifier
,
act_fn
=
"score"
),
}
)
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
...
...
vllm/model_executor/models/clip.py
View file @
c8ed39b9
...
@@ -26,7 +26,7 @@ from vllm.model_executor.layers.linear import (
...
@@ -26,7 +26,7 @@ from vllm.model_executor.layers.linear import (
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
RowParallelLinear
,
)
)
from
vllm.model_executor.layers.pooler
import
DispatchPooler
,
Pooler
from
vllm.model_executor.layers.pooler
import
DispatchPooler
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
...
@@ -880,12 +880,7 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
...
@@ -880,12 +880,7 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
assert
pooler_config
is
not
None
assert
pooler_config
is
not
None
self
.
pooler_config
=
pooler_config
self
.
pooler_config
=
pooler_config
self
.
pooler
=
DispatchPooler
(
self
.
pooler
=
DispatchPooler
.
for_embedding
(
pooler_config
)
{
"token_embed"
:
Pooler
.
for_token_embed
(
pooler_config
),
"embed"
:
Pooler
.
for_embed
(
pooler_config
),
}
)
# Assumes that self.forward is called after self.embed_input_ids
# Assumes that self.forward is called after self.embed_input_ids
self
.
_is_text_input
=
True
self
.
_is_text_input
=
True
...
...
vllm/model_executor/models/gpt2.py
View file @
c8ed39b9
...
@@ -41,6 +41,7 @@ from vllm.model_executor.layers.linear import (
...
@@ -41,6 +41,7 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear
,
RowParallelLinear
,
)
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.pooler
import
DispatchPooler
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
ParallelLMHead
,
...
@@ -49,7 +50,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -49,7 +50,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
..layers.pooler
import
DispatchPooler
,
Pooler
from
.interfaces
import
SupportsCrossEncoding
,
SupportsPP
from
.interfaces
import
SupportsCrossEncoding
,
SupportsPP
from
.utils
import
(
from
.utils
import
(
AutoWeightsLoader
,
AutoWeightsLoader
,
...
@@ -351,19 +351,7 @@ class GPT2ForSequenceClassification(nn.Module, SupportsCrossEncoding):
...
@@ -351,19 +351,7 @@ class GPT2ForSequenceClassification(nn.Module, SupportsCrossEncoding):
pooler_config
=
vllm_config
.
model_config
.
pooler_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
assert
pooler_config
is
not
None
self
.
pooler
=
DispatchPooler
(
self
.
pooler
=
DispatchPooler
.
for_seq_cls
(
pooler_config
,
classifier
=
self
.
score
)
{
"token_classify"
:
Pooler
.
for_token_classify
(
pooler_config
,
classifier
=
self
.
score
),
"classify"
:
Pooler
.
for_classify
(
pooler_config
,
classifier
=
self
.
score
,
act_fn
=
"classify"
),
"score"
:
Pooler
.
for_classify
(
pooler_config
,
classifier
=
self
.
score
,
act_fn
=
"score"
),
}
)
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
transformer
.
embed_input_ids
(
input_ids
)
return
self
.
transformer
.
embed_input_ids
(
input_ids
)
...
...
vllm/model_executor/models/gritlm.py
View file @
c8ed39b9
...
@@ -9,17 +9,19 @@ from vllm.config import ModelConfig, VllmConfig
...
@@ -9,17 +9,19 @@ from vllm.config import ModelConfig, VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.pooler
import
(
from
vllm.model_executor.layers.pooler
import
(
DispatchPooler
,
DispatchPooler
,
Pooler
,
PoolerNormalize
,
PoolingMethod
,
PoolingParamsUpdate
,
PoolingParamsUpdate
,
TokenPoolerHeadOutput
,
TokenPoolingMethodOutput
,
)
)
from
vllm.model_executor.layers.pooler.activations
import
PoolerNormalize
from
vllm.model_executor.layers.pooler.seqwise
import
(
SequencePooler
,
SequencePoolerHeadOutput
,
SequencePoolingMethod
,
SequencePoolingMethodOutput
,
)
from
vllm.model_executor.layers.pooler.tokwise
import
pooler_for_token_embed
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
from
vllm.tasks
import
PoolingTask
from
vllm.tasks
import
PoolingTask
from
vllm.tokenizers
import
cached_tokenizer_from_config
from
vllm.tokenizers
import
cached_tokenizer_from_config
from
vllm.v1.outputs
import
TokenPoolerOutput
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
.interfaces_base
import
default_pooling_type
from
.interfaces_base
import
default_pooling_type
...
@@ -27,7 +29,7 @@ from .interfaces_base import default_pooling_type
...
@@ -27,7 +29,7 @@ from .interfaces_base import default_pooling_type
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
class
GritLMMeanPool
(
PoolingMethod
):
class
GritLMMeanPool
(
Sequence
PoolingMethod
):
"""As `MeanPool`, but only includes non-instruction tokens."""
"""As `MeanPool`, but only includes non-instruction tokens."""
def
__init__
(
self
,
model_config
:
ModelConfig
):
def
__init__
(
self
,
model_config
:
ModelConfig
):
...
@@ -151,7 +153,7 @@ class GritLMMeanPool(PoolingMethod):
...
@@ -151,7 +153,7 @@ class GritLMMeanPool(PoolingMethod):
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
pooling_metadata
:
PoolingMetadata
,
)
->
Token
PoolingMethodOutput
:
)
->
Sequence
PoolingMethodOutput
:
prompt_lens
=
pooling_metadata
.
prompt_lens
prompt_lens
=
pooling_metadata
.
prompt_lens
instr_lens
=
torch
.
tensor
(
instr_lens
=
torch
.
tensor
(
[
[
...
@@ -174,35 +176,22 @@ class GritLMMeanPool(PoolingMethod):
...
@@ -174,35 +176,22 @@ class GritLMMeanPool(PoolingMethod):
return
pooled_data
return
pooled_data
class
GritLMPooler
(
Pooler
):
class
GritLMPooler
(
Sequence
Pooler
):
def
__init__
(
self
,
model_config
:
ModelConfig
):
def
__init__
(
self
,
model_config
:
ModelConfig
):
super
().
__init__
()
super
().
__init__
(
pooling
=
GritLMMeanPool
(
model_config
),
head
=
self
.
head
,
)
self
.
pooling
=
GritLMMeanPool
(
model_config
)
self
.
activation
=
PoolerNormalize
()
self
.
activation
=
PoolerNormalize
()
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
self
.
pooling
.
get_supported_tasks
()
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
return
self
.
pooling
.
get_pooling_updates
(
task
)
def
head
(
def
head
(
self
,
self
,
pooled_data
:
Token
PoolingMethodOutput
,
pooled_data
:
Sequence
PoolingMethodOutput
,
pooling_metadata
:
PoolingMetadata
,
pooling_metadata
:
PoolingMetadata
,
)
->
Token
PoolerHeadOutput
:
)
->
Sequence
PoolerHeadOutput
:
return
self
.
activation
(
pooled_data
)
return
self
.
activation
(
pooled_data
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
TokenPoolerOutput
:
pooled_data
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
pooled_data
=
self
.
head
(
pooled_data
,
pooling_metadata
)
return
pooled_data
@
default_pooling_type
(
"MEAN"
)
@
default_pooling_type
(
"MEAN"
)
class
GritLM
(
LlamaForCausalLM
):
class
GritLM
(
LlamaForCausalLM
):
...
@@ -245,7 +234,7 @@ class GritLM(LlamaForCausalLM):
...
@@ -245,7 +234,7 @@ class GritLM(LlamaForCausalLM):
if
pooler_config
is
not
None
:
if
pooler_config
is
not
None
:
self
.
pooler
=
DispatchPooler
(
self
.
pooler
=
DispatchPooler
(
{
{
"token_embed"
:
P
ooler
.
for_token_embed
(
pooler_config
),
"token_embed"
:
p
ooler
_
for_token_embed
(
pooler_config
),
"embed"
:
GritLMPooler
(
vllm_config
.
model_config
),
"embed"
:
GritLMPooler
(
vllm_config
.
model_config
),
}
}
)
)
vllm/model_executor/models/internlm2.py
View file @
c8ed39b9
...
@@ -28,7 +28,7 @@ from vllm.model_executor.layers.linear import (
...
@@ -28,7 +28,7 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear
,
RowParallelLinear
,
)
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.pooler
import
DispatchPooler
,
Pooler
from
vllm.model_executor.layers.pooler
.tokwise
import
pooler_for_token_classify
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
@@ -434,9 +434,7 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
...
@@ -434,9 +434,7 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
pooler_config
=
vllm_config
.
model_config
.
pooler_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
assert
pooler_config
is
not
None
self
.
pooler
=
DispatchPooler
(
self
.
pooler
=
pooler_for_token_classify
(
pooler_config
)
{
"token_classify"
:
Pooler
.
for_token_classify
(
pooler_config
)}
)
def
forward
(
def
forward
(
self
,
self
,
...
...
vllm/model_executor/models/jamba.py
View file @
c8ed39b9
...
@@ -27,7 +27,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
...
@@ -27,7 +27,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator
,
MambaStateDtypeCalculator
,
MambaStateShapeCalculator
,
MambaStateShapeCalculator
,
)
)
from
vllm.model_executor.layers.pooler
import
DispatchPooler
,
Pooler
from
vllm.model_executor.layers.pooler
import
DispatchPooler
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
ParallelLMHead
,
...
@@ -596,16 +596,4 @@ class JambaForSequenceClassification(JambaForCausalLM):
...
@@ -596,16 +596,4 @@ class JambaForSequenceClassification(JambaForCausalLM):
pooler_config
=
vllm_config
.
model_config
.
pooler_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
assert
pooler_config
is
not
None
self
.
pooler
=
DispatchPooler
(
self
.
pooler
=
DispatchPooler
.
for_seq_cls
(
pooler_config
,
classifier
=
self
.
score
)
{
"token_classify"
:
Pooler
.
for_token_classify
(
pooler_config
,
classifier
=
self
.
score
),
"classify"
:
Pooler
.
for_classify
(
pooler_config
,
classifier
=
self
.
score
,
act_fn
=
"classify"
),
"score"
:
Pooler
.
for_classify
(
pooler_config
,
classifier
=
self
.
score
,
act_fn
=
"score"
),
}
)
vllm/model_executor/models/jina_vl.py
View file @
c8ed39b9
...
@@ -10,7 +10,7 @@ from vllm.config import ModelConfig, VllmConfig
...
@@ -10,7 +10,7 @@ from vllm.config import ModelConfig, VllmConfig
from
vllm.inputs
import
TokensPrompt
from
vllm.inputs
import
TokensPrompt
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
ColumnParallelLinear
,
RowParallelLinear
from
vllm.model_executor.layers.linear
import
ColumnParallelLinear
,
RowParallelLinear
from
vllm.model_executor.layers.pooler
import
DispatchPooler
,
Pooler
from
vllm.model_executor.layers.pooler
import
DispatchPooler
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
...
@@ -105,19 +105,7 @@ class JinaVLForSequenceClassification(
...
@@ -105,19 +105,7 @@ class JinaVLForSequenceClassification(
self
.
score
=
JinaVLScorer
(
self
.
score
=
JinaVLScorer
(
vllm_config
.
model_config
,
prefix
=
maybe_prefix
(
prefix
,
"score"
)
vllm_config
.
model_config
,
prefix
=
maybe_prefix
(
prefix
,
"score"
)
)
)
self
.
pooler
=
DispatchPooler
(
self
.
pooler
=
DispatchPooler
.
for_seq_cls
(
pooler_config
,
classifier
=
self
.
score
)
{
"token_classify"
:
Pooler
.
for_token_classify
(
pooler_config
,
classifier
=
self
.
score
),
"classify"
:
Pooler
.
for_classify
(
pooler_config
,
classifier
=
self
.
score
,
act_fn
=
"classify"
),
"score"
:
Pooler
.
for_classify
(
pooler_config
,
classifier
=
self
.
score
,
act_fn
=
"score"
),
}
)
@
classmethod
@
classmethod
def
get_placeholder_str
(
cls
,
modality
:
str
,
i
:
int
)
->
str
|
None
:
def
get_placeholder_str
(
cls
,
modality
:
str
,
i
:
int
)
->
str
|
None
:
...
...
vllm/model_executor/models/modernbert.py
View file @
c8ed39b9
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Iterable
,
Set
from
collections.abc
import
Iterable
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -12,21 +12,18 @@ from vllm.compilation.decorators import support_torch_compile
...
@@ -12,21 +12,18 @@ from vllm.compilation.decorators import support_torch_compile
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.linear
import
QKVParallelLinear
,
RowParallelLinear
from
vllm.model_executor.layers.linear
import
QKVParallelLinear
,
RowParallelLinear
from
vllm.model_executor.layers.pooler
import
(
from
vllm.model_executor.layers.pooler
import
DispatchPooler
ClassifierPooler
,
from
vllm.model_executor.layers.pooler.seqwise
import
(
DispatchPooler
,
SequencePooler
,
Pooler
,
SequencePoolerHeadOutput
,
PoolingMethod
,
SequencePoolingMethodOutput
,
PoolingParamsUpdate
,
get_seq_pooling_method
,
TokenPoolerHeadOutput
,
TokenPoolingMethodOutput
,
)
)
from
vllm.model_executor.layers.pooler.tokwise
import
pooler_for_token_classify
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.tasks
import
PoolingTask
from
vllm.v1.outputs
import
TokenPoolerOutput
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
.interfaces
import
SupportsCrossEncoding
from
.interfaces
import
SupportsCrossEncoding
...
@@ -282,12 +279,13 @@ class ModernBertModel(nn.Module):
...
@@ -282,12 +279,13 @@ class ModernBertModel(nn.Module):
return
norm_outputs
return
norm_outputs
class
ModernBertPooler
(
Pooler
):
class
ModernBertPooler
(
Sequence
Pooler
):
def
__init__
(
self
,
config
:
ModernBertConfig
):
def
__init__
(
self
,
config
:
ModernBertConfig
):
super
().
__init__
()
super
().
__init__
(
pooling
=
get_seq_pooling_method
(
config
.
classifier_pooling
.
upper
()),
head
=
self
.
head
,
)
pooling_type
=
config
.
classifier_pooling
.
upper
()
self
.
pooling
=
PoolingMethod
.
from_pooling_type
(
pooling_type
)
self
.
dense
=
nn
.
Linear
(
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
,
config
.
classifier_bias
config
.
hidden_size
,
config
.
hidden_size
,
config
.
classifier_bias
)
)
...
@@ -296,32 +294,17 @@ class ModernBertPooler(Pooler):
...
@@ -296,32 +294,17 @@ class ModernBertPooler(Pooler):
config
.
hidden_size
,
eps
=
config
.
norm_eps
,
bias
=
config
.
norm_bias
config
.
hidden_size
,
eps
=
config
.
norm_eps
,
bias
=
config
.
norm_bias
)
)
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
self
.
pooling
.
get_supported_tasks
()
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
return
self
.
pooling
.
get_pooling_updates
(
task
)
def
head
(
def
head
(
self
,
self
,
pooled_data
:
Token
PoolingMethodOutput
,
pooled_data
:
Sequence
PoolingMethodOutput
,
pooling_metadata
:
PoolingMetadata
,
pooling_metadata
:
PoolingMetadata
,
)
->
Token
PoolerHeadOutput
:
)
->
Sequence
PoolerHeadOutput
:
if
isinstance
(
pooled_data
,
list
):
if
isinstance
(
pooled_data
,
list
):
pooled_data
=
torch
.
stack
(
pooled_data
)
pooled_data
=
torch
.
stack
(
pooled_data
)
pooled_data
=
pooled_data
.
to
(
self
.
dense
.
weight
.
dtype
)
pooled_data
=
pooled_data
.
to
(
self
.
dense
.
weight
.
dtype
)
return
self
.
norm
(
self
.
act
(
self
.
dense
(
pooled_data
)))
return
self
.
norm
(
self
.
act
(
self
.
dense
(
pooled_data
)))
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
TokenPoolerOutput
:
pooled_data
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
pooled_data
=
self
.
head
(
pooled_data
,
pooling_metadata
)
return
pooled_data
@
default_pooling_type
(
"CLS"
)
@
default_pooling_type
(
"CLS"
)
class
ModernBertForSequenceClassification
(
nn
.
Module
,
SupportsCrossEncoding
):
class
ModernBertForSequenceClassification
(
nn
.
Module
,
SupportsCrossEncoding
):
...
@@ -344,18 +327,10 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
...
@@ -344,18 +327,10 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
pooler_config
=
vllm_config
.
model_config
.
pooler_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
assert
pooler_config
is
not
None
self
.
pooler
=
DispatchPooler
(
self
.
pooler
=
DispatchPooler
.
for_seq_cls
(
{
pooler_config
,
"token_classify"
:
Pooler
.
for_token_classify
(
pooling
=
self
.
pooling
,
pooler_config
,
classifier
=
self
.
classifier
classifier
=
self
.
classifier
,
),
"classify"
:
ClassifierPooler
(
pooling
=
self
.
pooling
,
classifier
=
self
.
classifier
,
act_fn
=
"classify"
),
"score"
:
ClassifierPooler
(
pooling
=
self
.
pooling
,
classifier
=
self
.
classifier
,
act_fn
=
"score"
),
}
)
)
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -438,13 +413,7 @@ class ModernBertForTokenClassification(nn.Module):
...
@@ -438,13 +413,7 @@ class ModernBertForTokenClassification(nn.Module):
pooler_config
=
vllm_config
.
model_config
.
pooler_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
assert
pooler_config
is
not
None
self
.
pooler
=
DispatchPooler
(
self
.
pooler
=
pooler_for_token_classify
(
pooler_config
)
{
"token_classify"
:
Pooler
.
for_token_classify
(
pooler_config
=
pooler_config
),
}
)
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
embed_input_ids
(
input_ids
)
return
self
.
model
.
embed_input_ids
(
input_ids
)
...
...
vllm/model_executor/models/qwen2_rm.py
View file @
c8ed39b9
...
@@ -14,7 +14,8 @@ from torch import nn
...
@@ -14,7 +14,8 @@ from torch import nn
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.linear
import
ColumnParallelLinear
,
RowParallelLinear
from
vllm.model_executor.layers.linear
import
ColumnParallelLinear
,
RowParallelLinear
from
vllm.model_executor.layers.pooler
import
DispatchPooler
,
Pooler
from
vllm.model_executor.layers.pooler
import
Pooler
from
vllm.model_executor.layers.pooler.tokwise
import
pooler_for_token_classify
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.interfaces
import
SupportsLoRA
,
SupportsPP
...
@@ -104,9 +105,7 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel):
...
@@ -104,9 +105,7 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel):
pooler_config
=
vllm_config
.
model_config
.
pooler_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
assert
pooler_config
is
not
None
self
.
pooler
=
DispatchPooler
(
self
.
pooler
=
pooler_for_token_classify
(
pooler_config
)
{
"token_classify"
:
Pooler
.
for_token_classify
(
pooler_config
)}
)
@
default_pooling_type
(
"STEP"
)
@
default_pooling_type
(
"STEP"
)
...
@@ -118,6 +117,4 @@ class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
...
@@ -118,6 +117,4 @@ class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
pooler_config
=
vllm_config
.
model_config
.
pooler_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
assert
pooler_config
is
not
None
self
.
pooler
=
DispatchPooler
(
self
.
pooler
=
pooler_for_token_classify
(
pooler_config
)
{
"token_classify"
:
Pooler
.
for_token_classify
(
pooler_config
)}
)
vllm/model_executor/models/roberta.py
View file @
c8ed39b9
...
@@ -8,12 +8,8 @@ from torch import nn
...
@@ -8,12 +8,8 @@ from torch import nn
from
transformers
import
RobertaConfig
from
transformers
import
RobertaConfig
from
vllm.config
import
ModelConfig
,
VllmConfig
from
vllm.config
import
ModelConfig
,
VllmConfig
from
vllm.model_executor.layers.pooler
import
(
from
vllm.model_executor.layers.pooler
import
DispatchPooler
ClassifierPooler
,
from
vllm.model_executor.layers.pooler.seqwise
import
CLSPool
CLSPool
,
DispatchPooler
,
Pooler
,
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.models.bert
import
(
from
vllm.model_executor.models.bert
import
(
TOKEN_TYPE_SHIFT
,
TOKEN_TYPE_SHIFT
,
...
@@ -196,18 +192,10 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
...
@@ -196,18 +192,10 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
pooler_config
=
vllm_config
.
model_config
.
pooler_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
assert
pooler_config
is
not
None
self
.
pooler
=
DispatchPooler
(
self
.
pooler
=
DispatchPooler
.
for_seq_cls
(
{
pooler_config
,
"token_classify"
:
Pooler
.
for_token_classify
(
pooling
=
CLSPool
(),
pooler_config
=
pooler_config
,
classifier
=
self
.
classifier
classifier
=
self
.
classifier
,
),
"classify"
:
ClassifierPooler
(
pooling
=
CLSPool
(),
classifier
=
self
.
classifier
,
act_fn
=
"classify"
),
"score"
:
ClassifierPooler
(
pooling
=
CLSPool
(),
classifier
=
self
.
classifier
,
act_fn
=
"score"
),
}
)
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
...
...
vllm/model_executor/models/siglip.py
View file @
c8ed39b9
...
@@ -27,7 +27,7 @@ from vllm.model_executor.layers.linear import (
...
@@ -27,7 +27,7 @@ from vllm.model_executor.layers.linear import (
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
RowParallelLinear
,
)
)
from
vllm.model_executor.layers.pooler
import
DispatchPooler
,
Pooler
from
vllm.model_executor.layers.pooler
import
DispatchPooler
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.model_loader.weight_utils
import
(
from
vllm.model_executor.model_loader.weight_utils
import
(
...
@@ -1050,12 +1050,7 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
...
@@ -1050,12 +1050,7 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
assert
pooler_config
is
not
None
assert
pooler_config
is
not
None
self
.
pooler_config
=
pooler_config
self
.
pooler_config
=
pooler_config
self
.
pooler
=
DispatchPooler
(
self
.
pooler
=
DispatchPooler
.
for_embedding
(
pooler_config
)
{
"token_embed"
:
Pooler
.
for_token_embed
(
pooler_config
),
"embed"
:
Pooler
.
for_embed
(
pooler_config
),
}
)
self
.
_is_text_input
=
True
self
.
_is_text_input
=
True
...
...
vllm/model_executor/models/terratorch.py
View file @
c8ed39b9
...
@@ -34,7 +34,7 @@ from transformers import BatchFeature
...
@@ -34,7 +34,7 @@ from transformers import BatchFeature
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.config.multimodal
import
BaseDummyOptions
from
vllm.config.multimodal
import
BaseDummyOptions
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.pooler
import
DispatchPooler
,
Dumm
yPooler
from
vllm.model_executor.layers.pooler
import
Identit
yPooler
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.utils
import
AutoWeightsLoader
from
vllm.model_executor.models.utils
import
AutoWeightsLoader
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
...
@@ -248,7 +248,7 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
...
@@ -248,7 +248,7 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
pooler_config
=
vllm_config
.
model_config
.
pooler_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
assert
pooler_config
is
not
None
self
.
pooler
=
DispatchPooler
({
"plugin"
:
Dumm
yPooler
()
})
self
.
pooler
=
Identit
yPooler
()
def
embed_input_ids
(
def
embed_input_ids
(
self
,
self
,
...
...
vllm/model_executor/models/transformers/pooling.py
View file @
c8ed39b9
...
@@ -22,12 +22,8 @@ import torch
...
@@ -22,12 +22,8 @@ import torch
from
transformers
import
AutoModelForSequenceClassification
from
transformers
import
AutoModelForSequenceClassification
from
vllm.config.utils
import
getattr_iter
from
vllm.config.utils
import
getattr_iter
from
vllm.model_executor.layers.pooler
import
(
from
vllm.model_executor.layers.pooler
import
DispatchPooler
ClassifierPooler
,
from
vllm.model_executor.layers.pooler.seqwise
import
CLSPool
CLSPool
,
DispatchPooler
,
Pooler
,
)
from
vllm.model_executor.models.interfaces
import
SupportsCrossEncoding
from
vllm.model_executor.models.interfaces
import
SupportsCrossEncoding
from
vllm.model_executor.models.interfaces_base
import
VllmModelForPooling
from
vllm.model_executor.models.interfaces_base
import
VllmModelForPooling
...
@@ -47,12 +43,7 @@ class EmbeddingMixin(VllmModelForPooling):
...
@@ -47,12 +43,7 @@ class EmbeddingMixin(VllmModelForPooling):
pooler_config
=
vllm_config
.
model_config
.
pooler_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
assert
pooler_config
is
not
None
self
.
pooler
=
DispatchPooler
(
self
.
pooler
=
DispatchPooler
.
for_embedding
(
pooler_config
)
{
"token_embed"
:
Pooler
.
for_token_embed
(
pooler_config
),
"embed"
:
Pooler
.
for_embed
(
pooler_config
),
}
)
class
SequenceClassificationMixin
(
SupportsCrossEncoding
,
VllmModelForPooling
):
class
SequenceClassificationMixin
(
SupportsCrossEncoding
,
VllmModelForPooling
):
...
@@ -104,16 +95,8 @@ class SequenceClassificationMixin(SupportsCrossEncoding, VllmModelForPooling):
...
@@ -104,16 +95,8 @@ class SequenceClassificationMixin(SupportsCrossEncoding, VllmModelForPooling):
self
.
classifier
.
__class__
=
ClassifierWithReshape
self
.
classifier
.
__class__
=
ClassifierWithReshape
self
.
pooler
=
DispatchPooler
(
self
.
pooler
=
DispatchPooler
.
for_seq_cls
(
{
pooler_config
,
"token_classify"
:
Pooler
.
for_token_classify
(
pooling
=
CLSPool
(),
pooler_config
,
classifier
=
self
.
classifier
classifier
=
self
.
classifier
,
),
"classify"
:
ClassifierPooler
(
pooling
=
CLSPool
(),
classifier
=
self
.
classifier
,
act_fn
=
"classify"
),
"score"
:
ClassifierPooler
(
pooling
=
CLSPool
(),
classifier
=
self
.
classifier
,
act_fn
=
"score"
),
}
)
)
vllm/v1/outputs.py
View file @
c8ed39b9
...
@@ -91,9 +91,7 @@ class LogprobsTensors(NamedTuple):
...
@@ -91,9 +91,7 @@ class LogprobsTensors(NamedTuple):
# [num_reqs, <dynamic>]
# [num_reqs, <dynamic>]
# The shape of each element depends on the pooler used
# The shape of each element depends on the pooler used
TokenPoolerOutput
:
TypeAlias
=
torch
.
Tensor
|
list
[
torch
.
Tensor
]
PoolerOutput
:
TypeAlias
=
torch
.
Tensor
|
list
[
torch
.
Tensor
]
|
list
[
torch
.
Tensor
|
None
]
TokenwisePoolerOutput
:
TypeAlias
=
list
[
torch
.
Tensor
]
|
list
[
torch
.
Tensor
|
None
]
PoolerOutput
:
TypeAlias
=
TokenPoolerOutput
|
TokenwisePoolerOutput
@
dataclass
@
dataclass
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment