Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
15bb8330
Unverified
Commit
15bb8330
authored
Nov 14, 2024
by
Isotr0py
Committed by
GitHub
Nov 14, 2024
Browse files
[Bugfix] Fix tensor parallel for qwen2 classification model (#10297)
Signed-off-by:
Isotr0py
<
2037008807@qq.com
>
parent
ac49b59d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
4 deletions
+9
-4
tests/models/embedding/language/test_cls_models.py
tests/models/embedding/language/test_cls_models.py
+3
-3
vllm/model_executor/models/qwen2_cls.py
vllm/model_executor/models/qwen2_cls.py
+6
-1
No files found.
tests/models/embedding/language/test_cls_models.py
View file @
15bb8330
...
...
@@ -21,14 +21,14 @@ def test_classification_models(
model
:
str
,
dtype
:
str
,
)
->
None
:
with
vllm_runner
(
model
,
dtype
=
dtype
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
classify
(
example_prompts
)
with
hf_runner
(
model
,
dtype
=
dtype
,
auto_cls
=
AutoModelForSequenceClassification
)
as
hf_model
:
hf_outputs
=
hf_model
.
classify
(
example_prompts
)
with
vllm_runner
(
model
,
dtype
=
dtype
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
classify
(
example_prompts
)
print
(
hf_outputs
,
vllm_outputs
)
# check logits difference
...
...
vllm/model_executor/models/qwen2_cls.py
View file @
15bb8330
...
...
@@ -69,9 +69,14 @@ class Qwen2ForSequenceClassification(nn.Module):
self
.
model
=
Qwen2Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
# hidden_states from Qwen2Model has been reduced,
# the input of score layer is not parallelized.
self
.
score
=
RowParallelLinear
(
config
.
hidden_size
,
config
.
num_labels
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
input_is_parallel
=
False
,
bias
=
False
,
prefix
=
maybe_prefix
(
prefix
,
"score"
))
self
.
_pooler
=
Pooler
.
from_config_with_defaults
(
pooler_config
,
pooling_type
=
PoolingType
.
LAST
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment