Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b3104b2a
"vscode:/vscode.git/clone" did not exist on "e8961e963a76feb3e2c080220e79d2d5a9d272f9"
Unverified
Commit
b3104b2a
authored
Apr 10, 2024
by
胡译文
Committed by
GitHub
Apr 10, 2024
Browse files
[Bugfix] Fix logits processor when prompt_logprobs is not None (#3899)
parent
c2e00af5
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
72 additions
and
1 deletion
+72
-1
tests/samplers/test_logits_processor.py
tests/samplers/test_logits_processor.py
+62
-0
vllm/model_executor/layers/logits_processor.py
vllm/model_executor/layers/logits_processor.py
+10
-1
No files found.
tests/samplers/test_logits_processor.py
0 → 100644
View file @
b3104b2a
import
pytest
import
torch
from
vllm
import
SamplingParams
MODELS
=
[
"facebook/opt-125m"
]
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
def
test_logits_processor_force_generate
(
vllm_runner
,
example_prompts
,
model
:
str
,
dtype
:
str
,
)
->
None
:
vllm_model
=
vllm_runner
(
model
,
dtype
=
dtype
)
tokenizer
=
vllm_model
.
model
.
get_tokenizer
()
repeat_times
=
2
enforced_answers
=
" vLLM"
vllm_token_ids
=
tokenizer
.
encode
(
enforced_answers
,
add_special_tokens
=
False
)
max_tokens
=
len
(
vllm_token_ids
)
*
repeat_times
def
pick_vllm
(
token_ids
,
logits
):
token_id
=
vllm_token_ids
[
len
(
token_ids
)
%
len
(
vllm_token_ids
)]
logits
[
token_id
]
=
torch
.
finfo
(
logits
.
dtype
).
max
return
logits
params_with_logprobs
=
SamplingParams
(
logits_processors
=
[
pick_vllm
],
prompt_logprobs
=
3
,
max_tokens
=
max_tokens
,
)
# test logits_processors when prompt_logprobs is not None
vllm_model
.
model
.
_add_request
(
prompt
=
example_prompts
[
0
],
sampling_params
=
params_with_logprobs
,
prompt_token_ids
=
None
,
)
# test prompt_logprobs is not None
vllm_model
.
model
.
_add_request
(
prompt
=
example_prompts
[
1
],
sampling_params
=
SamplingParams
(
prompt_logprobs
=
3
,
max_tokens
=
max_tokens
,
),
prompt_token_ids
=
None
,
)
# test grouped requests
vllm_model
.
model
.
_add_request
(
prompt
=
example_prompts
[
2
],
sampling_params
=
SamplingParams
(
max_tokens
=
max_tokens
),
prompt_token_ids
=
None
,
)
outputs
=
vllm_model
.
model
.
_run_engine
(
False
)
assert
outputs
[
0
].
outputs
[
0
].
text
==
enforced_answers
*
repeat_times
vllm/model_executor/layers/logits_processor.py
View file @
b3104b2a
...
...
@@ -86,8 +86,16 @@ def _apply_logits_processors(
)
->
torch
.
Tensor
:
logits_row_idx
=
0
found_logits_processors
=
False
for
seq_ids
,
sampling_params
in
sampling_metadata
.
seq_groups
:
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
seq_ids
,
sampling_params
=
seq_group
logits_processors
=
sampling_params
.
logits_processors
# handle prompt_logprobs by skipping rows in logits added for
# the prompt tokens (prompt logprobs are not processed)
if
(
i
<
sampling_metadata
.
num_prompts
and
sampling_params
.
prompt_logprobs
is
not
None
):
assert
len
(
seq_ids
)
==
1
logits_row_idx
+=
sampling_metadata
.
prompt_lens
[
i
]
-
1
if
logits_processors
:
found_logits_processors
=
True
for
seq_id
in
seq_ids
:
...
...
@@ -100,5 +108,6 @@ def _apply_logits_processors(
else
:
logits_row_idx
+=
len
(
seq_ids
)
if
found_logits_processors
:
# verifies that no rows in logits were missed unexpectedly
assert
logits_row_idx
==
logits
.
shape
[
0
]
return
logits
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment