Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1c3198b6
Unverified
Commit
1c3198b6
authored
Jul 16, 2025
by
Cyrus Leung
Committed by
GitHub
Jul 16, 2025
Browse files
[Model] Consolidate pooler implementations (#20927)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
260127ea
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
553 additions
and
367 deletions
+553
-367
vllm/model_executor/layers/pooler.py
vllm/model_executor/layers/pooler.py
+434
-247
vllm/model_executor/models/adapters.py
vllm/model_executor/models/adapters.py
+48
-51
vllm/model_executor/models/bert.py
vllm/model_executor/models/bert.py
+16
-9
vllm/model_executor/models/gritlm.py
vllm/model_executor/models/gritlm.py
+2
-2
vllm/model_executor/models/interfaces.py
vllm/model_executor/models/interfaces.py
+1
-1
vllm/model_executor/models/jamba.py
vllm/model_executor/models/jamba.py
+26
-13
vllm/model_executor/models/modernbert.py
vllm/model_executor/models/modernbert.py
+18
-15
vllm/model_executor/models/roberta.py
vllm/model_executor/models/roberta.py
+8
-5
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+0
-24
No files found.
vllm/model_executor/layers/pooler.py
View file @
1c3198b6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
from
enum
import
IntEnum
from
typing
import
Optional
,
Union
from
typing
import
Callable
,
Optional
,
TypeVar
,
Union
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
t
yping_extensions
import
assert_never
from
t
ransformers
import
PretrainedConfig
from
vllm.config
import
ModelConfig
,
PoolerConfig
from
vllm.model_executor.pooling_metadata
import
(
# noqa: E501
PoolingMetadata
as
V0PoolingMetadata
)
from
vllm.model_executor.pooling_metadata
import
PoolingTensors
from
vllm.sequence
import
PoolerOutput
,
PoolingSequenceGroupOutput
from
vllm.transformers_utils.config
import
(
get_classification_activation_function
,
get_cross_encoder_activation_function
)
from
vllm.utils
import
resolve_obj_by_qualname
from
vllm.v1.pool.metadata
import
PoolingMetadata
as
V1PoolingMetadata
PoolingMetadata
=
Union
[
V0PoolingMetadata
,
V1PoolingMetadata
]
...
...
@@ -31,140 +30,202 @@ class PoolingType(IntEnum):
MEAN
=
4
class
SimplePooler
(
nn
.
Module
):
"""A layer that pools specific information from hidden states.
@
dataclass
(
frozen
=
True
)
class
ResolvedPoolingConfig
:
pooling_type
:
PoolingType
This layer does the following:
1. Extracts specific tokens or aggregates data based on pooling method.
2. Normalizes output if specified.
3. Returns structured results as `PoolerOutput`.
Attributes:
pooling_type: The type of pooling to use.
normalize: Whether to normalize the pooled data.
"""
normalize
:
bool
softmax
:
bool
step_tag_id
:
Optional
[
int
]
returned_token_ids
:
Optional
[
list
[
int
]]
@
staticmethod
def
from_pooling_type
(
@
classmethod
def
from_config_with_defaults
(
cls
,
pooler_config
:
PoolerConfig
,
pooling_type
:
PoolingType
,
*
,
normalize
:
bool
,
softmax
:
bool
,
step_tag_id
:
Optional
[
int
]
=
None
,
returned_token_ids
:
Optional
[
list
[
int
]]
=
None
,
)
->
"SimplePooler"
:
if
pooling_type
==
PoolingType
.
LAST
:
assert
step_tag_id
is
None
and
returned_token_ids
is
None
return
LastPool
(
normalize
=
normalize
,
softmax
=
softmax
)
if
pooling_type
==
PoolingType
.
ALL
:
assert
step_tag_id
is
None
and
returned_token_ids
is
None
return
AllPool
(
normalize
=
normalize
,
softmax
=
softmax
)
if
pooling_type
==
PoolingType
.
CLS
:
assert
step_tag_id
is
None
and
returned_token_ids
is
None
return
CLSPool
(
normalize
=
normalize
,
softmax
=
softmax
)
if
pooling_type
==
PoolingType
.
MEAN
:
assert
step_tag_id
is
None
and
returned_token_ids
is
None
return
MeanPool
(
normalize
=
normalize
,
softmax
=
softmax
)
if
pooling_type
==
PoolingType
.
STEP
:
return
StepPool
(
normalize
=
normalize
,
softmax
=
softmax
,
step_tag_id
=
step_tag_id
,
returned_token_ids
=
returned_token_ids
)
)
->
"ResolvedPoolingConfig"
:
return
cls
(
pooling_type
=
PoolingType
[
pooler_config
.
pooling_type
]
if
pooler_config
.
pooling_type
is
not
None
else
pooling_type
,
normalize
=
pooler_config
.
normalize
if
pooler_config
.
normalize
is
not
None
else
normalize
,
softmax
=
pooler_config
.
softmax
if
pooler_config
.
softmax
is
not
None
else
softmax
,
step_tag_id
=
pooler_config
.
step_tag_id
if
pooler_config
.
step_tag_id
is
not
None
else
step_tag_id
,
returned_token_ids
=
pooler_config
.
returned_token_ids
if
pooler_config
.
returned_token_ids
is
not
None
else
returned_token_ids
,
)
assert_never
(
pooling_type
)
def
__init__
(
self
,
*
,
normalize
:
bool
,
softmax
:
bool
)
->
None
:
super
().
__init__
()
def
get_prompt_lens
(
hidden_states
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
pooling_metadata
:
PoolingMetadata
,
)
->
torch
.
Tensor
:
if
isinstance
(
pooling_metadata
,
V1PoolingMetadata
):
return
pooling_metadata
.
prompt_lens
assert
isinstance
(
hidden_states
,
torch
.
Tensor
)
return
PoolingTensors
.
from_pooling_metadata
(
pooling_metadata
,
hidden_states
.
device
).
prompt_lens
def
get_classification_activation_function
(
config
:
PretrainedConfig
):
return
PoolerClassify
()
def
get_cross_encoder_activation_function
(
config
:
PretrainedConfig
):
function_name
:
Optional
[
str
]
=
None
if
(
hasattr
(
config
,
"sentence_transformers"
)
and
"activation_fn"
in
config
.
sentence_transformers
):
function_name
=
config
.
sentence_transformers
[
"activation_fn"
]
elif
(
hasattr
(
config
,
"sbert_ce_default_activation_function"
)
and
config
.
sbert_ce_default_activation_function
is
not
None
):
function_name
=
config
.
sbert_ce_default_activation_function
if
function_name
is
not
None
:
assert
function_name
.
startswith
(
"torch.nn.modules."
),
(
"Loading of activation functions is restricted to "
"torch.nn.modules for security reasons"
)
fn
=
resolve_obj_by_qualname
(
function_name
)()
return
PoolerActivation
.
wraps
(
fn
)
self
.
head
=
PoolerHead
(
normalize
=
normalize
,
softmax
=
softmax
)
return
PoolerScore
(
)
def
get_prompt_lens
(
def
build_output
(
all_data
:
torch
.
Tensor
)
->
PoolerOutput
:
all_outputs
=
[
PoolingSequenceGroupOutput
(
data
)
for
data
in
all_data
]
return
PoolerOutput
(
outputs
=
all_outputs
)
class
BasePooler
(
nn
.
Module
):
@
abstractmethod
def
forward
(
self
,
hidden_states
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
raise
NotImplementedError
class
PoolingMethod
(
nn
.
Module
,
ABC
):
@
staticmethod
def
from_pooling_type
(
pooling_type
:
PoolingType
)
->
"PoolingMethod"
:
if
pooling_type
==
PoolingType
.
LAST
:
return
LastPool
()
if
pooling_type
==
PoolingType
.
ALL
:
return
AllPool
()
if
pooling_type
==
PoolingType
.
CLS
:
return
CLSPool
()
if
pooling_type
==
PoolingType
.
MEAN
:
return
MeanPool
()
raise
NotImplementedError
(
f
"Unsupported method:
{
pooling_type
}
"
)
@
abstractmethod
def
forward_one
(
self
,
hidden_states
:
torch
.
Tensor
,
prompt_len
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
isinstance
(
pooling_metadata
,
V1PoolingMetadata
):
return
pooling_metadata
.
prompt_lens
assert
isinstance
(
hidden_states
,
torch
.
Tensor
)
return
PoolingTensors
.
from_pooling_metadata
(
pooling_metadata
,
hidden_states
.
device
).
prompt_lens
"""
Note:
`prompt_len=None` means `prompt_len=len(hidden_states)`.
"""
raise
NotImplementedError
def
extract_states
(
@
abstractmethod
def
forward_all
(
self
,
hidden_states
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
p
ooling_metadata
:
PoolingMetadata
,
hidden_states
:
torch
.
Tensor
,
p
rompt_lens
:
torch
.
Tensor
,
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
raise
NotImplementedError
def
build_output
(
self
,
data
:
torch
.
Tensor
)
->
PoolingSequenceGroupOutput
:
return
PoolingSequenceGroupOutput
(
data
)
def
forward
(
self
,
hidden_states
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
pooled_data
=
self
.
extract_states
(
hidden_states
,
pooling_metadata
)
pooled_data
=
self
.
head
(
pooled_data
,
pooling_metadata
)
pooled_outputs
=
[
self
.
build_output
(
data
)
for
data
in
pooled_data
]
return
PoolerOutput
(
outputs
=
pooled_outputs
)
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
prompt_lens
=
get_prompt_lens
(
hidden_states
,
pooling_metadata
)
if
isinstance
(
hidden_states
,
list
):
return
[
self
.
forward_one
(
h
,
prompt_len
)
for
h
,
prompt_len
in
zip
(
hidden_states
,
prompt_lens
)
]
return
self
.
forward_all
(
hidden_states
,
prompt_lens
)
class
CLSPool
(
SimplePooler
):
def
extract_states
(
class
CLSPool
(
PoolingMethod
):
def
forward_one
(
self
,
hidden_states
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
pooling_metadata
:
PoolingMetadata
,
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
prompt_lens
=
self
.
get_prompt_lens
(
hidden_states
,
pooling_metadata
)
hidden_states
:
torch
.
Tensor
,
prompt_len
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
assert
prompt_len
is
None
or
prompt_len
==
hidden_states
.
shape
[
0
],
\
"partial prefill not supported with CLS pooling"
if
isinstance
(
hidden_states
,
list
):
result
=
[]
for
req_state
,
prompt_len
in
zip
(
hidden_states
,
prompt_lens
):
assert
prompt_len
==
req_state
.
shape
[
0
],
\
"partial prefill not supported with CLS pooling"
result
.
append
(
req_state
[
0
])
return
result
return
hidden_states
[
0
]
def
forward_all
(
self
,
hidden_states
:
torch
.
Tensor
,
prompt_lens
:
torch
.
Tensor
,
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
first_token_flat_indices
=
torch
.
zeros_like
(
prompt_lens
)
first_token_flat_indices
[
1
:]
+=
torch
.
cumsum
(
prompt_lens
,
dim
=
0
)[:
-
1
]
return
hidden_states
[
first_token_flat_indices
]
class
LastPool
(
SimplePooler
):
class
LastPool
(
PoolingMethod
):
def
extract_states
(
def
forward_one
(
self
,
hidden_states
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
pooling_metadata
:
PoolingMetadata
,
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
if
isinstance
(
hidden_states
,
list
):
return
[
h
[
-
1
]
for
h
in
hidden_states
]
prompt_lens
=
self
.
get_prompt_lens
(
hidden_states
,
pooling_metadata
)
hidden_states
:
torch
.
Tensor
,
prompt_len
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
return
hidden_states
[
-
1
]
def
forward_all
(
self
,
hidden_states
:
torch
.
Tensor
,
prompt_lens
:
torch
.
Tensor
,
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
last_token_flat_indices
=
torch
.
cumsum
(
prompt_lens
,
dim
=
0
)
-
1
return
hidden_states
[
last_token_flat_indices
]
class
AllPool
(
SimplePooler
):
class
AllPool
(
PoolingMethod
):
def
extract_states
(
def
forward_one
(
self
,
hidden_states
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
pooling_metadata
:
PoolingMetadata
,
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
prompt_lens
=
self
.
get_prompt_lens
(
hidden_states
,
pooling_metadata
)
hidden_states
:
torch
.
Tensor
,
prompt_len
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
assert
prompt_len
is
None
or
prompt_len
==
hidden_states
.
shape
[
0
],
\
"partial prefill not supported with ALL pooling"
if
isinstance
(
hidden_states
,
list
):
for
req_state
,
prompt_len
in
zip
(
hidden_states
,
prompt_lens
):
assert
prompt_len
==
req_state
.
shape
[
0
],
\
"partial prefill not supported with ALL pooling"
return
hidden_states
return
hidden_states
def
forward_all
(
self
,
hidden_states
:
torch
.
Tensor
,
prompt_lens
:
torch
.
Tensor
,
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
offset
=
0
pooled_data
=
list
[
torch
.
Tensor
]()
for
prompt_len
in
prompt_lens
:
pooled_data
.
append
(
hidden_states
[
offset
:
offset
+
prompt_len
])
offset
+=
prompt_len
...
...
@@ -172,24 +233,23 @@ class AllPool(SimplePooler):
return
pooled_data
class
MeanPool
(
SimplePooler
):
class
MeanPool
(
PoolingMethod
):
def
extract_states
(
def
forward_one
(
self
,
hidden_states
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
pooling_metadata
:
PoolingMetadata
,
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
prompt_lens
=
self
.
get_prompt_lens
(
hidden_states
,
pooling_metadata
)
hidden_states
:
torch
.
Tensor
,
prompt_len
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
assert
prompt_len
is
None
or
prompt_len
==
hidden_states
.
shape
[
0
],
\
"partial prefill not supported with MEAN pooling"
if
isinstance
(
hidden_states
,
list
):
result
=
[]
for
req_state
,
prompt_len
in
zip
(
hidden_states
,
prompt_lens
):
assert
prompt_len
==
req_state
.
shape
[
0
],
\
"partial prefill not supported with mean pooling"
result
.
append
(
torch
.
mean
(
req_state
,
dim
=
0
,
dtype
=
torch
.
float32
))
return
result
return
hidden_states
.
mean
(
dim
=
0
,
dtype
=
torch
.
float32
)
def
forward_all
(
self
,
hidden_states
:
torch
.
Tensor
,
prompt_lens
:
torch
.
Tensor
,
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
# Use float32 for torch.cumsum in MeanPool,
# otherwise precision will be lost significantly.
cumsum
=
torch
.
cumsum
(
hidden_states
,
dim
=
0
,
dtype
=
torch
.
float32
)
...
...
@@ -203,78 +263,127 @@ class MeanPool(SimplePooler):
hidden_states
[
start_indices
])
/
prompt_lens
.
unsqueeze
(
1
)
class
StepPool
(
SimplePooler
):
_T
=
TypeVar
(
"_T"
,
torch
.
Tensor
,
list
[
torch
.
Tensor
])
def
__init__
(
self
,
*
,
normalize
:
bool
,
softmax
:
bool
,
step_tag_id
:
Optional
[
int
]
=
None
,
returned_token_ids
:
Optional
[
list
[
int
]]
=
None
,
):
super
().
__init__
(
normalize
=
normalize
,
softmax
=
softmax
)
self
.
step_tag_id
=
step_tag_id
self
.
returned_token_ids
=
returned_token_ids
class
BasePoolerActivation
(
nn
.
Module
,
ABC
):
def
get_prompt_token_ids
(
self
,
pooling_metadata
:
PoolingMetadata
,
)
->
list
[
torch
.
Tensor
]:
if
isinstance
(
pooling_metadata
,
V1PoolingMetadata
):
return
[
pooling_metadata
.
prompt_token_ids
[
i
,
:
num
]
for
i
,
num
in
enumerate
(
pooling_metadata
.
prompt_lens
)
]
return
[
torch
.
tensor
(
seq_data_i
.
prompt_token_ids
)
for
seq_data_i
in
pooling_metadata
.
seq_data
.
values
()
]
@
abstractmethod
def
forward
(
self
,
pooled_data
:
_T
)
->
_T
:
# shape:
# classify (& score) -> (batch_size, num_classes)
# embed -> (batch_size, embedding_dim) or list(embedding_dim)
# (batch_size, dimensions) or list(dimensions) if using MRL
raise
NotImplementedError
def
extract_states
(
self
,
hidden_states
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
pooling_metadata
:
PoolingMetadata
,
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
prompt_lens
=
self
.
get_prompt_lens
(
hidden_states
,
pooling_metadata
)
prompt_token_ids
=
self
.
get_prompt_token_ids
(
pooling_metadata
)
pooled_data_lst
=
list
[
torch
.
Tensor
]()
if
isinstance
(
hidden_states
,
list
):
for
req_state
,
prompt_len
in
zip
(
hidden_states
,
prompt_lens
):
assert
prompt_len
==
req_state
.
shape
[
0
],
\
"partial prefill not supported with step pooling"
pooled_data_lst
=
hidden_states
else
:
offset
=
0
for
prompt_len
in
prompt_lens
:
pooled_data_i
=
hidden_states
[
offset
:
offset
+
prompt_len
]
offset
+=
prompt_len
pooled_data_lst
.
append
(
pooled_data_i
)
class
PoolerActivation
(
BasePoolerActivation
):
pooled_data
=
list
[
torch
.
Tensor
]()
returned_token_ids
=
self
.
returned_token_ids
step_tag_id
=
self
.
step_tag_id
@
staticmethod
def
wraps
(
module
:
nn
.
Module
):
if
isinstance
(
module
,
nn
.
Identity
):
return
PoolerIdentity
()
if
isinstance
(
module
,
(
nn
.
Sigmoid
,
nn
.
Softmax
)):
return
PoolerClassify
()
for
data
,
token_id
in
zip
(
pooled_data_lst
,
prompt_token_ids
):
if
returned_token_ids
is
not
None
and
len
(
returned_token_ids
)
>
0
:
data
=
data
[:,
returned_token_ids
]
return
LambdaPoolerActivation
(
module
)
@
abstractmethod
def
forward_chunk
(
self
,
pooled_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
raise
NotImplementedError
def
forward
(
self
,
pooled_data
:
_T
)
->
_T
:
if
isinstance
(
pooled_data
,
list
):
return
[
self
.
forward_chunk
(
data
)
for
data
in
pooled_data
]
return
self
.
forward_chunk
(
pooled_data
)
if
step_tag_id
is
not
None
:
data
=
data
[
token_id
==
step_tag_id
]
pooled_data
.
append
(
data
)
class
PoolerIdentity
(
PoolerActivation
):
def
forward_chunk
(
self
,
pooled_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
pooled_data
class
PoolerNormalize
(
PoolerActivation
):
def
forward_chunk
(
self
,
pooled_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
F
.
normalize
(
pooled_data
.
float
(),
p
=
2
,
dim
=-
1
)
return
x
.
to
(
pooled_data
.
dtype
)
class
PoolerClassify
(
PoolerActivation
):
def
forward_chunk
(
self
,
pooled_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_labels
=
pooled_data
.
shape
[
-
1
]
if
num_labels
<
2
:
return
F
.
sigmoid
(
pooled_data
.
float
()).
to
(
pooled_data
.
dtype
)
return
F
.
softmax
(
pooled_data
.
float
(),
dim
=-
1
).
to
(
pooled_data
.
dtype
)
class
PoolerScore
(
PoolerActivation
):
def
forward_chunk
(
self
,
pooled_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_labels
=
pooled_data
.
shape
[
-
1
]
if
num_labels
<
2
:
return
F
.
sigmoid
(
pooled_data
.
float
()).
to
(
pooled_data
.
dtype
)
return
pooled_data
class
LambdaPoolerActivation
(
PoolerActivation
):
def
__init__
(
self
,
fn
:
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
]):
super
().
__init__
()
self
.
fn
=
fn
def
forward_chunk
(
self
,
pooled_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
fn
(
pooled_data
)
class
PoolerHead
(
nn
.
Module
):
def
__init__
(
self
,
*
,
normalize
:
bool
,
softmax
:
bool
)
->
None
:
@
classmethod
def
from_config_with_defaults
(
cls
,
pooler_config
:
PoolerConfig
,
pooling_type
:
PoolingType
,
normalize
:
bool
,
softmax
:
bool
,
)
->
"PoolerHead"
:
resolved_config
=
ResolvedPoolingConfig
.
from_config_with_defaults
(
pooler_config
=
pooler_config
,
pooling_type
=
pooling_type
,
normalize
=
normalize
,
softmax
=
softmax
,
step_tag_id
=
None
,
returned_token_ids
=
None
,
)
return
cls
.
from_config
(
resolved_config
)
@
classmethod
def
from_config
(
cls
,
pooler_config
:
ResolvedPoolingConfig
)
->
"PoolerHead"
:
if
pooler_config
.
normalize
and
pooler_config
.
softmax
:
raise
ValueError
(
"`normalize=True` and `softmax=True` should not "
"be set together"
)
activation
:
PoolerActivation
if
pooler_config
.
normalize
:
activation
=
PoolerNormalize
()
elif
pooler_config
.
softmax
:
activation
=
PoolerClassify
()
else
:
activation
=
PoolerIdentity
()
return
cls
(
activation
)
def
__init__
(
self
,
activation
:
PoolerActivation
)
->
None
:
super
().
__init__
()
self
.
normalize
=
normalize
self
.
softmax
=
softmax
self
.
activation
=
activation
def
forward
(
self
,
pooled_data
:
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
],
pooling_metadata
:
PoolingMetadata
):
...
...
@@ -312,35 +421,21 @@ class PoolerHead(nn.Module):
for
vecs
,
d
in
zip
(
pooled_data
,
dimensions_list
)
]
if
self
.
normalize
:
if
isinstance
(
pooled_data
,
list
):
pooled_data
=
[
F
.
normalize
(
data
,
p
=
2
,
dim
=-
1
)
for
data
in
pooled_data
]
else
:
pooled_data
=
F
.
normalize
(
pooled_data
,
p
=
2
,
dim
=-
1
)
return
self
.
activation
(
pooled_data
)
if
self
.
softmax
:
if
isinstance
(
pooled_data
,
list
):
pooled_data
=
[
F
.
softmax
(
data
,
dim
=-
1
)
if
data
.
shape
[
-
1
]
>=
2
else
F
.
sigmoid
(
data
)
for
data
in
pooled_data
]
else
:
if
pooled_data
.
shape
[
-
1
]
>=
2
:
pooled_data
=
F
.
softmax
(
pooled_data
,
dim
=-
1
)
else
:
pooled_data
=
F
.
sigmoid
(
pooled_data
)
# shape:
# classify (& score) -> (batch_size, num_classes)
# embed -> (batch_size, embedding_dim) or list(embedding_dim)
# (batch_size, dimensions) or list(dimensions) if using MRL
return
pooled_data
class
SimplePooler
(
BasePooler
):
"""A layer that pools specific information from hidden states.
This layer does the following:
1. Extracts specific tokens or aggregates data based on pooling method.
2. Normalizes output if specified.
3. Returns structured results as `PoolerOutput`.
class
Pooler
(
nn
.
Module
):
Attributes:
pooling_type: The type of pooling to use.
normalize: Whether to normalize the pooled data.
"""
@
classmethod
def
from_config_with_defaults
(
...
...
@@ -349,23 +444,146 @@ class Pooler(nn.Module):
pooling_type
:
PoolingType
,
normalize
:
bool
,
softmax
:
bool
,
)
->
"SimplePooler"
:
resolved_config
=
ResolvedPoolingConfig
.
from_config_with_defaults
(
pooler_config
=
pooler_config
,
pooling_type
=
pooling_type
,
normalize
=
normalize
,
softmax
=
softmax
,
)
assert
resolved_config
.
pooling_type
!=
PoolingType
.
STEP
return
cls
.
from_config
(
resolved_config
)
@
classmethod
def
from_config
(
cls
,
pooler_config
:
ResolvedPoolingConfig
,
)
->
"SimplePooler"
:
pooling
=
PoolingMethod
.
from_pooling_type
(
pooler_config
.
pooling_type
)
head
=
PoolerHead
.
from_config
(
pooler_config
)
return
cls
(
pooling
,
head
)
def
__init__
(
self
,
pooling
:
PoolingMethod
,
head
:
PoolerHead
)
->
None
:
super
().
__init__
()
self
.
pooling
=
pooling
self
.
head
=
head
def
forward
(
self
,
hidden_states
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
pooled_data
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
pooled_data
=
self
.
head
(
pooled_data
,
pooling_metadata
)
return
build_output
(
pooled_data
)
class
StepPooler
(
BasePooler
):
@
classmethod
def
from_config
(
cls
,
pooler_config
:
ResolvedPoolingConfig
)
->
"StepPooler"
:
assert
pooler_config
.
pooling_type
==
PoolingType
.
STEP
return
cls
(
PoolerHead
.
from_config
(
pooler_config
),
step_tag_id
=
pooler_config
.
step_tag_id
,
returned_token_ids
=
pooler_config
.
returned_token_ids
,
)
def
__init__
(
self
,
head
:
PoolerHead
,
*
,
step_tag_id
:
Optional
[
int
]
=
None
,
returned_token_ids
:
Optional
[
list
[
int
]]
=
None
,
)
->
SimplePooler
:
return
SimplePooler
.
from_pooling_type
(
pooling_type
=
PoolingType
[
pooler_config
.
pooling_type
]
if
pooler_config
.
pooling_type
is
not
None
else
pooling_type
,
normalize
=
pooler_config
.
normalize
if
pooler_config
.
normalize
is
not
None
else
normalize
,
softmax
=
pooler_config
.
softmax
if
pooler_config
.
softmax
is
not
None
else
softmax
,
step_tag_id
=
pooler_config
.
step_tag_id
if
pooler_config
.
step_tag_id
is
not
None
else
step_tag_id
,
returned_token_ids
=
pooler_config
.
returned_token_ids
if
pooler_config
.
returned_token_ids
is
not
None
else
returned_token_ids
,
)
->
None
:
super
().
__init__
()
self
.
pooling
=
AllPool
()
self
.
head
=
head
self
.
step_tag_id
=
step_tag_id
self
.
returned_token_ids
=
returned_token_ids
def
get_prompt_token_ids
(
self
,
pooling_metadata
:
PoolingMetadata
,
)
->
list
[
torch
.
Tensor
]:
if
isinstance
(
pooling_metadata
,
V1PoolingMetadata
):
return
[
pooling_metadata
.
prompt_token_ids
[
i
,
:
num
]
for
i
,
num
in
enumerate
(
pooling_metadata
.
prompt_lens
)
]
return
[
torch
.
tensor
(
seq_data_i
.
prompt_token_ids
)
for
seq_data_i
in
pooling_metadata
.
seq_data
.
values
()
]
def
extract_states
(
self
,
hidden_states
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
pooling_metadata
:
PoolingMetadata
,
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
pooled_data_lst
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
prompt_token_ids
=
self
.
get_prompt_token_ids
(
pooling_metadata
)
pooled_data
=
list
[
torch
.
Tensor
]()
returned_token_ids
=
self
.
returned_token_ids
step_tag_id
=
self
.
step_tag_id
for
data
,
token_id
in
zip
(
pooled_data_lst
,
prompt_token_ids
):
if
returned_token_ids
is
not
None
and
len
(
returned_token_ids
)
>
0
:
data
=
data
[:,
returned_token_ids
]
if
step_tag_id
is
not
None
:
data
=
data
[
token_id
==
step_tag_id
]
pooled_data
.
append
(
data
)
return
pooled_data
def
forward
(
self
,
hidden_states
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
pooled_data
=
self
.
extract_states
(
hidden_states
,
pooling_metadata
)
pooled_data
=
self
.
head
(
pooled_data
,
pooling_metadata
)
return
build_output
(
pooled_data
)
class
Pooler
(
nn
.
Module
):
@
staticmethod
def
from_config_with_defaults
(
pooler_config
:
PoolerConfig
,
pooling_type
:
PoolingType
,
normalize
:
bool
,
softmax
:
bool
,
step_tag_id
:
Optional
[
int
]
=
None
,
returned_token_ids
:
Optional
[
list
[
int
]]
=
None
,
)
->
BasePooler
:
resolved_config
=
ResolvedPoolingConfig
.
from_config_with_defaults
(
pooler_config
=
pooler_config
,
pooling_type
=
pooling_type
,
normalize
=
normalize
,
softmax
=
softmax
,
step_tag_id
=
step_tag_id
,
returned_token_ids
=
returned_token_ids
,
)
if
pooling_type
==
PoolingType
.
STEP
:
return
StepPooler
.
from_config
(
resolved_config
)
return
SimplePooler
.
from_config
(
resolved_config
)
PoolingFn
=
Callable
[
[
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
PoolingMetadata
],
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]]
ClassifierFn
=
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
]
class
ClassifierPooler
(
nn
.
Module
):
"""A pooling layer for classification tasks.
...
...
@@ -382,69 +600,39 @@ class ClassifierPooler(nn.Module):
def
__init__
(
self
,
config
:
ModelConfig
,
classifier
:
nn
.
Module
,
pooler
:
Optional
[
nn
.
Module
]
=
None
,
):
pooling
:
PoolingFn
,
classifier
:
ClassifierFn
,
act_fn
:
Optional
[
PoolerActivation
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
pooling
=
pooling
self
.
classifier
=
classifier
self
.
pooler
=
pooler
self
.
classification_act_fn
=
get_classification_activation_function
(
config
.
hf_config
)
config
.
hf_config
)
if
act_fn
is
None
else
act_fn
self
.
cross_encoder_act_fn
=
get_cross_encoder_activation_function
(
config
.
hf_config
)
config
.
hf_config
)
if
act_fn
is
None
else
act_fn
def
_get_act_fn
(
self
,
use_cross_encoder
:
bool
):
return
(
self
.
cross_encoder_act_fn
if
use_cross_encoder
else
self
.
classification_act_fn
)
def
get_prompt_lens
(
self
,
hidden_states
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
pooling_metadata
:
PoolingMetadata
,
)
->
torch
.
Tensor
:
if
isinstance
(
pooling_metadata
,
V1PoolingMetadata
):
return
pooling_metadata
.
prompt_lens
assert
isinstance
(
hidden_states
,
torch
.
Tensor
)
return
PoolingTensors
.
from_pooling_metadata
(
pooling_metadata
,
hidden_states
.
device
).
prompt_lens
def
forward
(
self
,
hidden_states
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
"""Pools sentence pair scores from the hidden_states."""
p
rompt_lens
=
self
.
get_prompt_lens
(
hidden_states
,
pooling_metadata
)
p
ooled_data
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
pooled_data
=
list
[
torch
.
Tensor
]()
if
isinstance
(
hidden_states
,
list
):
for
req_state
,
prompt_len
in
zip
(
hidden_states
,
prompt_lens
):
assert
prompt_len
==
req_state
.
shape
[
0
],
\
"partial prefill not supported with classifier"
pooled_data
=
hidden_states
# apply classifier once on the full batch if possible
if
isinstance
(
pooled_data
,
torch
.
Tensor
):
pooled_output
=
self
.
classifier
(
pooled_data
)
elif
len
({
data
.
shape
for
data
in
pooled_data
})
<=
1
:
pooled_output
=
self
.
classifier
(
torch
.
stack
(
pooled_data
))
else
:
offset
=
0
for
prompt_len
in
prompt_lens
:
pooled_data_i
=
hidden_states
[
offset
:
offset
+
prompt_len
]
offset
+=
prompt_len
pooled_data
.
append
(
pooled_data_i
)
pooled_data_lst
=
[]
for
pooled_data_i
in
pooled_data
:
if
self
.
pooler
is
not
None
:
final_shape_tensor
=
self
.
pooler
(
pooled_data_i
)
else
:
final_shape_tensor
=
self
.
classifier
(
pooled_data_i
)
pooled_data_lst
.
append
(
final_shape_tensor
)
pooled_output
=
torch
.
stack
(
pooled_data_lst
)
if
self
.
pooler
is
not
None
:
# apply classifier once on the full batch if possible
pooled_output
=
self
.
classifier
(
pooled_output
)
pooled_output
=
[
self
.
classifier
(
data
)
for
data
in
pooled_data
]
if
isinstance
(
pooling_metadata
,
V0PoolingMetadata
):
use_cross_encoder_list
=
[
...
...
@@ -469,5 +657,4 @@ class ClassifierPooler(nn.Module):
pooled_output
)
])
pooled_outputs
=
[
PoolingSequenceGroupOutput
(
data
)
for
data
in
scores
]
return
PoolerOutput
(
outputs
=
pooled_outputs
)
return
build_output
(
scores
)
vllm/model_executor/models/adapters.py
View file @
1c3198b6
...
...
@@ -58,22 +58,27 @@ def _create_pooling_model_cls(
)
->
None
:
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
**
kwargs
)
self
.
vllm_config
=
vllm_config
# These are not used in pooling models
for
attr
in
(
"lm_head"
,
"logits_processor"
):
if
hasattr
(
self
,
attr
):
delattr
(
self
,
attr
)
# If the model already defines a pooler instance, don't overwrite it
if
not
getattr
(
self
,
"_pooler"
,
None
):
self
.
_init_pooler
(
vllm_config
,
prefix
=
prefix
)
def
_init_pooler
(
self
,
vllm_config
:
"VllmConfig"
,
prefix
:
str
=
""
):
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
# If the model already defines a pooler instance, don't overwrite it
if
not
getattr
(
self
,
"_pooler"
,
None
):
self
.
_pooler
=
Pooler
.
from_config_with_defaults
(
pooler_config
,
pooling_type
=
default_pooling_type
,
normalize
=
default_normalize
,
softmax
=
default_softmax
,
)
self
.
_pooler
=
Pooler
.
from_config_with_defaults
(
pooler_config
,
pooling_type
=
default_pooling_type
,
normalize
=
default_normalize
,
softmax
=
default_softmax
,
)
def
pooler
(
self
,
...
...
@@ -165,7 +170,9 @@ def as_seq_cls_model(cls: _T) -> _T:
# Lazy import
from
vllm.model_executor.layers.linear
import
RowParallelLinear
from
vllm.model_executor.layers.pooler
import
PoolerOutput
,
PoolingType
from
vllm.model_executor.layers.pooler
import
(
ClassifierPooler
,
PoolerOutput
,
PoolingType
,
SimplePooler
)
from
vllm.model_executor.models.interfaces
import
SupportsCrossEncoding
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.sequence
import
IntermediateTensors
...
...
@@ -182,30 +189,40 @@ def as_seq_cls_model(cls: _T) -> _T:
class
ModelForSequenceClassification
(
ModelForPooling
,
SupportsCrossEncoding
):
def
__init__
(
self
,
*
,
vllm_config
:
"VllmConfig"
,
prefix
:
str
=
""
,
**
kwargs
:
Any
,
)
->
None
:
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
,
**
kwargs
)
def
_init_pooler
(
self
,
vllm_config
:
"VllmConfig"
,
prefix
:
str
=
""
):
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
self
.
vllm_config
=
vllm_config
self
.
task
=
vllm_config
.
model_config
.
task
self
.
pooling_type
=
(
vllm_config
.
model_config
.
pooler_config
.
pooling_type
)
self
.
score
=
RowParallelLinear
(
config
.
hidden_size
,
config
.
num_labels
,
quant_config
=
quant_config
,
input_is_parallel
=
False
,
bias
=
False
,
prefix
=
maybe_prefix
(
prefix
,
"score"
))
self
.
score
=
RowParallelLinear
(
config
.
hidden_size
,
config
.
num_labels
,
input_is_parallel
=
False
,
bias
=
False
,
params_dtype
=
torch
.
float32
,
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"score"
),
)
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
pooler
=
SimplePooler
.
from_config_with_defaults
(
pooler_config
,
pooling_type
=
PoolingType
.
LAST
,
normalize
=
False
,
softmax
=
True
,
)
self
.
_pooler
=
ClassifierPooler
(
vllm_config
.
model_config
,
pooling
=
pooler
.
pooling
,
classifier
=
self
.
_classifier
,
act_fn
=
pooler
.
head
.
activation
,
)
def
_classifier
(
self
,
x
:
torch
.
Tensor
):
x
,
_
=
self
.
score
(
x
.
float
())
return
x
def
forward
(
self
,
...
...
@@ -222,27 +239,7 @@ def as_seq_cls_model(cls: _T) -> _T:
hidden_states
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
def
get_logits
(
hidden_states
):
if
isinstance
(
hidden_states
,
list
):
logits
=
[
self
.
score
(
state
)[
0
]
for
state
in
hidden_states
]
else
:
logits
,
_
=
self
.
score
(
hidden_states
)
return
logits
if
self
.
pooling_type
==
PoolingType
.
ALL
:
logits
=
get_logits
(
hidden_states
)
return
self
.
_pooler
(
logits
,
pooling_metadata
)
else
:
hidden_states
=
self
.
_pooler
.
extract_states
(
hidden_states
,
pooling_metadata
)
logits
=
get_logits
(
hidden_states
)
pooled_data
=
self
.
_pooler
.
head
(
logits
,
pooling_metadata
)
pooled_outputs
=
[
self
.
_pooler
.
build_output
(
data
)
for
data
in
pooled_data
]
return
PoolerOutput
(
outputs
=
pooled_outputs
)
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
tokens
=
getattr
(
self
.
config
,
"classifier_from_token"
,
None
)
...
...
vllm/model_executor/models/bert.py
View file @
1c3198b6
...
...
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Iterable
from
typing
import
Optional
from
typing
import
Optional
,
Union
import
torch
from
torch
import
nn
...
...
@@ -18,7 +18,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.pooler
import
(
ClassifierPooler
,
Pooler
,
PoolingType
)
PoolingMethod
,
PoolingType
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
...
...
@@ -84,14 +84,18 @@ class BertPooler(nn.Module):
def
__init__
(
self
,
config
:
BertConfig
):
super
().
__init__
()
self
.
pooling
=
PoolingMethod
.
from_pooling_type
(
PoolingType
.
CLS
)
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
activation
=
nn
.
Tanh
()
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor
=
hidden_states
[
0
,
:]
pooled_output
=
self
.
dense
(
first_token_tensor
)
def
forward
(
self
,
hidden_states
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
pooling_metadata
:
PoolingMetadata
,
)
->
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]:
pooled_output
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
pooled_output
=
self
.
dense
(
pooled_output
)
pooled_output
=
self
.
activation
(
pooled_output
)
return
pooled_output
...
...
@@ -472,8 +476,11 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
embedding_class
=
BertEmbedding
,
add_pooling_layer
=
True
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
num_labels
)
self
.
_pooler
=
ClassifierPooler
(
vllm_config
.
model_config
,
self
.
classifier
,
self
.
bert
.
pooler
)
self
.
_pooler
=
ClassifierPooler
(
vllm_config
.
model_config
,
pooling
=
self
.
bert
.
pooler
,
classifier
=
self
.
classifier
,
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
loader
=
AutoWeightsLoader
(
self
)
...
...
vllm/model_executor/models/gritlm.py
View file @
1c3198b6
...
...
@@ -9,7 +9,7 @@ import torch.nn as nn
from
vllm.config
import
ModelConfig
,
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.pooler
import
PoolerHead
from
vllm.model_executor.layers.pooler
import
PoolerHead
,
PoolerNormalize
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
from
vllm.model_executor.pooling_metadata
import
(
PoolingMetadata
,
PoolingTensors
)
...
...
@@ -49,7 +49,7 @@ class GritLMPooler(nn.Module):
self
.
embed_pattern_ids
=
tokens_to_ids
(
[
"▁<"
,
"|"
,
"embed"
,
"|"
,
">"
,
"<0x0A>"
])
self
.
head
=
PoolerHead
(
normalize
=
True
,
softmax
=
False
)
self
.
head
=
PoolerHead
(
PoolerNormalize
()
)
def
_find_array
(
self
,
arr
:
array
,
target
:
array
,
start_idx
:
int
)
->
int
:
"""
...
...
vllm/model_executor/models/interfaces.py
View file @
1c3198b6
...
...
@@ -659,7 +659,7 @@ def supports_cross_encoding(
def
has_step_pooler
(
model
:
Union
[
type
[
object
],
object
])
->
bool
:
"""Check if the model uses step pooler."""
return
is_pooling_model
(
model
)
and
any
(
type
(
module
).
__name__
==
"StepPool"
for
module
in
model
.
modules
())
type
(
module
).
__name__
==
"StepPool
er
"
for
module
in
model
.
modules
())
class
SupportsQuant
:
...
...
vllm/model_executor/models/jamba.py
View file @
1c3198b6
...
...
@@ -19,7 +19,8 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.mamba.mamba_mixer
import
MambaMixer
from
vllm.model_executor.layers.pooler
import
Pooler
,
PoolingType
from
vllm.model_executor.layers.pooler
import
(
ClassifierPooler
,
PoolingType
,
SimplePooler
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
...
...
@@ -564,29 +565,41 @@ class JambaForSequenceClassification(JambaForCausalLM):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
config
=
vllm_config
.
model_config
.
hf_config
num_labels
:
int
=
config
.
num_labels
score_bias
:
bool
=
getattr
(
config
,
'score_bias'
,
False
)
self
.
score
=
nn
.
Linear
(
config
.
hidden_size
,
num_labels
,
bias
=
score_bias
)
# TODO: The original reward weights have float32 accuracy data, we
# would like to load them in fp32 to get that extra precision.
# Currently weight_loader passes the weight which is already in bf16
self
.
score
=
nn
.
Linear
(
config
.
hidden_size
,
num_labels
,
bias
=
score_bias
,
dtype
=
torch
.
float32
,
)
pooler_config
=
vllm_config
.
model_config
.
pooler_config
self
.
_pooler
=
Pooler
.
from_config_with_defaults
(
assert
pooler_config
is
not
None
pooler
=
SimplePooler
.
from_config_with_defaults
(
pooler_config
,
pooling_type
=
PoolingType
.
LAST
,
normalize
=
False
,
softmax
=
False
)
softmax
=
False
,
)
self
.
_pooler
=
ClassifierPooler
(
vllm_config
.
model_config
,
pooling
=
pooler
.
pooling
,
classifier
=
self
.
score
,
act_fn
=
pooler
.
head
.
activation
,
)
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
hidden_states
=
hidden_states
.
float
()
logits
=
self
.
score
(
hidden_states
)
return
self
.
_pooler
(
logits
,
pooling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
# TODO: The reward weights themselves have float32 accuracy data, we
# would like to load them in fp32 to get that extra precision.
super
().
load_weights
(
weights
)
self
.
score
=
self
.
score
.
float
()
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
vllm/model_executor/models/modernbert.py
View file @
1c3198b6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Iterable
from
typing
import
Optional
from
typing
import
Optional
,
Union
import
torch
from
torch
import
nn
...
...
@@ -13,7 +13,8 @@ from vllm.config import VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.pooler
import
ClassifierPooler
from
vllm.model_executor.layers.pooler
import
(
BasePooler
,
ClassifierPooler
,
PoolingMethod
,
PoolingType
)
from
vllm.model_executor.layers.rotary_embedding
import
RotaryEmbedding
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
...
...
@@ -252,10 +253,13 @@ class ModernBertModel(nn.Module):
return
norm_outputs
class
ModernBertPooler
(
nn
.
Modu
le
):
class
ModernBertPooler
(
BasePoo
le
r
):
def
__init__
(
self
,
config
:
ModernBertConfig
):
super
().
__init__
()
pooling_type
=
PoolingType
[
config
.
classifier_pooling
.
upper
()]
self
.
pooling
=
PoolingMethod
.
from_pooling_type
(
pooling_type
)
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
,
config
.
classifier_bias
)
self
.
pooling_type
=
config
.
classifier_pooling
...
...
@@ -264,15 +268,12 @@ class ModernBertPooler(nn.Module):
eps
=
config
.
norm_eps
,
bias
=
config
.
norm_bias
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
pooled_output
=
hidden_states
if
self
.
pooling_type
==
"mean"
:
pooled_output
=
pooled_output
.
mean
(
dim
=
0
,
keepdim
=
False
)
elif
self
.
pooling_type
==
"cls"
:
pooled_output
=
pooled_output
[
0
,
:]
else
:
raise
ValueError
(
"Pooling type should be either `cls` or `mean`, "
f
"but got
{
self
.
pooling_type
}
"
)
def
forward
(
self
,
hidden_states
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
pooling_metadata
:
PoolingMetadata
,
)
->
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]:
pooled_output
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
pooled_output
=
self
.
norm
(
self
.
act
(
self
.
dense
(
pooled_output
)))
return
pooled_output
...
...
@@ -287,9 +288,11 @@ class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
self
.
model
=
ModernBertModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"modernbert"
))
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
num_labels
)
self
.
_pooler
=
ClassifierPooler
(
vllm_config
.
model_config
,
self
.
classifier
,
ModernBertPooler
(
config
))
self
.
_pooler
=
ClassifierPooler
(
vllm_config
.
model_config
,
pooling
=
ModernBertPooler
(
config
),
classifier
=
self
.
classifier
,
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
...
...
vllm/model_executor/models/roberta.py
View file @
1c3198b6
...
...
@@ -9,7 +9,7 @@ from torch import nn
from
transformers
import
RobertaConfig
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.pooler
import
ClassifierPooler
from
vllm.model_executor.layers.pooler
import
ClassifierPooler
,
CLSPool
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
from
vllm.model_executor.models.bert
import
BertEmbeddingModel
,
BertModel
...
...
@@ -106,8 +106,8 @@ class RobertaClassificationHead(nn.Module):
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
out_proj
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
num_labels
)
def
forward
(
self
,
features
,
**
kwargs
)
:
x
=
features
[
0
,
:]
# take <s> token (equiv. to [CLS])
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# CLSPool has already been applied in `pooling`
x
=
self
.
dense
(
x
)
x
=
torch
.
tanh
(
x
)
x
=
self
.
out_proj
(
x
)
...
...
@@ -188,8 +188,11 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
add_pooling_layer
=
False
)
self
.
classifier
=
RobertaClassificationHead
(
config
)
self
.
_pooler
=
ClassifierPooler
(
vllm_config
.
model_config
,
self
.
classifier
)
self
.
_pooler
=
ClassifierPooler
(
vllm_config
.
model_config
,
pooling
=
CLSPool
(),
classifier
=
self
.
classifier
,
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
loader
=
AutoWeightsLoader
(
self
)
...
...
vllm/transformers_utils/config.py
View file @
1c3198b6
...
...
@@ -17,7 +17,6 @@ from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError,
HFValidationError
,
LocalEntryNotFoundError
,
RepositoryNotFoundError
,
RevisionNotFoundError
)
from
torch
import
nn
from
transformers
import
GenerationConfig
,
PretrainedConfig
from
transformers.models.auto.image_processing_auto
import
(
get_image_processor_config
)
...
...
@@ -44,7 +43,6 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
# yapf: enable
from
vllm.transformers_utils.configs.mistral
import
adapt_config_dict
from
vllm.transformers_utils.utils
import
check_gguf_file
from
vllm.utils
import
resolve_obj_by_qualname
if
envs
.
VLLM_USE_MODELSCOPE
:
from
modelscope
import
AutoConfig
...
...
@@ -775,28 +773,6 @@ def try_get_generation_config(
return
None
def
get_classification_activation_function
(
config
:
PretrainedConfig
):
return
nn
.
Sigmoid
()
if
config
.
num_labels
==
1
else
nn
.
Softmax
()
def
get_cross_encoder_activation_function
(
config
:
PretrainedConfig
):
function_name
:
Optional
[
str
]
=
None
if
(
hasattr
(
config
,
"sentence_transformers"
)
and
"activation_fn"
in
config
.
sentence_transformers
):
function_name
=
config
.
sentence_transformers
[
"activation_fn"
]
elif
(
hasattr
(
config
,
"sbert_ce_default_activation_function"
)
and
config
.
sbert_ce_default_activation_function
is
not
None
):
function_name
=
config
.
sbert_ce_default_activation_function
if
function_name
is
not
None
:
assert
function_name
.
startswith
(
"torch.nn.modules."
),
(
"Loading of activation functions is restricted to "
"torch.nn.modules for security reasons"
)
return
resolve_obj_by_qualname
(
function_name
)()
return
nn
.
Sigmoid
()
if
config
.
num_labels
==
1
else
nn
.
Identity
()
def
try_get_safetensors_metadata
(
model
:
str
,
*
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment