Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d70a1662
Unverified
Commit
d70a1662
authored
Aug 21, 2025
by
wang.yuqi
Committed by
GitHub
Aug 21, 2025
Browse files
[Performance] V1 Pooling Models E2E Performance Optimization (#23162)
Signed-off-by:
wang.yuqi
<
noooop@126.com
>
parent
5cc54f7c
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
161 additions
and
167 deletions
+161
-167
vllm/model_executor/layers/pooler.py
vllm/model_executor/layers/pooler.py
+48
-83
vllm/model_executor/models/bert.py
vllm/model_executor/models/bert.py
+3
-3
vllm/model_executor/models/roberta.py
vllm/model_executor/models/roberta.py
+10
-55
vllm/model_executor/pooling_metadata.py
vllm/model_executor/pooling_metadata.py
+17
-6
vllm/v1/pool/metadata.py
vllm/v1/pool/metadata.py
+54
-2
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+1
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+20
-16
vllm/worker/pooling_model_runner.py
vllm/worker/pooling_model_runner.py
+8
-1
No files found.
vllm/model_executor/layers/pooler.py
View file @
d70a1662
...
@@ -19,7 +19,8 @@ from vllm.model_executor.pooling_metadata import PoolingTensors
...
@@ -19,7 +19,8 @@ from vllm.model_executor.pooling_metadata import PoolingTensors
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.sequence
import
PoolerOutput
,
PoolingSequenceGroupOutput
from
vllm.sequence
import
PoolerOutput
,
PoolingSequenceGroupOutput
from
vllm.tasks
import
PoolingTask
from
vllm.tasks
import
PoolingTask
from
vllm.utils
import
resolve_obj_by_qualname
from
vllm.utils
import
current_stream
,
resolve_obj_by_qualname
from
vllm.v1.pool.metadata
import
PoolingCursor
from
vllm.v1.pool.metadata
import
PoolingMetadata
as
V1PoolingMetadata
from
vllm.v1.pool.metadata
import
PoolingMetadata
as
V1PoolingMetadata
PoolingMetadata
=
Union
[
V0PoolingMetadata
,
V1PoolingMetadata
]
PoolingMetadata
=
Union
[
V0PoolingMetadata
,
V1PoolingMetadata
]
...
@@ -205,6 +206,13 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
...
@@ -205,6 +206,13 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
def
build_output
(
def
build_output
(
all_data
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
)
->
PoolerOutput
:
all_data
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
)
->
PoolerOutput
:
# Pooling models D2H & synchronize occurs here
if
isinstance
(
all_data
,
list
):
all_data
=
[
d
.
to
(
"cpu"
,
non_blocking
=
True
)
for
d
in
all_data
]
else
:
all_data
=
all_data
.
to
(
"cpu"
,
non_blocking
=
True
)
current_stream
().
synchronize
()
all_outputs
=
[
PoolingSequenceGroupOutput
(
data
)
for
data
in
all_data
]
all_outputs
=
[
PoolingSequenceGroupOutput
(
data
)
for
data
in
all_data
]
return
PoolerOutput
(
outputs
=
all_outputs
)
return
PoolerOutput
(
outputs
=
all_outputs
)
...
@@ -231,40 +239,21 @@ class PoolingMethod(nn.Module, ABC):
...
@@ -231,40 +239,21 @@ class PoolingMethod(nn.Module, ABC):
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
return
PoolingParamsUpdate
()
return
PoolingParamsUpdate
()
@
abstractmethod
def
forward_one
(
self
,
hidden_states
:
torch
.
Tensor
,
prompt_len
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""
Note:
`prompt_len=None` means `prompt_len=len(hidden_states)`.
"""
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
forward_all
(
def
forward_all
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
p
rompt_lens
:
torch
.
Ten
sor
,
p
ooling_cursor
:
PoolingCur
sor
,
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
raise
NotImplementedError
raise
NotImplementedError
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
pooling_metadata
:
PoolingMetadata
,
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
prompt_lens
=
get_prompt_lens
(
hidden_states
,
pooling_metadata
)
pooling_cursor
=
pooling_metadata
.
pooling_cursor
return
self
.
forward_all
(
hidden_states
,
pooling_cursor
)
if
isinstance
(
hidden_states
,
list
):
return
[
self
.
forward_one
(
h
,
prompt_len
)
for
h
,
prompt_len
in
zip
(
hidden_states
,
prompt_lens
)
]
return
self
.
forward_all
(
hidden_states
,
prompt_lens
)
class
CLSPool
(
PoolingMethod
):
class
CLSPool
(
PoolingMethod
):
...
@@ -272,24 +261,15 @@ class CLSPool(PoolingMethod):
...
@@ -272,24 +261,15 @@ class CLSPool(PoolingMethod):
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"encode"
,
"embed"
,
"classify"
,
"score"
}
return
{
"encode"
,
"embed"
,
"classify"
,
"score"
}
def
forward_one
(
self
,
hidden_states
:
torch
.
Tensor
,
prompt_len
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
assert
prompt_len
is
None
or
prompt_len
==
hidden_states
.
shape
[
0
],
\
"partial prefill not supported with CLS pooling"
return
hidden_states
[
0
]
def
forward_all
(
def
forward_all
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
p
rompt_lens
:
torch
.
Ten
sor
,
p
ooling_cursor
:
PoolingCur
sor
,
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
first_token_flat_indices
=
torch
.
zeros_like
(
prompt_lens
)
assert
not
pooling_cursor
.
is_partial_prefill
(),
\
first_token_flat_indices
[
1
:]
+=
torch
.
cumsum
(
prompt_lens
,
dim
=
0
)[:
-
1
]
"partial prefill not supported with CLS pooling"
return
hidden_states
[
first_token_flat_indices
]
return
hidden_states
[
pooling_cursor
.
first_token_indices_gpu
]
class
LastPool
(
PoolingMethod
):
class
LastPool
(
PoolingMethod
):
...
@@ -297,20 +277,12 @@ class LastPool(PoolingMethod):
...
@@ -297,20 +277,12 @@ class LastPool(PoolingMethod):
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"encode"
,
"embed"
,
"classify"
,
"score"
}
return
{
"encode"
,
"embed"
,
"classify"
,
"score"
}
def
forward_one
(
self
,
hidden_states
:
torch
.
Tensor
,
prompt_len
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
return
hidden_states
[
-
1
]
def
forward_all
(
def
forward_all
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
p
rompt_lens
:
torch
.
Ten
sor
,
p
ooling_cursor
:
PoolingCur
sor
,
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
last_token_flat_indices
=
torch
.
cumsum
(
prompt_lens
,
dim
=
0
)
-
1
return
hidden_states
[
pooling_cursor
.
last_token_indices_gpu
]
return
hidden_states
[
last_token_flat_indices
]
class
AllPool
(
PoolingMethod
):
class
AllPool
(
PoolingMethod
):
...
@@ -318,22 +290,19 @@ class AllPool(PoolingMethod):
...
@@ -318,22 +290,19 @@ class AllPool(PoolingMethod):
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"encode"
}
return
{
"encode"
}
def
forward_one
(
self
,
hidden_states
:
torch
.
Tensor
,
prompt_len
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
assert
prompt_len
is
None
or
prompt_len
==
hidden_states
.
shape
[
0
],
\
"partial prefill not supported with ALL pooling"
return
hidden_states
def
forward_all
(
def
forward_all
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
p
rompt_lens
:
torch
.
Ten
sor
,
p
ooling_cursor
:
PoolingCur
sor
,
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
return
list
(
hidden_states
.
split_with_sizes
(
prompt_lens
.
tolist
()))
assert
not
pooling_cursor
.
is_partial_prefill
(),
\
"partial prefill not supported with ALL pooling"
hidden_states_lst
=
list
(
hidden_states
.
split
(
pooling_cursor
.
num_scheduled_tokens_cpu
.
tolist
()))
return
[
hidden_states_lst
[
i
]
for
i
in
pooling_cursor
.
index
]
class
MeanPool
(
PoolingMethod
):
class
MeanPool
(
PoolingMethod
):
...
@@ -341,31 +310,25 @@ class MeanPool(PoolingMethod):
...
@@ -341,31 +310,25 @@ class MeanPool(PoolingMethod):
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"encode"
,
"embed"
,
"classify"
,
"score"
}
return
{
"encode"
,
"embed"
,
"classify"
,
"score"
}
def
forward_
one
(
def
forward_
all
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
prompt_len
:
Optional
[
torch
.
Tensor
]
=
None
,
pooling_cursor
:
PoolingCursor
,
)
->
torch
.
Tensor
:
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
assert
prompt_len
is
None
or
prompt_len
==
hidden_states
.
shape
[
0
],
\
assert
not
pooling_cursor
.
is_partial_prefill
(),
\
"partial prefill not supported with MEAN pooling"
"partial prefill not supported with MEAN pooling"
return
hidden_states
.
mean
(
dim
=
0
,
dtype
=
torch
.
float32
)
prompt_lens
=
pooling_cursor
.
prompt_lens_cpu
.
to
(
hidden_states
.
device
,
non_blocking
=
True
)
def
forward_all
(
self
,
hidden_states
:
torch
.
Tensor
,
prompt_lens
:
torch
.
Tensor
,
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]:
# Use float32 for torch.cumsum in MeanPool,
# Use float32 for torch.cumsum in MeanPool,
# otherwise precision will be lost significantly.
# otherwise precision will be lost significantly.
cumsum
=
torch
.
cumsum
(
hidden_states
,
dim
=
0
,
dtype
=
torch
.
float32
)
cumsum
=
torch
.
cumsum
(
hidden_states
,
dim
=
0
,
dtype
=
torch
.
float32
)
start_indices
=
torch
.
cat
([
start_indices
=
pooling_cursor
.
first_token_indices_gpu
torch
.
tensor
([
0
],
device
=
hidden_states
.
device
),
end_indices
=
pooling_cursor
.
last_token_indices_gpu
torch
.
cumsum
(
prompt_lens
[:
-
1
],
dim
=
0
)
return
(
cumsum
[
end_indices
]
-
cumsum
[
start_indices
]
+
])
end_indices
=
torch
.
cumsum
(
prompt_lens
,
dim
=
0
)
return
(
cumsum
[
end_indices
-
1
]
-
cumsum
[
start_indices
]
+
hidden_states
[
start_indices
])
/
prompt_lens
.
unsqueeze
(
1
)
hidden_states
[
start_indices
])
/
prompt_lens
.
unsqueeze
(
1
)
...
@@ -477,6 +440,10 @@ class EmbeddingPoolerHead(PoolerHead):
...
@@ -477,6 +440,10 @@ class EmbeddingPoolerHead(PoolerHead):
pooling_params
=
get_pooling_params
(
pooling_metadata
)
pooling_params
=
get_pooling_params
(
pooling_metadata
)
if
isinstance
(
pooled_data
,
list
):
pooled_data
=
torch
.
stack
(
pooled_data
)
# pooled_data shape: [batchsize, embedding_dimension]
# for matryoshka representation
# for matryoshka representation
dimensions_list
=
[
dimensions_list
=
[
pooling_param
.
dimensions
for
pooling_param
in
pooling_params
pooling_param
.
dimensions
for
pooling_param
in
pooling_params
...
@@ -667,6 +634,10 @@ class ClassifierPooler(Pooler):
...
@@ -667,6 +634,10 @@ class ClassifierPooler(Pooler):
)
->
PoolerOutput
:
)
->
PoolerOutput
:
pooled_data
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
pooled_data
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
if
isinstance
(
pooled_data
,
list
):
pooled_data
=
torch
.
stack
(
pooled_data
)
# pooled_data shape: [batchsize, hidden_size]
if
self
.
classifier
is
not
None
:
if
self
.
classifier
is
not
None
:
# apply classifier once on the full batch if possible
# apply classifier once on the full batch if possible
if
isinstance
(
pooled_data
,
torch
.
Tensor
):
if
isinstance
(
pooled_data
,
torch
.
Tensor
):
...
@@ -717,12 +688,6 @@ class DispatchPooler(Pooler):
...
@@ -717,12 +688,6 @@ class DispatchPooler(Pooler):
)
->
PoolerOutput
:
)
->
PoolerOutput
:
poolers_by_task
=
self
.
poolers_by_task
poolers_by_task
=
self
.
poolers_by_task
if
isinstance
(
hidden_states
,
list
):
hidden_states_lst
=
hidden_states
else
:
prompt_lens
=
get_prompt_lens
(
hidden_states
,
pooling_metadata
)
hidden_states_lst
=
list
(
hidden_states
.
split
(
prompt_lens
.
tolist
()))
outputs
=
list
[
PoolingSequenceGroupOutput
]()
outputs
=
list
[
PoolingSequenceGroupOutput
]()
offset
=
0
offset
=
0
for
task
,
group
in
groupby
(
get_tasks
(
pooling_metadata
)):
for
task
,
group
in
groupby
(
get_tasks
(
pooling_metadata
)):
...
@@ -733,7 +698,7 @@ class DispatchPooler(Pooler):
...
@@ -733,7 +698,7 @@ class DispatchPooler(Pooler):
num_items
=
len
(
list
(
group
))
num_items
=
len
(
list
(
group
))
group_output
:
PoolerOutput
=
pooler
(
group_output
:
PoolerOutput
=
pooler
(
hidden_states
_lst
[
offset
:
offset
+
num_items
]
,
hidden_states
,
pooling_metadata
[
offset
:
offset
+
num_items
],
pooling_metadata
[
offset
:
offset
+
num_items
],
)
)
...
...
vllm/model_executor/models/bert.py
View file @
d70a1662
...
@@ -528,7 +528,7 @@ def _encode_token_type_ids(input_ids: torch.Tensor,
...
@@ -528,7 +528,7 @@ def _encode_token_type_ids(input_ids: torch.Tensor,
def
_decode_token_type_ids
(
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_decode_token_type_ids
(
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
ids_mask
=
torch
.
ones
(
input_ids
.
shape
,
ids_mask
=
torch
.
ones
_like
(
input_ids
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
input_ids
.
device
)
<<
TOKEN_TYPE_SHIFT
device
=
input_ids
.
device
)
<<
TOKEN_TYPE_SHIFT
tokens_mask
=
ids_mask
.
bitwise_not
()
tokens_mask
=
ids_mask
.
bitwise_not
()
...
...
vllm/model_executor/models/roberta.py
View file @
d70a1662
...
@@ -9,7 +9,6 @@ from torch import nn
...
@@ -9,7 +9,6 @@ from torch import nn
from
transformers
import
RobertaConfig
from
transformers
import
RobertaConfig
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.forward_context
import
get_forward_context
from
vllm.model_executor.layers.pooler
import
(
ClassifierPooler
,
CLSPool
,
from
vllm.model_executor.layers.pooler
import
(
ClassifierPooler
,
CLSPool
,
DispatchPooler
,
Pooler
)
DispatchPooler
,
Pooler
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
@@ -100,7 +99,7 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
...
@@ -100,7 +99,7 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
self
.
padding_idx
=
vllm_config
.
model_config
.
hf_config
.
pad_token_id
self
.
padding_idx
:
int
=
vllm_config
.
model_config
.
hf_config
.
pad_token_id
def
forward
(
def
forward
(
self
,
self
,
...
@@ -178,7 +177,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
...
@@ -178,7 +177,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
self
.
padding_idx
=
vllm_config
.
model_config
.
hf_config
.
pad_token_id
self
.
padding_idx
:
int
=
vllm_config
.
model_config
.
hf_config
.
pad_token_id
self
.
num_labels
=
config
.
num_labels
self
.
num_labels
=
config
.
num_labels
self
.
roberta
=
BertModel
(
vllm_config
=
vllm_config
,
self
.
roberta
=
BertModel
(
vllm_config
=
vllm_config
,
...
@@ -233,58 +232,14 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
...
@@ -233,58 +232,14 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
intermediate_tensors
=
intermediate_tensors
)
intermediate_tensors
=
intermediate_tensors
)
# Adapted from transformers
def
create_position_ids_from_input_ids
(
input_ids
,
padding_idx
,
past_key_values_length
=
0
):
"""
Replace non-padding symbols with their position numbers.
Position numbers begin at padding_idx+1. Padding symbols
are ignored. This is modified from fairseq's `utils.make_positions`.
Args:
x: torch.Tensor x:
Returns: torch.Tensor
"""
# The series of casts and type-conversions here are carefully
# balanced to both work with ONNX export and XLA.
mask
=
input_ids
.
ne
(
padding_idx
).
int
()
incremental_indices
=
(
torch
.
cumsum
(
mask
,
dim
=
0
).
type_as
(
mask
)
+
past_key_values_length
)
*
mask
return
incremental_indices
.
long
()
+
padding_idx
def
replace_roberta_positions
(
input_ids
:
torch
.
Tensor
,
def
replace_roberta_positions
(
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
padding_idx
:
int
)
->
None
:
padding_idx
:
int
)
->
None
:
seq_lens
:
Optional
[
torch
.
Tensor
]
=
None
attn_metadata
=
get_forward_context
().
attn_metadata
if
attn_metadata
is
not
None
:
# can be None during warmup
if
isinstance
(
attn_metadata
,
dict
):
attn_metadata
=
next
(
iter
(
attn_metadata
.
values
()))
# TODO: remove "seq_lens_tensor" after V0 is removed
seq_lens
=
getattr
(
attn_metadata
,
"seq_lens_tensor"
,
getattr
(
attn_metadata
,
"seq_lens"
,
None
))
if
seq_lens
is
not
None
:
assert
isinstance
(
seq_lens
,
torch
.
Tensor
)
# Replace position ids because in RoBERTa models
# Replace position ids because in RoBERTa models
# they have to start at padding_idx + 1 and ignore
# they have to start at padding_idx + 1 and ignore
# existing padding tokens
# existing padding tokens
# References:
# References:
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
token_list
=
torch
.
split
(
input_ids
[:
torch
.
sum
(
seq_lens
)],
# vllm does not use padding tokens, let's make things simpler
seq_lens
.
tolist
())
position_ids
+=
padding_idx
+
1
offset
=
0
for
tokens
in
token_list
:
length
=
tokens
.
shape
[
0
]
position_ids
[
offset
:
offset
+
length
]
=
\
create_position_ids_from_input_ids
(
tokens
,
padding_idx
)
offset
=
offset
+
length
vllm/model_executor/pooling_metadata.py
View file @
d70a1662
...
@@ -2,12 +2,13 @@
...
@@ -2,12 +2,13 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Any
from
typing
import
Any
,
Optional
import
torch
import
torch
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.utils
import
is_pin_memory_available
from
vllm.utils
import
is_pin_memory_available
from
vllm.v1.pool.metadata
import
PoolingCursor
,
build_pooling_cursor
class
PoolingMetadata
:
class
PoolingMetadata
:
...
@@ -27,10 +28,11 @@ class PoolingMetadata:
...
@@ -27,10 +28,11 @@ class PoolingMetadata:
seq_groups
:
list
[
tuple
[
list
[
int
],
PoolingParams
]],
seq_groups
:
list
[
tuple
[
list
[
int
],
PoolingParams
]],
seq_data
:
dict
[
int
,
Any
],
# Specific data related to sequences
seq_data
:
dict
[
int
,
Any
],
# Specific data related to sequences
prompt_lens
:
list
[
int
],
prompt_lens
:
list
[
int
],
)
->
None
:
pooling_cursor
:
Optional
[
PoolingCursor
]
=
None
)
->
None
:
self
.
seq_groups
=
seq_groups
self
.
seq_groups
=
seq_groups
self
.
seq_data
=
seq_data
self
.
seq_data
=
seq_data
self
.
prompt_lens
=
prompt_lens
self
.
prompt_lens
=
prompt_lens
self
.
pooling_cursor
:
Optional
[
PoolingCursor
]
=
pooling_cursor
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
"PoolingMetadata("
return
(
"PoolingMetadata("
...
@@ -43,8 +45,17 @@ class PoolingMetadata:
...
@@ -43,8 +45,17 @@ class PoolingMetadata:
seq_groups
=
self
.
seq_groups
[
indices
],
seq_groups
=
self
.
seq_groups
[
indices
],
seq_data
=
dict
(
list
(
self
.
seq_data
.
items
())[
indices
]),
seq_data
=
dict
(
list
(
self
.
seq_data
.
items
())[
indices
]),
prompt_lens
=
self
.
prompt_lens
[
indices
],
prompt_lens
=
self
.
prompt_lens
[
indices
],
pooling_cursor
=
None
if
self
.
pooling_cursor
is
None
else
self
.
pooling_cursor
[
indices
],
)
)
def
build_pooling_cursor
(
self
,
num_scheduled_tokens
:
list
[
int
],
device
:
torch
.
device
):
prompt_lens
=
torch
.
tensor
(
self
.
prompt_lens
,
device
=
"cpu"
)
self
.
pooling_cursor
=
build_pooling_cursor
(
num_scheduled_tokens
,
prompt_lens
,
device
=
device
)
@
dataclass
@
dataclass
class
PoolingTensors
:
class
PoolingTensors
:
...
...
vllm/v1/pool/metadata.py
View file @
d70a1662
...
@@ -6,15 +6,40 @@ from typing import Optional
...
@@ -6,15 +6,40 @@ from typing import Optional
import
torch
import
torch
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.utils
import
is_pin_memory_available
pin_memory
=
is_pin_memory_available
()
@
dataclass
class
PoolingCursor
:
index
:
list
[
int
]
first_token_indices_gpu
:
torch
.
Tensor
last_token_indices_gpu
:
torch
.
Tensor
prompt_lens_cpu
:
torch
.
Tensor
num_scheduled_tokens_cpu
:
torch
.
Tensor
def
__getitem__
(
self
,
indices
:
slice
):
return
PoolingCursor
(
index
=
self
.
index
[
indices
],
first_token_indices_gpu
=
self
.
first_token_indices_gpu
[
indices
],
last_token_indices_gpu
=
self
.
last_token_indices_gpu
[
indices
],
prompt_lens_cpu
=
self
.
prompt_lens_cpu
[
indices
],
num_scheduled_tokens_cpu
=
self
.
num_scheduled_tokens_cpu
[
indices
],
)
def
is_partial_prefill
(
self
):
return
not
torch
.
all
(
self
.
prompt_lens_cpu
==
self
.
num_scheduled_tokens_cpu
)
@
dataclass
@
dataclass
class
PoolingMetadata
:
class
PoolingMetadata
:
"""Tensors for pooling."""
"""Tensors for pooling."""
prompt_lens
:
torch
.
Tensor
# CPU Tensor
prompt_lens
:
torch
.
Tensor
prompt_token_ids
:
Optional
[
torch
.
Tensor
]
prompt_token_ids
:
Optional
[
torch
.
Tensor
]
pooling_params
:
list
[
PoolingParams
]
pooling_params
:
list
[
PoolingParams
]
pooling_cursor
:
Optional
[
PoolingCursor
]
=
None
def
__getitem__
(
self
,
indices
:
slice
):
def
__getitem__
(
self
,
indices
:
slice
):
return
PoolingMetadata
(
return
PoolingMetadata
(
...
@@ -22,4 +47,31 @@ class PoolingMetadata:
...
@@ -22,4 +47,31 @@ class PoolingMetadata:
prompt_token_ids
=
None
if
self
.
prompt_token_ids
is
None
else
prompt_token_ids
=
None
if
self
.
prompt_token_ids
is
None
else
self
.
prompt_token_ids
[
indices
],
self
.
prompt_token_ids
[
indices
],
pooling_params
=
self
.
pooling_params
[
indices
],
pooling_params
=
self
.
pooling_params
[
indices
],
pooling_cursor
=
None
if
self
.
pooling_cursor
is
None
else
self
.
pooling_cursor
[
indices
],
)
)
def
build_pooling_cursor
(
self
,
num_scheduled_tokens
:
list
[
int
],
device
:
torch
.
device
):
self
.
pooling_cursor
=
build_pooling_cursor
(
num_scheduled_tokens
,
self
.
prompt_lens
,
device
)
def
build_pooling_cursor
(
num_scheduled_tokens
:
list
[
int
],
prompt_lens
:
torch
.
Tensor
,
device
:
torch
.
device
):
assert
len
(
prompt_lens
)
==
len
(
num_scheduled_tokens
)
n_seq
=
len
(
num_scheduled_tokens
)
index
=
list
(
range
(
n_seq
))
num_scheduled_tokens
=
torch
.
tensor
(
num_scheduled_tokens
,
device
=
"cpu"
)
cumsum
=
torch
.
zeros
(
n_seq
+
1
,
dtype
=
torch
.
int64
,
pin_memory
=
pin_memory
,
device
=
"cpu"
)
torch
.
cumsum
(
num_scheduled_tokens
,
dim
=
0
,
out
=
cumsum
[
1
:])
cumsum
=
cumsum
.
to
(
device
,
non_blocking
=
True
)
return
PoolingCursor
(
index
=
index
,
first_token_indices_gpu
=
cumsum
[:
n_seq
],
last_token_indices_gpu
=
cumsum
[
1
:]
-
1
,
prompt_lens_cpu
=
prompt_lens
,
num_scheduled_tokens_cpu
=
num_scheduled_tokens
)
vllm/v1/worker/gpu_input_batch.py
View file @
d70a1662
...
@@ -713,7 +713,7 @@ class InputBatch:
...
@@ -713,7 +713,7 @@ class InputBatch:
return
PoolingMetadata
(
return
PoolingMetadata
(
prompt_lens
=
torch
.
from_numpy
(
prompt_lens
=
torch
.
from_numpy
(
self
.
num_prompt_tokens
[:
self
.
num_reqs
])
.
to
(
self
.
device
)
,
self
.
num_prompt_tokens
[:
self
.
num_reqs
]),
prompt_token_ids
=
self
.
sampling_metadata
.
prompt_token_ids
,
prompt_token_ids
=
self
.
sampling_metadata
.
prompt_token_ids
,
pooling_params
=
pooling_params
,
pooling_params
=
pooling_params
,
)
)
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
d70a1662
...
@@ -1476,23 +1476,22 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -1476,23 +1476,22 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
"Either all or none of the requests in"
\
"Either all or none of the requests in"
\
" a batch must be pooling request"
" a batch must be pooling request"
extracted_hidden_states
=
list
(
hidden_states
=
hidden_states
[:
num_scheduled_tokens
]
torch
.
split
(
hidden_states
[:
num_scheduled_tokens
],
num_scheduled_tokens_np
.
tolist
()))
pooling_metadata
=
self
.
input_batch
.
pooling_metadata
pooling_metadata
=
self
.
input_batch
.
pooling_metadata
pooling_metadata
.
build_pooling_cursor
(
num_scheduled_tokens_np
.
tolist
(),
device
=
hidden_states
.
device
)
seq_lens_cpu
=
self
.
seq_lens_cpu
[:
self
.
input_batch
.
num_reqs
]
# Pooling models D2H & synchronize occurs in pooler.py:build_output
raw_pooler_output
=
self
.
model
.
pooler
(
raw_pooler_output
=
self
.
model
.
pooler
(
hidden_states
=
extracted_hidden_states
,
hidden_states
=
hidden_states
,
pooling_metadata
=
pooling_metadata
)
pooling_metadata
=
pooling_metadata
)
pooler_output
:
list
[
Optional
[
torch
.
Tensor
]]
=
[]
pooler_output
:
list
[
Optional
[
torch
.
Tensor
]]
=
[]
seq_lens
=
self
.
seq_lens
[:
self
.
input_batch
.
num_reqs
]
for
raw_output
,
seq_len
,
prompt_len
in
zip
(
for
raw_output
,
seq_len
,
prompt_len
in
zip
(
raw_pooler_output
,
seq_lens
,
pooling_metadata
.
prompt_lens
):
raw_pooler_output
,
seq_lens
_cpu
,
pooling_metadata
.
prompt_lens
):
if
seq_len
==
prompt_len
:
if
seq_len
==
prompt_len
:
pooler_output
.
append
(
raw_output
.
data
.
cpu
()
)
pooler_output
.
append
(
raw_output
.
data
)
else
:
else
:
pooler_output
.
append
(
None
)
pooler_output
.
append
(
None
)
...
@@ -2524,13 +2523,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -2524,13 +2523,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
assert
sum
(
num_scheduled_tokens_list
)
==
num_tokens
assert
sum
(
num_scheduled_tokens_list
)
==
num_tokens
assert
len
(
num_scheduled_tokens_list
)
==
num_reqs
assert
len
(
num_scheduled_tokens_list
)
==
num_reqs
hidden_states_list
=
list
(
torch
.
split
(
hidden_states
,
num_scheduled_tokens_list
))
req_num_tokens
=
num_tokens
//
num_reqs
req_num_tokens
=
num_tokens
//
num_reqs
dummy_prompt_lens
=
torch
.
tensor
(
dummy_prompt_lens
=
torch
.
tensor
(
[
h
.
shape
[
0
]
for
h
in
hidden_state
s_list
]
,
num_scheduled_token
s_list
,
device
=
self
.
device
,
device
=
"cpu"
,
)
)
dummy_token_ids
=
torch
.
zeros
((
num_reqs
,
req_num_tokens
),
dummy_token_ids
=
torch
.
zeros
((
num_reqs
,
req_num_tokens
),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
...
@@ -2547,8 +2544,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -2547,8 +2544,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
pooling_params
=
[
dummy_pooling_params
]
*
num_reqs
,
pooling_params
=
[
dummy_pooling_params
]
*
num_reqs
,
)
)
dummy_metadata
.
build_pooling_cursor
(
num_scheduled_tokens_list
,
device
=
hidden_states
.
device
)
try
:
try
:
return
model
.
pooler
(
hidden_states
=
hidden_states
_list
,
return
model
.
pooler
(
hidden_states
=
hidden_states
,
pooling_metadata
=
dummy_metadata
)
pooling_metadata
=
dummy_metadata
)
except
RuntimeError
as
e
:
except
RuntimeError
as
e
:
if
'out of memory'
in
str
(
e
):
if
'out of memory'
in
str
(
e
):
...
@@ -3316,10 +3316,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -3316,10 +3316,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dummy_block_table
=
torch
.
zeros
((
num_reqs
,
1
),
dummy_block_table
=
torch
.
zeros
((
num_reqs
,
1
),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
pin_memory
=
self
.
pin_memory
,
device
=
"cpu"
).
to
(
self
.
device
,
non_blocking
=
True
)
dummy_slot_mapping
=
torch
.
zeros
((
total_num_scheduled_tokens
,
),
dummy_slot_mapping
=
torch
.
zeros
((
total_num_scheduled_tokens
,
),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
pin_memory
=
self
.
pin_memory
,
device
=
"cpu"
).
to
(
self
.
device
,
non_blocking
=
True
)
group_metadata
=
dict
[
str
,
tuple
[
CommonAttentionMetadata
,
Any
]]()
group_metadata
=
dict
[
str
,
tuple
[
CommonAttentionMetadata
,
Any
]]()
...
...
vllm/worker/pooling_model_runner.py
View file @
d70a1662
...
@@ -149,9 +149,16 @@ class PoolingModelRunner(
...
@@ -149,9 +149,16 @@ class PoolingModelRunner(
if
not
self
.
is_driver_worker
:
if
not
self
.
is_driver_worker
:
return
[]
return
[]
pooling_metadata
=
model_input
.
pooling_metadata
assert
pooling_metadata
is
not
None
pooling_metadata
.
build_pooling_cursor
(
num_scheduled_tokens
=
pooling_metadata
.
prompt_lens
,
device
=
hidden_or_intermediate_states
.
device
)
return
[
return
[
self
.
model
.
pooler
(
hidden_states
=
hidden_or_intermediate_states
,
self
.
model
.
pooler
(
hidden_states
=
hidden_or_intermediate_states
,
pooling_metadata
=
model_input
.
pooling_metadata
)
pooling_metadata
=
pooling_metadata
)
]
]
def
make_model_input_from_broadcasted_tensor_dict
(
def
make_model_input_from_broadcasted_tensor_dict
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment