"vscode:/vscode.git/clone" did not exist on "24700c346bee5760f015bf41cdc6fd9ffb5d6aaf"
Unverified Commit b5ca9c35 authored by Ekagra Ranjan's avatar Ekagra Ranjan Committed by GitHub
Browse files

[Models] Cohere ASR (#35809)


Signed-off-by: default avatarEkagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
parent 24575899
......@@ -70,6 +70,29 @@ def run_audioflamingo3(question: str, audio_count: int) -> ModelRequestData:
)
# CohereASR
def run_cohere_asr(question: str, audio_count: int) -> ModelRequestData:
assert audio_count == 1, "CohereASR only support single audio input per prompt"
# TODO (ekagra): add HF ckpt after asr release
model_name = "/host/engines/vllm/audio/2b-release"
prompt = (
"<|startofcontext|><|startoftranscript|>"
"<|emo:undefined|><|en|><|en|><|pnc|><|noitn|>"
"<|notimestamp|><|nodiarize|>"
)
engine_args = EngineArgs(
model=model_name,
limit_mm_per_prompt={"audio": audio_count},
trust_remote_code=True,
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
)
# MusicFlamingo
def run_musicflamingo(question: str, audio_count: int) -> ModelRequestData:
model_name = "nvidia/music-flamingo-2601-hf"
......@@ -508,14 +531,15 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData:
model_example_map = {
"audioflamingo3": run_audioflamingo3,
"musicflamingo": run_musicflamingo,
"cohere_asr": run_cohere_asr,
"funaudiochat": run_funaudiochat,
"gemma3n": run_gemma3n,
"glmasr": run_glmasr,
"funaudiochat": run_funaudiochat,
"granite_speech": run_granite_speech,
"kimi_audio": run_kimi_audio,
"midashenglm": run_midashenglm,
"minicpmo": run_minicpmo,
"musicflamingo": run_musicflamingo,
"phi4_mm": run_phi4mm,
"qwen2_audio": run_qwen2_audio,
"qwen2_5_omni": run_qwen2_5_omni,
......
......@@ -19,8 +19,10 @@ import soundfile
import torch
from datasets import load_dataset
from evaluate import load
from transformers import AutoTokenizer
from vllm.tokenizers import get_tokenizer
from ....models.registry import HF_EXAMPLE_MODELS
from ....utils import RemoteOpenAIServer
......@@ -64,8 +66,12 @@ async def bound_transcribe(sem, client, tokenizer, audio, reference):
async def process_dataset(model, client, data, concurrent_request):
sem = asyncio.Semaphore(concurrent_request)
# Load tokenizer once outside the loop
tokenizer = AutoTokenizer.from_pretrained(model)
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
tokenizer = get_tokenizer(
model,
tokenizer_mode=model_info.tokenizer_mode,
trust_remote_code=model_info.trust_remote_code,
)
# Warmup call as the first `librosa.load` server-side is quite slow.
audio, sr = data[0]["audio"]["array"], data[0]["audio"]["sampling_rate"]
......@@ -144,20 +150,35 @@ def run_evaluation(
# alternatives "openai/whisper-large-v2", "openai/whisper-large-v3-turbo"..
@pytest.mark.parametrize("model_name", ["openai/whisper-large-v3"])
# NOTE: Expected WER measured with equivalent hf.transformers args:
# whisper-large-v3 + esb-datasets-earnings22-validation-tiny-filtered.
@pytest.mark.parametrize(
"model_config",
[
("openai/whisper-large-v3", 12.744980),
# TODO (ekagra): add HF ckpt after asr release
# ("/host/engines/vllm/audio/2b-release", 11.73),
],
)
# Original dataset is 20GB+ in size, hence we use a pre-filtered slice.
@pytest.mark.parametrize(
"dataset_repo", ["D4nt3/esb-datasets-earnings22-validation-tiny-filtered"]
)
# NOTE: Expected WER measured with equivalent hf.transformers args:
# whisper-large-v3 + esb-datasets-earnings22-validation-tiny-filtered.
@pytest.mark.parametrize("expected_wer", [12.744980])
def test_wer_correctness(
model_name, dataset_repo, expected_wer, n_examples=-1, max_concurrent_request=None
model_config, dataset_repo, n_examples=-1, max_concurrent_request=None
):
model_name, expected_wer = model_config
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_name)
# TODO refactor to use `ASRDataset`
server_args = [
"--enforce-eager",
f"--tokenizer_mode={model_info.tokenizer_mode}",
]
if model_info.trust_remote_code:
server_args.append("--trust-remote-code")
with RemoteOpenAIServer(
model_name, ["--enforce-eager"], max_wait_seconds=480
model_name,
server_args,
) as remote_server:
dataset = load_hf_dataset(dataset_repo)
......@@ -167,7 +188,14 @@ def test_wer_correctness(
client = remote_server.get_async_client()
wer = run_evaluation(
model_name, client, dataset, max_concurrent_request, n_examples
model_name,
client,
dataset,
max_concurrent_request,
n_examples,
)
print(f"Expected WER: {expected_wer}, Actual WER: {wer}")
if expected_wer:
torch.testing.assert_close(wer, expected_wer, atol=1e-1, rtol=1e-2)
......@@ -1116,6 +1116,11 @@ _MULTIMODAL_EXAMPLE_MODELS = {
tokenizer_mode="mistral",
),
# [Encoder-decoder]
"CohereASRForConditionalGeneration": _HfExamplesInfo(
"/host/engines/vllm/audio/2b-release",
trust_remote_code=True,
is_available_online=False, # TODO (ekagra): revert after asr release
),
"NemotronParseForConditionalGeneration": _HfExamplesInfo(
"nvidia/NVIDIA-Nemotron-Parse-v1.1", trust_remote_code=True
),
......
......@@ -3157,7 +3157,7 @@ class ASRDataset(HuggingFaceDataset):
**kwargs,
) -> list:
output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
if "openai" in tokenizer.name_or_path:
if "openai" in getattr(tokenizer, "name_or_path", ""):
prompt = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>"
else:
prompt = ""
......
......@@ -107,7 +107,7 @@ class TranscriptionRequest(OpenAIBaseModel):
stream_include_usage: bool | None = False
stream_continuous_usage_stats: bool | None = False
vllm_xargs: dict[str, str | int | float] | None = Field(
vllm_xargs: dict[str, str | int | float | bool] | None = Field(
default=None,
description=(
"Additional request parameters with string or "
......
......@@ -365,6 +365,7 @@ def build_enc_dec_inputs(
encoder_inputs: SingletonInputs,
decoder_inputs: SingletonInputs | None,
decoder_start_token_id: int,
skip_decoder_start_token: bool = False,
) -> EncoderDecoderInputs:
enc_inputs = _validate_enc_inputs(encoder_inputs)
......@@ -396,10 +397,11 @@ def build_enc_dec_inputs(
else:
assert_never(enc_inputs)
dec_inputs_new["prompt_token_ids"] = _prepare_decoder_input_ids_for_generation(
dec_inputs_new["prompt_token_ids"],
decoder_start_token_id,
)
if not skip_decoder_start_token:
dec_inputs_new["prompt_token_ids"] = _prepare_decoder_input_ids_for_generation(
dec_inputs_new["prompt_token_ids"],
decoder_start_token_id,
)
if cache_salt := enc_inputs.get("cache_salt"):
dec_inputs_new["cache_salt"] = cache_salt
......
......@@ -261,6 +261,15 @@ class InputPreprocessor:
encoder_prompt = prompt["encoder_prompt"]
decoder_prompt = prompt["decoder_prompt"]
skip_decoder_start_token = False
if self.renderer.mm_processor is not None:
from vllm.multimodal.processing import EncDecMultiModalProcessor
if isinstance(self.renderer.mm_processor, EncDecMultiModalProcessor):
skip_decoder_start_token = (
self.renderer.mm_processor.skip_decoder_start_token
)
return build_enc_dec_inputs(
encoder_inputs=self._prompt_to_llm_inputs(
encoder_prompt,
......@@ -275,6 +284,7 @@ class InputPreprocessor:
)
),
decoder_start_token_id=self.renderer.get_dec_start_token_id(),
skip_decoder_start_token=skip_decoder_start_token,
)
def _process_decoder_only_prompt(
......
This diff is collapsed.
......@@ -534,6 +534,10 @@ _MULTIMODAL_MODELS = {
"VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501
"VoxtralRealtimeGeneration": ("voxtral_realtime", "VoxtralRealtimeGeneration"), # noqa: E501
# [Encoder-decoder]
"CohereASRForConditionalGeneration": (
"cohere_asr",
"CohereASRForConditionalGeneration",
),
"NemotronParseForConditionalGeneration": (
"nemotron_parse",
"NemotronParseForConditionalGeneration",
......
......@@ -1682,6 +1682,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
skip_decoder_start_token: bool = False
@abstractmethod
def create_encoder_prompt(
self,
......
......@@ -700,12 +700,20 @@ class BaseRenderer(ABC, Generic[_T]):
enc_prompt = prompt["encoder_prompt"]
dec_prompt = prompt["decoder_prompt"]
skip_decoder_start_token = False
if self.mm_processor is not None:
from vllm.multimodal.processing import EncDecMultiModalProcessor
if isinstance(self.mm_processor, EncDecMultiModalProcessor):
skip_decoder_start_token = self.mm_processor.skip_decoder_start_token
return build_enc_dec_inputs(
encoder_inputs=self._process_singleton(enc_prompt),
decoder_inputs=(
None if dec_prompt is None else self._process_singleton(dec_prompt)
),
decoder_start_token_id=self.get_dec_start_token_id(),
skip_decoder_start_token=skip_decoder_start_token,
)
def process_for_engine(
......
......@@ -300,6 +300,28 @@ class ModelArchConfigConvertorBase:
return model_arch_config
class CohereAsrModelArchConfigConvertor(ModelArchConfigConvertorBase):
def get_total_num_attention_heads(self) -> int:
return self.hf_text_config.transf_decoder["config_dict"]["num_attention_heads"]
def get_head_size(self) -> int:
hidden_size = self.hf_text_config.transf_decoder["config_dict"]["hidden_size"]
num_attention_heads = self.hf_text_config.transf_decoder["config_dict"][
"num_attention_heads"
]
return hidden_size // num_attention_heads
def get_total_num_kv_heads(self) -> int:
enc_num_kv_heads = self.hf_text_config.encoder["n_heads"]
dec_num_kv_heads = self.hf_text_config.transf_decoder["config_dict"][
"num_attention_heads"
]
assert enc_num_kv_heads == dec_num_kv_heads, (
"Encoder and decoder must have the same number of kv heads"
)
return enc_num_kv_heads
class MambaModelArchConfigConvertor(ModelArchConfigConvertorBase):
def get_head_size(self) -> int:
return 0
......@@ -425,6 +447,7 @@ class LongCatFlashMTPModelArchConfigConvertor(ModelArchConfigConvertorBase):
# hf_config.model_type -> convertor class
MODEL_ARCH_CONFIG_CONVERTORS = {
"cohere_asr": CohereAsrModelArchConfigConvertor,
"mamba": MambaModelArchConfigConvertor,
"falcon_mamba": MambaModelArchConfigConvertor,
"timm_wrapper": TerratorchModelArchConfigConvertor,
......
......@@ -12,6 +12,7 @@ import importlib
__all__ = [
"BagelProcessor",
"CohereASRProcessor",
"DeepseekVLV2Processor",
"Eagle2_5_VLProcessor",
"FireRedASR2Processor",
......@@ -38,6 +39,7 @@ __all__ = [
_CLASS_TO_MODULE: dict[str, str] = {
"BagelProcessor": "vllm.transformers_utils.processors.bagel",
"CohereASRProcessor": "vllm.transformers_utils.processors.cohere_asr",
"DeepseekVLV2Processor": "vllm.transformers_utils.processors.deepseek_vl2",
"Eagle2_5_VLProcessor": "vllm.transformers_utils.processors.eagle2_5_vl",
"FireRedASR2Processor": "vllm.transformers_utils.processors.fireredasr2",
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment