Unverified Commit 3d93f84a authored by Mick's avatar Mick Committed by GitHub
Browse files

[Feature] Support minicpmv v2.6 (#2785)


Co-authored-by: default avatarChayenne <zhaochen20@outlook.com>
Co-authored-by: default avataryizhang2077 <1109276519@qq.com>
parent c2f212d6
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
- InternLM 2 - InternLM 2
- Exaone 3 - Exaone 3
- BaiChuan2 - BaiChuan2
- MiniCPM / MiniCPM 3 - MiniCPM / MiniCPM 3 / MiniCPMV
- XVERSE / XVERSE MoE - XVERSE / XVERSE MoE
- SmolLM - SmolLM
- GLM-4 - GLM-4
......
...@@ -88,7 +88,6 @@ register_chat_template( ...@@ -88,7 +88,6 @@ register_chat_template(
) )
) )
register_chat_template( register_chat_template(
ChatTemplate( ChatTemplate(
name="claude", name="claude",
...@@ -101,7 +100,6 @@ register_chat_template( ...@@ -101,7 +100,6 @@ register_chat_template(
) )
) )
register_chat_template( register_chat_template(
ChatTemplate( ChatTemplate(
name="chatml", name="chatml",
...@@ -116,7 +114,6 @@ register_chat_template( ...@@ -116,7 +114,6 @@ register_chat_template(
) )
) )
register_chat_template( register_chat_template(
ChatTemplate( ChatTemplate(
name="chatml-llava", name="chatml-llava",
...@@ -132,7 +129,6 @@ register_chat_template( ...@@ -132,7 +129,6 @@ register_chat_template(
) )
) )
# There is default system prompt for qwen # There is default system prompt for qwen
# reference: https://modelscope.cn/models/qwen/Qwen2-72B-Instruct/file/view/master?fileName=tokenizer_config.json&status=1 # reference: https://modelscope.cn/models/qwen/Qwen2-72B-Instruct/file/view/master?fileName=tokenizer_config.json&status=1
# The chat template is: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" # The chat template is: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
...@@ -219,6 +215,21 @@ register_chat_template( ...@@ -219,6 +215,21 @@ register_chat_template(
) )
) )
# https://huggingface.co/openbmb/MiniCPM-V-2_6
register_chat_template(
ChatTemplate(
name="minicpmv",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("", " "),
"user": ("user:", " "),
"assistant": ("assistant:", "</s>"),
},
stop_str=("<|im_end|>", "<|endoftext|>"),
image_token="(<image>./</image>)",
)
)
# The difference between "llama-3-instruct-llava" and "llama-3-instruct" is that llava uses a different image_token. # The difference between "llama-3-instruct-llava" and "llama-3-instruct" is that llava uses a different image_token.
register_chat_template( register_chat_template(
ChatTemplate( ChatTemplate(
......
...@@ -402,6 +402,7 @@ def is_multimodal_model(model_architectures: List[str]): ...@@ -402,6 +402,7 @@ def is_multimodal_model(model_architectures: List[str]):
or "LlavaVidForCausalLM" in model_architectures or "LlavaVidForCausalLM" in model_architectures
or "MllamaForConditionalGeneration" in model_architectures or "MllamaForConditionalGeneration" in model_architectures
or "Qwen2VLForConditionalGeneration" in model_architectures or "Qwen2VLForConditionalGeneration" in model_architectures
or "MiniCPMV" in model_architectures
): ):
return True return True
else: else:
......
...@@ -452,7 +452,6 @@ def generate_chat_conv( ...@@ -452,7 +452,6 @@ def generate_chat_conv(
# Add a blank message for the assistant. # Add a blank message for the assistant.
conv.append_message(conv.roles[1], None) conv.append_message(conv.roles[1], None)
return conv return conv
...@@ -555,3 +554,17 @@ register_conv_template( ...@@ -555,3 +554,17 @@ register_conv_template(
image_token="<|vision_start|><|image_pad|><|vision_end|>", image_token="<|vision_start|><|image_pad|><|vision_end|>",
) )
) )
# Reference: https://huggingface.co/openbmb/MiniCPM-V-2_6#usage
register_conv_template(
Conversation(
name="minicpmv",
system_message="You are a helpful assistant",
system_template="<|im_start|>system\n{system_message}.",
roles=("<|im_start|>user", "<|im_start|>assistant"),
sep="<|im_end|>\n",
sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
stop_str=("<|im_end|>", "<|endoftext|>"),
image_token="(<image>./</image>)",
)
)
from __future__ import annotations
from typing import Optional
import torch
import torch.nn as nn
from einops import rearrange, repeat
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
context_attention_fwd,
)
from sglang.srt.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from sglang.srt.layers.quantization import QuantizationConfig
def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
if not interleaved:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
else:
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
)
def apply_rotary_emb_torch(
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
) -> torch.Tensor:
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
"""
ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
cos = repeat(
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
sin = repeat(
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
return torch.cat(
[
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
x[..., ro_dim:],
],
dim=-1,
)
def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
t_ = t.float()
cos = freqs.cos()
sin = freqs.sin()
output = apply_rotary_emb_torch(t_, cos, sin).type_as(t)
return output
class VisionAttention(nn.Module):
"""Multi-headed attention without any cache, mostly used for ViT."""
def __init__(
self,
embed_dim: int,
num_heads: int,
projection_size: int,
use_qkv_parallel: bool,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
world_size = parallel_state.get_tensor_model_parallel_world_size()
self.hidden_size_per_attention_head = dist_utils.divide(
projection_size, num_heads
)
self.num_attention_heads_per_partition = dist_utils.divide(
num_heads, world_size
)
# self.tp_size = get_tensor_model_parallel_world_size()
# num_heads = self.num_heads_per_partition
self.use_qkv_parallel = use_qkv_parallel
if use_qkv_parallel:
self.head_dim = embed_dim // num_heads
self.qkv_proj = QKVParallelLinear(
hidden_size=embed_dim,
head_size=self.head_dim,
total_num_heads=num_heads,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
else:
self.qkv_proj = ColumnParallelLinear(
input_size=embed_dim,
output_size=3 * projection_size,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.proj = RowParallelLinear(
input_size=embed_dim,
output_size=embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
def forward(
self,
x: torch.Tensor,
cu_seqlens: Optional[torch.Tensor] = None,
rotary_pos_emb: torch.Tensor = None,
) -> torch.Tensor:
"""
Input shape: [b, s, embed_dim]
Output shape: [s, b, num_heads * head_size]
"""
bsz, s, _ = x.shape
if self.use_qkv_parallel:
# [b, s, embed_dim] --> [b, s, embed_dim]
qkv, _ = self.qkv_proj(x)
q, k, v = qkv.chunk(3, dim=-1)
# [b, s, embed_dim] --> [b * s, num_heads, head_size]
q, k, v = [
x.reshape(
bsz * s, self.num_attention_heads_per_partition, -1
).contiguous()
for x in (q, k, v)
]
else:
# [b, s, embed_dim] --> [s, b, embed_dim]
x = rearrange(x, "b s ... -> s b ...")
# [s, b, embed_dim] --> [s, b, head * 3 * head_dim]
qkv, _ = self.qkv_proj(x)
# [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
new_x_shape = qkv.size()[:-1] + (
self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head,
)
qkv = qkv.view(*new_x_shape)
# [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3)
# [s, b, head, head_dim] --> [b, s, head, head_dim]
q, k, v = [
rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
]
if rotary_pos_emb is not None:
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
if self.use_qkv_parallel:
pass
else:
# [b, s, head, head_dim] --> [b * s, head, head_dim]
q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
# [b * s, num_heads, head_size]
output = torch.empty_like(q)
seq_lens = (cu_seqlens[1:] - cu_seqlens[:-1]).cuda()
max_seqlen = seq_lens.max().item()
context_attention_fwd(
q,
k,
v,
output,
cu_seqlens.cuda(),
seq_lens,
max_seqlen,
is_causal=False,
)
if self.use_qkv_parallel:
# [b * s, head, head_dim] --> [b, s, head * head_dim]
output = rearrange(output, "(b s) ... h d -> b s ... (h d)", b=bsz)
# [b, s, head, head_dim] --> [b, s, head, head_dim]
output, _ = self.proj(output)
else:
# [b * s, head, head_dim] --> [b, s, head, head_dim]
context_layer = rearrange(output, "(b s) ... -> b s ...", b=bsz)
# [s, b, num_heads * head_size]
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
).contiguous()
# [s, b, num_heads * head_size] --> [s, b, num_heads * head_size]
output, _ = self.proj(context_layer)
output = output.view(bsz, s, -1)
return output
...@@ -127,7 +127,7 @@ class LogitsProcessor(nn.Module): ...@@ -127,7 +127,7 @@ class LogitsProcessor(nn.Module):
hidden_states, hidden_states,
lm_head: VocabParallelEmbedding, lm_head: VocabParallelEmbedding,
logits_metadata: Union[LogitsMetadata, ForwardBatch], logits_metadata: Union[LogitsMetadata, ForwardBatch],
): ) -> LogitsProcessorOutput:
if isinstance(logits_metadata, ForwardBatch): if isinstance(logits_metadata, ForwardBatch):
logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata) logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
......
...@@ -56,6 +56,7 @@ class DataParallelController: ...@@ -56,6 +56,7 @@ class DataParallelController:
def __init__(self, server_args, port_args) -> None: def __init__(self, server_args, port_args) -> None:
# Parse args # Parse args
self.max_total_num_tokens = None
self.server_args = server_args self.server_args = server_args
self.port_args = port_args self.port_args = port_args
self.load_balance_method = LoadBalanceMethod.from_str( self.load_balance_method = LoadBalanceMethod.from_str(
...@@ -96,6 +97,8 @@ class DataParallelController: ...@@ -96,6 +97,8 @@ class DataParallelController:
True, True,
) )
self.max_req_input_len = None
def launch_dp_schedulers(self, server_args, port_args): def launch_dp_schedulers(self, server_args, port_args):
base_gpu_id = 0 base_gpu_id = 0
...@@ -189,6 +192,7 @@ class DataParallelController: ...@@ -189,6 +192,7 @@ class DataParallelController:
scheduler_info.append(scheduler_pipe_readers[i].recv()) scheduler_info.append(scheduler_pipe_readers[i].recv())
self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"] self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
self.max_req_input_len = scheduler_info[0]["max_req_input_len"]
def round_robin_scheduler(self, req): def round_robin_scheduler(self, req):
self.workers[self.round_robin_counter].send_pyobj(req) self.workers[self.round_robin_counter].send_pyobj(req)
...@@ -231,7 +235,11 @@ def run_data_parallel_controller_process( ...@@ -231,7 +235,11 @@ def run_data_parallel_controller_process(
try: try:
controller = DataParallelController(server_args, port_args) controller = DataParallelController(server_args, port_args)
pipe_writer.send( pipe_writer.send(
{"status": "ready", "max_total_num_tokens": controller.max_total_num_tokens} {
"status": "ready",
"max_total_num_tokens": controller.max_total_num_tokens,
"max_req_input_len": controller.max_req_input_len,
}
) )
if server_args.node_rank == 0: if server_args.node_rank == 0:
controller.event_loop() controller.event_loop()
......
...@@ -9,6 +9,8 @@ from typing import List, Optional, Union ...@@ -9,6 +9,8 @@ from typing import List, Optional, Union
import numpy as np import numpy as np
import transformers import transformers
from decord import VideoReader, cpu
from PIL import Image
from sglang.srt.hf_transformers_utils import get_processor from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.mm_utils import expand2square, process_anyres_image from sglang.srt.mm_utils import expand2square, process_anyres_image
...@@ -36,6 +38,7 @@ class BaseImageProcessor(ABC): ...@@ -36,6 +38,7 @@ class BaseImageProcessor(ABC):
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor):
self.hf_config = hf_config self.hf_config = hf_config
self._processor = _processor self._processor = _processor
self.server_args = server_args
self.executor = concurrent.futures.ProcessPoolExecutor( self.executor = concurrent.futures.ProcessPoolExecutor(
initializer=init_global_processor, initializer=init_global_processor,
...@@ -126,7 +129,12 @@ class LlavaImageProcessor(BaseImageProcessor): ...@@ -126,7 +129,12 @@ class LlavaImageProcessor(BaseImageProcessor):
) )
async def process_images_async( async def process_images_async(
self, image_data: List[Union[str, bytes]], input_text, request_obj self,
image_data: List[Union[str, bytes]],
input_text,
request_obj,
*args,
**kwargs,
): ):
if not image_data: if not image_data:
return None return None
...@@ -229,6 +237,147 @@ class MllamaImageProcessor(BaseImageProcessor): ...@@ -229,6 +237,147 @@ class MllamaImageProcessor(BaseImageProcessor):
return image_inputs return image_inputs
class MiniCPMVImageProcessor(BaseImageProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
@staticmethod
def _process_images_task(images, input_text):
result = global_processor.__call__(
text=input_text, images=images, return_tensors="pt"
)
return {
"input_ids": result["input_ids"],
"pixel_values": result["pixel_values"],
"tgt_sizes": result["tgt_sizes"],
}
async def _process_images(self, images, input_text):
if self.executor is not None:
loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor(
self.executor,
MiniCPMVImageProcessor._process_images_task,
images,
input_text,
)
else:
image_inputs = self._processor(
images=images, text=input_text, return_tensors="pt"
)
return image_inputs
async def process_images_async(
self,
image_data: List[Union[str, bytes]],
input_text,
request_obj,
max_req_input_len,
):
if not image_data:
return None
if not isinstance(image_data, list):
image_data = [image_data]
image_hashes, image_sizes = [], []
raw_images = []
IMAGE_TOKEN = "(<image>./</image>)"
# roughly calculate the max number of frames
# TODO: the process should be applied to all the visual inputs
def calculate_max_num_frames() -> int:
# Model-specific
NUM_TOKEN_PER_FRAME = 330
ret = (max_req_input_len - len(input_text)) // NUM_TOKEN_PER_FRAME
return min(ret, 100)
# if cuda OOM set a smaller number
MAX_NUM_FRAMES = calculate_max_num_frames()
print(f"MAX_NUM_FRAMES: {MAX_NUM_FRAMES}")
def encode_video(video_path):
if not os.path.exists(video_path):
logger.error(f"Video {video_path} does not exist")
return []
if MAX_NUM_FRAMES == 0:
return []
def uniform_sample(l, n):
gap = len(l) / n
idxs = [int(i * gap + gap / 2) for i in range(n)]
return [l[i] for i in idxs]
vr = VideoReader(video_path, ctx=cpu(0))
sample_fps = round(vr.get_avg_fps() / 1) # FPS
frame_idx = [i for i in range(0, len(vr), sample_fps)]
if len(frame_idx) > MAX_NUM_FRAMES:
frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES)
frames = vr.get_batch(frame_idx).asnumpy()
frames = [Image.fromarray(v.astype("uint8")) for v in frames]
return frames
if isinstance(input_text, list):
assert len(input_text) and isinstance(input_text[0], int)
input_text = self._processor.tokenizer.decode(input_text)
# MiniCPMV requires each frame of video as a single image token
text_parts = input_text.split(IMAGE_TOKEN)
new_text_parts = []
for image_index, image in enumerate(image_data):
try:
if isinstance(image, str) and image.startswith("video:"):
path = image[len("video:") :]
frames = encode_video(path)
else:
raw_image, size = load_image(image)
frames = [raw_image]
if len(frames) == 0:
continue
except FileNotFoundError as e:
print(e)
return None
image_sizes += frames[0].size * len(frames)
image_hashes += [hash(image)] * len(frames)
raw_images += frames
new_text_parts.append(text_parts[image_index])
new_text_parts.append(IMAGE_TOKEN * len(frames))
new_text_parts.append(text_parts[-1])
input_text = "".join(new_text_parts)
if len(raw_images) == 0:
return None
res = await self._process_images(images=raw_images, input_text=input_text)
pixel_values = res["pixel_values"]
tgt_sizes = res["tgt_sizes"]
input_ids = res["input_ids"]
# Collect special token ids
tokenizer = self._processor.tokenizer
im_start_id = [tokenizer.im_start_id]
im_end_id = [tokenizer.im_end_id]
if tokenizer.slice_start_id:
slice_start_id = [tokenizer.slice_start_id]
slice_end_id = [tokenizer.slice_end_id]
return {
"input_ids": input_ids.flatten().tolist(),
"pixel_values": pixel_values,
"tgt_sizes": tgt_sizes,
"image_hashes": image_hashes,
"modalities": request_obj.modalities or ["image"],
"im_start_id": im_start_id,
"im_end_id": im_end_id,
"slice_start_id": slice_start_id,
"slice_end_id": slice_end_id,
}
class Qwen2VLImageProcessor(BaseImageProcessor): class Qwen2VLImageProcessor(BaseImageProcessor):
def __init__(self, hf_config, server_args, _image_processor): def __init__(self, hf_config, server_args, _image_processor):
self.hf_config = hf_config self.hf_config = hf_config
...@@ -289,7 +438,12 @@ class Qwen2VLImageProcessor(BaseImageProcessor): ...@@ -289,7 +438,12 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
return self._process_single_image_task(image_data) return self._process_single_image_task(image_data)
async def process_images_async( async def process_images_async(
self, image_data: List[Union[str, bytes]], input_text, request_obj self,
image_data: List[Union[str, bytes]],
input_text,
request_obj,
*args,
**kwargs,
): ):
if not image_data: if not image_data:
return None return None
...@@ -350,6 +504,8 @@ def get_image_processor( ...@@ -350,6 +504,8 @@ def get_image_processor(
return MllamaImageProcessor(hf_config, server_args, processor) return MllamaImageProcessor(hf_config, server_args, processor)
elif "Qwen2VLForConditionalGeneration" in hf_config.architectures: elif "Qwen2VLForConditionalGeneration" in hf_config.architectures:
return Qwen2VLImageProcessor(hf_config, server_args, processor.image_processor) return Qwen2VLImageProcessor(hf_config, server_args, processor.image_processor)
elif "MiniCPMV" in hf_config.architectures:
return MiniCPMVImageProcessor(hf_config, server_args, processor)
else: else:
return LlavaImageProcessor(hf_config, server_args, processor.image_processor) return LlavaImageProcessor(hf_config, server_args, processor.image_processor)
......
...@@ -52,7 +52,6 @@ from sglang.srt.server_args import ServerArgs ...@@ -52,7 +52,6 @@ from sglang.srt.server_args import ServerArgs
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
# Put some global args for easy access # Put some global args for easy access
...@@ -68,7 +67,6 @@ global_server_args_dict = { ...@@ -68,7 +67,6 @@ global_server_args_dict = {
"device": ServerArgs.device, "device": ServerArgs.device,
} }
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -149,6 +147,16 @@ class ImageInputs: ...@@ -149,6 +147,16 @@ class ImageInputs:
image_grid_thws: List[Tuple[int, int, int]] = None image_grid_thws: List[Tuple[int, int, int]] = None
mrope_position_delta: Optional[torch.Tensor] = None mrope_position_delta: Optional[torch.Tensor] = None
# MiniCPMV related
# All the images in the batch should share the same special image
# bound token ids.
im_start_id: Optional[torch.Tensor] = None
im_end_id: Optional[torch.Tensor] = None
slice_start_id: Optional[torch.Tensor] = None
slice_end_id: Optional[torch.Tensor] = None
tgt_sizes: Optional[list] = None
@staticmethod @staticmethod
def from_dict(obj: dict): def from_dict(obj: dict):
ret = ImageInputs( ret = ImageInputs(
...@@ -168,6 +176,11 @@ class ImageInputs: ...@@ -168,6 +176,11 @@ class ImageInputs:
"aspect_ratio_ids", "aspect_ratio_ids",
"aspect_ratio_mask", "aspect_ratio_mask",
"image_grid_thws", "image_grid_thws",
"im_start_id",
"im_end_id",
"slice_start_id",
"slice_end_id",
"tgt_sizes",
] ]
for arg in optional_args: for arg in optional_args:
if arg in obj: if arg in obj:
...@@ -1140,7 +1153,6 @@ class ScheduleBatch: ...@@ -1140,7 +1153,6 @@ class ScheduleBatch:
global bid global bid
bid += 1 bid += 1
return ModelWorkerBatch( return ModelWorkerBatch(
bid=bid, bid=bid,
forward_mode=self.forward_mode, forward_mode=self.forward_mode,
......
...@@ -274,7 +274,6 @@ class Scheduler: ...@@ -274,7 +274,6 @@ class Scheduler:
self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func() self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
global_server_args_dict.update(worker_global_server_args_dict) global_server_args_dict.update(worker_global_server_args_dict)
set_random_seed(self.random_seed) set_random_seed(self.random_seed)
# Print debug info # Print debug info
logger.info( logger.info(
f"max_total_num_tokens={self.max_total_num_tokens}, " f"max_total_num_tokens={self.max_total_num_tokens}, "
...@@ -1729,7 +1728,11 @@ def run_scheduler_process( ...@@ -1729,7 +1728,11 @@ def run_scheduler_process(
try: try:
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank) scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
pipe_writer.send( pipe_writer.send(
{"status": "ready", "max_total_num_tokens": scheduler.max_total_num_tokens} {
"status": "ready",
"max_total_num_tokens": scheduler.max_total_num_tokens,
"max_req_input_len": scheduler.max_req_input_len,
}
) )
if scheduler.enable_overlap: if scheduler.enable_overlap:
scheduler.event_loop_overlap() scheduler.event_loop_overlap()
......
...@@ -112,6 +112,7 @@ class TokenizerManager: ...@@ -112,6 +112,7 @@ class TokenizerManager:
port_args: PortArgs, port_args: PortArgs,
): ):
# Parse args # Parse args
self.server_args = server_args self.server_args = server_args
self.enable_metrics = server_args.enable_metrics self.enable_metrics = server_args.enable_metrics
self.log_requests = server_args.log_requests self.log_requests = server_args.log_requests
...@@ -207,6 +208,8 @@ class TokenizerManager: ...@@ -207,6 +208,8 @@ class TokenizerManager:
self.resume_memory_occupation_communicator = _Communicator( self.resume_memory_occupation_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size self.send_to_scheduler, server_args.dp_size
) )
# Set after scheduler is initialized
self.max_req_input_len = None
# Metrics # Metrics
if self.enable_metrics: if self.enable_metrics:
...@@ -281,7 +284,7 @@ class TokenizerManager: ...@@ -281,7 +284,7 @@ class TokenizerManager:
if self.is_generation: if self.is_generation:
# TODO: also support getting embeddings for multimodal models # TODO: also support getting embeddings for multimodal models
image_inputs: Dict = await self.image_processor.process_images_async( image_inputs: Dict = await self.image_processor.process_images_async(
obj.image_data, input_text or input_ids, obj obj.image_data, input_text or input_ids, obj, self.max_req_input_len
) )
if image_inputs and "input_ids" in image_inputs: if image_inputs and "input_ids" in image_inputs:
input_ids = image_inputs["input_ids"] input_ids = image_inputs["input_ids"]
......
...@@ -237,7 +237,7 @@ class ModelRunner: ...@@ -237,7 +237,7 @@ class ModelRunner:
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce) set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
if not self.is_draft_worker: if not self.is_draft_worker:
# Only initilzie the distributed environment on the target model worker. # Only initialize the distributed environment on the target model worker.
init_distributed_environment( init_distributed_environment(
backend=backend, backend=backend,
world_size=self.tp_size, world_size=self.tp_size,
......
This diff is collapsed.
...@@ -248,6 +248,9 @@ class Qwen2Model(nn.Module): ...@@ -248,6 +248,9 @@ class Qwen2Model(nn.Module):
) )
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -296,7 +299,6 @@ class Qwen2Model(nn.Module): ...@@ -296,7 +299,6 @@ class Qwen2Model(nn.Module):
class Qwen2ForCausalLM(nn.Module): class Qwen2ForCausalLM(nn.Module):
# BitandBytes specific attributes # BitandBytes specific attributes
default_bitsandbytes_target_modules = [ default_bitsandbytes_target_modules = [
".gate_proj.", ".gate_proj.",
...@@ -334,6 +336,9 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -334,6 +336,9 @@ class Qwen2ForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
@torch.no_grad() @torch.no_grad()
def forward( def forward(
self, self,
......
...@@ -37,9 +37,7 @@ from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig ...@@ -37,9 +37,7 @@ from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig
from sglang.srt.distributed import parallel_state from sglang.srt.distributed import parallel_state
from sglang.srt.distributed import utils as dist_utils from sglang.srt.distributed import utils as dist_utils
from sglang.srt.hf_transformers_utils import get_processor from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.attention.triton_ops.prefill_attention import ( from sglang.srt.layers.attention.vision import VisionAttention
context_attention_fwd,
)
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.pooler import Pooler, PoolingType
...@@ -52,6 +50,7 @@ from sglang.srt.models.qwen2 import Qwen2Model ...@@ -52,6 +50,7 @@ from sglang.srt.models.qwen2 import Qwen2Model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# === Vision Inputs === # # === Vision Inputs === #
...@@ -110,118 +109,6 @@ class Qwen2VisionMLP(nn.Module): ...@@ -110,118 +109,6 @@ class Qwen2VisionMLP(nn.Module):
return x return x
def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
if not interleaved:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
else:
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
)
def apply_rotary_emb_torch(
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
) -> torch.Tensor:
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
"""
ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
cos = repeat(
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
sin = repeat(
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
return torch.cat(
[
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
x[..., ro_dim:],
],
dim=-1,
)
def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
t_ = t.float()
cos = freqs.cos()
sin = freqs.sin()
output = apply_rotary_emb_torch(t_, cos, sin).type_as(t)
return output
class Qwen2VisionAttention(nn.Module):
def __init__(
self,
embed_dim: Optional[int] = None,
num_heads: Optional[int] = None,
projection_size: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
# Per attention head and per partition values.
world_size = parallel_state.get_tensor_model_parallel_world_size()
self.hidden_size_per_attention_head = dist_utils.divide(
projection_size, num_heads
)
self.num_attention_heads_per_partition = dist_utils.divide(
num_heads, world_size
)
self.qkv = ColumnParallelLinear(
input_size=embed_dim,
output_size=3 * projection_size,
quant_config=quant_config,
)
self.proj = RowParallelLinear(
input_size=projection_size, output_size=embed_dim, quant_config=quant_config
)
def forward(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor = None,
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)
# [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
new_x_shape = x.size()[:-1] + (
self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head,
)
x = x.view(*new_x_shape)
# [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
q, k, v = dist_utils.split_tensor_along_last_dim(x, 3)
batch_size = q.shape[1]
q, k, v = [rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)]
if rotary_pos_emb is not None:
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
max_seqlen = (seq_lens).max().item()
q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
output = torch.empty_like(q)
context_attention_fwd(
q, k, v, output, cu_seqlens, seq_lens, max_seqlen, is_causal=False
)
context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
output, _ = self.proj(context_layer)
return output
class Qwen2VisionBlock(nn.Module): class Qwen2VisionBlock(nn.Module):
def __init__( def __init__(
...@@ -240,10 +127,11 @@ class Qwen2VisionBlock(nn.Module): ...@@ -240,10 +127,11 @@ class Qwen2VisionBlock(nn.Module):
self.norm2 = norm_layer(dim) self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio) mlp_hidden_dim = int(dim * mlp_ratio)
self.attn = Qwen2VisionAttention( self.attn = VisionAttention(
embed_dim=dim, embed_dim=dim,
num_heads=num_heads, num_heads=num_heads,
projection_size=dim, projection_size=dim,
use_qkv_parallel=False,
quant_config=quant_config, quant_config=quant_config,
) )
self.mlp = Qwen2VisionMLP( self.mlp = Qwen2VisionMLP(
...@@ -253,9 +141,13 @@ class Qwen2VisionBlock(nn.Module): ...@@ -253,9 +141,13 @@ class Qwen2VisionBlock(nn.Module):
def forward( def forward(
self, x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor self, x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor
) -> torch.Tensor: ) -> torch.Tensor:
x = x + self.attn( hidden_states = self.norm1(x)
self.norm1(x), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb hidden_states = rearrange(hidden_states, "s b ... -> b s ...")
attn = self.attn(
hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
) )
attn = rearrange(attn, "b s ... -> s b ...")
x = x + attn
x = x + self.mlp(self.norm2(x)) x = x + self.mlp(self.norm2(x))
return x return x
...@@ -684,10 +576,12 @@ class Qwen2VLForConditionalGeneration(nn.Module): ...@@ -684,10 +576,12 @@ class Qwen2VLForConditionalGeneration(nn.Module):
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
...@@ -696,6 +590,7 @@ class Qwen2VLForConditionalGeneration(nn.Module): ...@@ -696,6 +590,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
else: else:
if "visual" in name and "qkv.weight" in name: if "visual" in name and "qkv.weight" in name:
visual_num_heads = self.config.vision_config.num_heads visual_num_heads = self.config.vision_config.num_heads
visual_embed_dim = self.config.vision_config.embed_dim visual_embed_dim = self.config.vision_config.embed_dim
...@@ -712,6 +607,11 @@ class Qwen2VLForConditionalGeneration(nn.Module): ...@@ -712,6 +607,11 @@ class Qwen2VLForConditionalGeneration(nn.Module):
loaded_weight = loaded_weight.view(3, visual_num_heads, head_size) loaded_weight = loaded_weight.view(3, visual_num_heads, head_size)
loaded_weight = loaded_weight.transpose(0, 1) loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1) loaded_weight = loaded_weight.reshape(-1)
if "visual" in name:
# adapt to VisionAttention
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
try: try:
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
......
...@@ -565,6 +565,7 @@ def launch_engine( ...@@ -565,6 +565,7 @@ def launch_engine(
# Assume all schedulers have same scheduler_info # Assume all schedulers have same scheduler_info
scheduler_info = scheduler_infos[0] scheduler_info = scheduler_infos[0]
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
def launch_server( def launch_server(
......
...@@ -451,6 +451,8 @@ def load_image(image_file: Union[str, bytes]): ...@@ -451,6 +451,8 @@ def load_image(image_file: Union[str, bytes]):
else: else:
raise ValueError(f"Invalid image: {image}") raise ValueError(f"Invalid image: {image}")
# if image_size is None:
# image_size = image.size
return image, image_size return image, image_size
......
...@@ -406,7 +406,7 @@ def popen_launch_server( ...@@ -406,7 +406,7 @@ def popen_launch_server(
base_url: str, base_url: str,
timeout: float, timeout: float,
api_key: Optional[str] = None, api_key: Optional[str] = None,
other_args: tuple = (), other_args: list[str] = (),
env: Optional[dict] = None, env: Optional[dict] = None,
return_stdout_stderr: Optional[tuple] = None, return_stdout_stderr: Optional[tuple] = None,
): ):
......
...@@ -25,7 +25,7 @@ export OPENAI_API_KEY=sk-***** ...@@ -25,7 +25,7 @@ export OPENAI_API_KEY=sk-*****
python3 test_openai_backend.py python3 test_openai_backend.py
# Run a single test # Run a single test
python3 -m unittest test_openai_backend.TestOpenAIBackend.test_few_shot_qa python3 -m unittest test_openai_backend.TestOpenAIServer.test_few_shot_qa
# Run a suite with multiple files # Run a suite with multiple files
python3 run_suite.py --suite per-commit python3 run_suite.py --suite per-commit
......
...@@ -171,7 +171,7 @@ class TestOpenAIVisionServer(unittest.TestCase): ...@@ -171,7 +171,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
text = response.choices[0].message.content text = response.choices[0].message.content
assert isinstance(text, str) assert isinstance(text, str)
print(text) print(text)
assert "man" in text or "cab" in text, text assert "man" in text or "cab" in text or "SUV" in text or "taxi" in text, text
assert "logo" in text or '"S"' in text or "SG" in text, text assert "logo" in text or '"S"' in text or "SG" in text, text
assert response.id assert response.id
assert response.created assert response.created
...@@ -444,5 +444,24 @@ class TestMllamaServer(TestOpenAIVisionServer): ...@@ -444,5 +444,24 @@ class TestMllamaServer(TestOpenAIVisionServer):
pass pass
class TestMinicpmvServer(TestOpenAIVisionServer):
@classmethod
def setUpClass(cls):
cls.model = "openbmb/MiniCPM-V-2_6"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--chat-template",
"minicpmv",
],
)
cls.base_url += "/v1"
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment