Unverified Commit fa93cd9f authored by Alex Brooks's avatar Alex Brooks Committed by GitHub
Browse files

[Model] Add Granite Speech Support (#16246)


Signed-off-by: default avatarAlex-Brooks <Alex.brooks@ibm.com>
Signed-off-by: default avatarAlex-Brooks <Alex.Brooks@ibm.com>
parent aec9674d
......@@ -895,6 +895,13 @@ See [this page](#generative-models) for more information on how to use generativ
* ✅︎
* ✅︎
* ✅︎
- * `GraniteSpeechForConditionalGeneration`
* Granite Speech
* T + A
* `ibm-granite/granite-speech-3.3-8b`
* ✅︎
* ✅︎
* ✅︎
- * `H2OVLChatModel`
* H2OVL
* T + I<sup>E+</sup>
......
......@@ -38,6 +38,37 @@ class ModelRequestData(NamedTuple):
# Unless specified, these settings have been tested to work on a single L4.
# Granite Speech
def run_granite_speech(question: str, audio_count: int) -> ModelRequestData:
# NOTE - the setting in this example are somehat different than what is
# optimal for granite speech, and it is generally recommended to use beam
# search. Check the model README for suggested settings.
# https://huggingface.co/ibm-granite/granite-speech-3.3-8b
model_name = "ibm-granite/granite-speech-3.3-8b"
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=2048,
max_num_seqs=2,
enable_lora=True,
max_lora_rank=64,
limit_mm_per_prompt={"audio": audio_count},
)
# The model has an audio-specific lora directly in its model dir;
# it should be enabled whenever you pass audio inputs to the model.
speech_lora_path = model_name
audio_placeholder = "<|audio|>" * audio_count
prompts = f"<|start_of_role|>system<|end_of_role|>Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>{audio_placeholder}{question}<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>" # noqa: E501
return ModelRequestData(
engine_args=engine_args,
prompt=prompts,
lora_requests=[LoRARequest("speech", 1, speech_lora_path)],
)
# MiniCPM-O
def run_minicpmo(question: str, audio_count: int) -> ModelRequestData:
model_name = "openbmb/MiniCPM-o-2_6"
......@@ -209,6 +240,7 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData:
model_example_map = {
"granite_speech": run_granite_speech,
"minicpmo": run_minicpmo,
"phi4_mm": run_phi4mm,
"qwen2_audio": run_qwen2_audio,
......
......@@ -21,6 +21,7 @@ from transformers.models.auto.auto_factory import _BaseAutoModelClass
from tests.models.utils import (TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs)
from vllm import LLM, SamplingParams
from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.config import TaskOption, _get_and_verify_dtype
......@@ -103,10 +104,25 @@ class _VideoAssets(_VideoAssetsBase):
return [prompts["sample_demo_1"]]
class _AudioAssetsBase(UserList[AudioAsset]):
pass
class _AudioAssets(_AudioAssetsBase):
def __init__(self) -> None:
super().__init__([
AudioAsset("mary_had_lamb"),
AudioAsset("winning_call"),
])
IMAGE_ASSETS = _ImageAssets()
"""Singleton instance of :class:`_ImageAssets`."""
VIDEO_ASSETS = _VideoAssets()
"""Singleton instance of :class:`_VideoAssets`."""
AUDIO_ASSETS = _AudioAssets()
"""Singleton instance of :class:`_AudioAssets`."""
@pytest.fixture(scope="function", autouse=True)
......@@ -263,6 +279,11 @@ def video_assets() -> _VideoAssets:
return VIDEO_ASSETS
@pytest.fixture(scope="session")
def audio_assets() -> _AudioAssets:
return AUDIO_ASSETS
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)
_R = TypeVar("_R")
......@@ -390,10 +411,15 @@ class HfRunner:
processor_kwargs["images"] = image
if videos is not None and (video := videos[i]) is not None:
processor_kwargs["videos"] = video
if audios is not None and (audio_tuple := audios[i]) is not None:
audio, sr = audio_tuple
if audios is not None and (audio_inputs := audios[i]) is not None:
# HACK - not all processors take sampling_rate; we should
# clean this up in the future.
if len(audio_inputs) == 2:
audio, sr = audio_inputs
processor_kwargs["audio"] = audio
processor_kwargs["sampling_rate"] = sr
else:
processor_kwargs["audio"] = audio_inputs
inputs = self.processor(**processor_kwargs)
if isinstance(inputs, BatchFeature):
......
# SPDX-License-Identifier: Apache-2.0
from collections.abc import Sequence
from typing import Optional
import pytest
from transformers import AutoModelForSpeechSeq2Seq
from vllm.lora.request import LoRARequest
from vllm.sequence import SampleLogprobs
from ....conftest import HfRunner, PromptAudioInput, VllmRunner, _AudioAssets
from ...registry import HF_EXAMPLE_MODELS
from ...utils import check_logprobs_close
HF_AUDIO_PROMPT = "<|start_of_role|>system<|end_of_role|>Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|><|audio|>can you transcribe the speech into a written format?<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>" # noqa: E501
def vllm_to_hf_output(
vllm_output: tuple[list[int], str, Optional[SampleLogprobs]],
) -> tuple[list[int], str, Optional[SampleLogprobs]]:
"""Sanitize hf output to be comparable with vllm output."""
output_ids, output_str, out_logprobs = vllm_output
hf_output_str = output_str + "<|end_of_text|>"
return output_ids, hf_output_str, out_logprobs
MODEL_NAME = "ibm-granite/granite-speech-3.3-8b"
# Audio lora co-exists directly in the model directory, but
# currently still needs to be passed directly to vLLM.
audio_lora_path = MODEL_NAME
models = [MODEL_NAME]
def run_test(
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
inputs: Sequence[tuple[list[str], PromptAudioInput]],
model: str,
*,
max_model_len: int,
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
"""Inference result should be the same between hf and vllm.
All the audio fixtures for the test are from AUDIO_ASSETS.
For huggingface runner, we provide the audio as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
# max_model_len should be greater than image_feature_size
with vllm_runner(
model,
task="generate",
max_model_len=max_model_len,
max_num_seqs=1,
dtype=dtype,
limit_mm_per_prompt={"audio": 1},
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enable_lora=True,
max_lora_rank=64,
enforce_eager=True,
) as vllm_model:
lora_request = LoRARequest("audio", 1, audio_lora_path)
vllm_outputs_per_case = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
audios=audios,
lora_request=lora_request)
for prompts, audios in inputs
]
with hf_runner(model, dtype=dtype,
auto_cls=AutoModelForSpeechSeq2Seq) as hf_model:
hf_processor = hf_model.processor
eos_token_id = hf_processor.tokenizer.eos_token_id
hf_outputs_per_case = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
num_logprobs=num_logprobs,
audios=[audios],
eos_token_id=eos_token_id)
for prompts, audios in inputs
]
for hf_outputs, vllm_outputs in zip(hf_outputs_per_case,
vllm_outputs_per_case):
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(output) for output in vllm_outputs
],
name_0="hf",
name_1="vllm",
)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_model_len", [2048])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [10])
def test_models(hf_runner, vllm_runner, model: str, audio_assets: _AudioAssets,
dtype: str, max_model_len: int, max_tokens: int,
num_logprobs: int) -> None:
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip")
audio, sr = audio_assets[0].audio_and_sample_rate
# This model expects 16k sample rate, which our test audio
# already is; if this changes, it may break this test,
# so we check it directly
assert sr == 16000
run_test(
hf_runner,
vllm_runner,
[
([HF_AUDIO_PROMPT], [audio]),
],
model,
dtype=dtype,
max_model_len=max_model_len,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
......@@ -11,7 +11,7 @@ from transformers import AutoModel, AutoTokenizer
from vllm.multimodal.audio import resample_audio_librosa
from vllm.sequence import SampleLogprobs
from ....conftest import HfRunner, VllmRunner
from ....conftest import HfRunner, VllmRunner, _AudioAssets
from ....utils import RemoteOpenAIServer
from ...registry import HF_EXAMPLE_MODELS
from ...utils import check_logprobs_close
......@@ -31,12 +31,6 @@ CHUNKED_PREFILL_KWARGS = {
}
@pytest.fixture(scope="session")
def audio_assets():
from vllm.assets.audio import AudioAsset
return [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
@pytest.fixture(scope="module", params=("mary_had_lamb", "winning_call"))
def audio(request):
from vllm.assets.audio import AudioAsset
......@@ -59,7 +53,7 @@ def params_kwargs_to_cli_args(params_kwargs: dict[str, Any]) -> list[str]:
pytest.param({}, marks=pytest.mark.cpu_model),
pytest.param(CHUNKED_PREFILL_KWARGS),
])
def server(request, audio_assets):
def server(request, audio_assets: _AudioAssets):
args = [
"--dtype", "bfloat16", "--max-model-len", "4096", "--enforce-eager",
"--limit-mm-per-prompt",
......@@ -230,8 +224,9 @@ def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int,
pytest.param({}, marks=pytest.mark.cpu_model),
pytest.param(CHUNKED_PREFILL_KWARGS),
])
def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str,
max_tokens: int, num_logprobs: int,
def test_models_with_multiple_audios(vllm_runner, audio_assets: _AudioAssets,
dtype: str, max_tokens: int,
num_logprobs: int,
vllm_kwargs: dict) -> None:
vllm_prompt = _get_prompt(len(audio_assets),
......@@ -250,7 +245,7 @@ def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str,
@pytest.mark.asyncio
async def test_online_serving(client, audio_assets):
async def test_online_serving(client, audio_assets: _AudioAssets):
"""Exercises online serving with/without chunked prefill enabled."""
messages = [{
......
......@@ -254,6 +254,7 @@ def _test_processing_correctness_mistral(
"adept/fuyu-8b",
"google/gemma-3-4b-it",
"THUDM/glm-4v-9b",
"ibm-granite/granite-speech-3.3-8b",
"h2oai/h2ovl-mississippi-800m",
"OpenGVLab/InternVL2-1B",
"HuggingFaceM4/Idefics3-8B-Llama3",
......
......@@ -301,6 +301,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
"Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"),
"GraniteSpeechForConditionalGeneration": _HfExamplesInfo("ibm-granite/granite-speech-3.3-8b", # noqa: E501
min_transformers_version="4.52.0"), # noqa: E501
"GLM4VForCausalLM": _HfExamplesInfo("THUDM/glm-4v-9b",
trust_remote_code=True,
hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501
......
......@@ -517,7 +517,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
raise TypeError(f"Unknown {modality} model type: {model_type}")
elif modality == "audio":
if model_type == "ultravox":
if model_type in ("ultravox", "granite_speech"):
return "<|audio|>"
if model_type == "phi4mm":
return f"<|audio_{current_count}|>"
......
......@@ -60,6 +60,7 @@ class Blip2QFormerMultiHeadAttention(nn.Module):
quant_config: Optional[QuantizationConfig],
cache_config: Optional[CacheConfig],
is_cross_attention: bool = False,
prefix: str = "",
) -> None:
super().__init__()
......@@ -139,7 +140,7 @@ class Blip2QFormerMultiHeadAttention(nn.Module):
class Blip2QFormerSelfOutput(nn.Module):
def __init__(self, config: Blip2QFormerConfig) -> None:
def __init__(self, config: Blip2QFormerConfig, prefix: str = "") -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
......@@ -167,6 +168,7 @@ class Blip2QFormerAttention(nn.Module):
quant_config: Optional[QuantizationConfig],
cache_config: Optional[CacheConfig],
is_cross_attention: bool = False,
prefix: str = "",
) -> None:
super().__init__()
......@@ -175,9 +177,10 @@ class Blip2QFormerAttention(nn.Module):
quant_config=quant_config,
cache_config=cache_config,
is_cross_attention=is_cross_attention,
prefix=f"{prefix}.attention",
)
self.output = Blip2QFormerSelfOutput(config)
self.output = Blip2QFormerSelfOutput(config, prefix=f"{prefix}.output")
def forward(
self,
......@@ -195,7 +198,7 @@ class Blip2QFormerAttention(nn.Module):
class Blip2QFormerIntermediate(nn.Module):
def __init__(self, config: Blip2QFormerConfig) -> None:
def __init__(self, config: Blip2QFormerConfig, prefix: str = "") -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
......@@ -209,7 +212,7 @@ class Blip2QFormerIntermediate(nn.Module):
class Blip2QFormerOutput(nn.Module):
def __init__(self, config: Blip2QFormerConfig) -> None:
def __init__(self, config: Blip2QFormerConfig, prefix: str = "") -> None:
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
......@@ -237,6 +240,7 @@ class Blip2QFormerLayer(nn.Module):
quant_config: Optional[QuantizationConfig],
cache_config: Optional[CacheConfig],
layer_idx: int,
prefix: str = "",
) -> None:
super().__init__()
......@@ -244,7 +248,8 @@ class Blip2QFormerLayer(nn.Module):
self.seq_len_dim = 1
self.attention = Blip2QFormerAttention(config,
quant_config=quant_config,
cache_config=cache_config)
cache_config=cache_config,
prefix=f"{prefix}.attention")
self.layer_idx = layer_idx
......@@ -253,13 +258,16 @@ class Blip2QFormerLayer(nn.Module):
config,
quant_config=quant_config,
cache_config=cache_config,
is_cross_attention=True)
is_cross_attention=True,
prefix=f"{prefix}.crossattention")
self.has_cross_attention = True
else:
self.has_cross_attention = False
self.intermediate_query = Blip2QFormerIntermediate(config)
self.output_query = Blip2QFormerOutput(config)
self.intermediate_query = Blip2QFormerIntermediate(
config, prefix=f"{prefix}.intermediate_query")
self.output_query = Blip2QFormerOutput(config,
prefix=f"{prefix}.output_query")
def forward(
self,
......@@ -325,6 +333,7 @@ class Blip2QFormerEncoder(nn.Module):
*,
quant_config: Optional[QuantizationConfig],
cache_config: Optional[CacheConfig],
prefix: str = "",
) -> None:
super().__init__()
......@@ -334,7 +343,8 @@ class Blip2QFormerEncoder(nn.Module):
Blip2QFormerLayer(config,
quant_config=quant_config,
cache_config=cache_config,
layer_idx=layer_idx)
layer_idx=layer_idx,
prefix=f"{prefix}.layer.{layer_idx}")
for layer_idx in range(config.num_hidden_layers)
])
......@@ -365,6 +375,7 @@ class Blip2QFormerModel(nn.Module):
*,
quant_config: Optional[QuantizationConfig],
cache_config: Optional[CacheConfig],
prefix: str = "",
) -> None:
super().__init__()
......@@ -376,7 +387,8 @@ class Blip2QFormerModel(nn.Module):
self.encoder = Blip2QFormerEncoder(config,
quant_config=quant_config,
cache_config=cache_config)
cache_config=cache_config,
prefix=f"{prefix}.encoder")
def forward(
self,
......@@ -511,7 +523,8 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
self.qformer = Blip2QFormerModel(config.qformer_config,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.qformer")
self.language_projection = nn.Linear(
config.qformer_config.hidden_size,
......
This diff is collapsed.
......@@ -178,6 +178,7 @@ _MULTIMODAL_MODELS = {
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
"Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), # noqa: E501
"GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
"GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"), # noqa: E501
"H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
"InternVLChatModel": ("internvl", "InternVLChatModel"),
"Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment