Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
755ed7b0
Unverified
Commit
755ed7b0
authored
Sep 25, 2025
by
Cyrus Leung
Committed by
GitHub
Sep 25, 2025
Browse files
[Misc] Simplify PoolerOutput and move to `v1/outputs` (#25629)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
a676e668
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
34 additions
and
82 deletions
+34
-82
vllm/executor/executor_base.py
vllm/executor/executor_base.py
+2
-2
vllm/model_executor/layers/pooler.py
vllm/model_executor/layers/pooler.py
+8
-21
vllm/model_executor/models/gritlm.py
vllm/model_executor/models/gritlm.py
+3
-3
vllm/sequence.py
vllm/sequence.py
+0
-48
vllm/v1/outputs.py
vllm/v1/outputs.py
+6
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+15
-7
No files found.
vllm/executor/executor_base.py
View file @
755ed7b0
...
@@ -15,10 +15,10 @@ from vllm.config import VllmConfig
...
@@ -15,10 +15,10 @@ from vllm.config import VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.utils
import
KVOutputAggregator
from
vllm.distributed.kv_transfer.kv_connector.utils
import
KVOutputAggregator
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
ExecuteModelRequest
,
PoolerOutput
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.tasks
import
SupportedTask
from
vllm.tasks
import
SupportedTask
from
vllm.utils
import
make_async
from
vllm.utils
import
make_async
from
vllm.v1.outputs
import
SamplerOutput
from
vllm.v1.outputs
import
PoolerOutput
,
SamplerOutput
from
vllm.worker.worker_base
import
WorkerBase
from
vllm.worker.worker_base
import
WorkerBase
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
...
vllm/model_executor/layers/pooler.py
View file @
755ed7b0
...
@@ -16,9 +16,9 @@ from vllm.config import ModelConfig, PoolerConfig, get_current_vllm_config
...
@@ -16,9 +16,9 @@ from vllm.config import ModelConfig, PoolerConfig, get_current_vllm_config
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.models.adapters
import
_load_st_projector
from
vllm.model_executor.models.adapters
import
_load_st_projector
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.sequence
import
PoolerOutput
,
PoolingSequenceGroupOutput
from
vllm.tasks
import
PoolingTask
from
vllm.tasks
import
PoolingTask
from
vllm.utils
import
current_stream
,
resolve_obj_by_qualname
from
vllm.utils
import
resolve_obj_by_qualname
from
vllm.v1.outputs
import
PoolerOutput
from
vllm.v1.pool.metadata
import
PoolingCursor
,
PoolingMetadata
from
vllm.v1.pool.metadata
import
PoolingCursor
,
PoolingMetadata
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -190,19 +190,6 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
...
@@ -190,19 +190,6 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
return
PoolerClassify
()
return
PoolerClassify
()
def
build_output
(
all_data
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
)
->
PoolerOutput
:
# Pooling models D2H & synchronize occurs here
if
isinstance
(
all_data
,
list
):
all_data
=
[
d
.
to
(
"cpu"
,
non_blocking
=
True
)
for
d
in
all_data
]
else
:
all_data
=
all_data
.
to
(
"cpu"
,
non_blocking
=
True
)
current_stream
().
synchronize
()
all_outputs
=
[
PoolingSequenceGroupOutput
(
data
)
for
data
in
all_data
]
return
PoolerOutput
(
outputs
=
all_outputs
)
class
PoolingMethod
(
nn
.
Module
,
ABC
):
class
PoolingMethod
(
nn
.
Module
,
ABC
):
@
staticmethod
@
staticmethod
...
@@ -556,7 +543,7 @@ class SimplePooler(Pooler):
...
@@ -556,7 +543,7 @@ class SimplePooler(Pooler):
)
->
PoolerOutput
:
)
->
PoolerOutput
:
pooled_data
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
pooled_data
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
pooled_data
=
self
.
head
(
pooled_data
,
pooling_metadata
)
pooled_data
=
self
.
head
(
pooled_data
,
pooling_metadata
)
return
build_output
(
pooled_data
)
return
pooled_data
class
StepPooler
(
Pooler
):
class
StepPooler
(
Pooler
):
...
@@ -607,7 +594,7 @@ class StepPooler(Pooler):
...
@@ -607,7 +594,7 @@ class StepPooler(Pooler):
)
->
PoolerOutput
:
)
->
PoolerOutput
:
pooled_data
=
self
.
extract_states
(
hidden_states
,
pooling_metadata
)
pooled_data
=
self
.
extract_states
(
hidden_states
,
pooling_metadata
)
pooled_data
=
self
.
head
(
pooled_data
,
pooling_metadata
)
pooled_data
=
self
.
head
(
pooled_data
,
pooling_metadata
)
return
build_output
(
pooled_data
)
return
pooled_data
class
ClassifierPooler
(
Pooler
):
class
ClassifierPooler
(
Pooler
):
...
@@ -678,7 +665,7 @@ class ClassifierPooler(Pooler):
...
@@ -678,7 +665,7 @@ class ClassifierPooler(Pooler):
]
]
# scores shape: [batchsize, num_labels]
# scores shape: [batchsize, num_labels]
return
build_output
(
scores
)
return
scores
class
DispatchPooler
(
Pooler
):
class
DispatchPooler
(
Pooler
):
...
@@ -708,7 +695,7 @@ class DispatchPooler(Pooler):
...
@@ -708,7 +695,7 @@ class DispatchPooler(Pooler):
)
->
PoolerOutput
:
)
->
PoolerOutput
:
poolers_by_task
=
self
.
poolers_by_task
poolers_by_task
=
self
.
poolers_by_task
outputs
=
list
[
PoolingSequenceGroupOutput
]()
outputs
=
list
[
torch
.
Tensor
]()
offset
=
0
offset
=
0
for
task
,
group
in
groupby
(
get_tasks
(
pooling_metadata
)):
for
task
,
group
in
groupby
(
get_tasks
(
pooling_metadata
)):
if
not
(
pooler
:
=
poolers_by_task
.
get
(
task
)):
if
not
(
pooler
:
=
poolers_by_task
.
get
(
task
)):
...
@@ -722,10 +709,10 @@ class DispatchPooler(Pooler):
...
@@ -722,10 +709,10 @@ class DispatchPooler(Pooler):
pooling_metadata
[
offset
:
offset
+
num_items
],
pooling_metadata
[
offset
:
offset
+
num_items
],
)
)
outputs
.
extend
(
group_output
.
outputs
)
outputs
.
extend
(
group_output
)
offset
+=
num_items
offset
+=
num_items
return
PoolerOutput
(
outputs
)
return
outputs
def
extra_repr
(
self
)
->
str
:
def
extra_repr
(
self
)
->
str
:
s
=
f
"supported_task=
{
self
.
get_supported_tasks
()
}
"
s
=
f
"supported_task=
{
self
.
get_supported_tasks
()
}
"
...
...
vllm/model_executor/models/gritlm.py
View file @
755ed7b0
...
@@ -12,12 +12,12 @@ from vllm.logger import init_logger
...
@@ -12,12 +12,12 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.pooler
import
(
DispatchPooler
,
Pooler
,
from
vllm.model_executor.layers.pooler
import
(
DispatchPooler
,
Pooler
,
PoolerHead
,
PoolerNormalize
,
PoolerHead
,
PoolerNormalize
,
PoolingParamsUpdate
,
PoolingParamsUpdate
,
build_output
,
get_prompt_lens
,
get_prompt_lens
,
get_prompt_token_ids
)
get_prompt_token_ids
)
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
from
vllm.sequence
import
PoolerOutput
from
vllm.tasks
import
PoolingTask
from
vllm.tasks
import
PoolingTask
from
vllm.transformers_utils.tokenizer
import
cached_tokenizer_from_config
from
vllm.transformers_utils.tokenizer
import
cached_tokenizer_from_config
from
vllm.v1.outputs
import
PoolerOutput
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
.interfaces_base
import
default_pooling_type
from
.interfaces_base
import
default_pooling_type
...
@@ -212,7 +212,7 @@ class GritLMPooler(Pooler):
...
@@ -212,7 +212,7 @@ class GritLMPooler(Pooler):
)
->
PoolerOutput
:
)
->
PoolerOutput
:
pooled_data
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
pooled_data
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
pooled_data
=
self
.
head
(
pooled_data
,
pooling_metadata
)
pooled_data
=
self
.
head
(
pooled_data
,
pooling_metadata
)
return
build_output
(
pooled_data
)
return
pooled_data
@
default_pooling_type
(
"MEAN"
)
@
default_pooling_type
(
"MEAN"
)
...
...
vllm/sequence.py
View file @
755ed7b0
...
@@ -11,7 +11,6 @@ if TYPE_CHECKING:
...
@@ -11,7 +11,6 @@ if TYPE_CHECKING:
from
vllm.v1.worker.kv_connector_model_runner_mixin
import
(
from
vllm.v1.worker.kv_connector_model_runner_mixin
import
(
KVConnectorOutput
)
KVConnectorOutput
)
else
:
else
:
LoRARequest
=
Any
KVConnectorOutput
=
Any
KVConnectorOutput
=
Any
VLLM_TOKEN_ID_ARRAY_TYPE
=
"l"
VLLM_TOKEN_ID_ARRAY_TYPE
=
"l"
...
@@ -48,29 +47,6 @@ class RequestMetrics:
...
@@ -48,29 +47,6 @@ class RequestMetrics:
model_execute_time
:
Optional
[
float
]
=
None
model_execute_time
:
Optional
[
float
]
=
None
class
PoolingSequenceGroupOutput
(
msgspec
.
Struct
,
omit_defaults
=
True
,
# type: ignore[call-arg]
array_like
=
True
,
# type: ignore[call-arg]
):
"""The model output associated with a pooling sequence group."""
# Annotated as Any to be compatible with msgspec
# The actual type is in SequenceGroup.pooled_data
data
:
Any
def
get_data_nbytes
(
self
)
->
int
:
data
:
torch
.
Tensor
=
self
.
data
return
data
.
nbytes
def
__repr__
(
self
)
->
str
:
return
f
"PoolingSequenceGroupOutput(data=
{
self
.
data
}
"
def
__eq__
(
self
,
other
:
object
)
->
bool
:
if
not
isinstance
(
other
,
PoolingSequenceGroupOutput
):
raise
NotImplementedError
()
return
self
.
data
==
other
.
data
# cannot use msgspec.Struct here because Dynamo does not support it
# cannot use msgspec.Struct here because Dynamo does not support it
@
dataclass
@
dataclass
class
IntermediateTensors
:
class
IntermediateTensors
:
...
@@ -119,30 +95,6 @@ class IntermediateTensors:
...
@@ -119,30 +95,6 @@ class IntermediateTensors:
return
f
"IntermediateTensors(tensors=
{
self
.
tensors
}
)"
return
f
"IntermediateTensors(tensors=
{
self
.
tensors
}
)"
class
PoolerOutput
(
msgspec
.
Struct
,
omit_defaults
=
True
,
# type: ignore[call-arg]
array_like
=
True
):
# type: ignore[call-arg]
"""The output from a pooling operation in the pooling model."""
outputs
:
list
[
PoolingSequenceGroupOutput
]
def
get_data_nbytes
(
self
)
->
int
:
return
sum
(
o
.
get_data_nbytes
()
for
o
in
self
.
outputs
)
def
__getitem__
(
self
,
idx
:
int
)
->
PoolingSequenceGroupOutput
:
return
self
.
outputs
[
idx
]
def
__setitem__
(
self
,
idx
:
int
,
value
:
PoolingSequenceGroupOutput
):
self
.
outputs
[
idx
]
=
value
def
__len__
(
self
):
return
len
(
self
.
outputs
)
def
__eq__
(
self
,
other
:
object
):
return
isinstance
(
other
,
self
.
__class__
)
and
self
.
outputs
==
other
.
outputs
class
ExecuteModelRequest
(
class
ExecuteModelRequest
(
msgspec
.
Struct
,
msgspec
.
Struct
,
array_like
=
True
,
# type: ignore[call-arg]
array_like
=
True
,
# type: ignore[call-arg]
...
...
vllm/v1/outputs.py
View file @
755ed7b0
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
NamedTuple
,
Optional
from
typing
import
TYPE_CHECKING
,
NamedTuple
,
Optional
,
Union
import
torch
import
torch
...
@@ -65,6 +65,11 @@ class LogprobsTensors(NamedTuple):
...
@@ -65,6 +65,11 @@ class LogprobsTensors(NamedTuple):
)
)
# [num_reqs, <dynamic>]
# The shape of each element depends on the pooler used
PoolerOutput
=
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
@
dataclass
@
dataclass
class
SamplerOutput
:
class
SamplerOutput
:
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
755ed7b0
...
@@ -52,13 +52,14 @@ from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargsItem,
...
@@ -52,13 +52,14 @@ from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargsItem,
from
vllm.multimodal.utils
import
group_mm_kwargs_by_modality
from
vllm.multimodal.utils
import
group_mm_kwargs_by_modality
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
SamplingType
from
vllm.sampling_params
import
SamplingType
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
from
vllm.sequence
import
IntermediateTensors
from
vllm.tasks
import
GenerationTask
,
PoolingTask
,
SupportedTask
from
vllm.tasks
import
GenerationTask
,
PoolingTask
,
SupportedTask
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
DeviceMemoryProfiler
,
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
DeviceMemoryProfiler
,
GiB_bytes
,
cdiv
,
check_use_alibi
,
get_dtype_size
,
GiB_bytes
,
cdiv
,
check_use_alibi
,
get_dtype_size
,
is_pin_memory_available
,
is_pin_memory_available
,
length_from_prompt_token_ids_or_embeds
,
round_up
,
length_from_prompt_token_ids_or_embeds
,
round_up
,
supports_dynamo
)
supports_dynamo
)
from
vllm.utils.jsontree
import
json_map_leaves
from
vllm.v1.attention.backends.flash_attn
import
AttentionMetadata
from
vllm.v1.attention.backends.flash_attn
import
AttentionMetadata
from
vllm.v1.attention.backends.gdn_attn
import
GDNAttentionMetadataBuilder
from
vllm.v1.attention.backends.gdn_attn
import
GDNAttentionMetadataBuilder
from
vllm.v1.attention.backends.utils
import
(
from
vllm.v1.attention.backends.utils
import
(
...
@@ -79,7 +80,7 @@ from vllm.v1.kv_cache_interface import (AttentionSpec,
...
@@ -79,7 +80,7 @@ from vllm.v1.kv_cache_interface import (AttentionSpec,
# yapf: enable
# yapf: enable
from
vllm.v1.outputs
import
(
EMPTY_MODEL_RUNNER_OUTPUT
,
AsyncModelRunnerOutput
,
from
vllm.v1.outputs
import
(
EMPTY_MODEL_RUNNER_OUTPUT
,
AsyncModelRunnerOutput
,
DraftTokenIds
,
LogprobsLists
,
LogprobsTensors
,
DraftTokenIds
,
LogprobsLists
,
LogprobsTensors
,
ModelRunnerOutput
,
SamplerOutput
)
ModelRunnerOutput
,
PoolerOutput
,
SamplerOutput
)
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
vllm.v1.sample.logits_processor
import
LogitsProcessors
,
build_logitsprocs
from
vllm.v1.sample.logits_processor
import
LogitsProcessors
,
build_logitsprocs
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
...
@@ -1823,15 +1824,22 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -1823,15 +1824,22 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
device
=
hidden_states
.
device
)
device
=
hidden_states
.
device
)
seq_lens_cpu
=
self
.
seq_lens
.
cpu
[:
self
.
input_batch
.
num_reqs
]
seq_lens_cpu
=
self
.
seq_lens
.
cpu
[:
self
.
input_batch
.
num_reqs
]
# Pooling models D2H & synchronize occurs in pooler.py:build_output
model
=
cast
(
VllmModelForPooling
,
self
.
model
)
raw_pooler_output
=
self
.
model
.
pooler
(
raw_pooler_output
:
PoolerOutput
=
model
.
pooler
(
hidden_states
=
hidden_states
,
pooling_metadata
=
pooling_metadata
)
hidden_states
=
hidden_states
,
pooling_metadata
=
pooling_metadata
,
)
raw_pooler_output
=
json_map_leaves
(
lambda
x
:
x
.
to
(
"cpu"
,
non_blocking
=
True
),
raw_pooler_output
,
)
self
.
_sync_device
()
pooler_output
:
list
[
Optional
[
torch
.
Tensor
]]
=
[]
pooler_output
:
list
[
Optional
[
torch
.
Tensor
]]
=
[]
for
raw_output
,
seq_len
,
prompt_len
in
zip
(
for
raw_output
,
seq_len
,
prompt_len
in
zip
(
raw_pooler_output
,
seq_lens_cpu
,
pooling_metadata
.
prompt_lens
):
raw_pooler_output
,
seq_lens_cpu
,
pooling_metadata
.
prompt_lens
):
output
=
raw_output
.
data
if
seq_len
==
prompt_len
else
None
output
=
raw_output
if
seq_len
==
prompt_len
else
None
pooler_output
.
append
(
output
)
pooler_output
.
append
(
output
)
return
ModelRunnerOutput
(
return
ModelRunnerOutput
(
...
@@ -3233,7 +3241,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -3233,7 +3241,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for
task
in
self
.
get_supported_pooling_tasks
():
for
task
in
self
.
get_supported_pooling_tasks
():
# Run a full batch with each task to ensure none of them OOMs
# Run a full batch with each task to ensure none of them OOMs
output
=
self
.
_dummy_pooler_run_task
(
hidden_states
,
task
)
output
=
self
.
_dummy_pooler_run_task
(
hidden_states
,
task
)
output_size
[
task
]
=
output
.
get_data_nbytes
(
)
output_size
[
task
]
=
sum
(
o
.
nbytes
for
o
in
output
)
del
output
# Allow GC
del
output
# Allow GC
max_task
=
max
(
output_size
.
items
(),
key
=
lambda
x
:
x
[
1
])[
0
]
max_task
=
max
(
output_size
.
items
(),
key
=
lambda
x
:
x
[
1
])[
0
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment