Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c8ed39b9
Unverified
Commit
c8ed39b9
authored
Jan 09, 2026
by
Cyrus Leung
Committed by
GitHub
Jan 09, 2026
Browse files
[Model] Reorganize pooling layers (#31973)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
02073280
Changes
34
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
75 additions
and
209 deletions
+75
-209
vllm/model_executor/models/bert_with_rope.py
vllm/model_executor/models/bert_with_rope.py
+5
-15
vllm/model_executor/models/clip.py
vllm/model_executor/models/clip.py
+2
-7
vllm/model_executor/models/gpt2.py
vllm/model_executor/models/gpt2.py
+2
-14
vllm/model_executor/models/gritlm.py
vllm/model_executor/models/gritlm.py
+18
-29
vllm/model_executor/models/internlm2.py
vllm/model_executor/models/internlm2.py
+2
-4
vllm/model_executor/models/jamba.py
vllm/model_executor/models/jamba.py
+2
-14
vllm/model_executor/models/jina_vl.py
vllm/model_executor/models/jina_vl.py
+2
-14
vllm/model_executor/models/modernbert.py
vllm/model_executor/models/modernbert.py
+20
-51
vllm/model_executor/models/qwen2_rm.py
vllm/model_executor/models/qwen2_rm.py
+4
-7
vllm/model_executor/models/roberta.py
vllm/model_executor/models/roberta.py
+6
-18
vllm/model_executor/models/siglip.py
vllm/model_executor/models/siglip.py
+2
-7
vllm/model_executor/models/terratorch.py
vllm/model_executor/models/terratorch.py
+2
-2
vllm/model_executor/models/transformers/pooling.py
vllm/model_executor/models/transformers/pooling.py
+7
-24
vllm/v1/outputs.py
vllm/v1/outputs.py
+1
-3
No files found.
vllm/model_executor/models/bert_with_rope.py
View file @
c8ed39b9
...
...
@@ -24,6 +24,7 @@ from vllm.model_executor.layers.linear import (
ReplicatedLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.pooler
import
DispatchPooler
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
...
...
@@ -37,7 +38,6 @@ from vllm.model_executor.utils import set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
IntermediateTensors
from
..layers.pooler
import
ClassifierPooler
,
DispatchPooler
,
Pooler
from
.bert
import
BertPooler
from
.interfaces
import
SupportsCrossEncoding
,
SupportsQuant
from
.interfaces_base
import
default_pooling_type
...
...
@@ -693,20 +693,10 @@ class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding):
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
self
.
pooler
=
DispatchPooler
(
{
"token_classify"
:
Pooler
.
for_token_classify
(
pooler_config
,
classifier
=
self
.
classifier
),
"classify"
:
ClassifierPooler
(
pooling
=
self
.
new
.
pooler
,
classifier
=
self
.
classifier
,
act_fn
=
"classify"
,
),
"score"
:
ClassifierPooler
(
pooling
=
self
.
new
.
pooler
,
classifier
=
self
.
classifier
,
act_fn
=
"score"
),
}
self
.
pooler
=
DispatchPooler
.
for_seq_cls
(
pooler_config
,
pooling
=
self
.
new
.
pooler
,
classifier
=
self
.
classifier
,
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
...
...
vllm/model_executor/models/clip.py
View file @
c8ed39b9
...
...
@@ -26,7 +26,7 @@ from vllm.model_executor.layers.linear import (
QKVParallelLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.pooler
import
DispatchPooler
,
Pooler
from
vllm.model_executor.layers.pooler
import
DispatchPooler
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
...
...
@@ -880,12 +880,7 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
assert
pooler_config
is
not
None
self
.
pooler_config
=
pooler_config
self
.
pooler
=
DispatchPooler
(
{
"token_embed"
:
Pooler
.
for_token_embed
(
pooler_config
),
"embed"
:
Pooler
.
for_embed
(
pooler_config
),
}
)
self
.
pooler
=
DispatchPooler
.
for_embedding
(
pooler_config
)
# Assumes that self.forward is called after self.embed_input_ids
self
.
_is_text_input
=
True
...
...
vllm/model_executor/models/gpt2.py
View file @
c8ed39b9
...
...
@@ -41,6 +41,7 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear
,
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.pooler
import
DispatchPooler
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
...
...
@@ -49,7 +50,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.sequence
import
IntermediateTensors
from
..layers.pooler
import
DispatchPooler
,
Pooler
from
.interfaces
import
SupportsCrossEncoding
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
...
...
@@ -351,19 +351,7 @@ class GPT2ForSequenceClassification(nn.Module, SupportsCrossEncoding):
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
self
.
pooler
=
DispatchPooler
(
{
"token_classify"
:
Pooler
.
for_token_classify
(
pooler_config
,
classifier
=
self
.
score
),
"classify"
:
Pooler
.
for_classify
(
pooler_config
,
classifier
=
self
.
score
,
act_fn
=
"classify"
),
"score"
:
Pooler
.
for_classify
(
pooler_config
,
classifier
=
self
.
score
,
act_fn
=
"score"
),
}
)
self
.
pooler
=
DispatchPooler
.
for_seq_cls
(
pooler_config
,
classifier
=
self
.
score
)
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
transformer
.
embed_input_ids
(
input_ids
)
...
...
vllm/model_executor/models/gritlm.py
View file @
c8ed39b9
...
...
@@ -9,17 +9,19 @@ from vllm.config import ModelConfig, VllmConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.pooler
import
(
DispatchPooler
,
Pooler
,
PoolerNormalize
,
PoolingMethod
,
PoolingParamsUpdate
,
TokenPoolerHeadOutput
,
TokenPoolingMethodOutput
,
)
from
vllm.model_executor.layers.pooler.activations
import
PoolerNormalize
from
vllm.model_executor.layers.pooler.seqwise
import
(
SequencePooler
,
SequencePoolerHeadOutput
,
SequencePoolingMethod
,
SequencePoolingMethodOutput
,
)
from
vllm.model_executor.layers.pooler.tokwise
import
pooler_for_token_embed
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
from
vllm.tasks
import
PoolingTask
from
vllm.tokenizers
import
cached_tokenizer_from_config
from
vllm.v1.outputs
import
TokenPoolerOutput
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
.interfaces_base
import
default_pooling_type
...
...
@@ -27,7 +29,7 @@ from .interfaces_base import default_pooling_type
logger
=
init_logger
(
__name__
)
class
GritLMMeanPool
(
PoolingMethod
):
class
GritLMMeanPool
(
Sequence
PoolingMethod
):
"""As `MeanPool`, but only includes non-instruction tokens."""
def
__init__
(
self
,
model_config
:
ModelConfig
):
...
...
@@ -151,7 +153,7 @@ class GritLMMeanPool(PoolingMethod):
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Token
PoolingMethodOutput
:
)
->
Sequence
PoolingMethodOutput
:
prompt_lens
=
pooling_metadata
.
prompt_lens
instr_lens
=
torch
.
tensor
(
[
...
...
@@ -174,35 +176,22 @@ class GritLMMeanPool(PoolingMethod):
return
pooled_data
class
GritLMPooler
(
Pooler
):
class
GritLMPooler
(
Sequence
Pooler
):
def
__init__
(
self
,
model_config
:
ModelConfig
):
super
().
__init__
()
super
().
__init__
(
pooling
=
GritLMMeanPool
(
model_config
),
head
=
self
.
head
,
)
self
.
pooling
=
GritLMMeanPool
(
model_config
)
self
.
activation
=
PoolerNormalize
()
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
self
.
pooling
.
get_supported_tasks
()
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
return
self
.
pooling
.
get_pooling_updates
(
task
)
def
head
(
self
,
pooled_data
:
Token
PoolingMethodOutput
,
pooled_data
:
Sequence
PoolingMethodOutput
,
pooling_metadata
:
PoolingMetadata
,
)
->
Token
PoolerHeadOutput
:
)
->
Sequence
PoolerHeadOutput
:
return
self
.
activation
(
pooled_data
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
TokenPoolerOutput
:
pooled_data
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
pooled_data
=
self
.
head
(
pooled_data
,
pooling_metadata
)
return
pooled_data
@
default_pooling_type
(
"MEAN"
)
class
GritLM
(
LlamaForCausalLM
):
...
...
@@ -245,7 +234,7 @@ class GritLM(LlamaForCausalLM):
if
pooler_config
is
not
None
:
self
.
pooler
=
DispatchPooler
(
{
"token_embed"
:
P
ooler
.
for_token_embed
(
pooler_config
),
"token_embed"
:
p
ooler
_
for_token_embed
(
pooler_config
),
"embed"
:
GritLMPooler
(
vllm_config
.
model_config
),
}
)
vllm/model_executor/models/internlm2.py
View file @
c8ed39b9
...
...
@@ -28,7 +28,7 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear
,
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.pooler
import
DispatchPooler
,
Pooler
from
vllm.model_executor.layers.pooler
.tokwise
import
pooler_for_token_classify
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
...
@@ -434,9 +434,7 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
self
.
pooler
=
DispatchPooler
(
{
"token_classify"
:
Pooler
.
for_token_classify
(
pooler_config
)}
)
self
.
pooler
=
pooler_for_token_classify
(
pooler_config
)
def
forward
(
self
,
...
...
vllm/model_executor/models/jamba.py
View file @
c8ed39b9
...
...
@@ -27,7 +27,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator
,
MambaStateShapeCalculator
,
)
from
vllm.model_executor.layers.pooler
import
DispatchPooler
,
Pooler
from
vllm.model_executor.layers.pooler
import
DispatchPooler
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
...
...
@@ -596,16 +596,4 @@ class JambaForSequenceClassification(JambaForCausalLM):
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
self
.
pooler
=
DispatchPooler
(
{
"token_classify"
:
Pooler
.
for_token_classify
(
pooler_config
,
classifier
=
self
.
score
),
"classify"
:
Pooler
.
for_classify
(
pooler_config
,
classifier
=
self
.
score
,
act_fn
=
"classify"
),
"score"
:
Pooler
.
for_classify
(
pooler_config
,
classifier
=
self
.
score
,
act_fn
=
"score"
),
}
)
self
.
pooler
=
DispatchPooler
.
for_seq_cls
(
pooler_config
,
classifier
=
self
.
score
)
vllm/model_executor/models/jina_vl.py
View file @
c8ed39b9
...
...
@@ -10,7 +10,7 @@ from vllm.config import ModelConfig, VllmConfig
from
vllm.inputs
import
TokensPrompt
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
ColumnParallelLinear
,
RowParallelLinear
from
vllm.model_executor.layers.pooler
import
DispatchPooler
,
Pooler
from
vllm.model_executor.layers.pooler
import
DispatchPooler
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.sequence
import
IntermediateTensors
...
...
@@ -105,19 +105,7 @@ class JinaVLForSequenceClassification(
self
.
score
=
JinaVLScorer
(
vllm_config
.
model_config
,
prefix
=
maybe_prefix
(
prefix
,
"score"
)
)
self
.
pooler
=
DispatchPooler
(
{
"token_classify"
:
Pooler
.
for_token_classify
(
pooler_config
,
classifier
=
self
.
score
),
"classify"
:
Pooler
.
for_classify
(
pooler_config
,
classifier
=
self
.
score
,
act_fn
=
"classify"
),
"score"
:
Pooler
.
for_classify
(
pooler_config
,
classifier
=
self
.
score
,
act_fn
=
"score"
),
}
)
self
.
pooler
=
DispatchPooler
.
for_seq_cls
(
pooler_config
,
classifier
=
self
.
score
)
@
classmethod
def
get_placeholder_str
(
cls
,
modality
:
str
,
i
:
int
)
->
str
|
None
:
...
...
vllm/model_executor/models/modernbert.py
View file @
c8ed39b9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Iterable
,
Set
from
collections.abc
import
Iterable
import
torch
from
torch
import
nn
...
...
@@ -12,21 +12,18 @@ from vllm.compilation.decorators import support_torch_compile
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.linear
import
QKVParallelLinear
,
RowParallelLinear
from
vllm.model_executor.layers.pooler
import
(
ClassifierPooler
,
DispatchPooler
,
Pooler
,
PoolingMethod
,
PoolingParamsUpdate
,
TokenPoolerHeadOutput
,
TokenPoolingMethodOutput
,
from
vllm.model_executor.layers.pooler
import
DispatchPooler
from
vllm.model_executor.layers.pooler.seqwise
import
(
SequencePooler
,
SequencePoolerHeadOutput
,
SequencePoolingMethodOutput
,
get_seq_pooling_method
,
)
from
vllm.model_executor.layers.pooler.tokwise
import
pooler_for_token_classify
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.sequence
import
IntermediateTensors
from
vllm.tasks
import
PoolingTask
from
vllm.v1.outputs
import
TokenPoolerOutput
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
.interfaces
import
SupportsCrossEncoding
...
...
@@ -282,12 +279,13 @@ class ModernBertModel(nn.Module):
return
norm_outputs
class
ModernBertPooler
(
Pooler
):
class
ModernBertPooler
(
Sequence
Pooler
):
def
__init__
(
self
,
config
:
ModernBertConfig
):
super
().
__init__
()
super
().
__init__
(
pooling
=
get_seq_pooling_method
(
config
.
classifier_pooling
.
upper
()),
head
=
self
.
head
,
)
pooling_type
=
config
.
classifier_pooling
.
upper
()
self
.
pooling
=
PoolingMethod
.
from_pooling_type
(
pooling_type
)
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
,
config
.
classifier_bias
)
...
...
@@ -296,32 +294,17 @@ class ModernBertPooler(Pooler):
config
.
hidden_size
,
eps
=
config
.
norm_eps
,
bias
=
config
.
norm_bias
)
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
self
.
pooling
.
get_supported_tasks
()
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
return
self
.
pooling
.
get_pooling_updates
(
task
)
def
head
(
self
,
pooled_data
:
Token
PoolingMethodOutput
,
pooled_data
:
Sequence
PoolingMethodOutput
,
pooling_metadata
:
PoolingMetadata
,
)
->
Token
PoolerHeadOutput
:
)
->
Sequence
PoolerHeadOutput
:
if
isinstance
(
pooled_data
,
list
):
pooled_data
=
torch
.
stack
(
pooled_data
)
pooled_data
=
pooled_data
.
to
(
self
.
dense
.
weight
.
dtype
)
return
self
.
norm
(
self
.
act
(
self
.
dense
(
pooled_data
)))
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
TokenPoolerOutput
:
pooled_data
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
pooled_data
=
self
.
head
(
pooled_data
,
pooling_metadata
)
return
pooled_data
@
default_pooling_type
(
"CLS"
)
class
ModernBertForSequenceClassification
(
nn
.
Module
,
SupportsCrossEncoding
):
...
...
@@ -344,18 +327,10 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
self
.
pooler
=
DispatchPooler
(
{
"token_classify"
:
Pooler
.
for_token_classify
(
pooler_config
,
classifier
=
self
.
classifier
),
"classify"
:
ClassifierPooler
(
pooling
=
self
.
pooling
,
classifier
=
self
.
classifier
,
act_fn
=
"classify"
),
"score"
:
ClassifierPooler
(
pooling
=
self
.
pooling
,
classifier
=
self
.
classifier
,
act_fn
=
"score"
),
}
self
.
pooler
=
DispatchPooler
.
for_seq_cls
(
pooler_config
,
pooling
=
self
.
pooling
,
classifier
=
self
.
classifier
,
)
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -438,13 +413,7 @@ class ModernBertForTokenClassification(nn.Module):
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
self
.
pooler
=
DispatchPooler
(
{
"token_classify"
:
Pooler
.
for_token_classify
(
pooler_config
=
pooler_config
),
}
)
self
.
pooler
=
pooler_for_token_classify
(
pooler_config
)
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
embed_input_ids
(
input_ids
)
...
...
vllm/model_executor/models/qwen2_rm.py
View file @
c8ed39b9
...
...
@@ -14,7 +14,8 @@ from torch import nn
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.linear
import
ColumnParallelLinear
,
RowParallelLinear
from
vllm.model_executor.layers.pooler
import
DispatchPooler
,
Pooler
from
vllm.model_executor.layers.pooler
import
Pooler
from
vllm.model_executor.layers.pooler.tokwise
import
pooler_for_token_classify
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
,
SupportsPP
...
...
@@ -104,9 +105,7 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel):
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
self
.
pooler
=
DispatchPooler
(
{
"token_classify"
:
Pooler
.
for_token_classify
(
pooler_config
)}
)
self
.
pooler
=
pooler_for_token_classify
(
pooler_config
)
@
default_pooling_type
(
"STEP"
)
...
...
@@ -118,6 +117,4 @@ class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
self
.
pooler
=
DispatchPooler
(
{
"token_classify"
:
Pooler
.
for_token_classify
(
pooler_config
)}
)
self
.
pooler
=
pooler_for_token_classify
(
pooler_config
)
vllm/model_executor/models/roberta.py
View file @
c8ed39b9
...
...
@@ -8,12 +8,8 @@ from torch import nn
from
transformers
import
RobertaConfig
from
vllm.config
import
ModelConfig
,
VllmConfig
from
vllm.model_executor.layers.pooler
import
(
ClassifierPooler
,
CLSPool
,
DispatchPooler
,
Pooler
,
)
from
vllm.model_executor.layers.pooler
import
DispatchPooler
from
vllm.model_executor.layers.pooler.seqwise
import
CLSPool
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.models.bert
import
(
TOKEN_TYPE_SHIFT
,
...
...
@@ -196,18 +192,10 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
self
.
pooler
=
DispatchPooler
(
{
"token_classify"
:
Pooler
.
for_token_classify
(
pooler_config
=
pooler_config
,
classifier
=
self
.
classifier
),
"classify"
:
ClassifierPooler
(
pooling
=
CLSPool
(),
classifier
=
self
.
classifier
,
act_fn
=
"classify"
),
"score"
:
ClassifierPooler
(
pooling
=
CLSPool
(),
classifier
=
self
.
classifier
,
act_fn
=
"score"
),
}
self
.
pooler
=
DispatchPooler
.
for_seq_cls
(
pooler_config
,
pooling
=
CLSPool
(),
classifier
=
self
.
classifier
,
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
...
...
vllm/model_executor/models/siglip.py
View file @
c8ed39b9
...
...
@@ -27,7 +27,7 @@ from vllm.model_executor.layers.linear import (
QKVParallelLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.pooler
import
DispatchPooler
,
Pooler
from
vllm.model_executor.layers.pooler
import
DispatchPooler
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.model_loader.weight_utils
import
(
...
...
@@ -1050,12 +1050,7 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
assert
pooler_config
is
not
None
self
.
pooler_config
=
pooler_config
self
.
pooler
=
DispatchPooler
(
{
"token_embed"
:
Pooler
.
for_token_embed
(
pooler_config
),
"embed"
:
Pooler
.
for_embed
(
pooler_config
),
}
)
self
.
pooler
=
DispatchPooler
.
for_embedding
(
pooler_config
)
self
.
_is_text_input
=
True
...
...
vllm/model_executor/models/terratorch.py
View file @
c8ed39b9
...
...
@@ -34,7 +34,7 @@ from transformers import BatchFeature
from
vllm.config
import
VllmConfig
from
vllm.config.multimodal
import
BaseDummyOptions
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.pooler
import
DispatchPooler
,
Dumm
yPooler
from
vllm.model_executor.layers.pooler
import
Identit
yPooler
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.utils
import
AutoWeightsLoader
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
...
...
@@ -248,7 +248,7 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
self
.
pooler
=
DispatchPooler
({
"plugin"
:
Dumm
yPooler
()
})
self
.
pooler
=
Identit
yPooler
()
def
embed_input_ids
(
self
,
...
...
vllm/model_executor/models/transformers/pooling.py
View file @
c8ed39b9
...
...
@@ -22,12 +22,8 @@ import torch
from
transformers
import
AutoModelForSequenceClassification
from
vllm.config.utils
import
getattr_iter
from
vllm.model_executor.layers.pooler
import
(
ClassifierPooler
,
CLSPool
,
DispatchPooler
,
Pooler
,
)
from
vllm.model_executor.layers.pooler
import
DispatchPooler
from
vllm.model_executor.layers.pooler.seqwise
import
CLSPool
from
vllm.model_executor.models.interfaces
import
SupportsCrossEncoding
from
vllm.model_executor.models.interfaces_base
import
VllmModelForPooling
...
...
@@ -47,12 +43,7 @@ class EmbeddingMixin(VllmModelForPooling):
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
self
.
pooler
=
DispatchPooler
(
{
"token_embed"
:
Pooler
.
for_token_embed
(
pooler_config
),
"embed"
:
Pooler
.
for_embed
(
pooler_config
),
}
)
self
.
pooler
=
DispatchPooler
.
for_embedding
(
pooler_config
)
class
SequenceClassificationMixin
(
SupportsCrossEncoding
,
VllmModelForPooling
):
...
...
@@ -104,16 +95,8 @@ class SequenceClassificationMixin(SupportsCrossEncoding, VllmModelForPooling):
self
.
classifier
.
__class__
=
ClassifierWithReshape
self
.
pooler
=
DispatchPooler
(
{
"token_classify"
:
Pooler
.
for_token_classify
(
pooler_config
,
classifier
=
self
.
classifier
),
"classify"
:
ClassifierPooler
(
pooling
=
CLSPool
(),
classifier
=
self
.
classifier
,
act_fn
=
"classify"
),
"score"
:
ClassifierPooler
(
pooling
=
CLSPool
(),
classifier
=
self
.
classifier
,
act_fn
=
"score"
),
}
self
.
pooler
=
DispatchPooler
.
for_seq_cls
(
pooler_config
,
pooling
=
CLSPool
(),
classifier
=
self
.
classifier
,
)
vllm/v1/outputs.py
View file @
c8ed39b9
...
...
@@ -91,9 +91,7 @@ class LogprobsTensors(NamedTuple):
# [num_reqs, <dynamic>]
# The shape of each element depends on the pooler used
TokenPoolerOutput
:
TypeAlias
=
torch
.
Tensor
|
list
[
torch
.
Tensor
]
TokenwisePoolerOutput
:
TypeAlias
=
list
[
torch
.
Tensor
]
|
list
[
torch
.
Tensor
|
None
]
PoolerOutput
:
TypeAlias
=
TokenPoolerOutput
|
TokenwisePoolerOutput
PoolerOutput
:
TypeAlias
=
torch
.
Tensor
|
list
[
torch
.
Tensor
]
|
list
[
torch
.
Tensor
|
None
]
@
dataclass
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment