"vscode:/vscode.git/clone" did not exist on "eef921f45e7d3efb2ed2ccab80ee20ee2e4ebe38"
Unverified Commit 65a4da15 authored by Alex Brooks's avatar Alex Brooks Committed by GitHub
Browse files

[Frontend] Add Support for MM Encoder/Decoder Beam Search (Online Transcriptions) (#36160)


Signed-off-by: default avatarAlex Brooks <albrooks@redhat.com>
parent 217f2759
......@@ -439,6 +439,8 @@ you can use the [official OpenAI Python client](https://github.com/openai/openai
Code example: [examples/online_serving/openai_transcription_client.py](../../examples/online_serving/openai_transcription_client.py)
NOTE: beam search is currently supported in the transcriptions endpoint for encoder-decoder multimodal models, e.g., whisper, but highly inefficient as work for handling the encoder/decoder cache is actively ongoing. This is an active point of ongoing optimization and will be handled properly in the very near future.
#### API Enforced Limits
Set the maximum audio file size (in MB) that VLLM will accept, via the
......
......@@ -317,3 +317,72 @@ async def test_language_auto_detect(
assert any(word.lower() in text_lower for word in expected_text), (
f"Expected {expected_lang} text but got: {transcription.text}"
)
@pytest.mark.asyncio
async def test_whisper_beam_search_single_beam(mary_had_lamb, whisper_client):
"""Test beam search with encoder-decoder model (Whisper) on transcriptions with
one beam aligns with greedy decoding.
"""
beam_transcription = await whisper_client.audio.transcriptions.create(
model=MODEL_NAME,
file=mary_had_lamb,
language="en",
response_format="text",
temperature=0.0,
extra_body=dict(
use_beam_search=True,
n=1,
),
)
greedy_transcription = await whisper_client.audio.transcriptions.create(
model=MODEL_NAME,
file=mary_had_lamb,
response_format="text",
temperature=0.0,
)
greedy_res = json.loads(greedy_transcription)["text"]
beam_res = json.loads(beam_transcription)["text"]
assert greedy_res == beam_res
@pytest.mark.asyncio
async def test_whisper_beam_search_multibeam(mary_had_lamb, whisper_client):
"""Test n>1 for beam search returns one transcription (best beam)."""
transcription = await whisper_client.audio.transcriptions.create(
model=MODEL_NAME,
file=mary_had_lamb,
language="en",
response_format="text",
temperature=0.0,
extra_body=dict(
use_beam_search=True,
n=2,
),
)
result = json.loads(transcription)
text = result["text"]
assert text is not None
assert len(text) > 0
assert "mary had a little lamb" in text.lower()
@pytest.mark.asyncio
async def test_stream_with_beams_raises(winning_call, whisper_client):
"""Test that stream=True + beam search raises bad request for now."""
with pytest.raises(openai.BadRequestError):
await whisper_client.audio.transcriptions.create(
model=MODEL_NAME,
file=winning_call,
language="en",
stream=True,
extra_body=dict(
use_beam_search=True,
n=2,
),
)
......@@ -129,6 +129,11 @@ class OpenAIServingCompletion(OpenAIServing):
- suffix (the language models we currently support do not support
suffix)
"""
if request.stream and request.use_beam_search:
return self.create_error_response(
"Streaming is not currently supported with beam search"
)
result = await self.render_completion_request(request)
if isinstance(result, ErrorResponse):
return result
......@@ -211,13 +216,10 @@ class OpenAIServingCompletion(OpenAIServing):
model_name = self.models.model_name(lora_request)
num_prompts = len(engine_prompts)
# We do not stream the results when using beam search.
stream = request.stream and not request.use_beam_search
# Streaming response
tokenizer = self.renderer.tokenizer
if stream:
if request.stream:
return self.completion_stream_generator(
request,
engine_prompts,
......
......@@ -237,13 +237,14 @@ class OpenAIServing:
if prompt["type"] == "embeds":
raise NotImplementedError("Embedding prompt not supported for beam search")
if prompt["type"] == "enc_dec":
raise NotImplementedError(
"Encoder-decoder prompt not supported for beam search"
# Extract prompt tokens and text based on model type
decoder_prompt = (
prompt if prompt["type"] != "enc_dec" else prompt["decoder_prompt"]
)
prompt_text = decoder_prompt.get("prompt")
prompt_token_ids = decoder_prompt["prompt_token_ids"]
prompt_text = prompt.get("prompt")
prompt_token_ids = prompt["prompt_token_ids"]
tokenized_length = len(prompt_token_ids)
logprobs_num = 2 * beam_width
......
......@@ -20,6 +20,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger
from vllm.sampling_params import (
BeamSearchParams,
RequestOutputKind,
SamplingParams,
)
......@@ -123,6 +124,18 @@ class TranscriptionRequest(OpenAIBaseModel):
"""
# --8<-- [start:transcription-sampling-params]
use_beam_search: bool = False
"""Whether or not beam search should be used."""
n: int = 1
"""The number of beams to be used in beam search."""
length_penalty: float = 1.0
"""Length penalty to be used for beam search."""
include_stop_str_in_output: bool = False
"""Whether to include the stop strings in output text."""
temperature: float = Field(default=0.0)
"""The sampling temperature, between 0 and 1.
......@@ -170,6 +183,29 @@ class TranscriptionRequest(OpenAIBaseModel):
"min_p": 0.0,
}
def to_beam_search_params(
self,
default_max_tokens: int,
default_sampling_params: dict | None = None,
) -> BeamSearchParams:
if default_sampling_params is None:
default_sampling_params = {}
max_tokens = default_max_tokens
n = self.n if self.n is not None else 1
# NOTE: Temp 0 is a different fallback than completions
if (temperature := self.temperature) is None:
temperature = default_sampling_params.get("temperature", 0)
return BeamSearchParams(
beam_width=n,
max_tokens=max_tokens,
temperature=temperature,
length_penalty=self.length_penalty,
include_stop_str_in_output=self.include_stop_str_in_output,
)
def to_sampling_params(
self, default_max_tokens: int, default_sampling_params: dict | None = None
) -> SamplingParams:
......@@ -376,6 +412,18 @@ class TranslationRequest(OpenAIBaseModel):
# TODO support additional sampling parameters
# --8<-- [start:translation-sampling-params]
use_beam_search: bool = False
"""Whether or not beam search should be used."""
n: int = 1
"""The number of beams to be used in beam search."""
length_penalty: float = 1.0
"""Length penalty to be used for beam search."""
include_stop_str_in_output: bool = False
"""Whether to include the stop strings in output text."""
seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
"""The seed to use for sampling."""
......@@ -424,6 +472,29 @@ class TranslationRequest(OpenAIBaseModel):
"temperature": 0,
}
def to_beam_search_params(
self,
default_max_tokens: int,
default_sampling_params: dict | None = None,
) -> BeamSearchParams:
if default_sampling_params is None:
default_sampling_params = {}
max_tokens = default_max_tokens
n = self.n if self.n is not None else 1
# NOTE: Temp 0 is a different fallback than completions
if (temperature := self.temperature) is None:
temperature = default_sampling_params.get("temperature", 0)
return BeamSearchParams(
beam_width=n,
max_tokens=max_tokens,
temperature=temperature,
length_penalty=self.length_penalty,
include_stop_str_in_output=self.include_stop_str_in_output,
)
def to_sampling_params(
self, default_max_tokens: int, default_sampling_params: dict | None = None
) -> SamplingParams:
......
......@@ -39,7 +39,7 @@ from vllm.entrypoints.openai.speech_to_text.protocol import (
)
from vllm.entrypoints.utils import get_max_tokens
from vllm.exceptions import VLLMValidationError
from vllm.inputs import ProcessorInputs
from vllm.inputs import EncoderDecoderInputs, ProcessorInputs
from vllm.logger import init_logger
from vllm.logprobs import FlatLogprobs, Logprob
from vllm.model_executor.models import (
......@@ -50,6 +50,7 @@ from vllm.multimodal.audio import split_audio
from vllm.outputs import RequestOutput
from vllm.renderers.inputs import DictPrompt, EncoderDecoderDictPrompt
from vllm.renderers.inputs.preprocess import parse_enc_dec_prompt, parse_model_prompt
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import get_tokenizer
from vllm.utils.import_utils import PlaceholderModule
......@@ -264,8 +265,6 @@ class OpenAISpeechToText(OpenAIServing):
via ``get_language_detection_prompt`` and
``parse_language_detection_output``.
"""
from vllm.sampling_params import SamplingParams
prompt = self.model_cls.get_language_detection_prompt(
audio_chunk,
self.asr_config,
......@@ -403,6 +402,26 @@ class OpenAISpeechToText(OpenAIServing):
return prompt
@staticmethod
def _get_decoder_prompt_len(engine_prompts: list[ProcessorInputs]) -> int:
"""Get the length of the decoder prompt. Currently we need to offset
by the decoder prompt length when running beam search because the mm
encoder is not currently cached and runs on decode calls; because of
this, we need to make sure the redundant encoder calls won't exceed
the context :(
FIXME (Alex) - this will be removed in the very near future once the
encoder/decoder caching is implemented.
"""
input_len = 0
assert len(engine_prompts) > 0
first_eng_prompt = engine_prompts[0]
if first_eng_prompt.get("type") == "enc_dec":
first_eng_prompt = cast(EncoderDecoderInputs, first_eng_prompt)
input_len = len(first_eng_prompt["decoder_prompt"]["prompt_token_ids"])
return input_len
def _get_verbose_segments(
self,
tokens: tuple,
......@@ -481,6 +500,11 @@ class OpenAISpeechToText(OpenAIServing):
) -> T | V | AsyncGenerator[str, None] | ErrorResponse:
"""Base method for speech-to-text operations like transcription and
translation."""
if request.stream and request.use_beam_search:
return self.create_error_response(
"Streaming is not currently supported with beam search"
)
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
......@@ -526,6 +550,13 @@ class OpenAISpeechToText(OpenAIServing):
# Schedule the request and get the result generator.
max_model_len = self.model_config.max_model_len
list_result_generator: list[AsyncGenerator[RequestOutput, None]] | None = None
input_len = (
OpenAISpeechToText._get_decoder_prompt_len(engine_prompts)
if request.use_beam_search
else 0
)
# Unlike most decoder-only models, whisper generation length is not
# constrained by the size of the input audio, which is mapped to a
# fixed-size log-mel-spectogram. Still, allow for fewer tokens to be
......@@ -533,14 +564,20 @@ class OpenAISpeechToText(OpenAIServing):
max_tokens = get_max_tokens(
max_model_len,
request.max_completion_tokens,
0,
input_len,
self.default_sampling_params,
)
if request.use_beam_search:
sampling_params = request.to_beam_search_params(
max_tokens, self.default_sampling_params
)
else:
sampling_params = request.to_sampling_params(
max_tokens,
self.default_sampling_params,
)
if request.response_format == "verbose_json":
sampling_params.logprobs = 1
......@@ -561,6 +598,15 @@ class OpenAISpeechToText(OpenAIServing):
else await self._get_trace_headers(raw_request.headers)
)
if isinstance(sampling_params, BeamSearchParams):
generator = self.beam_search(
prompt=engine_prompt,
params=sampling_params,
request_id=request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
)
else:
generator = self.engine_client.generate(
engine_prompt,
sampling_params,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment