Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
68eb5c8d
Unverified
Commit
68eb5c8d
authored
Dec 04, 2025
by
Cyrus Leung
Committed by
GitHub
Dec 04, 2025
Browse files
[Misc] Move functions into `PoolingMetadata` (#30027)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
5430e110
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
30 additions
and
47 deletions
+30
-47
vllm/model_executor/layers/pooler.py
vllm/model_executor/layers/pooler.py
+7
-43
vllm/model_executor/models/gritlm.py
vllm/model_executor/models/gritlm.py
+2
-4
vllm/v1/pool/metadata.py
vllm/v1/pool/metadata.py
+21
-0
No files found.
vllm/model_executor/layers/pooler.py
View file @
68eb5c8d
...
...
@@ -64,42 +64,6 @@ class PoolingParamsUpdate:
params
.
requires_token_ids
=
self
.
requires_token_ids
def
get_prompt_lens
(
hidden_states
:
torch
.
Tensor
|
list
[
torch
.
Tensor
],
pooling_metadata
:
PoolingMetadata
,
)
->
torch
.
Tensor
:
return
pooling_metadata
.
prompt_lens
def
get_prompt_token_ids
(
pooling_metadata
:
PoolingMetadata
)
->
list
[
torch
.
Tensor
]:
assert
pooling_metadata
.
prompt_token_ids
is
not
None
,
(
"Please set `requires_token_ids=True` in `get_pooling_updates`"
)
return
[
pooling_metadata
.
prompt_token_ids
[
i
,
:
num
]
for
i
,
num
in
enumerate
(
pooling_metadata
.
prompt_lens
)
]
def
get_pooling_params
(
pooling_metadata
:
PoolingMetadata
)
->
list
[
PoolingParams
]:
pooling_params
=
pooling_metadata
.
pooling_params
return
pooling_params
def
get_tasks
(
pooling_metadata
:
PoolingMetadata
)
->
list
[
PoolingTask
]:
pooling_params
=
get_pooling_params
(
pooling_metadata
)
tasks
:
list
[
PoolingTask
]
=
[
task
for
pooling_param
in
pooling_params
if
(
task
:
=
pooling_param
.
task
)
is
not
None
]
assert
len
(
pooling_params
)
==
len
(
tasks
)
return
tasks
def
get_classification_activation_function
(
config
:
PretrainedConfig
):
# Implement alignment with transformers ForSequenceClassificationLoss
# https://github.com/huggingface/transformers/blob/57bb6db6ee4cfaccc45b8d474dfad5a17811ca60/src/transformers/loss/loss_utils.py#L92
...
...
@@ -466,7 +430,7 @@ class EmbeddingPoolerHead(PoolerHead):
pooled_data
=
self
.
projector
(
pooled_data
)
# pooled_data shape: [batchsize, embedding_dimension]
pooling_params
=
get_
pooling_
params
(
pooling_metadata
)
pooling_params
=
pooling_
metadata
.
pooling_params
# for matryoshka representation
dimensions_list
=
[
pooling_param
.
dimensions
for
pooling_param
in
pooling_params
]
...
...
@@ -606,7 +570,7 @@ class ClassifierPooler(Pooler):
if
self
.
logit_bias
is
not
None
:
pooled_data
-=
self
.
logit_bias
pooling_params
=
get_
pooling_
params
(
pooling_metadata
)
pooling_params
=
pooling_
metadata
.
pooling_params
flags
=
[
p
.
use_activation
for
p
in
pooling_params
]
if
len
(
set
(
flags
))
==
1
:
...
...
@@ -704,7 +668,7 @@ class AllPooler(Pooler):
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
pooled_data
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
pooling_params
=
get_
pooling_
params
(
pooling_metadata
)
pooling_params
=
pooling_
metadata
.
pooling_params
assert
len
(
pooled_data
)
==
len
(
pooling_params
)
pooled_data
=
[
self
.
head
(
d
,
p
)
for
d
,
p
in
zip
(
pooled_data
,
pooling_params
)]
...
...
@@ -724,11 +688,11 @@ class StepPooler(Pooler):
pooling_metadata
:
PoolingMetadata
,
)
->
torch
.
Tensor
|
list
[
torch
.
Tensor
]:
pooled_data_lst
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
prompt_token_ids
=
get_prompt_token_ids
(
pooling_metadata
)
prompt_token_ids
=
pooling_metadata
.
get_prompt_token_ids
()
pooled_data
=
list
[
torch
.
Tensor
]()
pooling_params
=
get_
pooling_
params
(
pooling_metadata
)
pooling_params
=
pooling_
metadata
.
pooling_params
for
data
,
token_id
,
pooling_param
in
zip
(
pooled_data_lst
,
prompt_token_ids
,
pooling_params
...
...
@@ -757,7 +721,7 @@ class StepPooler(Pooler):
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
pooled_data
=
self
.
extract_states
(
hidden_states
,
pooling_metadata
)
pooling_params
=
get_
pooling_
params
(
pooling_metadata
)
pooling_params
=
pooling_
metadata
.
pooling_params
assert
len
(
pooled_data
)
==
len
(
pooling_params
)
pooled_data
=
[
self
.
head
(
d
,
p
)
for
d
,
p
in
zip
(
pooled_data
,
pooling_params
)]
...
...
@@ -794,7 +758,7 @@ class DispatchPooler(Pooler):
outputs
=
list
[
torch
.
Tensor
]()
offset
=
0
for
task
,
group
in
groupby
(
get_tasks
(
pooling_metadata
)
):
for
task
,
group
in
groupby
(
pooling_metadata
.
tasks
):
if
not
(
pooler
:
=
poolers_by_task
.
get
(
task
)):
raise
ValueError
(
f
"Unsupported task:
{
task
}
"
...
...
vllm/model_executor/models/gritlm.py
View file @
68eb5c8d
...
...
@@ -14,8 +14,6 @@ from vllm.model_executor.layers.pooler import (
PoolerHead
,
PoolerNormalize
,
PoolingParamsUpdate
,
get_prompt_lens
,
get_prompt_token_ids
,
)
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
from
vllm.tasks
import
PoolingTask
...
...
@@ -153,11 +151,11 @@ class GritLMMeanPool(nn.Module):
hidden_states
:
torch
.
Tensor
|
list
[
torch
.
Tensor
],
pooling_metadata
:
PoolingMetadata
,
)
->
list
[
torch
.
Tensor
]
|
torch
.
Tensor
:
prompt_lens
=
get_prompt_lens
(
hidden_states
,
pooling_metadata
)
prompt_lens
=
pooling_metadata
.
prompt_lens
instr_lens
=
torch
.
tensor
(
[
self
.
_get_instruction_len
(
token_ids
.
cpu
().
numpy
())
for
token_ids
in
get_prompt_token_ids
(
pooling_metadata
)
for
token_ids
in
pooling_metadata
.
get_prompt_token_ids
()
],
device
=
"cpu"
,
)
...
...
vllm/v1/pool/metadata.py
View file @
68eb5c8d
...
...
@@ -5,6 +5,7 @@ from dataclasses import dataclass
import
torch
from
vllm.pooling_params
import
PoolingParams
from
vllm.tasks
import
PoolingTask
from
vllm.utils.platform_utils
import
is_pin_memory_available
pin_memory
=
is_pin_memory_available
()
...
...
@@ -40,6 +41,18 @@ class PoolingMetadata:
pooling_params
:
list
[
PoolingParams
]
pooling_cursor
:
PoolingCursor
|
None
=
None
def
__post_init__
(
self
)
->
None
:
pooling_params
=
self
.
pooling_params
tasks
:
list
[
PoolingTask
]
=
[
task
for
pooling_param
in
pooling_params
if
(
task
:
=
pooling_param
.
task
)
is
not
None
]
assert
len
(
pooling_params
)
==
len
(
tasks
)
self
.
tasks
=
tasks
def
__getitem__
(
self
,
indices
:
slice
):
return
PoolingMetadata
(
prompt_lens
=
self
.
prompt_lens
[
indices
],
...
...
@@ -52,6 +65,14 @@ class PoolingMetadata:
else
self
.
pooling_cursor
[
indices
],
)
def
get_prompt_token_ids
(
self
)
->
list
[
torch
.
Tensor
]:
prompt_token_ids
=
self
.
prompt_token_ids
assert
prompt_token_ids
is
not
None
,
(
"Please set `requires_token_ids=True` in `get_pooling_updates`"
)
return
[
prompt_token_ids
[
i
,
:
num
]
for
i
,
num
in
enumerate
(
self
.
prompt_lens
)]
def
build_pooling_cursor
(
self
,
num_scheduled_tokens
:
list
[
int
],
device
:
torch
.
device
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment