Unverified Commit 3b7fea77 authored by Yang Fan's avatar Yang Fan Committed by GitHub
Browse files

[Model][VLM] Add Qwen2-VL model support (#7905)


Co-authored-by: default avatarRoger Wang <136131678+ywang96@users.noreply.github.com>
Co-authored-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent cea95dfb
...@@ -252,6 +252,11 @@ Multimodal Language Models ...@@ -252,6 +252,11 @@ Multimodal Language Models
- Image\ :sup:`E` - Image\ :sup:`E`
- :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc. - :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc.
- -
* - :code:`Qwen2VLForConditionalGeneration`
- Qwen2-VL (see note)
- Image\ :sup:`+` / Video\ :sup:`+`
- :code:`Qwen/Qwen2-VL-2B-Instruct`, :code:`Qwen/Qwen2-VL-7B-Instruct`, :code:`Qwen/Qwen2-VL-72B-Instruct`, etc.
-
* - :code:`UltravoxModel` * - :code:`UltravoxModel`
- Ultravox - Ultravox
- Audio\ :sup:`E+` - Audio\ :sup:`E+`
...@@ -265,15 +270,14 @@ Multimodal Language Models ...@@ -265,15 +270,14 @@ Multimodal Language Models
For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now. For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now.
For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630
For :code:`LLaVA-NeXT-Video`, the latest release of :code:`huggingface/transformers` doesn't work yet, so we need to use a developer version (:code:`21fac7abba2a37fae86106f87fcf9974fd1e3830`) for now. .. note::
For :code:`LLaVA-NeXT-Video` and :code:`Qwen2-VL`, the latest release of :code:`huggingface/transformers` doesn't work yet, so we need to use a developer version (:code:`21fac7abba2a37fae86106f87fcf9974fd1e3830`) for now.
This can be installed by running the following command: This can be installed by running the following command:
.. code-block:: bash .. code-block:: bash
pip install git+https://github.com/huggingface/transformers.git@21fac7abba2a37fae86106f87fcf9974fd1e3830 pip install git+https://github.com/huggingface/transformers.git@21fac7abba2a37fae86106f87fcf9974fd1e3830
---- ----
If your model uses one of the above model architectures, you can seamlessly run your model with vLLM. If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.
......
...@@ -179,6 +179,23 @@ def run_qwen_vl(question): ...@@ -179,6 +179,23 @@ def run_qwen_vl(question):
return llm, prompt, stop_token_ids return llm, prompt, stop_token_ids
# Qwen2-VL
def run_qwen2_vl(question):
model_name = "Qwen/Qwen2-VL-7B-Instruct"
llm = LLM(
model=model_name,
max_num_seqs=5,
)
prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
f"{question}<|im_end|>\n"
"<|im_start|>assistant\n")
stop_token_ids = None
return llm, prompt, stop_token_ids
model_example_map = { model_example_map = {
"llava": run_llava, "llava": run_llava,
"llava-next": run_llava_next, "llava-next": run_llava_next,
...@@ -191,6 +208,7 @@ model_example_map = { ...@@ -191,6 +208,7 @@ model_example_map = {
"blip-2": run_blip2, "blip-2": run_blip2,
"internvl_chat": run_internvl, "internvl_chat": run_internvl,
"qwen_vl": run_qwen_vl, "qwen_vl": run_qwen_vl,
"qwen2_vl": run_qwen2_vl,
} }
......
...@@ -6,7 +6,7 @@ by the model. ...@@ -6,7 +6,7 @@ by the model.
from argparse import Namespace from argparse import Namespace
from typing import List from typing import List
from transformers import AutoTokenizer from transformers import AutoProcessor, AutoTokenizer
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.multimodal.utils import fetch_image from vllm.multimodal.utils import fetch_image
...@@ -30,7 +30,7 @@ def load_phi3v(question, image_urls: List[str]): ...@@ -30,7 +30,7 @@ def load_phi3v(question, image_urls: List[str]):
for i, _ in enumerate(image_urls, start=1)) for i, _ in enumerate(image_urls, start=1))
prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n" prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n"
stop_token_ids = None stop_token_ids = None
return llm, prompt, stop_token_ids return llm, prompt, stop_token_ids, None
def load_internvl(question, image_urls: List[str]): def load_internvl(question, image_urls: List[str]):
...@@ -60,18 +60,72 @@ def load_internvl(question, image_urls: List[str]): ...@@ -60,18 +60,72 @@ def load_internvl(question, image_urls: List[str]):
# https://huggingface.co/OpenGVLab/InternVL2-2B#service # https://huggingface.co/OpenGVLab/InternVL2-2B#service
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"] stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens] stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
return llm, prompt, stop_token_ids
return llm, prompt, stop_token_ids, None
def load_qwen2_vl(question, image_urls: List[str]):
try:
from qwen_vl_utils import process_vision_info
except ModuleNotFoundError:
print('WARNING: `qwen-vl-utils` not installed, input images will not '
'be automatically resized. You can enable this functionality by '
'`pip install qwen-vl-utils`.')
process_vision_info = None
model_name = "Qwen/Qwen2-VL-7B-Instruct"
llm = LLM(
model=model_name,
max_num_seqs=5,
max_model_len=32768 if process_vision_info is None else 4096,
limit_mm_per_prompt={"image": len(image_urls)},
)
placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role":
"user",
"content": [
*placeholders,
{
"type": "text",
"text": question
},
],
}]
processor = AutoProcessor.from_pretrained(model_name)
prompt = processor.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
stop_token_ids = None
if process_vision_info is None:
image_data = [fetch_image(url) for url in image_urls]
else:
image_data, _ = process_vision_info(messages)
return llm, prompt, stop_token_ids, image_data
model_example_map = { model_example_map = {
"phi3_v": load_phi3v, "phi3_v": load_phi3v,
"internvl_chat": load_internvl, "internvl_chat": load_internvl,
"qwen2_vl": load_qwen2_vl,
} }
def run_generate(model, question: str, image_urls: List[str]): def run_generate(model, question: str, image_urls: List[str]):
llm, prompt, stop_token_ids = model_example_map[model](question, llm, prompt, stop_token_ids, image_data = model_example_map[model](
image_urls) question, image_urls)
if image_data is None:
image_data = [fetch_image(url) for url in image_urls]
sampling_params = SamplingParams(temperature=0.0, sampling_params = SamplingParams(temperature=0.0,
max_tokens=128, max_tokens=128,
...@@ -81,7 +135,7 @@ def run_generate(model, question: str, image_urls: List[str]): ...@@ -81,7 +135,7 @@ def run_generate(model, question: str, image_urls: List[str]):
{ {
"prompt": prompt, "prompt": prompt,
"multi_modal_data": { "multi_modal_data": {
"image": [fetch_image(url) for url in image_urls] "image": image_data
}, },
}, },
sampling_params=sampling_params) sampling_params=sampling_params)
...@@ -92,7 +146,7 @@ def run_generate(model, question: str, image_urls: List[str]): ...@@ -92,7 +146,7 @@ def run_generate(model, question: str, image_urls: List[str]):
def run_chat(model: str, question: str, image_urls: List[str]): def run_chat(model: str, question: str, image_urls: List[str]):
llm, _, stop_token_ids = model_example_map[model](question, image_urls) llm, _, stop_token_ids, _ = model_example_map[model](question, image_urls)
sampling_params = SamplingParams(temperature=0.0, sampling_params = SamplingParams(temperature=0.0,
max_tokens=128, max_tokens=128,
......
...@@ -28,3 +28,4 @@ importlib_metadata ...@@ -28,3 +28,4 @@ importlib_metadata
mistral_common >= 1.3.4 mistral_common >= 1.3.4
pyyaml pyyaml
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
einops # Required for Qwen2-VL.
import pytest import pytest
import transformers
from vllm.model_executor.models import _MODELS, ModelRegistry from vllm.model_executor.models import _MODELS, ModelRegistry
@pytest.mark.parametrize("model_cls", _MODELS) @pytest.mark.parametrize("model_cls", _MODELS)
def test_registry_imports(model_cls): def test_registry_imports(model_cls):
if (model_cls == "Qwen2VLForConditionalGeneration"
and transformers.__version__ < "4.45"):
pytest.skip("Waiting for next transformers release")
# Ensure all model classes can be imported successfully # Ensure all model classes can be imported successfully
ModelRegistry.resolve_model_cls([model_cls]) ModelRegistry.resolve_model_cls([model_cls])
...@@ -1733,6 +1733,9 @@ def _get_and_verify_max_len( ...@@ -1733,6 +1733,9 @@ def _get_and_verify_max_len(
"with rope_scaling. Please raise an issue so we can " "with rope_scaling. Please raise an issue so we can "
"investigate.") "investigate.")
if rope_type == "mrope":
scaling_factor = 1
else:
assert "factor" in rope_scaling assert "factor" in rope_scaling
scaling_factor = rope_scaling["factor"] scaling_factor = rope_scaling["factor"]
if rope_type == "yarn": if rope_type == "yarn":
......
...@@ -108,7 +108,7 @@ class ConversationMessage(TypedDict, total=False): ...@@ -108,7 +108,7 @@ class ConversationMessage(TypedDict, total=False):
"""The tool calls generated by the model, such as function calls.""" """The tool calls generated by the model, such as function calls."""
ModalityStr = Literal["image", "audio"] ModalityStr = Literal["image", "audio", "video"]
_T = TypeVar("_T") _T = TypeVar("_T")
...@@ -158,12 +158,18 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -158,12 +158,18 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
hf_config.image_token_index) hf_config.image_token_index)
if model_type in ("chameleon", "internvl_chat"): if model_type in ("chameleon", "internvl_chat"):
return "<image>" return "<image>"
if model_type == "qwen2_vl":
return "<|vision_start|><|image_pad|><|vision_end|>"
raise TypeError(f"Unknown model type: {model_type}") raise TypeError(f"Unknown model type: {model_type}")
elif modality == "audio": elif modality == "audio":
if model_type == "ultravox": if model_type == "ultravox":
return "<|reserved_special_token_0|>" return "<|reserved_special_token_0|>"
raise TypeError(f"Unknown model type: {model_type}") raise TypeError(f"Unknown model type: {model_type}")
elif modality == "video":
if model_type == "qwen2_vl":
return "<|vision_start|><|video_pad|><|vision_end|>"
raise TypeError(f"Unknown model type: {model_type}")
else: else:
raise TypeError(f"Unknown modality: {modality}") raise TypeError(f"Unknown modality: {modality}")
......
...@@ -712,6 +712,179 @@ class Llama3RotaryEmbedding(RotaryEmbedding): ...@@ -712,6 +712,179 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
return new_freqs return new_freqs
class MRotaryEmbedding(RotaryEmbedding):
"""Rotary Embedding with Multimodal Sections."""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
mrope_section: Optional[List[int]] = None,
) -> None:
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)
self.mrope_section = mrope_section
if self.mrope_section:
assert sum(self.mrope_section) == rotary_dim // 2
def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward().
Args:
positions:
[num_tokens,] (text only) or
[3, num_tokens] (T/H/W positions with multimodal inputs)
query: [num_tokens, num_heads * head_size]
key: [num_tokens, num_kv_heads * head_size]
"""
assert positions.ndim == 1 or positions.ndim == 2
num_tokens = positions.shape[-1]
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
if positions.ndim == 2:
assert self.mrope_section
cos = torch.cat([
m[i]
for i, m in enumerate(cos.split(self.mrope_section, dim=-1))
],
dim=-1)
sin = torch.cat([
m[i]
for i, m in enumerate(sin.split(self.mrope_section, dim=-1))
],
dim=-1)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., :self.rotary_dim]
query_pass = query[..., self.rotary_dim:]
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim:]
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
@staticmethod
def get_input_positions(
input_tokens: List[int],
image_grid_thw: Union[List[List[int]], torch.Tensor],
video_grid_thw: Union[List[List[int]], torch.Tensor],
image_token_id: int,
video_token_id: int,
vision_start_token_id: int,
vision_end_token_id: int,
spatial_merge_size: int,
context_len: int = 0,
) -> Tuple[List[List[int]], int]:
"""Get mrope input positions and delta value."""
if isinstance(image_grid_thw, torch.Tensor):
image_grid_thw = image_grid_thw.tolist()
if isinstance(video_grid_thw, torch.Tensor):
video_grid_thw = video_grid_thw.tolist()
input_tokens_tensor = torch.tensor(input_tokens)
vision_start_indices = torch.argwhere(
input_tokens_tensor == vision_start_token_id).squeeze(1)
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
image_index, video_index = 0, 0
for _ in range(image_nums + video_nums):
if image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(image_token_id, st)
else:
ed_image = len(input_tokens) + 1
if video_token_id in input_tokens and remain_videos > 0:
ed_video = input_tokens.index(video_token_id, st)
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = \
t, h // spatial_merge_size, w // spatial_merge_size
text_len = ed - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
llm_grid_t, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
llm_positions = llm_positions[:, context_len:]
mrope_position_delta = (llm_positions.max() + 1 -
len(input_tokens)).item()
return llm_positions.tolist(), mrope_position_delta
@staticmethod
def get_next_input_positions(
mrope_position_delta: int,
context_len: int,
seq_len: int,
) -> List[List[int]]:
return [
list(
range(context_len + mrope_position_delta,
seq_len + mrope_position_delta)) for _ in range(3)
]
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
...@@ -752,7 +925,7 @@ def get_rope( ...@@ -752,7 +925,7 @@ def get_rope(
# The correct one should be "longrope" but keep "su" here # The correct one should be "longrope" but keep "su" here
# for backward compatible # for backward compatible
if scaling_type not in {"su", "longrope"}: if scaling_type not in {"su", "longrope"}:
scaling_factor = rope_scaling["factor"] scaling_factor = rope_scaling.get("factor", 1.0)
if scaling_type == "llama3": if scaling_type == "llama3":
low_freq_factor = rope_scaling["low_freq_factor"] low_freq_factor = rope_scaling["low_freq_factor"]
high_freq_factor = rope_scaling["high_freq_factor"] high_freq_factor = rope_scaling["high_freq_factor"]
...@@ -816,6 +989,16 @@ def get_rope( ...@@ -816,6 +989,16 @@ def get_rope(
head_size, rotary_dim, max_position, original_max_position, head_size, rotary_dim, max_position, original_max_position,
base, is_neox_style, dtype, short_factor, long_factor, base, is_neox_style, dtype, short_factor, long_factor,
**extra_kwargs) **extra_kwargs)
elif scaling_type == "mrope":
return MRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
mrope_section=rope_scaling["mrope_section"],
)
else: else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}") raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
_ROPE_DICT[key] = rotary_emb _ROPE_DICT[key] = rotary_emb
......
...@@ -53,6 +53,8 @@ _GENERATION_MODELS = { ...@@ -53,6 +53,8 @@ _GENERATION_MODELS = {
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"), "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"), "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
"Qwen2VLForConditionalGeneration":
("qwen2_vl", "Qwen2VLForConditionalGeneration"),
"RWForCausalLM": ("falcon", "FalconForCausalLM"), "RWForCausalLM": ("falcon", "FalconForCausalLM"),
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
...@@ -90,6 +92,8 @@ _MULTIMODAL_MODELS = { ...@@ -90,6 +92,8 @@ _MULTIMODAL_MODELS = {
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"UltravoxModel": ("ultravox", "UltravoxModel"), "UltravoxModel": ("ultravox", "UltravoxModel"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2VLForConditionalGeneration": ("qwen2_vl",
"Qwen2VLForConditionalGeneration"),
} }
_CONDITIONAL_GENERATION_MODELS = { _CONDITIONAL_GENERATION_MODELS = {
"BartModel": ("bart", "BartForConditionalGeneration"), "BartModel": ("bart", "BartForConditionalGeneration"),
......
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
from array import array
from functools import lru_cache, partial
from typing import (Iterable, List, Mapping, Optional, Tuple, Type, TypedDict,
Union)
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from PIL import Image
from transformers import Qwen2VLConfig
from transformers.image_utils import (get_image_size,
infer_channel_dimension_format,
to_numpy_array)
from transformers.models.qwen2_vl.configuration_qwen2_vl import (
Qwen2VLVisionConfig)
from transformers.models.qwen2_vl.image_processing_qwen2_vl import (
make_batched_images, make_batched_videos, smart_resize)
import vllm.envs as envs
from vllm.attention import AttentionMetadata
from vllm.attention.selector import (_Backend, backend_name_to_enum,
get_global_forced_attn_backend)
from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.activation import QuickGELU
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
MultiModalInputs)
from vllm.multimodal.base import MultiModalData
from vllm.multimodal.image import cached_get_image_processor
from vllm.platforms import current_platform
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from vllm.transformers_utils.processor import get_processor
logger = init_logger(__name__)
# === Vision Inputs === #
class Qwen2VLImageInputs(TypedDict):
pixel_values: torch.Tensor
"""Shape:
`(num_patches, num_channels * patch_size * patch_size)`
"""
image_grid_thw: torch.Tensor
"""Shape: `(num_images, 3)`
This should be in `(grid_t, grid_h, grid_w)` format.
"""
class Qwen2VLVideoInputs(TypedDict):
pixel_values_videos: torch.Tensor
"""Shape:
`(num_patches,
num_channels * temporal_patch_size * patch_size * patch_size)`
"""
video_grid_thw: torch.Tensor
"""Shape: `(num_videos, 3)`
This should be in `(grid_t, grid_h, grid_w)` format.
"""
# === Vision Encoder === #
class Qwen2VisionMLP(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: int = None,
act_layer: Type[nn.Module] = QuickGELU,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.fc1 = ColumnParallelLinear(in_features,
hidden_features,
quant_config=quant_config)
self.act = act_layer()
self.fc2 = RowParallelLinear(hidden_features,
in_features,
quant_config=quant_config)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_parallel, _ = self.fc1(x)
x_parallel = self.act(x_parallel)
x, _ = self.fc2(x_parallel)
return x
def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
if not interleaved:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
else:
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(torch.stack((-x2, x1), dim=-1),
"... d two -> ... (d two)",
two=2)
def apply_rotary_emb_torch(x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
interleaved: bool = False) -> torch.Tensor:
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
"""
ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
cos = repeat(
cos,
"... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
sin = repeat(
sin,
"... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
return torch.cat(
[
x[..., :ro_dim] * cos +
rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]
],
dim=-1,
)
def apply_rotary_pos_emb_vision(t: torch.Tensor,
freqs: torch.Tensor) -> torch.Tensor:
t_ = t.float()
cos = freqs.cos()
sin = freqs.sin()
output = apply_rotary_emb_torch(t_, cos, sin).type_as(t)
return output
class Qwen2VisionAttention(nn.Module):
def __init__(
self,
embed_dim: Optional[int] = None,
num_heads: Optional[int] = None,
projection_size: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
# Per attention head and per partition values.
world_size = parallel_state.get_tensor_model_parallel_world_size()
self.hidden_size_per_attention_head = dist_utils.divide(
projection_size, num_heads)
self.num_attention_heads_per_partition = dist_utils.divide(
num_heads, world_size)
self.qkv = ColumnParallelLinear(input_size=embed_dim,
output_size=3 * projection_size,
quant_config=quant_config)
self.proj = RowParallelLinear(input_size=projection_size,
output_size=embed_dim,
quant_config=quant_config)
# Detect attention implementation.
selected_backend: Optional[_Backend] = get_global_forced_attn_backend()
if selected_backend is None:
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
if backend_by_env_var is not None:
selected_backend = backend_name_to_enum(backend_by_env_var)
if selected_backend is None:
# For Volta and Turing GPUs, use xformers instead.
device_available = current_platform.get_device_capability()[0] >= 8
if device_available:
from transformers.utils import is_flash_attn_2_available
if is_flash_attn_2_available():
self._use_flash_attn = True
else:
logger.warning(
"Current Qwen2-VL implementation has a bug with "
"`vllm-flash-attn` inside vision module, so we use "
"xformers backend instead. You can run `pip install "
"flash-attn to use flash-attention backend.")
self._use_flash_attn = False
else:
self._use_flash_attn = False
else:
if selected_backend == _Backend.FLASH_ATTN:
self._use_flash_attn = True
elif selected_backend == _Backend.XFORMERS:
self._use_flash_attn = False
else:
raise RuntimeError(
f"Qwen2-VL does not support {selected_backend} backend now."
)
def forward(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor = None,
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)
# [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
new_x_shape = x.size()[:-1] + (
self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head,
)
x = x.view(*new_x_shape)
# [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
q, k, v = dist_utils.split_tensor_along_last_dim(x, 3)
batch_size = q.shape[1]
q, k, v = [
rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
]
if rotary_pos_emb is not None:
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
if self._use_flash_attn:
# from vllm_flash_attn.flash_attn_interface import (
# flash_attn_varlen_func)
from flash_attn import flash_attn_varlen_func
q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
output = flash_attn_varlen_func(q,
k,
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0,
causal=False)
context_layer = rearrange(output,
"(b s) ... -> b s ...",
b=batch_size)
else:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens,
kv_seqlen=None)
context_layer = xops.memory_efficient_attention_forward(
q, k, v, attn_bias=attn_bias, p=0, scale=None)
context_layer = rearrange(context_layer,
"b s h d -> s b (h d)").contiguous()
output, _ = self.proj(context_layer)
return output
class Qwen2VisionBlock(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float,
act_layer: Type[nn.Module] = QuickGELU,
norm_layer: Type[nn.Module] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.norm1 = norm_layer(dim)
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.attn = Qwen2VisionAttention(embed_dim=dim,
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config)
self.mlp = Qwen2VisionMLP(dim,
mlp_hidden_dim,
act_layer=act_layer,
quant_config=quant_config)
def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.norm1(x),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb)
x = x + self.mlp(self.norm2(x))
return x
class Qwen2VisionPatchEmbed(nn.Module):
def __init__(
self,
patch_size: int = 14,
temporal_patch_size: int = 2,
in_chans: int = 3,
embed_dim: int = 1152,
) -> None:
super().__init__()
self.patch_size = patch_size
self.temporal_patch_size = temporal_patch_size
self.embed_dim = embed_dim
kernel_size = [temporal_patch_size, patch_size, patch_size]
self.proj = nn.Conv3d(in_chans,
embed_dim,
kernel_size=kernel_size,
stride=kernel_size,
bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
L, C = x.shape
x = x.view(L, -1, self.temporal_patch_size, self.patch_size,
self.patch_size)
x = self.proj(x).view(L, self.embed_dim)
return x
class Qwen2VisionPatchMerger(nn.Module):
def __init__(
self,
d_model: int,
context_dim: int,
norm_layer: Type[nn.Module] = None,
spatial_merge_size: int = 2,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = context_dim * (spatial_merge_size**2)
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.ln_q = norm_layer(context_dim)
self.mlp = nn.ModuleList([
ColumnParallelLinear(self.hidden_size,
self.hidden_size,
bias=True,
quant_config=quant_config),
nn.GELU(),
RowParallelLinear(self.hidden_size,
d_model,
bias=True,
quant_config=quant_config),
])
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.ln_q(x)
x = x.view(-1, self.hidden_size)
mlp_fc1, mlp_act, mlp_fc2 = self.mlp
x_parallel, _ = mlp_fc1(x)
x_parallel = mlp_act(x_parallel)
out, _ = mlp_fc2(x_parallel)
return out
class Qwen2VisionRotaryEmbedding(nn.Module):
def __init__(self, dim: int, theta: float = 10000.0) -> None:
super().__init__()
self.dim = dim
self.theta = theta
inv_freq = 1.0 / (theta
**(torch.arange(0, dim, 2, dtype=torch.float) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._seq_len_cached = 0
self._freqs_cached = None
def update_freqs_cache(self, seqlen: int) -> None:
if seqlen > self._seq_len_cached:
seqlen *= 2
self._seq_len_cached = seqlen
self.inv_freq = 1.0 / (self.theta**(torch.arange(
0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device)
/ self.dim))
seq = torch.arange(seqlen,
device=self.inv_freq.device,
dtype=self.inv_freq.dtype)
freqs = torch.outer(seq, self.inv_freq)
self._freqs_cached = freqs
def forward(self, seqlen: int) -> torch.Tensor:
self.update_freqs_cache(seqlen)
return self._freqs_cached[:seqlen]
class Qwen2VisionTransformer(nn.Module):
def __init__(
self,
vision_config: Qwen2VLVisionConfig,
norm_eps: float = 1e-6,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
patch_size: int = vision_config.patch_size
temporal_patch_size: int = vision_config.temporal_patch_size
spatial_merge_size: int = vision_config.spatial_merge_size
in_chans: int = vision_config.in_chans
hidden_size: int = vision_config.hidden_size
embed_dim: int = vision_config.embed_dim
depth: int = vision_config.depth
num_heads: int = vision_config.num_heads
mlp_ratio: float = vision_config.mlp_ratio
self.spatial_merge_size = spatial_merge_size
self.patch_embed = Qwen2VisionPatchEmbed(
patch_size=patch_size,
temporal_patch_size=temporal_patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
)
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
head_dim = embed_dim // num_heads
self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2)
self.blocks = nn.ModuleList([
Qwen2VisionBlock(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
norm_layer=norm_layer,
quant_config=quant_config,
) for _ in range(depth)
])
self.merger = Qwen2VisionPatchMerger(
d_model=hidden_size,
context_dim=embed_dim,
norm_layer=norm_layer,
quant_config=quant_config,
)
@property
def dtype(self) -> torch.dtype:
return self.blocks[0].mlp.fc2.weight.dtype
@property
def device(self) -> torch.device:
return self.blocks[0].mlp.fc2.weight.device
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
pos_ids = []
for t, h, w in grid_thw:
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
hpos_ids = hpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
).permute(0, 2, 1, 3).flatten()
wpos_ids = wpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
).permute(0, 2, 1, 3).flatten()
pos_ids.append(
torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
def forward(
self,
x: torch.Tensor,
grid_thw: torch.Tensor,
) -> torch.Tensor:
# patchify
x = x.to(device=self.device, dtype=self.dtype)
x = self.patch_embed(x)
# compute position embedding
rotary_pos_emb = self.rot_pos_emb(grid_thw)
# compute cu_seqlens
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2],
grid_thw[:, 0]).cumsum(
dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
# transformers
x = x.unsqueeze(1)
for blk in self.blocks:
x = blk(x, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
# adapter
x = self.merger(x)
return x
# === Vision input helpers === #
cached_get_processor = lru_cache(get_processor)
def mm_input_mapper_for_qwen2_vl(
ctx: InputContext,
data: MultiModalData[object],
data_type_key: str,
) -> MultiModalInputs:
"""Input mapper for Qwen2-VL."""
model_config = ctx.model_config
image_processor = cached_get_image_processor(
model_config.model, trust_remote_code=model_config.trust_remote_code)
if image_processor is None:
raise RuntimeError("No HuggingFace processor is available "
"to process the image object")
images = None
videos = None
if data_type_key == "image":
images = data
else:
assert data_type_key == "video"
videos = data
try:
batch_data = image_processor \
.preprocess(images=images, videos=videos, return_tensors="pt") \
.data
except Exception:
logger.error("Failed to process image (%s)", data)
raise
return MultiModalInputs(batch_data)
image_input_mapper_for_qwen2_vl = partial(mm_input_mapper_for_qwen2_vl,
data_type_key="image")
video_input_mapper_for_qwen2_vl = partial(mm_input_mapper_for_qwen2_vl,
data_type_key="video")
def _get_vision_info(
image_processor,
height: int,
width: int,
min_pixels: int,
max_pixels: int,
do_resize: bool = True,
data_type_key: str = "image",
mm_count: int = 1,
):
"""Get information (resized height / width and number of vision tokens)
of input image / video frame."""
if do_resize:
resized_height, resized_width = smart_resize(
height=height,
width=width,
factor=image_processor.patch_size * image_processor.merge_size,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
else:
resized_height, resized_width = height, width
if data_type_key == "image":
grid_t = mm_count
else:
assert data_type_key == "video"
grid_t = max(mm_count // image_processor.temporal_patch_size, 1)
grid_h = resized_height // image_processor.patch_size
grid_w = resized_width // image_processor.patch_size
vision_tokens = grid_t * grid_h * grid_w
llm_num_vision_tokens = (vision_tokens // image_processor.merge_size //
image_processor.merge_size)
return resized_height, resized_width, llm_num_vision_tokens
def _get_max_image_info(
image_processor,
data_type_key: str = "image",
mm_count: int = 1,
):
return _get_vision_info(
image_processor,
height=9999999,
width=9999999,
# Limit min / max pixels.
min_pixels=max(image_processor.min_pixels, 28 * 28),
max_pixels=min(image_processor.max_pixels, 1280 * 28 * 28),
data_type_key=data_type_key,
mm_count=mm_count,
)
def get_max_qwen2_vl_mm_tokens(ctx: InputContext, data_type_key: str) -> int:
image_processor = cached_get_image_processor(ctx.model_config.model)
max_resized_height, max_resized_width, max_llm_image_tokens = \
_get_max_image_info(image_processor, data_type_key=data_type_key,
mm_count=1)
return max_llm_image_tokens
get_max_qwen2_vl_image_tokens = partial(get_max_qwen2_vl_mm_tokens,
data_type_key="image")
get_max_qwen2_vl_video_tokens = partial(get_max_qwen2_vl_mm_tokens,
data_type_key="video")
def dummy_data_for_qwen2_vl(
ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]
) -> Tuple[SequenceData, Optional[MultiModalDataDict]]:
image_processor = cached_get_image_processor(ctx.model_config.model)
num_images = mm_counts["image"]
max_resized_height, max_resized_width, max_llm_image_tokens = \
_get_max_image_info(image_processor, data_type_key="image",
mm_count=num_images)
if seq_len - max_llm_image_tokens - 2 < 0:
raise RuntimeError(
f"Qwen2-VL cannot process {num_images} images in a prompt, "
"please increase max_model_len or reduce image limit by "
"--limit-mm-per-prompt.")
# Check video counts.
num_videos = mm_counts["video"]
max_resized_height, max_resized_width, max_llm_video_tokens = \
_get_max_image_info(image_processor, data_type_key="video",
mm_count=num_videos)
if seq_len - max_llm_video_tokens - 2 < 0:
raise RuntimeError(
f"Qwen2-VL cannot process {num_images} videos in a prompt, "
"please increase max_model_len or reduce video limit by "
"--limit-mm-per-prompt.")
hf_config = ctx.get_hf_config(Qwen2VLConfig)
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[hf_config.vision_start_token_id])
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[hf_config.image_token_id]) * max_llm_image_tokens
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[hf_config.vision_end_token_id])
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - max_llm_image_tokens - 2)
dummy_seqdata = SequenceData(token_ids)
dummy_image = Image.new("RGB", (max_resized_width, max_resized_height),
color=0)
return dummy_seqdata, {
"image": dummy_image if num_images == 1 else [dummy_image] * num_images
}
def _get_llm_num_vision_tokens(
mm_inputs: list,
data_type_key: str,
image_processor,
):
"""Get number of vision tokens of multimodal inputs.
This method is derived from `transformers.models.qwen2_vl.
image_processing_qwen2_vl.Qwen2VLImageProcessor._preprocess`.
"""
image = to_numpy_array(mm_inputs[0])
input_data_format = infer_channel_dimension_format(image)
height, width = get_image_size(image, channel_dim=input_data_format)
_, _, llm_num_vision_tokens = _get_vision_info(
image_processor,
height=height,
width=width,
min_pixels=image_processor.min_pixels,
max_pixels=image_processor.max_pixels,
do_resize=image_processor.do_resize,
data_type_key=data_type_key,
mm_count=len(mm_inputs),
)
return llm_num_vision_tokens
def input_processor_for_qwen2_vl(ctx: InputContext,
llm_inputs: LLMInputs) -> LLMInputs:
multi_modal_data = llm_inputs.get("multi_modal_data", None)
if multi_modal_data is None:
return llm_inputs
image_inputs = multi_modal_data.get("image", None)
video_inputs = multi_modal_data.get("video", None)
processor = cached_get_processor(ctx.model_config.model)
image_processor = processor.image_processor
hf_config = ctx.get_hf_config(Qwen2VLConfig)
# To avoid redundant processing of vision objects (resize, rescale, etc.),
# we extract code of calculating number of vision tokens from
# `transformers.models.qwen2_vl.processing_qwen2_vl.Qwen2VLProcessor`.
#
# The following code is equivalent to:
# prompt = llm_inputs["prompt"]
# inputs = processor(text=[prompt],
# images=image_inputs,
# videos=video_inputs,
# padding=True,
# return_tensors="pt")
# prompt_token_ids = inputs["input_ids"][0].tolist()
prompt_token_ids = llm_inputs.get("prompt_token_ids", None)
if prompt_token_ids is None:
prompt = llm_inputs["prompt"]
prompt_token_ids = processor.tokenizer(
prompt,
padding=True,
return_tensors=None,
)["input_ids"]
# Expand image pad tokens.
if image_inputs is not None:
image_indices = [
idx for idx, token in enumerate(prompt_token_ids)
if token == hf_config.image_token_id
]
image_inputs = make_batched_images(image_inputs)
assert len(image_indices) == len(image_inputs)
prompt_token_ids_with_image = []
for image_cnt, image in enumerate(image_inputs):
num_image_tokens = _get_llm_num_vision_tokens(
[image],
data_type_key="image",
image_processor=image_processor,
)
if image_cnt == 0:
non_image_tokens = prompt_token_ids[:image_indices[image_cnt]]
else:
non_image_tokens = prompt_token_ids[image_indices[image_cnt -
1] +
1:image_indices[image_cnt]]
prompt_token_ids_with_image.extend(non_image_tokens)
prompt_token_ids_with_image.extend(
hf_config.image_token_id for _ in range(num_image_tokens))
prompt_token_ids_with_image.extend(prompt_token_ids[image_indices[-1] +
1:])
prompt_token_ids = prompt_token_ids_with_image
# Expand video pad tokens.
if video_inputs is not None:
video_indices = [
idx for idx, token in enumerate(prompt_token_ids)
if token == hf_config.video_token_id
]
video_inputs = make_batched_videos(video_inputs)
assert len(video_indices) == len(video_inputs)
prompt_token_ids_with_video = []
for video_cnt, video in enumerate(video_inputs):
num_video_tokens = _get_llm_num_vision_tokens(
video,
data_type_key="video",
image_processor=image_processor,
)
if video_cnt == 0:
non_video_tokens = prompt_token_ids[:video_indices[video_cnt]]
else:
non_video_tokens = prompt_token_ids[video_indices[video_cnt -
1] +
1:video_indices[video_cnt]]
prompt_token_ids_with_video.extend(non_video_tokens)
prompt_token_ids_with_video.extend(
hf_config.video_token_id for _ in range(num_video_tokens))
prompt_token_ids_with_video.extend(prompt_token_ids[video_indices[-1] +
1:])
prompt_token_ids = prompt_token_ids_with_video
return LLMInputs(
prompt_token_ids=prompt_token_ids,
prompt=llm_inputs["prompt"],
multi_modal_data=multi_modal_data,
)
@MULTIMODAL_REGISTRY.register_image_input_mapper(
image_input_mapper_for_qwen2_vl)
@MULTIMODAL_REGISTRY.register_input_mapper("video",
video_input_mapper_for_qwen2_vl)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_qwen2_vl_image_tokens)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"video", get_max_qwen2_vl_video_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen2_vl)
@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen2_vl)
class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__(self,
config: Qwen2VLConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
assert not cache_config.enable_prefix_caching, \
"Qwen2-VL currently does not support prefix caching"
self.config = config
self.multimodal_config = multimodal_config
self.visual = Qwen2VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
# NOTE: Qwen2-VL vision encoder does not support any
# quantization method now.
quant_config=None,
)
self.model = Qwen2Model(config, cache_config, quant_config)
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def _validate_and_reshape_mm_tensor(self,
mm_input: Union[torch.Tensor,
List[torch.Tensor]],
name: str) -> torch.Tensor:
if not isinstance(mm_input, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of {name}. "
f"Got type: {type(mm_input)}")
if isinstance(mm_input, torch.Tensor):
if mm_input.ndim == 2:
return mm_input
if mm_input.ndim != 3:
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
f"Got ndim: {mm_input.ndim}")
return torch.concat(list(mm_input))
else:
return torch.concat(mm_input)
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Qwen2VLImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_grid_thw = kwargs.pop("image_grid_thw", None)
if pixel_values is None:
return None
pixel_values = self._validate_and_reshape_mm_tensor(
pixel_values, "image pixel values")
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw")
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of image pixel values. "
f"Got type: {type(pixel_values)}")
return Qwen2VLImageInputs(pixel_values=pixel_values,
image_grid_thw=image_grid_thw)
def _parse_and_validate_video_input(
self, **kwargs: object) -> Optional[Qwen2VLVideoInputs]:
pixel_values_videos = kwargs.pop("pixel_values_videos", None)
video_grid_thw = kwargs.pop("video_grid_thw", None)
if pixel_values_videos is None:
return None
pixel_values_videos = self._validate_and_reshape_mm_tensor(
pixel_values_videos, "video pixel values")
video_grid_thw = self._validate_and_reshape_mm_tensor(
video_grid_thw, "video grid_thw")
return Qwen2VLVideoInputs(
pixel_values_videos=pixel_values_videos,
video_grid_thw=video_grid_thw,
)
def _process_image_input(self,
image_input: Qwen2VLImageInputs) -> torch.Tensor:
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
image_embeds = self.visual(pixel_values,
grid_thw=image_input["image_grid_thw"])
return image_embeds
def _process_video_input(self,
video_input: Qwen2VLVideoInputs) -> torch.Tensor:
pixel_values_videos = video_input["pixel_values_videos"].type(
self.visual.dtype)
video_embeds = self.visual(pixel_values_videos,
grid_thw=video_input["video_grid_thw"])
return video_embeds
def _merge_multimodal_embeddings(
self,
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
multimodal_embeddings: torch.Tensor,
placeholder_token_id: int,
) -> torch.Tensor:
mask = (input_ids == placeholder_token_id)
inputs_embeds[mask, :] = multimodal_embeddings
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object,
) -> SamplerOutput:
"""Run forward pass for Qwen2-VL.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
positions: Flattened (concatenated) position ids corresponding to a
batch.
**NOTE**: If mrope is enabled (default setting for Qwen2-VL
opensource models), the shape will be `(3, seq_len)`,
otherwise it will be `(seq_len,).
pixel_values: Pixel values to be fed to a model.
`None` if no images are passed.
image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM.
`None` if no images are passed.
pixel_values_videos: Pixel values of videos to be fed to a model.
`None` if no videos are passed.
video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM.
`None` if no videos are passed.
"""
image_input = self._parse_and_validate_image_input(**kwargs)
video_input = self._parse_and_validate_video_input(**kwargs)
if image_input is None and video_input is None:
inputs_embeds = None
else:
if getattr(self.config, "rope_scaling", {}).get("type",
None) == "mrope":
assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}")
inputs_embeds = self.model.embed_tokens(input_ids)
if image_input is not None:
image_embeds = self._process_image_input(image_input)
inputs_embeds = self._merge_multimodal_embeddings(
input_ids,
inputs_embeds,
image_embeds,
placeholder_token_id=self.config.image_token_id,
)
if video_input is not None:
video_embeds = self._process_video_input(video_input)
inputs_embeds = self._merge_multimodal_embeddings(
input_ids,
inputs_embeds,
video_embeds,
placeholder_token_id=self.config.video_token_id,
)
input_ids = None
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
inputs_embeds=inputs_embeds,
)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "up_proj", 1),
("gate_up_proj", "gate_proj", 0),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if "visual" in name and "qkv.weight" in name:
visual_num_heads = self.config.vision_config.num_heads
visual_embed_dim = self.config.vision_config.embed_dim
head_size = visual_embed_dim // visual_num_heads
loaded_weight = loaded_weight.view(3, visual_num_heads,
head_size,
visual_embed_dim)
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
elif "visual" in name and "qkv.bias" in name:
visual_num_heads = self.config.vision_config.num_heads
visual_embed_dim = self.config.vision_config.embed_dim
head_size = visual_embed_dim // visual_num_heads
loaded_weight = loaded_weight.view(3, visual_num_heads,
head_size)
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1)
try:
param = params_dict[name]
except KeyError:
print(params_dict.keys())
raise
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
...@@ -79,14 +79,12 @@ class MultiModalInputs(_MultiModalInputsBase): ...@@ -79,14 +79,12 @@ class MultiModalInputs(_MultiModalInputsBase):
if len(inputs_list) == 0: if len(inputs_list) == 0:
return {} return {}
keys = inputs_list[0].keys()
item_lists: Dict[str, List[NestedTensors]] = defaultdict(list) item_lists: Dict[str, List[NestedTensors]] = defaultdict(list)
for inputs in inputs_list: for inputs in inputs_list:
if inputs.keys() != keys: # For models that supports multiple modalities (e.g. Qwen2-VL),
msg = f"Inputs do not share the same keys ({keys})" # different modalities will return different data keys,
raise ValueError(msg) # so batch() should skip the same key check.
for k, v in inputs.items(): for k, v in inputs.items():
item_lists[k].append(v) item_lists[k].append(v)
......
...@@ -165,6 +165,9 @@ class SequenceData(msgspec.Struct, ...@@ -165,6 +165,9 @@ class SequenceData(msgspec.Struct,
# is called. # is called.
_new_appended_tokens: List[int] = msgspec.field(default_factory=list) _new_appended_tokens: List[int] = msgspec.field(default_factory=list)
# It is used to compute mrope_position_ids.
_mrope_position_delta: Optional[int] = None
def __post_init__(self) -> None: def __post_init__(self) -> None:
assert self._prompt_token_ids.typecode == "l" assert self._prompt_token_ids.typecode == "l"
assert self._output_token_ids.typecode == "l" assert self._output_token_ids.typecode == "l"
...@@ -219,6 +222,14 @@ class SequenceData(msgspec.Struct, ...@@ -219,6 +222,14 @@ class SequenceData(msgspec.Struct,
assert isinstance(self._output_token_ids, array) assert isinstance(self._output_token_ids, array)
return self._output_token_ids return self._output_token_ids
@property
def mrope_position_delta(self) -> Optional[int]:
return self._mrope_position_delta
@mrope_position_delta.setter
def mrope_position_delta(self, new_mrope_position_delta):
self._mrope_position_delta = new_mrope_position_delta
def append_token_id(self, token_id: int, logprob: float) -> None: def append_token_id(self, token_id: int, logprob: float) -> None:
self._output_token_ids.append(token_id) self._output_token_ids.append(token_id)
self._new_appended_tokens.append(token_id) self._new_appended_tokens.append(token_id)
......
from typing import cast
def get_processor(
processor_name: str,
*args,
trust_remote_code: bool = False,
**kwargs,
):
"""Gets a processor for the given model name via HuggingFace."""
# don't put this import at the top level
# it will call torch.cuda.device_count()
from transformers import AutoProcessor
from transformers.processing_utils import ProcessorMixin
try:
processor = AutoProcessor.from_pretrained(
processor_name,
*args,
trust_remote_code=trust_remote_code,
**kwargs)
except ValueError as e:
# If the error pertains to the processor class not existing or not
# currently being imported, suggest using the --trust-remote-code flag.
# Unlike AutoTokenizer, AutoProcessor does not separate such errors
if not trust_remote_code:
err_msg = (
"Failed to load the processor. If the processor is "
"a custom processor not yet available in the HuggingFace "
"transformers library, consider setting "
"`trust_remote_code=True` in LLM or using the "
"`--trust-remote-code` flag in the CLI.")
raise RuntimeError(err_msg) from e
else:
raise e
return cast(ProcessorMixin, processor)
...@@ -30,6 +30,7 @@ from vllm.lora.layers import LoRAMapping ...@@ -30,6 +30,7 @@ from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata, SamplingMetadataCache from vllm.model_executor import SamplingMetadata, SamplingMetadataCache
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
...@@ -181,6 +182,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -181,6 +182,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
def simple_reinit(self): def simple_reinit(self):
self.input_tokens[0].clear() # type: ignore self.input_tokens[0].clear() # type: ignore
self.input_positions[0].clear() # type: ignore self.input_positions[0].clear() # type: ignore
self.mrope_input_positions = None # type: ignore
self.seq_lens[0] = 0 # type: ignore self.seq_lens[0] = 0 # type: ignore
self.orig_seq_lens[0] = 0 # type: ignore self.orig_seq_lens[0] = 0 # type: ignore
self.query_lens[0] = 0 # type: ignore self.query_lens[0] = 0 # type: ignore
...@@ -206,6 +208,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -206,6 +208,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# Input tokens and positions. # Input tokens and positions.
input_tokens: Optional[List[List[int]]] = None, input_tokens: Optional[List[List[int]]] = None,
input_positions: Optional[List[List[int]]] = None, input_positions: Optional[List[List[int]]] = None,
mrope_input_positions: Optional[List[List[List[int]]]] = None,
# The sequence length (may be capped to the sliding window). # The sequence length (may be capped to the sliding window).
seq_lens: Optional[List[int]] = None, seq_lens: Optional[List[int]] = None,
...@@ -266,6 +269,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -266,6 +269,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
for seq_id in range(len(self.seq_ids)): for seq_id in range(len(self.seq_ids)):
self.input_positions[seq_id].clear() self.input_positions[seq_id].clear()
self.mrope_input_positions = None
if seq_lens: if seq_lens:
self.seq_lens = seq_lens self.seq_lens = seq_lens
else: else:
...@@ -327,6 +332,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -327,6 +332,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
else: else:
self.input_tokens = input_tokens or [] self.input_tokens = input_tokens or []
self.input_positions = input_positions or [] self.input_positions = input_positions or []
self.mrope_input_positions = mrope_input_positions or None
self.seq_lens = seq_lens or [] self.seq_lens = seq_lens or []
self.orig_seq_lens = orig_seq_lens or [] self.orig_seq_lens = orig_seq_lens or []
self.query_lens = query_lens or [] self.query_lens = query_lens or []
...@@ -357,6 +363,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -357,6 +363,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.input_tokens = [[] for _ in range(self.n_seqs)] self.input_tokens = [[] for _ in range(self.n_seqs)]
self.input_positions = [[] for _ in range(self.n_seqs)] self.input_positions = [[] for _ in range(self.n_seqs)]
self.mrope_input_positions = None
self.seq_lens = [0] * self.n_seqs self.seq_lens = [0] * self.n_seqs
self.orig_seq_lens = [0] * self.n_seqs self.orig_seq_lens = [0] * self.n_seqs
self.query_lens = [0] * self.n_seqs self.query_lens = [0] * self.n_seqs
...@@ -493,6 +500,17 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -493,6 +500,17 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
inter_data.query_lens[ inter_data.query_lens[
seq_idx] = seq_len - context_len if inter_data.is_prompt else 1 seq_idx] = seq_len - context_len if inter_data.is_prompt else 1
if seq_data.mrope_position_delta is not None:
if inter_data.mrope_input_positions is None:
inter_data.mrope_input_positions = [None] * inter_data.n_seqs
inter_data.mrope_input_positions[
seq_idx] = MRotaryEmbedding.get_next_input_positions(
seq_data.mrope_position_delta,
context_len,
seq_len,
)
def _compute_for_prefix_cache_hit( def _compute_for_prefix_cache_hit(
self, inter_data: InterDataForSeqGroup, seq_idx: int, self, inter_data: InterDataForSeqGroup, seq_idx: int,
seq_group_metadata: SequenceGroupMetadata): seq_group_metadata: SequenceGroupMetadata):
...@@ -636,6 +654,40 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -636,6 +654,40 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
mm_kwargs = self.multi_modal_input_mapper(mm_data) mm_kwargs = self.multi_modal_input_mapper(mm_data)
inter_data.multi_modal_inputs = mm_kwargs inter_data.multi_modal_inputs = mm_kwargs
# special processing for mrope position deltas.
if self.runner.model_is_mrope:
image_grid_thw = mm_kwargs.get("image_grid_thw", None)
video_grid_thw = mm_kwargs.get("video_grid_thw", None)
assert image_grid_thw is not None or video_grid_thw is not None, (
"mrope embedding type requires multi-modal input mapper "
"returns 'image_grid_thw' or 'video_grid_thw'.")
hf_config = self.runner.model_config.hf_config
inter_data.mrope_input_positions = [None] * inter_data.n_seqs
for seq_idx in range(inter_data.n_seqs):
seq_data = seq_group_metadata.seq_data[
inter_data.seq_ids[seq_idx]]
token_ids = seq_data.get_token_ids()
mrope_input_positions, mrope_position_delta = \
MRotaryEmbedding.get_input_positions(
token_ids,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
image_token_id=hf_config.image_token_id,
video_token_id=hf_config.video_token_id,
vision_start_token_id=hf_config.vision_start_token_id,
vision_end_token_id=hf_config.vision_end_token_id,
spatial_merge_size=hf_config.vision_config.
spatial_merge_size,
context_len=inter_data.context_lens[seq_idx],
)
seq_data.mrope_position_delta = mrope_position_delta
inter_data.mrope_input_positions[
seq_idx] = mrope_input_positions
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
"""Add a sequence group to the builder.""" """Add a sequence group to the builder."""
seq_ids = seq_group_metadata.seq_data.keys() seq_ids = seq_group_metadata.seq_data.keys()
...@@ -684,6 +736,23 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -684,6 +736,23 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# prefix caching and there is no decode request. # prefix caching and there is no decode request.
return self.model_input_cls() return self.model_input_cls()
mrope_input_positions: Optional[List[List[int]]] = None
if any(inter_data.mrope_input_positions is not None
for inter_data in self.inter_data_list):
mrope_input_positions = [[] for _ in range(3)]
for idx in range(3):
for inter_data in self.inter_data_list:
msections = inter_data.mrope_input_positions
if msections is None:
for _seq_input_positions in inter_data.input_positions:
mrope_input_positions[idx].extend(
_seq_input_positions)
else:
for _seq_mrope_input_positions in msections:
mrope_input_positions[idx].extend(
_seq_mrope_input_positions[idx])
input_positions = None
else:
input_positions = [] input_positions = []
for inter_data in self.inter_data_list: for inter_data in self.inter_data_list:
for cur_input_positions in inter_data.input_positions: for cur_input_positions in inter_data.input_positions:
...@@ -724,15 +793,24 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): ...@@ -724,15 +793,24 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# Tokens and positions. # Tokens and positions.
if cuda_graph_pad_size: if cuda_graph_pad_size:
input_tokens.extend(itertools.repeat(0, cuda_graph_pad_size)) input_tokens.extend(itertools.repeat(0, cuda_graph_pad_size))
input_positions.extend(itertools.repeat(0, cuda_graph_pad_size))
assert self.runner.device is not None assert self.runner.device is not None
input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long, input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long,
self.runner.device, self.runner.device,
self.runner.pin_memory) self.runner.pin_memory)
input_positions_tensor = async_tensor_h2d(input_positions, torch.long, if mrope_input_positions is not None:
for idx in range(3):
mrope_input_positions[idx].extend(
itertools.repeat(0, cuda_graph_pad_size))
input_positions_tensor = async_tensor_h2d(mrope_input_positions,
torch.long,
self.runner.device,
self.runner.pin_memory)
else:
input_positions.extend(itertools.repeat(0, cuda_graph_pad_size))
input_positions_tensor = async_tensor_h2d(input_positions,
torch.long,
self.runner.device, self.runner.device,
self.runner.pin_memory) self.runner.pin_memory)
# Sequence and query lengths. # Sequence and query lengths.
if cuda_graph_pad_size: if cuda_graph_pad_size:
seq_lens.extend(itertools.repeat(1, cuda_graph_pad_size)) seq_lens.extend(itertools.repeat(1, cuda_graph_pad_size))
...@@ -1199,6 +1277,15 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1199,6 +1277,15 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
raise RuntimeError("PromptAdapter is not enabled.") raise RuntimeError("PromptAdapter is not enabled.")
return self.prompt_adapter_manager.list_adapters() return self.prompt_adapter_manager.list_adapters()
@property
def model_is_mrope(self) -> bool:
"""Detect if the model has "mrope" rope_scaling type.
mrope requires keep "rope_deltas" between prompt and decoding phases."""
rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
if rope_scaling is None:
return False
return rope_scaling.get("type", None) == "mrope"
@torch.inference_mode() @torch.inference_mode()
def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
"""Cuda graph capture a model. """Cuda graph capture a model.
...@@ -1229,7 +1316,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1229,7 +1316,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
max_batch_size = self.max_batchsize_to_capture max_batch_size = self.max_batchsize_to_capture
input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda() input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda() input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
if self.model_is_mrope:
input_positions = torch.tile(input_positions, (3, 1))
# Prepare dummy previous_hidden_states only if needed by the model. # Prepare dummy previous_hidden_states only if needed by the model.
# This is used by draft models such as EAGLE. # This is used by draft models such as EAGLE.
previous_hidden_states = None previous_hidden_states = None
...@@ -1293,7 +1381,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1293,7 +1381,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
"input_ids": "input_ids":
input_tokens[:batch_size], input_tokens[:batch_size],
"positions": "positions":
input_positions[:batch_size], input_positions[..., :batch_size],
"hidden_or_intermediate_states": "hidden_or_intermediate_states":
hidden_or_intermediate_states[ hidden_or_intermediate_states[
virtual_engine] # type: ignore virtual_engine] # type: ignore
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment