Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
90bd2ab6
Unverified
Commit
90bd2ab6
authored
Jul 18, 2025
by
Cyrus Leung
Committed by
GitHub
Jul 17, 2025
Browse files
[Model] Update pooling model interface (#21058)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
9fb2d220
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
247 additions
and
345 deletions
+247
-345
tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py
...dd_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py
+5
-10
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+5
-29
vllm/model_executor/layers/pooler.py
vllm/model_executor/layers/pooler.py
+112
-64
vllm/model_executor/models/adapters.py
vllm/model_executor/models/adapters.py
+8
-23
vllm/model_executor/models/bert.py
vllm/model_executor/models/bert.py
+18
-19
vllm/model_executor/models/gpt2.py
vllm/model_executor/models/gpt2.py
+4
-10
vllm/model_executor/models/gritlm.py
vllm/model_executor/models/gritlm.py
+3
-9
vllm/model_executor/models/interfaces.py
vllm/model_executor/models/interfaces.py
+12
-74
vllm/model_executor/models/interfaces_base.py
vllm/model_executor/models/interfaces_base.py
+16
-17
vllm/model_executor/models/internlm2.py
vllm/model_executor/models/internlm2.py
+4
-10
vllm/model_executor/models/jamba.py
vllm/model_executor/models/jamba.py
+4
-10
vllm/model_executor/models/jina_vl.py
vllm/model_executor/models/jina_vl.py
+4
-11
vllm/model_executor/models/modernbert.py
vllm/model_executor/models/modernbert.py
+12
-12
vllm/model_executor/models/prithvi_geospatial_mae.py
vllm/model_executor/models/prithvi_geospatial_mae.py
+8
-12
vllm/model_executor/models/qwen2_rm.py
vllm/model_executor/models/qwen2_rm.py
+9
-14
vllm/model_executor/models/roberta.py
vllm/model_executor/models/roberta.py
+3
-10
vllm/pooling_params.py
vllm/pooling_params.py
+20
-11
No files found.
tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py
View file @
90bd2ab6
...
@@ -11,11 +11,13 @@ from vllm.config import VllmConfig
...
@@ -11,11 +11,13 @@ from vllm.config import VllmConfig
from
vllm.model_executor.layers.pooler
import
Pooler
,
PoolingType
from
vllm.model_executor.layers.pooler
import
Pooler
,
PoolingType
from
vllm.model_executor.models.gemma2
import
Gemma2Model
from
vllm.model_executor.models.gemma2
import
Gemma2Model
from
vllm.model_executor.models.utils
import
WeightsMapper
,
maybe_prefix
from
vllm.model_executor.models.utils
import
WeightsMapper
,
maybe_prefix
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
class
MyGemma2Embedding
(
nn
.
Module
):
class
MyGemma2Embedding
(
nn
.
Module
):
is_pooling_model
=
True
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
"model."
:
""
})
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
"model."
:
""
})
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
...
@@ -24,7 +26,7 @@ class MyGemma2Embedding(nn.Module):
...
@@ -24,7 +26,7 @@ class MyGemma2Embedding(nn.Module):
self
.
model
=
Gemma2Model
(
vllm_config
=
vllm_config
,
self
.
model
=
Gemma2Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
_
pooler
=
Pooler
.
from_config_with_defaults
(
self
.
pooler
=
Pooler
.
from_config_with_defaults
(
vllm_config
.
model_config
.
pooler_config
,
vllm_config
.
model_config
.
pooler_config
,
pooling_type
=
PoolingType
.
LAST
,
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
,
normalize
=
True
,
...
@@ -54,13 +56,6 @@ class MyGemma2Embedding(nn.Module):
...
@@ -54,13 +56,6 @@ class MyGemma2Embedding(nn.Module):
# Return all-zero embeddings
# Return all-zero embeddings
return
torch
.
zeros_like
(
hidden_states
)
return
torch
.
zeros_like
(
hidden_states
)
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
weights
=
self
.
hf_to_vllm_mapper
.
apply
(
weights
)
weights
=
self
.
hf_to_vllm_mapper
.
apply
(
weights
)
...
...
vllm/entrypoints/openai/protocol.py
View file @
90bd2ab6
...
@@ -1237,10 +1237,6 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
...
@@ -1237,10 +1237,6 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
user
:
Optional
[
str
]
=
None
user
:
Optional
[
str
]
=
None
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=-
1
)]]
=
None
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=-
1
)]]
=
None
# --8<-- [start:embedding-pooling-params]
additional_data
:
Optional
[
Any
]
=
None
# --8<-- [end:embedding-pooling-params]
# --8<-- [start:embedding-extra-params]
# --8<-- [start:embedding-extra-params]
add_special_tokens
:
bool
=
Field
(
add_special_tokens
:
bool
=
Field
(
default
=
True
,
default
=
True
,
...
@@ -1259,8 +1255,7 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
...
@@ -1259,8 +1255,7 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
# --8<-- [end:embedding-extra-params]
# --8<-- [end:embedding-extra-params]
def
to_pooling_params
(
self
):
def
to_pooling_params
(
self
):
return
PoolingParams
(
dimensions
=
self
.
dimensions
,
return
PoolingParams
(
dimensions
=
self
.
dimensions
)
additional_data
=
self
.
additional_data
)
class
EmbeddingChatRequest
(
OpenAIBaseModel
):
class
EmbeddingChatRequest
(
OpenAIBaseModel
):
...
@@ -1272,10 +1267,6 @@ class EmbeddingChatRequest(OpenAIBaseModel):
...
@@ -1272,10 +1267,6 @@ class EmbeddingChatRequest(OpenAIBaseModel):
user
:
Optional
[
str
]
=
None
user
:
Optional
[
str
]
=
None
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=-
1
)]]
=
None
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=-
1
)]]
=
None
# --8<-- [start:chat-embedding-pooling-params]
additional_data
:
Optional
[
Any
]
=
None
# --8<-- [end:chat-embedding-pooling-params]
# --8<-- [start:chat-embedding-extra-params]
# --8<-- [start:chat-embedding-extra-params]
add_special_tokens
:
bool
=
Field
(
add_special_tokens
:
bool
=
Field
(
default
=
False
,
default
=
False
,
...
@@ -1323,8 +1314,7 @@ class EmbeddingChatRequest(OpenAIBaseModel):
...
@@ -1323,8 +1314,7 @@ class EmbeddingChatRequest(OpenAIBaseModel):
return
data
return
data
def
to_pooling_params
(
self
):
def
to_pooling_params
(
self
):
return
PoolingParams
(
dimensions
=
self
.
dimensions
,
return
PoolingParams
(
dimensions
=
self
.
dimensions
)
additional_data
=
self
.
additional_data
)
EmbeddingRequest
=
Union
[
EmbeddingCompletionRequest
,
EmbeddingChatRequest
]
EmbeddingRequest
=
Union
[
EmbeddingCompletionRequest
,
EmbeddingChatRequest
]
...
@@ -1340,10 +1330,6 @@ class ScoreRequest(OpenAIBaseModel):
...
@@ -1340,10 +1330,6 @@ class ScoreRequest(OpenAIBaseModel):
text_2
:
Union
[
list
[
str
],
str
,
ScoreMultiModalParam
]
text_2
:
Union
[
list
[
str
],
str
,
ScoreMultiModalParam
]
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=-
1
)]]
=
None
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=-
1
)]]
=
None
# --8<-- [start:score-pooling-params]
additional_data
:
Optional
[
Any
]
=
None
# --8<-- [end:score-pooling-params]
# --8<-- [start:score-extra-params]
# --8<-- [start:score-extra-params]
mm_processor_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
Field
(
mm_processor_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
Field
(
...
@@ -1362,8 +1348,7 @@ class ScoreRequest(OpenAIBaseModel):
...
@@ -1362,8 +1348,7 @@ class ScoreRequest(OpenAIBaseModel):
# --8<-- [end:score-extra-params]
# --8<-- [end:score-extra-params]
def
to_pooling_params
(
self
,
*
,
use_cross_encoder
:
bool
=
False
):
def
to_pooling_params
(
self
,
*
,
use_cross_encoder
:
bool
=
False
):
return
PoolingParams
(
use_cross_encoder
=
use_cross_encoder
,
return
PoolingParams
(
use_cross_encoder
=
use_cross_encoder
)
additional_data
=
self
.
additional_data
)
class
RerankRequest
(
OpenAIBaseModel
):
class
RerankRequest
(
OpenAIBaseModel
):
...
@@ -1373,10 +1358,6 @@ class RerankRequest(OpenAIBaseModel):
...
@@ -1373,10 +1358,6 @@ class RerankRequest(OpenAIBaseModel):
top_n
:
int
=
Field
(
default_factory
=
lambda
:
0
)
top_n
:
int
=
Field
(
default_factory
=
lambda
:
0
)
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=-
1
)]]
=
None
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=-
1
)]]
=
None
# --8<-- [start:rerank-pooling-params]
additional_data
:
Optional
[
Any
]
=
None
# --8<-- [end:rerank-pooling-params]
# --8<-- [start:rerank-extra-params]
# --8<-- [start:rerank-extra-params]
mm_processor_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
Field
(
mm_processor_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
Field
(
...
@@ -1395,8 +1376,7 @@ class RerankRequest(OpenAIBaseModel):
...
@@ -1395,8 +1376,7 @@ class RerankRequest(OpenAIBaseModel):
# --8<-- [end:rerank-extra-params]
# --8<-- [end:rerank-extra-params]
def
to_pooling_params
(
self
,
*
,
use_cross_encoder
:
bool
=
False
):
def
to_pooling_params
(
self
,
*
,
use_cross_encoder
:
bool
=
False
):
return
PoolingParams
(
use_cross_encoder
=
use_cross_encoder
,
return
PoolingParams
(
use_cross_encoder
=
use_cross_encoder
)
additional_data
=
self
.
additional_data
)
class
RerankDocument
(
BaseModel
):
class
RerankDocument
(
BaseModel
):
...
@@ -1534,10 +1514,6 @@ class ClassificationRequest(OpenAIBaseModel):
...
@@ -1534,10 +1514,6 @@ class ClassificationRequest(OpenAIBaseModel):
truncate_prompt_tokens
:
Optional
[
int
]
=
None
truncate_prompt_tokens
:
Optional
[
int
]
=
None
user
:
Optional
[
str
]
=
None
user
:
Optional
[
str
]
=
None
# --8<-- [start:classification-pooling-params]
additional_data
:
Optional
[
Any
]
=
None
# --8<-- [end:classification-pooling-params]
# --8<-- [start:classification-extra-params]
# --8<-- [start:classification-extra-params]
priority
:
int
=
Field
(
priority
:
int
=
Field
(
default
=
0
,
default
=
0
,
...
@@ -1550,7 +1526,7 @@ class ClassificationRequest(OpenAIBaseModel):
...
@@ -1550,7 +1526,7 @@ class ClassificationRequest(OpenAIBaseModel):
# --8<-- [end:classification-extra-params]
# --8<-- [end:classification-extra-params]
def
to_pooling_params
(
self
):
def
to_pooling_params
(
self
):
return
PoolingParams
(
additional_data
=
self
.
additional_data
)
return
PoolingParams
()
class
ClassificationData
(
OpenAIBaseModel
):
class
ClassificationData
(
OpenAIBaseModel
):
...
...
vllm/model_executor/layers/pooler.py
View file @
90bd2ab6
...
@@ -3,22 +3,25 @@
...
@@ -3,22 +3,25 @@
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
enum
import
IntEnum
from
enum
import
IntEnum
from
typing
import
Callable
,
Optional
,
TypeVar
,
Union
from
typing
import
Callable
,
Literal
,
Optional
,
TypeVar
,
Union
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
typing_extensions
import
assert_never
from
vllm.config
import
ModelConfig
,
PoolerConfig
from
vllm.config
import
ModelConfig
,
PoolerConfig
from
vllm.model_executor.pooling_metadata
import
(
# noqa: E501
from
vllm.model_executor.pooling_metadata
import
(
# noqa: E501
PoolingMetadata
as
V0PoolingMetadata
)
PoolingMetadata
as
V0PoolingMetadata
)
from
vllm.model_executor.pooling_metadata
import
PoolingTensors
from
vllm.model_executor.pooling_metadata
import
PoolingTensors
from
vllm.pooling_params
import
PoolingParams
from
vllm.sequence
import
PoolerOutput
,
PoolingSequenceGroupOutput
from
vllm.sequence
import
PoolerOutput
,
PoolingSequenceGroupOutput
from
vllm.utils
import
resolve_obj_by_qualname
from
vllm.utils
import
resolve_obj_by_qualname
from
vllm.v1.pool.metadata
import
PoolingMetadata
as
V1PoolingMetadata
from
vllm.v1.pool.metadata
import
PoolingMetadata
as
V1PoolingMetadata
PoolingMetadata
=
Union
[
V0PoolingMetadata
,
V1PoolingMetadata
]
PoolingMetadata
=
Union
[
V0PoolingMetadata
,
V1PoolingMetadata
]
PoolingTask
=
Literal
[
"encode"
,
"embed"
,
"classify"
,
"score"
]
class
PoolingType
(
IntEnum
):
class
PoolingType
(
IntEnum
):
...
@@ -64,6 +67,48 @@ class ResolvedPoolingConfig:
...
@@ -64,6 +67,48 @@ class ResolvedPoolingConfig:
)
)
class
Pooler
(
nn
.
Module
,
ABC
):
"""The interface required for all poolers used in pooling models in vLLM."""
@
staticmethod
def
from_config_with_defaults
(
pooler_config
:
PoolerConfig
,
pooling_type
:
PoolingType
,
normalize
:
bool
,
softmax
:
bool
,
step_tag_id
:
Optional
[
int
]
=
None
,
returned_token_ids
:
Optional
[
list
[
int
]]
=
None
,
)
->
"Pooler"
:
resolved_config
=
ResolvedPoolingConfig
.
from_config_with_defaults
(
pooler_config
=
pooler_config
,
pooling_type
=
pooling_type
,
normalize
=
normalize
,
softmax
=
softmax
,
step_tag_id
=
step_tag_id
,
returned_token_ids
=
returned_token_ids
,
)
if
pooling_type
==
PoolingType
.
STEP
:
return
StepPooler
.
from_config
(
resolved_config
)
return
SimplePooler
.
from_config
(
resolved_config
)
def
get_pooling_params
(
self
,
task
:
PoolingTask
)
->
Optional
[
PoolingParams
]:
"""
Construct the pooling parameters to use for a task,
or `None` if the task is not supported.
"""
return
None
@
abstractmethod
def
forward
(
self
,
hidden_states
:
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
],
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
raise
NotImplementedError
def
get_prompt_lens
(
def
get_prompt_lens
(
hidden_states
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
hidden_states
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
pooling_metadata
:
PoolingMetadata
,
pooling_metadata
:
PoolingMetadata
,
...
@@ -104,17 +149,6 @@ def build_output(all_data: torch.Tensor) -> PoolerOutput:
...
@@ -104,17 +149,6 @@ def build_output(all_data: torch.Tensor) -> PoolerOutput:
return
PoolerOutput
(
outputs
=
all_outputs
)
return
PoolerOutput
(
outputs
=
all_outputs
)
class
BasePooler
(
nn
.
Module
):
@
abstractmethod
def
forward
(
self
,
hidden_states
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
raise
NotImplementedError
class
PoolingMethod
(
nn
.
Module
,
ABC
):
class
PoolingMethod
(
nn
.
Module
,
ABC
):
@
staticmethod
@
staticmethod
...
@@ -130,6 +164,10 @@ class PoolingMethod(nn.Module, ABC):
...
@@ -130,6 +164,10 @@ class PoolingMethod(nn.Module, ABC):
raise
NotImplementedError
(
f
"Unsupported method:
{
pooling_type
}
"
)
raise
NotImplementedError
(
f
"Unsupported method:
{
pooling_type
}
"
)
@
abstractmethod
def
get_pooling_params
(
self
,
task
:
PoolingTask
)
->
Optional
[
PoolingParams
]:
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
forward_one
(
def
forward_one
(
self
,
self
,
...
@@ -168,6 +206,14 @@ class PoolingMethod(nn.Module, ABC):
...
@@ -168,6 +206,14 @@ class PoolingMethod(nn.Module, ABC):
class
CLSPool
(
PoolingMethod
):
class
CLSPool
(
PoolingMethod
):
def
get_pooling_params
(
self
,
task
:
PoolingTask
)
->
Optional
[
PoolingParams
]:
# The equalities are split up to keep mypy happy
if
(
task
==
"encode"
or
task
==
"embed"
or
task
==
"classify"
or
task
==
"score"
):
return
PoolingParams
()
assert_never
(
task
)
def
forward_one
(
def
forward_one
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
...
@@ -190,6 +236,14 @@ class CLSPool(PoolingMethod):
...
@@ -190,6 +236,14 @@ class CLSPool(PoolingMethod):
class
LastPool
(
PoolingMethod
):
class
LastPool
(
PoolingMethod
):
def
get_pooling_params
(
self
,
task
:
PoolingTask
)
->
Optional
[
PoolingParams
]:
# The equalities are split up to keep mypy happy
if
(
task
==
"encode"
or
task
==
"embed"
or
task
==
"classify"
or
task
==
"score"
):
return
PoolingParams
()
assert_never
(
task
)
def
forward_one
(
def
forward_one
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
...
@@ -208,6 +262,16 @@ class LastPool(PoolingMethod):
...
@@ -208,6 +262,16 @@ class LastPool(PoolingMethod):
class
AllPool
(
PoolingMethod
):
class
AllPool
(
PoolingMethod
):
def
get_pooling_params
(
self
,
task
:
PoolingTask
)
->
Optional
[
PoolingParams
]:
if
task
==
"encode"
:
return
PoolingParams
()
# The equalities are split up to keep mypy happy
if
task
==
"embed"
or
task
==
"classify"
or
task
==
"score"
:
return
None
assert_never
(
task
)
def
forward_one
(
def
forward_one
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
...
@@ -235,6 +299,14 @@ class AllPool(PoolingMethod):
...
@@ -235,6 +299,14 @@ class AllPool(PoolingMethod):
class
MeanPool
(
PoolingMethod
):
class
MeanPool
(
PoolingMethod
):
def
get_pooling_params
(
self
,
task
:
PoolingTask
)
->
Optional
[
PoolingParams
]:
# The equalities are split up to keep mypy happy
if
(
task
==
"encode"
or
task
==
"embed"
or
task
==
"classify"
or
task
==
"score"
):
return
PoolingParams
()
assert_never
(
task
)
def
forward_one
(
def
forward_one
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
...
@@ -345,25 +417,6 @@ class LambdaPoolerActivation(PoolerActivation):
...
@@ -345,25 +417,6 @@ class LambdaPoolerActivation(PoolerActivation):
class
PoolerHead
(
nn
.
Module
):
class
PoolerHead
(
nn
.
Module
):
@
classmethod
def
from_config_with_defaults
(
cls
,
pooler_config
:
PoolerConfig
,
pooling_type
:
PoolingType
,
normalize
:
bool
,
softmax
:
bool
,
)
->
"PoolerHead"
:
resolved_config
=
ResolvedPoolingConfig
.
from_config_with_defaults
(
pooler_config
=
pooler_config
,
pooling_type
=
pooling_type
,
normalize
=
normalize
,
softmax
=
softmax
,
step_tag_id
=
None
,
returned_token_ids
=
None
,
)
return
cls
.
from_config
(
resolved_config
)
@
classmethod
@
classmethod
def
from_config
(
cls
,
pooler_config
:
ResolvedPoolingConfig
)
->
"PoolerHead"
:
def
from_config
(
cls
,
pooler_config
:
ResolvedPoolingConfig
)
->
"PoolerHead"
:
if
pooler_config
.
normalize
and
pooler_config
.
softmax
:
if
pooler_config
.
normalize
and
pooler_config
.
softmax
:
...
@@ -424,21 +477,17 @@ class PoolerHead(nn.Module):
...
@@ -424,21 +477,17 @@ class PoolerHead(nn.Module):
return
self
.
activation
(
pooled_data
)
return
self
.
activation
(
pooled_data
)
class
SimplePooler
(
Base
Pooler
):
class
SimplePooler
(
Pooler
):
"""A layer that pools specific information from hidden states.
"""A layer that pools specific information from hidden states.
This layer does the following:
This layer does the following:
1. Extracts specific tokens or aggregates data based on pooling method.
1. Extracts specific tokens or aggregates data based on pooling method.
2. Normalizes output if specified.
2. Normalizes output if specified.
3. Returns structured results as `PoolerOutput`.
3. Returns structured results as `PoolerOutput`.
Attributes:
pooling_type: The type of pooling to use.
normalize: Whether to normalize the pooled data.
"""
"""
@
classmethod
@
classmethod
def
from_config_with_defaults
(
def
from_config_with_defaults
(
# type: ignore[override]
cls
,
cls
,
pooler_config
:
PoolerConfig
,
pooler_config
:
PoolerConfig
,
pooling_type
:
PoolingType
,
pooling_type
:
PoolingType
,
...
@@ -471,6 +520,9 @@ class SimplePooler(BasePooler):
...
@@ -471,6 +520,9 @@ class SimplePooler(BasePooler):
self
.
pooling
=
pooling
self
.
pooling
=
pooling
self
.
head
=
head
self
.
head
=
head
def
get_pooling_params
(
self
,
task
:
PoolingTask
)
->
Optional
[
PoolingParams
]:
return
self
.
pooling
.
get_pooling_params
(
task
)
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
hidden_states
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
...
@@ -481,7 +533,7 @@ class SimplePooler(BasePooler):
...
@@ -481,7 +533,7 @@ class SimplePooler(BasePooler):
return
build_output
(
pooled_data
)
return
build_output
(
pooled_data
)
class
StepPooler
(
Base
Pooler
):
class
StepPooler
(
Pooler
):
@
classmethod
@
classmethod
def
from_config
(
cls
,
pooler_config
:
ResolvedPoolingConfig
)
->
"StepPooler"
:
def
from_config
(
cls
,
pooler_config
:
ResolvedPoolingConfig
)
->
"StepPooler"
:
...
@@ -543,6 +595,16 @@ class StepPooler(BasePooler):
...
@@ -543,6 +595,16 @@ class StepPooler(BasePooler):
return
pooled_data
return
pooled_data
def
get_pooling_params
(
self
,
task
:
PoolingTask
)
->
Optional
[
PoolingParams
]:
if
task
==
"encode"
:
return
PoolingParams
(
logits_processing_needs_token_ids
=
True
)
# The equalities are split up to keep mypy happy
if
task
==
"embed"
or
task
==
"classify"
or
task
==
"score"
:
return
None
assert_never
(
task
)
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
hidden_states
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
...
@@ -553,32 +615,6 @@ class StepPooler(BasePooler):
...
@@ -553,32 +615,6 @@ class StepPooler(BasePooler):
return
build_output
(
pooled_data
)
return
build_output
(
pooled_data
)
class
Pooler
(
nn
.
Module
):
@
staticmethod
def
from_config_with_defaults
(
pooler_config
:
PoolerConfig
,
pooling_type
:
PoolingType
,
normalize
:
bool
,
softmax
:
bool
,
step_tag_id
:
Optional
[
int
]
=
None
,
returned_token_ids
:
Optional
[
list
[
int
]]
=
None
,
)
->
BasePooler
:
resolved_config
=
ResolvedPoolingConfig
.
from_config_with_defaults
(
pooler_config
=
pooler_config
,
pooling_type
=
pooling_type
,
normalize
=
normalize
,
softmax
=
softmax
,
step_tag_id
=
step_tag_id
,
returned_token_ids
=
returned_token_ids
,
)
if
pooling_type
==
PoolingType
.
STEP
:
return
StepPooler
.
from_config
(
resolved_config
)
return
SimplePooler
.
from_config
(
resolved_config
)
PoolingFn
=
Callable
[
PoolingFn
=
Callable
[
[
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
PoolingMetadata
],
[
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
PoolingMetadata
],
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]]
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]]
...
@@ -618,6 +654,18 @@ class ClassifierPooler(nn.Module):
...
@@ -618,6 +654,18 @@ class ClassifierPooler(nn.Module):
return
(
self
.
cross_encoder_act_fn
return
(
self
.
cross_encoder_act_fn
if
use_cross_encoder
else
self
.
classification_act_fn
)
if
use_cross_encoder
else
self
.
classification_act_fn
)
def
get_pooling_params
(
self
,
task
:
PoolingTask
)
->
Optional
[
PoolingParams
]:
if
task
==
"encode"
:
return
PoolingParams
()
if
task
==
"embed"
:
return
None
if
task
==
"classify"
:
return
PoolingParams
()
if
task
==
"score"
:
return
PoolingParams
(
use_cross_encoder
=
True
)
assert_never
(
task
)
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
hidden_states
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
...
...
vllm/model_executor/models/adapters.py
View file @
90bd2ab6
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Iterable
from
collections.abc
import
Iterable
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
TypeVar
,
Union
,
cast
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
TypeVar
,
cast
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -42,13 +42,14 @@ def _create_pooling_model_cls(
...
@@ -42,13 +42,14 @@ def _create_pooling_model_cls(
default_softmax
:
bool
,
default_softmax
:
bool
,
)
->
_T
:
)
->
_T
:
# Lazy import
# Lazy import
from
vllm.model_executor.layers.pooler
import
Pooler
,
PoolerOutput
from
vllm.model_executor.layers.pooler
import
Pooler
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
.utils
import
AutoWeightsLoader
,
WeightsMapper
from
.utils
import
AutoWeightsLoader
,
WeightsMapper
class
ModelForPooling
(
orig_cls
,
VllmModelForPooling
):
class
ModelForPooling
(
orig_cls
,
VllmModelForPooling
):
is_pooling_model
=
True
def
__init__
(
def
__init__
(
self
,
self
,
*
,
*
,
...
@@ -66,27 +67,20 @@ def _create_pooling_model_cls(
...
@@ -66,27 +67,20 @@ def _create_pooling_model_cls(
delattr
(
self
,
attr
)
delattr
(
self
,
attr
)
# If the model already defines a pooler instance, don't overwrite it
# If the model already defines a pooler instance, don't overwrite it
if
not
getattr
(
self
,
"
_
pooler"
,
None
):
if
not
getattr
(
self
,
"pooler"
,
None
):
self
.
_init_pooler
(
vllm_config
,
prefix
=
prefix
)
self
.
_init_pooler
(
vllm_config
,
prefix
=
prefix
)
def
_init_pooler
(
self
,
vllm_config
:
"VllmConfig"
,
prefix
:
str
=
""
):
def
_init_pooler
(
self
,
vllm_config
:
"VllmConfig"
,
prefix
:
str
=
""
):
pooler_config
=
vllm_config
.
model_config
.
pooler_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
assert
pooler_config
is
not
None
self
.
_
pooler
=
Pooler
.
from_config_with_defaults
(
self
.
pooler
=
Pooler
.
from_config_with_defaults
(
pooler_config
,
pooler_config
,
pooling_type
=
default_pooling_type
,
pooling_type
=
default_pooling_type
,
normalize
=
default_normalize
,
normalize
=
default_normalize
,
softmax
=
default_softmax
,
softmax
=
default_softmax
,
)
)
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
# TODO: Support uninitialized params tracking
# TODO: Support uninitialized params tracking
...
@@ -171,10 +165,8 @@ def as_seq_cls_model(cls: _T) -> _T:
...
@@ -171,10 +165,8 @@ def as_seq_cls_model(cls: _T) -> _T:
# Lazy import
# Lazy import
from
vllm.model_executor.layers.linear
import
RowParallelLinear
from
vllm.model_executor.layers.linear
import
RowParallelLinear
from
vllm.model_executor.layers.pooler
import
(
ClassifierPooler
,
from
vllm.model_executor.layers.pooler
import
(
ClassifierPooler
,
PoolerOutput
,
PoolingType
,
PoolingType
,
SimplePooler
)
SimplePooler
)
from
vllm.model_executor.models.interfaces
import
SupportsCrossEncoding
from
vllm.model_executor.models.interfaces
import
SupportsCrossEncoding
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.utils
import
maybe_prefix
from
.utils
import
maybe_prefix
...
@@ -213,7 +205,7 @@ def as_seq_cls_model(cls: _T) -> _T:
...
@@ -213,7 +205,7 @@ def as_seq_cls_model(cls: _T) -> _T:
softmax
=
True
,
softmax
=
True
,
)
)
self
.
_
pooler
=
ClassifierPooler
(
self
.
pooler
=
ClassifierPooler
(
vllm_config
.
model_config
,
vllm_config
.
model_config
,
pooling
=
pooler
.
pooling
,
pooling
=
pooler
.
pooling
,
classifier
=
self
.
_classifier
,
classifier
=
self
.
_classifier
,
...
@@ -234,13 +226,6 @@ def as_seq_cls_model(cls: _T) -> _T:
...
@@ -234,13 +226,6 @@ def as_seq_cls_model(cls: _T) -> _T:
return
super
().
forward
(
input_ids
,
positions
,
intermediate_tensors
,
return
super
().
forward
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
)
inputs_embeds
)
def
pooler
(
self
,
hidden_states
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
tokens
=
getattr
(
self
.
config
,
"classifier_from_token"
,
None
)
tokens
=
getattr
(
self
.
config
,
"classifier_from_token"
,
None
)
method
=
getattr
(
self
.
config
,
"method"
,
None
)
method
=
getattr
(
self
.
config
,
"method"
,
None
)
...
...
vllm/model_executor/models/bert.py
View file @
90bd2ab6
...
@@ -18,12 +18,14 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...
@@ -18,12 +18,14 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.pooler
import
(
ClassifierPooler
,
Pooler
,
from
vllm.model_executor.layers.pooler
import
(
ClassifierPooler
,
Pooler
,
PoolingMethod
,
PoolingType
)
PoolingMethod
,
PoolingTask
,
PoolingType
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
VocabParallelEmbedding
)
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsCrossEncoding
,
SupportsQuant
,
SupportsV0Only
from
.interfaces
import
SupportsCrossEncoding
,
SupportsQuant
,
SupportsV0Only
from
.utils
import
AutoWeightsLoader
,
WeightsMapper
,
maybe_prefix
from
.utils
import
AutoWeightsLoader
,
WeightsMapper
,
maybe_prefix
...
@@ -80,7 +82,7 @@ class BertEmbedding(nn.Module):
...
@@ -80,7 +82,7 @@ class BertEmbedding(nn.Module):
return
embeddings
return
embeddings
class
BertPooler
(
nn
.
Modu
le
):
class
BertPooler
(
Poo
le
r
):
def
__init__
(
self
,
config
:
BertConfig
):
def
__init__
(
self
,
config
:
BertConfig
):
super
().
__init__
()
super
().
__init__
()
...
@@ -89,6 +91,9 @@ class BertPooler(nn.Module):
...
@@ -89,6 +91,9 @@ class BertPooler(nn.Module):
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
activation
=
nn
.
Tanh
()
self
.
activation
=
nn
.
Tanh
()
def
get_pooling_params
(
self
,
task
:
PoolingTask
)
->
Optional
[
PoolingParams
]:
return
self
.
pooling
.
get_pooling_params
(
task
)
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
hidden_states
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
...
@@ -319,6 +324,9 @@ class BertOutput(nn.Module):
...
@@ -319,6 +324,9 @@ class BertOutput(nn.Module):
class
BertModel
(
nn
.
Module
,
SupportsQuant
):
class
BertModel
(
nn
.
Module
,
SupportsQuant
):
is_pooling_model
=
True
packed_modules_mapping
=
{
"qkv_proj"
:
[
"query"
,
"key"
,
"value"
]}
packed_modules_mapping
=
{
"qkv_proj"
:
[
"query"
,
"key"
,
"value"
]}
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -403,12 +411,15 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
...
@@ -403,12 +411,15 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
_pooler: An instance of Pooler used for pooling operations.
_pooler: An instance of Pooler used for pooling operations.
"""
"""
is_pooling_model
=
True
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
pooler_config
=
vllm_config
.
model_config
.
pooler_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
self
.
model
=
self
.
_build_model
(
vllm_config
=
vllm_config
,
self
.
model
=
self
.
_build_model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
_
pooler
=
self
.
_build_pooler
(
pooler_config
)
self
.
pooler
=
self
.
_build_pooler
(
pooler_config
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -422,13 +433,6 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
...
@@ -422,13 +433,6 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
inputs_embeds
=
inputs_embeds
,
inputs_embeds
=
inputs_embeds
,
intermediate_tensors
=
intermediate_tensors
)
intermediate_tensors
=
intermediate_tensors
)
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
weights_list
=
list
(
weights
)
weights_list
=
list
(
weights
)
...
@@ -466,6 +470,8 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
...
@@ -466,6 +470,8 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
_pooler: An instance of Pooler used for pooling operations.
_pooler: An instance of Pooler used for pooling operations.
"""
"""
is_pooling_model
=
True
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
...
@@ -476,7 +482,7 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
...
@@ -476,7 +482,7 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
embedding_class
=
BertEmbedding
,
embedding_class
=
BertEmbedding
,
add_pooling_layer
=
True
)
add_pooling_layer
=
True
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
num_labels
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
num_labels
)
self
.
_
pooler
=
ClassifierPooler
(
self
.
pooler
=
ClassifierPooler
(
vllm_config
.
model_config
,
vllm_config
.
model_config
,
pooling
=
self
.
bert
.
pooler
,
pooling
=
self
.
bert
.
pooler
,
classifier
=
self
.
classifier
,
classifier
=
self
.
classifier
,
...
@@ -487,13 +493,6 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
...
@@ -487,13 +493,6 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
loaded_params
=
loader
.
load_weights
(
weights
)
loaded_params
=
loader
.
load_weights
(
weights
)
return
loaded_params
return
loaded_params
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
input_ids
:
Optional
[
torch
.
Tensor
],
...
...
vllm/model_executor/models/gpt2.py
View file @
90bd2ab6
...
@@ -40,9 +40,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
...
@@ -40,9 +40,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
from
vllm.sequence
import
IntermediateTensors
from
..layers.pooler
import
Pooler
,
PoolingType
from
..layers.pooler
import
Pooler
,
PoolingType
from
.interfaces
import
SupportsPP
from
.interfaces
import
SupportsPP
...
@@ -332,6 +331,8 @@ class GPT2ForSequenceClassification(nn.Module):
...
@@ -332,6 +331,8 @@ class GPT2ForSequenceClassification(nn.Module):
_pooler: An instance of Pooler used for pooling operations.
_pooler: An instance of Pooler used for pooling operations.
"""
"""
is_pooling_model
=
True
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
...
@@ -339,7 +340,7 @@ class GPT2ForSequenceClassification(nn.Module):
...
@@ -339,7 +340,7 @@ class GPT2ForSequenceClassification(nn.Module):
prefix
=
maybe_prefix
(
prefix
,
"gpt2"
))
prefix
=
maybe_prefix
(
prefix
,
"gpt2"
))
self
.
score
=
nn
.
Linear
(
config
.
n_embd
,
config
.
num_labels
,
bias
=
False
)
self
.
score
=
nn
.
Linear
(
config
.
n_embd
,
config
.
num_labels
,
bias
=
False
)
pooler_config
=
vllm_config
.
model_config
.
pooler_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
self
.
_
pooler
=
Pooler
.
from_config_with_defaults
(
self
.
pooler
=
Pooler
.
from_config_with_defaults
(
pooler_config
,
pooler_config
,
pooling_type
=
PoolingType
.
LAST
,
pooling_type
=
PoolingType
.
LAST
,
normalize
=
False
,
normalize
=
False
,
...
@@ -349,13 +350,6 @@ class GPT2ForSequenceClassification(nn.Module):
...
@@ -349,13 +350,6 @@ class GPT2ForSequenceClassification(nn.Module):
loader
=
AutoWeightsLoader
(
self
)
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
)
return
loader
.
load_weights
(
weights
)
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/gritlm.py
View file @
90bd2ab6
...
@@ -2,7 +2,6 @@
...
@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
array
import
array
from
array
import
array
from
typing
import
Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -195,6 +194,8 @@ class GritLM(LlamaForCausalLM, SupportsV0Only):
...
@@ -195,6 +194,8 @@ class GritLM(LlamaForCausalLM, SupportsV0Only):
- "<|user|>
\n
PROMPT
\n
<|assistant|>
\n
"
- "<|user|>
\n
PROMPT
\n
<|assistant|>
\n
"
"""
"""
is_pooling_model
=
True
def
__init__
(
def
__init__
(
self
,
self
,
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
...
@@ -214,11 +215,4 @@ class GritLM(LlamaForCausalLM, SupportsV0Only):
...
@@ -214,11 +215,4 @@ class GritLM(LlamaForCausalLM, SupportsV0Only):
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
**
kwargs
)
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
**
kwargs
)
self
.
_pooler
=
GritLMPooler
(
vllm_config
.
model_config
)
self
.
pooler
=
GritLMPooler
(
vllm_config
.
model_config
)
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
vllm/model_executor/models/interfaces.py
View file @
90bd2ab6
...
@@ -119,13 +119,6 @@ class SupportsMultiModal(Protocol):
...
@@ -119,13 +119,6 @@ class SupportsMultiModal(Protocol):
...
...
# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@
runtime_checkable
class
_SupportsMultiModalType
(
Protocol
):
supports_multimodal
:
Literal
[
True
]
@
overload
@
overload
def
supports_multimodal
(
def
supports_multimodal
(
model
:
type
[
object
])
->
TypeIs
[
type
[
SupportsMultiModal
]]:
model
:
type
[
object
])
->
TypeIs
[
type
[
SupportsMultiModal
]]:
...
@@ -140,10 +133,7 @@ def supports_multimodal(model: object) -> TypeIs[SupportsMultiModal]:
...
@@ -140,10 +133,7 @@ def supports_multimodal(model: object) -> TypeIs[SupportsMultiModal]:
def
supports_multimodal
(
def
supports_multimodal
(
model
:
Union
[
type
[
object
],
object
],
model
:
Union
[
type
[
object
],
object
],
)
->
Union
[
TypeIs
[
type
[
SupportsMultiModal
]],
TypeIs
[
SupportsMultiModal
]]:
)
->
Union
[
TypeIs
[
type
[
SupportsMultiModal
]],
TypeIs
[
SupportsMultiModal
]]:
if
isinstance
(
model
,
type
):
return
getattr
(
model
,
"supports_multimodal"
,
False
)
return
isinstance
(
model
,
_SupportsMultiModalType
)
return
isinstance
(
model
,
SupportsMultiModal
)
@
runtime_checkable
@
runtime_checkable
...
@@ -174,13 +164,6 @@ class SupportsScoreTemplate(Protocol):
...
@@ -174,13 +164,6 @@ class SupportsScoreTemplate(Protocol):
...
...
# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@
runtime_checkable
class
_SupportsScoreTemplateType
(
Protocol
):
supports_score_template
:
Literal
[
True
]
@
overload
@
overload
def
supports_score_template
(
def
supports_score_template
(
model
:
type
[
object
])
->
TypeIs
[
type
[
SupportsScoreTemplate
]]:
model
:
type
[
object
])
->
TypeIs
[
type
[
SupportsScoreTemplate
]]:
...
@@ -195,11 +178,7 @@ def supports_score_template(model: object) -> TypeIs[SupportsScoreTemplate]:
...
@@ -195,11 +178,7 @@ def supports_score_template(model: object) -> TypeIs[SupportsScoreTemplate]:
def
supports_score_template
(
def
supports_score_template
(
model
:
Union
[
type
[
object
],
object
],
model
:
Union
[
type
[
object
],
object
],
)
->
Union
[
TypeIs
[
type
[
SupportsScoreTemplate
]],
TypeIs
[
SupportsScoreTemplate
]]:
)
->
Union
[
TypeIs
[
type
[
SupportsScoreTemplate
]],
TypeIs
[
SupportsScoreTemplate
]]:
return
getattr
(
model
,
"supports_score_template"
,
False
)
if
isinstance
(
model
,
type
):
return
isinstance
(
model
,
_SupportsScoreTemplateType
)
return
isinstance
(
model
,
SupportsScoreTemplate
)
@
runtime_checkable
@
runtime_checkable
...
@@ -409,11 +388,6 @@ class HasInnerState(Protocol):
...
@@ -409,11 +388,6 @@ class HasInnerState(Protocol):
"""
"""
@
runtime_checkable
class
_HasInnerStateType
(
Protocol
):
has_inner_state
:
ClassVar
[
Literal
[
True
]]
@
overload
@
overload
def
has_inner_state
(
model
:
object
)
->
TypeIs
[
HasInnerState
]:
def
has_inner_state
(
model
:
object
)
->
TypeIs
[
HasInnerState
]:
...
...
...
@@ -427,10 +401,7 @@ def has_inner_state(model: type[object]) -> TypeIs[type[HasInnerState]]:
...
@@ -427,10 +401,7 @@ def has_inner_state(model: type[object]) -> TypeIs[type[HasInnerState]]:
def
has_inner_state
(
def
has_inner_state
(
model
:
Union
[
type
[
object
],
object
]
model
:
Union
[
type
[
object
],
object
]
)
->
Union
[
TypeIs
[
type
[
HasInnerState
]],
TypeIs
[
HasInnerState
]]:
)
->
Union
[
TypeIs
[
type
[
HasInnerState
]],
TypeIs
[
HasInnerState
]]:
if
isinstance
(
model
,
type
):
return
getattr
(
model
,
"has_inner_state"
,
False
)
return
isinstance
(
model
,
_HasInnerStateType
)
return
isinstance
(
model
,
HasInnerState
)
@
runtime_checkable
@
runtime_checkable
...
@@ -446,11 +417,6 @@ class IsAttentionFree(Protocol):
...
@@ -446,11 +417,6 @@ class IsAttentionFree(Protocol):
"""
"""
@
runtime_checkable
class
_IsAttentionFreeType
(
Protocol
):
is_attention_free
:
ClassVar
[
Literal
[
True
]]
@
overload
@
overload
def
is_attention_free
(
model
:
object
)
->
TypeIs
[
IsAttentionFree
]:
def
is_attention_free
(
model
:
object
)
->
TypeIs
[
IsAttentionFree
]:
...
...
...
@@ -464,10 +430,7 @@ def is_attention_free(model: type[object]) -> TypeIs[type[IsAttentionFree]]:
...
@@ -464,10 +430,7 @@ def is_attention_free(model: type[object]) -> TypeIs[type[IsAttentionFree]]:
def
is_attention_free
(
def
is_attention_free
(
model
:
Union
[
type
[
object
],
object
]
model
:
Union
[
type
[
object
],
object
]
)
->
Union
[
TypeIs
[
type
[
IsAttentionFree
]],
TypeIs
[
IsAttentionFree
]]:
)
->
Union
[
TypeIs
[
type
[
IsAttentionFree
]],
TypeIs
[
IsAttentionFree
]]:
if
isinstance
(
model
,
type
):
return
getattr
(
model
,
"is_attention_free"
,
False
)
return
isinstance
(
model
,
_IsAttentionFreeType
)
return
isinstance
(
model
,
IsAttentionFree
)
@
runtime_checkable
@
runtime_checkable
...
@@ -502,11 +465,6 @@ class IsHybrid(Protocol):
...
@@ -502,11 +465,6 @@ class IsHybrid(Protocol):
...
...
@
runtime_checkable
class
_IsHybridType
(
Protocol
):
is_hybrid
:
ClassVar
[
Literal
[
True
]]
@
overload
@
overload
def
is_hybrid
(
model
:
object
)
->
TypeIs
[
IsHybrid
]:
def
is_hybrid
(
model
:
object
)
->
TypeIs
[
IsHybrid
]:
...
...
...
@@ -520,10 +478,7 @@ def is_hybrid(model: type[object]) -> TypeIs[type[IsHybrid]]:
...
@@ -520,10 +478,7 @@ def is_hybrid(model: type[object]) -> TypeIs[type[IsHybrid]]:
def
is_hybrid
(
def
is_hybrid
(
model
:
Union
[
type
[
object
],
object
]
model
:
Union
[
type
[
object
],
object
]
)
->
Union
[
TypeIs
[
type
[
IsHybrid
]],
TypeIs
[
IsHybrid
]]:
)
->
Union
[
TypeIs
[
type
[
IsHybrid
]],
TypeIs
[
IsHybrid
]]:
if
isinstance
(
model
,
type
):
return
getattr
(
model
,
"is_hybrid"
,
False
)
return
isinstance
(
model
,
_IsHybridType
)
return
isinstance
(
model
,
IsHybrid
)
@
runtime_checkable
@
runtime_checkable
...
@@ -598,11 +553,6 @@ class HasNoOps(Protocol):
...
@@ -598,11 +553,6 @@ class HasNoOps(Protocol):
has_noops
:
ClassVar
[
Literal
[
True
]]
=
True
has_noops
:
ClassVar
[
Literal
[
True
]]
=
True
@
runtime_checkable
class
_HasNoOpsType
(
Protocol
):
has_noops
:
ClassVar
[
Literal
[
True
]]
@
overload
@
overload
def
has_noops
(
model
:
object
)
->
TypeIs
[
HasNoOps
]:
def
has_noops
(
model
:
object
)
->
TypeIs
[
HasNoOps
]:
...
...
...
@@ -616,10 +566,7 @@ def has_noops(model: type[object]) -> TypeIs[type[HasNoOps]]:
...
@@ -616,10 +566,7 @@ def has_noops(model: type[object]) -> TypeIs[type[HasNoOps]]:
def
has_noops
(
def
has_noops
(
model
:
Union
[
type
[
object
],
object
]
model
:
Union
[
type
[
object
],
object
]
)
->
Union
[
TypeIs
[
type
[
HasNoOps
]],
TypeIs
[
HasNoOps
]]:
)
->
Union
[
TypeIs
[
type
[
HasNoOps
]],
TypeIs
[
HasNoOps
]]:
if
isinstance
(
model
,
type
):
return
getattr
(
model
,
"has_noops"
,
False
)
return
isinstance
(
model
,
_HasNoOpsType
)
return
isinstance
(
model
,
HasNoOps
)
@
runtime_checkable
@
runtime_checkable
...
@@ -643,11 +590,7 @@ def supports_cross_encoding(model: object) -> TypeIs[SupportsCrossEncoding]:
...
@@ -643,11 +590,7 @@ def supports_cross_encoding(model: object) -> TypeIs[SupportsCrossEncoding]:
def
_supports_cross_encoding
(
def
_supports_cross_encoding
(
model
:
Union
[
type
[
object
],
object
],
model
:
Union
[
type
[
object
],
object
],
)
->
Union
[
TypeIs
[
type
[
SupportsCrossEncoding
]],
TypeIs
[
SupportsCrossEncoding
]]:
)
->
Union
[
TypeIs
[
type
[
SupportsCrossEncoding
]],
TypeIs
[
SupportsCrossEncoding
]]:
return
getattr
(
model
,
"supports_cross_encoding"
,
False
)
if
isinstance
(
model
,
type
):
return
isinstance
(
model
,
SupportsCrossEncoding
)
return
isinstance
(
model
,
SupportsCrossEncoding
)
def
supports_cross_encoding
(
def
supports_cross_encoding
(
...
@@ -658,8 +601,9 @@ def supports_cross_encoding(
...
@@ -658,8 +601,9 @@ def supports_cross_encoding(
def
has_step_pooler
(
model
:
Union
[
type
[
object
],
object
])
->
bool
:
def
has_step_pooler
(
model
:
Union
[
type
[
object
],
object
])
->
bool
:
"""Check if the model uses step pooler."""
"""Check if the model uses step pooler."""
return
is_pooling_model
(
model
)
and
any
(
from
vllm.model_executor.layers.pooler
import
StepPooler
type
(
module
).
__name__
==
"StepPooler"
for
module
in
model
.
modules
())
return
is_pooling_model
(
model
)
and
isinstance
(
model
.
pooler
,
StepPooler
)
class
SupportsQuant
:
class
SupportsQuant
:
...
@@ -770,10 +714,7 @@ def supports_transcription(model: object) -> TypeIs[SupportsTranscription]:
...
@@ -770,10 +714,7 @@ def supports_transcription(model: object) -> TypeIs[SupportsTranscription]:
def
supports_transcription
(
def
supports_transcription
(
model
:
Union
[
type
[
object
],
object
],
model
:
Union
[
type
[
object
],
object
],
)
->
Union
[
TypeIs
[
type
[
SupportsTranscription
]],
TypeIs
[
SupportsTranscription
]]:
)
->
Union
[
TypeIs
[
type
[
SupportsTranscription
]],
TypeIs
[
SupportsTranscription
]]:
if
isinstance
(
model
,
type
):
return
getattr
(
model
,
"supports_transcription"
,
False
)
return
isinstance
(
model
,
SupportsTranscription
)
return
isinstance
(
model
,
SupportsTranscription
)
@
runtime_checkable
@
runtime_checkable
...
@@ -796,7 +737,4 @@ def supports_v0_only(model: object) -> TypeIs[SupportsV0Only]:
...
@@ -796,7 +737,4 @@ def supports_v0_only(model: object) -> TypeIs[SupportsV0Only]:
def
supports_v0_only
(
def
supports_v0_only
(
model
:
Union
[
type
[
object
],
object
],
model
:
Union
[
type
[
object
],
object
],
)
->
Union
[
TypeIs
[
type
[
SupportsV0Only
]],
TypeIs
[
SupportsV0Only
]]:
)
->
Union
[
TypeIs
[
type
[
SupportsV0Only
]],
TypeIs
[
SupportsV0Only
]]:
if
isinstance
(
model
,
type
):
return
getattr
(
model
,
"supports_v0_only"
,
False
)
return
isinstance
(
model
,
SupportsV0Only
)
return
isinstance
(
model
,
SupportsV0Only
)
vllm/model_executor/models/interfaces_base.py
View file @
90bd2ab6
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
(
TYPE_CHECKING
,
ClassVar
,
Literal
,
Optional
,
Protocol
,
from
typing
import
(
TYPE_CHECKING
,
Optional
,
Protocol
,
Union
,
overload
,
Union
,
overload
,
runtime_checkable
)
runtime_checkable
)
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -13,8 +12,7 @@ from vllm.utils import supports_kw
...
@@ -13,8 +12,7 @@ from vllm.utils import supports_kw
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.pooler
import
PoolerOutput
from
vllm.model_executor.layers.pooler
import
Pooler
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -130,16 +128,20 @@ def is_text_generation_model(
...
@@ -130,16 +128,20 @@ def is_text_generation_model(
@
runtime_checkable
@
runtime_checkable
class
VllmModelForPooling
(
VllmModel
[
T
],
Protocol
[
T
]):
class
VllmModelForPooling
(
VllmModel
[
T
_co
],
Protocol
[
T
_co
]):
"""The interface required for all pooling models in vLLM."""
"""The interface required for all pooling models in vLLM."""
def
pooler
(
is_pooling_model
:
ClassVar
[
Literal
[
True
]]
=
True
self
,
"""
hidden_states
:
T
,
A flag that indicates this model supports pooling.
pooling_metadata
:
"PoolingMetadata"
,
)
->
"PoolerOutput"
:
Note:
"""Only called on TP rank 0."""
There is no need to redefine this flag if this class is in the
...
MRO of your model class.
"""
pooler
:
"Pooler"
"""The pooler is only called on TP rank 0."""
@
overload
@
overload
...
@@ -158,7 +160,4 @@ def is_pooling_model(
...
@@ -158,7 +160,4 @@ def is_pooling_model(
if
not
is_vllm_model
(
model
):
if
not
is_vllm_model
(
model
):
return
False
return
False
if
isinstance
(
model
,
type
):
return
getattr
(
model
,
"is_pooling_model"
,
False
)
return
isinstance
(
model
,
VllmModelForPooling
)
return
isinstance
(
model
,
VllmModelForPooling
)
vllm/model_executor/models/internlm2.py
View file @
90bd2ab6
...
@@ -28,9 +28,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
...
@@ -28,9 +28,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
from
.utils
import
(
is_pp_missing_parameter
,
...
@@ -404,6 +403,8 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
...
@@ -404,6 +403,8 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
class
InternLM2ForRewardModel
(
InternLM2ForCausalLM
):
class
InternLM2ForRewardModel
(
InternLM2ForCausalLM
):
is_pooling_model
=
True
def
__init__
(
def
__init__
(
self
,
self
,
*
,
*
,
...
@@ -428,7 +429,7 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
...
@@ -428,7 +429,7 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
)
)
pooler_config
=
vllm_config
.
model_config
.
pooler_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
self
.
_
pooler
=
Pooler
.
from_config_with_defaults
(
self
.
pooler
=
Pooler
.
from_config_with_defaults
(
pooler_config
,
pooler_config
,
pooling_type
=
PoolingType
.
ALL
,
pooling_type
=
PoolingType
.
ALL
,
normalize
=
False
,
normalize
=
False
,
...
@@ -446,10 +447,3 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
...
@@ -446,10 +447,3 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
inputs_embeds
)
inputs_embeds
)
logits
,
_
=
self
.
v_head
(
hidden_states
)
logits
,
_
=
self
.
v_head
(
hidden_states
)
return
logits
return
logits
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
vllm/model_executor/models/jamba.py
View file @
90bd2ab6
...
@@ -27,9 +27,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -27,9 +27,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.mamba_cache
import
(
MambaCacheManager
,
from
vllm.model_executor.models.mamba_cache
import
(
MambaCacheManager
,
MambaCacheParams
)
MambaCacheParams
)
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
LayerBlockType
from
vllm.utils
import
LayerBlockType
from
.interfaces
import
(
HasInnerState
,
IsHybrid
,
SupportsLoRA
,
SupportsPP
,
from
.interfaces
import
(
HasInnerState
,
IsHybrid
,
SupportsLoRA
,
SupportsPP
,
...
@@ -563,6 +562,8 @@ def _is_moe_layer(name: str):
...
@@ -563,6 +562,8 @@ def _is_moe_layer(name: str):
class
JambaForSequenceClassification
(
JambaForCausalLM
):
class
JambaForSequenceClassification
(
JambaForCausalLM
):
is_pooling_model
=
True
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
...
@@ -590,16 +591,9 @@ class JambaForSequenceClassification(JambaForCausalLM):
...
@@ -590,16 +591,9 @@ class JambaForSequenceClassification(JambaForCausalLM):
softmax
=
False
,
softmax
=
False
,
)
)
self
.
_
pooler
=
ClassifierPooler
(
self
.
pooler
=
ClassifierPooler
(
vllm_config
.
model_config
,
vllm_config
.
model_config
,
pooling
=
pooler
.
pooling
,
pooling
=
pooler
.
pooling
,
classifier
=
self
.
score
,
classifier
=
self
.
score
,
act_fn
=
pooler
.
head
.
activation
,
act_fn
=
pooler
.
head
.
activation
,
)
)
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
vllm/model_executor/models/jina_vl.py
View file @
90bd2ab6
...
@@ -13,9 +13,8 @@ from vllm.logger import init_logger
...
@@ -13,9 +13,8 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.pooler
import
Pooler
,
PoolingType
from
vllm.model_executor.layers.pooler
import
Pooler
,
PoolingType
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
(
SupportsCrossEncoding
,
SupportsMultiModal
,
from
.interfaces
import
(
SupportsCrossEncoding
,
SupportsMultiModal
,
SupportsScoreTemplate
)
SupportsScoreTemplate
)
...
@@ -72,6 +71,8 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration,
...
@@ -72,6 +71,8 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration,
SupportsCrossEncoding
,
SupportsCrossEncoding
,
SupportsMultiModal
,
SupportsMultiModal
,
SupportsScoreTemplate
):
SupportsScoreTemplate
):
is_pooling_model
=
True
weight_mapper
=
WeightsMapper
(
weight_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
orig_to_new_prefix
=
{
"score.0."
:
"score.dense."
,
"score.0."
:
"score.dense."
,
...
@@ -95,7 +96,7 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration,
...
@@ -95,7 +96,7 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration,
self
.
score
=
JinaVLScorer
(
config
)
self
.
score
=
JinaVLScorer
(
config
)
self
.
_
pooler
=
Pooler
.
from_config_with_defaults
(
self
.
pooler
=
Pooler
.
from_config_with_defaults
(
pooler_config
,
pooler_config
,
pooling_type
=
PoolingType
.
LAST
,
pooling_type
=
PoolingType
.
LAST
,
normalize
=
False
,
normalize
=
False
,
...
@@ -137,14 +138,6 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration,
...
@@ -137,14 +138,6 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration,
logits
=
self
.
score
(
hidden_states
)
-
self
.
LOGIT_BIAS
logits
=
self
.
score
(
hidden_states
)
-
self
.
LOGIT_BIAS
return
logits
return
logits
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
loader
=
AutoWeightsLoader
(
self
)
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
,
mapper
=
self
.
weight_mapper
)
return
loader
.
load_weights
(
weights
,
mapper
=
self
.
weight_mapper
)
vllm/model_executor/models/modernbert.py
View file @
90bd2ab6
...
@@ -13,14 +13,16 @@ from vllm.config import VllmConfig
...
@@ -13,14 +13,16 @@ from vllm.config import VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.pooler
import
(
BasePooler
,
ClassifierPooler
,
from
vllm.model_executor.layers.pooler
import
(
ClassifierPooler
,
Pooler
,
PoolingMethod
,
PoolingType
)
PoolingMethod
,
PoolingTask
,
PoolingType
)
from
vllm.model_executor.layers.rotary_embedding
import
RotaryEmbedding
from
vllm.model_executor.layers.rotary_embedding
import
RotaryEmbedding
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsCrossEncoding
,
SupportsV0Only
from
.interfaces
import
SupportsCrossEncoding
,
SupportsV0Only
from
.utils
import
WeightsMapper
,
maybe_prefix
from
.utils
import
WeightsMapper
,
maybe_prefix
...
@@ -253,7 +255,7 @@ class ModernBertModel(nn.Module):
...
@@ -253,7 +255,7 @@ class ModernBertModel(nn.Module):
return
norm_outputs
return
norm_outputs
class
ModernBertPooler
(
Base
Pooler
):
class
ModernBertPooler
(
Pooler
):
def
__init__
(
self
,
config
:
ModernBertConfig
):
def
__init__
(
self
,
config
:
ModernBertConfig
):
super
().
__init__
()
super
().
__init__
()
...
@@ -268,6 +270,9 @@ class ModernBertPooler(BasePooler):
...
@@ -268,6 +270,9 @@ class ModernBertPooler(BasePooler):
eps
=
config
.
norm_eps
,
eps
=
config
.
norm_eps
,
bias
=
config
.
norm_bias
)
bias
=
config
.
norm_bias
)
def
get_pooling_params
(
self
,
task
:
PoolingTask
)
->
Optional
[
PoolingParams
]:
return
self
.
pooling
.
get_pooling_params
(
task
)
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
hidden_states
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
...
@@ -281,6 +286,8 @@ class ModernBertPooler(BasePooler):
...
@@ -281,6 +286,8 @@ class ModernBertPooler(BasePooler):
class
ModernBertForSequenceClassification
(
nn
.
Module
,
SupportsV0Only
,
class
ModernBertForSequenceClassification
(
nn
.
Module
,
SupportsV0Only
,
SupportsCrossEncoding
):
SupportsCrossEncoding
):
is_pooling_model
=
True
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
...
@@ -288,7 +295,7 @@ class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
...
@@ -288,7 +295,7 @@ class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
self
.
model
=
ModernBertModel
(
vllm_config
=
vllm_config
,
self
.
model
=
ModernBertModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"modernbert"
))
prefix
=
maybe_prefix
(
prefix
,
"modernbert"
))
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
num_labels
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
num_labels
)
self
.
_
pooler
=
ClassifierPooler
(
self
.
pooler
=
ClassifierPooler
(
vllm_config
.
model_config
,
vllm_config
.
model_config
,
pooling
=
ModernBertPooler
(
config
),
pooling
=
ModernBertPooler
(
config
),
classifier
=
self
.
classifier
,
classifier
=
self
.
classifier
,
...
@@ -321,13 +328,6 @@ class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
...
@@ -321,13 +328,6 @@ class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
default_weight_loader
)
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
Optional
[
torch
.
LongTensor
],
input_ids
:
Optional
[
torch
.
LongTensor
],
...
...
vllm/model_executor/models/prithvi_geospatial_mae.py
View file @
90bd2ab6
...
@@ -24,12 +24,13 @@ import torch.nn as nn
...
@@ -24,12 +24,13 @@ import torch.nn as nn
from
transformers
import
BatchFeature
from
transformers
import
BatchFeature
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.pooler
import
(
AllPool
,
PoolerHead
,
PoolerIdentity
,
SimplePooler
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.interfaces
import
(
IsAttentionFree
,
from
vllm.model_executor.models.interfaces
import
(
IsAttentionFree
,
SupportsMultiModal
,
SupportsMultiModal
,
SupportsV0Only
)
SupportsV0Only
)
from
vllm.model_executor.models.utils
import
AutoWeightsLoader
from
vllm.model_executor.models.utils
import
AutoWeightsLoader
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalInputs
,
MultiModalKwargs
)
MultiModalInputs
,
MultiModalKwargs
)
...
@@ -37,8 +38,7 @@ from vllm.multimodal.parse import MultiModalDataItems
...
@@ -37,8 +38,7 @@ from vllm.multimodal.parse import MultiModalDataItems
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptUpdate
)
BaseProcessingInfo
,
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
(
IntermediateTensors
,
PoolerOutput
,
from
vllm.sequence
import
IntermediateTensors
PoolingSequenceGroupOutput
)
class
PrithviGeoSpatialMAEProcessingInfo
(
BaseProcessingInfo
):
class
PrithviGeoSpatialMAEProcessingInfo
(
BaseProcessingInfo
):
...
@@ -116,7 +116,9 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
...
@@ -116,7 +116,9 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
dummy_inputs
=
PrithviGeoSpatialMAEInputBuilder
)
dummy_inputs
=
PrithviGeoSpatialMAEInputBuilder
)
class
PrithviGeoSpatialMAE
(
nn
.
Module
,
IsAttentionFree
,
SupportsMultiModal
,
class
PrithviGeoSpatialMAE
(
nn
.
Module
,
IsAttentionFree
,
SupportsMultiModal
,
SupportsV0Only
):
SupportsV0Only
):
""" Prithvi Masked Autoencoder"""
"""Prithvi Masked Autoencoder"""
is_pooling_model
=
True
@
classmethod
@
classmethod
def
get_placeholder_str
(
cls
,
modality
:
str
,
i
:
int
)
->
Optional
[
str
]:
def
get_placeholder_str
(
cls
,
modality
:
str
,
i
:
int
)
->
Optional
[
str
]:
...
@@ -162,6 +164,8 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
...
@@ -162,6 +164,8 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
"Only SemanticSegmentationTask is supported for now "
"Only SemanticSegmentationTask is supported for now "
"by PrithviGeospatialMAE."
)
"by PrithviGeospatialMAE."
)
self
.
pooler
=
SimplePooler
(
AllPool
(),
PoolerHead
(
PoolerIdentity
()))
def
_parse_and_validate_multimodal_data
(
def
_parse_and_validate_multimodal_data
(
self
,
**
kwargs
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
self
,
**
kwargs
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
...
@@ -189,7 +193,6 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
...
@@ -189,7 +193,6 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
:
object
,
**
kwargs
:
object
,
):
):
pixel_values
,
location_coords
=
(
pixel_values
,
location_coords
=
(
self
.
_parse_and_validate_multimodal_data
(
**
kwargs
))
self
.
_parse_and_validate_multimodal_data
(
**
kwargs
))
model_output
=
self
.
model
(
pixel_values
,
model_output
=
self
.
model
(
pixel_values
,
...
@@ -197,13 +200,6 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
...
@@ -197,13 +200,6 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
return
model_output
.
output
return
model_output
.
output
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
PoolerOutput
([
PoolingSequenceGroupOutput
(
hidden_states
)])
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
torch
.
Tensor
]])
->
set
[
str
]:
params_list
=
[]
params_list
=
[]
...
...
vllm/model_executor/models/qwen2_rm.py
View file @
90bd2ab6
...
@@ -16,8 +16,7 @@ from vllm.config import VllmConfig
...
@@ -16,8 +16,7 @@ from vllm.config import VllmConfig
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.pooler
import
Pooler
,
PoolingType
,
SimplePooler
from
vllm.model_executor.layers.pooler
import
Pooler
,
PoolingType
,
SimplePooler
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.qwen2
import
Qwen2Model
from
.qwen2
import
Qwen2Model
...
@@ -25,6 +24,10 @@ from .utils import AutoWeightsLoader, maybe_prefix
...
@@ -25,6 +24,10 @@ from .utils import AutoWeightsLoader, maybe_prefix
class
Qwen2RewardBaseModel
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
class
Qwen2RewardBaseModel
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
is_pooling_model
=
True
pooler
:
SimplePooler
packed_modules_mapping
=
{
packed_modules_mapping
=
{
"qkv_proj"
:
[
"qkv_proj"
:
[
"q_proj"
,
"q_proj"
,
...
@@ -61,7 +64,6 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -61,7 +64,6 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
return_bias
=
False
),
return_bias
=
False
),
)
)
self
.
_pooler
:
SimplePooler
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
self
.
model
.
make_empty_intermediate_tensors
)
...
@@ -80,13 +82,6 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -80,13 +82,6 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
logits
=
self
.
score
(
hidden_states
)
logits
=
self
.
score
(
hidden_states
)
return
logits
return
logits
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
torch
.
Tensor
]])
->
set
[
str
]:
loader
=
AutoWeightsLoader
(
self
,
loader
=
AutoWeightsLoader
(
self
,
...
@@ -96,11 +91,11 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -96,11 +91,11 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
class
Qwen2ForRewardModel
(
Qwen2RewardBaseModel
):
class
Qwen2ForRewardModel
(
Qwen2RewardBaseModel
):
def
__init__
(
self
,
*
,
vllm_config
,
prefix
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
vllm_config
.
model_config
.
hf_config
.
num_labels
=
1
vllm_config
.
model_config
.
hf_config
.
num_labels
=
1
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
pooler_config
=
vllm_config
.
model_config
.
pooler_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
self
.
_
pooler
=
Pooler
.
from_config_with_defaults
(
self
.
pooler
=
Pooler
.
from_config_with_defaults
(
pooler_config
,
pooler_config
,
pooling_type
=
PoolingType
.
ALL
,
pooling_type
=
PoolingType
.
ALL
,
normalize
=
False
,
normalize
=
False
,
...
@@ -109,11 +104,11 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel):
...
@@ -109,11 +104,11 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel):
class
Qwen2ForProcessRewardModel
(
Qwen2RewardBaseModel
):
class
Qwen2ForProcessRewardModel
(
Qwen2RewardBaseModel
):
def
__init__
(
self
,
*
,
vllm_config
,
prefix
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
vllm_config
.
model_config
.
hf_config
.
num_labels
=
2
vllm_config
.
model_config
.
hf_config
.
num_labels
=
2
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
pooler_config
=
vllm_config
.
model_config
.
pooler_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
self
.
_
pooler
=
Pooler
.
from_config_with_defaults
(
self
.
pooler
=
Pooler
.
from_config_with_defaults
(
pooler_config
,
pooler_config
,
pooling_type
=
PoolingType
.
STEP
,
pooling_type
=
PoolingType
.
STEP
,
normalize
=
False
,
normalize
=
False
,
...
...
vllm/model_executor/models/roberta.py
View file @
90bd2ab6
...
@@ -15,8 +15,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -15,8 +15,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from
vllm.model_executor.models.bert
import
BertEmbeddingModel
,
BertModel
from
vllm.model_executor.models.bert
import
BertEmbeddingModel
,
BertModel
from
vllm.model_executor.models.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
from
vllm.model_executor.models.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
maybe_prefix
)
maybe_prefix
)
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
from
.bert_with_rope
import
BertWithRope
,
JinaRobertaModel
from
.bert_with_rope
import
BertWithRope
,
JinaRobertaModel
from
.interfaces
import
SupportsCrossEncoding
,
SupportsV0Only
from
.interfaces
import
SupportsCrossEncoding
,
SupportsV0Only
...
@@ -165,6 +164,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
...
@@ -165,6 +164,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
_pooler: An instance of Pooler used for pooling operations.
_pooler: An instance of Pooler used for pooling operations.
"""
"""
is_pooling_model
=
True
jina_to_vllm_mapper
=
WeightsMapper
(
jina_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_substr
=
{
orig_to_new_substr
=
{
'emb_ln'
:
"embeddings.LayerNorm"
,
'emb_ln'
:
"embeddings.LayerNorm"
,
...
@@ -188,7 +188,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
...
@@ -188,7 +188,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
add_pooling_layer
=
False
)
add_pooling_layer
=
False
)
self
.
classifier
=
RobertaClassificationHead
(
config
)
self
.
classifier
=
RobertaClassificationHead
(
config
)
self
.
_
pooler
=
ClassifierPooler
(
self
.
pooler
=
ClassifierPooler
(
vllm_config
.
model_config
,
vllm_config
.
model_config
,
pooling
=
CLSPool
(),
pooling
=
CLSPool
(),
classifier
=
self
.
classifier
,
classifier
=
self
.
classifier
,
...
@@ -198,13 +198,6 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
...
@@ -198,13 +198,6 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
loader
=
AutoWeightsLoader
(
self
)
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
,
mapper
=
self
.
jina_to_vllm_mapper
)
return
loader
.
load_weights
(
weights
,
mapper
=
self
.
jina_to_vllm_mapper
)
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
input_ids
:
Optional
[
torch
.
Tensor
],
...
...
vllm/pooling_params.py
View file @
90bd2ab6
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
from
typing
import
TYPE_CHECKING
,
Optional
import
msgspec
import
msgspec
...
@@ -15,24 +15,31 @@ class PoolingParams(
...
@@ -15,24 +15,31 @@ class PoolingParams(
msgspec
.
Struct
,
msgspec
.
Struct
,
omit_defaults
=
True
,
# type: ignore[call-arg]
omit_defaults
=
True
,
# type: ignore[call-arg]
array_like
=
True
):
# type: ignore[call-arg]
array_like
=
True
):
# type: ignore[call-arg]
"""API parameters for pooling models. This
is currently a placeholder.
"""API parameters for pooling models. This
Attributes:
Attributes:
dimensions: Reduce the dimensions of embeddings
dimensions: Reduce the dimensions of embeddings
if model support matryoshka representation.
if model support matryoshka representation.
additional_data: Any additional data needed for pooling.
"""
"""
dimensions
:
Optional
[
int
]
=
None
dimensions
:
Optional
[
int
]
=
None
use_cross_encoder
:
bool
=
False
use_cross_encoder
:
bool
=
False
additional_data
:
Optional
[
Any
]
=
None
"""Internal use only."""
logits_processing_needs_token_ids
:
bool
=
False
"""Internal use only."""
output_kind
:
RequestOutputKind
=
RequestOutputKind
.
FINAL_ONLY
output_kind
:
RequestOutputKind
=
RequestOutputKind
.
FINAL_ONLY
def
clone
(
self
)
->
"PoolingParams"
:
def
clone
(
self
)
->
"PoolingParams"
:
"""Returns a deep copy of the PoolingParams instance."""
"""Returns a deep copy of the PoolingParams instance."""
return
PoolingParams
(
dimensions
=
self
.
dimensions
,
return
PoolingParams
(
use_cross_encoder
=
self
.
use_cross_encoder
,
dimensions
=
self
.
dimensions
,
additional_data
=
self
.
additional_data
)
use_cross_encoder
=
self
.
use_cross_encoder
,
logits_processing_needs_token_ids
=
self
.
logits_processing_needs_token_ids
,
)
def
verify
(
self
,
model_config
:
"ModelConfig"
)
->
None
:
def
verify
(
self
,
model_config
:
"ModelConfig"
)
->
None
:
if
self
.
dimensions
is
not
None
:
if
self
.
dimensions
is
not
None
:
...
@@ -54,10 +61,12 @@ class PoolingParams(
...
@@ -54,10 +61,12 @@ class PoolingParams(
raise
ValueError
(
"Dimensions must be greater than 0"
)
raise
ValueError
(
"Dimensions must be greater than 0"
)
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
"PoolingParams("
return
(
f
"dimensions=
{
self
.
dimensions
}
, "
f
"PoolingParams("
f
"use_cross_encoder=
{
self
.
use_cross_encoder
}
, "
f
"dimensions=
{
self
.
dimensions
}
, "
f
"additional_metadata=
{
self
.
additional_data
}
)"
)
f
"use_cross_encoder=
{
self
.
use_cross_encoder
}
, "
f
"logits_processing_needs_token_ids=
{
self
.
logits_processing_needs_token_ids
}
)"
)
def
__post_init__
(
self
)
->
None
:
def
__post_init__
(
self
)
->
None
:
assert
self
.
output_kind
==
RequestOutputKind
.
FINAL_ONLY
,
\
assert
self
.
output_kind
==
RequestOutputKind
.
FINAL_ONLY
,
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment