Unverified Commit 1ca0d4f8 authored by Peter Salas's avatar Peter Salas Committed by GitHub
Browse files

[Model] Add UltravoxModel and UltravoxConfig (#7615)

parent dd53c4b0
......@@ -186,7 +186,7 @@ Multimodal Language Models
* - Architecture
- Models
- Supported Modality(ies)
- Supported Modalities
- Example HuggingFace Models
- :ref:`LoRA <lora>`
* - :code:`Blip2ForConditionalGeneration`
......@@ -234,6 +234,11 @@ Multimodal Language Models
- Image
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc.
-
* - :code: `UltravoxModel`
- Ultravox
- Audio
- :code: `fixie-ai/ultravox-v0_3`
-
.. note::
For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now.
......
"""
This example shows how to use vLLM for running offline inference
with the correct prompt format on vision language models.
For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
"""
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from vllm.assets.audio import AudioAsset
from vllm.utils import FlexibleArgumentParser
# Input audio and question
audio_and_sample_rate = AudioAsset("mary_had_lamb").audio_and_sample_rate
question = "What is recited in the audio?"
# Ultravox 0.3
def run_ultravox(question):
model_name = "fixie-ai/ultravox-v0_3"
tokenizer = AutoTokenizer.from_pretrained(model_name)
messages = [{
'role': 'user',
'content': f"<|reserved_special_token_0|>\n{question}"
}]
prompt = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
llm = LLM(model=model_name)
stop_token_ids = None
return llm, prompt, stop_token_ids
model_example_map = {
"ultravox": run_ultravox,
}
def main(args):
model = args.model_type
if model not in model_example_map:
raise ValueError(f"Model type {model} is not supported.")
llm, prompt, stop_token_ids = model_example_map[model](question)
# We set temperature to 0.2 so that outputs can be different
# even when all prompts are identical when running batch inference.
sampling_params = SamplingParams(temperature=0.2,
max_tokens=64,
stop_token_ids=stop_token_ids)
assert args.num_prompts > 0
if args.num_prompts == 1:
# Single inference
inputs = {
"prompt": prompt,
"multi_modal_data": {
"audio": audio_and_sample_rate
},
}
else:
# Batch inference
inputs = [{
"prompt": prompt,
"multi_modal_data": {
"audio": audio_and_sample_rate
},
} for _ in range(args.num_prompts)]
outputs = llm.generate(inputs, sampling_params=sampling_params)
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description='Demo on using vLLM for offline inference with '
'audio language models')
parser.add_argument('--model-type',
'-m',
type=str,
default="ultravox",
choices=model_example_map.keys(),
help='Huggingface "model_type".')
parser.add_argument('--num-prompts',
type=int,
default=1,
help='Number of prompts to run.')
args = parser.parse_args()
main(args)
"""An example showing how to use vLLM to serve VLMs.
Launch the vLLM server with the following command:
vllm serve fixie-ai/ultravox-v0_3
"""
import base64
import requests
from openai import OpenAI
from vllm.assets.audio import AudioAsset
# Modify OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"
client = OpenAI(
# defaults to os.environ.get("OPENAI_API_KEY")
api_key=openai_api_key,
base_url=openai_api_base,
)
models = client.models.list()
model = models.data[0].id
# Any format supported by librosa is supported
audio_url = AudioAsset("winning_call").url
# Use audio url in the payload
chat_completion_from_url = client.chat.completions.create(
messages=[{
"role":
"user",
"content": [
{
"type": "text",
"text": "What's in this audio?"
},
{
"type": "audio_url",
"audio_url": {
"url": audio_url
},
},
],
}],
model=model,
max_tokens=64,
)
result = chat_completion_from_url.choices[0].message.content
print(f"Chat completion output:{result}")
# Use base64 encoded audio in the payload
def encode_audio_base64_from_url(audio_url: str) -> str:
"""Encode an audio retrieved from a remote url to base64 format."""
with requests.get(audio_url) as response:
response.raise_for_status()
result = base64.b64encode(response.content).decode('utf-8')
return result
audio_base64 = encode_audio_base64_from_url(audio_url=audio_url)
chat_completion_from_base64 = client.chat.completions.create(
messages=[{
"role":
"user",
"content": [
{
"type": "text",
"text": "What's in this audio?"
},
{
"type": "audio_url",
"audio_url": {
# Any format supported by librosa is supported
"url": f"data:audio/ogg;base64,{audio_base64}"
},
},
],
}],
model=model,
max_tokens=64,
)
result = chat_completion_from_base64.choices[0].message.content
print(f"Chat completion output:{result}")
......@@ -9,14 +9,14 @@ from enum import Enum
from typing import (Any, Callable, Dict, List, Optional, Tuple, TypedDict,
TypeVar, Union)
import numpy as np
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
from huggingface_hub import snapshot_download
from PIL import Image
from transformers import (AutoModelForCausalLM, AutoModelForSeq2SeqLM,
AutoModelForVision2Seq, AutoTokenizer, BatchEncoding,
from transformers import (AutoModelForCausalLM, AutoTokenizer, BatchEncoding,
BatchFeature)
from vllm import LLM, SamplingParams
......@@ -216,8 +216,7 @@ class HfRunner:
*,
model_kwargs: Optional[Dict[str, Any]] = None,
is_embedding_model: bool = False,
is_vision_model: bool = False,
is_encoder_decoder_model: bool = False,
auto_cls=AutoModelForCausalLM,
postprocess_inputs: Callable[[BatchEncoding],
BatchEncoding] = identity,
) -> None:
......@@ -234,13 +233,6 @@ class HfRunner:
device="cpu",
).to(dtype=torch_dtype))
else:
if is_vision_model:
auto_cls = AutoModelForVision2Seq
elif is_encoder_decoder_model:
auto_cls = AutoModelForSeq2SeqLM
else:
auto_cls = AutoModelForCausalLM
model_kwargs = model_kwargs if model_kwargs is not None else {}
self.model = self.wrap_device(
auto_cls.from_pretrained(
......@@ -432,6 +424,7 @@ class HfRunner:
max_tokens: int,
num_logprobs: int,
images: Optional[List[Image.Image]] = None,
audios: Optional[List[Tuple[np.ndarray, int]]] = None,
**kwargs: Any,
) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
all_logprobs: List[List[Dict[int, float]]] = []
......@@ -446,6 +439,11 @@ class HfRunner:
if images is not None and images[i] is not None:
processor_kwargs["images"] = images[i]
if audios is not None:
audio, sr = audios[i]
processor_kwargs["audio"] = audio
processor_kwargs["sampling_rate"] = sr
inputs = self.processor(**processor_kwargs)
inputs = self.postprocess_inputs(inputs)
......@@ -627,6 +625,8 @@ class VllmRunner:
sampling_params: SamplingParams,
images: Optional[Union[List[Image.Image],
List[List[Image.Image]]]] = None,
audios: Optional[Union[List[Tuple[np.ndarray, int]],
List[List[Tuple[np.ndarray, int]]]]] = None
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
assert sampling_params.logprobs is not None
......@@ -638,6 +638,10 @@ class VllmRunner:
for i, image in enumerate(images):
inputs[i]["multi_modal_data"] = {"image": image}
if audios is not None:
for i, audio in enumerate(audios):
inputs[i]["multi_modal_data"] = {"audio": audio}
req_outputs = self.model.generate(inputs,
sampling_params=sampling_params)
return self._final_steps_generate_w_logprobs(req_outputs)
......@@ -674,6 +678,8 @@ class VllmRunner:
num_logprobs: int,
images: Optional[Union[List[Image.Image],
List[List[Image.Image]]]] = None,
audios: Optional[Union[List[Tuple[np.ndarray, int]],
List[List[Tuple[np.ndarray, int]]]]] = None,
stop_token_ids: Optional[List[int]] = None,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
greedy_logprobs_params = SamplingParams(temperature=0.0,
......@@ -682,7 +688,8 @@ class VllmRunner:
stop_token_ids=stop_token_ids)
outputs = self.generate_w_logprobs(prompts,
greedy_logprobs_params,
images=images)
images=images,
audios=audios)
return [(output_ids, output_str, output_logprobs)
for output_ids, output_str, output_logprobs in outputs]
......
......@@ -10,6 +10,7 @@ pytest distributed/test_basic_distributed_correctness_enc_dec.py
"""
import pytest
from transformers import AutoModelForSeq2SeqLM
from vllm.utils import cuda_device_count_stateless
......@@ -85,7 +86,7 @@ def test_models(
}
with hf_runner(model, dtype=dtype,
is_encoder_decoder_model=True) as hf_model:
auto_cls=AutoModelForSeq2SeqLM) as hf_model:
hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit(
test_prompts,
max_tokens,
......
import math
import sys
import time
from typing import Dict, List, Optional, Tuple, Union, cast
from unittest.mock import patch
import librosa
import numpy as np
from typing import Dict, List
import openai
import pytest
import requests
import torch
from vllm import ModelRegistry
from vllm.config import MultiModalConfig
from vllm.inputs import INPUT_REGISTRY
from vllm.inputs.data import LLMInputs
from vllm.inputs.registry import InputContext
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.opt import OPTForCausalLM
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.image import (cached_get_tokenizer,
repeat_and_pad_image_tokens)
from vllm.multimodal.utils import encode_audio_base64, fetch_audio
from vllm.utils import get_open_port
from ...utils import VLLM_PATH
from vllm.assets.audio import AudioAsset
from vllm.multimodal.utils import encode_audio_base64, fetch_audio
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
assert chatml_jinja_path.exists()
from ...utils import RemoteOpenAIServer
MODEL_NAME = "facebook/opt-125m"
MODEL_NAME = "fixie-ai/ultravox-v0_3"
TEST_AUDIO_URLS = [
"https://upload.wikimedia.org/wikipedia/en/b/bf/Dave_Niehaus_Winning_Call_1995_AL_Division_Series.ogg",
AudioAsset("winning_call").url,
]
def server_function(port):
def fake_input_mapper(ctx: InputContext, data: object):
assert isinstance(data, tuple)
(audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], data)
# Resample it to 1 sample per second
audio = librosa.resample(audio, orig_sr=sr, target_sr=1)
return MultiModalInputs({"processed_audio": torch.from_numpy(audio)})
def fake_input_processor(ctx: InputContext, llm_inputs: LLMInputs):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "audio" not in multi_modal_data:
return llm_inputs
audio, sr = multi_modal_data.get("audio")
audio_duration = math.ceil(len(audio) / sr)
new_prompt, new_token_ids = repeat_and_pad_image_tokens(
cached_get_tokenizer(ctx.model_config.tokenizer),
llm_inputs.get("prompt"),
llm_inputs["prompt_token_ids"],
image_token_id=62, # "_"
repeat_count=audio_duration)
return LLMInputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
@MULTIMODAL_REGISTRY.register_input_mapper("audio", fake_input_mapper)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"audio", lambda *_, **__: 100)
@INPUT_REGISTRY.register_input_processor(fake_input_processor)
class FakeAudioModel(OPTForCausalLM, SupportsMultiModal):
def __init__(self, *args, multimodal_config: MultiModalConfig,
**kwargs):
assert multimodal_config is not None
super().__init__(*args, **kwargs)
def forward(
self,
*args,
processed_audio: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
return super().forward(*args, **kwargs)
ModelRegistry.register_model("OPTForCausalLM", FakeAudioModel)
with patch(
"vllm.entrypoints.chat_utils._mm_token_str",
lambda *_, **__: "_"), patch(
"vllm.model_executor.models.ModelRegistry.is_multimodal_model"
) as mock:
mock.return_value = True
sys.argv = ["placeholder.py"] + \
(f"--model {MODEL_NAME} --gpu-memory-utilization 0.10 "
"--dtype bfloat16 --enforce-eager --api-key token-abc123 "
f"--port {port} --chat-template {chatml_jinja_path} "
"--disable-frontend-multiprocessing").split()
import runpy
runpy.run_module('vllm.entrypoints.openai.api_server',
run_name='__main__')
@pytest.fixture(scope="module")
def server():
args = [
"--dtype",
"bfloat16",
"--max-model-len",
"4096",
"--enforce-eager",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest.fixture(scope="module")
def client():
port = get_open_port()
ctx = torch.multiprocessing.get_context("spawn")
server = ctx.Process(target=server_function, args=(port, ))
server.start()
MAX_SERVER_START_WAIT_S = 60
client = openai.AsyncOpenAI(
base_url=f"http://localhost:{port}/v1",
api_key="token-abc123",
)
# run health check
health_url = f"http://localhost:{port}/health"
start = time.time()
while True:
try:
if requests.get(health_url).status_code == 200:
break
except Exception as err:
result = server.exitcode
if result is not None:
raise RuntimeError("Server exited unexpectedly.") from err
time.sleep(0.5)
if time.time() - start > MAX_SERVER_START_WAIT_S:
raise RuntimeError("Server failed to start in time.") from err
try:
yield client
finally:
server.kill()
def client(server):
return server.get_async_client()
@pytest.fixture(scope="session")
......@@ -176,7 +74,7 @@ async def test_single_chat_session_audio(client: openai.AsyncOpenAI,
choice = chat_completion.choices[0]
assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=36, total_tokens=46)
completion_tokens=10, prompt_tokens=202, total_tokens=212)
message = choice.message
message = chat_completion.choices[0].message
......@@ -231,7 +129,7 @@ async def test_single_chat_session_audio_base64encoded(
choice = chat_completion.choices[0]
assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=36, total_tokens=46)
completion_tokens=10, prompt_tokens=202, total_tokens=212)
message = choice.message
message = chat_completion.choices[0].message
......
......@@ -12,6 +12,7 @@ if not is_cpu():
# (xFormers, etc.)
import pytest
from transformers import AutoModelForSeq2SeqLM
from vllm.sequence import SampleLogprobs
......@@ -131,7 +132,7 @@ if not is_cpu():
}
with hf_runner(model, dtype=dtype,
is_encoder_decoder_model=True) as hf_model:
auto_cls=AutoModelForSeq2SeqLM) as hf_model:
hf_outputs = (
hf_model.generate_encoder_decoder_greedy_logprobs_limit(
test_case_prompts,
......
from typing import List, Optional, Tuple
import pytest
from transformers import AutoTokenizer
from transformers import AutoModelForVision2Seq, AutoTokenizer
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
......@@ -80,7 +80,8 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
for prompts, images in inputs_per_image
]
with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model:
with hf_runner(model, dtype=dtype,
auto_cls=AutoModelForVision2Seq) as hf_model:
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
......
from typing import List, Optional, Type
import pytest
from transformers import BatchEncoding
from transformers import AutoModelForVision2Seq, BatchEncoding
from vllm.multimodal.utils import rescale_image_size
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
......@@ -74,7 +74,7 @@ def run_test(
with hf_runner(model,
dtype=dtype,
postprocess_inputs=process,
is_vision_model=True) as hf_model:
auto_cls=AutoModelForVision2Seq) as hf_model:
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
......
from typing import List, Optional, Tuple, Type
import pytest
from transformers import AutoConfig, AutoTokenizer, BatchEncoding
from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
BatchEncoding)
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
......@@ -124,7 +125,7 @@ def run_test(
with hf_runner(model,
dtype=dtype,
postprocess_inputs=process,
is_vision_model=True) as hf_model:
auto_cls=AutoModelForVision2Seq) as hf_model:
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
......
from typing import List, Optional, Tuple, Type
import pytest
from transformers import AutoConfig, AutoTokenizer
from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer
from vllm.sequence import SampleLogprobs
......@@ -105,7 +105,8 @@ def run_test(
for prompts, images in vllm_inputs_per_image
]
with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model:
with hf_runner(model, dtype=dtype,
auto_cls=AutoModelForVision2Seq) as hf_model:
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
......
from typing import List, Optional, Tuple, Type, overload
import pytest
from transformers import AutoConfig, AutoTokenizer
from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
......@@ -129,7 +129,8 @@ def run_test(
for prompts, images in inputs_per_image
]
with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model:
with hf_runner(model, dtype=dtype,
auto_cls=AutoModelForVision2Seq) as hf_model:
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
......
......@@ -2,7 +2,7 @@ import os
from typing import List, Optional, Tuple, Type
import pytest
from transformers import AutoConfig, AutoTokenizer
from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
......@@ -102,7 +102,8 @@ def run_test(
for prompts, images in inputs_per_image
]
with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model:
with hf_runner(model, dtype=dtype,
auto_cls=AutoModelForVision2Seq) as hf_model:
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
......
......@@ -26,7 +26,7 @@ def test_text_only_qwen_model(
# for qwen-vl is still unsupported in VLLM. In the near-future, the
# implementation and this test will be extended to consider
# visual inputs as well.
with hf_runner(model, dtype=dtype, is_vision_model=False) as hf_model:
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts,
max_tokens,
......
from typing import List, Optional, Tuple, Type
import librosa
import numpy as np
import pytest
from transformers import AutoModel, AutoTokenizer, BatchEncoding
from vllm.assets.audio import AudioAsset
from vllm.sequence import SampleLogprobs
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from ..conftest import HfRunner, VllmRunner
from .utils import check_logprobs_close
pytestmark = pytest.mark.vlm
MODEL_NAME = "fixie-ai/ultravox-v0_3"
AudioTuple = Tuple[np.ndarray, int]
@pytest.fixture(scope="session")
def audio_and_sample_rate():
return AudioAsset("mary_had_lamb").audio_and_sample_rate
@pytest.fixture
def prompts_and_audios(audio_and_sample_rate):
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
vllm_placeholder = "<|reserved_special_token_0|>"
hf_placeholder = "<|audio|>"
question = "What's in the audio?"
vllm_prompt = tokenizer.apply_chat_template(
[{
'role': 'user',
'content': f"{vllm_placeholder}\n{question}"
}],
tokenize=False,
add_generation_prompt=True)
hf_prompt = tokenizer.apply_chat_template(
[{
'role': 'user',
'content': f"{hf_placeholder}\n{question}"
}],
tokenize=False,
add_generation_prompt=True)
return [(vllm_prompt, hf_prompt, audio_and_sample_rate)]
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
Optional[SampleLogprobs]],
model: str):
"""Sanitize vllm output to be comparable with hf output."""
output_ids, output_str, out_logprobs = vllm_output
tokenizer = AutoTokenizer.from_pretrained(model)
eos_token_id = tokenizer.eos_token_id
hf_output_ids = output_ids[:]
hf_output_str = output_str
if hf_output_ids[-1] == eos_token_id:
hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)
return hf_output_ids, hf_output_str, out_logprobs
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
prompts_and_audios: List[Tuple[str, str, AudioTuple]],
model: str,
*,
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
"""Inference result should be the same between hf and vllm."""
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
with vllm_runner(model,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model:
vllm_outputs_per_audio = [
vllm_model.generate_greedy_logprobs([vllm_prompt],
max_tokens,
num_logprobs=num_logprobs,
audios=[audio])
for vllm_prompt, _, audio in prompts_and_audios
]
def process(hf_inputs: BatchEncoding):
hf_inputs["audio_values"] = hf_inputs["audio_values"] \
.to(torch_dtype) # type: ignore
return hf_inputs
with hf_runner(model,
dtype=dtype,
postprocess_inputs=process,
auto_cls=AutoModel) as hf_model:
hf_outputs_per_audio = [
hf_model.generate_greedy_logprobs_limit(
[hf_prompt],
max_tokens,
num_logprobs=num_logprobs,
audios=[(librosa.resample(audio[0],
orig_sr=audio[1],
target_sr=16000), 16000)])
for _, hf_prompt, audio in prompts_and_audios
]
for hf_outputs, vllm_outputs in zip(hf_outputs_per_audio,
vllm_outputs_per_audio):
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output, model)
for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, prompts_and_audios, dtype: str,
max_tokens: int, num_logprobs: int) -> None:
run_test(
hf_runner,
vllm_runner,
prompts_and_audios,
MODEL_NAME,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
from dataclasses import dataclass
from typing import Literal, Tuple
from urllib.parse import urljoin
import librosa
import numpy as np
from vllm.assets.base import get_vllm_public_assets, vLLM_S3_BUCKET_URL
ASSET_DIR = "multimodal_asset"
@dataclass(frozen=True)
class AudioAsset:
name: Literal["winning_call", "mary_had_lamb"]
@property
def audio_and_sample_rate(self) -> Tuple[np.ndarray, int]:
audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg",
s3_prefix=ASSET_DIR)
return librosa.load(audio_path, sr=None)
@property
def url(self) -> str:
return urljoin(vLLM_S3_BUCKET_URL, f"{ASSET_DIR}/{self.name}.ogg")
......@@ -117,8 +117,8 @@ def _mm_token_str(model_config: ModelConfig, tokenizer: AnyTokenizer,
modality: Literal["image", "audio"]) -> Optional[str]:
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
model_type = model_config.hf_config.model_type
if modality == "image":
model_type = model_config.hf_config.model_type
if model_type == "phi3_v":
# Workaround since this token is not defined in the tokenizer
return "<|image_1|>"
......@@ -134,7 +134,9 @@ def _mm_token_str(model_config: ModelConfig, tokenizer: AnyTokenizer,
raise TypeError(f"Unknown model type: {model_type}")
elif modality == "audio":
raise TypeError("No audio models are supported yet.")
if model_type == "ultravox":
return "<|reserved_special_token_0|>"
raise TypeError(f"Unknown model type: {model_type}")
else:
raise TypeError(f"Unknown modality: {modality}")
......
......@@ -61,7 +61,7 @@ _GENERATION_MODELS = {
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
"MedusaModel": ("medusa", "Medusa"),
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
"JambaForCausalLM": ("jamba", "JambaForCausalLM")
"JambaForCausalLM": ("jamba", "JambaForCausalLM"),
}
_EMBEDDING_MODELS = {
......@@ -83,6 +83,7 @@ _MULTIMODAL_MODELS = {
"PaliGemmaForConditionalGeneration": ("paligemma",
"PaliGemmaForConditionalGeneration"),
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"UltravoxModel": ("ultravox", "UltravoxModel"),
}
_CONDITIONAL_GENERATION_MODELS = {
"BartModel": ("bart", "BartForConditionalGeneration"),
......
......@@ -15,8 +15,8 @@ from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal.image import (cached_get_tokenizer,
repeat_and_pad_image_tokens)
from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
......@@ -97,11 +97,11 @@ def input_processor_for_blip(
else:
image_feature_size = image_feature_size_override
new_prompt, new_token_ids = repeat_and_pad_image_tokens(
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
tokenizer,
llm_inputs.get("prompt"),
llm_inputs["prompt_token_ids"],
image_token_id=image_token_id,
placeholder_token_id=image_token_id,
repeat_count=image_feature_size,
)
......
......@@ -30,8 +30,8 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import (cached_get_tokenizer,
repeat_and_pad_image_tokens)
from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SamplerOutput, SequenceData)
from vllm.utils import print_warning_once
......@@ -124,11 +124,11 @@ def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs):
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer)
new_prompt, new_token_ids = repeat_and_pad_image_tokens(
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
tokenizer,
llm_inputs.get("prompt"),
llm_inputs["prompt_token_ids"],
image_token_id=CHAMELEON_IMAGE_TOKEN_ID,
placeholder_token_id=CHAMELEON_IMAGE_TOKEN_ID,
repeat_count=CHAMELEON_IMAGE_SEQ_LENGTH,
pad_token_left=CHAMELEON_IMAGE_START_TOKEN_ID,
pad_token_right=CHAMELEON_IMAGE_END_TOKEN_ID,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment