Unverified Commit cbbc82b7 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

Support qwen2 vl model (#1721)


Co-authored-by: default avataryizhang2077 <1109276519@qq.com>
Co-authored-by: default avatarispobock <ISPObaoke@163.com>
parent 8bee20f8
...@@ -73,7 +73,7 @@ jobs: ...@@ -73,7 +73,7 @@ jobs:
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
- name: Run test - name: Run test
timeout-minutes: 20 timeout-minutes: 30
run: | run: |
cd test/srt cd test/srt
python3 run_suite.py --suite minimal --range-begin 5 --range-end 17 python3 run_suite.py --suite minimal --range-begin 5 --range-end 17
...@@ -93,7 +93,7 @@ jobs: ...@@ -93,7 +93,7 @@ jobs:
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
- name: Run test - name: Run test
timeout-minutes: 20 timeout-minutes: 30
run: | run: |
cd test/srt cd test/srt
python3 run_suite.py --suite minimal --range-begin 17 python3 run_suite.py --suite minimal --range-begin 17
......
...@@ -133,6 +133,22 @@ register_chat_template( ...@@ -133,6 +133,22 @@ register_chat_template(
) )
) )
# Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example
register_chat_template(
ChatTemplate(
name="qwen2-vl",
default_system_prompt="You are a helpful assistant.",
role_prefix_and_suffix={
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
},
style=ChatTemplateStyle.PLAIN,
stop_str=("<|im_end|>"),
image_token="<|vision_start|><|image_pad|><|vision_end|>",
)
)
register_chat_template( register_chat_template(
ChatTemplate( ChatTemplate(
......
from sglang.srt.configs.exaone import ExaoneConfig from sglang.srt.configs.exaone import ExaoneConfig
from sglang.srt.configs.qwen2vl import Qwen2VLConfig, Qwen2VLVisionConfig
__all__ = [ __all__ = [
"ExaoneConfig", "ExaoneConfig",
"Qwen2VLConfig",
"Qwen2VLVisionConfig",
] ]
# coding=utf-8
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Qwen2VL model configuration"""
import os
from typing import Union
from transformers import PretrainedConfig
class Qwen2VLVisionConfig(PretrainedConfig):
model_type = "qwen2_vl"
def __init__(
self,
depth=32,
embed_dim=1280,
hidden_size=3584,
hidden_act="quick_gelu",
mlp_ratio=4,
num_heads=16,
in_channels=3,
patch_size=14,
spatial_merge_size=2,
temporal_patch_size=2,
**kwargs,
):
super().__init__(**kwargs)
self.depth = depth
self.embed_dim = embed_dim
self.hidden_size = hidden_size
self.hidden_act = hidden_act
self.mlp_ratio = mlp_ratio
self.num_heads = num_heads
self.in_channels = in_channels
self.patch_size = patch_size
self.spatial_merge_size = spatial_merge_size
self.temporal_patch_size = temporal_patch_size
@classmethod
def from_pretrained(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(
pretrained_model_name_or_path, **kwargs
)
if config_dict.get("model_type") == "qwen2_vl":
config_dict = config_dict["vision_config"]
return cls.from_dict(config_dict, **kwargs)
class Qwen2VLConfig(PretrainedConfig):
model_type = "qwen2_vl"
def __init__(
self,
vocab_size=152064,
hidden_size=8192,
intermediate_size=29568,
num_hidden_layers=80,
num_attention_heads=64,
num_key_value_heads=8,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=1e-05,
use_cache=True,
tie_word_embeddings=False,
rope_theta=1000000.0,
use_sliding_window=False,
sliding_window=4096,
max_window_layers=80,
attention_dropout=0.0,
vision_config=None,
rope_scaling=None,
**kwargs,
):
if isinstance(vision_config, dict):
self.vision_config = Qwen2VLVisionConfig(**vision_config)
elif vision_config is None:
self.vision_config = Qwen2VLVisionConfig()
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.use_sliding_window = use_sliding_window
self.sliding_window = sliding_window
self.max_window_layers = max_window_layers
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout
self.rope_scaling = rope_scaling
# NOTE: the following section from original transformers config
# for Qwen2-VL is commented out to address rope config loading issue
#
# if self.rope_scaling is not None and "type" in self.rope_scaling:
# if self.rope_scaling["type"] == "mrope":
# self.rope_scaling["type"] = "default"
# self.rope_scaling["rope_type"] = self.rope_scaling["type"]
# rope_config_validation(self)
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
...@@ -530,3 +530,17 @@ register_conv_template( ...@@ -530,3 +530,17 @@ register_conv_template(
stop_str=["<|im_end|>", "<|action_end|>"], stop_str=["<|im_end|>", "<|action_end|>"],
) )
) )
# Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example
register_conv_template(
Conversation(
name="qwen2-vl",
system_message="You are a helpful assistant.",
system_template="<|im_start|>system\n{system_message}",
roles=("<|im_start|>user", "<|im_start|>assistant"),
sep="<|im_end|>\n",
sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
stop_str=["<|im_end|>"],
image_token="<|vision_start|><|image_pad|><|vision_end|>",
)
)
...@@ -33,12 +33,13 @@ from transformers import ( ...@@ -33,12 +33,13 @@ from transformers import (
try: try:
from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig
from sglang.srt.configs import ExaoneConfig from sglang.srt.configs import ExaoneConfig, Qwen2VLConfig
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
ChatGLMConfig.model_type: ChatGLMConfig, ChatGLMConfig.model_type: ChatGLMConfig,
DbrxConfig.model_type: DbrxConfig, DbrxConfig.model_type: DbrxConfig,
ExaoneConfig.model_type: ExaoneConfig, ExaoneConfig.model_type: ExaoneConfig,
Qwen2VLConfig.model_type: Qwen2VLConfig,
} }
except ImportError: except ImportError:
# We want this file to run without vllm dependency # We want this file to run without vllm dependency
......
...@@ -50,6 +50,7 @@ def _fwd_kernel( ...@@ -50,6 +50,7 @@ def _fwd_kernel(
BLOCK_M: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
IS_CAUSAL: tl.constexpr,
Lk: tl.constexpr, Lk: tl.constexpr,
): ):
cur_batch = tl.program_id(0) cur_batch = tl.program_id(0)
...@@ -78,7 +79,9 @@ def _fwd_kernel( ...@@ -78,7 +79,9 @@ def _fwd_kernel(
mask_d = offs_d < Lk mask_d = offs_d < Lk
q = tl.load( q = tl.load(
Q + off_q, mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d), other=0.0 Q + off_q,
mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :]),
other=0.0,
) )
k_ptrs = K + off_k k_ptrs = K + off_k
...@@ -91,7 +94,12 @@ def _fwd_kernel( ...@@ -91,7 +94,12 @@ def _fwd_kernel(
block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): end_n = (
cur_batch_seq_len
if not IS_CAUSAL
else tl.minimum((start_m + 1) * BLOCK_M, cur_batch_seq_len)
)
for start_n in range(0, block_mask * end_n, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N) start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ---- # -- compute qk ----
k = tl.load( k = tl.load(
...@@ -104,7 +112,18 @@ def _fwd_kernel( ...@@ -104,7 +112,18 @@ def _fwd_kernel(
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k) qk += tl.dot(q, k)
qk *= sm_scale qk *= sm_scale
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
if IS_CAUSAL:
qk += tl.where(
(start_n + offs_n[None, :] < cur_batch_seq_len)
& (offs_m[:, None] >= (start_n + offs_n[None, :])),
0,
float("-inf"),
)
else:
qk += tl.where(
(start_n + offs_n[None, :]) < cur_batch_seq_len, 0, float("-inf")
)
# -- compute m_ij, p, l_ij # -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1) m_ij = tl.max(qk, 1)
...@@ -146,7 +165,9 @@ def _fwd_kernel( ...@@ -146,7 +165,9 @@ def _fwd_kernel(
) )
def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): def context_attention_fwd(
q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True
):
if is_cuda_available and CUDA_CAPABILITY[0] >= 8: if is_cuda_available and CUDA_CAPABILITY[0] >= 8:
BLOCK = 128 BLOCK = 128
else: else:
...@@ -181,6 +202,7 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): ...@@ -181,6 +202,7 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
BLOCK_M=BLOCK, BLOCK_M=BLOCK,
BLOCK_DMODEL=triton.next_power_of_2(Lk), BLOCK_DMODEL=triton.next_power_of_2(Lk),
BLOCK_N=BLOCK, BLOCK_N=BLOCK,
IS_CAUSAL=is_causal,
num_warps=num_warps, num_warps=num_warps,
num_stages=1, num_stages=1,
Lk=Lk, Lk=Lk,
......
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""MRotaryEmbedding"""
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
class MRotaryEmbedding:
"""Rotary Embedding with Multimodal Sections."""
@staticmethod
def get_input_positions(
input_tokens: List[int],
image_grid_thw: Union[List[List[int]], torch.Tensor],
video_grid_thw: Union[List[List[int]], torch.Tensor],
image_token_id: int,
video_token_id: int,
vision_start_token_id: int,
vision_end_token_id: int,
spatial_merge_size: int,
context_len: int = 0,
extend_prefix_len: int = 0,
) -> Tuple[List[List[int]], int]:
"""Get mrope input positions and delta value."""
if isinstance(image_grid_thw, torch.Tensor):
image_grid_thw = image_grid_thw.tolist()
if isinstance(video_grid_thw, torch.Tensor):
video_grid_thw = video_grid_thw.tolist()
input_tokens_tensor = torch.tensor(input_tokens)
vision_start_indices = torch.argwhere(
input_tokens_tensor == vision_start_token_id
).squeeze(1)
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
image_index, video_index = 0, 0
for _ in range(image_nums + video_nums):
if image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(image_token_id, st)
else:
ed_image = len(input_tokens) + 1
if video_token_id in input_tokens and remain_videos > 0:
ed_video = input_tokens.index(video_token_id, st)
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
h // spatial_merge_size,
w // spatial_merge_size,
)
text_len = ed - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
t_index = (
torch.arange(llm_grid_t)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
.flatten()
)
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(llm_grid_t, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(llm_grid_t, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
llm_positions = llm_positions[:, context_len:]
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
llm_positions += extend_prefix_len
return llm_positions.tolist(), mrope_position_delta
@staticmethod
def get_next_input_positions(
mrope_position_delta: int,
context_len: int,
seq_len: int,
) -> List[List[int]]:
return [
list(
range(
context_len + mrope_position_delta, seq_len + mrope_position_delta
)
)
for _ in range(3)
]
...@@ -177,10 +177,127 @@ class LlavaImageProcessor(BaseImageProcessor): ...@@ -177,10 +177,127 @@ class LlavaImageProcessor(BaseImageProcessor):
} }
class Qwen2VLImageProcessor(BaseImageProcessor):
def __init__(self, hf_config, server_args, _image_processor):
self.hf_config = hf_config
self._image_processor = _image_processor
self.executor = concurrent.futures.ProcessPoolExecutor(
initializer=init_global_processor,
mp_context=mp.get_context("fork"),
initargs=(server_args,),
max_workers=os.environ.get("SGLANG_CPU_COUNT", os.cpu_count()),
)
@staticmethod
def _process_single_image_task(
image_data: Union[str, bytes],
image_processor=None,
):
image_processor = image_processor or global_processor.image_processor
try:
image, image_size = load_image(image_data)
if image_size is not None:
# It is a video with multiple images
image_hash = hash(image_data)
process_result = image_processor(image)
pixel_values, image_grid_thws = (
process_result["pixel_values"],
process_result["image_grid_thw"][0],
)
for _ in range(len(pixel_values)):
pixel_values[_] = pixel_values[_].astype(np.float16)
pixel_values = np.stack(pixel_values, axis=0)
image_grid_thws = np.stack(image_grid_thws, axis=0)
return pixel_values, image_hash, image_size, image_grid_thws
else:
# It is an image
image_hash = hash(image_data)
process_result = image_processor(image)
pixel_values, image_grid_thws = (
process_result["pixel_values"],
process_result["image_grid_thw"][0],
)
if isinstance(pixel_values, np.ndarray):
pixel_values = pixel_values.astype(np.float16)
return pixel_values, image_hash, image.size, image_grid_thws
except Exception:
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
async def _process_single_image(self, image_data: Union[bytes, str]):
if self.executor is not None:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
Qwen2VLImageProcessor._process_single_image_task,
image_data,
)
else:
return self._process_single_image_task(image_data)
async def process_images_async(
self, image_data: List[Union[str, bytes]], request_obj
):
if not image_data:
return None
if isinstance(image_data, list) and len(image_data) > 0:
# Multiple images
if len(image_data) > 1:
pixel_values, image_hashes, image_sizes, image_grid_thws = (
[],
[],
[],
[],
)
res = []
for img_data in image_data:
res.append(self._process_single_image(img_data))
res = await asyncio.gather(*res)
for pixel_v, image_h, image_s, image_thw in res:
pixel_values.append(pixel_v)
image_hashes.append(image_h)
image_sizes.append(image_s)
image_grid_thws.append(image_thw)
if isinstance(pixel_values[0], np.ndarray):
pixel_values = np.concatenate(pixel_values, axis=0)
else:
# A single image
pixel_values, image_hash, image_size, image_grid_thw = (
await self._process_single_image(image_data[0])
)
image_hashes = [image_hash]
image_sizes = [image_size]
image_grid_thws = [image_grid_thw]
elif isinstance(image_data, str):
# A single image
pixel_values, image_hash, image_size, image_grid_thw = (
await self._process_single_image(image_data)
)
image_hashes = [image_hash]
image_sizes = [image_size]
image_grid_thws = [image_grid_thw]
else:
raise ValueError(f"Invalid image data: {image_data}")
return {
"pixel_values": pixel_values,
"image_hashes": image_hashes,
"image_sizes": image_sizes,
"modalities": request_obj.modalities,
"image_grid_thws": image_grid_thws,
}
def get_image_processor( def get_image_processor(
hf_config, server_args: ServerArgs, _image_processor hf_config, server_args: ServerArgs, _image_processor
) -> BaseImageProcessor: ) -> BaseImageProcessor:
return LlavaImageProcessor(hf_config, server_args, _image_processor) if "Qwen2VLForConditionalGeneration" in hf_config.architectures:
return Qwen2VLImageProcessor(hf_config, server_args, _image_processor)
else:
return LlavaImageProcessor(hf_config, server_args, _image_processor)
def get_dummy_image_processor(): def get_dummy_image_processor():
......
...@@ -128,6 +128,8 @@ class ImageInputs: ...@@ -128,6 +128,8 @@ class ImageInputs:
image_embeds: Optional[List[torch.Tensor]] = None image_embeds: Optional[List[torch.Tensor]] = None
aspect_ratio_ids: Optional[List[torch.Tensor]] = None aspect_ratio_ids: Optional[List[torch.Tensor]] = None
aspect_ratio_mask: Optional[List[torch.Tensor]] = None aspect_ratio_mask: Optional[List[torch.Tensor]] = None
# QWen2-VL related
image_grid_thws: List[Tuple[int, int, int]] = None
@staticmethod @staticmethod
def from_dict(obj, vocab_size): def from_dict(obj, vocab_size):
...@@ -135,6 +137,7 @@ class ImageInputs: ...@@ -135,6 +137,7 @@ class ImageInputs:
ret = ImageInputs( ret = ImageInputs(
pixel_values=obj["pixel_values"], pixel_values=obj["pixel_values"],
image_hash=hash(tuple(obj["image_hashes"])), image_hash=hash(tuple(obj["image_hashes"])),
image_grid_thws=obj.get("image_grid_thws"),
) )
image_hash = ret.image_hash image_hash = ret.image_hash
ret.pad_values = [ ret.pad_values = [
...@@ -236,6 +239,9 @@ class Req: ...@@ -236,6 +239,9 @@ class Req:
self.regex_fsm_state: int = 0 self.regex_fsm_state: int = 0
self.jump_forward_map: JumpForwardMap = None self.jump_forward_map: JumpForwardMap = None
# For Qwen2-VL
self.mrope_position_delta = [] # use mutable object
# whether request reached finished condition # whether request reached finished condition
def finished(self) -> bool: def finished(self) -> bool:
return self.finished_reason is not None return self.finished_reason is not None
...@@ -854,6 +860,8 @@ class ScheduleBatch: ...@@ -854,6 +860,8 @@ class ScheduleBatch:
global bid global bid
bid += 1 bid += 1
mrope_positions_delta = [req.mrope_position_delta for req in self.reqs]
return ModelWorkerBatch( return ModelWorkerBatch(
bid=bid, bid=bid,
forward_mode=self.forward_mode, forward_mode=self.forward_mode,
...@@ -869,6 +877,7 @@ class ScheduleBatch: ...@@ -869,6 +877,7 @@ class ScheduleBatch:
image_inputs=image_inputs, image_inputs=image_inputs,
lora_paths=lora_paths, lora_paths=lora_paths,
sampling_info=self.sampling_info, sampling_info=self.sampling_info,
mrope_positions_delta=mrope_positions_delta,
) )
def copy(self): def copy(self):
...@@ -920,6 +929,9 @@ class ModelWorkerBatch: ...@@ -920,6 +929,9 @@ class ModelWorkerBatch:
# Sampling info # Sampling info
sampling_info: SamplingBatchInfo sampling_info: SamplingBatchInfo
# For Qwen2-VL
mrope_positions_delta: List[List[int]]
def copy(self): def copy(self):
return ModelWorkerBatch( return ModelWorkerBatch(
bid=self.bid, bid=self.bid,
...@@ -936,4 +948,5 @@ class ModelWorkerBatch: ...@@ -936,4 +948,5 @@ class ModelWorkerBatch:
image_inputs=self.image_inputs, image_inputs=self.image_inputs,
lora_paths=self.lora_paths, lora_paths=self.lora_paths,
sampling_info=self.sampling_info.copy(), sampling_info=self.sampling_info.copy(),
mrope_positions_delta=self.mrope_positions_delta,
) )
...@@ -36,6 +36,8 @@ from typing import TYPE_CHECKING, List, Optional ...@@ -36,6 +36,8 @@ from typing import TYPE_CHECKING, List, Optional
import numpy as np import numpy as np
import torch import torch
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.attention import AttentionBackend from sglang.srt.layers.attention import AttentionBackend
from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch
...@@ -112,14 +114,88 @@ class ForwardBatch: ...@@ -112,14 +114,88 @@ class ForwardBatch:
token_to_kv_pool: BaseTokenToKVPool = None token_to_kv_pool: BaseTokenToKVPool = None
attn_backend: AttentionBackend = None attn_backend: AttentionBackend = None
# For Qwen2-VL
mrope_positions: torch.Tensor = None
def compute_mrope_positions(
self, model_runner: ModelRunner, batch: ModelWorkerBatch
):
device = model_runner.device
hf_config = model_runner.model_config.hf_config
mrope_positions_list = [None] * self.seq_lens.shape[0]
if self.forward_mode.is_decode():
for i, _ in enumerate(mrope_positions_list):
mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions(
batch.mrope_positions_delta[i][0],
int(self.seq_lens[i]) - 1,
int(self.seq_lens[i]),
)
elif self.forward_mode.is_extend():
for i, image_inputs in enumerate(batch.image_inputs):
if image_inputs is None:
# text only
mrope_positions = [[i for i in range(self.seq_lens[i])]] * 3
mrope_position_delta = 0
else:
extend_start_loc, extend_seq_len, extend_prefix_len = (
self.extend_start_loc[i],
self.extend_seq_lens[i],
self.extend_prefix_lens[i],
)
mrope_positions, mrope_position_delta = (
MRotaryEmbedding.get_input_positions(
input_tokens=self.input_ids[
extend_start_loc : extend_start_loc + extend_seq_len
].tolist(),
image_grid_thw=image_inputs.image_grid_thws,
video_grid_thw=None,
image_token_id=hf_config.image_token_id,
video_token_id=hf_config.video_token_id,
vision_start_token_id=hf_config.vision_start_token_id,
vision_end_token_id=hf_config.vision_end_token_id,
spatial_merge_size=hf_config.vision_config.spatial_merge_size,
context_len=0,
extend_prefix_len=extend_prefix_len.item(),
)
)
mrope_positions_list[i] = mrope_positions
batch.mrope_positions_delta[i].append(mrope_position_delta)
self.mrope_positions = torch.tensor(
np.concatenate(
[np.array(pos) for pos in mrope_positions_list],
axis=1,
),
device=device,
)
self.mrope_positions = self.mrope_positions.to(torch.int64)
def compute_positions(self, model_runner: ModelRunner, batch: ModelWorkerBatch):
device = model_runner.device
if self.forward_mode.is_decode():
self.positions = (self.seq_lens - 1).to(torch.int64)
else:
self.positions = torch.tensor(
np.concatenate(
[
np.arange(prefix_len, prefix_len + extend_len)
for prefix_len, extend_len in zip(
batch.extend_prefix_lens, batch.extend_seq_lens
)
],
axis=0,
),
device=device,
).to(torch.int64)
@classmethod @classmethod
def init_new( def init_new(
cls, cls,
batch: ModelWorkerBatch, batch: ModelWorkerBatch,
model_runner: ModelRunner, model_runner: ModelRunner,
): ):
device = model_runner.device
device = model_runner.device
ret = cls( ret = cls(
forward_mode=batch.forward_mode, forward_mode=batch.forward_mode,
batch_size=len(batch.seq_lens), batch_size=len(batch.seq_lens),
...@@ -156,6 +232,13 @@ class ForwardBatch: ...@@ -156,6 +232,13 @@ class ForwardBatch:
ret.extend_seq_lens_cpu = batch.extend_seq_lens ret.extend_seq_lens_cpu = batch.extend_seq_lens
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
# Init position information
is_mrope = model_runner.model_is_mrope
if is_mrope:
ret.compute_mrope_positions(model_runner, batch)
else:
ret.compute_positions(model_runner, batch)
# Init attention information # Init attention information
ret.req_to_token_pool = model_runner.req_to_token_pool ret.req_to_token_pool = model_runner.req_to_token_pool
ret.token_to_kv_pool = model_runner.token_to_kv_pool ret.token_to_kv_pool = model_runner.token_to_kv_pool
......
...@@ -125,6 +125,11 @@ class ModelRunner: ...@@ -125,6 +125,11 @@ class ModelRunner:
) )
server_args.chunked_prefill_size = None server_args.chunked_prefill_size = None
server_args.mem_fraction_static *= 0.95 server_args.mem_fraction_static *= 0.95
# TODO: qwen2-vl does not support cuda graph now, set disable-graph=True automatically
if self.model_config.hf_config.architectures == [
"Qwen2VLForConditionalGeneration"
]:
server_args.disable_cuda_graph = True
# Global vars # Global vars
if server_args.show_time_cost: if server_args.show_time_cost:
...@@ -622,6 +627,15 @@ class ModelRunner: ...@@ -622,6 +627,15 @@ class ModelRunner:
return logits return logits
@property
def model_is_mrope(self) -> bool:
"""Detect if the model has "mrope" rope_scaling type.
mrope requires keep "rope_deltas" between prompt and decoding phases."""
rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
if rope_scaling is None:
return False
return rope_scaling.get("type", None) == "mrope"
@lru_cache() @lru_cache()
def import_model_classes(): def import_model_classes():
......
This diff is collapsed.
...@@ -209,6 +209,7 @@ def is_multimodal_model(model_architectures): ...@@ -209,6 +209,7 @@ def is_multimodal_model(model_architectures):
or "LlavaQwenForCausalLM" in model_architectures or "LlavaQwenForCausalLM" in model_architectures
or "LlavaMistralForCausalLM" in model_architectures or "LlavaMistralForCausalLM" in model_architectures
or "LlavaVidForCausalLM" in model_architectures or "LlavaVidForCausalLM" in model_architectures
or "Qwen2VLForConditionalGeneration" in model_architectures
): ):
return True return True
else: else:
......
...@@ -344,5 +344,24 @@ class TestOpenAIVisionServer(unittest.TestCase): ...@@ -344,5 +344,24 @@ class TestOpenAIVisionServer(unittest.TestCase):
list(executor.map(self.run_decode_with_image, image_ids)) list(executor.map(self.run_decode_with_image, image_ids))
class TestQWen2VLServer(TestOpenAIVisionServer):
@classmethod
def setUpClass(cls):
cls.model = "Qwen/Qwen2-VL-7B-Instruct"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=[
"--chat-template",
"qwen2-vl",
],
)
cls.base_url += "/v1"
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment