Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1ffc8a73
Unverified
Commit
1ffc8a73
authored
Oct 18, 2024
by
Nick Hill
Committed by
GitHub
Oct 18, 2024
Browse files
[BugFix] Typing fixes to RequestOutput.prompt and beam search (#9473)
parent
944dd8ed
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
26 additions
and
14 deletions
+26
-14
vllm/beam_search.py
vllm/beam_search.py
+5
-2
vllm/engine/protocol.py
vllm/engine/protocol.py
+19
-10
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+1
-0
vllm/outputs.py
vllm/outputs.py
+1
-2
No files found.
vllm/beam_search.py
View file @
1ffc8a73
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
from
typing
import
Dict
,
List
,
Optional
from
vllm.sequence
import
Logprob
@
dataclass
...
...
@@ -11,6 +13,7 @@ class BeamSearchSequence:
"""
# The tokens includes the prompt.
tokens
:
List
[
int
]
logprobs
:
List
[
Dict
[
int
,
Logprob
]]
cum_logprob
:
float
=
0.0
text
:
Optional
[
str
]
=
None
...
...
@@ -28,7 +31,7 @@ class BeamSearchInstance:
def
__init__
(
self
,
prompt_tokens
:
List
[
int
]):
self
.
beams
:
List
[
BeamSearchSequence
]
=
[
BeamSearchSequence
(
tokens
=
prompt_tokens
)
BeamSearchSequence
(
tokens
=
prompt_tokens
,
logprobs
=
[]
)
]
self
.
completed
:
List
[
BeamSearchSequence
]
=
[]
...
...
vllm/engine/protocol.py
View file @
1ffc8a73
...
...
@@ -59,7 +59,7 @@ class EngineClient(ABC):
async
def
beam_search
(
self
,
prompt
:
Union
[
PromptType
,
List
[
int
]],
prompt
:
Union
[
str
,
List
[
int
]],
request_id
:
str
,
params
:
BeamSearchParams
,
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
...
...
@@ -71,9 +71,13 @@ class EngineClient(ABC):
length_penalty
=
params
.
length_penalty
tokenizer
=
await
self
.
get_tokenizer
(
lora_request
=
None
)
tokenizedPrompt
=
prompt
if
isinstance
(
prompt
,
list
)
else
tokenizer
.
encode
(
prompt
)
tokenizedLength
=
len
(
tokenizedPrompt
)
if
isinstance
(
prompt
,
str
):
tokenized_prompt
=
tokenizer
.
encode
(
prompt
)
prompt_text
=
prompt
else
:
tokenized_prompt
=
prompt
prompt_text
=
None
tokenized_length
=
len
(
tokenized_prompt
)
sort_beams_key
=
create_sort_beams_key_function
(
tokenizer
.
eos_token_id
,
length_penalty
)
...
...
@@ -81,7 +85,11 @@ class EngineClient(ABC):
beam_search_params
=
SamplingParams
(
logprobs
=
2
*
beam_width
,
max_tokens
=
1
,
temperature
=
temperature
)
all_beams
=
[
BeamSearchSequence
(
tokens
=
tokenizedPrompt
,
cum_logprob
=
0
)]
all_beams
=
[
BeamSearchSequence
(
tokens
=
tokenized_prompt
,
logprobs
=
[],
cum_logprob
=
0
)
]
completed
=
[]
for
_
in
range
(
max_tokens
):
...
...
@@ -114,6 +122,7 @@ class EngineClient(ABC):
for
token_id
,
logprob_obj
in
logprobs
.
items
():
new_beam
=
BeamSearchSequence
(
tokens
=
current_beam
.
tokens
+
[
token_id
],
logprobs
=
current_beam
.
logprobs
+
[
logprobs
],
cum_logprob
=
current_beam
.
cum_logprob
+
logprob_obj
.
logprob
)
...
...
@@ -131,22 +140,22 @@ class EngineClient(ABC):
best_beams
=
sorted_completed
[:
beam_width
]
for
beam
in
best_beams
:
beam
.
text
=
tokenizer
.
decode
(
beam
.
tokens
[
tokenized
L
ength
:])
beam
.
text
=
tokenizer
.
decode
(
beam
.
tokens
[
tokenized
_l
ength
:])
beam_search_output
=
RequestOutput
(
request_id
=
request_id
,
prompt
=
prompt
,
prompt
=
prompt
_text
,
outputs
=
[
CompletionOutput
(
text
=
beam
.
text
,
cumulative_logprob
=
beam
.
cum_logprob
,
token_ids
=
beam
.
tokens
,
token_ids
=
beam
.
tokens
[
tokenized_length
:]
,
index
=
i
,
logprobs
=
beam
.
cum_
logprob
,
logprobs
=
beam
.
logprob
s
,
)
for
(
i
,
beam
)
in
enumerate
(
best_beams
)
],
finished
=
True
,
prompt_token_ids
=
tokenized
P
rompt
,
prompt_token_ids
=
tokenized
_p
rompt
,
prompt_logprobs
=
None
)
yield
beam_search_output
...
...
vllm/entrypoints/llm.py
View file @
1ffc8a73
...
...
@@ -433,6 +433,7 @@ class LLM:
for
token_id
,
logprob_obj
in
logprobs
.
items
():
new_beam
=
BeamSearchSequence
(
tokens
=
current_beam
.
tokens
+
[
token_id
],
logprobs
=
current_beam
.
logprobs
+
[
logprobs
],
cum_logprob
=
current_beam
.
cum_logprob
+
logprob_obj
.
logprob
)
...
...
vllm/outputs.py
View file @
1ffc8a73
...
...
@@ -4,7 +4,6 @@ from typing import List, Optional
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Union
from
vllm.inputs
import
PromptType
from
vllm.lora.request
import
LoRARequest
from
vllm.sampling_params
import
RequestOutputKind
from
vllm.sequence
import
(
PromptLogprobs
,
RequestMetrics
,
SampleLogprobs
,
...
...
@@ -93,7 +92,7 @@ class RequestOutput:
def
__init__
(
self
,
request_id
:
str
,
prompt
:
Optional
[
PromptType
],
prompt
:
Optional
[
str
],
prompt_token_ids
:
Optional
[
List
[
int
]],
prompt_logprobs
:
Optional
[
PromptLogprobs
],
outputs
:
List
[
CompletionOutput
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment