Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
168cab6b
Unverified
Commit
168cab6b
authored
Oct 05, 2024
by
Brendan Wong
Committed by
GitHub
Oct 05, 2024
Browse files
[Frontend] API support for beam search (#9087)
Co-authored-by:
youkaichao
<
youkaichao@126.com
>
parent
23fea871
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
275 additions
and
68 deletions
+275
-68
benchmarks/benchmark_throughput.py
benchmarks/benchmark_throughput.py
+8
-4
tests/conftest.py
tests/conftest.py
+4
-1
tests/entrypoints/openai/test_completion.py
tests/entrypoints/openai/test_completion.py
+24
-19
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+103
-4
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+10
-10
vllm/entrypoints/logger.py
vllm/entrypoints/logger.py
+3
-2
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+34
-2
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+31
-12
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+34
-12
vllm/entrypoints/openai/serving_engine.py
vllm/entrypoints/openai/serving_engine.py
+3
-2
vllm/sampling_params.py
vllm/sampling_params.py
+12
-0
vllm/utils.py
vllm/utils.py
+9
-0
No files found.
benchmarks/benchmark_throughput.py
View file @
168cab6b
...
...
@@ -15,6 +15,7 @@ from vllm.engine.arg_utils import DEVICE_OPTIONS, AsyncEngineArgs, EngineArgs
from
vllm.entrypoints.openai.api_server
import
(
build_async_engine_client_from_engine_args
)
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
from
vllm.sampling_params
import
BeamSearchParams
from
vllm.utils
import
FlexibleArgumentParser
,
merge_async_iterators
...
...
@@ -145,10 +146,13 @@ def run_vllm(
for
prompt
,
input_len
,
_output_len
in
requests
:
assert
_output_len
==
output_len
start
=
time
.
perf_counter
()
llm
.
beam_search
(
prompts
,
beam_width
=
n
,
max_tokens
=
output_len
,
ignore_eos
=
True
)
llm
.
beam_search
(
prompts
,
BeamSearchParams
(
beam_width
=
n
,
max_tokens
=
output_len
,
ignore_eos
=
True
,
))
end
=
time
.
perf_counter
()
return
end
-
start
...
...
tests/conftest.py
View file @
168cab6b
...
...
@@ -35,6 +35,7 @@ from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
to_enc_dec_tuple_list
,
zip_enc_dec_prompts
)
from
vllm.logger
import
init_logger
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
BeamSearchParams
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
cuda_device_count_stateless
,
identity
,
is_cpu
)
...
...
@@ -812,7 +813,9 @@ class VllmRunner:
beam_width
:
int
,
max_tokens
:
int
,
)
->
List
[
Tuple
[
List
[
List
[
int
]],
List
[
str
]]]:
outputs
=
self
.
model
.
beam_search
(
prompts
,
beam_width
,
max_tokens
)
outputs
=
self
.
model
.
beam_search
(
prompts
,
BeamSearchParams
(
beam_width
=
beam_width
,
max_tokens
=
max_tokens
))
returned_outputs
=
[]
for
output
in
outputs
:
token_ids
=
[
x
.
tokens
for
x
in
output
.
sequences
]
...
...
tests/entrypoints/openai/test_completion.py
View file @
168cab6b
...
...
@@ -495,25 +495,30 @@ async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str):
assert
len
(
batch
.
choices
)
==
2
assert
batch
.
choices
[
0
].
text
==
batch
.
choices
[
1
].
text
# test n = 2
batch
=
await
client
.
completions
.
create
(
model
=
model_name
,
prompt
=
prompts
,
n
=
2
,
max_tokens
=
5
,
temperature
=
0.0
,
extra_body
=
dict
(
# NOTE: this has to be true for n > 1 in vLLM, but not necessary
# for official client.
use_beam_search
=
True
),
)
assert
len
(
batch
.
choices
)
==
4
assert
batch
.
choices
[
0
].
text
!=
batch
.
choices
[
1
].
text
,
"beam search should be different"
assert
batch
.
choices
[
0
].
text
==
batch
.
choices
[
2
].
text
,
"two copies of the same prompt should be the same"
assert
batch
.
choices
[
1
].
text
==
batch
.
choices
[
3
].
text
,
"two copies of the same prompt should be the same"
try
:
# test n = 2
batch
=
await
client
.
completions
.
create
(
model
=
model_name
,
prompt
=
prompts
,
n
=
2
,
max_tokens
=
5
,
temperature
=
0.0
,
extra_body
=
dict
(
# NOTE: this has to be true for n > 1 in vLLM, but
# not necessary for official client.
use_beam_search
=
True
),
)
assert
len
(
batch
.
choices
)
==
4
assert
batch
.
choices
[
0
].
text
!=
batch
.
choices
[
1
].
text
,
"beam search should be different"
assert
batch
.
choices
[
0
].
text
==
batch
.
choices
[
2
].
text
,
"two copies of the same prompt should be the same"
assert
batch
.
choices
[
1
].
text
==
batch
.
choices
[
3
].
text
,
"two copies of the same prompt should be the same"
except
BadRequestError
as
e
:
# the only allowed exception is when beam search is not supported
# in the default mqllmengine
assert
"--disable-frontend-multiprocessing"
in
str
(
e
)
# test streaming
batch
=
await
client
.
completions
.
create
(
...
...
vllm/engine/async_llm_engine.py
View file @
168cab6b
...
...
@@ -14,23 +14,26 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from
vllm.engine.async_timeout
import
asyncio_timeout
from
vllm.engine.llm_engine
import
LLMEngine
,
SchedulerOutputState
from
vllm.engine.metrics_types
import
StatLoggerBase
from
vllm.entrypoints.llm
import
BeamSearchSequence
from
vllm.executor.executor_base
import
ExecutorAsyncBase
from
vllm.executor.gpu_executor
import
GPUExecutorAsync
from
vllm.executor.ray_utils
import
initialize_ray_cluster
from
vllm.inputs
import
PromptType
from
vllm.inputs
import
PromptType
,
TokensPrompt
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.guided_decoding
import
(
get_guided_decoding_logits_processor
)
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
from
vllm.outputs
import
(
CompletionOutput
,
EmbeddingRequestOutput
,
RequestOutput
)
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
BeamSearchParams
,
SamplingParams
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
deprecate_kwargs
,
weak_bind
from
vllm.utils
import
(
collect_from_async_generator
,
deprecate_kwargs
,
random_uuid
,
weak_bind
)
logger
=
init_logger
(
__name__
)
ENGINE_ITERATION_TIMEOUT_S
=
envs
.
VLLM_ENGINE_ITERATION_TIMEOUT_S
...
...
@@ -1036,6 +1039,102 @@ class AsyncLLMEngine:
):
yield
LLMEngine
.
validate_output
(
output
,
RequestOutput
)
async
def
beam_search
(
self
,
prompt
:
Union
[
PromptType
,
List
[
int
]],
request_id
:
str
,
params
:
BeamSearchParams
,
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
beam_width
=
params
.
beam_width
max_tokens
=
params
.
max_tokens
ignore_eos
=
params
.
ignore_eos
temperature
=
params
.
temperature
tokenizer
=
await
self
.
get_tokenizer
()
tokenizedPrompt
=
prompt
if
isinstance
(
prompt
,
list
)
else
tokenizer
.
encode
(
prompt
)
tokenizedLength
=
len
(
tokenizedPrompt
)
beam_search_params
=
SamplingParams
(
logprobs
=
2
*
beam_width
,
max_tokens
=
1
,
temperature
=
temperature
)
all_beams
=
[
BeamSearchSequence
(
tokens
=
tokenizedPrompt
,
cum_logprob
=
0
)]
completed
=
[]
for
_
in
range
(
max_tokens
):
prompts_batch
=
[
TokensPrompt
(
prompt_token_ids
=
beam
.
tokens
)
for
beam
in
all_beams
]
tasks
=
[]
request_id
=
f
"beam_search-
{
random_uuid
()
}
"
for
i
,
individual_prompt
in
enumerate
(
prompts_batch
):
request_id_item
=
f
"
{
request_id
}
-
{
i
}
"
task
=
asyncio
.
create_task
(
collect_from_async_generator
(
self
.
generate
(
individual_prompt
,
beam_search_params
,
request_id_item
)))
tasks
.
append
(
task
)
output
=
await
asyncio
.
gather
(
*
tasks
)
output
=
[
x
[
0
]
for
x
in
output
]
logger
.
info
(
output
)
new_beams
=
[]
for
i
,
current_beam
in
enumerate
(
all_beams
):
result
=
output
[
i
]
if
result
.
outputs
[
0
].
logprobs
is
not
None
:
logprobs
=
result
.
outputs
[
0
].
logprobs
[
0
]
for
token_id
,
logprob_obj
in
logprobs
.
items
():
new_beam
=
BeamSearchSequence
(
tokens
=
current_beam
.
tokens
+
[
token_id
],
cum_logprob
=
current_beam
.
cum_logprob
+
logprob_obj
.
logprob
)
if
token_id
==
tokenizer
.
eos_token_id
and
\
not
ignore_eos
:
completed
.
append
(
new_beam
)
else
:
new_beams
.
append
(
new_beam
)
sorted_beams
=
sorted
(
new_beams
,
key
=
lambda
x
:
x
.
cum_logprob
,
reverse
=
True
)
all_beams
=
sorted_beams
[:
beam_width
]
completed
.
extend
(
all_beams
)
sorted_completed
=
sorted
(
completed
,
key
=
lambda
x
:
x
.
cum_logprob
,
reverse
=
True
)
best_beams
=
sorted_completed
[:
beam_width
]
for
beam
in
best_beams
:
beam
.
text
=
tokenizer
.
decode
(
beam
.
tokens
[
tokenizedLength
:])
beam_search_output
=
RequestOutput
(
request_id
=
request_id
,
prompt
=
prompt
,
outputs
=
[
CompletionOutput
(
text
=
beam
.
text
,
cumulative_logprob
=
beam
.
cum_logprob
,
token_ids
=
beam
.
tokens
,
index
=
i
,
logprobs
=
beam
.
cum_logprob
,
)
for
(
i
,
beam
)
in
enumerate
(
best_beams
)
],
finished
=
True
,
prompt_token_ids
=
tokenizedPrompt
,
prompt_logprobs
=
None
)
yield
LLMEngine
.
validate_output
(
beam_search_output
,
RequestOutput
)
async
def
encode
(
self
,
prompt
:
PromptType
,
...
...
vllm/entrypoints/llm.py
View file @
168cab6b
...
...
@@ -22,8 +22,8 @@ from vllm.model_executor.guided_decoding.guided_fields import (
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
(
GuidedDecodingParams
,
RequestOutputKind
,
SamplingParams
)
from
vllm.sampling_params
import
(
BeamSearchParams
,
GuidedDecodingParams
,
RequestOutputKind
,
SamplingParams
)
from
vllm.transformers_utils.tokenizer
import
(
AnyTokenizer
,
MistralTokenizer
,
get_cached_tokenizer
)
from
vllm.transformers_utils.tokenizer_group
import
TokenizerGroup
...
...
@@ -394,10 +394,7 @@ class LLM:
def
beam_search
(
self
,
prompts
:
List
[
Union
[
str
,
List
[
int
]]],
beam_width
:
int
,
max_tokens
:
int
,
ignore_eos
:
bool
=
False
,
temperature
:
float
=
0.0
,
params
:
BeamSearchParams
,
)
->
List
[
BeamSearchOutput
]:
"""
Generate sequences using beam search.
...
...
@@ -405,14 +402,17 @@ class LLM:
Args:
prompts: A list of prompts. Each prompt can be a string or a list
of token IDs.
beam_width: The number of beams to keep at each step.
max_tokens: The max number of tokens to generate for each prompt.
temperature: The temperature to use for generation.
params: The beam search parameters.
TODO: how does beam search work together with length penalty, frequency
penalty, and stopping criteria, etc.?
"""
beam_width
=
params
.
beam_width
max_tokens
=
params
.
max_tokens
temperature
=
params
.
temperature
ignore_eos
=
params
.
ignore_eos
tokenizer
=
self
.
get_tokenizer
()
# generate 2 * beam_width candidates at each step
# following the huggingface transformers implementation
...
...
vllm/entrypoints/logger.py
View file @
168cab6b
...
...
@@ -4,7 +4,7 @@ from vllm.logger import init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
BeamSearchParams
,
SamplingParams
logger
=
init_logger
(
__name__
)
...
...
@@ -21,7 +21,8 @@ class RequestLogger:
request_id
:
str
,
prompt
:
Optional
[
str
],
prompt_token_ids
:
Optional
[
List
[
int
]],
params
:
Optional
[
Union
[
SamplingParams
,
PoolingParams
]],
params
:
Optional
[
Union
[
SamplingParams
,
PoolingParams
,
BeamSearchParams
]],
lora_request
:
Optional
[
LoRARequest
],
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
)
->
None
:
...
...
vllm/entrypoints/openai/protocol.py
View file @
168cab6b
...
...
@@ -11,8 +11,8 @@ from typing_extensions import Annotated, Required, TypedDict
from
vllm.entrypoints.chat_utils
import
ChatCompletionMessageParam
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
(
GuidedDecodingParams
,
RequestOutputKind
,
SamplingParams
)
from
vllm.sampling_params
import
(
BeamSearchParams
,
GuidedDecodingParams
,
RequestOutputKind
,
SamplingParams
)
from
vllm.sequence
import
Logprob
from
vllm.utils
import
random_uuid
...
...
@@ -288,6 +288,22 @@ class ChatCompletionRequest(OpenAIBaseModel):
# doc: end-chat-completion-extra-params
def
to_beam_search_params
(
self
,
default_max_tokens
:
int
)
->
BeamSearchParams
:
max_tokens
=
self
.
max_tokens
if
max_tokens
is
None
:
max_tokens
=
default_max_tokens
n
=
self
.
n
if
self
.
n
is
not
None
else
1
temperature
=
self
.
temperature
if
self
.
temperature
is
not
None
else
0.0
return
BeamSearchParams
(
beam_width
=
n
,
max_tokens
=
max_tokens
,
ignore_eos
=
self
.
ignore_eos
,
temperature
=
temperature
,
)
def
to_sampling_params
(
self
,
default_max_tokens
:
int
)
->
SamplingParams
:
max_tokens
=
self
.
max_tokens
if
max_tokens
is
None
:
...
...
@@ -567,6 +583,22 @@ class CompletionRequest(OpenAIBaseModel):
# doc: end-completion-extra-params
def
to_beam_search_params
(
self
,
default_max_tokens
:
int
)
->
BeamSearchParams
:
max_tokens
=
self
.
max_tokens
if
max_tokens
is
None
:
max_tokens
=
default_max_tokens
n
=
self
.
n
if
self
.
n
is
not
None
else
1
temperature
=
self
.
temperature
if
self
.
temperature
is
not
None
else
0.0
return
BeamSearchParams
(
beam_width
=
n
,
max_tokens
=
max_tokens
,
ignore_eos
=
self
.
ignore_eos
,
temperature
=
temperature
,
)
def
to_sampling_params
(
self
,
default_max_tokens
:
int
)
->
SamplingParams
:
max_tokens
=
self
.
max_tokens
if
max_tokens
is
None
:
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
168cab6b
...
...
@@ -9,6 +9,7 @@ from typing import Union
from
fastapi
import
Request
from
vllm.config
import
ModelConfig
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.chat_utils
import
(
ConversationMessage
,
apply_hf_chat_template
,
...
...
@@ -33,6 +34,7 @@ from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from
vllm.inputs
import
TokensPrompt
from
vllm.logger
import
init_logger
from
vllm.outputs
import
CompletionOutput
,
RequestOutput
from
vllm.sampling_params
import
BeamSearchParams
,
SamplingParams
from
vllm.sequence
import
Logprob
from
vllm.tracing
import
(
contains_trace_headers
,
extract_trace_headers
,
log_tracing_disabled_warning
)
...
...
@@ -203,9 +205,15 @@ class OpenAIServingChat(OpenAIServing):
assert
prompt_inputs
is
not
None
sampling_params
=
request
.
to_sampling_params
(
default_max_tokens
=
self
.
max_model_len
-
len
(
prompt_inputs
[
"prompt_token_ids"
]))
sampling_params
:
Union
[
SamplingParams
,
BeamSearchParams
]
default_max_tokens
=
self
.
max_model_len
-
len
(
prompt_inputs
[
"prompt_token_ids"
])
if
request
.
use_beam_search
:
sampling_params
=
request
.
to_beam_search_params
(
default_max_tokens
)
else
:
sampling_params
=
request
.
to_sampling_params
(
default_max_tokens
)
self
.
_log_inputs
(
request_id
,
prompt_inputs
,
...
...
@@ -227,15 +235,26 @@ class OpenAIServingChat(OpenAIServing):
and
contains_trace_headers
(
raw_request
.
headers
)):
log_tracing_disabled_warning
()
result_generator
=
self
.
engine_client
.
generate
(
engine_inputs
,
sampling_params
,
request_id
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
,
priority
=
request
.
priority
,
)
if
isinstance
(
sampling_params
,
BeamSearchParams
):
if
not
isinstance
(
self
.
engine_client
,
AsyncLLMEngine
):
raise
ValueError
(
"Beam search in the API server is only supported with"
" AsyncLLMEngine. please add "
"`--disable-frontend-multiprocessing` to "
"use beam search."
)
result_generator
=
self
.
engine_client
.
beam_search
(
engine_inputs
[
'prompt_token_ids'
],
request_id
,
sampling_params
)
else
:
result_generator
=
self
.
engine_client
.
generate
(
engine_inputs
,
sampling_params
,
request_id
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
,
priority
=
request
.
priority
,
)
except
ValueError
as
e
:
# TODO: Use a vllm-specific Validation Error
return
self
.
create_error_response
(
str
(
e
))
...
...
vllm/entrypoints/openai/serving_completion.py
View file @
168cab6b
...
...
@@ -8,6 +8,7 @@ from typing import Tuple, Union, cast
from
fastapi
import
Request
from
vllm.config
import
ModelConfig
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.logger
import
RequestLogger
# yapf conflicts with isort for this block
...
...
@@ -28,6 +29,7 @@ from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
PromptAdapterPath
)
from
vllm.logger
import
init_logger
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
BeamSearchParams
,
SamplingParams
from
vllm.sequence
import
Logprob
from
vllm.tracing
import
(
contains_trace_headers
,
extract_trace_headers
,
log_tracing_disabled_warning
)
...
...
@@ -120,9 +122,15 @@ class OpenAIServingCompletion(OpenAIServing):
))
for
i
,
prompt_inputs
in
enumerate
(
prompts
):
sampling_params
=
request
.
to_sampling_params
(
default_max_tokens
=
self
.
max_model_len
-
len
(
prompt_inputs
[
"prompt_token_ids"
]))
sampling_params
:
Union
[
SamplingParams
,
BeamSearchParams
]
default_max_tokens
=
self
.
max_model_len
-
len
(
prompt_inputs
[
"prompt_token_ids"
])
if
request
.
use_beam_search
:
sampling_params
=
request
.
to_beam_search_params
(
default_max_tokens
)
else
:
sampling_params
=
request
.
to_sampling_params
(
default_max_tokens
)
request_id_item
=
f
"
{
request_id
}
-
{
i
}
"
...
...
@@ -141,15 +149,29 @@ class OpenAIServingCompletion(OpenAIServing):
raw_request
.
headers
):
log_tracing_disabled_warning
()
generator
=
self
.
engine_client
.
generate
(
{
"prompt_token_ids"
:
prompt_inputs
[
"prompt_token_ids"
]},
sampling_params
,
request_id_item
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
trace_headers
=
trace_headers
,
priority
=
request
.
priority
,
)
if
isinstance
(
sampling_params
,
BeamSearchParams
):
if
not
isinstance
(
self
.
engine_client
,
AsyncLLMEngine
):
raise
ValueError
(
"Beam search in the API server is only supported"
" with AsyncLLMEngine. please add "
"`--disable-frontend-multiprocessing` to "
"use beam search."
)
generator
=
self
.
engine_client
.
beam_search
(
prompt_inputs
[
"prompt_token_ids"
],
request_id_item
,
sampling_params
)
else
:
generator
=
self
.
engine_client
.
generate
(
{
"prompt_token_ids"
:
prompt_inputs
[
"prompt_token_ids"
]
},
sampling_params
,
request_id_item
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
trace_headers
=
trace_headers
,
priority
=
request
.
priority
,
)
generators
.
append
(
generator
)
except
ValueError
as
e
:
...
...
vllm/entrypoints/openai/serving_engine.py
View file @
168cab6b
...
...
@@ -29,7 +29,7 @@ from vllm.logger import init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
BeamSearchParams
,
SamplingParams
from
vllm.sequence
import
Logprob
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.utils
import
AtomicCounter
...
...
@@ -371,7 +371,8 @@ class OpenAIServing:
self
,
request_id
:
str
,
inputs
:
Union
[
str
,
List
[
int
],
TextTokensPrompt
],
params
:
Optional
[
Union
[
SamplingParams
,
PoolingParams
]],
params
:
Optional
[
Union
[
SamplingParams
,
PoolingParams
,
BeamSearchParams
]],
lora_request
:
Optional
[
LoRARequest
],
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
)
->
None
:
...
...
vllm/sampling_params.py
View file @
168cab6b
...
...
@@ -530,3 +530,15 @@ class SamplingParams(
f
"
{
self
.
spaces_between_special_tokens
}
, "
f
"truncate_prompt_tokens=
{
self
.
truncate_prompt_tokens
}
), "
f
"guided_decoding=
{
self
.
guided_decoding
}
"
)
class
BeamSearchParams
(
msgspec
.
Struct
,
omit_defaults
=
True
,
# type: ignore[call-arg]
# required for @cached_property.
dict
=
True
):
# type: ignore[call-arg]
"""Beam search parameters for text generation."""
beam_width
:
int
max_tokens
:
int
ignore_eos
:
bool
=
False
temperature
:
float
=
0.0
vllm/utils.py
View file @
168cab6b
...
...
@@ -504,6 +504,15 @@ async def merge_async_iterators(
await
it
.
aclose
()
async
def
collect_from_async_generator
(
iterator
:
AsyncGenerator
[
T
,
None
])
->
List
[
T
]:
"""Collect all items from an async generator into a list."""
items
=
[]
async
for
item
in
iterator
:
items
.
append
(
item
)
return
items
def
get_ip
()
->
str
:
host_ip
=
envs
.
VLLM_HOST_IP
if
host_ip
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment