Unverified Commit cbbc82b7 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

Support qwen2 vl model (#1721)


Co-authored-by: default avataryizhang2077 <1109276519@qq.com>
Co-authored-by: default avatarispobock <ISPObaoke@163.com>
parent 8bee20f8
...@@ -73,7 +73,7 @@ jobs: ...@@ -73,7 +73,7 @@ jobs:
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
- name: Run test - name: Run test
timeout-minutes: 20 timeout-minutes: 30
run: | run: |
cd test/srt cd test/srt
python3 run_suite.py --suite minimal --range-begin 5 --range-end 17 python3 run_suite.py --suite minimal --range-begin 5 --range-end 17
...@@ -93,7 +93,7 @@ jobs: ...@@ -93,7 +93,7 @@ jobs:
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
- name: Run test - name: Run test
timeout-minutes: 20 timeout-minutes: 30
run: | run: |
cd test/srt cd test/srt
python3 run_suite.py --suite minimal --range-begin 17 python3 run_suite.py --suite minimal --range-begin 17
......
...@@ -133,6 +133,22 @@ register_chat_template( ...@@ -133,6 +133,22 @@ register_chat_template(
) )
) )
# Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example
register_chat_template(
ChatTemplate(
name="qwen2-vl",
default_system_prompt="You are a helpful assistant.",
role_prefix_and_suffix={
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
},
style=ChatTemplateStyle.PLAIN,
stop_str=("<|im_end|>"),
image_token="<|vision_start|><|image_pad|><|vision_end|>",
)
)
register_chat_template( register_chat_template(
ChatTemplate( ChatTemplate(
......
from sglang.srt.configs.exaone import ExaoneConfig from sglang.srt.configs.exaone import ExaoneConfig
from sglang.srt.configs.qwen2vl import Qwen2VLConfig, Qwen2VLVisionConfig
__all__ = [ __all__ = [
"ExaoneConfig", "ExaoneConfig",
"Qwen2VLConfig",
"Qwen2VLVisionConfig",
] ]
# coding=utf-8
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Qwen2VL model configuration"""
import os
from typing import Union
from transformers import PretrainedConfig
class Qwen2VLVisionConfig(PretrainedConfig):
model_type = "qwen2_vl"
def __init__(
self,
depth=32,
embed_dim=1280,
hidden_size=3584,
hidden_act="quick_gelu",
mlp_ratio=4,
num_heads=16,
in_channels=3,
patch_size=14,
spatial_merge_size=2,
temporal_patch_size=2,
**kwargs,
):
super().__init__(**kwargs)
self.depth = depth
self.embed_dim = embed_dim
self.hidden_size = hidden_size
self.hidden_act = hidden_act
self.mlp_ratio = mlp_ratio
self.num_heads = num_heads
self.in_channels = in_channels
self.patch_size = patch_size
self.spatial_merge_size = spatial_merge_size
self.temporal_patch_size = temporal_patch_size
@classmethod
def from_pretrained(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(
pretrained_model_name_or_path, **kwargs
)
if config_dict.get("model_type") == "qwen2_vl":
config_dict = config_dict["vision_config"]
return cls.from_dict(config_dict, **kwargs)
class Qwen2VLConfig(PretrainedConfig):
model_type = "qwen2_vl"
def __init__(
self,
vocab_size=152064,
hidden_size=8192,
intermediate_size=29568,
num_hidden_layers=80,
num_attention_heads=64,
num_key_value_heads=8,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=1e-05,
use_cache=True,
tie_word_embeddings=False,
rope_theta=1000000.0,
use_sliding_window=False,
sliding_window=4096,
max_window_layers=80,
attention_dropout=0.0,
vision_config=None,
rope_scaling=None,
**kwargs,
):
if isinstance(vision_config, dict):
self.vision_config = Qwen2VLVisionConfig(**vision_config)
elif vision_config is None:
self.vision_config = Qwen2VLVisionConfig()
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.use_sliding_window = use_sliding_window
self.sliding_window = sliding_window
self.max_window_layers = max_window_layers
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout
self.rope_scaling = rope_scaling
# NOTE: the following section from original transformers config
# for Qwen2-VL is commented out to address rope config loading issue
#
# if self.rope_scaling is not None and "type" in self.rope_scaling:
# if self.rope_scaling["type"] == "mrope":
# self.rope_scaling["type"] = "default"
# self.rope_scaling["rope_type"] = self.rope_scaling["type"]
# rope_config_validation(self)
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
...@@ -530,3 +530,17 @@ register_conv_template( ...@@ -530,3 +530,17 @@ register_conv_template(
stop_str=["<|im_end|>", "<|action_end|>"], stop_str=["<|im_end|>", "<|action_end|>"],
) )
) )
# Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example
register_conv_template(
Conversation(
name="qwen2-vl",
system_message="You are a helpful assistant.",
system_template="<|im_start|>system\n{system_message}",
roles=("<|im_start|>user", "<|im_start|>assistant"),
sep="<|im_end|>\n",
sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
stop_str=["<|im_end|>"],
image_token="<|vision_start|><|image_pad|><|vision_end|>",
)
)
...@@ -33,12 +33,13 @@ from transformers import ( ...@@ -33,12 +33,13 @@ from transformers import (
try: try:
from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig
from sglang.srt.configs import ExaoneConfig from sglang.srt.configs import ExaoneConfig, Qwen2VLConfig
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
ChatGLMConfig.model_type: ChatGLMConfig, ChatGLMConfig.model_type: ChatGLMConfig,
DbrxConfig.model_type: DbrxConfig, DbrxConfig.model_type: DbrxConfig,
ExaoneConfig.model_type: ExaoneConfig, ExaoneConfig.model_type: ExaoneConfig,
Qwen2VLConfig.model_type: Qwen2VLConfig,
} }
except ImportError: except ImportError:
# We want this file to run without vllm dependency # We want this file to run without vllm dependency
......
...@@ -50,6 +50,7 @@ def _fwd_kernel( ...@@ -50,6 +50,7 @@ def _fwd_kernel(
BLOCK_M: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
IS_CAUSAL: tl.constexpr,
Lk: tl.constexpr, Lk: tl.constexpr,
): ):
cur_batch = tl.program_id(0) cur_batch = tl.program_id(0)
...@@ -78,7 +79,9 @@ def _fwd_kernel( ...@@ -78,7 +79,9 @@ def _fwd_kernel(
mask_d = offs_d < Lk mask_d = offs_d < Lk
q = tl.load( q = tl.load(
Q + off_q, mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d), other=0.0 Q + off_q,
mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :]),
other=0.0,
) )
k_ptrs = K + off_k k_ptrs = K + off_k
...@@ -91,7 +94,12 @@ def _fwd_kernel( ...@@ -91,7 +94,12 @@ def _fwd_kernel(
block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): end_n = (
cur_batch_seq_len
if not IS_CAUSAL
else tl.minimum((start_m + 1) * BLOCK_M, cur_batch_seq_len)
)
for start_n in range(0, block_mask * end_n, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N) start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ---- # -- compute qk ----
k = tl.load( k = tl.load(
...@@ -104,7 +112,18 @@ def _fwd_kernel( ...@@ -104,7 +112,18 @@ def _fwd_kernel(
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k) qk += tl.dot(q, k)
qk *= sm_scale qk *= sm_scale
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
if IS_CAUSAL:
qk += tl.where(
(start_n + offs_n[None, :] < cur_batch_seq_len)
& (offs_m[:, None] >= (start_n + offs_n[None, :])),
0,
float("-inf"),
)
else:
qk += tl.where(
(start_n + offs_n[None, :]) < cur_batch_seq_len, 0, float("-inf")
)
# -- compute m_ij, p, l_ij # -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1) m_ij = tl.max(qk, 1)
...@@ -146,7 +165,9 @@ def _fwd_kernel( ...@@ -146,7 +165,9 @@ def _fwd_kernel(
) )
def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): def context_attention_fwd(
q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True
):
if is_cuda_available and CUDA_CAPABILITY[0] >= 8: if is_cuda_available and CUDA_CAPABILITY[0] >= 8:
BLOCK = 128 BLOCK = 128
else: else:
...@@ -181,6 +202,7 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): ...@@ -181,6 +202,7 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
BLOCK_M=BLOCK, BLOCK_M=BLOCK,
BLOCK_DMODEL=triton.next_power_of_2(Lk), BLOCK_DMODEL=triton.next_power_of_2(Lk),
BLOCK_N=BLOCK, BLOCK_N=BLOCK,
IS_CAUSAL=is_causal,
num_warps=num_warps, num_warps=num_warps,
num_stages=1, num_stages=1,
Lk=Lk, Lk=Lk,
......
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""MRotaryEmbedding"""
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
class MRotaryEmbedding:
"""Rotary Embedding with Multimodal Sections."""
@staticmethod
def get_input_positions(
input_tokens: List[int],
image_grid_thw: Union[List[List[int]], torch.Tensor],
video_grid_thw: Union[List[List[int]], torch.Tensor],
image_token_id: int,
video_token_id: int,
vision_start_token_id: int,
vision_end_token_id: int,
spatial_merge_size: int,
context_len: int = 0,
extend_prefix_len: int = 0,
) -> Tuple[List[List[int]], int]:
"""Get mrope input positions and delta value."""
if isinstance(image_grid_thw, torch.Tensor):
image_grid_thw = image_grid_thw.tolist()
if isinstance(video_grid_thw, torch.Tensor):
video_grid_thw = video_grid_thw.tolist()
input_tokens_tensor = torch.tensor(input_tokens)
vision_start_indices = torch.argwhere(
input_tokens_tensor == vision_start_token_id
).squeeze(1)
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
image_index, video_index = 0, 0
for _ in range(image_nums + video_nums):
if image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(image_token_id, st)
else:
ed_image = len(input_tokens) + 1
if video_token_id in input_tokens and remain_videos > 0:
ed_video = input_tokens.index(video_token_id, st)
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
h // spatial_merge_size,
w // spatial_merge_size,
)
text_len = ed - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
t_index = (
torch.arange(llm_grid_t)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
.flatten()
)
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(llm_grid_t, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(llm_grid_t, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
llm_positions = llm_positions[:, context_len:]
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
llm_positions += extend_prefix_len
return llm_positions.tolist(), mrope_position_delta
@staticmethod
def get_next_input_positions(
mrope_position_delta: int,
context_len: int,
seq_len: int,
) -> List[List[int]]:
return [
list(
range(
context_len + mrope_position_delta, seq_len + mrope_position_delta
)
)
for _ in range(3)
]
...@@ -177,10 +177,127 @@ class LlavaImageProcessor(BaseImageProcessor): ...@@ -177,10 +177,127 @@ class LlavaImageProcessor(BaseImageProcessor):
} }
class Qwen2VLImageProcessor(BaseImageProcessor):
def __init__(self, hf_config, server_args, _image_processor):
self.hf_config = hf_config
self._image_processor = _image_processor
self.executor = concurrent.futures.ProcessPoolExecutor(
initializer=init_global_processor,
mp_context=mp.get_context("fork"),
initargs=(server_args,),
max_workers=os.environ.get("SGLANG_CPU_COUNT", os.cpu_count()),
)
@staticmethod
def _process_single_image_task(
image_data: Union[str, bytes],
image_processor=None,
):
image_processor = image_processor or global_processor.image_processor
try:
image, image_size = load_image(image_data)
if image_size is not None:
# It is a video with multiple images
image_hash = hash(image_data)
process_result = image_processor(image)
pixel_values, image_grid_thws = (
process_result["pixel_values"],
process_result["image_grid_thw"][0],
)
for _ in range(len(pixel_values)):
pixel_values[_] = pixel_values[_].astype(np.float16)
pixel_values = np.stack(pixel_values, axis=0)
image_grid_thws = np.stack(image_grid_thws, axis=0)
return pixel_values, image_hash, image_size, image_grid_thws
else:
# It is an image
image_hash = hash(image_data)
process_result = image_processor(image)
pixel_values, image_grid_thws = (
process_result["pixel_values"],
process_result["image_grid_thw"][0],
)
if isinstance(pixel_values, np.ndarray):
pixel_values = pixel_values.astype(np.float16)
return pixel_values, image_hash, image.size, image_grid_thws
except Exception:
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
async def _process_single_image(self, image_data: Union[bytes, str]):
if self.executor is not None:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
Qwen2VLImageProcessor._process_single_image_task,
image_data,
)
else:
return self._process_single_image_task(image_data)
async def process_images_async(
self, image_data: List[Union[str, bytes]], request_obj
):
if not image_data:
return None
if isinstance(image_data, list) and len(image_data) > 0:
# Multiple images
if len(image_data) > 1:
pixel_values, image_hashes, image_sizes, image_grid_thws = (
[],
[],
[],
[],
)
res = []
for img_data in image_data:
res.append(self._process_single_image(img_data))
res = await asyncio.gather(*res)
for pixel_v, image_h, image_s, image_thw in res:
pixel_values.append(pixel_v)
image_hashes.append(image_h)
image_sizes.append(image_s)
image_grid_thws.append(image_thw)
if isinstance(pixel_values[0], np.ndarray):
pixel_values = np.concatenate(pixel_values, axis=0)
else:
# A single image
pixel_values, image_hash, image_size, image_grid_thw = (
await self._process_single_image(image_data[0])
)
image_hashes = [image_hash]
image_sizes = [image_size]
image_grid_thws = [image_grid_thw]
elif isinstance(image_data, str):
# A single image
pixel_values, image_hash, image_size, image_grid_thw = (
await self._process_single_image(image_data)
)
image_hashes = [image_hash]
image_sizes = [image_size]
image_grid_thws = [image_grid_thw]
else:
raise ValueError(f"Invalid image data: {image_data}")
return {
"pixel_values": pixel_values,
"image_hashes": image_hashes,
"image_sizes": image_sizes,
"modalities": request_obj.modalities,
"image_grid_thws": image_grid_thws,
}
def get_image_processor( def get_image_processor(
hf_config, server_args: ServerArgs, _image_processor hf_config, server_args: ServerArgs, _image_processor
) -> BaseImageProcessor: ) -> BaseImageProcessor:
return LlavaImageProcessor(hf_config, server_args, _image_processor) if "Qwen2VLForConditionalGeneration" in hf_config.architectures:
return Qwen2VLImageProcessor(hf_config, server_args, _image_processor)
else:
return LlavaImageProcessor(hf_config, server_args, _image_processor)
def get_dummy_image_processor(): def get_dummy_image_processor():
......
...@@ -128,6 +128,8 @@ class ImageInputs: ...@@ -128,6 +128,8 @@ class ImageInputs:
image_embeds: Optional[List[torch.Tensor]] = None image_embeds: Optional[List[torch.Tensor]] = None
aspect_ratio_ids: Optional[List[torch.Tensor]] = None aspect_ratio_ids: Optional[List[torch.Tensor]] = None
aspect_ratio_mask: Optional[List[torch.Tensor]] = None aspect_ratio_mask: Optional[List[torch.Tensor]] = None
# QWen2-VL related
image_grid_thws: List[Tuple[int, int, int]] = None
@staticmethod @staticmethod
def from_dict(obj, vocab_size): def from_dict(obj, vocab_size):
...@@ -135,6 +137,7 @@ class ImageInputs: ...@@ -135,6 +137,7 @@ class ImageInputs:
ret = ImageInputs( ret = ImageInputs(
pixel_values=obj["pixel_values"], pixel_values=obj["pixel_values"],
image_hash=hash(tuple(obj["image_hashes"])), image_hash=hash(tuple(obj["image_hashes"])),
image_grid_thws=obj.get("image_grid_thws"),
) )
image_hash = ret.image_hash image_hash = ret.image_hash
ret.pad_values = [ ret.pad_values = [
...@@ -236,6 +239,9 @@ class Req: ...@@ -236,6 +239,9 @@ class Req:
self.regex_fsm_state: int = 0 self.regex_fsm_state: int = 0
self.jump_forward_map: JumpForwardMap = None self.jump_forward_map: JumpForwardMap = None
# For Qwen2-VL
self.mrope_position_delta = [] # use mutable object
# whether request reached finished condition # whether request reached finished condition
def finished(self) -> bool: def finished(self) -> bool:
return self.finished_reason is not None return self.finished_reason is not None
...@@ -854,6 +860,8 @@ class ScheduleBatch: ...@@ -854,6 +860,8 @@ class ScheduleBatch:
global bid global bid
bid += 1 bid += 1
mrope_positions_delta = [req.mrope_position_delta for req in self.reqs]
return ModelWorkerBatch( return ModelWorkerBatch(
bid=bid, bid=bid,
forward_mode=self.forward_mode, forward_mode=self.forward_mode,
...@@ -869,6 +877,7 @@ class ScheduleBatch: ...@@ -869,6 +877,7 @@ class ScheduleBatch:
image_inputs=image_inputs, image_inputs=image_inputs,
lora_paths=lora_paths, lora_paths=lora_paths,
sampling_info=self.sampling_info, sampling_info=self.sampling_info,
mrope_positions_delta=mrope_positions_delta,
) )
def copy(self): def copy(self):
...@@ -920,6 +929,9 @@ class ModelWorkerBatch: ...@@ -920,6 +929,9 @@ class ModelWorkerBatch:
# Sampling info # Sampling info
sampling_info: SamplingBatchInfo sampling_info: SamplingBatchInfo
# For Qwen2-VL
mrope_positions_delta: List[List[int]]
def copy(self): def copy(self):
return ModelWorkerBatch( return ModelWorkerBatch(
bid=self.bid, bid=self.bid,
...@@ -936,4 +948,5 @@ class ModelWorkerBatch: ...@@ -936,4 +948,5 @@ class ModelWorkerBatch:
image_inputs=self.image_inputs, image_inputs=self.image_inputs,
lora_paths=self.lora_paths, lora_paths=self.lora_paths,
sampling_info=self.sampling_info.copy(), sampling_info=self.sampling_info.copy(),
mrope_positions_delta=self.mrope_positions_delta,
) )
...@@ -36,6 +36,8 @@ from typing import TYPE_CHECKING, List, Optional ...@@ -36,6 +36,8 @@ from typing import TYPE_CHECKING, List, Optional
import numpy as np import numpy as np
import torch import torch
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.attention import AttentionBackend from sglang.srt.layers.attention import AttentionBackend
from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch
...@@ -112,14 +114,88 @@ class ForwardBatch: ...@@ -112,14 +114,88 @@ class ForwardBatch:
token_to_kv_pool: BaseTokenToKVPool = None token_to_kv_pool: BaseTokenToKVPool = None
attn_backend: AttentionBackend = None attn_backend: AttentionBackend = None
# For Qwen2-VL
mrope_positions: torch.Tensor = None
def compute_mrope_positions(
self, model_runner: ModelRunner, batch: ModelWorkerBatch
):
device = model_runner.device
hf_config = model_runner.model_config.hf_config
mrope_positions_list = [None] * self.seq_lens.shape[0]
if self.forward_mode.is_decode():
for i, _ in enumerate(mrope_positions_list):
mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions(
batch.mrope_positions_delta[i][0],
int(self.seq_lens[i]) - 1,
int(self.seq_lens[i]),
)
elif self.forward_mode.is_extend():
for i, image_inputs in enumerate(batch.image_inputs):
if image_inputs is None:
# text only
mrope_positions = [[i for i in range(self.seq_lens[i])]] * 3
mrope_position_delta = 0
else:
extend_start_loc, extend_seq_len, extend_prefix_len = (
self.extend_start_loc[i],
self.extend_seq_lens[i],
self.extend_prefix_lens[i],
)
mrope_positions, mrope_position_delta = (
MRotaryEmbedding.get_input_positions(
input_tokens=self.input_ids[
extend_start_loc : extend_start_loc + extend_seq_len
].tolist(),
image_grid_thw=image_inputs.image_grid_thws,
video_grid_thw=None,
image_token_id=hf_config.image_token_id,
video_token_id=hf_config.video_token_id,
vision_start_token_id=hf_config.vision_start_token_id,
vision_end_token_id=hf_config.vision_end_token_id,
spatial_merge_size=hf_config.vision_config.spatial_merge_size,
context_len=0,
extend_prefix_len=extend_prefix_len.item(),
)
)
mrope_positions_list[i] = mrope_positions
batch.mrope_positions_delta[i].append(mrope_position_delta)
self.mrope_positions = torch.tensor(
np.concatenate(
[np.array(pos) for pos in mrope_positions_list],
axis=1,
),
device=device,
)
self.mrope_positions = self.mrope_positions.to(torch.int64)
def compute_positions(self, model_runner: ModelRunner, batch: ModelWorkerBatch):
device = model_runner.device
if self.forward_mode.is_decode():
self.positions = (self.seq_lens - 1).to(torch.int64)
else:
self.positions = torch.tensor(
np.concatenate(
[
np.arange(prefix_len, prefix_len + extend_len)
for prefix_len, extend_len in zip(
batch.extend_prefix_lens, batch.extend_seq_lens
)
],
axis=0,
),
device=device,
).to(torch.int64)
@classmethod @classmethod
def init_new( def init_new(
cls, cls,
batch: ModelWorkerBatch, batch: ModelWorkerBatch,
model_runner: ModelRunner, model_runner: ModelRunner,
): ):
device = model_runner.device
device = model_runner.device
ret = cls( ret = cls(
forward_mode=batch.forward_mode, forward_mode=batch.forward_mode,
batch_size=len(batch.seq_lens), batch_size=len(batch.seq_lens),
...@@ -156,6 +232,13 @@ class ForwardBatch: ...@@ -156,6 +232,13 @@ class ForwardBatch:
ret.extend_seq_lens_cpu = batch.extend_seq_lens ret.extend_seq_lens_cpu = batch.extend_seq_lens
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
# Init position information
is_mrope = model_runner.model_is_mrope
if is_mrope:
ret.compute_mrope_positions(model_runner, batch)
else:
ret.compute_positions(model_runner, batch)
# Init attention information # Init attention information
ret.req_to_token_pool = model_runner.req_to_token_pool ret.req_to_token_pool = model_runner.req_to_token_pool
ret.token_to_kv_pool = model_runner.token_to_kv_pool ret.token_to_kv_pool = model_runner.token_to_kv_pool
......
...@@ -125,6 +125,11 @@ class ModelRunner: ...@@ -125,6 +125,11 @@ class ModelRunner:
) )
server_args.chunked_prefill_size = None server_args.chunked_prefill_size = None
server_args.mem_fraction_static *= 0.95 server_args.mem_fraction_static *= 0.95
# TODO: qwen2-vl does not support cuda graph now, set disable-graph=True automatically
if self.model_config.hf_config.architectures == [
"Qwen2VLForConditionalGeneration"
]:
server_args.disable_cuda_graph = True
# Global vars # Global vars
if server_args.show_time_cost: if server_args.show_time_cost:
...@@ -622,6 +627,15 @@ class ModelRunner: ...@@ -622,6 +627,15 @@ class ModelRunner:
return logits return logits
@property
def model_is_mrope(self) -> bool:
"""Detect if the model has "mrope" rope_scaling type.
mrope requires keep "rope_deltas" between prompt and decoding phases."""
rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
if rope_scaling is None:
return False
return rope_scaling.get("type", None) == "mrope"
@lru_cache() @lru_cache()
def import_model_classes(): def import_model_classes():
......
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
from functools import lru_cache, partial
from typing import Iterable, List, Mapping, Optional, Tuple, Type, TypedDict, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import QuickGELU
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsMultiModal
from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig
from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
context_attention_fwd,
)
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.qwen2 import Qwen2Model
logger = init_logger(__name__)
# === Vision Inputs === #
class Qwen2VLImageInputs(TypedDict):
pixel_values: torch.Tensor
"""Shape:
`(num_patches, num_channels * patch_size * patch_size)`
"""
image_grid_thw: torch.Tensor
"""Shape: `(num_images, 3)`
This should be in `(grid_t, grid_h, grid_w)` format.
"""
class Qwen2VLVideoInputs(TypedDict):
pixel_values_videos: torch.Tensor
"""Shape:
`(num_patches,
num_channels * temporal_patch_size * patch_size * patch_size)`
"""
video_grid_thw: torch.Tensor
"""Shape: `(num_videos, 3)`
This should be in `(grid_t, grid_h, grid_w)` format.
"""
# === Vision Encoder === #
class Qwen2VisionMLP(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: int = None,
act_layer: Type[nn.Module] = QuickGELU,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.fc1 = ColumnParallelLinear(
in_features, hidden_features, quant_config=quant_config
)
self.act = act_layer()
self.fc2 = RowParallelLinear(
hidden_features, in_features, quant_config=quant_config
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_parallel, _ = self.fc1(x)
x_parallel = self.act(x_parallel)
x, _ = self.fc2(x_parallel)
return x
def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
if not interleaved:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
else:
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
)
def apply_rotary_emb_torch(
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
) -> torch.Tensor:
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
"""
ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
cos = repeat(
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
sin = repeat(
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
return torch.cat(
[
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
x[..., ro_dim:],
],
dim=-1,
)
def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
t_ = t.float()
cos = freqs.cos()
sin = freqs.sin()
output = apply_rotary_emb_torch(t_, cos, sin).type_as(t)
return output
class Qwen2VisionAttention(nn.Module):
def __init__(
self,
embed_dim: Optional[int] = None,
num_heads: Optional[int] = None,
projection_size: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
# Per attention head and per partition values.
world_size = parallel_state.get_tensor_model_parallel_world_size()
self.hidden_size_per_attention_head = dist_utils.divide(
projection_size, num_heads
)
self.num_attention_heads_per_partition = dist_utils.divide(
num_heads, world_size
)
self.qkv = ColumnParallelLinear(
input_size=embed_dim,
output_size=3 * projection_size,
quant_config=quant_config,
)
self.proj = RowParallelLinear(
input_size=projection_size, output_size=embed_dim, quant_config=quant_config
)
def forward(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor = None,
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)
# [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
new_x_shape = x.size()[:-1] + (
self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head,
)
x = x.view(*new_x_shape)
# [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
q, k, v = dist_utils.split_tensor_along_last_dim(x, 3)
batch_size = q.shape[1]
q, k, v = [rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)]
if rotary_pos_emb is not None:
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
max_seqlen = (seq_lens).max().item()
q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
output = torch.empty_like(q)
context_attention_fwd(
q, k, v, output, cu_seqlens, seq_lens, max_seqlen, is_causal=False
)
context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
output, _ = self.proj(context_layer)
return output
class Qwen2VisionBlock(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float,
act_layer: Type[nn.Module] = QuickGELU,
norm_layer: Type[nn.Module] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.norm1 = norm_layer(dim)
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.attn = Qwen2VisionAttention(
embed_dim=dim,
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
)
self.mlp = Qwen2VisionMLP(
dim, mlp_hidden_dim, act_layer=act_layer, quant_config=quant_config
)
def forward(
self, x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor
) -> torch.Tensor:
x = x + self.attn(
self.norm1(x), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
)
x = x + self.mlp(self.norm2(x))
return x
class Qwen2VisionPatchEmbed(nn.Module):
def __init__(
self,
patch_size: int = 14,
temporal_patch_size: int = 2,
in_chans: int = 3,
embed_dim: int = 1152,
) -> None:
super().__init__()
self.patch_size = patch_size
self.temporal_patch_size = temporal_patch_size
self.embed_dim = embed_dim
kernel_size = [temporal_patch_size, patch_size, patch_size]
self.proj = nn.Conv3d(
in_chans, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
L, C = x.shape
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
x = self.proj(x).view(L, self.embed_dim)
return x
class Qwen2VisionPatchMerger(nn.Module):
def __init__(
self,
d_model: int,
context_dim: int,
norm_layer: Type[nn.Module] = None,
spatial_merge_size: int = 2,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = context_dim * (spatial_merge_size**2)
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.ln_q = norm_layer(context_dim)
self.mlp = nn.ModuleList(
[
ColumnParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
quant_config=quant_config,
),
nn.GELU(),
RowParallelLinear(
self.hidden_size, d_model, bias=True, quant_config=quant_config
),
]
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.ln_q(x)
x = x.view(-1, self.hidden_size)
mlp_fc1, mlp_act, mlp_fc2 = self.mlp
x_parallel, _ = mlp_fc1(x)
x_parallel = mlp_act(x_parallel)
out, _ = mlp_fc2(x_parallel)
return out
class Qwen2VisionRotaryEmbedding(nn.Module):
def __init__(self, dim: int, theta: float = 10000.0) -> None:
super().__init__()
self.dim = dim
self.theta = theta
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._seq_len_cached = 0
self._freqs_cached = None
def update_freqs_cache(self, seqlen: int) -> None:
if seqlen > self._seq_len_cached:
seqlen *= 2
self._seq_len_cached = seqlen
self.inv_freq = 1.0 / (
self.theta
** (
torch.arange(
0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device
)
/ self.dim
)
)
seq = torch.arange(
seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
)
freqs = torch.outer(seq, self.inv_freq)
self._freqs_cached = freqs
def forward(self, seqlen: int) -> torch.Tensor:
self.update_freqs_cache(seqlen)
return self._freqs_cached[:seqlen]
class Qwen2VisionTransformer(nn.Module):
def __init__(
self,
vision_config: Qwen2VLVisionConfig,
norm_eps: float = 1e-6,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
patch_size: int = vision_config.patch_size
temporal_patch_size: int = vision_config.temporal_patch_size
spatial_merge_size: int = vision_config.spatial_merge_size
in_chans: int = vision_config.in_chans
hidden_size: int = vision_config.hidden_size
embed_dim: int = vision_config.embed_dim
depth: int = vision_config.depth
num_heads: int = vision_config.num_heads
mlp_ratio: float = vision_config.mlp_ratio
self.spatial_merge_size = spatial_merge_size
self.patch_embed = Qwen2VisionPatchEmbed(
patch_size=patch_size,
temporal_patch_size=temporal_patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
)
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
head_dim = embed_dim // num_heads
self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2)
self.blocks = nn.ModuleList(
[
Qwen2VisionBlock(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
norm_layer=norm_layer,
quant_config=quant_config,
)
for _ in range(depth)
]
)
self.merger = Qwen2VisionPatchMerger(
d_model=hidden_size,
context_dim=embed_dim,
norm_layer=norm_layer,
quant_config=quant_config,
)
@property
def dtype(self) -> torch.dtype:
return self.blocks[0].mlp.fc2.weight.dtype
@property
def device(self) -> torch.device:
return self.blocks[0].mlp.fc2.weight.device
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
pos_ids = []
for t, h, w in grid_thw:
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
hpos_ids = (
hpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
.permute(0, 2, 1, 3)
.flatten()
)
wpos_ids = (
wpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
.permute(0, 2, 1, 3)
.flatten()
)
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
def forward(
self,
x: torch.Tensor,
grid_thw: torch.Tensor,
) -> torch.Tensor:
# patchify
x = x.to(device=self.device, dtype=self.dtype)
x = self.patch_embed(x)
# compute position embedding
rotary_pos_emb = self.rot_pos_emb(grid_thw)
# compute cu_seqlens
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
).cumsum(dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
# transformers
x = x.unsqueeze(1)
for blk in self.blocks:
x = blk(x, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
# adapter
x = self.merger(x)
return x
cached_get_processor = lru_cache(get_processor)
class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
def calculate_num_image_tokens(self, image_grid_thw: Tuple[int, int, int]):
processor = cached_get_processor(self.config._name_or_path)
grid_t, grid_h, grid_w = image_grid_thw
num_image_tokens = (
grid_t
* grid_h
* grid_w
// processor.image_processor.merge_size
// processor.image_processor.merge_size
)
return num_image_tokens
# Use grid_t * grid_w * grid_h to pad tokens for each image
# and replaced padding by unique image hash
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
image_grid_thws = image_inputs.image_grid_thws
pad_values = image_inputs.pad_values
image_indices = [
idx
for idx, token in enumerate(input_ids)
if token == self.config.image_token_id
]
image_inputs.image_offsets = []
input_ids_with_image = []
for image_cnt, _ in enumerate(image_grid_thws):
num_image_tokens = self.calculate_num_image_tokens(
image_grid_thws[image_cnt]
)
if image_cnt == 0:
non_image_tokens = input_ids[: image_indices[image_cnt]]
else:
non_image_tokens = input_ids[
image_indices[image_cnt - 1] + 1 : image_indices[image_cnt]
]
input_ids_with_image.extend(non_image_tokens)
image_inputs.image_offsets.append(len(input_ids_with_image))
pad_ids = pad_values * (
(num_image_tokens + len(pad_values)) // len(pad_values)
)
input_ids_with_image.extend(pad_ids[:num_image_tokens])
input_ids_with_image.extend(input_ids[image_indices[-1] + 1 :])
return input_ids_with_image
def __init__(
self,
config: Qwen2VLConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.multimodal_config = multimodal_config
self.visual = Qwen2VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
# NOTE: Qwen2-VL vision encoder does not support any
# quantization method now.
quant_config=None,
)
self.model = Qwen2Model(config, quant_config)
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
)
self.logits_processor = LogitsProcessor(config)
def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor:
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=image_input["image_grid_thw"])
return image_embeds
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
pixel_values_videos = video_input["pixel_values_videos"].type(self.visual.dtype)
video_embeds = self.visual(
pixel_values_videos, grid_thw=video_input["video_grid_thw"]
)
return video_embeds
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
):
"""Run forward pass for Qwen2-VL.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
positions: Flattened (concatenated) position ids corresponding to a
batch.
**NOTE**: If mrope is enabled (default setting for Qwen2-VL
opensource models), the shape will be `(3, seq_len)`,
otherwise it will be `(seq_len,).
(Use input_metadata.mrope_positions to replace it)
pixel_values: Pixel values to be fed to a model.
`None` if no images are passed.
image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM.
`None` if no images are passed.
"""
image_inputs = None
if forward_batch.image_inputs is not None:
image_inputs = [
img for img in forward_batch.image_inputs if img is not None
]
positions = forward_batch.mrope_positions
if image_inputs is None or len(image_inputs) == 0:
inputs_embeds = self.model.embed_tokens(input_ids)
else:
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}"
)
inputs_embeds = self.model.embed_tokens(input_ids)
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
prefix_lens_cpu = forward_batch.extend_prefix_lens.cpu().numpy()
for i, image in enumerate(forward_batch.image_inputs):
if image == None:
continue
start_idx = extend_start_loc_cpu[i]
prefix_len = prefix_lens_cpu[i]
pixel_values = torch.tensor(image.pixel_values, device="cuda")
image_grid_thws = torch.tensor(
np.array(image.image_grid_thws), device="cuda"
)
image_offsets = image.image_offsets
image_input = Qwen2VLImageInputs(
pixel_values=pixel_values, image_grid_thw=image_grid_thws
)
image_embeds = self._process_image_input(image_input)
image_embeds_offset = 0
for idx, image_offset in enumerate(image_offsets):
if image_offset < prefix_len:
continue
num_image_tokens = self.calculate_num_image_tokens(
image_grid_thws[idx]
)
left_idx = start_idx + (image_offset - prefix_len)
right_idx = (
start_idx + (image_offset - prefix_len) + num_image_tokens
)
inputs_embeds[left_idx:right_idx] = image_embeds[
image_embeds_offset : image_embeds_offset + num_image_tokens
]
image_embeds_offset += num_image_tokens
input_ids = None
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
forward_batch=forward_batch,
input_embeds=inputs_embeds,
)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "up_proj", 1),
("gate_up_proj", "gate_proj", 0),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if "visual" in name and "qkv.weight" in name:
visual_num_heads = self.config.vision_config.num_heads
visual_embed_dim = self.config.vision_config.embed_dim
head_size = visual_embed_dim // visual_num_heads
loaded_weight = loaded_weight.view(
3, visual_num_heads, head_size, visual_embed_dim
)
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
elif "visual" in name and "qkv.bias" in name:
visual_num_heads = self.config.vision_config.num_heads
visual_embed_dim = self.config.vision_config.embed_dim
head_size = visual_embed_dim // visual_num_heads
loaded_weight = loaded_weight.view(3, visual_num_heads, head_size)
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1)
try:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
except KeyError:
print(params_dict.keys())
raise
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
EntryClass = Qwen2VLForConditionalGeneration
...@@ -209,6 +209,7 @@ def is_multimodal_model(model_architectures): ...@@ -209,6 +209,7 @@ def is_multimodal_model(model_architectures):
or "LlavaQwenForCausalLM" in model_architectures or "LlavaQwenForCausalLM" in model_architectures
or "LlavaMistralForCausalLM" in model_architectures or "LlavaMistralForCausalLM" in model_architectures
or "LlavaVidForCausalLM" in model_architectures or "LlavaVidForCausalLM" in model_architectures
or "Qwen2VLForConditionalGeneration" in model_architectures
): ):
return True return True
else: else:
......
...@@ -344,5 +344,24 @@ class TestOpenAIVisionServer(unittest.TestCase): ...@@ -344,5 +344,24 @@ class TestOpenAIVisionServer(unittest.TestCase):
list(executor.map(self.run_decode_with_image, image_ids)) list(executor.map(self.run_decode_with_image, image_ids))
class TestQWen2VLServer(TestOpenAIVisionServer):
@classmethod
def setUpClass(cls):
cls.model = "Qwen/Qwen2-VL-7B-Instruct"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=[
"--chat-template",
"qwen2-vl",
],
)
cls.base_url += "/v1"
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment