Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
37970105
Unverified
Commit
37970105
authored
Sep 18, 2025
by
Jee Jee Li
Committed by
GitHub
Sep 18, 2025
Browse files
[Model] Improve Pooling Model (#25149)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
cc935fdd
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
6 deletions
+7
-6
vllm/model_executor/layers/pooler.py
vllm/model_executor/layers/pooler.py
+6
-6
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+1
-0
No files found.
vllm/model_executor/layers/pooler.py
View file @
37970105
...
@@ -12,8 +12,9 @@ import torch.nn as nn
...
@@ -12,8 +12,9 @@ import torch.nn as nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.config
import
ModelConfig
,
PoolerConfig
from
vllm.config
import
ModelConfig
,
PoolerConfig
,
get_current_vllm_config
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.models.adapters
import
_load_st_projector
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.sequence
import
PoolerOutput
,
PoolingSequenceGroupOutput
from
vllm.sequence
import
PoolerOutput
,
PoolingSequenceGroupOutput
from
vllm.tasks
import
PoolingTask
from
vllm.tasks
import
PoolingTask
...
@@ -377,7 +378,6 @@ class PoolerClassify(PoolerActivation):
...
@@ -377,7 +378,6 @@ class PoolerClassify(PoolerActivation):
super
().
__init__
()
super
().
__init__
()
if
static_num_labels
:
if
static_num_labels
:
from
vllm.config
import
get_current_vllm_config
vllm_config
=
get_current_vllm_config
()
vllm_config
=
get_current_vllm_config
()
self
.
num_labels
=
getattr
(
vllm_config
.
model_config
.
hf_config
,
self
.
num_labels
=
getattr
(
vllm_config
.
model_config
.
hf_config
,
"num_labels"
,
0
)
"num_labels"
,
0
)
...
@@ -427,8 +427,6 @@ class EmbeddingPoolerHead(PoolerHead):
...
@@ -427,8 +427,6 @@ class EmbeddingPoolerHead(PoolerHead):
super
().
__init__
(
activation
=
PoolerNormalize
())
super
().
__init__
(
activation
=
PoolerNormalize
())
# Load ST projector if available
# Load ST projector if available
from
vllm.config
import
get_current_vllm_config
from
vllm.model_executor.models.adapters
import
_load_st_projector
vllm_config
=
get_current_vllm_config
()
vllm_config
=
get_current_vllm_config
()
self
.
projector
:
Optional
[
nn
.
Module
]
=
_load_st_projector
(
self
.
projector
:
Optional
[
nn
.
Module
]
=
_load_st_projector
(
...
@@ -489,7 +487,6 @@ class RewardPoolerHead(PoolerHead):
...
@@ -489,7 +487,6 @@ class RewardPoolerHead(PoolerHead):
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
super
().
__init__
(
activation
=
PoolerClassify
(
static_num_labels
=
False
))
super
().
__init__
(
activation
=
PoolerClassify
(
static_num_labels
=
False
))
from
vllm.config
import
get_current_vllm_config
vllm_config
=
get_current_vllm_config
()
vllm_config
=
get_current_vllm_config
()
self
.
head_dtype
=
vllm_config
.
model_config
.
head_dtype
self
.
head_dtype
=
vllm_config
.
model_config
.
head_dtype
...
@@ -638,7 +635,6 @@ class ClassifierPooler(Pooler):
...
@@ -638,7 +635,6 @@ class ClassifierPooler(Pooler):
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
from
vllm.config
import
get_current_vllm_config
vllm_config
=
get_current_vllm_config
()
vllm_config
=
get_current_vllm_config
()
self
.
pooling
=
pooling
self
.
pooling
=
pooling
...
@@ -730,3 +726,7 @@ class DispatchPooler(Pooler):
...
@@ -730,3 +726,7 @@ class DispatchPooler(Pooler):
offset
+=
num_items
offset
+=
num_items
return
PoolerOutput
(
outputs
)
return
PoolerOutput
(
outputs
)
def
extra_repr
(
self
)
->
str
:
s
=
f
"supported_task=
{
self
.
get_supported_tasks
()
}
"
return
s
vllm/v1/worker/gpu_model_runner.py
View file @
37970105
...
@@ -3151,6 +3151,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -3151,6 +3151,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
model
=
cast
(
VllmModelForPooling
,
self
.
get_model
())
model
=
cast
(
VllmModelForPooling
,
self
.
get_model
())
dummy_pooling_params
=
PoolingParams
(
task
=
task
)
dummy_pooling_params
=
PoolingParams
(
task
=
task
)
dummy_pooling_params
.
verify
(
task
=
task
,
model_config
=
self
.
model_config
)
to_update
=
model
.
pooler
.
get_pooling_updates
(
task
)
to_update
=
model
.
pooler
.
get_pooling_updates
(
task
)
to_update
.
apply
(
dummy_pooling_params
)
to_update
.
apply
(
dummy_pooling_params
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment