Unverified Commit ff2ce0b8 authored by Mick's avatar Mick Committed by GitHub
Browse files

refactor: move image processors to separate files (#4229)

parent 0f2a2e3c
...@@ -11,11 +11,16 @@ import argparse ...@@ -11,11 +11,16 @@ import argparse
import random import random
import torch import torch
from bench_sglang import EvalArgs, prepare_samples
from data_utils import save_json from data_utils import save_json
from eval_utils import eval_result, get_sampling_params, parse_multi_choice_response from eval_utils import (
EvalArgs,
eval_result,
get_sampling_params,
prepare_samples,
process_result,
)
from tqdm import tqdm from tqdm import tqdm
from transformers import AutoModelForImageTextToText, AutoProcessor from transformers import AutoModelForImageTextToText, AutoProcessor, GenerationConfig
@torch.no_grad() @torch.no_grad()
...@@ -28,7 +33,6 @@ def eval_mmmu(args): ...@@ -28,7 +33,6 @@ def eval_mmmu(args):
trust_remote_code=True, trust_remote_code=True,
) )
model = model.eval().cuda() model = model.eval().cuda()
model = torch.compile(model)
processor = AutoProcessor.from_pretrained( processor = AutoProcessor.from_pretrained(
args.model_path, torch_dtype="auto", device_map="auto" args.model_path, torch_dtype="auto", device_map="auto"
...@@ -38,6 +42,10 @@ def eval_mmmu(args): ...@@ -38,6 +42,10 @@ def eval_mmmu(args):
out_samples = dict() out_samples = dict()
sampling_params = get_sampling_params(eval_args) sampling_params = get_sampling_params(eval_args)
generation_config = GenerationConfig(
max_new_tokens=sampling_params["max_new_tokens"],
do_sample=False,
)
answer_dict = {} answer_dict = {}
for sample in tqdm(samples): for sample in tqdm(samples):
...@@ -62,7 +70,6 @@ def eval_mmmu(args): ...@@ -62,7 +70,6 @@ def eval_mmmu(args):
text = processor.apply_chat_template( text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages, tokenize=False, add_generation_prompt=True
) )
inputs = processor( inputs = processor(
text=[text], text=[text],
images=[image], images=[image],
...@@ -70,13 +77,16 @@ def eval_mmmu(args): ...@@ -70,13 +77,16 @@ def eval_mmmu(args):
return_tensors="pt", return_tensors="pt",
).to(model.device) ).to(model.device)
generated_ids = model.generate(**inputs, **sampling_params) generated_ids = model.generate(
**inputs, generation_config=generation_config
)
response = processor.decode( response = processor.decode(
generated_ids[0], generated_ids[0],
skip_special_tokens=True, skip_special_tokens=True,
clean_up_tokenization_spaces=False, clean_up_tokenization_spaces=False,
)[len(text) :] )[len(text) :]
print(f"response: {response}")
else: # multiple images actually else: # multiple images actually
if sample["question_type"] == "multiple-choice": if sample["question_type"] == "multiple-choice":
all_choices = sample["all_choices"] all_choices = sample["all_choices"]
...@@ -85,24 +95,11 @@ def eval_mmmu(args): ...@@ -85,24 +95,11 @@ def eval_mmmu(args):
else: else:
response = "INVALID GENERATION FOR MULTIPLE IMAGE INPUTS" response = "INVALID GENERATION FOR MULTIPLE IMAGE INPUTS"
if sample["question_type"] == "multiple-choice": process_result(response, sample, answer_dict, out_samples)
pred_ans = parse_multi_choice_response(
response, sample["all_choices"], sample["index2ans"]
)
else: # open question
pred_ans = response
out_samples[sample["id"]] = pred_ans
torch.cuda.empty_cache()
# set ground truth answer
answer_dict[sample["id"]] = {
"question_type": sample["question_type"],
"ground_truth": sample["answer"],
}
args.output_path = f"{args.model_path}_val_hf.json" args.output_path = f"{args.model_path}_val_hf.json"
save_json(args.output_path, out_samples) save_json(args.output_path, out_samples)
eval_result(output_path=args.output_path, answer_dict=answer_dict) eval_result(model_answer_path=args.output_path, answer_dict=answer_dict)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -8,9 +8,9 @@ ...@@ -8,9 +8,9 @@
""" """
import argparse import argparse
import base64
import dataclasses import dataclasses
import random import random
import re
from io import BytesIO from io import BytesIO
from data_utils import save_json from data_utils import save_json
...@@ -18,13 +18,14 @@ from eval_utils import ( ...@@ -18,13 +18,14 @@ from eval_utils import (
EvalArgs, EvalArgs,
eval_result, eval_result,
get_sampling_params, get_sampling_params,
parse_multi_choice_response,
prepare_samples, prepare_samples,
process_result,
) )
from tqdm import tqdm from tqdm import tqdm
from sglang import Engine from sglang import Engine
from sglang.srt.conversation import chat_templates from sglang.srt.conversation import generate_chat_conv
from sglang.srt.openai_api.protocol import ChatCompletionRequest
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
...@@ -35,61 +36,76 @@ def eval_mmmu(args): ...@@ -35,61 +36,76 @@ def eval_mmmu(args):
if server_args.chat_template is None: if server_args.chat_template is None:
raise ValueError("Chat template must be provided for this benchmark") raise ValueError("Chat template must be provided for this benchmark")
samples = prepare_samples(eval_args)
backend = Engine(**dataclasses.asdict(server_args)) backend = Engine(**dataclasses.asdict(server_args))
out_samples = dict() out_samples = dict()
sampling_params = get_sampling_params(eval_args) sampling_params = get_sampling_params(eval_args)
conv = chat_templates[server_args.chat_template].copy() samples = prepare_samples(eval_args)
image_token = conv.image_token
answer_dict = {} answer_dict = {}
for sample in tqdm(samples): for sample in tqdm(samples):
prompt = sample["final_input_prompt"] prompt = sample["final_input_prompt"]
image = sample["image"] image = sample["image"]
bytes_io = BytesIO() buff = BytesIO()
image.save(bytes_io, format="PNG") image.save(buff, format="PNG")
png_bytes = bytes_io.getvalue() base64_str = base64.b64encode(buff.getvalue()).decode("utf-8")
prefix = prompt.split("<")[0]
prompt = re.sub(r"<[^>]*>", image_token, prompt) suffix = prompt.split(">")[1]
request_dict = {
"model": "",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": prefix,
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_str}"
},
},
{
"type": "text",
"text": suffix,
},
],
}
],
}
conv = generate_chat_conv(
ChatCompletionRequest(**request_dict),
template_name=server_args.chat_template,
)
prompt = conv.get_prompt()
if image is not None: if image is not None:
gen_out = backend.generate( gen_out = backend.generate(
prompt=prompt, image_data=[png_bytes], sampling_params=sampling_params prompt=prompt,
image_data=conv.image_data,
sampling_params=sampling_params,
)["text"] )["text"]
response = gen_out response = gen_out
else: # multiple images actually else: # multiple images actually
if sample["question_type"] == "multiple-choice": if sample["question_type"] == "multiple-choice":
all_choices = sample["all_choices"] all_choices = sample["all_choices"]
response = random.choice(all_choices) response = random.choice(all_choices)
else: else:
response = "INVALID GENERATION FOR MULTIPLE IMAGE INPUTS" response = "INVALID GENERATION FOR MULTIPLE IMAGE INPUTS"
if sample["question_type"] == "multiple-choice": process_result(response, sample, answer_dict, out_samples)
pred_ans = parse_multi_choice_response(
response, sample["all_choices"], sample["index2ans"]
)
else: # open question
pred_ans = response
out_samples[sample["id"]] = pred_ans
# set ground truth answer
answer_dict[sample["id"]] = {
"question_type": sample["question_type"],
"ground_truth": (
sample["correct_choice"]
if "correct_choice" in samples
else sample["answer"]
),
}
args.output_path = f"{args.model_path}_val_sglang.json" args.output_path = f"{args.model_path}_val_sglang.json"
save_json(args.output_path, out_samples) save_json(args.output_path, out_samples)
eval_result(output_path=args.output_path, answer_dict=answer_dict) eval_result(model_answer_path=args.output_path, answer_dict=answer_dict)
backend.shutdown()
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -143,6 +143,7 @@ def process_single_sample(data): ...@@ -143,6 +143,7 @@ def process_single_sample(data):
# DATA SAVING # DATA SAVING
def save_json(filename, ds): def save_json(filename, ds):
print(f"answers saved to: {filename}")
os.makedirs(os.path.dirname(filename), exist_ok=True) os.makedirs(os.path.dirname(filename), exist_ok=True)
with open(filename, "w") as f: with open(filename, "w") as f:
json.dump(ds, f, indent=4) json.dump(ds, f, indent=4)
......
...@@ -87,6 +87,7 @@ def set_seed(seed_value): ...@@ -87,6 +87,7 @@ def set_seed(seed_value):
def prepare_samples(eval_args: EvalArgs): def prepare_samples(eval_args: EvalArgs):
print("preparing samples...")
# Build prompts # Build prompts
set_seed(eval_args.seed) set_seed(eval_args.seed)
...@@ -110,6 +111,7 @@ def prepare_samples(eval_args: EvalArgs): ...@@ -110,6 +111,7 @@ def prepare_samples(eval_args: EvalArgs):
eval_args.dataset_path, subject, split=eval_args.split eval_args.dataset_path, subject, split=eval_args.split
) )
sub_dataset_list.append(sub_dataset) sub_dataset_list.append(sub_dataset)
# break
# merge all dataset # merge all dataset
dataset = concatenate_datasets(sub_dataset_list) dataset = concatenate_datasets(sub_dataset_list)
...@@ -426,9 +428,26 @@ def calculate_ins_level_acc(results: Dict): ...@@ -426,9 +428,26 @@ def calculate_ins_level_acc(results: Dict):
return acc / ins_num return acc / ins_num
def eval_result(output_path, answer_dict): def process_result(response, sample, answer_dict, out_samples):
if sample["question_type"] == "multiple-choice":
pred_ans = parse_multi_choice_response(
response, sample["all_choices"], sample["index2ans"]
)
else: # open question
pred_ans = response
out_samples[sample["id"]] = pred_ans
# set ground truth answer
answer_dict[sample["id"]] = {
"question_type": sample["question_type"],
"ground_truth": sample["answer"],
}
def eval_result(model_answer_path, answer_dict):
print("Evaluating...") print("Evaluating...")
output_dict = json.load(open(output_path)) output_dict = json.load(open(model_answer_path))
# answer_dict = json.load(open(answer_path)) # answer_dict = json.load(open(answer_path))
# group by category # group by category
...@@ -521,7 +540,7 @@ def eval_result(output_path, answer_dict): ...@@ -521,7 +540,7 @@ def eval_result(output_path, answer_dict):
"acc": overall_acc, "acc": overall_acc,
} }
pprint.pprint(printable_results) pprint.pprint(printable_results)
out = output_path out = model_answer_path
with open(out, "w", encoding="utf-8") as outfile: with open(out, "w", encoding="utf-8") as outfile:
json.dump(printable_results, outfile) json.dump(printable_results, outfile)
print(f"eval out saved to {out}") print(f"eval out saved to {out}")
......
...@@ -191,7 +191,7 @@ class Conversation: ...@@ -191,7 +191,7 @@ class Conversation:
for i, (role, message) in enumerate(self.messages): for i, (role, message) in enumerate(self.messages):
if i % 2 == 0: if i % 2 == 0:
ret += f"[Round {i//2 + round_add_n}]{self.sep}" ret += f"[Round {i // 2 + round_add_n}]{self.sep}"
if message: if message:
ret += f"{role}{message}{self.sep}" ret += f"{role}{message}{self.sep}"
...@@ -453,7 +453,6 @@ def generate_chat_conv( ...@@ -453,7 +453,6 @@ def generate_chat_conv(
conv.system_message = getattr(message.content[0], "text", "") conv.system_message = getattr(message.content[0], "text", "")
elif msg_role == "user": elif msg_role == "user":
# Handle the various types of Chat Request content types here. # Handle the various types of Chat Request content types here.
role = conv.roles[0]
if isinstance(message.content, str): if isinstance(message.content, str):
conv.append_message(conv.roles[0], message.content) conv.append_message(conv.roles[0], message.content)
else: else:
......
...@@ -66,6 +66,7 @@ def get_config( ...@@ -66,6 +66,7 @@ def get_config(
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
) )
if config.model_type in _CONFIG_REGISTRY: if config.model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[config.model_type] config_class = _CONFIG_REGISTRY[config.model_type]
config = config_class.from_pretrained(model, revision=revision) config = config_class.from_pretrained(model, revision=revision)
......
from __future__ import annotations from __future__ import annotations
from functools import lru_cache from functools import lru_cache
from typing import Optional from typing import Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -22,47 +22,29 @@ from sglang.srt.layers.quantization import QuantizationConfig ...@@ -22,47 +22,29 @@ from sglang.srt.layers.quantization import QuantizationConfig
from sglang.srt.utils import add_prefix from sglang.srt.utils import add_prefix
def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor: # Copied from transformers, modeling_qwen2_vl.py
if not interleaved: def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1) """Rotates half the hidden dims of the input."""
return torch.cat((-x2, x1), dim=-1) x1 = x[..., : x.shape[-1] // 2]
else: x2 = x[..., x.shape[-1] // 2 :]
x1, x2 = x[..., ::2], x[..., 1::2] return torch.cat((-x2, x1), dim=-1)
return rearrange(
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
)
def apply_rotary_emb_torch( def apply_rotary_pos_emb_vision(
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> torch.Tensor: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" orig_q_dtype = q.dtype
x: (batch_size, seqlen, nheads, headdim) orig_k_dtype = k.dtype
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) q, k = q.float(), k.float()
"""
ro_dim = cos.shape[-1] * 2 cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
assert ro_dim <= x.shape[-1] q_embed = (q * cos) + (rotate_half(q) * sin)
cos = repeat( k_embed = (k * cos) + (rotate_half(k) * sin)
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
) q_embed = q_embed.to(orig_q_dtype)
sin = repeat( k_embed = k_embed.to(orig_k_dtype)
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
) return q_embed, k_embed
return torch.cat(
[
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
x[..., ro_dim:],
],
dim=-1,
)
def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
t_ = t.float()
cos = freqs.cos()
sin = freqs.sin()
output = apply_rotary_emb_torch(t_, cos, sin).type_as(t)
return output
class VisionAttention(nn.Module): class VisionAttention(nn.Module):
...@@ -75,8 +57,8 @@ class VisionAttention(nn.Module): ...@@ -75,8 +57,8 @@ class VisionAttention(nn.Module):
use_context_forward (bool, default to True): use_context_forward (bool, default to True):
if ``True``, a flash_attn style attention will be applied if ``True``, a flash_attn style attention will be applied
Otherwise, a full-sequence attention will be applied. Otherwise, a full-sequence attention will be applied.
use_full_precision_softmax (bool, default to False): softmax_in_single_precision (bool, default to False):
if ``True``, the softmax will be performed in full-precision if ``True``, the softmax will be performed in single-precision
Otherwise, it will be performed in half-precision Otherwise, it will be performed in half-precision
""" """
...@@ -90,7 +72,7 @@ class VisionAttention(nn.Module): ...@@ -90,7 +72,7 @@ class VisionAttention(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
dropout: float = 0.0, dropout: float = 0.0,
use_context_forward: bool = True, use_context_forward: bool = True,
use_full_precision_softmax: bool = False, softmax_in_single_precision: bool = False,
flatten_batch: bool = False, flatten_batch: bool = False,
prefix: str = "", prefix: str = "",
): ):
...@@ -113,7 +95,7 @@ class VisionAttention(nn.Module): ...@@ -113,7 +95,7 @@ class VisionAttention(nn.Module):
head_size=self.head_size, head_size=self.head_size,
dropout=dropout, dropout=dropout,
flatten_batch=flatten_batch, flatten_batch=flatten_batch,
use_full_precision_softmax=use_full_precision_softmax, softmax_in_single_precision=softmax_in_single_precision,
) )
self.use_qkv_parallel = use_qkv_parallel self.use_qkv_parallel = use_qkv_parallel
...@@ -143,7 +125,7 @@ class VisionAttention(nn.Module): ...@@ -143,7 +125,7 @@ class VisionAttention(nn.Module):
self, self,
x: torch.Tensor, x: torch.Tensor,
cu_seqlens: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None,
rotary_pos_emb: torch.Tensor = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
r""" r"""
...@@ -151,21 +133,17 @@ class VisionAttention(nn.Module): ...@@ -151,21 +133,17 @@ class VisionAttention(nn.Module):
x: [b, s, embed_dim] x: [b, s, embed_dim]
cu_seqlens: [b] cu_seqlens: [b]
Returns: Returns:
[s, b, num_heads * head] [s, b, head * head_size]
""" """
bsz, s, _ = x.shape bsz, s, _ = x.shape
head = self.num_attention_heads_per_partition
if self.use_qkv_parallel: if self.use_qkv_parallel:
# [b, s, embed_dim] --> [b, s, embed_dim] # [b, s, embed_dim] --> [b, s, embed_dim]
qkv, _ = self.qkv_proj(x) qkv, _ = self.qkv_proj(x)
q, k, v = qkv.chunk(3, dim=-1) q, k, v = qkv.chunk(3, dim=-1)
# [b, s, embed_dim] --> [b * s, num_heads, head_size] # [b, s, embed_dim] --> [b * s, head, head_size]
q, k, v = [ q, k, v = [x.reshape(bsz * s, head, -1).contiguous() for x in (q, k, v)]
x.reshape(
bsz * s, self.num_attention_heads_per_partition, -1
).contiguous()
for x in (q, k, v)
]
else: else:
# [b, s, embed_dim] --> [s, b, embed_dim] # [b, s, embed_dim] --> [s, b, embed_dim]
x = rearrange(x, "b s ... -> s b ...") x = rearrange(x, "b s ... -> s b ...")
...@@ -173,7 +151,7 @@ class VisionAttention(nn.Module): ...@@ -173,7 +151,7 @@ class VisionAttention(nn.Module):
qkv, _ = self.qkv_proj(x) qkv, _ = self.qkv_proj(x)
# [s, b, head * 3 * head_size] --> [s, b, head, 3 * head_size] # [s, b, head * 3 * head_size] --> [s, b, head, 3 * head_size]
new_x_shape = qkv.size()[:-1] + ( new_x_shape = qkv.size()[:-1] + (
self.num_attention_heads_per_partition, head,
3 * self.hidden_size_per_attention_head, 3 * self.hidden_size_per_attention_head,
) )
qkv = qkv.view(*new_x_shape) qkv = qkv.view(*new_x_shape)
...@@ -186,9 +164,12 @@ class VisionAttention(nn.Module): ...@@ -186,9 +164,12 @@ class VisionAttention(nn.Module):
rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v) rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
] ]
if rotary_pos_emb is not None: if position_embeddings is not None:
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) cos, sin = position_embeddings
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) original_shape = q.shape
q, k = q.view(s, head, -1), k.view(s, head, -1)
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
q, k = q.reshape(original_shape), k.reshape(original_shape)
if self.use_qkv_parallel: if self.use_qkv_parallel:
pass pass
...@@ -230,12 +211,12 @@ class VisionSdpaAttention(nn.Module): ...@@ -230,12 +211,12 @@ class VisionSdpaAttention(nn.Module):
head_size: int, head_size: int,
dropout: float = 0.0, dropout: float = 0.0,
flatten_batch: bool = False, flatten_batch: bool = False,
use_full_precision_softmax: bool = False, softmax_in_single_precision: bool = False,
): ):
super().__init__() super().__init__()
self.head_size = head_size self.head_size = head_size
self.flatten_batch = flatten_batch self.flatten_batch = flatten_batch
self.use_full_precision_softmax = use_full_precision_softmax self.softmax_in_single_precision = softmax_in_single_precision
self.dropout = dropout self.dropout = dropout
@staticmethod @staticmethod
...@@ -319,14 +300,14 @@ class VisionSdpaAttention(nn.Module): ...@@ -319,14 +300,14 @@ class VisionSdpaAttention(nn.Module):
) )
if attention_mask is None: if attention_mask is None:
if self.use_full_precision_softmax: if self.softmax_in_single_precision:
raise RuntimeError("Empty attention mask") raise RuntimeError("Empty attention mask")
else: else:
attention_mask = attention_mask.to(device=q.device) attention_mask = attention_mask.to(device=q.device)
q, k, v = [rearrange(x, "(b s) h d -> b h s d", b=bsz) for x in [q, k, v]] q, k, v = [rearrange(x, "(b s) h d -> b h s d", b=bsz) for x in [q, k, v]]
if self.use_full_precision_softmax: if self.softmax_in_single_precision:
scale = self.head_size**-0.5 scale = self.head_size**-0.5
k_transposed = rearrange(k, "b h s d -> b h d s") k_transposed = rearrange(k, "b h s d -> b h d s")
attn_weights = torch.matmul(q, k_transposed) * scale attn_weights = torch.matmul(q, k_transposed) * scale
......
import concurrent
import concurrent.futures
import dataclasses
import multiprocessing as mp
import os
from abc import ABC, abstractmethod
from typing import Optional
import PIL
import transformers
from decord import VideoReader, cpu
from PIL import Image
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import load_image
global global_processor
def get_global_processor():
global global_processor
return global_processor
@dataclasses.dataclass
class BaseImageProcessorOutput:
image_hashes: list[int]
image_sizes: list[tuple[int, int]]
all_frames: [PIL.Image]
# input_text, with each frame of video/image represented as an image_token
input_text: str
class BaseImageProcessor(ABC):
def __init__(self, hf_config, server_args, _processor):
self.hf_config = hf_config
self._processor = _processor
self.server_args = server_args
# FIXME: not accurate, model and image specific
self.NUM_TOKEN_PER_FRAME = 330
self.executor = concurrent.futures.ProcessPoolExecutor(
initializer=init_global_processor,
mp_context=mp.get_context("fork"),
initargs=(
self,
server_args,
),
max_workers=int(os.environ.get("SGLANG_CPU_COUNT", os.cpu_count())),
)
def _build_processor(self, server_args):
"""Init the global processor for multi modal models."""
from sglang.srt.hf_transformers_utils import get_processor
return get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
@abstractmethod
async def process_images_async(
self, image_data, input_text, max_req_input_len, **kwargs
):
pass
def get_estimated_frames_list(self, image_data):
"""
estimate the total frame count from all visual input
"""
# Before processing inputs
estimated_frames_list = []
for image in image_data:
if isinstance(image, str) and image.startswith("video:"):
path = image[len("video:") :]
# Estimate frames for the video
vr = VideoReader(path, ctx=cpu(0))
num_frames = len(vr)
else:
# For images, each contributes one frame
num_frames = 1
estimated_frames_list.append(num_frames)
return estimated_frames_list
@staticmethod
def encode_video(video_path, frame_count_limit=None):
if not os.path.exists(video_path):
logger.error(f"Video {video_path} does not exist")
return []
if frame_count_limit == 0:
return []
def uniform_sample(l, n):
gap = len(l) / n
idxs = [int(i * gap + gap / 2) for i in range(n)]
return [l[i] for i in idxs]
vr = VideoReader(video_path, ctx=cpu(0))
sample_fps = round(vr.get_avg_fps() / 1) # FPS
frame_indices = [i for i in range(0, len(vr), sample_fps)]
if frame_count_limit is not None and len(frame_indices) > frame_count_limit:
frame_indices = uniform_sample(frame_indices, frame_count_limit)
frames = vr.get_batch(frame_indices).asnumpy()
frames = [Image.fromarray(v.astype("uint8")) for v in frames]
return frames
def load_images(
self,
input_ids: list,
image_data,
image_token: str,
max_req_input_len: int,
return_text: Optional[bool] = True,
discard_alpha_channel: bool = True,
) -> BaseImageProcessorOutput:
"""
Each frame of video/image will be replaced by a single image token
"""
image_hashes, image_sizes = [], []
all_frames = []
new_text_parts = []
if isinstance(input_ids, list) and return_text:
assert len(input_ids) and isinstance(input_ids[0], int)
input_text = self._processor.tokenizer.decode(input_ids)
else:
input_text = input_ids
if return_text:
text_parts = input_text.split(image_token)
# roughly calculate the max number of frames under the max_req_input_len limit
MAX_NUM_FRAMES = 30
estimated_frames_list = self.get_estimated_frames_list(image_data=image_data)
total_frame_count = sum(estimated_frames_list)
# a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
# e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
scaling_factor = min(1.0, MAX_NUM_FRAMES / total_frame_count)
assert len(image_data) == len(estimated_frames_list)
# Process each input with allocated frames
for image_index, (image, estimated_frames) in enumerate(
zip(image_data, estimated_frames_list)
):
if len(all_frames) >= MAX_NUM_FRAMES:
max_frames_to_process = 0
else:
max_frames_to_process = max(1, int(estimated_frames * scaling_factor))
if max_frames_to_process == 0:
frames = []
else:
try:
if isinstance(image, str) and image.startswith("video:"):
path = image[len("video:") :]
frames = BaseImageProcessor.encode_video(
path, frame_count_limit=max_frames_to_process
)
else:
raw_image, _size = load_image(image)
if discard_alpha_channel:
raw_image = raw_image.convert("RGB")
frames = [raw_image]
assert len(frames) != 0
except FileNotFoundError as e:
print(e)
return None
image_sizes += [frames[0].size] * len(frames)
image_hashes += [hash(image)] * len(frames)
all_frames += frames
if return_text:
new_text_parts.append(text_parts[image_index])
if max_frames_to_process != 0:
new_text_parts.append(image_token * len(frames))
assert max_frames_to_process >= len(frames)
if return_text:
new_text_parts.append(text_parts[-1])
input_text = "".join(new_text_parts)
return BaseImageProcessorOutput(
image_hashes, image_sizes, all_frames, input_text
)
class DummyImageProcessor(BaseImageProcessor):
def __init__(self):
pass
async def process_images_async(self, *args, **kwargs):
return None
def init_global_processor(
sglang_image_processor: BaseImageProcessor, server_args: ServerArgs
):
"""Init the global processor for multi-modal models."""
global global_processor
transformers.logging.set_verbosity_error()
global_processor = sglang_image_processor._build_processor(server_args=server_args)
import asyncio
from typing import List, Optional, Union
import numpy as np
from sglang.srt.managers.image_processor import BaseImageProcessor
from sglang.srt.managers.image_processors.base_image_processor import (
get_global_processor,
)
from sglang.srt.mm_utils import expand2square, process_anyres_image
from sglang.srt.models.llava import LlavaMistralForCausalLM, LlavaQwenForCausalLM
from sglang.srt.models.llavavid import LlavaVidForCausalLM
from sglang.srt.utils import load_image, logger
from sglang.utils import get_exception_traceback
class LlavaImageProcessor(BaseImageProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
@staticmethod
def _process_single_image_task(
image_data: Union[str, bytes],
image_aspect_ratio: Optional[str] = None,
image_grid_pinpoints: Optional[str] = None,
image_processor=None,
):
processor = get_global_processor()
image_processor = image_processor or processor.image_processor
try:
image, image_size = load_image(image_data)
if image_size is not None:
# It is a video with multiple images
image_hash = hash(image_data)
pixel_values = image_processor(image)["pixel_values"]
for _ in range(len(pixel_values)):
pixel_values[_] = pixel_values[_].astype(np.float16)
pixel_values = np.stack(pixel_values, axis=0)
return pixel_values, image_hash, image_size
else:
# It is an image
image_hash = hash(image_data)
if image_aspect_ratio == "pad":
image = expand2square(
image,
tuple(int(x * 255) for x in image_processor.image_mean),
)
pixel_values = image_processor(image.convert("RGB"))[
"pixel_values"
][0]
elif image_aspect_ratio == "anyres" or (
image_aspect_ratio is not None
and "anyres_max" in image_aspect_ratio
):
pixel_values = process_anyres_image(
image, image_processor, image_grid_pinpoints
)
else:
pixel_values = image_processor(image)["pixel_values"][0]
if isinstance(pixel_values, np.ndarray):
pixel_values = pixel_values.astype(np.float16)
return pixel_values, image_hash, image.size
except Exception:
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
async def _process_single_image(
self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
):
if self.executor is not None:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
LlavaImageProcessor._process_single_image_task,
image_data,
aspect_ratio,
grid_pinpoints,
)
else:
return self._process_single_image_task(
image_data, aspect_ratio, grid_pinpoints
)
async def process_images_async(
self,
image_data: List[Union[str, bytes]],
input_text,
request_obj,
*args,
**kwargs,
):
if not image_data:
return None
modalities = request_obj.modalities or ["image"]
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
grid_pinpoints = (
self.hf_config.image_grid_pinpoints
if hasattr(self.hf_config, "image_grid_pinpoints")
and "anyres" in aspect_ratio
else None
)
if isinstance(image_data, str):
image_data = [image_data]
if isinstance(image_data, list) and len(image_data) > 0:
if "multi-images" in modalities or "video" in modalities:
# Multiple images
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
pixel_values, image_hashes, image_sizes = [], [], []
res = []
for img_data in image_data:
res.append(
self._process_single_image(
img_data, aspect_ratio, grid_pinpoints
)
)
res = await asyncio.gather(*res)
for pixel_v, image_h, image_s in res:
pixel_values.append(pixel_v)
image_hashes.append(image_h)
image_sizes.append(image_s)
if isinstance(pixel_values[0], np.ndarray):
pixel_values = np.stack(pixel_values, axis=0)
else:
# A single image
pixel_values, image_hash, image_size = await self._process_single_image(
image_data[0], aspect_ratio, grid_pinpoints
)
image_hashes = [image_hash]
image_sizes = [image_size]
else:
raise ValueError(f"Invalid image data: {image_data}")
return {
"pixel_values": pixel_values,
"image_hashes": image_hashes,
"image_sizes": image_sizes,
"modalities": request_obj.modalities or ["image"],
}
ImageProcessorMapping = {
LlavaVidForCausalLM: LlavaImageProcessor,
LlavaQwenForCausalLM: LlavaImageProcessor,
LlavaMistralForCausalLM: LlavaImageProcessor,
}
import asyncio
from typing import List, Union
from sglang.srt.managers.image_processor import BaseImageProcessor
from sglang.srt.managers.image_processors.base_image_processor import (
get_global_processor,
)
from sglang.srt.models.minicpmv import MiniCPMV
class MiniCPMVImageProcessor(BaseImageProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "(<image>./</image>)"
@staticmethod
def _process_images_task(images, input_text):
processor = get_global_processor()
result = processor.__call__(text=input_text, images=images, return_tensors="pt")
return {
"input_ids": result.input_ids,
"pixel_values": result.pixel_values,
"tgt_sizes": result.tgt_sizes,
}
async def _process_images(self, images, input_text):
if self.executor is not None:
loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor(
self.executor,
MiniCPMVImageProcessor._process_images_task,
images,
input_text,
)
else:
image_inputs = self._processor(
images=images, text=input_text, return_tensors="pt"
)
return image_inputs
async def process_images_async(
self,
image_data: List[Union[str, bytes]],
input_ids,
request_obj,
max_req_input_len,
):
if not image_data:
return None
if not isinstance(image_data, list):
image_data = [image_data]
base_output = self.load_images(
input_ids, image_data, self.IMAGE_TOKEN, max_req_input_len
)
if base_output is None:
return None
if len(base_output.all_frames) == 0:
return None
res = await self._process_images(
images=base_output.all_frames, input_text=base_output.input_text
)
# Collect special token ids
tokenizer = self._processor.tokenizer
im_start_id = tokenizer.im_start_id
im_end_id = tokenizer.im_end_id
if tokenizer.slice_start_id:
slice_start_id = tokenizer.slice_start_id
slice_end_id = tokenizer.slice_end_id
return {
"input_ids": res["input_ids"].flatten().tolist(),
"pixel_values": res["pixel_values"],
"tgt_sizes": res["tgt_sizes"],
"image_hashes": base_output.image_hashes,
"modalities": request_obj.modalities or ["image"],
"im_start_id": im_start_id,
"im_end_id": im_end_id,
"slice_start_id": slice_start_id,
"slice_end_id": slice_end_id,
}
ImageProcessorMapping = {MiniCPMV: MiniCPMVImageProcessor}
import asyncio
from typing import List, Union
from sglang.srt.managers.image_processor import BaseImageProcessor
from sglang.srt.managers.image_processors.base_image_processor import (
get_global_processor,
)
from sglang.srt.models.mllama import MllamaForConditionalGeneration
from sglang.srt.utils import load_image
class MllamaImageProcessor(BaseImageProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
@staticmethod
def _process_single_image_task(images, input_text):
# input_ids', 'attention_mask', 'pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask', 'cross_attention_mask'
return get_global_processor()(images, input_text, return_tensors="pt")
async def _process_single_image(self, images, input_text):
if self.executor is not None:
loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor(
self.executor,
MllamaImageProcessor._process_single_image_task,
images,
input_text,
)
else:
image_inputs = self._processor(images, input_text, return_tensors="pt")
return image_inputs
async def process_images_async(
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
):
if not image_data:
return None
if isinstance(input_text, list):
assert len(input_text) and isinstance(input_text[0], int)
input_text = self._processor.tokenizer.decode(input_text)
if not isinstance(image_data, list):
image_data = [image_data]
if len(image_data) > 0:
images = [load_image(image)[0] for image in image_data]
else:
images = load_image(image_data[0])[0]
image_inputs = await self._process_single_image(images, input_text)
image_inputs["image_hashes"] = [hash(str(image_data))]
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
return image_inputs
ImageProcessorMapping = {MllamaForConditionalGeneration: MllamaImageProcessor}
import asyncio
import math
from typing import List, Union
from PIL import Image
from sglang.srt.managers.image_processor import BaseImageProcessor
from sglang.srt.managers.image_processors.base_image_processor import (
get_global_processor,
)
from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration
# Compatible with Qwen2VL and Qwen2_5VL
class Qwen2_5VLImageProcessor(BaseImageProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "<|vision_start|><|image_pad|><|vision_end|>"
self.IM_START_TOKEN_ID = hf_config.vision_start_token_id
self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
self.image_token_id = hf_config.image_token_id
self.video_token_id = hf_config.video_token_id
self.NUM_TOKEN_PER_FRAME = 770
self.IMAGE_FACTOR = 28
self.MIN_PIXELS = 4 * 28 * 28
self.MAX_PIXELS = 16384 * 28 * 28
self.MAX_PIXELS = 16384 * 28 * 28
self.MAX_RATIO = 200
@staticmethod
def _process_images_task(images, input_text, _hf_config):
if isinstance(images, list) and len(images) == 0:
images = None
result = get_global_processor().__call__(
text=[input_text], images=images, padding=True, return_tensors="pt"
)
return {
"input_ids": result.input_ids,
"pixel_values": getattr(result, "pixel_values", None),
"image_grid_thw": getattr(result, "image_grid_thw", None),
"second_per_grid_ts": getattr(result, "second_per_grid_ts", None),
"video_grid_thws": getattr(result, "video_grid_thws", None),
}
async def _process_images(self, images, input_text) -> dict:
if self.executor is not None:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
Qwen2_5VLImageProcessor._process_images_task,
images,
input_text,
self.hf_config,
)
else:
return self._process_images_task(images, input_text, self.hf_config)
async def process_images_async(
self,
image_data: List[Union[str, bytes]],
input_ids,
request_obj,
max_req_input_len,
*args,
**kwargs,
):
if not image_data:
return None
if isinstance(image_data, str):
image_data = [image_data]
image_token = self.IMAGE_TOKEN
base_output = self.load_images(
input_ids,
image_data,
image_token,
max_req_input_len,
)
def smart_resize(
height: int,
width: int,
factor: int = self.IMAGE_FACTOR,
min_pixels: int = self.MIN_PIXELS,
max_pixels: int = self.MAX_PIXELS,
) -> tuple[int, int]:
"""
Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
if max(height, width) / min(height, width) > self.MAX_RATIO:
raise ValueError(
f"absolute aspect ratio must be smaller than {self.MAX_RATIO}, got {max(height, width) / min(height, width)}"
)
h_bar = max(factor, round_by_factor(height, factor))
w_bar = max(factor, round_by_factor(width, factor))
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = floor_by_factor(height / beta, factor)
w_bar = floor_by_factor(width / beta, factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = ceil_by_factor(height * beta, factor)
w_bar = ceil_by_factor(width * beta, factor)
return h_bar, w_bar
def resize_image(image, size_factor: int = self.IMAGE_FACTOR) -> Image.Image:
width, height = image.size
min_pixels = self.MIN_PIXELS
max_pixels = self.MAX_PIXELS
resized_height, resized_width = smart_resize(
height,
width,
factor=size_factor,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
image = image.resize((resized_width, resized_height))
return image
def round_by_factor(number: int, factor: int) -> int:
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
return round(number / factor) * factor
def ceil_by_factor(number: int, factor: int) -> int:
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
return math.ceil(number / factor) * factor
def floor_by_factor(number: int, factor: int) -> int:
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return math.floor(number / factor) * factor
images = [resize_image(image) for image in base_output.all_frames]
ret = await self._process_images(images, base_output.input_text)
return {
"input_ids": ret["input_ids"].flatten().tolist(),
"pixel_values": ret["pixel_values"],
"image_hashes": base_output.image_hashes,
"modalities": request_obj.modalities or ["image"],
"image_grid_thws": ret["image_grid_thw"],
"video_grid_thws": ret["video_grid_thws"],
"im_start_id": self.IM_START_TOKEN_ID,
"im_end_id": self.IM_END_TOKEN_ID,
"im_token_id": self.image_token_id,
"video_token_id": self.video_token_id,
"second_per_grid_ts": ret["second_per_grid_ts"],
}
ImageProcessorMapping = {
Qwen2VLForConditionalGeneration: Qwen2_5VLImageProcessor,
Qwen2_5_VLForConditionalGeneration: Qwen2_5VLImageProcessor,
}
from abc import abstractmethod
from typing import Callable, List, Optional, Tuple
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.utils import logger
class MultiModalityDataPaddingPattern:
"""
Data tokens (like image tokens) often need special handling during padding
to maintain model compatibility. This class provides the interface for
implementing different padding strategies for data tokens
"""
@abstractmethod
def pad_input_tokens(
self, input_ids: List[int], image_inputs: ImageInputs
) -> List[int]:
"""
Pad the input ids sequence containing data tokens, and replace them with pad_values
"""
pass
class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern):
"""In this pattern, data tokens should be enclosed by special token pairs (e.g. <image>...</image>, data_token_pairs)
This strategy should be applied when data content is marked by start/end token pairs in the input sequence.
"""
def __init__(self, data_token_pairs: Optional[List[Tuple[int, int]]]) -> None:
self.data_token_id_pairs = data_token_pairs
def pad_input_tokens(
self, input_ids: List[int], image_inputs: ImageInputs
) -> List[int]:
"""
This function will replace the data-tokens inbetween with pad_values accordingly
"""
pad_values = image_inputs.pad_values
data_token_pairs = self.data_token_id_pairs
image_inputs.image_offsets = []
if data_token_pairs is None:
data_token_pairs = [image_inputs.im_start_id, image_inputs.im_end_id]
if data_token_pairs is None:
logger.warning(
"No data_token_pairs provided, RadixAttention might be influenced."
)
return input_ids
start_token_ids = [s for s, _e in data_token_pairs]
end_tokens_ids = [e for _s, e in data_token_pairs]
# First start token marks new data
data_start_token = start_token_ids[0]
padded_ids = []
last_idx = 0
data_idx = -1
start_indices = [i for i, x in enumerate(input_ids) if x in start_token_ids]
end_indices = [i for i, x in enumerate(input_ids) if x in end_tokens_ids]
if len(start_indices) != len(end_indices):
return input_ids
for start_idx, end_idx in zip(start_indices, end_indices):
padded_ids.extend(input_ids[last_idx : start_idx + 1])
if input_ids[start_idx] == data_start_token:
data_idx += 1
image_inputs.image_offsets += [start_idx]
num_tokens = end_idx - start_idx - 1
pad_value = pad_values[data_idx]
padded_ids.extend([pad_value] * num_tokens)
last_idx = end_idx
padded_ids.extend(input_ids[last_idx:])
assert len(input_ids) == len(padded_ids)
return padded_ids
class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern):
"""In this pattern, data is represented with a special token_id ( image_inputs.im_token_id ),
which needs first to be expanded to multiple tokens, then replaced with their padding values
This strategy should be used when a single data token represents content that should
be expanded to multiple tokens during processing.
"""
def __init__(
self, num_data_token_calc_func: Callable[[Tuple[int, int, int]], int]
) -> None:
self.num_data_token_calc_func = num_data_token_calc_func
def pad_input_tokens(
self, input_ids: List[int], image_inputs: ImageInputs
) -> List[int]:
"""
This function will follow the procedure of:
1. the data token will be expanded, of which the final number will be calculated by `num_data_token_calc_func`
2. the padded data tokens will be replaced with their pad_values
"""
image_grid_thws = image_inputs.image_grid_thws
pad_values = image_inputs.pad_values
image_indices = [
idx
for idx, token in enumerate(input_ids)
if token == image_inputs.im_token_id
]
image_inputs.image_offsets = []
input_ids_with_image = []
for image_cnt, _ in enumerate(image_grid_thws):
print(f"image_cnt {image_cnt}")
num_image_tokens = self.num_data_token_calc_func(image_grid_thws[image_cnt])
if image_cnt == 0:
non_image_tokens = input_ids[: image_indices[image_cnt]]
else:
non_image_tokens = input_ids[
image_indices[image_cnt - 1] + 1 : image_indices[image_cnt]
]
input_ids_with_image.extend(non_image_tokens)
image_inputs.image_offsets.append(len(input_ids_with_image))
pad_ids = pad_values * (
(num_image_tokens + len(pad_values)) // len(pad_values)
)
input_ids_with_image.extend(pad_ids[:num_image_tokens])
input_ids_with_image.extend(input_ids[image_indices[-1] + 1 :])
return input_ids_with_image
...@@ -158,15 +158,19 @@ class ImageInputs: ...@@ -158,15 +158,19 @@ class ImageInputs:
image_grid_thws: List[Tuple[int, int, int]] = None image_grid_thws: List[Tuple[int, int, int]] = None
mrope_position_delta: Optional[torch.Tensor] = None mrope_position_delta: Optional[torch.Tensor] = None
# MiniCPMV related # The id of the single-image placeholder token
im_token_id: Optional[torch.Tensor] = None
# All the images in the batch should share the same special image # All the images in the batch should share the same special image
# bound token ids. # bound token ids.
im_start_id: Optional[torch.Tensor] = None im_start_id: Optional[int] = None
im_end_id: Optional[torch.Tensor] = None im_end_id: Optional[int] = None
slice_start_id: Optional[torch.Tensor] = None slice_start_id: Optional[int] = None
slice_end_id: Optional[torch.Tensor] = None slice_end_id: Optional[int] = None
tgt_sizes: Optional[list] = None tgt_sizes: Optional[list] = None
# denotes the number of valid image tokens in each image
images_emb_mask: Optional[torch.BoolTensor] = None
@staticmethod @staticmethod
def from_dict(obj: dict): def from_dict(obj: dict):
ret = ImageInputs( ret = ImageInputs(
...@@ -186,11 +190,13 @@ class ImageInputs: ...@@ -186,11 +190,13 @@ class ImageInputs:
"aspect_ratio_ids", "aspect_ratio_ids",
"aspect_ratio_mask", "aspect_ratio_mask",
"image_grid_thws", "image_grid_thws",
"im_token_id",
"im_start_id", "im_start_id",
"im_end_id", "im_end_id",
"slice_start_id", "slice_start_id",
"slice_end_id", "slice_end_id",
"tgt_sizes", "tgt_sizes",
"images_emb_mask",
] ]
for arg in optional_args: for arg in optional_args:
if arg in obj: if arg in obj:
......
...@@ -455,7 +455,7 @@ def pt_weights_iterator( ...@@ -455,7 +455,7 @@ def pt_weights_iterator(
disable=not enable_tqdm, disable=not enable_tqdm,
bar_format=_BAR_FORMAT, bar_format=_BAR_FORMAT,
): ):
state = torch.load(bin_file, map_location="cpu") state = torch.load(bin_file, map_location="cpu", weights_only=True)
yield from state.items() yield from state.items()
del state del state
torch.cuda.empty_cache() torch.cuda.empty_cache()
......
...@@ -41,7 +41,6 @@ from torch import nn ...@@ -41,7 +41,6 @@ from torch import nn
from torch.nn.init import trunc_normal_ from torch.nn.init import trunc_normal_
from transformers import PretrainedConfig from transformers import PretrainedConfig
from sglang.srt.distributed import divide, get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import get_act_fn from sglang.srt.layers.activation import get_act_fn
from sglang.srt.layers.attention.vision import VisionAttention from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
...@@ -51,6 +50,9 @@ from sglang.srt.layers.linear import ( ...@@ -51,6 +50,9 @@ from sglang.srt.layers.linear import (
) )
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.multi_modality_padding import (
MultiModalityDataPaddingPatternTokenPairs,
)
from sglang.srt.managers.schedule_batch import ImageInputs from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.utils import set_default_torch_dtype from sglang.srt.model_loader.utils import set_default_torch_dtype
...@@ -186,19 +188,16 @@ class Idefics2EncoderLayer(nn.Module): ...@@ -186,19 +188,16 @@ class Idefics2EncoderLayer(nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
num_heads_per_partition = divide(self.num_heads, tp_size)
self.self_attn = VisionAttention( self.self_attn = VisionAttention(
embed_dim=config.hidden_size, embed_dim=config.hidden_size,
num_heads=num_heads_per_partition, num_heads=self.num_heads,
projection_size=config.intermediate_size, projection_size=config.intermediate_size,
use_qkv_parallel=True, use_qkv_parallel=True,
quant_config=quant_config, quant_config=quant_config,
dropout=config.attention_dropout, dropout=config.attention_dropout,
use_context_forward=False, use_context_forward=False,
use_full_precision_softmax=True, softmax_in_single_precision=True,
flatten_batch=False, flatten_batch=False,
prefix=add_prefix("self_attn", prefix), prefix=add_prefix("self_attn", prefix),
) )
...@@ -708,21 +707,21 @@ class MiniCPMVBaseModel(nn.Module): ...@@ -708,21 +707,21 @@ class MiniCPMVBaseModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
pad_values: List[int], pad_values: List[int],
im_start_id: torch.Tensor, im_start_id: int,
im_end_id: torch.Tensor, im_end_id: int,
slice_start_id: Optional[torch.Tensor] = None, slice_start_id: Optional[int] = None,
slice_end_id: Optional[torch.Tensor] = None, slice_end_id: Optional[int] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Returns a tensor indicating the bounds (start and end token ids) of the images Returns a tensor indicating the bounds (start and end token ids) of the images
""" """
# All the images in the batch should share the same special image # All the images in the batch should share the same special image
# bound token ids. # bound token ids.
start_cond = input_ids == im_start_id[0] start_cond = input_ids == im_start_id
end_cond = input_ids == im_end_id[0] end_cond = input_ids == im_end_id
if slice_start_id is not None: if slice_start_id is not None:
start_cond |= input_ids == slice_start_id[0] start_cond |= input_ids == slice_start_id
end_cond |= input_ids == slice_end_id[0] end_cond |= input_ids == slice_end_id
(image_start_tokens,) = torch.where(start_cond) (image_start_tokens,) = torch.where(start_cond)
image_start_tokens += 1 image_start_tokens += 1
...@@ -733,6 +732,8 @@ class MiniCPMVBaseModel(nn.Module): ...@@ -733,6 +732,8 @@ class MiniCPMVBaseModel(nn.Module):
if ( if (
len(image_start_tokens) + 1 == len(image_end_tokens) len(image_start_tokens) + 1 == len(image_end_tokens)
and input_ids[0] in pad_values and input_ids[0] in pad_values
and len(image_start_tokens) != 0
and len(image_end_tokens) != 0
and image_end_tokens[0] < image_start_tokens[0] and image_end_tokens[0] < image_start_tokens[0]
): ):
image_start_tokens = torch.cat( image_start_tokens = torch.cat(
...@@ -897,9 +898,12 @@ class MiniCPMVBaseModel(nn.Module): ...@@ -897,9 +898,12 @@ class MiniCPMVBaseModel(nn.Module):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
**kwargs: Any, **kwargs: Any,
) -> torch.Tensor: ) -> torch.Tensor:
if forward_batch.image_inputs is not None and forward_batch.image_inputs != [ if (
None forward_batch.image_inputs is not None
]: and len(forward_batch.image_inputs) > 0
and forward_batch.image_inputs[0] is not None
):
# TODO: bath
kwargs.update( kwargs.update(
{ {
"pixel_values": ( "pixel_values": (
...@@ -1135,81 +1139,16 @@ class MiniCPMV2_6(MiniCPMVBaseModel): ...@@ -1135,81 +1139,16 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
return self.resampler(vision_embedding, tgt_sizes) return self.resampler(vision_embedding, tgt_sizes)
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs): def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
if not isinstance(image_inputs.im_start_id, list) or not isinstance(
image_inputs.im_end_id, list
):
return input_ids
new_input_ids = []
last_idx = 0
image_idx = -1
image_inputs.image_offsets = []
# Get all special token IDs # Get all special token IDs
im_start_id = ( im_start_id: int = image_inputs.im_start_id
image_inputs.im_start_id[0].item() im_end_id: int = image_inputs.im_end_id
if isinstance(image_inputs.im_start_id[0], torch.Tensor) slice_start_id: int = image_inputs.slice_start_id
else image_inputs.im_start_id[0] slice_end_id: int = image_inputs.slice_end_id
)
im_end_id = (
image_inputs.im_end_id[0].item()
if isinstance(image_inputs.im_end_id[0], torch.Tensor)
else image_inputs.im_end_id[0]
)
slice_start_id = (
image_inputs.slice_start_id[0].item()
if isinstance(image_inputs.slice_start_id[0], torch.Tensor)
else image_inputs.slice_start_id[0]
)
slice_end_id = (
image_inputs.slice_end_id[0].item()
if isinstance(image_inputs.slice_end_id[0], torch.Tensor)
else image_inputs.slice_end_id[0]
)
# Find all start and end positions for both types
start_indices = [
i
for i, x in enumerate(input_ids)
if x == im_start_id or x == slice_start_id
]
end_indices = [
i for i, x in enumerate(input_ids) if x == im_end_id or x == slice_end_id
]
if len(start_indices) != len(end_indices):
return input_ids
# Process each region (both image and slice)
for start_idx, end_idx in zip(start_indices, end_indices):
# Add non-image tokens before this region
new_input_ids.extend(
input_ids[last_idx : start_idx + 1]
) # include start token
is_image_start = input_ids[start_idx] == im_start_id
if is_image_start:
image_inputs.image_offsets += [start_idx]
image_idx += 1
num_tokens = end_idx - start_idx - 1 # exclude start and end tokens
# Generate pad_ids
pad_values = [image_inputs.pad_values[image_idx]]
pad_ids = pad_values * ((num_tokens + len(pad_values)) // len(pad_values))
pad_ids = pad_ids[:num_tokens]
# Add pad_ids
new_input_ids.extend(pad_ids)
# Update last_idx to after end token media_token_pairs = [(im_start_id, im_end_id), (slice_start_id, slice_end_id)]
last_idx = end_idx pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
# Add remaining tokens after last region return pattern.pad_input_tokens(input_ids, image_inputs)
new_input_ids.extend(input_ids[last_idx:])
assert len(input_ids) == len(new_input_ids)
return new_input_ids
_SUPPORT_VERSION = {(2, 6): MiniCPMV2_6} _SUPPORT_VERSION = {(2, 6): MiniCPMV2_6}
......
...@@ -202,7 +202,7 @@ class MllamaVisionEncoderLayer(nn.Module): ...@@ -202,7 +202,7 @@ class MllamaVisionEncoderLayer(nn.Module):
quant_config=None, quant_config=None,
dropout=0.0, dropout=0.0,
use_context_forward=False, use_context_forward=False,
use_full_precision_softmax=False, softmax_in_single_precision=False,
flatten_batch=False, flatten_batch=False,
prefix=add_prefix("self_attn", prefix), prefix=add_prefix("self_attn", prefix),
) )
......
...@@ -47,6 +47,9 @@ from sglang.srt.layers.logits_processor import LogitsProcessor ...@@ -47,6 +47,9 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.multi_modality_padding import (
MultiModalityDataPaddingPatternTokenPairs,
)
from sglang.srt.managers.schedule_batch import ImageInputs from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
...@@ -121,12 +124,12 @@ class Qwen2_5_VisionBlock(nn.Module): ...@@ -121,12 +124,12 @@ class Qwen2_5_VisionBlock(nn.Module):
self.norm2 = Qwen2RMSNorm(dim, eps=1e-6) self.norm2 = Qwen2RMSNorm(dim, eps=1e-6)
if attn_implementation == "sdpa": if attn_implementation == "sdpa":
use_context_forward = False use_context_forward = False
use_full_precision_softmax = False softmax_in_single_precision = False
elif attn_implementation == "flash_attention_2": elif attn_implementation == "flash_attention_2":
use_full_precision_softmax = False softmax_in_single_precision = False
use_context_forward = True use_context_forward = True
elif attn_implementation == "eager": elif attn_implementation == "eager":
use_full_precision_softmax = True softmax_in_single_precision = True
use_context_forward = False use_context_forward = False
self.attn = VisionAttention( self.attn = VisionAttention(
...@@ -135,7 +138,7 @@ class Qwen2_5_VisionBlock(nn.Module): ...@@ -135,7 +138,7 @@ class Qwen2_5_VisionBlock(nn.Module):
projection_size=dim, projection_size=dim,
use_qkv_parallel=False, use_qkv_parallel=False,
use_context_forward=use_context_forward, use_context_forward=use_context_forward,
use_full_precision_softmax=use_full_precision_softmax, softmax_in_single_precision=softmax_in_single_precision,
flatten_batch=True, flatten_batch=True,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("attn", prefix), prefix=add_prefix("attn", prefix),
...@@ -149,12 +152,17 @@ class Qwen2_5_VisionBlock(nn.Module): ...@@ -149,12 +152,17 @@ class Qwen2_5_VisionBlock(nn.Module):
) )
def forward( def forward(
self, x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
position_embeddings: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.norm1(x) hidden_states = self.norm1(x)
hidden_states = rearrange(hidden_states, "s b ... -> b s ...") hidden_states = rearrange(hidden_states, "s b ... -> b s ...")
attn = self.attn( attn = self.attn(
hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb hidden_states,
cu_seqlens=cu_seqlens,
position_embeddings=position_embeddings,
) )
attn = rearrange(attn, "b s ... -> s b ...") attn = rearrange(attn, "b s ... -> s b ...")
x = x + attn x = x + attn
...@@ -443,6 +451,8 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -443,6 +451,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
) )
rotary_pos_emb = rotary_pos_emb[window_index, :, :] rotary_pos_emb = rotary_pos_emb[window_index, :, :]
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
position_embeddings = (emb.cos(), emb.sin())
# compute cu_seqlens # compute cu_seqlens
cu_seqlens = torch.repeat_interleave( cu_seqlens = torch.repeat_interleave(
...@@ -457,7 +467,9 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -457,7 +467,9 @@ class Qwen2_5_VisionTransformer(nn.Module):
cu_seqlens_now = cu_seqlens cu_seqlens_now = cu_seqlens
else: else:
cu_seqlens_now = cu_window_seqlens cu_seqlens_now = cu_window_seqlens
x = blk(x, cu_seqlens=cu_seqlens_now, rotary_pos_emb=rotary_pos_emb) x = blk(
x, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings
)
# adapter # adapter
x = self.merger(x) x = self.merger(x)
...@@ -522,50 +534,14 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module): ...@@ -522,50 +534,14 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
return num_image_tokens return num_image_tokens
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs): def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
new_input_ids = []
last_idx = 0
image_idx = -1
image_inputs.image_offsets = []
# Get all special token IDs # Get all special token IDs
im_start_id = image_inputs.im_start_id im_start_id: int = image_inputs.im_start_id
im_end_id = image_inputs.im_end_id im_end_id: int = image_inputs.im_end_id
# Find all start and end positions for both types
start_indices = [i for i, x in enumerate(input_ids) if x == im_start_id]
end_indices = [i for i, x in enumerate(input_ids) if x == im_end_id]
if len(start_indices) != len(end_indices):
return input_ids
# Process each region (both image and slice)
for start_idx, end_idx in zip(start_indices, end_indices):
# Add non-image tokens before this region
new_input_ids.extend(input_ids[last_idx : start_idx + 1])
is_image_start = input_ids[start_idx] == im_start_id
if is_image_start:
image_inputs.image_offsets += [start_idx]
image_idx += 1
num_tokens = end_idx - start_idx - 1 # exclude start and end tokens
# Generate pad_ids
pad_values = [image_inputs.pad_values[image_idx]]
pad_ids = pad_values * ((num_tokens + len(pad_values)) // len(pad_values))
pad_ids = pad_ids[:num_tokens]
# Add pad_ids
new_input_ids.extend(pad_ids)
# Update last_idx to after end token media_token_pairs = [(im_start_id, im_end_id)]
last_idx = end_idx pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
# Add remaining tokens after last region return pattern.pad_input_tokens(input_ids, image_inputs)
new_input_ids.extend(input_ids[last_idx:])
assert len(input_ids) == len(new_input_ids)
return new_input_ids
def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor: def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor:
pixel_values = image_input["pixel_values"].type(self.visual.dtype) pixel_values = image_input["pixel_values"].type(self.visual.dtype)
...@@ -629,7 +605,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module): ...@@ -629,7 +605,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy() extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
for i, image in enumerate(forward_batch.image_inputs): for i, image in enumerate(forward_batch.image_inputs):
if image is None: if image is None or image.pixel_values is None:
continue continue
start_idx = extend_start_loc_cpu[i] start_idx = extend_start_loc_cpu[i]
prefix_len = prefix_lens_cpu[i] prefix_len = prefix_lens_cpu[i]
......
...@@ -42,6 +42,9 @@ from sglang.srt.layers.logits_processor import LogitsProcessor ...@@ -42,6 +42,9 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.multi_modality_padding import (
MultiModalityDataPaddingPatternTokenPairs,
)
from sglang.srt.managers.schedule_batch import ImageInputs from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
...@@ -137,12 +140,12 @@ class Qwen2VisionBlock(nn.Module): ...@@ -137,12 +140,12 @@ class Qwen2VisionBlock(nn.Module):
mlp_hidden_dim = int(dim * mlp_ratio) mlp_hidden_dim = int(dim * mlp_ratio)
if attn_implementation == "sdpa": if attn_implementation == "sdpa":
use_context_forward = False use_context_forward = False
use_full_precision_softmax = False softmax_in_single_precision = False
elif attn_implementation == "flash_attention_2": elif attn_implementation == "flash_attention_2":
use_full_precision_softmax = False softmax_in_single_precision = False
use_context_forward = True use_context_forward = True
elif attn_implementation == "eager": elif attn_implementation == "eager":
use_full_precision_softmax = True softmax_in_single_precision = True
use_context_forward = False use_context_forward = False
self.attn = VisionAttention( self.attn = VisionAttention(
...@@ -151,7 +154,7 @@ class Qwen2VisionBlock(nn.Module): ...@@ -151,7 +154,7 @@ class Qwen2VisionBlock(nn.Module):
projection_size=dim, projection_size=dim,
use_qkv_parallel=False, use_qkv_parallel=False,
use_context_forward=use_context_forward, use_context_forward=use_context_forward,
use_full_precision_softmax=use_full_precision_softmax, softmax_in_single_precision=softmax_in_single_precision,
flatten_batch=True, flatten_batch=True,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("attn", prefix), prefix=add_prefix("attn", prefix),
...@@ -165,12 +168,17 @@ class Qwen2VisionBlock(nn.Module): ...@@ -165,12 +168,17 @@ class Qwen2VisionBlock(nn.Module):
) )
def forward( def forward(
self, x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
position_embeddings: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.norm1(x) hidden_states = self.norm1(x)
hidden_states = rearrange(hidden_states, "s b ... -> b s ...") hidden_states = rearrange(hidden_states, "s b ... -> b s ...")
attn = self.attn( attn = self.attn(
hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb hidden_states,
cu_seqlens=cu_seqlens,
position_embeddings=position_embeddings,
) )
attn = rearrange(attn, "b s ... -> s b ...") attn = rearrange(attn, "b s ... -> s b ...")
x = x + attn x = x + attn
...@@ -392,7 +400,8 @@ class Qwen2VisionTransformer(nn.Module): ...@@ -392,7 +400,8 @@ class Qwen2VisionTransformer(nn.Module):
# compute position embedding # compute position embedding
rotary_pos_emb = self.rot_pos_emb(grid_thw) rotary_pos_emb = self.rot_pos_emb(grid_thw)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
position_embeddings = (emb.cos(), emb.sin())
# compute cu_seqlens # compute cu_seqlens
cu_seqlens = torch.repeat_interleave( cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
...@@ -402,7 +411,7 @@ class Qwen2VisionTransformer(nn.Module): ...@@ -402,7 +411,7 @@ class Qwen2VisionTransformer(nn.Module):
# transformers # transformers
x = x.unsqueeze(1) x = x.unsqueeze(1)
for blk in self.blocks: for blk in self.blocks:
x = blk(x, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb) x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings)
# adapter # adapter
x = self.merger(x) x = self.merger(x)
...@@ -425,40 +434,6 @@ class Qwen2VLForConditionalGeneration(nn.Module): ...@@ -425,40 +434,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
) )
return num_image_tokens return num_image_tokens
# Use grid_t * grid_w * grid_h to pad tokens for each image
# add replaced padding by unique image hash
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
image_grid_thws = image_inputs.image_grid_thws
pad_values = image_inputs.pad_values
image_indices = [
idx
for idx, token in enumerate(input_ids)
if token == self.config.image_token_id
]
image_inputs.image_offsets = []
input_ids_with_image = []
for image_cnt, _ in enumerate(image_grid_thws):
num_image_tokens = self.calculate_num_image_tokens(
image_grid_thws[image_cnt]
)
if image_cnt == 0:
non_image_tokens = input_ids[: image_indices[image_cnt]]
else:
non_image_tokens = input_ids[
image_indices[image_cnt - 1] + 1 : image_indices[image_cnt]
]
input_ids_with_image.extend(non_image_tokens)
image_inputs.image_offsets.append(len(input_ids_with_image))
pad_ids = pad_values * (
(num_image_tokens + len(pad_values)) // len(pad_values)
)
input_ids_with_image.extend(pad_ids[:num_image_tokens])
input_ids_with_image.extend(input_ids[image_indices[-1] + 1 :])
return input_ids_with_image
def __init__( def __init__(
self, self,
config: Qwen2VLConfig, config: Qwen2VLConfig,
...@@ -494,6 +469,17 @@ class Qwen2VLForConditionalGeneration(nn.Module): ...@@ -494,6 +469,17 @@ class Qwen2VLForConditionalGeneration(nn.Module):
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
# Use grid_t * grid_w * grid_h to pad tokens for each image
# add replaced padding by unique image hash
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
# Get all special token IDs
im_start_id: int = image_inputs.im_start_id
im_end_id: int = image_inputs.im_end_id
media_token_pairs = [(im_start_id, im_end_id)]
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
return pattern.pad_input_tokens(input_ids, image_inputs)
def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor: def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor:
pixel_values = image_input["pixel_values"].type(self.visual.dtype) pixel_values = image_input["pixel_values"].type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=image_input["image_grid_thw"]) image_embeds = self.visual(pixel_values, grid_thw=image_input["image_grid_thw"])
...@@ -556,12 +542,12 @@ class Qwen2VLForConditionalGeneration(nn.Module): ...@@ -556,12 +542,12 @@ class Qwen2VLForConditionalGeneration(nn.Module):
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy() extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
for i, image in enumerate(forward_batch.image_inputs): for i, image in enumerate(forward_batch.image_inputs):
if image is None: if image is None or image.pixel_values is None:
continue continue
start_idx = extend_start_loc_cpu[i] start_idx = extend_start_loc_cpu[i]
prefix_len = prefix_lens_cpu[i] prefix_len = prefix_lens_cpu[i]
pixel_values = image.pixel_values.clone()
pixel_values = torch.tensor(image.pixel_values, device="cuda")
image_grid_thws = torch.tensor( image_grid_thws = torch.tensor(
np.array(image.image_grid_thws), device="cuda" np.array(image.image_grid_thws), device="cuda"
) )
...@@ -579,15 +565,13 @@ class Qwen2VLForConditionalGeneration(nn.Module): ...@@ -579,15 +565,13 @@ class Qwen2VLForConditionalGeneration(nn.Module):
image_grid_thws[idx] image_grid_thws[idx]
) )
left_idx = start_idx + (image_offset - prefix_len) left_idx = start_idx + (image_offset - prefix_len + 1)
right_idx = ( right_idx = left_idx + num_image_tokens
start_idx + (image_offset - prefix_len) + num_image_tokens
)
inputs_embeds[left_idx:right_idx] = image_embeds[ inputs_embeds[left_idx:right_idx] = image_embeds[
image_embeds_offset : image_embeds_offset + num_image_tokens image_embeds_offset : image_embeds_offset + num_image_tokens
] ]
image_embeds_offset += num_image_tokens image_embeds_offset += num_image_tokens
input_ids = None
hidden_states = self.model( hidden_states = self.model(
input_ids=input_ids, input_ids=input_ids,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment