Unverified Commit 64172a97 authored by xwjiang2010's avatar xwjiang2010 Committed by GitHub
Browse files

[Feature] Add vision language model support. (#3042)

parent f408d05c
#!/bin/bash
set -ex
set -o pipefail
(which wget && which curl) || (apt-get update && apt-get install -y wget curl)
# aws s3 sync s3://air-example-data-2/vllm_opensource_llava/ images/
mkdir -p images
cd images
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/stop_sign_pixel_values.pt
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/stop_sign_image_features.pt
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/cherry_blossom_pixel_values.pt
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/cherry_blossom_image_features.pt
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/stop_sign.jpg
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/cherry_blossom.jpg
cd -
......@@ -39,9 +39,15 @@ steps:
- label: Models Test
commands:
- pytest -v -s models --forked
- bash ../.buildkite/download-images.sh
- pytest -v -s models --ignore=models/test_llava.py --forked
soft_fail: true
- label: Llava Test
commands:
- bash ../.buildkite/download-images.sh
- pytest -v -s models/test_llava.py
- label: Prefix Caching Test
commands:
- pytest -v -s prefix_caching
......
import argparse
import os
import subprocess
import torch
from vllm import LLM
from vllm.sequence import MultiModalData
# The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`.
def run_llava_pixel_values():
llm = LLM(
model="llava-hf/llava-1.5-7b-hf",
image_input_type="pixel_values",
image_token_id=32000,
image_input_shape="1,3,336,336",
image_feature_size=576,
)
prompt = "<image>" * 576 + (
"\nUSER: What is the content of this image?\nASSISTANT:")
# This should be provided by another online or offline component.
images = torch.load("images/stop_sign_pixel_values.pt")
outputs = llm.generate(prompt,
multi_modal_data=MultiModalData(
type=MultiModalData.Type.IMAGE, data=images))
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
def run_llava_image_features():
llm = LLM(
model="llava-hf/llava-1.5-7b-hf",
image_input_type="image_features",
image_token_id=32000,
image_input_shape="1,576,1024",
image_feature_size=576,
)
prompt = "<image>" * 576 + (
"\nUSER: What is the content of this image?\nASSISTANT:")
# This should be provided by another online or offline component.
images = torch.load("images/stop_sign_image_features.pt")
outputs = llm.generate(prompt,
multi_modal_data=MultiModalData(
type=MultiModalData.Type.IMAGE, data=images))
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
def main(args):
if args.type == "pixel_values":
run_llava_pixel_values()
else:
run_llava_image_features()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Demo on Llava")
parser.add_argument("--type",
type=str,
choices=["pixel_values", "image_features"],
default="pixel_values",
help="image input type")
args = parser.parse_args()
# Download from s3
s3_bucket_path = "s3://air-example-data-2/vllm_opensource_llava/"
local_directory = "images"
# Make sure the local directory exists or create it
os.makedirs(local_directory, exist_ok=True)
# Use AWS CLI to sync the directory
subprocess.check_call(
["aws", "s3", "sync", s3_bucket_path, local_directory])
main(args)
......@@ -24,6 +24,10 @@ openai
requests
ray
peft
awscli
# Benchmarking
aiohttp
# Multimodal
pillow
......@@ -3,16 +3,39 @@ from typing import List, Optional, Tuple
import pytest
import torch
from transformers import AutoModelForCausalLM
from PIL import Image
from transformers import (AutoModelForCausalLM, AutoProcessor,
LlavaForConditionalGeneration)
from vllm import LLM, SamplingParams
from vllm.config import TokenizerPoolConfig
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
from vllm.sequence import MultiModalData
from vllm.transformers_utils.tokenizer import get_tokenizer
_TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
# Multi modal related
_PIXEL_VALUES_FILES = [
os.path.join(_TEST_DIR, "images", filename) for filename in
["stop_sign_pixel_values.pt", "cherry_blossom_pixel_values.pt"]
]
_IMAGE_FEATURES_FILES = [
os.path.join(_TEST_DIR, "images", filename) for filename in
["stop_sign_image_features.pt", "cherry_blossom_image_features.pt"]
]
_IMAGE_FILES = [
os.path.join(_TEST_DIR, "images", filename)
for filename in ["stop_sign.jpg", "cherry_blossom.jpg"]
]
_IMAGE_PROMPTS = [
"<image>\nUSER: What's the content of the image?\nASSISTANT:",
"<image>\nUSER: What is the season?\nASSISTANT:"
]
assert len(_PIXEL_VALUES_FILES) == len(_IMAGE_FEATURES_FILES) == len(
_IMAGE_FILES) == len(_IMAGE_PROMPTS)
def _read_prompts(filename: str) -> List[str]:
with open(filename, "r") as f:
......@@ -20,6 +43,39 @@ def _read_prompts(filename: str) -> List[str]:
return prompts
@pytest.fixture(scope="session")
def hf_image_prompts() -> List[str]:
return _IMAGE_PROMPTS
@pytest.fixture(scope="session")
def hf_images() -> List[Image.Image]:
return [Image.open(filename) for filename in _IMAGE_FILES]
@pytest.fixture()
def vllm_images(request) -> "torch.Tensor":
vision_language_config = request.getfixturevalue("model_and_config")[1]
all_images = []
if vision_language_config.image_input_type == (
VisionLanguageConfig.ImageInputType.IMAGE_FEATURES):
filenames = _IMAGE_FEATURES_FILES
else:
filenames = _PIXEL_VALUES_FILES
for filename in filenames:
all_images.append(torch.load(filename))
return torch.concat(all_images, dim=0)
@pytest.fixture()
def vllm_image_prompts(request) -> List[str]:
vision_language_config = request.getfixturevalue("model_and_config")[1]
return [
"<image>" * (vision_language_config.image_feature_size - 1) + p
for p in _IMAGE_PROMPTS
]
@pytest.fixture
def example_prompts() -> List[str]:
prompts = []
......@@ -42,6 +98,10 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
"float": torch.float,
}
_VISION_LANGUAGE_MODELS = {
"llava-hf/llava-1.5-7b-hf": LlavaForConditionalGeneration,
}
class HfRunner:
......@@ -53,11 +113,24 @@ class HfRunner:
) -> None:
assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=True,
).cuda()
self.model_name = model_name
if model_name not in _VISION_LANGUAGE_MODELS:
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=True,
).cuda()
self.processor = None
else:
self.model = _VISION_LANGUAGE_MODELS[model_name].from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=True,
).cuda()
self.processor = AutoProcessor.from_pretrained(
model_name,
torch_dtype=torch_dtype,
)
if tokenizer_name is None:
tokenizer_name = model_name
self.tokenizer = get_tokenizer(tokenizer_name, trust_remote_code=True)
......@@ -65,13 +138,28 @@ class HfRunner:
def generate(
self,
prompts: List[str],
images: Optional[List[Image.Image]] = None,
**kwargs,
) -> List[Tuple[List[int], str]]:
outputs: List[Tuple[List[int], str]] = []
for prompt in prompts:
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
if images:
assert len(prompts) == len(images)
for i, prompt in enumerate(prompts):
if self.model_name not in _VISION_LANGUAGE_MODELS:
input_ids = self.tokenizer(prompt,
return_tensors="pt").input_ids
inputs = {"input_ids": input_ids.cuda()}
else:
image = images[i] if images else None
inputs = self.processor(text=prompt,
images=image,
return_tensors="pt")
inputs = {
key: value.cuda() if value is not None else None
for key, value in inputs.items()
}
output_ids = self.model.generate(
input_ids.cuda(),
**inputs,
use_cache=True,
**kwargs,
)
......@@ -88,10 +176,12 @@ class HfRunner:
self,
prompts: List[str],
max_tokens: int,
images: Optional["torch.Tensor"] = None,
) -> List[Tuple[List[int], str]]:
outputs = self.generate(prompts,
do_sample=False,
max_new_tokens=max_tokens)
max_new_tokens=max_tokens,
images=images)
for i in range(len(outputs)):
output_ids, output_str = outputs[i]
outputs[i] = (output_ids[0], output_str[0])
......@@ -183,9 +273,16 @@ class VllmRunner:
self,
prompts: List[str],
sampling_params: SamplingParams,
images: Optional["torch.Tensor"] = None,
) -> List[Tuple[List[int], str]]:
req_outputs = self.model.generate(prompts,
sampling_params=sampling_params)
if images is not None:
assert len(prompts) == images.shape[0]
req_outputs = self.model.generate(
prompts,
sampling_params=sampling_params,
multi_modal_data=MultiModalData(type=MultiModalData.Type.IMAGE,
data=images)
if images is not None else None)
outputs = []
for req_output in req_outputs:
prompt_str = req_output.prompt
......@@ -222,9 +319,10 @@ class VllmRunner:
self,
prompts: List[str],
max_tokens: int,
images: Optional[torch.Tensor] = None,
) -> List[Tuple[List[int], str]]:
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
outputs = self.generate(prompts, greedy_params)
outputs = self.generate(prompts, greedy_params, images=images)
return [(output_ids[0], output_str[0])
for output_ids, output_str in outputs]
......
import gc
from dataclasses import fields
from enum import Enum
from typing import Dict, List, Tuple
import pytest
import torch
from transformers import AutoTokenizer
from vllm.config import VisionLanguageConfig
model_and_vl_config = [
("llava-hf/llava-1.5-7b-hf",
VisionLanguageConfig(
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
image_feature_size=576,
image_token_id=32000,
image_input_shape=(1, 3, 336, 336))),
("llava-hf/llava-1.5-7b-hf",
VisionLanguageConfig(
image_input_type=VisionLanguageConfig.ImageInputType.IMAGE_FEATURES,
image_feature_size=576,
image_token_id=32000,
image_input_shape=(1, 576, 1024)))
]
def as_dict(vision_language_config: VisionLanguageConfig) -> Dict:
"""Flatten vision language config to pure args.
Compatible with what llm entrypoint expects.
"""
result = {}
for field in fields(vision_language_config):
value = getattr(vision_language_config, field.name)
if isinstance(value, Enum):
result[field.name] = value.name.lower()
elif isinstance(value, tuple):
result[field.name] = ",".join([str(item) for item in value])
else:
result[field.name] = value
return result
def sanitize_vllm_output(vllm_output: Tuple[List[int], str],
vision_language_config: VisionLanguageConfig,
model_id: str):
"""Sanitize vllm output to be comparable with hf output.
The function reduces `input_ids` from 1, 32000, 32000, ..., 32000,
x1, x2, x3 ... to 1, 32000, x1, x2, x3 ...
It also reduces `output_str` from "<image><image>bla" to "bla".
"""
tokenizer = AutoTokenizer.from_pretrained(model_id)
image_token_str = tokenizer.decode(vision_language_config.image_token_id)
image_token_str_len = len(image_token_str)
input_ids, output_str = vllm_output
sanitized_input_ids = input_ids[0:2] + input_ids[2 + vision_language_config
.image_feature_size - 1:]
sanitzied_output_str = output_str[vision_language_config.
image_feature_size *
image_token_str_len:]
return sanitized_input_ids, sanitzied_output_str
@pytest.mark.parametrize("worker_use_ray", [False])
@pytest.mark.parametrize("model_and_config", model_and_vl_config)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
def test_models(hf_runner, vllm_runner, hf_image_prompts, hf_images,
vllm_image_prompts, vllm_images, model_and_config: tuple,
dtype: str, max_tokens: int, worker_use_ray: bool) -> None:
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the raw images as input.
For vllm runner, we provide image tensors and corresponding
vision language config as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
model_id, vision_language_config = model_and_config
hf_model = hf_runner(model_id, dtype=dtype)
hf_outputs = hf_model.generate_greedy(hf_image_prompts,
max_tokens,
images=hf_images)
del hf_model
gc.collect()
torch.cuda.empty_cache()
vllm_model = vllm_runner(model_id,
dtype=dtype,
worker_use_ray=worker_use_ray,
**as_dict(vision_language_config))
vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts,
max_tokens,
images=vllm_images)
del vllm_model
gc.collect()
torch.cuda.empty_cache()
for i in range(len(hf_image_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
vllm_output_ids, vllm_output_str = sanitize_vllm_output(
vllm_outputs[i], vision_language_config, model_id)
assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
......@@ -25,7 +25,7 @@ MODELS = [
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
def test_models(
hf_runner,
......
......@@ -109,7 +109,7 @@ def create_worker(cls: type,
)
(model_config, cache_config, parallel_config, scheduler_config,
device_config, _) = engine_args.create_engine_configs()
device_config, _, _) = engine_args.create_engine_configs()
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
......
......@@ -35,7 +35,7 @@ def test_prepare_prompt(batch_size):
prompt_len - 1)
selected_token_start_idx += prompt_len
(input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _,
_) = (model_runner._prepare_prompt(seq_group_metadata_list))
_, _) = (model_runner._prepare_prompt(seq_group_metadata_list))
assert return_prompt_lens == prompt_lens
# Verify input metadata is correct for prompts.
......
......@@ -11,7 +11,7 @@ def test_swap() -> None:
dtype="half",
load_format="dummy")
(model_config, cache_config, parallel_config, scheduler_config,
device_config, _) = engine_args.create_engine_configs()
device_config, _, _) = engine_args.create_engine_configs()
cache_config.num_gpu_blocks = 100
cache_config.num_cpu_blocks = 100
......
import enum
import json
import os
from dataclasses import dataclass
......@@ -8,7 +9,7 @@ from packaging.version import Version
from transformers import PretrainedConfig
from vllm.logger import init_logger
from vllm.transformers_utils.config import get_config
from vllm.transformers_utils.config import get_config, get_hf_text_config
from vllm.utils import get_cpu_memory, get_nvcc_cuda_version, is_hip, is_neuron
if TYPE_CHECKING:
......@@ -118,8 +119,9 @@ class ModelConfig:
self.hf_config = get_config(self.model, trust_remote_code, revision,
code_revision)
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
self.max_model_len = _get_and_verify_max_len(self.hf_config,
self.hf_text_config = get_hf_text_config(self.hf_config)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
self.max_model_len = _get_and_verify_max_len(self.hf_text_config,
max_model_len)
self._verify_load_format()
self._verify_tokenizer_mode()
......@@ -218,7 +220,7 @@ class ModelConfig:
self,
parallel_config: "ParallelConfig",
) -> None:
total_num_attention_heads = self.hf_config.num_attention_heads
total_num_attention_heads = self.hf_text_config.num_attention_heads
tensor_parallel_size = parallel_config.tensor_parallel_size
if total_num_attention_heads % tensor_parallel_size != 0:
raise ValueError(
......@@ -226,7 +228,7 @@ class ModelConfig:
" must be divisible by tensor parallel size "
f"({tensor_parallel_size}).")
total_num_hidden_layers = self.hf_config.num_hidden_layers
total_num_hidden_layers = self.hf_text_config.num_hidden_layers
pipeline_parallel_size = parallel_config.pipeline_parallel_size
if total_num_hidden_layers % pipeline_parallel_size != 0:
raise ValueError(
......@@ -241,22 +243,23 @@ class ModelConfig:
# Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
# addition to sliding window size. We check if that field is present
# and if it's False, return None.
if (hasattr(self.hf_config, "use_sliding_window")
and not self.hf_config.use_sliding_window):
if (hasattr(self.hf_text_config, "use_sliding_window")
and not self.hf_text_config.use_sliding_window):
return None
return getattr(self.hf_config, "sliding_window", None)
return getattr(self.hf_text_config, "sliding_window", None)
def get_vocab_size(self) -> int:
return self.hf_config.vocab_size
return self.hf_text_config.vocab_size
def get_hidden_size(self) -> int:
return self.hf_config.hidden_size
return self.hf_text_config.hidden_size
def get_head_size(self) -> int:
if hasattr(self.hf_config, "head_dim"):
return self.hf_config.head_dim
if hasattr(self.hf_text_config, "head_dim"):
return self.hf_text_config.head_dim
# FIXME(woosuk): This may not be true for all models.
return self.hf_config.hidden_size // self.hf_config.num_attention_heads
return (self.hf_text_config.hidden_size //
self.hf_text_config.num_attention_heads)
def get_total_num_kv_heads(self) -> int:
"""Returns the total number of KV heads."""
......@@ -268,7 +271,7 @@ class ModelConfig:
new_decoder_arch_falcon = (
self.hf_config.model_type in falcon_model_types
and getattr(self.hf_config, "new_decoder_architecture", False))
if not new_decoder_arch_falcon and getattr(self.hf_config,
if not new_decoder_arch_falcon and getattr(self.hf_text_config,
"multi_query", False):
# Multi-query attention, only one KV head.
# Currently, tensor parallelism is not supported in this case.
......@@ -284,13 +287,13 @@ class ModelConfig:
"multi_query_group_num",
]
for attr in attributes:
num_kv_heads = getattr(self.hf_config, attr, None)
num_kv_heads = getattr(self.hf_text_config, attr, None)
if num_kv_heads is not None:
return num_kv_heads
# For non-grouped-query attention models, the number of KV heads is
# equal to the number of attention heads.
return self.hf_config.num_attention_heads
return self.hf_text_config.num_attention_heads
def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
"""Returns the number of KV heads per GPU."""
......@@ -303,7 +306,7 @@ class ModelConfig:
total_num_kv_heads // parallel_config.tensor_parallel_size)
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
total_num_hidden_layers = self.hf_config.num_hidden_layers
total_num_hidden_layers = self.hf_text_config.num_hidden_layers
return total_num_hidden_layers // parallel_config.pipeline_parallel_size
......@@ -627,6 +630,48 @@ class LoRAConfig:
"LoRA is enabled.")
@dataclass
class VisionLanguageConfig:
"""Configs the input data format and how models should run for
vision language models."""
class ImageInputType(enum.Enum):
"""Image input type into the vision language model.
An image roughly goes through the following transformation:
Raw image --> pixel values --> image features --> image embeddings.
The difference between different image input types is where the
image encoder (pixel values --> image features) is run.
Different image input types also correspond to different tensor shapes.
For example, for Llava, PIXEL_VALUES: (1, 3, 336, 336).
IMAGE_FEATURES: (1, 576, 1024).
"""
PIXEL_VALUES = enum.auto()
IMAGE_FEATURES = enum.auto()
image_input_type: ImageInputType
# The input id corresponding to image token.
image_token_id: int
# Used for running `run_prefill_max_token`.
# For models that support varying resolution, this corresponds to
# worst case scenario (biggest supported resolution).
image_input_shape: tuple
image_feature_size: int
@classmethod
def get_image_input_enum_type(
cls, value: str) -> "VisionLanguageConfig.ImageInputType":
"""Get the image input type from a string."""
try:
return cls.ImageInputType[value.upper()]
except KeyError as e:
raise ValueError(f"{value} is not a valid choice. "
f"Expecting to choose from "
f"{[x.name for x in cls.ImageInputType]}.") from e
_STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.float16,
"float16": torch.float16,
......
......@@ -388,6 +388,12 @@ class Scheduler:
computed_block_nums=self.block_manager.
get_common_computed_block_ids(seq_group),
state=seq_group.state,
# `multi_modal_data` will only be present for the 1st comm
# between engine and worker.
# the subsequent comms can still use delta, but
# `multi_modal_data` will be None.
multi_modal_data=seq_group.multi_modal_data
if scheduler_outputs.prompt_run else None,
)
seq_group_metadata_list.append(seq_group_metadata)
return seq_group_metadata_list, scheduler_outputs
......
......@@ -4,7 +4,9 @@ from dataclasses import dataclass
from typing import Optional, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, TokenizerPoolConfig)
ParallelConfig, SchedulerConfig, TokenizerPoolConfig,
VisionLanguageConfig)
from vllm.utils import str_to_int_tuple
@dataclass
......@@ -50,6 +52,11 @@ class EngineArgs:
max_cpu_loras: Optional[int] = None
device: str = 'auto'
ray_workers_use_nsight: bool = False
# Related to Vision-language models such as llava
image_input_type: Optional[str] = None
image_token_id: Optional[int] = None
image_input_shape: Optional[str] = None
image_feature_size: Optional[int] = None
scheduler_delay_factor: float = 0.0
def __post_init__(self):
......@@ -305,6 +312,31 @@ class EngineArgs:
default=EngineArgs.device,
choices=["auto", "cuda", "neuron"],
help='Device type for vLLM execution.')
# Related to Vision-language models such as llava
parser.add_argument(
'--image-input-type',
type=str,
default=None,
choices=[
t.name.lower() for t in VisionLanguageConfig.ImageInputType
],
help=('The image input type passed into vLLM. '
'Should be one of "pixel_values" or "image_features".'))
parser.add_argument('--image-token-id',
type=int,
default=None,
help=('Input id for image token.'))
parser.add_argument(
'--image-input-shape',
type=str,
default=None,
help=('The biggest image input shape (worst for memory footprint) '
'given an input type. Only used for vLLM\'s profile_run.'))
parser.add_argument(
'--image-feature-size',
type=int,
default=None,
help=('The image feature size along the context dimension.'))
parser.add_argument(
'--scheduler-delay-factor',
type=float,
......@@ -324,7 +356,8 @@ class EngineArgs:
def create_engine_configs(
self,
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig,
DeviceConfig, Optional[LoRAConfig]]:
DeviceConfig, Optional[LoRAConfig],
Optional[VisionLanguageConfig]]:
device_config = DeviceConfig(self.device)
model_config = ModelConfig(
self.model, self.tokenizer, self.tokenizer_mode,
......@@ -358,8 +391,25 @@ class EngineArgs:
lora_dtype=self.lora_dtype,
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
and self.max_cpu_loras > 0 else None) if self.enable_lora else None
if self.image_input_type:
if (not self.image_token_id or not self.image_input_shape
or not self.image_feature_size):
raise ValueError(
'Specify `image_token_id`, `image_input_shape` and '
'`image_feature_size` together with `image_input_type`.')
vision_language_config = VisionLanguageConfig(
image_input_type=VisionLanguageConfig.
get_image_input_enum_type(self.image_input_type),
image_token_id=self.image_token_id,
image_input_shape=str_to_int_tuple(self.image_input_shape),
image_feature_size=self.image_feature_size,
)
else:
vision_language_config = None
return (model_config, cache_config, parallel_config, scheduler_config,
device_config, lora_config)
device_config, lora_config, vision_language_config)
@dataclass
......
......@@ -15,6 +15,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import MultiModalData
logger = init_logger(__name__)
ENGINE_ITERATION_TIMEOUT_S = int(
......@@ -240,6 +241,7 @@ class _AsyncLLMEngine(LLMEngine):
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> None:
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
......@@ -252,14 +254,13 @@ class _AsyncLLMEngine(LLMEngine):
prompt_token_ids=prompt_token_ids,
lora_request=lora_request)
return self.add_request(
request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
arrival_time=arrival_time,
lora_request=lora_request,
)
return self.add_request(request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
arrival_time=arrival_time,
lora_request=lora_request,
multi_modal_data=multi_modal_data)
async def check_health_async(self) -> None:
self.model_executor.check_health()
......@@ -486,6 +487,7 @@ class AsyncLLMEngine:
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> AsyncStream:
if self.log_requests:
shortened_prompt = prompt
......@@ -534,7 +536,9 @@ class AsyncLLMEngine:
sampling_params=sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time,
lora_request=lora_request)
lora_request=lora_request,
multi_modal_data=multi_modal_data,
)
return stream
......@@ -545,6 +549,7 @@ class AsyncLLMEngine:
request_id: str,
prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None
) -> AsyncIterator[RequestOutput]:
"""Generate outputs for a request.
......@@ -560,6 +565,7 @@ class AsyncLLMEngine:
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
lora_request: LoRA request to use for generation, if any.
multi_modal_data: Multi modal data per request.
Yields:
The output `RequestOutput` objects from the LLMEngine for the
......@@ -619,6 +625,7 @@ class AsyncLLMEngine:
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time,
lora_request=lora_request,
multi_modal_data=multi_modal_data,
)
async for request_output in stream:
......
......@@ -5,7 +5,7 @@ from transformers import PreTrainedTokenizer
import vllm
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
from vllm.core.scheduler import Scheduler, SchedulerOutputs
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics import StatLogger, Stats
......@@ -15,8 +15,9 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
SequenceGroupOutput, SequenceOutput, SequenceStatus)
from vllm.sequence import (MultiModalData, SamplerOutput, Sequence,
SequenceGroup, SequenceGroupOutput, SequenceOutput,
SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
get_tokenizer_group)
......@@ -62,6 +63,7 @@ class LLMEngine:
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional["VisionLanguageConfig"],
executor_class: Type[ExecutorBase],
log_stats: bool,
) -> None:
......@@ -90,6 +92,7 @@ class LLMEngine:
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.vision_language_config = vision_language_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
......@@ -102,7 +105,8 @@ class LLMEngine:
self.model_executor = executor_class(model_config, cache_config,
parallel_config, scheduler_config,
device_config, lora_config)
device_config, lora_config,
vision_language_config)
# Ping the tokenizer to ensure liveness if it runs in a
# different process.
......@@ -170,7 +174,6 @@ class LLMEngine:
trust_remote_code=self.model_config.trust_remote_code,
revision=self.model_config.tokenizer_revision)
init_kwargs.update(tokenizer_init_kwargs)
self.tokenizer: BaseTokenizerGroup = get_tokenizer_group(
self.parallel_config.tokenizer_pool_config, **init_kwargs)
......@@ -212,6 +215,7 @@ class LLMEngine:
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> None:
"""Add a request to the engine's request pool.
......@@ -228,6 +232,7 @@ class LLMEngine:
use the tokenizer to convert the prompts to token IDs.
arrival_time: The arrival time of the request. If None, we use
the current monotonic time.
multi_modal_data: Multi modal data per request.
Details:
- Set arrival_time to the current time if it is None.
......@@ -288,7 +293,7 @@ class LLMEngine:
# Create the sequence group.
seq_group = SequenceGroup(request_id, [seq], sampling_params,
arrival_time, lora_request)
arrival_time, lora_request, multi_modal_data)
# Add the sequence group to the scheduler.
self.scheduler.add_seq_group(seq_group)
......
from typing import List, Optional, Union
import torch
from tqdm import tqdm
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
......@@ -8,6 +9,7 @@ from vllm.engine.llm_engine import LLMEngine
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import MultiModalData
from vllm.utils import Counter
......@@ -126,6 +128,7 @@ class LLM:
prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]:
"""Generates the completions for the input prompts.
......@@ -141,6 +144,7 @@ class LLM:
use the tokenizer to convert the prompts to token IDs.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
multi_modal_data: Multi modal data.
Returns:
A list of `RequestOutput` objects containing the generated
......@@ -160,6 +164,9 @@ class LLM:
# Use default sampling params.
sampling_params = SamplingParams()
if multi_modal_data:
multi_modal_data.data = multi_modal_data.data.to(torch.float16)
# Add requests to the engine.
num_requests = len(prompts) if prompts is not None else len(
prompt_token_ids)
......@@ -167,10 +174,17 @@ class LLM:
prompt = prompts[i] if prompts is not None else None
token_ids = None if prompt_token_ids is None else prompt_token_ids[
i]
self._add_request(prompt,
sampling_params,
token_ids,
lora_request=lora_request)
self._add_request(
prompt,
sampling_params,
token_ids,
lora_request=lora_request,
# Get ith image while maintaining the batch dim.
multi_modal_data=MultiModalData(
type=multi_modal_data.type,
data=multi_modal_data.data[i].unsqueeze(0))
if multi_modal_data else None,
)
return self._run_engine(use_tqdm)
def _add_request(
......@@ -179,13 +193,15 @@ class LLM:
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]],
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> None:
request_id = str(next(self.request_counter))
self.llm_engine.add_request(request_id,
prompt,
sampling_params,
prompt_token_ids,
lora_request=lora_request)
lora_request=lora_request,
multi_modal_data=multi_modal_data)
def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
# Initialize tqdm.
......
......@@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
from typing import Dict, List, Optional
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
......@@ -24,6 +24,7 @@ class ExecutorBase(ABC):
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
) -> None:
raise NotImplementedError
......
from typing import Dict, List, Optional
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.executor.utils import check_block_size_valid
from vllm.logger import init_logger
......@@ -23,6 +23,7 @@ class GPUExecutor(ExecutorBase):
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
) -> None:
self.model_config = model_config
self.cache_config = cache_config
......@@ -30,6 +31,7 @@ class GPUExecutor(ExecutorBase):
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.vision_language_config = vision_language_config
# Instantiate the worker and load the model to GPU.
self._init_worker()
......@@ -56,6 +58,7 @@ class GPUExecutor(ExecutorBase):
rank=0,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=True,
)
......
......@@ -6,7 +6,7 @@ from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
from vllm.engine.ray_utils import RayWorkerVllm, ray
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.executor.utils import check_block_size_valid
......@@ -40,6 +40,7 @@ class RayGPUExecutor(ExecutorBase):
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
) -> None:
self.model_config = model_config
self.cache_config = cache_config
......@@ -47,6 +48,7 @@ class RayGPUExecutor(ExecutorBase):
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.vision_language_config = vision_language_config
assert self.parallel_config.worker_use_ray
placement_group = self.parallel_config.placement_group
......@@ -181,6 +183,7 @@ class RayGPUExecutor(ExecutorBase):
driver_rank,
distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=True,
)
......
......@@ -7,9 +7,14 @@ import torch.nn as nn
from vllm.config import DeviceConfig, ModelConfig
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.llava import LlavaForConditionalGeneration
from vllm.model_executor.weight_utils import (get_quant_config,
initialize_dummy_weights)
_VISION_MODEL_CLASSES = [
LlavaForConditionalGeneration,
]
@contextlib.contextmanager
def _set_default_torch_dtype(dtype: torch.dtype):
......@@ -40,6 +45,7 @@ def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]:
def get_model(model_config: ModelConfig, device_config: DeviceConfig,
**kwargs) -> nn.Module:
lora_config = kwargs.get("lora_config", None)
vision_language_config = kwargs.get("vision_language_config", None)
model_class = _get_model_architecture(model_config)
# Get the (maybe quantized) linear method.
......@@ -76,7 +82,11 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig,
"be added in the future. If this is important to you, "
"please open an issue on github.")
else:
model = model_class(model_config.hf_config, linear_method)
if model_class not in _VISION_MODEL_CLASSES:
model = model_class(model_config.hf_config, linear_method)
else:
model = model_class(model_config.hf_config,
vision_language_config, linear_method)
if model_config.load_format == "dummy":
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment