Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
90bd2ab6
Unverified
Commit
90bd2ab6
authored
Jul 18, 2025
by
Cyrus Leung
Committed by
GitHub
Jul 17, 2025
Browse files
[Model] Update pooling model interface (#21058)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
9fb2d220
Changes
17
Show whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
247 additions
and
345 deletions
+247
-345
tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py
...dd_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py
+5
-10
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+5
-29
vllm/model_executor/layers/pooler.py
vllm/model_executor/layers/pooler.py
+112
-64
vllm/model_executor/models/adapters.py
vllm/model_executor/models/adapters.py
+8
-23
vllm/model_executor/models/bert.py
vllm/model_executor/models/bert.py
+18
-19
vllm/model_executor/models/gpt2.py
vllm/model_executor/models/gpt2.py
+4
-10
vllm/model_executor/models/gritlm.py
vllm/model_executor/models/gritlm.py
+3
-9
vllm/model_executor/models/interfaces.py
vllm/model_executor/models/interfaces.py
+12
-74
vllm/model_executor/models/interfaces_base.py
vllm/model_executor/models/interfaces_base.py
+16
-17
vllm/model_executor/models/internlm2.py
vllm/model_executor/models/internlm2.py
+4
-10
vllm/model_executor/models/jamba.py
vllm/model_executor/models/jamba.py
+4
-10
vllm/model_executor/models/jina_vl.py
vllm/model_executor/models/jina_vl.py
+4
-11
vllm/model_executor/models/modernbert.py
vllm/model_executor/models/modernbert.py
+12
-12
vllm/model_executor/models/prithvi_geospatial_mae.py
vllm/model_executor/models/prithvi_geospatial_mae.py
+8
-12
vllm/model_executor/models/qwen2_rm.py
vllm/model_executor/models/qwen2_rm.py
+9
-14
vllm/model_executor/models/roberta.py
vllm/model_executor/models/roberta.py
+3
-10
vllm/pooling_params.py
vllm/pooling_params.py
+20
-11
No files found.
tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py
View file @
90bd2ab6
...
...
@@ -11,11 +11,13 @@ from vllm.config import VllmConfig
from
vllm.model_executor.layers.pooler
import
Pooler
,
PoolingType
from
vllm.model_executor.models.gemma2
import
Gemma2Model
from
vllm.model_executor.models.utils
import
WeightsMapper
,
maybe_prefix
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
from
vllm.sequence
import
IntermediateTensors
class
MyGemma2Embedding
(
nn
.
Module
):
is_pooling_model
=
True
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
"model."
:
""
})
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
...
...
@@ -24,7 +26,7 @@ class MyGemma2Embedding(nn.Module):
self
.
model
=
Gemma2Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
_
pooler
=
Pooler
.
from_config_with_defaults
(
self
.
pooler
=
Pooler
.
from_config_with_defaults
(
vllm_config
.
model_config
.
pooler_config
,
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
,
...
...
@@ -54,13 +56,6 @@ class MyGemma2Embedding(nn.Module):
# Return all-zero embeddings
return
torch
.
zeros_like
(
hidden_states
)
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
weights
=
self
.
hf_to_vllm_mapper
.
apply
(
weights
)
...
...
vllm/entrypoints/openai/protocol.py
View file @
90bd2ab6
...
...
@@ -1237,10 +1237,6 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
user
:
Optional
[
str
]
=
None
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=-
1
)]]
=
None
# --8<-- [start:embedding-pooling-params]
additional_data
:
Optional
[
Any
]
=
None
# --8<-- [end:embedding-pooling-params]
# --8<-- [start:embedding-extra-params]
add_special_tokens
:
bool
=
Field
(
default
=
True
,
...
...
@@ -1259,8 +1255,7 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
# --8<-- [end:embedding-extra-params]
def
to_pooling_params
(
self
):
return
PoolingParams
(
dimensions
=
self
.
dimensions
,
additional_data
=
self
.
additional_data
)
return
PoolingParams
(
dimensions
=
self
.
dimensions
)
class
EmbeddingChatRequest
(
OpenAIBaseModel
):
...
...
@@ -1272,10 +1267,6 @@ class EmbeddingChatRequest(OpenAIBaseModel):
user
:
Optional
[
str
]
=
None
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=-
1
)]]
=
None
# --8<-- [start:chat-embedding-pooling-params]
additional_data
:
Optional
[
Any
]
=
None
# --8<-- [end:chat-embedding-pooling-params]
# --8<-- [start:chat-embedding-extra-params]
add_special_tokens
:
bool
=
Field
(
default
=
False
,
...
...
@@ -1323,8 +1314,7 @@ class EmbeddingChatRequest(OpenAIBaseModel):
return
data
def
to_pooling_params
(
self
):
return
PoolingParams
(
dimensions
=
self
.
dimensions
,
additional_data
=
self
.
additional_data
)
return
PoolingParams
(
dimensions
=
self
.
dimensions
)
EmbeddingRequest
=
Union
[
EmbeddingCompletionRequest
,
EmbeddingChatRequest
]
...
...
@@ -1340,10 +1330,6 @@ class ScoreRequest(OpenAIBaseModel):
text_2
:
Union
[
list
[
str
],
str
,
ScoreMultiModalParam
]
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=-
1
)]]
=
None
# --8<-- [start:score-pooling-params]
additional_data
:
Optional
[
Any
]
=
None
# --8<-- [end:score-pooling-params]
# --8<-- [start:score-extra-params]
mm_processor_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
Field
(
...
...
@@ -1362,8 +1348,7 @@ class ScoreRequest(OpenAIBaseModel):
# --8<-- [end:score-extra-params]
def
to_pooling_params
(
self
,
*
,
use_cross_encoder
:
bool
=
False
):
return
PoolingParams
(
use_cross_encoder
=
use_cross_encoder
,
additional_data
=
self
.
additional_data
)
return
PoolingParams
(
use_cross_encoder
=
use_cross_encoder
)
class
RerankRequest
(
OpenAIBaseModel
):
...
...
@@ -1373,10 +1358,6 @@ class RerankRequest(OpenAIBaseModel):
top_n
:
int
=
Field
(
default_factory
=
lambda
:
0
)
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=-
1
)]]
=
None
# --8<-- [start:rerank-pooling-params]
additional_data
:
Optional
[
Any
]
=
None
# --8<-- [end:rerank-pooling-params]
# --8<-- [start:rerank-extra-params]
mm_processor_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
Field
(
...
...
@@ -1395,8 +1376,7 @@ class RerankRequest(OpenAIBaseModel):
# --8<-- [end:rerank-extra-params]
def
to_pooling_params
(
self
,
*
,
use_cross_encoder
:
bool
=
False
):
return
PoolingParams
(
use_cross_encoder
=
use_cross_encoder
,
additional_data
=
self
.
additional_data
)
return
PoolingParams
(
use_cross_encoder
=
use_cross_encoder
)
class
RerankDocument
(
BaseModel
):
...
...
@@ -1534,10 +1514,6 @@ class ClassificationRequest(OpenAIBaseModel):
truncate_prompt_tokens
:
Optional
[
int
]
=
None
user
:
Optional
[
str
]
=
None
# --8<-- [start:classification-pooling-params]
additional_data
:
Optional
[
Any
]
=
None
# --8<-- [end:classification-pooling-params]
# --8<-- [start:classification-extra-params]
priority
:
int
=
Field
(
default
=
0
,
...
...
@@ -1550,7 +1526,7 @@ class ClassificationRequest(OpenAIBaseModel):
# --8<-- [end:classification-extra-params]
def
to_pooling_params
(
self
):
return
PoolingParams
(
additional_data
=
self
.
additional_data
)
return
PoolingParams
()
class
ClassificationData
(
OpenAIBaseModel
):
...
...
vllm/model_executor/layers/pooler.py
View file @
90bd2ab6
...
...
@@ -3,22 +3,25 @@
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
from
enum
import
IntEnum
from
typing
import
Callable
,
Optional
,
TypeVar
,
Union
from
typing
import
Callable
,
Literal
,
Optional
,
TypeVar
,
Union
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
transformers
import
PretrainedConfig
from
typing_extensions
import
assert_never
from
vllm.config
import
ModelConfig
,
PoolerConfig
from
vllm.model_executor.pooling_metadata
import
(
# noqa: E501
PoolingMetadata
as
V0PoolingMetadata
)
from
vllm.model_executor.pooling_metadata
import
PoolingTensors
from
vllm.pooling_params
import
PoolingParams
from
vllm.sequence
import
PoolerOutput
,
PoolingSequenceGroupOutput
from
vllm.utils
import
resolve_obj_by_qualname
from
vllm.v1.pool.metadata
import
PoolingMetadata
as
V1PoolingMetadata
PoolingMetadata
=
Union
[
V0PoolingMetadata
,
V1PoolingMetadata
]
PoolingTask
=
Literal
[
"encode"
,
"embed"
,
"classify"
,
"score"
]
class
PoolingType
(
IntEnum
):
...
...
@@ -64,6 +67,48 @@ class ResolvedPoolingConfig:
)
class
Pooler
(
nn
.
Module
,
ABC
):
"""The interface required for all poolers used in pooling models in vLLM."""
@
staticmethod
def
from_config_with_defaults
(
pooler_config
:
PoolerConfig
,
pooling_type
:
PoolingType
,
normalize
:
bool
,
softmax
:
bool
,
step_tag_id
:
Optional
[
int
]
=
None
,
returned_token_ids
:
Optional
[
list
[
int
]]
=
None
,
)
->
"Pooler"
:
resolved_config
=
ResolvedPoolingConfig
.
from_config_with_defaults
(
pooler_config
=
pooler_config
,
pooling_type
=
pooling_type
,
normalize
=
normalize
,
softmax
=
softmax
,
step_tag_id
=
step_tag_id
,
returned_token_ids
=
returned_token_ids
,
)
if
pooling_type
==
PoolingType
.
STEP
:
return
StepPooler
.
from_config
(
resolved_config
)
return
SimplePooler
.
from_config
(
resolved_config
)
def
get_pooling_params
(
self
,
task
:
PoolingTask
)
->
Optional
[
PoolingParams
]:
"""
Construct the pooling parameters to use for a task,
or `None` if the task is not supported.
"""
return
None
@
abstractmethod
def
forward
(
self
,
hidden_states
:
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
],
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
raise
NotImplementedError
def
get_prompt_lens
(
hidden_states
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
pooling_metadata
:
PoolingMetadata
,
...
...
@@ -104,17 +149,6 @@ def build_output(all_data: torch.Tensor) -> PoolerOutput:
return
PoolerOutput
(
outputs
=
all_outputs
)
class
BasePooler
(
nn
.
Module
):
@
abstractmethod
def
forward
(
self
,
hidden_states
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
raise
NotImplementedError
class
PoolingMethod
(
nn
.
Module
,
ABC
):
@
staticmethod
...
...
@@ -130,6 +164,10 @@ class PoolingMethod(nn.Module, ABC):
raise
NotImplementedError
(
f
"Unsupported method:
{
pooling_type
}
"
)
@
abstractmethod
def
get_pooling_params
(
self
,
task
:
PoolingTask
)
->
Optional
[
PoolingParams
]:
raise
NotImplementedError
@
abstractmethod
def
forward_one
(
self
,
...
...
@@ -168,6 +206,14 @@ class PoolingMethod(nn.Module, ABC):
class
CLSPool
(
PoolingMethod
):
def
get_pooling_params
(
self
,
task
:
PoolingTask
)
->
Optional
[
PoolingParams
]:
# The equalities are split up to keep mypy happy
if
(
task
==
"encode"
or
task
==
"embed"
or
task
==
"classify"
or
task
==
"score"
):
return
PoolingParams
()
assert_never
(
task
)
def
forward_one
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
@@ -190,6 +236,14 @@ class CLSPool(PoolingMethod):
class
LastPool
(
PoolingMethod
):
def
get_pooling_params
(
self
,
task
:
PoolingTask
)
->
Optional
[
PoolingParams
]:
# The equalities are split up to keep mypy happy
if
(
task
==
"encode"
or
task
==
"embed"
or
task
==
"classify"
or
task
==
"score"
):
return
PoolingParams
()
assert_never
(
task
)
def
forward_one
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
@@ -208,6 +262,16 @@ class LastPool(PoolingMethod):
class
AllPool
(
PoolingMethod
):
def
get_pooling_params
(
self
,
task
:
PoolingTask
)
->
Optional
[
PoolingParams
]:
if
task
==
"encode"
:
return
PoolingParams
()
# The equalities are split up to keep mypy happy
if
task
==
"embed"
or
task
==
"classify"
or
task
==
"score"
:
return
None
assert_never
(
task
)
def
forward_one
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
@@ -235,6 +299,14 @@ class AllPool(PoolingMethod):
class
MeanPool
(
PoolingMethod
):
def
get_pooling_params
(
self
,
task
:
PoolingTask
)
->
Optional
[
PoolingParams
]:
# The equalities are split up to keep mypy happy
if
(
task
==
"encode"
or
task
==
"embed"
or
task
==
"classify"
or
task
==
"score"
):
return
PoolingParams
()
assert_never
(
task
)
def
forward_one
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
@@ -345,25 +417,6 @@ class LambdaPoolerActivation(PoolerActivation):
class
PoolerHead
(
nn
.
Module
):
@
classmethod
def
from_config_with_defaults
(
cls
,
pooler_config
:
PoolerConfig
,
pooling_type
:
PoolingType
,
normalize
:
bool
,
softmax
:
bool
,
)
->
"PoolerHead"
:
resolved_config
=
ResolvedPoolingConfig
.
from_config_with_defaults
(
pooler_config
=
pooler_config
,
pooling_type
=
pooling_type
,
normalize
=
normalize
,
softmax
=
softmax
,
step_tag_id
=
None
,
returned_token_ids
=
None
,
)
return
cls
.
from_config
(
resolved_config
)
@
classmethod
def
from_config
(
cls
,
pooler_config
:
ResolvedPoolingConfig
)
->
"PoolerHead"
:
if
pooler_config
.
normalize
and
pooler_config
.
softmax
:
...
...
@@ -424,21 +477,17 @@ class PoolerHead(nn.Module):
return
self
.
activation
(
pooled_data
)
class
SimplePooler
(
Base
Pooler
):
class
SimplePooler
(
Pooler
):
"""A layer that pools specific information from hidden states.
This layer does the following:
1. Extracts specific tokens or aggregates data based on pooling method.
2. Normalizes output if specified.
3. Returns structured results as `PoolerOutput`.
Attributes:
pooling_type: The type of pooling to use.
normalize: Whether to normalize the pooled data.
"""
@
classmethod
def
from_config_with_defaults
(
def
from_config_with_defaults
(
# type: ignore[override]
cls
,
pooler_config
:
PoolerConfig
,
pooling_type
:
PoolingType
,
...
...
@@ -471,6 +520,9 @@ class SimplePooler(BasePooler):
self
.
pooling
=
pooling
self
.
head
=
head
def
get_pooling_params
(
self
,
task
:
PoolingTask
)
->
Optional
[
PoolingParams
]:
return
self
.
pooling
.
get_pooling_params
(
task
)
def
forward
(
self
,
hidden_states
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
...
...
@@ -481,7 +533,7 @@ class SimplePooler(BasePooler):
return
build_output
(
pooled_data
)
class
StepPooler
(
Base
Pooler
):
class
StepPooler
(
Pooler
):
@
classmethod
def
from_config
(
cls
,
pooler_config
:
ResolvedPoolingConfig
)
->
"StepPooler"
:
...
...
@@ -543,6 +595,16 @@ class StepPooler(BasePooler):
return
pooled_data
def
get_pooling_params
(
self
,
task
:
PoolingTask
)
->
Optional
[
PoolingParams
]:
if
task
==
"encode"
:
return
PoolingParams
(
logits_processing_needs_token_ids
=
True
)
# The equalities are split up to keep mypy happy
if
task
==
"embed"
or
task
==
"classify"
or
task
==
"score"
:
return
None
assert_never
(
task
)
def
forward
(
self
,
hidden_states
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
...
...
@@ -553,32 +615,6 @@ class StepPooler(BasePooler):
return
build_output
(
pooled_data
)
class
Pooler
(
nn
.
Module
):
@
staticmethod
def
from_config_with_defaults
(
pooler_config
:
PoolerConfig
,
pooling_type
:
PoolingType
,
normalize
:
bool
,
softmax
:
bool
,
step_tag_id
:
Optional
[
int
]
=
None
,
returned_token_ids
:
Optional
[
list
[
int
]]
=
None
,
)
->
BasePooler
:
resolved_config
=
ResolvedPoolingConfig
.
from_config_with_defaults
(
pooler_config
=
pooler_config
,
pooling_type
=
pooling_type
,
normalize
=
normalize
,
softmax
=
softmax
,
step_tag_id
=
step_tag_id
,
returned_token_ids
=
returned_token_ids
,
)
if
pooling_type
==
PoolingType
.
STEP
:
return
StepPooler
.
from_config
(
resolved_config
)
return
SimplePooler
.
from_config
(
resolved_config
)
PoolingFn
=
Callable
[
[
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
PoolingMetadata
],
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]]
...
...
@@ -618,6 +654,18 @@ class ClassifierPooler(nn.Module):
return
(
self
.
cross_encoder_act_fn
if
use_cross_encoder
else
self
.
classification_act_fn
)
def
get_pooling_params
(
self
,
task
:
PoolingTask
)
->
Optional
[
PoolingParams
]:
if
task
==
"encode"
:
return
PoolingParams
()
if
task
==
"embed"
:
return
None
if
task
==
"classify"
:
return
PoolingParams
()
if
task
==
"score"
:
return
PoolingParams
(
use_cross_encoder
=
True
)
assert_never
(
task
)
def
forward
(
self
,
hidden_states
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
...
...
vllm/model_executor/models/adapters.py
View file @
90bd2ab6
...
...
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Iterable
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
TypeVar
,
Union
,
cast
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
TypeVar
,
cast
import
torch
import
torch.nn
as
nn
...
...
@@ -42,13 +42,14 @@ def _create_pooling_model_cls(
default_softmax
:
bool
,
)
->
_T
:
# Lazy import
from
vllm.model_executor.layers.pooler
import
Pooler
,
PoolerOutput
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.model_executor.layers.pooler
import
Pooler
from
.utils
import
AutoWeightsLoader
,
WeightsMapper
class
ModelForPooling
(
orig_cls
,
VllmModelForPooling
):
is_pooling_model
=
True
def
__init__
(
self
,
*
,
...
...
@@ -66,27 +67,20 @@ def _create_pooling_model_cls(
delattr
(
self
,
attr
)
# If the model already defines a pooler instance, don't overwrite it
if
not
getattr
(
self
,
"
_
pooler"
,
None
):
if
not
getattr
(
self
,
"pooler"
,
None
):
self
.
_init_pooler
(
vllm_config
,
prefix
=
prefix
)
def
_init_pooler
(
self
,
vllm_config
:
"VllmConfig"
,
prefix
:
str
=
""
):
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
self
.
_
pooler
=
Pooler
.
from_config_with_defaults
(
self
.
pooler
=
Pooler
.
from_config_with_defaults
(
pooler_config
,
pooling_type
=
default_pooling_type
,
normalize
=
default_normalize
,
softmax
=
default_softmax
,
)
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
# TODO: Support uninitialized params tracking
...
...
@@ -171,10 +165,8 @@ def as_seq_cls_model(cls: _T) -> _T:
# Lazy import
from
vllm.model_executor.layers.linear
import
RowParallelLinear
from
vllm.model_executor.layers.pooler
import
(
ClassifierPooler
,
PoolerOutput
,
PoolingType
,
SimplePooler
)
PoolingType
,
SimplePooler
)
from
vllm.model_executor.models.interfaces
import
SupportsCrossEncoding
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.utils
import
maybe_prefix
...
...
@@ -213,7 +205,7 @@ def as_seq_cls_model(cls: _T) -> _T:
softmax
=
True
,
)
self
.
_
pooler
=
ClassifierPooler
(
self
.
pooler
=
ClassifierPooler
(
vllm_config
.
model_config
,
pooling
=
pooler
.
pooling
,
classifier
=
self
.
_classifier
,
...
...
@@ -234,13 +226,6 @@ def as_seq_cls_model(cls: _T) -> _T:
return
super
().
forward
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
def
pooler
(
self
,
hidden_states
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
tokens
=
getattr
(
self
.
config
,
"classifier_from_token"
,
None
)
method
=
getattr
(
self
.
config
,
"method"
,
None
)
...
...
vllm/model_executor/models/bert.py
View file @
90bd2ab6
...
...
@@ -18,12 +18,14 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.pooler
import
(
ClassifierPooler
,
Pooler
,
PoolingMethod
,
PoolingType
)
PoolingMethod
,
PoolingTask
,
PoolingType
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsCrossEncoding
,
SupportsQuant
,
SupportsV0Only
from
.utils
import
AutoWeightsLoader
,
WeightsMapper
,
maybe_prefix
...
...
@@ -80,7 +82,7 @@ class BertEmbedding(nn.Module):
return
embeddings
class
BertPooler
(
nn
.
Modu
le
):
class
BertPooler
(
Poo
le
r
):
def
__init__
(
self
,
config
:
BertConfig
):
super
().
__init__
()
...
...
@@ -89,6 +91,9 @@ class BertPooler(nn.Module):
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
activation
=
nn
.
Tanh
()
def
get_pooling_params
(
self
,
task
:
PoolingTask
)
->
Optional
[
PoolingParams
]:
return
self
.
pooling
.
get_pooling_params
(
task
)
def
forward
(
self
,
hidden_states
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
...
...
@@ -319,6 +324,9 @@ class BertOutput(nn.Module):
class
BertModel
(
nn
.
Module
,
SupportsQuant
):
is_pooling_model
=
True
packed_modules_mapping
=
{
"qkv_proj"
:
[
"query"
,
"key"
,
"value"
]}
def
__init__
(
self
,
...
...
@@ -403,12 +411,15 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
_pooler: An instance of Pooler used for pooling operations.
"""
is_pooling_model
=
True
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
pooler_config
=
vllm_config
.
model_config
.
pooler_config
self
.
model
=
self
.
_build_model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
_
pooler
=
self
.
_build_pooler
(
pooler_config
)
self
.
pooler
=
self
.
_build_pooler
(
pooler_config
)
def
forward
(
self
,
...
...
@@ -422,13 +433,6 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
inputs_embeds
=
inputs_embeds
,
intermediate_tensors
=
intermediate_tensors
)
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
weights_list
=
list
(
weights
)
...
...
@@ -466,6 +470,8 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
_pooler: An instance of Pooler used for pooling operations.
"""
is_pooling_model
=
True
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
...
...
@@ -476,7 +482,7 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
embedding_class
=
BertEmbedding
,
add_pooling_layer
=
True
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
num_labels
)
self
.
_
pooler
=
ClassifierPooler
(
self
.
pooler
=
ClassifierPooler
(
vllm_config
.
model_config
,
pooling
=
self
.
bert
.
pooler
,
classifier
=
self
.
classifier
,
...
...
@@ -487,13 +493,6 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
loaded_params
=
loader
.
load_weights
(
weights
)
return
loaded_params
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
...
...
vllm/model_executor/models/gpt2.py
View file @
90bd2ab6
...
...
@@ -40,9 +40,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
from
vllm.sequence
import
IntermediateTensors
from
..layers.pooler
import
Pooler
,
PoolingType
from
.interfaces
import
SupportsPP
...
...
@@ -332,6 +331,8 @@ class GPT2ForSequenceClassification(nn.Module):
_pooler: An instance of Pooler used for pooling operations.
"""
is_pooling_model
=
True
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
...
...
@@ -339,7 +340,7 @@ class GPT2ForSequenceClassification(nn.Module):
prefix
=
maybe_prefix
(
prefix
,
"gpt2"
))
self
.
score
=
nn
.
Linear
(
config
.
n_embd
,
config
.
num_labels
,
bias
=
False
)
pooler_config
=
vllm_config
.
model_config
.
pooler_config
self
.
_
pooler
=
Pooler
.
from_config_with_defaults
(
self
.
pooler
=
Pooler
.
from_config_with_defaults
(
pooler_config
,
pooling_type
=
PoolingType
.
LAST
,
normalize
=
False
,
...
...
@@ -349,13 +350,6 @@ class GPT2ForSequenceClassification(nn.Module):
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
)
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/gritlm.py
View file @
90bd2ab6
...
...
@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
array
import
array
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
...
...
@@ -195,6 +194,8 @@ class GritLM(LlamaForCausalLM, SupportsV0Only):
- "<|user|>
\n
PROMPT
\n
<|assistant|>
\n
"
"""
is_pooling_model
=
True
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
...
...
@@ -214,11 +215,4 @@ class GritLM(LlamaForCausalLM, SupportsV0Only):
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
**
kwargs
)
self
.
_pooler
=
GritLMPooler
(
vllm_config
.
model_config
)
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
self
.
pooler
=
GritLMPooler
(
vllm_config
.
model_config
)
vllm/model_executor/models/interfaces.py
View file @
90bd2ab6
...
...
@@ -119,13 +119,6 @@ class SupportsMultiModal(Protocol):
...
# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@
runtime_checkable
class
_SupportsMultiModalType
(
Protocol
):
supports_multimodal
:
Literal
[
True
]
@
overload
def
supports_multimodal
(
model
:
type
[
object
])
->
TypeIs
[
type
[
SupportsMultiModal
]]:
...
...
@@ -140,10 +133,7 @@ def supports_multimodal(model: object) -> TypeIs[SupportsMultiModal]:
def
supports_multimodal
(
model
:
Union
[
type
[
object
],
object
],
)
->
Union
[
TypeIs
[
type
[
SupportsMultiModal
]],
TypeIs
[
SupportsMultiModal
]]:
if
isinstance
(
model
,
type
):
return
isinstance
(
model
,
_SupportsMultiModalType
)
return
isinstance
(
model
,
SupportsMultiModal
)
return
getattr
(
model
,
"supports_multimodal"
,
False
)
@
runtime_checkable
...
...
@@ -174,13 +164,6 @@ class SupportsScoreTemplate(Protocol):
...
# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@
runtime_checkable
class
_SupportsScoreTemplateType
(
Protocol
):
supports_score_template
:
Literal
[
True
]
@
overload
def
supports_score_template
(
model
:
type
[
object
])
->
TypeIs
[
type
[
SupportsScoreTemplate
]]:
...
...
@@ -195,11 +178,7 @@ def supports_score_template(model: object) -> TypeIs[SupportsScoreTemplate]:
def
supports_score_template
(
model
:
Union
[
type
[
object
],
object
],
)
->
Union
[
TypeIs
[
type
[
SupportsScoreTemplate
]],
TypeIs
[
SupportsScoreTemplate
]]:
if
isinstance
(
model
,
type
):
return
isinstance
(
model
,
_SupportsScoreTemplateType
)
return
isinstance
(
model
,
SupportsScoreTemplate
)
return
getattr
(
model
,
"supports_score_template"
,
False
)
@
runtime_checkable
...
...
@@ -409,11 +388,6 @@ class HasInnerState(Protocol):
"""
@
runtime_checkable
class
_HasInnerStateType
(
Protocol
):
has_inner_state
:
ClassVar
[
Literal
[
True
]]
@
overload
def
has_inner_state
(
model
:
object
)
->
TypeIs
[
HasInnerState
]:
...
...
...
@@ -427,10 +401,7 @@ def has_inner_state(model: type[object]) -> TypeIs[type[HasInnerState]]:
def
has_inner_state
(
model
:
Union
[
type
[
object
],
object
]
)
->
Union
[
TypeIs
[
type
[
HasInnerState
]],
TypeIs
[
HasInnerState
]]:
if
isinstance
(
model
,
type
):
return
isinstance
(
model
,
_HasInnerStateType
)
return
isinstance
(
model
,
HasInnerState
)
return
getattr
(
model
,
"has_inner_state"
,
False
)
@
runtime_checkable
...
...
@@ -446,11 +417,6 @@ class IsAttentionFree(Protocol):
"""
@
runtime_checkable
class
_IsAttentionFreeType
(
Protocol
):
is_attention_free
:
ClassVar
[
Literal
[
True
]]
@
overload
def
is_attention_free
(
model
:
object
)
->
TypeIs
[
IsAttentionFree
]:
...
...
...
@@ -464,10 +430,7 @@ def is_attention_free(model: type[object]) -> TypeIs[type[IsAttentionFree]]:
def
is_attention_free
(
model
:
Union
[
type
[
object
],
object
]
)
->
Union
[
TypeIs
[
type
[
IsAttentionFree
]],
TypeIs
[
IsAttentionFree
]]:
if
isinstance
(
model
,
type
):
return
isinstance
(
model
,
_IsAttentionFreeType
)
return
isinstance
(
model
,
IsAttentionFree
)
return
getattr
(
model
,
"is_attention_free"
,
False
)
@
runtime_checkable
...
...
@@ -502,11 +465,6 @@ class IsHybrid(Protocol):
...
@
runtime_checkable
class
_IsHybridType
(
Protocol
):
is_hybrid
:
ClassVar
[
Literal
[
True
]]
@
overload
def
is_hybrid
(
model
:
object
)
->
TypeIs
[
IsHybrid
]:
...
...
...
@@ -520,10 +478,7 @@ def is_hybrid(model: type[object]) -> TypeIs[type[IsHybrid]]:
def
is_hybrid
(
model
:
Union
[
type
[
object
],
object
]
)
->
Union
[
TypeIs
[
type
[
IsHybrid
]],
TypeIs
[
IsHybrid
]]:
if
isinstance
(
model
,
type
):
return
isinstance
(
model
,
_IsHybridType
)
return
isinstance
(
model
,
IsHybrid
)
return
getattr
(
model
,
"is_hybrid"
,
False
)
@
runtime_checkable
...
...
@@ -598,11 +553,6 @@ class HasNoOps(Protocol):
has_noops
:
ClassVar
[
Literal
[
True
]]
=
True
@
runtime_checkable
class
_HasNoOpsType
(
Protocol
):
has_noops
:
ClassVar
[
Literal
[
True
]]
@
overload
def
has_noops
(
model
:
object
)
->
TypeIs
[
HasNoOps
]:
...
...
...
@@ -616,10 +566,7 @@ def has_noops(model: type[object]) -> TypeIs[type[HasNoOps]]:
def
has_noops
(
model
:
Union
[
type
[
object
],
object
]
)
->
Union
[
TypeIs
[
type
[
HasNoOps
]],
TypeIs
[
HasNoOps
]]:
if
isinstance
(
model
,
type
):
return
isinstance
(
model
,
_HasNoOpsType
)
return
isinstance
(
model
,
HasNoOps
)
return
getattr
(
model
,
"has_noops"
,
False
)
@
runtime_checkable
...
...
@@ -643,11 +590,7 @@ def supports_cross_encoding(model: object) -> TypeIs[SupportsCrossEncoding]:
def
_supports_cross_encoding
(
model
:
Union
[
type
[
object
],
object
],
)
->
Union
[
TypeIs
[
type
[
SupportsCrossEncoding
]],
TypeIs
[
SupportsCrossEncoding
]]:
if
isinstance
(
model
,
type
):
return
isinstance
(
model
,
SupportsCrossEncoding
)
return
isinstance
(
model
,
SupportsCrossEncoding
)
return
getattr
(
model
,
"supports_cross_encoding"
,
False
)
def
supports_cross_encoding
(
...
...
@@ -658,8 +601,9 @@ def supports_cross_encoding(
def
has_step_pooler
(
model
:
Union
[
type
[
object
],
object
])
->
bool
:
"""Check if the model uses step pooler."""
return
is_pooling_model
(
model
)
and
any
(
type
(
module
).
__name__
==
"StepPooler"
for
module
in
model
.
modules
())
from
vllm.model_executor.layers.pooler
import
StepPooler
return
is_pooling_model
(
model
)
and
isinstance
(
model
.
pooler
,
StepPooler
)
class
SupportsQuant
:
...
...
@@ -770,10 +714,7 @@ def supports_transcription(model: object) -> TypeIs[SupportsTranscription]:
def
supports_transcription
(
model
:
Union
[
type
[
object
],
object
],
)
->
Union
[
TypeIs
[
type
[
SupportsTranscription
]],
TypeIs
[
SupportsTranscription
]]:
if
isinstance
(
model
,
type
):
return
isinstance
(
model
,
SupportsTranscription
)
return
isinstance
(
model
,
SupportsTranscription
)
return
getattr
(
model
,
"supports_transcription"
,
False
)
@
runtime_checkable
...
...
@@ -796,7 +737,4 @@ def supports_v0_only(model: object) -> TypeIs[SupportsV0Only]:
def
supports_v0_only
(
model
:
Union
[
type
[
object
],
object
],
)
->
Union
[
TypeIs
[
type
[
SupportsV0Only
]],
TypeIs
[
SupportsV0Only
]]:
if
isinstance
(
model
,
type
):
return
isinstance
(
model
,
SupportsV0Only
)
return
isinstance
(
model
,
SupportsV0Only
)
return
getattr
(
model
,
"supports_v0_only"
,
False
)
vllm/model_executor/models/interfaces_base.py
View file @
90bd2ab6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
(
TYPE_CHECKING
,
Optional
,
Protocol
,
Union
,
overload
,
runtime_checkable
)
from
typing
import
(
TYPE_CHECKING
,
ClassVar
,
Literal
,
Optional
,
Protocol
,
Union
,
overload
,
runtime_checkable
)
import
torch
import
torch.nn
as
nn
...
...
@@ -13,8 +12,7 @@ from vllm.utils import supports_kw
if
TYPE_CHECKING
:
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.pooler
import
PoolerOutput
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.model_executor.layers.pooler
import
Pooler
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
logger
=
init_logger
(
__name__
)
...
...
@@ -130,16 +128,20 @@ def is_text_generation_model(
@
runtime_checkable
class
VllmModelForPooling
(
VllmModel
[
T
],
Protocol
[
T
]):
class
VllmModelForPooling
(
VllmModel
[
T
_co
],
Protocol
[
T
_co
]):
"""The interface required for all pooling models in vLLM."""
def
pooler
(
self
,
hidden_states
:
T
,
pooling_metadata
:
"PoolingMetadata"
,
)
->
"PoolerOutput"
:
"""Only called on TP rank 0."""
...
is_pooling_model
:
ClassVar
[
Literal
[
True
]]
=
True
"""
A flag that indicates this model supports pooling.
Note:
There is no need to redefine this flag if this class is in the
MRO of your model class.
"""
pooler
:
"Pooler"
"""The pooler is only called on TP rank 0."""
@
overload
...
...
@@ -158,7 +160,4 @@ def is_pooling_model(
if
not
is_vllm_model
(
model
):
return
False
if
isinstance
(
model
,
type
):
return
isinstance
(
model
,
VllmModelForPooling
)
return
isinstance
(
model
,
VllmModelForPooling
)
return
getattr
(
model
,
"is_pooling_model"
,
False
)
vllm/model_executor/models/internlm2.py
View file @
90bd2ab6
...
...
@@ -28,9 +28,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
...
...
@@ -404,6 +403,8 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
class
InternLM2ForRewardModel
(
InternLM2ForCausalLM
):
is_pooling_model
=
True
def
__init__
(
self
,
*
,
...
...
@@ -428,7 +429,7 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
)
pooler_config
=
vllm_config
.
model_config
.
pooler_config
self
.
_
pooler
=
Pooler
.
from_config_with_defaults
(
self
.
pooler
=
Pooler
.
from_config_with_defaults
(
pooler_config
,
pooling_type
=
PoolingType
.
ALL
,
normalize
=
False
,
...
...
@@ -446,10 +447,3 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
inputs_embeds
)
logits
,
_
=
self
.
v_head
(
hidden_states
)
return
logits
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
vllm/model_executor/models/jamba.py
View file @
90bd2ab6
...
...
@@ -27,9 +27,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.mamba_cache
import
(
MambaCacheManager
,
MambaCacheParams
)
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
LayerBlockType
from
.interfaces
import
(
HasInnerState
,
IsHybrid
,
SupportsLoRA
,
SupportsPP
,
...
...
@@ -563,6 +562,8 @@ def _is_moe_layer(name: str):
class
JambaForSequenceClassification
(
JambaForCausalLM
):
is_pooling_model
=
True
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
...
...
@@ -590,16 +591,9 @@ class JambaForSequenceClassification(JambaForCausalLM):
softmax
=
False
,
)
self
.
_
pooler
=
ClassifierPooler
(
self
.
pooler
=
ClassifierPooler
(
vllm_config
.
model_config
,
pooling
=
pooler
.
pooling
,
classifier
=
self
.
score
,
act_fn
=
pooler
.
head
.
activation
,
)
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
vllm/model_executor/models/jina_vl.py
View file @
90bd2ab6
...
...
@@ -13,9 +13,8 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.pooler
import
Pooler
,
PoolingType
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
(
SupportsCrossEncoding
,
SupportsMultiModal
,
SupportsScoreTemplate
)
...
...
@@ -72,6 +71,8 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration,
SupportsCrossEncoding
,
SupportsMultiModal
,
SupportsScoreTemplate
):
is_pooling_model
=
True
weight_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
"score.0."
:
"score.dense."
,
...
...
@@ -95,7 +96,7 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration,
self
.
score
=
JinaVLScorer
(
config
)
self
.
_
pooler
=
Pooler
.
from_config_with_defaults
(
self
.
pooler
=
Pooler
.
from_config_with_defaults
(
pooler_config
,
pooling_type
=
PoolingType
.
LAST
,
normalize
=
False
,
...
...
@@ -137,14 +138,6 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration,
logits
=
self
.
score
(
hidden_states
)
-
self
.
LOGIT_BIAS
return
logits
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
,
mapper
=
self
.
weight_mapper
)
vllm/model_executor/models/modernbert.py
View file @
90bd2ab6
...
...
@@ -13,14 +13,16 @@ from vllm.config import VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.pooler
import
(
BasePooler
,
ClassifierPooler
,
PoolingMethod
,
PoolingType
)
from
vllm.model_executor.layers.pooler
import
(
ClassifierPooler
,
Pooler
,
PoolingMethod
,
PoolingTask
,
PoolingType
)
from
vllm.model_executor.layers.rotary_embedding
import
RotaryEmbedding
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsCrossEncoding
,
SupportsV0Only
from
.utils
import
WeightsMapper
,
maybe_prefix
...
...
@@ -253,7 +255,7 @@ class ModernBertModel(nn.Module):
return
norm_outputs
class
ModernBertPooler
(
Base
Pooler
):
class
ModernBertPooler
(
Pooler
):
def
__init__
(
self
,
config
:
ModernBertConfig
):
super
().
__init__
()
...
...
@@ -268,6 +270,9 @@ class ModernBertPooler(BasePooler):
eps
=
config
.
norm_eps
,
bias
=
config
.
norm_bias
)
def
get_pooling_params
(
self
,
task
:
PoolingTask
)
->
Optional
[
PoolingParams
]:
return
self
.
pooling
.
get_pooling_params
(
task
)
def
forward
(
self
,
hidden_states
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
...
...
@@ -281,6 +286,8 @@ class ModernBertPooler(BasePooler):
class
ModernBertForSequenceClassification
(
nn
.
Module
,
SupportsV0Only
,
SupportsCrossEncoding
):
is_pooling_model
=
True
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
...
...
@@ -288,7 +295,7 @@ class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
self
.
model
=
ModernBertModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"modernbert"
))
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
num_labels
)
self
.
_
pooler
=
ClassifierPooler
(
self
.
pooler
=
ClassifierPooler
(
vllm_config
.
model_config
,
pooling
=
ModernBertPooler
(
config
),
classifier
=
self
.
classifier
,
...
...
@@ -321,13 +328,6 @@ class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
LongTensor
],
...
...
vllm/model_executor/models/prithvi_geospatial_mae.py
View file @
90bd2ab6
...
...
@@ -24,12 +24,13 @@ import torch.nn as nn
from
transformers
import
BatchFeature
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.pooler
import
(
AllPool
,
PoolerHead
,
PoolerIdentity
,
SimplePooler
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.interfaces
import
(
IsAttentionFree
,
SupportsMultiModal
,
SupportsV0Only
)
from
vllm.model_executor.models.utils
import
AutoWeightsLoader
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalInputs
,
MultiModalKwargs
)
...
...
@@ -37,8 +38,7 @@ from vllm.multimodal.parse import MultiModalDataItems
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
(
IntermediateTensors
,
PoolerOutput
,
PoolingSequenceGroupOutput
)
from
vllm.sequence
import
IntermediateTensors
class
PrithviGeoSpatialMAEProcessingInfo
(
BaseProcessingInfo
):
...
...
@@ -116,7 +116,9 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
dummy_inputs
=
PrithviGeoSpatialMAEInputBuilder
)
class
PrithviGeoSpatialMAE
(
nn
.
Module
,
IsAttentionFree
,
SupportsMultiModal
,
SupportsV0Only
):
""" Prithvi Masked Autoencoder"""
"""Prithvi Masked Autoencoder"""
is_pooling_model
=
True
@
classmethod
def
get_placeholder_str
(
cls
,
modality
:
str
,
i
:
int
)
->
Optional
[
str
]:
...
...
@@ -162,6 +164,8 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
"Only SemanticSegmentationTask is supported for now "
"by PrithviGeospatialMAE."
)
self
.
pooler
=
SimplePooler
(
AllPool
(),
PoolerHead
(
PoolerIdentity
()))
def
_parse_and_validate_multimodal_data
(
self
,
**
kwargs
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
...
...
@@ -189,7 +193,6 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
:
object
,
):
pixel_values
,
location_coords
=
(
self
.
_parse_and_validate_multimodal_data
(
**
kwargs
))
model_output
=
self
.
model
(
pixel_values
,
...
...
@@ -197,13 +200,6 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
return
model_output
.
output
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
PoolerOutput
([
PoolingSequenceGroupOutput
(
hidden_states
)])
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
params_list
=
[]
...
...
vllm/model_executor/models/qwen2_rm.py
View file @
90bd2ab6
...
...
@@ -16,8 +16,7 @@ from vllm.config import VllmConfig
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.pooler
import
Pooler
,
PoolingType
,
SimplePooler
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.qwen2
import
Qwen2Model
...
...
@@ -25,6 +24,10 @@ from .utils import AutoWeightsLoader, maybe_prefix
class
Qwen2RewardBaseModel
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
is_pooling_model
=
True
pooler
:
SimplePooler
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
...
...
@@ -61,7 +64,6 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
quant_config
=
quant_config
,
return_bias
=
False
),
)
self
.
_pooler
:
SimplePooler
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
...
...
@@ -80,13 +82,6 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
logits
=
self
.
score
(
hidden_states
)
return
logits
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
loader
=
AutoWeightsLoader
(
self
,
...
...
@@ -96,11 +91,11 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
class
Qwen2ForRewardModel
(
Qwen2RewardBaseModel
):
def
__init__
(
self
,
*
,
vllm_config
,
prefix
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
vllm_config
.
model_config
.
hf_config
.
num_labels
=
1
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
pooler_config
=
vllm_config
.
model_config
.
pooler_config
self
.
_
pooler
=
Pooler
.
from_config_with_defaults
(
self
.
pooler
=
Pooler
.
from_config_with_defaults
(
pooler_config
,
pooling_type
=
PoolingType
.
ALL
,
normalize
=
False
,
...
...
@@ -109,11 +104,11 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel):
class
Qwen2ForProcessRewardModel
(
Qwen2RewardBaseModel
):
def
__init__
(
self
,
*
,
vllm_config
,
prefix
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
vllm_config
.
model_config
.
hf_config
.
num_labels
=
2
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
pooler_config
=
vllm_config
.
model_config
.
pooler_config
self
.
_
pooler
=
Pooler
.
from_config_with_defaults
(
self
.
pooler
=
Pooler
.
from_config_with_defaults
(
pooler_config
,
pooling_type
=
PoolingType
.
STEP
,
normalize
=
False
,
...
...
vllm/model_executor/models/roberta.py
View file @
90bd2ab6
...
...
@@ -15,8 +15,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from
vllm.model_executor.models.bert
import
BertEmbeddingModel
,
BertModel
from
vllm.model_executor.models.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
maybe_prefix
)
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
from
vllm.sequence
import
IntermediateTensors
from
.bert_with_rope
import
BertWithRope
,
JinaRobertaModel
from
.interfaces
import
SupportsCrossEncoding
,
SupportsV0Only
...
...
@@ -165,6 +164,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
_pooler: An instance of Pooler used for pooling operations.
"""
is_pooling_model
=
True
jina_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_substr
=
{
'emb_ln'
:
"embeddings.LayerNorm"
,
...
...
@@ -188,7 +188,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
add_pooling_layer
=
False
)
self
.
classifier
=
RobertaClassificationHead
(
config
)
self
.
_
pooler
=
ClassifierPooler
(
self
.
pooler
=
ClassifierPooler
(
vllm_config
.
model_config
,
pooling
=
CLSPool
(),
classifier
=
self
.
classifier
,
...
...
@@ -198,13 +198,6 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
,
mapper
=
self
.
jina_to_vllm_mapper
)
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
...
...
vllm/pooling_params.py
View file @
90bd2ab6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
from
typing
import
TYPE_CHECKING
,
Optional
import
msgspec
...
...
@@ -15,24 +15,31 @@ class PoolingParams(
msgspec
.
Struct
,
omit_defaults
=
True
,
# type: ignore[call-arg]
array_like
=
True
):
# type: ignore[call-arg]
"""API parameters for pooling models. This
is currently a placeholder.
"""API parameters for pooling models. This
Attributes:
dimensions: Reduce the dimensions of embeddings
if model support matryoshka representation.
additional_data: Any additional data needed for pooling.
"""
dimensions
:
Optional
[
int
]
=
None
use_cross_encoder
:
bool
=
False
additional_data
:
Optional
[
Any
]
=
None
"""Internal use only."""
logits_processing_needs_token_ids
:
bool
=
False
"""Internal use only."""
output_kind
:
RequestOutputKind
=
RequestOutputKind
.
FINAL_ONLY
def
clone
(
self
)
->
"PoolingParams"
:
"""Returns a deep copy of the PoolingParams instance."""
return
PoolingParams
(
dimensions
=
self
.
dimensions
,
return
PoolingParams
(
dimensions
=
self
.
dimensions
,
use_cross_encoder
=
self
.
use_cross_encoder
,
additional_data
=
self
.
additional_data
)
logits_processing_needs_token_ids
=
self
.
logits_processing_needs_token_ids
,
)
def
verify
(
self
,
model_config
:
"ModelConfig"
)
->
None
:
if
self
.
dimensions
is
not
None
:
...
...
@@ -54,10 +61,12 @@ class PoolingParams(
raise
ValueError
(
"Dimensions must be greater than 0"
)
def
__repr__
(
self
)
->
str
:
return
(
f
"PoolingParams("
return
(
f
"PoolingParams("
f
"dimensions=
{
self
.
dimensions
}
, "
f
"use_cross_encoder=
{
self
.
use_cross_encoder
}
, "
f
"additional_metadata=
{
self
.
additional_data
}
)"
)
f
"logits_processing_needs_token_ids=
{
self
.
logits_processing_needs_token_ids
}
)"
)
def
__post_init__
(
self
)
->
None
:
assert
self
.
output_kind
==
RequestOutputKind
.
FINAL_ONLY
,
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment