Unverified Commit b5ca9c35 authored by Ekagra Ranjan's avatar Ekagra Ranjan Committed by GitHub
Browse files

[Models] Cohere ASR (#35809)


Signed-off-by: default avatarEkagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
parent 24575899
...@@ -70,6 +70,29 @@ def run_audioflamingo3(question: str, audio_count: int) -> ModelRequestData: ...@@ -70,6 +70,29 @@ def run_audioflamingo3(question: str, audio_count: int) -> ModelRequestData:
) )
# CohereASR
def run_cohere_asr(question: str, audio_count: int) -> ModelRequestData:
assert audio_count == 1, "CohereASR only support single audio input per prompt"
# TODO (ekagra): add HF ckpt after asr release
model_name = "/host/engines/vllm/audio/2b-release"
prompt = (
"<|startofcontext|><|startoftranscript|>"
"<|emo:undefined|><|en|><|en|><|pnc|><|noitn|>"
"<|notimestamp|><|nodiarize|>"
)
engine_args = EngineArgs(
model=model_name,
limit_mm_per_prompt={"audio": audio_count},
trust_remote_code=True,
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
)
# MusicFlamingo # MusicFlamingo
def run_musicflamingo(question: str, audio_count: int) -> ModelRequestData: def run_musicflamingo(question: str, audio_count: int) -> ModelRequestData:
model_name = "nvidia/music-flamingo-2601-hf" model_name = "nvidia/music-flamingo-2601-hf"
...@@ -508,14 +531,15 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData: ...@@ -508,14 +531,15 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData:
model_example_map = { model_example_map = {
"audioflamingo3": run_audioflamingo3, "audioflamingo3": run_audioflamingo3,
"musicflamingo": run_musicflamingo, "cohere_asr": run_cohere_asr,
"funaudiochat": run_funaudiochat,
"gemma3n": run_gemma3n, "gemma3n": run_gemma3n,
"glmasr": run_glmasr, "glmasr": run_glmasr,
"funaudiochat": run_funaudiochat,
"granite_speech": run_granite_speech, "granite_speech": run_granite_speech,
"kimi_audio": run_kimi_audio, "kimi_audio": run_kimi_audio,
"midashenglm": run_midashenglm, "midashenglm": run_midashenglm,
"minicpmo": run_minicpmo, "minicpmo": run_minicpmo,
"musicflamingo": run_musicflamingo,
"phi4_mm": run_phi4mm, "phi4_mm": run_phi4mm,
"qwen2_audio": run_qwen2_audio, "qwen2_audio": run_qwen2_audio,
"qwen2_5_omni": run_qwen2_5_omni, "qwen2_5_omni": run_qwen2_5_omni,
......
...@@ -19,8 +19,10 @@ import soundfile ...@@ -19,8 +19,10 @@ import soundfile
import torch import torch
from datasets import load_dataset from datasets import load_dataset
from evaluate import load from evaluate import load
from transformers import AutoTokenizer
from vllm.tokenizers import get_tokenizer
from ....models.registry import HF_EXAMPLE_MODELS
from ....utils import RemoteOpenAIServer from ....utils import RemoteOpenAIServer
...@@ -64,8 +66,12 @@ async def bound_transcribe(sem, client, tokenizer, audio, reference): ...@@ -64,8 +66,12 @@ async def bound_transcribe(sem, client, tokenizer, audio, reference):
async def process_dataset(model, client, data, concurrent_request): async def process_dataset(model, client, data, concurrent_request):
sem = asyncio.Semaphore(concurrent_request) sem = asyncio.Semaphore(concurrent_request)
# Load tokenizer once outside the loop model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
tokenizer = AutoTokenizer.from_pretrained(model) tokenizer = get_tokenizer(
model,
tokenizer_mode=model_info.tokenizer_mode,
trust_remote_code=model_info.trust_remote_code,
)
# Warmup call as the first `librosa.load` server-side is quite slow. # Warmup call as the first `librosa.load` server-side is quite slow.
audio, sr = data[0]["audio"]["array"], data[0]["audio"]["sampling_rate"] audio, sr = data[0]["audio"]["array"], data[0]["audio"]["sampling_rate"]
...@@ -144,20 +150,35 @@ def run_evaluation( ...@@ -144,20 +150,35 @@ def run_evaluation(
# alternatives "openai/whisper-large-v2", "openai/whisper-large-v3-turbo".. # alternatives "openai/whisper-large-v2", "openai/whisper-large-v3-turbo"..
@pytest.mark.parametrize("model_name", ["openai/whisper-large-v3"]) # NOTE: Expected WER measured with equivalent hf.transformers args:
# whisper-large-v3 + esb-datasets-earnings22-validation-tiny-filtered.
@pytest.mark.parametrize(
"model_config",
[
("openai/whisper-large-v3", 12.744980),
# TODO (ekagra): add HF ckpt after asr release
# ("/host/engines/vllm/audio/2b-release", 11.73),
],
)
# Original dataset is 20GB+ in size, hence we use a pre-filtered slice. # Original dataset is 20GB+ in size, hence we use a pre-filtered slice.
@pytest.mark.parametrize( @pytest.mark.parametrize(
"dataset_repo", ["D4nt3/esb-datasets-earnings22-validation-tiny-filtered"] "dataset_repo", ["D4nt3/esb-datasets-earnings22-validation-tiny-filtered"]
) )
# NOTE: Expected WER measured with equivalent hf.transformers args:
# whisper-large-v3 + esb-datasets-earnings22-validation-tiny-filtered.
@pytest.mark.parametrize("expected_wer", [12.744980])
def test_wer_correctness( def test_wer_correctness(
model_name, dataset_repo, expected_wer, n_examples=-1, max_concurrent_request=None model_config, dataset_repo, n_examples=-1, max_concurrent_request=None
): ):
model_name, expected_wer = model_config
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_name)
# TODO refactor to use `ASRDataset` # TODO refactor to use `ASRDataset`
server_args = [
"--enforce-eager",
f"--tokenizer_mode={model_info.tokenizer_mode}",
]
if model_info.trust_remote_code:
server_args.append("--trust-remote-code")
with RemoteOpenAIServer( with RemoteOpenAIServer(
model_name, ["--enforce-eager"], max_wait_seconds=480 model_name,
server_args,
) as remote_server: ) as remote_server:
dataset = load_hf_dataset(dataset_repo) dataset = load_hf_dataset(dataset_repo)
...@@ -167,7 +188,14 @@ def test_wer_correctness( ...@@ -167,7 +188,14 @@ def test_wer_correctness(
client = remote_server.get_async_client() client = remote_server.get_async_client()
wer = run_evaluation( wer = run_evaluation(
model_name, client, dataset, max_concurrent_request, n_examples model_name,
client,
dataset,
max_concurrent_request,
n_examples,
) )
print(f"Expected WER: {expected_wer}, Actual WER: {wer}")
if expected_wer: if expected_wer:
torch.testing.assert_close(wer, expected_wer, atol=1e-1, rtol=1e-2) torch.testing.assert_close(wer, expected_wer, atol=1e-1, rtol=1e-2)
...@@ -1116,6 +1116,11 @@ _MULTIMODAL_EXAMPLE_MODELS = { ...@@ -1116,6 +1116,11 @@ _MULTIMODAL_EXAMPLE_MODELS = {
tokenizer_mode="mistral", tokenizer_mode="mistral",
), ),
# [Encoder-decoder] # [Encoder-decoder]
"CohereASRForConditionalGeneration": _HfExamplesInfo(
"/host/engines/vllm/audio/2b-release",
trust_remote_code=True,
is_available_online=False, # TODO (ekagra): revert after asr release
),
"NemotronParseForConditionalGeneration": _HfExamplesInfo( "NemotronParseForConditionalGeneration": _HfExamplesInfo(
"nvidia/NVIDIA-Nemotron-Parse-v1.1", trust_remote_code=True "nvidia/NVIDIA-Nemotron-Parse-v1.1", trust_remote_code=True
), ),
......
...@@ -3157,7 +3157,7 @@ class ASRDataset(HuggingFaceDataset): ...@@ -3157,7 +3157,7 @@ class ASRDataset(HuggingFaceDataset):
**kwargs, **kwargs,
) -> list: ) -> list:
output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
if "openai" in tokenizer.name_or_path: if "openai" in getattr(tokenizer, "name_or_path", ""):
prompt = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>" prompt = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>"
else: else:
prompt = "" prompt = ""
......
...@@ -107,7 +107,7 @@ class TranscriptionRequest(OpenAIBaseModel): ...@@ -107,7 +107,7 @@ class TranscriptionRequest(OpenAIBaseModel):
stream_include_usage: bool | None = False stream_include_usage: bool | None = False
stream_continuous_usage_stats: bool | None = False stream_continuous_usage_stats: bool | None = False
vllm_xargs: dict[str, str | int | float] | None = Field( vllm_xargs: dict[str, str | int | float | bool] | None = Field(
default=None, default=None,
description=( description=(
"Additional request parameters with string or " "Additional request parameters with string or "
......
...@@ -365,6 +365,7 @@ def build_enc_dec_inputs( ...@@ -365,6 +365,7 @@ def build_enc_dec_inputs(
encoder_inputs: SingletonInputs, encoder_inputs: SingletonInputs,
decoder_inputs: SingletonInputs | None, decoder_inputs: SingletonInputs | None,
decoder_start_token_id: int, decoder_start_token_id: int,
skip_decoder_start_token: bool = False,
) -> EncoderDecoderInputs: ) -> EncoderDecoderInputs:
enc_inputs = _validate_enc_inputs(encoder_inputs) enc_inputs = _validate_enc_inputs(encoder_inputs)
...@@ -396,6 +397,7 @@ def build_enc_dec_inputs( ...@@ -396,6 +397,7 @@ def build_enc_dec_inputs(
else: else:
assert_never(enc_inputs) assert_never(enc_inputs)
if not skip_decoder_start_token:
dec_inputs_new["prompt_token_ids"] = _prepare_decoder_input_ids_for_generation( dec_inputs_new["prompt_token_ids"] = _prepare_decoder_input_ids_for_generation(
dec_inputs_new["prompt_token_ids"], dec_inputs_new["prompt_token_ids"],
decoder_start_token_id, decoder_start_token_id,
......
...@@ -261,6 +261,15 @@ class InputPreprocessor: ...@@ -261,6 +261,15 @@ class InputPreprocessor:
encoder_prompt = prompt["encoder_prompt"] encoder_prompt = prompt["encoder_prompt"]
decoder_prompt = prompt["decoder_prompt"] decoder_prompt = prompt["decoder_prompt"]
skip_decoder_start_token = False
if self.renderer.mm_processor is not None:
from vllm.multimodal.processing import EncDecMultiModalProcessor
if isinstance(self.renderer.mm_processor, EncDecMultiModalProcessor):
skip_decoder_start_token = (
self.renderer.mm_processor.skip_decoder_start_token
)
return build_enc_dec_inputs( return build_enc_dec_inputs(
encoder_inputs=self._prompt_to_llm_inputs( encoder_inputs=self._prompt_to_llm_inputs(
encoder_prompt, encoder_prompt,
...@@ -275,6 +284,7 @@ class InputPreprocessor: ...@@ -275,6 +284,7 @@ class InputPreprocessor:
) )
), ),
decoder_start_token_id=self.renderer.get_dec_start_token_id(), decoder_start_token_id=self.renderer.get_dec_start_token_id(),
skip_decoder_start_token=skip_decoder_start_token,
) )
def _process_decoder_only_prompt( def _process_decoder_only_prompt(
......
This diff is collapsed.
...@@ -534,6 +534,10 @@ _MULTIMODAL_MODELS = { ...@@ -534,6 +534,10 @@ _MULTIMODAL_MODELS = {
"VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501 "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501
"VoxtralRealtimeGeneration": ("voxtral_realtime", "VoxtralRealtimeGeneration"), # noqa: E501 "VoxtralRealtimeGeneration": ("voxtral_realtime", "VoxtralRealtimeGeneration"), # noqa: E501
# [Encoder-decoder] # [Encoder-decoder]
"CohereASRForConditionalGeneration": (
"cohere_asr",
"CohereASRForConditionalGeneration",
),
"NemotronParseForConditionalGeneration": ( "NemotronParseForConditionalGeneration": (
"nemotron_parse", "nemotron_parse",
"NemotronParseForConditionalGeneration", "NemotronParseForConditionalGeneration",
......
...@@ -1682,6 +1682,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1682,6 +1682,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
skip_decoder_start_token: bool = False
@abstractmethod @abstractmethod
def create_encoder_prompt( def create_encoder_prompt(
self, self,
......
...@@ -700,12 +700,20 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -700,12 +700,20 @@ class BaseRenderer(ABC, Generic[_T]):
enc_prompt = prompt["encoder_prompt"] enc_prompt = prompt["encoder_prompt"]
dec_prompt = prompt["decoder_prompt"] dec_prompt = prompt["decoder_prompt"]
skip_decoder_start_token = False
if self.mm_processor is not None:
from vllm.multimodal.processing import EncDecMultiModalProcessor
if isinstance(self.mm_processor, EncDecMultiModalProcessor):
skip_decoder_start_token = self.mm_processor.skip_decoder_start_token
return build_enc_dec_inputs( return build_enc_dec_inputs(
encoder_inputs=self._process_singleton(enc_prompt), encoder_inputs=self._process_singleton(enc_prompt),
decoder_inputs=( decoder_inputs=(
None if dec_prompt is None else self._process_singleton(dec_prompt) None if dec_prompt is None else self._process_singleton(dec_prompt)
), ),
decoder_start_token_id=self.get_dec_start_token_id(), decoder_start_token_id=self.get_dec_start_token_id(),
skip_decoder_start_token=skip_decoder_start_token,
) )
def process_for_engine( def process_for_engine(
......
...@@ -300,6 +300,28 @@ class ModelArchConfigConvertorBase: ...@@ -300,6 +300,28 @@ class ModelArchConfigConvertorBase:
return model_arch_config return model_arch_config
class CohereAsrModelArchConfigConvertor(ModelArchConfigConvertorBase):
def get_total_num_attention_heads(self) -> int:
return self.hf_text_config.transf_decoder["config_dict"]["num_attention_heads"]
def get_head_size(self) -> int:
hidden_size = self.hf_text_config.transf_decoder["config_dict"]["hidden_size"]
num_attention_heads = self.hf_text_config.transf_decoder["config_dict"][
"num_attention_heads"
]
return hidden_size // num_attention_heads
def get_total_num_kv_heads(self) -> int:
enc_num_kv_heads = self.hf_text_config.encoder["n_heads"]
dec_num_kv_heads = self.hf_text_config.transf_decoder["config_dict"][
"num_attention_heads"
]
assert enc_num_kv_heads == dec_num_kv_heads, (
"Encoder and decoder must have the same number of kv heads"
)
return enc_num_kv_heads
class MambaModelArchConfigConvertor(ModelArchConfigConvertorBase): class MambaModelArchConfigConvertor(ModelArchConfigConvertorBase):
def get_head_size(self) -> int: def get_head_size(self) -> int:
return 0 return 0
...@@ -425,6 +447,7 @@ class LongCatFlashMTPModelArchConfigConvertor(ModelArchConfigConvertorBase): ...@@ -425,6 +447,7 @@ class LongCatFlashMTPModelArchConfigConvertor(ModelArchConfigConvertorBase):
# hf_config.model_type -> convertor class # hf_config.model_type -> convertor class
MODEL_ARCH_CONFIG_CONVERTORS = { MODEL_ARCH_CONFIG_CONVERTORS = {
"cohere_asr": CohereAsrModelArchConfigConvertor,
"mamba": MambaModelArchConfigConvertor, "mamba": MambaModelArchConfigConvertor,
"falcon_mamba": MambaModelArchConfigConvertor, "falcon_mamba": MambaModelArchConfigConvertor,
"timm_wrapper": TerratorchModelArchConfigConvertor, "timm_wrapper": TerratorchModelArchConfigConvertor,
......
...@@ -12,6 +12,7 @@ import importlib ...@@ -12,6 +12,7 @@ import importlib
__all__ = [ __all__ = [
"BagelProcessor", "BagelProcessor",
"CohereASRProcessor",
"DeepseekVLV2Processor", "DeepseekVLV2Processor",
"Eagle2_5_VLProcessor", "Eagle2_5_VLProcessor",
"FireRedASR2Processor", "FireRedASR2Processor",
...@@ -38,6 +39,7 @@ __all__ = [ ...@@ -38,6 +39,7 @@ __all__ = [
_CLASS_TO_MODULE: dict[str, str] = { _CLASS_TO_MODULE: dict[str, str] = {
"BagelProcessor": "vllm.transformers_utils.processors.bagel", "BagelProcessor": "vllm.transformers_utils.processors.bagel",
"CohereASRProcessor": "vllm.transformers_utils.processors.cohere_asr",
"DeepseekVLV2Processor": "vllm.transformers_utils.processors.deepseek_vl2", "DeepseekVLV2Processor": "vllm.transformers_utils.processors.deepseek_vl2",
"Eagle2_5_VLProcessor": "vllm.transformers_utils.processors.eagle2_5_vl", "Eagle2_5_VLProcessor": "vllm.transformers_utils.processors.eagle2_5_vl",
"FireRedASR2Processor": "vllm.transformers_utils.processors.fireredasr2", "FireRedASR2Processor": "vllm.transformers_utils.processors.fireredasr2",
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment