Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
06b2550c
Unverified
Commit
06b2550c
authored
Jun 04, 2024
by
Toshiki Kataoka
Committed by
GitHub
Jun 03, 2024
Browse files
[Bugfix] Support `prompt_logprobs==0` (#5217)
parent
f775a07e
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
11 additions
and
7 deletions
+11
-7
tests/entrypoints/test_openai_server.py
tests/entrypoints/test_openai_server.py
+8
-4
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+1
-1
vllm/model_executor/sampling_metadata.py
vllm/model_executor/sampling_metadata.py
+1
-1
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+1
-1
No files found.
tests/entrypoints/test_openai_server.py
View file @
06b2550c
...
@@ -224,7 +224,7 @@ async def test_zero_logprobs(server, client: openai.AsyncOpenAI,
...
@@ -224,7 +224,7 @@ async def test_zero_logprobs(server, client: openai.AsyncOpenAI,
assert
choice
.
logprobs
is
not
None
assert
choice
.
logprobs
is
not
None
assert
choice
.
logprobs
.
token_logprobs
is
not
None
assert
choice
.
logprobs
.
token_logprobs
is
not
None
assert
choice
.
logprobs
.
top_logprobs
is
not
None
assert
choice
.
logprobs
.
top_logprobs
is
not
None
assert
len
(
choice
.
logprobs
.
top_logprobs
[
0
])
<
=
1
assert
len
(
choice
.
logprobs
.
top_logprobs
[
0
])
=
=
1
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
...
@@ -246,7 +246,7 @@ async def test_some_logprobs(server, client: openai.AsyncOpenAI,
...
@@ -246,7 +246,7 @@ async def test_some_logprobs(server, client: openai.AsyncOpenAI,
assert
choice
.
logprobs
is
not
None
assert
choice
.
logprobs
is
not
None
assert
choice
.
logprobs
.
token_logprobs
is
not
None
assert
choice
.
logprobs
.
token_logprobs
is
not
None
assert
choice
.
logprobs
.
top_logprobs
is
not
None
assert
choice
.
logprobs
.
top_logprobs
is
not
None
assert
len
(
choice
.
logprobs
.
top_logprobs
[
0
])
<=
6
assert
5
<=
len
(
choice
.
logprobs
.
top_logprobs
[
0
])
<=
6
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
...
@@ -1217,8 +1217,9 @@ number: "1" | "2"
...
@@ -1217,8 +1217,9 @@ number: "1" | "2"
"model_name"
,
"model_name"
,
[
MODEL_NAME
,
"zephyr-lora"
,
"zephyr-lora2"
],
[
MODEL_NAME
,
"zephyr-lora"
,
"zephyr-lora2"
],
)
)
@
pytest
.
mark
.
parametrize
(
"logprobs_arg"
,
[
1
,
0
])
async
def
test_echo_logprob_completion
(
server
,
client
:
openai
.
AsyncOpenAI
,
async
def
test_echo_logprob_completion
(
server
,
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
):
model_name
:
str
,
logprobs_arg
:
int
):
tokenizer
=
get_tokenizer
(
tokenizer_name
=
MODEL_NAME
)
tokenizer
=
get_tokenizer
(
tokenizer_name
=
MODEL_NAME
)
# test using text and token IDs
# test using text and token IDs
for
prompt
in
(
"Hello, my name is"
,
[
0
,
0
,
0
,
0
,
0
]):
for
prompt
in
(
"Hello, my name is"
,
[
0
,
0
,
0
,
0
,
0
]):
...
@@ -1227,7 +1228,7 @@ async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI,
...
@@ -1227,7 +1228,7 @@ async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI,
max_tokens
=
5
,
max_tokens
=
5
,
temperature
=
0.0
,
temperature
=
0.0
,
echo
=
True
,
echo
=
True
,
logprobs
=
1
)
logprobs
=
logprobs_arg
)
prompt_text
=
tokenizer
.
decode
(
prompt
)
if
isinstance
(
prompt
,
prompt_text
=
tokenizer
.
decode
(
prompt
)
if
isinstance
(
prompt
,
list
)
else
prompt
list
)
else
prompt
...
@@ -1240,6 +1241,9 @@ async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI,
...
@@ -1240,6 +1241,9 @@ async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI,
and
logprobs
.
token_logprobs
[
0
]
is
None
)
and
logprobs
.
token_logprobs
[
0
]
is
None
)
assert
(
len
(
logprobs
.
top_logprobs
)
>
5
assert
(
len
(
logprobs
.
top_logprobs
)
>
5
and
logprobs
.
top_logprobs
[
0
]
is
None
)
and
logprobs
.
top_logprobs
[
0
]
is
None
)
for
top_logprobs
in
logprobs
.
top_logprobs
[
1
:]:
assert
max
(
logprobs_arg
,
1
)
<=
len
(
top_logprobs
)
<=
logprobs_arg
+
1
assert
len
(
logprobs
.
tokens
)
>
5
assert
len
(
logprobs
.
tokens
)
>
5
...
...
vllm/entrypoints/openai/serving_completion.py
View file @
06b2550c
...
@@ -312,7 +312,7 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -312,7 +312,7 @@ class OpenAIServingCompletion(OpenAIServing):
elif
request
.
echo
and
request
.
max_tokens
>
0
:
elif
request
.
echo
and
request
.
max_tokens
>
0
:
token_ids
=
prompt_token_ids
+
output
.
token_ids
token_ids
=
prompt_token_ids
+
output
.
token_ids
top_logprobs
=
(
prompt_logprobs
+
output
.
logprobs
top_logprobs
=
(
prompt_logprobs
+
output
.
logprobs
if
request
.
logprobs
else
None
)
if
request
.
logprobs
is
not
None
else
None
)
output_text
=
prompt_text
+
output
.
text
output_text
=
prompt_text
+
output
.
text
else
:
else
:
token_ids
=
output
.
token_ids
token_ids
=
output
.
token_ids
...
...
vllm/model_executor/sampling_metadata.py
View file @
06b2550c
...
@@ -233,7 +233,7 @@ def _prepare_seq_groups(
...
@@ -233,7 +233,7 @@ def _prepare_seq_groups(
logits = hidden_states[selected_token_indices]
logits = hidden_states[selected_token_indices]
"""
"""
if
sampling_params
.
prompt_logprobs
:
if
sampling_params
.
prompt_logprobs
is
not
None
:
selected_token_indices
.
extend
(
selected_token_indices
.
extend
(
range
(
model_output_idx
,
model_output_idx
+
prompt_logprob_len
))
range
(
model_output_idx
,
model_output_idx
+
prompt_logprob_len
))
model_output_idx
+=
prompt_logprob_len
model_output_idx
+=
prompt_logprob_len
...
...
vllm/worker/model_runner.py
View file @
06b2550c
...
@@ -427,7 +427,7 @@ class ModelRunner:
...
@@ -427,7 +427,7 @@ class ModelRunner:
[
lora_id
]
*
[
lora_id
]
*
(
query_len
if
seq_group_metadata
.
sampling_params
(
query_len
if
seq_group_metadata
.
sampling_params
and
seq_group_metadata
.
sampling_params
.
prompt_logprobs
and
seq_group_metadata
.
sampling_params
.
prompt_logprobs
else
1
))
is
not
None
else
1
))
mm_data
=
seq_group_metadata
.
multi_modal_data
mm_data
=
seq_group_metadata
.
multi_modal_data
if
mm_data
is
not
None
:
if
mm_data
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment