Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1ffc8a73
Unverified
Commit
1ffc8a73
authored
Oct 18, 2024
by
Nick Hill
Committed by
GitHub
Oct 18, 2024
Browse files
[BugFix] Typing fixes to RequestOutput.prompt and beam search (#9473)
parent
944dd8ed
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
26 additions
and
14 deletions
+26
-14
vllm/beam_search.py
vllm/beam_search.py
+5
-2
vllm/engine/protocol.py
vllm/engine/protocol.py
+19
-10
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+1
-0
vllm/outputs.py
vllm/outputs.py
+1
-2
No files found.
vllm/beam_search.py
View file @
1ffc8a73
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
from
typing
import
Dict
,
List
,
Optional
from
vllm.sequence
import
Logprob
@
dataclass
@
dataclass
...
@@ -11,6 +13,7 @@ class BeamSearchSequence:
...
@@ -11,6 +13,7 @@ class BeamSearchSequence:
"""
"""
# The tokens includes the prompt.
# The tokens includes the prompt.
tokens
:
List
[
int
]
tokens
:
List
[
int
]
logprobs
:
List
[
Dict
[
int
,
Logprob
]]
cum_logprob
:
float
=
0.0
cum_logprob
:
float
=
0.0
text
:
Optional
[
str
]
=
None
text
:
Optional
[
str
]
=
None
...
@@ -28,7 +31,7 @@ class BeamSearchInstance:
...
@@ -28,7 +31,7 @@ class BeamSearchInstance:
def
__init__
(
self
,
prompt_tokens
:
List
[
int
]):
def
__init__
(
self
,
prompt_tokens
:
List
[
int
]):
self
.
beams
:
List
[
BeamSearchSequence
]
=
[
self
.
beams
:
List
[
BeamSearchSequence
]
=
[
BeamSearchSequence
(
tokens
=
prompt_tokens
)
BeamSearchSequence
(
tokens
=
prompt_tokens
,
logprobs
=
[]
)
]
]
self
.
completed
:
List
[
BeamSearchSequence
]
=
[]
self
.
completed
:
List
[
BeamSearchSequence
]
=
[]
...
...
vllm/engine/protocol.py
View file @
1ffc8a73
...
@@ -59,7 +59,7 @@ class EngineClient(ABC):
...
@@ -59,7 +59,7 @@ class EngineClient(ABC):
async
def
beam_search
(
async
def
beam_search
(
self
,
self
,
prompt
:
Union
[
PromptType
,
List
[
int
]],
prompt
:
Union
[
str
,
List
[
int
]],
request_id
:
str
,
request_id
:
str
,
params
:
BeamSearchParams
,
params
:
BeamSearchParams
,
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
...
@@ -71,9 +71,13 @@ class EngineClient(ABC):
...
@@ -71,9 +71,13 @@ class EngineClient(ABC):
length_penalty
=
params
.
length_penalty
length_penalty
=
params
.
length_penalty
tokenizer
=
await
self
.
get_tokenizer
(
lora_request
=
None
)
tokenizer
=
await
self
.
get_tokenizer
(
lora_request
=
None
)
tokenizedPrompt
=
prompt
if
isinstance
(
if
isinstance
(
prompt
,
str
):
prompt
,
list
)
else
tokenizer
.
encode
(
prompt
)
tokenized_prompt
=
tokenizer
.
encode
(
prompt
)
tokenizedLength
=
len
(
tokenizedPrompt
)
prompt_text
=
prompt
else
:
tokenized_prompt
=
prompt
prompt_text
=
None
tokenized_length
=
len
(
tokenized_prompt
)
sort_beams_key
=
create_sort_beams_key_function
(
sort_beams_key
=
create_sort_beams_key_function
(
tokenizer
.
eos_token_id
,
length_penalty
)
tokenizer
.
eos_token_id
,
length_penalty
)
...
@@ -81,7 +85,11 @@ class EngineClient(ABC):
...
@@ -81,7 +85,11 @@ class EngineClient(ABC):
beam_search_params
=
SamplingParams
(
logprobs
=
2
*
beam_width
,
beam_search_params
=
SamplingParams
(
logprobs
=
2
*
beam_width
,
max_tokens
=
1
,
max_tokens
=
1
,
temperature
=
temperature
)
temperature
=
temperature
)
all_beams
=
[
BeamSearchSequence
(
tokens
=
tokenizedPrompt
,
cum_logprob
=
0
)]
all_beams
=
[
BeamSearchSequence
(
tokens
=
tokenized_prompt
,
logprobs
=
[],
cum_logprob
=
0
)
]
completed
=
[]
completed
=
[]
for
_
in
range
(
max_tokens
):
for
_
in
range
(
max_tokens
):
...
@@ -114,6 +122,7 @@ class EngineClient(ABC):
...
@@ -114,6 +122,7 @@ class EngineClient(ABC):
for
token_id
,
logprob_obj
in
logprobs
.
items
():
for
token_id
,
logprob_obj
in
logprobs
.
items
():
new_beam
=
BeamSearchSequence
(
new_beam
=
BeamSearchSequence
(
tokens
=
current_beam
.
tokens
+
[
token_id
],
tokens
=
current_beam
.
tokens
+
[
token_id
],
logprobs
=
current_beam
.
logprobs
+
[
logprobs
],
cum_logprob
=
current_beam
.
cum_logprob
+
cum_logprob
=
current_beam
.
cum_logprob
+
logprob_obj
.
logprob
)
logprob_obj
.
logprob
)
...
@@ -131,22 +140,22 @@ class EngineClient(ABC):
...
@@ -131,22 +140,22 @@ class EngineClient(ABC):
best_beams
=
sorted_completed
[:
beam_width
]
best_beams
=
sorted_completed
[:
beam_width
]
for
beam
in
best_beams
:
for
beam
in
best_beams
:
beam
.
text
=
tokenizer
.
decode
(
beam
.
tokens
[
tokenized
L
ength
:])
beam
.
text
=
tokenizer
.
decode
(
beam
.
tokens
[
tokenized
_l
ength
:])
beam_search_output
=
RequestOutput
(
beam_search_output
=
RequestOutput
(
request_id
=
request_id
,
request_id
=
request_id
,
prompt
=
prompt
,
prompt
=
prompt
_text
,
outputs
=
[
outputs
=
[
CompletionOutput
(
CompletionOutput
(
text
=
beam
.
text
,
text
=
beam
.
text
,
cumulative_logprob
=
beam
.
cum_logprob
,
cumulative_logprob
=
beam
.
cum_logprob
,
token_ids
=
beam
.
tokens
,
token_ids
=
beam
.
tokens
[
tokenized_length
:]
,
index
=
i
,
index
=
i
,
logprobs
=
beam
.
cum_
logprob
,
logprobs
=
beam
.
logprob
s
,
)
for
(
i
,
beam
)
in
enumerate
(
best_beams
)
)
for
(
i
,
beam
)
in
enumerate
(
best_beams
)
],
],
finished
=
True
,
finished
=
True
,
prompt_token_ids
=
tokenized
P
rompt
,
prompt_token_ids
=
tokenized
_p
rompt
,
prompt_logprobs
=
None
)
prompt_logprobs
=
None
)
yield
beam_search_output
yield
beam_search_output
...
...
vllm/entrypoints/llm.py
View file @
1ffc8a73
...
@@ -433,6 +433,7 @@ class LLM:
...
@@ -433,6 +433,7 @@ class LLM:
for
token_id
,
logprob_obj
in
logprobs
.
items
():
for
token_id
,
logprob_obj
in
logprobs
.
items
():
new_beam
=
BeamSearchSequence
(
new_beam
=
BeamSearchSequence
(
tokens
=
current_beam
.
tokens
+
[
token_id
],
tokens
=
current_beam
.
tokens
+
[
token_id
],
logprobs
=
current_beam
.
logprobs
+
[
logprobs
],
cum_logprob
=
current_beam
.
cum_logprob
+
cum_logprob
=
current_beam
.
cum_logprob
+
logprob_obj
.
logprob
)
logprob_obj
.
logprob
)
...
...
vllm/outputs.py
View file @
1ffc8a73
...
@@ -4,7 +4,6 @@ from typing import List, Optional
...
@@ -4,7 +4,6 @@ from typing import List, Optional
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Union
from
typing
import
Union
from
vllm.inputs
import
PromptType
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.sampling_params
import
RequestOutputKind
from
vllm.sampling_params
import
RequestOutputKind
from
vllm.sequence
import
(
PromptLogprobs
,
RequestMetrics
,
SampleLogprobs
,
from
vllm.sequence
import
(
PromptLogprobs
,
RequestMetrics
,
SampleLogprobs
,
...
@@ -93,7 +92,7 @@ class RequestOutput:
...
@@ -93,7 +92,7 @@ class RequestOutput:
def
__init__
(
def
__init__
(
self
,
self
,
request_id
:
str
,
request_id
:
str
,
prompt
:
Optional
[
PromptType
],
prompt
:
Optional
[
str
],
prompt_token_ids
:
Optional
[
List
[
int
]],
prompt_token_ids
:
Optional
[
List
[
int
]],
prompt_logprobs
:
Optional
[
PromptLogprobs
],
prompt_logprobs
:
Optional
[
PromptLogprobs
],
outputs
:
List
[
CompletionOutput
],
outputs
:
List
[
CompletionOutput
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment