Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f856c33c
Unverified
Commit
f856c33c
authored
Aug 19, 2025
by
wang.yuqi
Committed by
GitHub
Aug 19, 2025
Browse files
[Model] Add multi_label_classification support (#23173)
Signed-off-by:
wang.yuqi
<
noooop@126.com
>
parent
03752dba
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
57 additions
and
1 deletion
+57
-1
tests/conftest.py
tests/conftest.py
+9
-1
tests/models/language/pooling/test_multilabel_classification_support.py
...anguage/pooling/test_multilabel_classification_support.py
+33
-0
vllm/model_executor/layers/pooler.py
vllm/model_executor/layers/pooler.py
+15
-0
No files found.
tests/conftest.py
View file @
f856c33c
...
...
@@ -456,6 +456,14 @@ class HfRunner:
outputs
=
[]
for
inputs
in
all_inputs
:
output
=
self
.
model
(
**
self
.
wrap_device
(
inputs
))
problem_type
=
getattr
(
self
.
config
,
"problem_type"
,
""
)
if
problem_type
==
"regression"
:
logits
=
output
.
logits
[
0
].
tolist
()
elif
problem_type
==
"multi_label_classification"
:
logits
=
output
.
logits
.
sigmoid
()[
0
].
tolist
()
else
:
logits
=
output
.
logits
.
softmax
(
dim
=-
1
)[
0
].
tolist
()
outputs
.
append
(
logits
)
...
...
tests/models/language/pooling/test_multilabel_classification_support.py
0 → 100644
View file @
f856c33c
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
import
torch
from
transformers
import
AutoModelForSequenceClassification
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"Rami/multi-label-class-classification-on-github-issues"
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
def
test_classify_models
(
hf_runner
,
vllm_runner
,
example_prompts
,
model
:
str
,
dtype
:
str
,
)
->
None
:
with
vllm_runner
(
model
,
max_model_len
=
512
,
dtype
=
dtype
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
classify
(
example_prompts
)
with
hf_runner
(
model
,
dtype
=
dtype
,
auto_cls
=
AutoModelForSequenceClassification
)
as
hf_model
:
hf_outputs
=
hf_model
.
classify
(
example_prompts
)
for
hf_output
,
vllm_output
in
zip
(
hf_outputs
,
vllm_outputs
):
hf_output
=
torch
.
tensor
(
hf_output
)
vllm_output
=
torch
.
tensor
(
vllm_output
)
assert
torch
.
allclose
(
hf_output
,
vllm_output
,
1e-3
if
dtype
==
"float"
else
1e-2
)
vllm/model_executor/layers/pooler.py
View file @
f856c33c
...
...
@@ -172,6 +172,15 @@ def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]:
def
get_classification_activation_function
(
config
:
PretrainedConfig
):
# Implement alignment with transformers ForSequenceClassificationLoss
# https://github.com/huggingface/transformers/blob/57bb6db6ee4cfaccc45b8d474dfad5a17811ca60/src/transformers/loss/loss_utils.py#L92
problem_type
=
getattr
(
config
,
"problem_type"
,
""
)
if
problem_type
==
"regression"
:
return
PoolerIdentity
()
if
problem_type
==
"single_label_classification"
:
return
PoolerClassify
()
if
problem_type
==
"multi_label_classification"
:
return
PoolerMultiLabelClassify
()
return
PoolerClassify
()
...
...
@@ -409,6 +418,12 @@ class PoolerNormalize(PoolerActivation):
return
x
.
to
(
pooled_data
.
dtype
)
class
PoolerMultiLabelClassify
(
PoolerActivation
):
def
forward_chunk
(
self
,
pooled_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
F
.
sigmoid
(
pooled_data
.
float
()).
to
(
pooled_data
.
dtype
)
class
PoolerClassify
(
PoolerActivation
):
def
forward_chunk
(
self
,
pooled_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment