"...text-generation-inference.git" did not exist on "c8c7ccd31e1e760d216c9d2f2b17b0d984ed033b"
Unverified Commit 5cb552b1 authored by Mick's avatar Mick Committed by GitHub
Browse files

refactor: multimodal data (#4754)

parent c7457191
...@@ -72,17 +72,38 @@ def eval_mmmu(args): ...@@ -72,17 +72,38 @@ def eval_mmmu(args):
if suffix: if suffix:
contents += [{"type": "text", "text": suffix}] contents += [{"type": "text", "text": suffix}]
messages = [{"role": "user", "content": contents}] messages = [{"role": "user", "content": contents}]
model_inputs = processor.apply_chat_template( try:
messages, model_inputs = processor.tokenizer.apply_chat_template(
tokenize=True, messages,
return_dict=True, tokenize=True,
add_generation_prompt=True, return_dict=True,
return_tensors="pt", add_generation_prompt=True,
).to(model.device) return_tensors="pt",
input_len = model_inputs["input_ids"].shape[-1] ).to(model.device)
generation = model.generate(**model_inputs, generation_config=generation_config) input_len = model_inputs["input_ids"].shape[-1]
generation = generation[0][input_len:] generation = model.generate(
response = processor.decode(generation, skip_special_tokens=True) **model_inputs, generation_config=generation_config
)
generation = generation[0][input_len:]
response = processor.decode(generation, skip_special_tokens=True)
except:
contents = []
if prefix:
contents += [prefix]
image = PIL.Image.open(sample["image_path"])
contents += [image]
if suffix:
contents += [suffix]
messages = [{"role": "user", "content": contents}]
response = model.chat(
msgs=messages,
tokenizer=processor.tokenizer,
sampling=False,
max_new_tokens=sampling_params["max_new_tokens"],
use_tts_template=False,
generate_audio=False,
temperature=0.0,
)
print(f"response: {response}") print(f"response: {response}")
process_result(response, sample, answer_dict, out_samples) process_result(response, sample, answer_dict, out_samples)
......
...@@ -442,6 +442,8 @@ def calculate_ins_level_acc(results: Dict): ...@@ -442,6 +442,8 @@ def calculate_ins_level_acc(results: Dict):
def process_result(response, sample, answer_dict, out_samples): def process_result(response, sample, answer_dict, out_samples):
if response is None:
return
if sample["question_type"] == "multiple-choice": if sample["question_type"] == "multiple-choice":
pred_ans = parse_multi_choice_response( pred_ans = parse_multi_choice_response(
response, sample["all_choices"], sample["index2ans"] response, sample["all_choices"], sample["index2ans"]
......
This diff is collapsed.
...@@ -64,5 +64,3 @@ def get_mm_processor( ...@@ -64,5 +64,3 @@ def get_mm_processor(
f"No processor registered for architecture: {hf_config.architectures}.\n" f"No processor registered for architecture: {hf_config.architectures}.\n"
f"Registered architectures: {[model_cls.__name__ for model_cls in PROCESSOR_MAPPING.keys()]}" f"Registered architectures: {[model_cls.__name__ for model_cls in PROCESSOR_MAPPING.keys()]}"
) )
self.image_proce
...@@ -8,18 +8,10 @@ from typing import Optional ...@@ -8,18 +8,10 @@ from typing import Optional
import numpy as np import numpy as np
import PIL import PIL
import transformers
from decord import VideoReader, cpu from decord import VideoReader, cpu
from PIL import Image from PIL import Image
from sglang.srt.utils import load_audio, load_image, logger from sglang.srt.utils import encode_video, load_audio, load_image, logger
global global_processor
def get_global_processor():
global global_processor
return global_processor
@dataclasses.dataclass @dataclasses.dataclass
...@@ -27,9 +19,6 @@ class BaseMultiModalProcessorOutput: ...@@ -27,9 +19,6 @@ class BaseMultiModalProcessorOutput:
# input_text, with each frame of video/image represented with a image_token # input_text, with each frame of video/image represented with a image_token
input_text: str input_text: str
mm_data_hashes: Optional[list[int]]
# images
image_sizes: Optional[list[int]]
# frames loaded from image and video, in given order # frames loaded from image and video, in given order
images: Optional[list[PIL.Image]] = None images: Optional[list[PIL.Image]] = None
...@@ -37,7 +26,7 @@ class BaseMultiModalProcessorOutput: ...@@ -37,7 +26,7 @@ class BaseMultiModalProcessorOutput:
audios: Optional[list[np.ndarray]] = None audios: Optional[list[np.ndarray]] = None
def normalize(self): def normalize(self):
for field_name in ["data_hashes", "image_sizes", "images", "audios"]: for field_name in ["image_sizes", "images", "audios"]:
field = getattr(self, field_name, None) field = getattr(self, field_name, None)
if field is not None and isinstance(field, list) and len(field) == 0: if field is not None and isinstance(field, list) and len(field) == 0:
setattr(self, field_name, None) setattr(self, field_name, None)
...@@ -67,28 +56,35 @@ class BaseMultimodalProcessor(ABC): ...@@ -67,28 +56,35 @@ class BaseMultimodalProcessor(ABC):
# FIXME: not accurate, model and image specific # FIXME: not accurate, model and image specific
self.NUM_TOKEN_PER_FRAME = 330 self.NUM_TOKEN_PER_FRAME = 330
# Initialize global processor first self.io_executor = concurrent.futures.ThreadPoolExecutor(
init_global_processor(self, server_args) max_workers=int(os.environ.get("SGLANG_IO_WORKERS", 4))
)
self.executor = concurrent.futures.ProcessPoolExecutor( self.cpu_executor = concurrent.futures.ProcessPoolExecutor(
initializer=init_global_processor,
mp_context=mp.get_context("fork"), mp_context=mp.get_context("fork"),
initargs=( max_workers=int(os.environ.get("SGLANG_CPU_WORKERS", os.cpu_count())),
self,
server_args,
),
max_workers=int(os.environ.get("SGLANG_CPU_COUNT", os.cpu_count())),
) )
def _build_processor(self, server_args): def process_mm_data(
"""Init the global processor for multi modal models.""" self, input_text, images=None, videos=None, audios=None, **kwargs
from sglang.srt.hf_transformers_utils import get_processor ):
"""
return get_processor( process multimodal data with transformers AutoProcessor
server_args.tokenizer_path, """
tokenizer_mode=server_args.tokenizer_mode, if images is not None:
trust_remote_code=server_args.trust_remote_code, kwargs["images"] = images
if videos is not None:
kwargs["videos"] = videos
if audios is not None:
kwargs["audios"] = audios
processor = self._processor
result = processor.__call__(
text=[input_text],
padding=True,
return_tensors="pt",
**kwargs,
) )
return result
@abstractmethod @abstractmethod
async def process_mm_data_async( async def process_mm_data_async(
...@@ -115,33 +111,9 @@ class BaseMultimodalProcessor(ABC): ...@@ -115,33 +111,9 @@ class BaseMultimodalProcessor(ABC):
return estimated_frames_list return estimated_frames_list
@staticmethod
def encode_video(video_path, frame_count_limit=None):
if not os.path.exists(video_path):
logger.error(f"Video {video_path} does not exist")
return []
if frame_count_limit == 0:
return []
def uniform_sample(l, n):
gap = len(l) / n
idxs = [int(i * gap + gap / 2) for i in range(n)]
return [l[i] for i in idxs]
vr = VideoReader(video_path, ctx=cpu(0))
sample_fps = round(vr.get_avg_fps() / 1) # FPS
frame_indices = [i for i in range(0, len(vr), sample_fps)]
if frame_count_limit is not None and len(frame_indices) > frame_count_limit:
frame_indices = uniform_sample(frame_indices, frame_count_limit)
frames = vr.get_batch(frame_indices).asnumpy()
frames = [Image.fromarray(v.astype("uint8")) for v in frames]
return frames
def load_mm_data( def load_mm_data(
self, self,
input_ids: list[int], prompt: str,
multimodal_tokens: MultimodalSpecialTokens, multimodal_tokens: MultimodalSpecialTokens,
max_req_input_len: int, max_req_input_len: int,
image_data: Optional[list] = None, image_data: Optional[list] = None,
...@@ -167,11 +139,13 @@ class BaseMultimodalProcessor(ABC): ...@@ -167,11 +139,13 @@ class BaseMultimodalProcessor(ABC):
else: else:
multimodal_tokens.image_token = multimodal_tokens.image_token multimodal_tokens.image_token = multimodal_tokens.image_token
if isinstance(input_ids, list) and return_text: assert isinstance(prompt, str)
assert len(input_ids) and isinstance(input_ids[0], int)
input_text = self._processor.tokenizer.decode(input_ids) if isinstance(prompt, list) and return_text:
assert len(prompt) and isinstance(prompt[0], int)
prompt = self._processor.tokenizer.decode(prompt)
else: else:
input_text = input_ids prompt = prompt
if return_text: if return_text:
import re import re
...@@ -181,7 +155,7 @@ class BaseMultimodalProcessor(ABC): ...@@ -181,7 +155,7 @@ class BaseMultimodalProcessor(ABC):
+ ")" + ")"
) )
# split text into list of normal text and special tokens # split text into list of normal text and special tokens
text_parts = re.split(pattern, input_text) text_parts = re.split(pattern, prompt)
# TODO(mick): load from server_args, env, or sampling_params # TODO(mick): load from server_args, env, or sampling_params
MAX_NUM_FRAMES = 30 MAX_NUM_FRAMES = 30
...@@ -217,7 +191,7 @@ class BaseMultimodalProcessor(ABC): ...@@ -217,7 +191,7 @@ class BaseMultimodalProcessor(ABC):
): ):
# video # video
path = image_file[len("video:") :] path = image_file[len("video:") :]
frames = BaseMultimodalProcessor.encode_video( frames = encode_video(
path, frame_count_limit=frames_to_process path, frame_count_limit=frames_to_process
) )
else: else:
...@@ -254,19 +228,9 @@ class BaseMultimodalProcessor(ABC): ...@@ -254,19 +228,9 @@ class BaseMultimodalProcessor(ABC):
raise RuntimeError(f"An exception occurred while loading images: {e}") raise RuntimeError(f"An exception occurred while loading images: {e}")
out = BaseMultiModalProcessorOutput( out = BaseMultiModalProcessorOutput(
mm_data_hashes=hashes,
image_sizes=image_sizes,
images=images, images=images,
audios=audios, audios=audios,
input_text=new_text, input_text=new_text,
) )
out.normalize() out.normalize()
return out return out
def init_global_processor(sglang_processor: BaseMultimodalProcessor, server_args):
"""
Init the global processor for multimodal models."""
global global_processor
transformers.logging.set_verbosity_error()
global_processor = sglang_processor._build_processor(server_args=server_args)
import asyncio
from typing import List, Union from typing import List, Union
from sglang.srt.managers.multimodal_processors.base_processor import ( from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor, BaseMultimodalProcessor,
get_global_processor,
) )
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.clip import CLIPModel from sglang.srt.models.clip import CLIPModel
from sglang.srt.utils import load_image from sglang.srt.utils import load_image
...@@ -15,29 +14,6 @@ class ClipImageProcessor(BaseMultimodalProcessor): ...@@ -15,29 +14,6 @@ class ClipImageProcessor(BaseMultimodalProcessor):
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor)
@staticmethod
def _process_single_image_task(images, input_text):
# input_ids', 'attention_mask', 'pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask', 'cross_attention_mask'
return get_global_processor()(
images=images, text=input_text, return_tensors="pt"
)
async def _process_single_image(self, images, input_text):
if self.executor is not None:
loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor(
self.executor,
ClipImageProcessor._process_single_image_task,
images,
input_text,
)
else:
image_inputs = self._processor(
images=images, text=[input_text], return_tensors="pt"
)
return image_inputs
async def process_mm_data_async( async def process_mm_data_async(
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
): ):
...@@ -56,8 +32,13 @@ class ClipImageProcessor(BaseMultimodalProcessor): ...@@ -56,8 +32,13 @@ class ClipImageProcessor(BaseMultimodalProcessor):
else: else:
images = load_image(image_data[0])[0] images = load_image(image_data[0])[0]
image_inputs = await self._process_single_image(images, input_text) image_inputs = self.process_mm_data(input_text=input_text, images=images)
image_inputs["data_hashes"] = [hash(str(image_data))] image_inputs["data_hashes"] = [hash(str(image_data))]
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0] image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
image_inputs["mm_items"] = [
MultimodalDataItem(
pixel_values=image_inputs["pixel_values"], modality=Modality.IMAGE
)
]
return image_inputs return image_inputs
...@@ -16,15 +16,14 @@ ...@@ -16,15 +16,14 @@
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import asyncio
import torch import torch
from sglang.srt.managers.multimodal_processors.base_processor import ( from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor, BaseMultimodalProcessor,
MultimodalSpecialTokens, MultimodalSpecialTokens,
get_global_processor,
) )
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.deepseek_vl2 import DeepseekVL2ForCausalLM from sglang.srt.models.deepseek_vl2 import DeepseekVL2ForCausalLM
...@@ -35,51 +34,6 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor): ...@@ -35,51 +34,6 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "<image>" self.IMAGE_TOKEN = "<image>"
@staticmethod
def _process_images_task(image, input_text, max_req_input_len):
processor = get_global_processor()
res = processor.__call__(
conversations=input_text, images=image, max_req_input_len=max_req_input_len
)
image_token_id = processor.image_token_id
res["im_token_id"] = image_token_id
return res
async def _process_images(self, image_data, input_text, max_req_input_len):
if self.executor is not None:
loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor(
self.executor,
DeepseekVL2ImageProcessor._process_images_task,
image_data,
input_text,
max_req_input_len,
)
else:
image_inputs = self._process_images_task(
image_data, input_text, max_req_input_len
)
return image_inputs
async def _process_images(self, image_data, input_text, max_req_input_len):
if self.executor is not None:
loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor(
self.executor,
DeepseekVL2ImageProcessor._process_images_task,
image_data,
input_text,
max_req_input_len,
)
else:
image_inputs = self._process_images_task(
image_data, input_text, max_req_input_len
)
return image_inputs
async def process_mm_data_async( async def process_mm_data_async(
self, image_data, input_ids, request_obj, max_req_input_len, *args, **kwargs self, image_data, input_ids, request_obj, max_req_input_len, *args, **kwargs
): ):
...@@ -89,8 +43,6 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor): ...@@ -89,8 +43,6 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
if not isinstance(image_data, list): if not isinstance(image_data, list):
image_data = [image_data] image_data = [image_data]
images, image_sizes = [], []
image_token = self.IMAGE_TOKEN image_token = self.IMAGE_TOKEN
base_output = self.load_mm_data( base_output = self.load_mm_data(
input_ids, input_ids,
...@@ -98,8 +50,11 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor): ...@@ -98,8 +50,11 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token), multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
max_req_input_len=max_req_input_len, max_req_input_len=max_req_input_len,
) )
res = await self._process_images( res = self.process_mm_data(
base_output.images, base_output.input_text, max_req_input_len input_text=base_output.input_text,
images=base_output.images,
max_req_input_len=max_req_input_len,
conversations=base_output.input_text,
) )
images_seq_mask = res["images_seq_mask"] images_seq_mask = res["images_seq_mask"]
images_spatial_crop = res["images_spatial_crop"] images_spatial_crop = res["images_spatial_crop"]
...@@ -107,13 +62,17 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor): ...@@ -107,13 +62,17 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
batched_images_spatial_crop.append(images_spatial_crop) batched_images_spatial_crop.append(images_spatial_crop)
batched_images_spatial_crop = torch.stack(batched_images_spatial_crop, dim=0) batched_images_spatial_crop = torch.stack(batched_images_spatial_crop, dim=0)
items = []
item = MultimodalDataItem(
pixel_values=res["images"],
modality=Modality.IMAGE,
image_emb_mask=images_seq_mask,
image_spatial_crop=batched_images_spatial_crop,
)
items += [item]
return { return {
"mm_items": items,
"input_ids": res["input_ids"].tolist(), "input_ids": res["input_ids"].tolist(),
"pixel_values": res["images"], "im_token_id": self._processor.image_token_id,
"im_token_id": res["im_token_id"],
"data_hashes": base_output.mm_data_hashes,
"image_sizes": image_sizes,
"images_emb_mask": images_seq_mask,
"image_spatial_crop": batched_images_spatial_crop,
"modalities": request_obj.modalities or ["image"],
} }
...@@ -7,8 +7,8 @@ from sglang.srt.managers.multimodal_processor import ( ...@@ -7,8 +7,8 @@ from sglang.srt.managers.multimodal_processor import (
) )
from sglang.srt.managers.multimodal_processors.base_processor import ( from sglang.srt.managers.multimodal_processors.base_processor import (
MultimodalSpecialTokens, MultimodalSpecialTokens,
get_global_processor,
) )
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.gemma3_mm import Gemma3ForConditionalGeneration from sglang.srt.models.gemma3_mm import Gemma3ForConditionalGeneration
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma3/image_processing_gemma3_fast.py # Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma3/image_processing_gemma3_fast.py
...@@ -25,28 +25,6 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor): ...@@ -25,28 +25,6 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
self.IM_START_TOKEN_ID = hf_config.boi_token_index self.IM_START_TOKEN_ID = hf_config.boi_token_index
self.IM_END_TOKEN_ID = hf_config.eoi_token_index self.IM_END_TOKEN_ID = hf_config.eoi_token_index
async def _process_single_image(self, images, input_text) -> dict:
if isinstance(images, list) and len(images) == 0:
images = None
processor = get_global_processor()
result = processor.__call__(
text=[input_text],
images=images,
padding=True,
return_tensors="pt",
# if RGBA, this needs to be set
# images_kwargs={
# "input_data_format": ChannelDimension.FIRST
# }
)
pixel_values = getattr(result, "pixel_values", None)
return {
"input_ids": result.input_ids,
"pixel_values": pixel_values,
}
async def process_mm_data_async( async def process_mm_data_async(
self, self,
image_data: List[Union[str, bytes]], image_data: List[Union[str, bytes]],
...@@ -63,21 +41,28 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor): ...@@ -63,21 +41,28 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
image_token = self.IMAGE_TOKEN image_token = self.IMAGE_TOKEN
base_output = self.load_mm_data( base_output = self.load_mm_data(
input_ids=input_ids, prompt=input_ids,
image_data=image_data, image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token), multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
max_req_input_len=max_req_input_len, max_req_input_len=max_req_input_len,
discard_alpha_channel=True, discard_alpha_channel=True,
) )
ret = await self._process_single_image( ret = self.process_mm_data(
input_text=base_output.input_text, images=base_output.images input_text=base_output.input_text, images=base_output.images
) )
items = []
for i, image in enumerate(base_output.images):
item = MultimodalDataItem(
pixel_values=ret["pixel_values"][i],
modality=Modality.IMAGE,
)
items += [item]
return { return {
"mm_items": items,
"input_ids": ret["input_ids"].flatten().tolist(), "input_ids": ret["input_ids"].flatten().tolist(),
"pixel_values": ret["pixel_values"],
"data_hashes": base_output.mm_data_hashes,
"im_start_id": self.IM_START_TOKEN_ID, "im_start_id": self.IM_START_TOKEN_ID,
"im_end_id": self.IM_END_TOKEN_ID, "im_end_id": self.IM_END_TOKEN_ID,
} }
import asyncio
from typing import List, Union from typing import List, Union
from sglang.srt.managers.multimodal_processors.base_processor import ( from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor, BaseMultimodalProcessor,
MultimodalSpecialTokens, MultimodalSpecialTokens,
get_global_processor,
) )
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.deepseek_janus_pro import MultiModalityCausalLM from sglang.srt.models.deepseek_janus_pro import MultiModalityCausalLM
...@@ -15,37 +14,6 @@ class JanusProImageProcessor(BaseMultimodalProcessor): ...@@ -15,37 +14,6 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor)
@staticmethod
def _process_images_task(images, input_text):
processor = get_global_processor()
result = processor.__call__(
prompt=input_text, images=images, return_tensors="pt"
)
return {
"input_ids": result["input_ids"],
"pixel_values": result["pixel_values"],
"images_emb_mask": result["images_emb_mask"],
"im_start_id": processor.image_start_id,
"im_end_id": processor.image_end_id,
"im_token_id": processor.image_id,
}
async def _process_images(self, images, input_text):
if self.executor is not None:
loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor(
self.executor,
JanusProImageProcessor._process_images_task,
images,
input_text,
)
else:
image_inputs = self._processor(
images=images, text=input_text, return_tensors="pt"
)
return image_inputs
async def process_mm_data_async( async def process_mm_data_async(
self, self,
image_data: List[Union[str, bytes]], image_data: List[Union[str, bytes]],
...@@ -60,25 +28,31 @@ class JanusProImageProcessor(BaseMultimodalProcessor): ...@@ -60,25 +28,31 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
if not isinstance(image_data, list): if not isinstance(image_data, list):
image_data = [image_data] image_data = [image_data]
processor = self._processor
base_out = self.load_mm_data( base_out = self.load_mm_data(
input_ids=input_ids, prompt=input_ids,
image_data=image_data, image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens( multimodal_tokens=MultimodalSpecialTokens(image_token=processor.image_tag),
image_token="<image_placeholder>"
),
max_req_input_len=max_req_input_len, max_req_input_len=max_req_input_len,
) )
images = base_out.images images = base_out.images
res = await self._process_images(images=images, input_text=base_out.input_text) res = self.process_mm_data(
# print(res) input_text=base_out.input_text,
# print(base_out) prompt=base_out.input_text,
# print("", res["images_emb_mask"].shape) images=images,
)
return { return {
"mm_items": [
MultimodalDataItem(
pixel_values=res["pixel_values"],
image_emb_mask=res["images_emb_mask"],
modality=Modality.IMAGE,
)
],
"input_ids": res["input_ids"].flatten().tolist(), "input_ids": res["input_ids"].flatten().tolist(),
"pixel_values": res["pixel_values"], "im_start_id": processor.image_start_id,
"images_emb_mask": res["images_emb_mask"], "im_end_id": processor.image_end_id,
"data_hashes": base_out.mm_data_hashes, "im_token_id": processor.image_id,
"im_start_id": res["im_start_id"],
"im_end_id": res["im_end_id"],
"im_token_id": res["im_token_id"],
} }
...@@ -5,8 +5,8 @@ import numpy as np ...@@ -5,8 +5,8 @@ import numpy as np
from sglang.srt.managers.multimodal_processors.base_processor import ( from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor, BaseMultimodalProcessor,
get_global_processor,
) )
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.mm_utils import expand2square, process_anyres_image from sglang.srt.mm_utils import expand2square, process_anyres_image
from sglang.srt.models.llava import LlavaMistralForCausalLM, LlavaQwenForCausalLM from sglang.srt.models.llava import LlavaMistralForCausalLM, LlavaQwenForCausalLM
from sglang.srt.models.llavavid import LlavaVidForCausalLM from sglang.srt.models.llavavid import LlavaVidForCausalLM
...@@ -25,11 +25,10 @@ class LlavaImageProcessor(BaseMultimodalProcessor): ...@@ -25,11 +25,10 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
image_data: Union[str, bytes], image_data: Union[str, bytes],
image_aspect_ratio: Optional[str] = None, image_aspect_ratio: Optional[str] = None,
image_grid_pinpoints: Optional[str] = None, image_grid_pinpoints: Optional[str] = None,
image_processor=None, processor=None,
): ):
processor = get_global_processor()
image_processor = image_processor or processor.image_processor image_processor = processor.image_processor
try: try:
image, image_size = load_image(image_data) image, image_size = load_image(image_data)
...@@ -72,18 +71,22 @@ class LlavaImageProcessor(BaseMultimodalProcessor): ...@@ -72,18 +71,22 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
async def _process_single_image( async def _process_single_image(
self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
): ):
if self.executor is not None: if self.cpu_executor is not None:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return await loop.run_in_executor( return await loop.run_in_executor(
self.executor, self.cpu_executor,
LlavaImageProcessor._process_single_image_task, LlavaImageProcessor._process_single_image_task,
image_data, image_data,
aspect_ratio, aspect_ratio,
grid_pinpoints, grid_pinpoints,
self._processor,
) )
else: else:
return self._process_single_image_task( return self._process_single_image_task(
image_data, aspect_ratio, grid_pinpoints image_data,
aspect_ratio,
grid_pinpoints,
self._processor.image_processor,
) )
async def process_mm_data_async( async def process_mm_data_async(
...@@ -134,14 +137,22 @@ class LlavaImageProcessor(BaseMultimodalProcessor): ...@@ -134,14 +137,22 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
pixel_values, image_hash, image_size = await self._process_single_image( pixel_values, image_hash, image_size = await self._process_single_image(
image_data[0], aspect_ratio, grid_pinpoints image_data[0], aspect_ratio, grid_pinpoints
) )
data_hashes = [image_hash]
image_sizes = [image_size] image_sizes = [image_size]
else: else:
raise ValueError(f"Invalid image data: {image_data}") raise ValueError(f"Invalid image data: {image_data}")
modality = Modality.IMAGE
if isinstance(request_obj.modalities, list):
if request_obj.modalities[0] == "multi-images":
modality = Modality.MULTI_IMAGES
elif request_obj.modalities[0] == "video":
modality = Modality.VIDEO
return { return {
"pixel_values": pixel_values, "mm_items": [
"data_hashes": data_hashes, MultimodalDataItem(
"image_sizes": image_sizes, pixel_values=pixel_values,
"modalities": request_obj.modalities or ["image"], image_sizes=image_sizes,
modality=modality,
)
],
} }
import asyncio
from typing import List, Union from typing import List, Union
import torch import torch
from transformers import BaseImageProcessorFast
from sglang.srt.managers.multimodal_processors.base_processor import ( from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor, BaseMultimodalProcessor,
MultimodalSpecialTokens, MultimodalSpecialTokens,
get_global_processor,
) )
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.minicpmo import MiniCPMO from sglang.srt.models.minicpmo import MiniCPMO
from sglang.srt.models.minicpmv import MiniCPMV from sglang.srt.models.minicpmv import MiniCPMV
...@@ -21,19 +21,23 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor): ...@@ -21,19 +21,23 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
self.image_token = "(<image>./</image>)" self.image_token = "(<image>./</image>)"
self.audio_token = "(<audio>./</audio>)" self.audio_token = "(<audio>./</audio>)"
@staticmethod def process_data_task(self, input_text, images=None, audios=None):
def _process_data_task(input_text, images=None, audios=None):
if isinstance(images, list) and len(images) == 0: if isinstance(images, list) and len(images) == 0:
images = None images = None
if isinstance(audios, list) and len(audios) == 0: if isinstance(audios, list) and len(audios) == 0:
audios = None audios = None
result = get_global_processor().__call__( processor = self._processor
args = {}
if isinstance(processor, BaseImageProcessorFast):
args["device"] = "cuda"
result = self._processor.__call__(
text=input_text, text=input_text,
images=images, images=images,
audios=audios, audios=audios,
return_tensors="pt", return_tensors="pt",
chunk_input=True, chunk_input=True,
**args,
) )
return { return {
"input_ids": result.input_ids, "input_ids": result.input_ids,
...@@ -44,23 +48,6 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor): ...@@ -44,23 +48,6 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
"audio_bounds": getattr(result, "audio_bounds", None), "audio_bounds": getattr(result, "audio_bounds", None),
} }
async def _process_data(self, images, input_text, audios=None):
if self.executor is not None:
loop = asyncio.get_event_loop()
multimodal_data_inputs = await loop.run_in_executor(
self.executor,
MiniCPMMultimodalProcessor._process_data_task,
input_text,
images,
audios,
)
else:
multimodal_data_inputs = self._processor(
images=images, text=input_text, audios=audios, return_tensors="pt"
)
return multimodal_data_inputs
async def process_mm_data_async( async def process_mm_data_async(
self, self,
image_data: List[Union[str, bytes]], image_data: List[Union[str, bytes]],
...@@ -77,7 +64,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor): ...@@ -77,7 +64,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
audio_data = [audio_data] audio_data = [audio_data]
base_output = self.load_mm_data( base_output = self.load_mm_data(
input_ids=input_ids, prompt=input_ids,
max_req_input_len=max_req_input_len, max_req_input_len=max_req_input_len,
audio_data=audio_data, audio_data=audio_data,
image_data=image_data, image_data=image_data,
...@@ -88,9 +75,9 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor): ...@@ -88,9 +75,9 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
if base_output is None: if base_output is None:
return None return None
res = await self._process_data( res = self.process_mm_data(
images=base_output.images,
input_text=base_output.input_text, input_text=base_output.input_text,
images=base_output.images,
audios=base_output.audios, audios=base_output.audios,
) )
...@@ -142,23 +129,33 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor): ...@@ -142,23 +129,33 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
tgt_sizes_flat += [tgt_n] tgt_sizes_flat += [tgt_n]
pixel_values = pixel_values_flat pixel_values = pixel_values_flat
if len(tgt_sizes_flat) == 0:
tgt_sizes = None items = []
else: if len(pixel_values) != 0:
tgt_sizes = torch.stack(tgt_sizes_flat) item = MultimodalDataItem(
if not isinstance(res["audio_features"], list): pixel_values=pixel_values,
res["audio_features"] = [res["audio_features"]] tgt_size=tgt_sizes_flat,
modality=Modality.IMAGE,
)
items += [item]
if (
"audio_features" in res
and res["audio_features"] is not None
and len(res["audio_features"]) != 0
):
item = MultimodalDataItem(
audio_features=[res["audio_features"]],
audio_feature_lens=res["audio_feature_lens"],
modality=Modality.AUDIO,
)
items += [item]
return { return {
"mm_items": items,
"input_ids": res["input_ids"].flatten().tolist(), "input_ids": res["input_ids"].flatten().tolist(),
"pixel_values": pixel_values,
"tgt_sizes": tgt_sizes,
"data_hashes": base_output.mm_data_hashes,
"modalities": request_obj.modalities or ["image"],
"audio_start_id": audio_start_id, "audio_start_id": audio_start_id,
"audio_end_id": audio_end_id, "audio_end_id": audio_end_id,
"audio_features": res["audio_features"],
"audio_bounds": res["audio_bounds"],
"audio_feature_lens": res["audio_feature_lens"],
"im_token_id": im_token_id, "im_token_id": im_token_id,
"im_start_id": tokenizer.im_start_id, "im_start_id": tokenizer.im_start_id,
"im_end_id": tokenizer.im_end_id, "im_end_id": tokenizer.im_end_id,
......
import asyncio
from typing import List, Union from typing import List, Union
from sglang.srt.managers.multimodal_processors.base_processor import ( from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor, BaseMultimodalProcessor,
get_global_processor,
) )
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.mllama import MllamaForConditionalGeneration from sglang.srt.models.mllama import MllamaForConditionalGeneration
from sglang.srt.utils import load_image from sglang.srt.utils import load_image
...@@ -15,25 +14,6 @@ class MllamaImageProcessor(BaseMultimodalProcessor): ...@@ -15,25 +14,6 @@ class MllamaImageProcessor(BaseMultimodalProcessor):
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor)
@staticmethod
def _process_single_image_task(images, input_text):
# input_ids', 'attention_mask', 'pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask', 'cross_attention_mask'
return get_global_processor()(images, input_text, return_tensors="pt")
async def _process_single_image(self, images, input_text):
if self.executor is not None:
loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor(
self.executor,
MllamaImageProcessor._process_single_image_task,
images,
input_text,
)
else:
image_inputs = self._processor(images, input_text, return_tensors="pt")
return image_inputs
async def process_mm_data_async( async def process_mm_data_async(
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
): ):
...@@ -52,8 +32,15 @@ class MllamaImageProcessor(BaseMultimodalProcessor): ...@@ -52,8 +32,15 @@ class MllamaImageProcessor(BaseMultimodalProcessor):
else: else:
images = load_image(image_data[0])[0] images = load_image(image_data[0])[0]
image_inputs = await self._process_single_image(images, input_text) image_inputs = self.process_mm_data(input_text=input_text, images=images)
image_inputs["data_hashes"] = [hash(str(image_data))]
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0] image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
image_inputs["mm_items"] = [
MultimodalDataItem(
pixel_values=image_inputs["pixel_values"],
aspect_ratio_id=image_inputs["aspect_ratio_ids"],
aspect_ratio_mask=image_inputs["aspect_ratio_mask"],
modality=Modality.IMAGE,
)
]
return image_inputs return image_inputs
import asyncio import asyncio
import math import math
import time
from typing import List, Union from typing import List, Union
import torch import torch
...@@ -11,8 +10,8 @@ from sglang.srt.managers.multimodal_processor import ( ...@@ -11,8 +10,8 @@ from sglang.srt.managers.multimodal_processor import (
) )
from sglang.srt.managers.multimodal_processors.base_processor import ( from sglang.srt.managers.multimodal_processors.base_processor import (
MultimodalSpecialTokens, MultimodalSpecialTokens,
get_global_processor,
) )
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration
...@@ -34,45 +33,15 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): ...@@ -34,45 +33,15 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
self.MAX_PIXELS = 16384 * 28 * 28 self.MAX_PIXELS = 16384 * 28 * 28
self.MAX_RATIO = 200 self.MAX_RATIO = 200
@staticmethod
def _process_images_task(images, input_text, _hf_config):
if isinstance(images, list) and len(images) == 0:
images = None
result = get_global_processor().__call__(
text=[input_text], images=images, padding=True, return_tensors="pt"
)
return {
"input_ids": result.input_ids,
"pixel_values": getattr(result, "pixel_values", None),
"image_grid_thw": getattr(result, "image_grid_thw", None),
"second_per_grid_ts": getattr(result, "second_per_grid_ts", None),
"video_grid_thws": getattr(result, "video_grid_thws", None),
}
async def _process_single_image(self, images, input_text) -> dict:
if self.executor is not None:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
Qwen2_5VLImageProcessor._process_images_task,
images,
input_text,
self.hf_config,
)
else:
return self._process_images_task(images, input_text, self.hf_config)
async def process_mm_data_async( async def process_mm_data_async(
self, self,
image_data: List[Union[str, bytes]], image_data: List[Union[str, bytes]],
input_ids, prompt,
request_obj, request_obj,
max_req_input_len, max_req_input_len,
*args, *args,
**kwargs, **kwargs,
): ):
start = time.time()
if not image_data: if not image_data:
return None return None
if isinstance(image_data, str): if isinstance(image_data, str):
...@@ -80,7 +49,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): ...@@ -80,7 +49,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
image_token = self.IMAGE_TOKEN image_token = self.IMAGE_TOKEN
base_output = self.load_mm_data( base_output = self.load_mm_data(
input_ids=input_ids, prompt=prompt,
image_data=image_data, image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token), multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
max_req_input_len=max_req_input_len, max_req_input_len=max_req_input_len,
...@@ -144,24 +113,32 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): ...@@ -144,24 +113,32 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return math.floor(number / factor) * factor return math.floor(number / factor) * factor
images = [resize_image(image) for image in base_output.images] async def resize_image_async(image):
return resize_image(image)
ret = await self._process_single_image( resize_tasks = [resize_image_async(image) for image in base_output.images]
images=images, input_text=base_output.input_text resized_images = await asyncio.gather(*resize_tasks)
ret = self.process_mm_data(
input_text=base_output.input_text,
images=resized_images,
) )
image_grid_thws = torch.concat([ret["image_grid_thw"]]) image_grid_thws = torch.concat([ret["image_grid_thw"]])
video_grid_thws = None
return { return {
"input_ids": ret["input_ids"].flatten().tolist(), "input_ids": ret["input_ids"].flatten().tolist(),
"pixel_values": ret["pixel_values"], "mm_items": [
"data_hashes": base_output.mm_data_hashes, MultimodalDataItem(
"modalities": request_obj.modalities or ["image"], pixel_values=ret["pixel_values"],
"image_grid_thws": image_grid_thws, image_grid_thws=image_grid_thws,
"video_grid_thws": video_grid_thws, # TODO
video_grid_thws=None,
second_per_grid_ts=ret.get("second_per_grid_ts", None),
modality=Modality.IMAGE,
)
],
"im_start_id": self.IM_START_TOKEN_ID, "im_start_id": self.IM_START_TOKEN_ID,
"im_end_id": self.IM_END_TOKEN_ID, "im_end_id": self.IM_END_TOKEN_ID,
"im_token_id": self.image_token_id, "im_token_id": self.image_token_id,
"video_token_id": self.video_token_id, "video_token_id": self.video_token_id,
"second_per_grid_ts": ret["second_per_grid_ts"],
} }
from __future__ import annotations from __future__ import annotations
from enum import Enum, auto
# Copyright 2023-2024 SGLang Team # Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -51,7 +53,7 @@ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, Forw ...@@ -51,7 +53,7 @@ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, Forw
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_compiler_backend from sglang.srt.utils import flatten_nested_list, get_compiler_backend
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
...@@ -143,165 +145,185 @@ class FINISH_ABORT(BaseFinishReason): ...@@ -143,165 +145,185 @@ class FINISH_ABORT(BaseFinishReason):
} }
class Modality(Enum):
IMAGE = auto()
MULTI_IMAGES = auto()
VIDEO = auto()
AUDIO = auto()
@dataclasses.dataclass @dataclasses.dataclass
class MultimodalInputs: class MultimodalDataItem:
"""The image related inputs.""" """
A single multimodal data, from a single image/video/audio or other
"""
modality: Modality
hash: int = None
pad_value: int = None
aspect_ratio_id: Optional[List[torch.Tensor]] = None
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
pixel_values: Union[torch.Tensor, np.array] image_sizes: Tuple[int, int] = None
data_hashes: Optional[list] = None
image_sizes: Optional[list] = None
image_offsets: Optional[list] = None image_offsets: Optional[list] = None
# the real data, pixel_values or audio_features
# data: Union[List[torch.Tensor], List[np.array]]
pixel_values: Union[torch.Tensor, np.array] = None
image_grid_thws: Union[torch.Tensor, np.array] = None
video_grid_thws: Union[torch.Tensor, np.array] = None
image_emb_mask: Optional[torch.Tensor] = None
image_spatial_crop: Optional[torch.Tensor] = None
second_per_grid_ts: Optional[List[torch.Tensor]] = None
# [num_images, (n, w, h)]
tgt_size: Tuple[int, int] = None
audio_features: Union[torch.Tensor, np.array] = None
audio_feature_lens: Optional[List[torch.Tensor]] = None
@staticmethod
def is_empty_list(l):
if l is None:
return True
return len([item for item in flatten_nested_list(l) if item is not None]) == 0
def set_pad_value(self):
"""
Set the pad value after first hashign the data
"""
def hash_feature(f):
if isinstance(f, list):
return hash(tuple(flatten_nested_list(f)))
elif isinstance(f, np.ndarray):
arr = np.ascontiguousarray(f)
arr_bytes = arr.tobytes()
return hash(arr_bytes)
return hash(f)
if self.is_audio():
self.hash = hash_feature(self.audio_features)
else:
self.hash = hash_feature(self.pixel_values)
assert self.hash is not None
self.pad_value = self.hash % (1 << 30)
def is_audio(self):
return (
self.modality == Modality.AUDIO
) and not MultimodalDataItem.is_empty_list(self.audio_features)
def is_image(self):
return (
self.modality == Modality.IMAGE or self.modality == Modality.MULTI_IMAGES
) and not MultimodalDataItem.is_empty_list(self.pixel_values)
def is_video(self):
return (
self.modality == Modality.VIDEO
) and not MultimodalDataItem.is_empty_list(self.pixel_values)
def validate(self):
...
# TODO
@dataclasses.dataclass
class MultimodalInputs:
"""The multimodal data related inputs."""
# items of data
mm_items: List[MultimodalDataItem]
image_pad_len: Optional[list] = None image_pad_len: Optional[list] = None
pad_values: Optional[list] = None
modalities: Optional[list] = None
num_image_tokens: Optional[int] = None num_image_tokens: Optional[int] = None
# Llava related
aspect_ratio_ids: Optional[List[torch.Tensor]] = None
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
# QWen2-VL related # QWen2-VL related
# [num_of_images, t, h, w]
image_grid_thws: torch.Tensor = None
mrope_position_delta: Optional[torch.Tensor] = None mrope_position_delta: Optional[torch.Tensor] = None
# Qwen2-VL video related
video_token_id: Optional[int] = None
video_grid_thws: List[Tuple[int, int, int]] = None
second_per_grid_ts: Optional[List[torch.Tensor]] = None
# deepseek vl2 related # image
images_emb_mask: Optional[List[torch.Tensor]] = None
image_spatial_crop: Optional[List[torch.Tensor]] = None
# The id of the single-image placeholder token
im_token_id: Optional[torch.Tensor] = None im_token_id: Optional[torch.Tensor] = None
# All the images in the batch should share the same special image
# bound token ids.
im_start_id: Optional[int] = None im_start_id: Optional[int] = None
im_end_id: Optional[int] = None im_end_id: Optional[int] = None
slice_start_id: Optional[int] = None slice_start_id: Optional[int] = None
slice_end_id: Optional[int] = None slice_end_id: Optional[int] = None
# [num_images, 2 (w, h)]
tgt_sizes: Optional[list] = None # video
video_token_id: Optional[int] = None
# audio # audio
audio_start_id: Optional[torch.Tensor] = None audio_start_id: Optional[torch.Tensor] = None
audio_end_id: Optional[torch.Tensor] = None audio_end_id: Optional[torch.Tensor] = None
audio_features: Optional[List[torch.Tensor]] = None
audio_feature_lens: Optional[List[torch.Tensor]] = None
@staticmethod @staticmethod
def from_dict(obj: dict): def from_dict(obj: dict):
ret = MultimodalInputs( ret = MultimodalInputs(
pixel_values=obj["pixel_values"], mm_items=obj["mm_items"],
data_hashes=obj["data_hashes"],
) )
assert isinstance(ret.mm_items, list)
ret.mm_items = [
item
for item in ret.mm_items
if item.is_audio() or item.is_image() or item.is_video()
]
assert len(ret.mm_items) != 0
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache. # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
# Please note that if the `input_ids` is later used in the model forward, # Please note that if the `input_ids` is later used in the model forward,
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
# errors in cuda kernels. See also llava.py for example. # errors in cuda kernels. See also llava.py for example.
ret.pad_values = [x % (1 << 30) for x in ret.data_hashes] for item in ret.mm_items:
item.set_pad_value()
optional_args = [ optional_args = [
"image_sizes",
"modalities", "modalities",
"aspect_ratio_ids",
"aspect_ratio_mask",
"image_grid_thws",
"images_emb_mask",
"image_spatial_crop",
"im_token_id", "im_token_id",
"im_start_id", "im_start_id",
"im_end_id", "im_end_id",
"slice_start_id", "slice_start_id",
"slice_end_id", "slice_end_id",
"tgt_sizes",
"audio_start_id", "audio_start_id",
"audio_end_id", "audio_end_id",
"audio_features",
"audio_feature_lens",
] ]
for arg in optional_args: for arg in optional_args:
if arg in obj: if arg in obj:
setattr(ret, arg, obj[arg]) setattr(ret, arg, obj[arg])
# validate
assert (
isinstance(ret.pixel_values, torch.Tensor)
or isinstance(ret.pixel_values, np.ndarray)
or isinstance(ret.pixel_values, list)
)
assert ret.audio_features is None or isinstance(ret.audio_features, list)
return ret return ret
def contains_image_inputs(self) -> bool: def contains_image_inputs(self) -> bool:
""" """ """ """
return self.pixel_values is not None and self.pixel_values != [] return any(item.is_image() for item in self.mm_items)
def contains_audio_inputs(self) -> bool: def contains_audio_inputs(self) -> bool:
""" """ """ """
return self.audio_features is not None and self.audio_features != [] return any(item.is_audio() for item in self.mm_items)
def collect_image_inputs(self) -> List[torch.Tensor]:
return [item.pixel_values for item in self.mm_items if item.is_image()]
def merge(self, other: MultimodalInputs): def merge(self, other: MultimodalInputs):
""" """
merge image inputs when requests are being merged merge image inputs when requests are being merged
""" """
if isinstance(self.pixel_values, list):
# in some rare cases, pixel values are list of patches with different shapes
# e.g. minicpm
self.pixel_values += other.pixel_values
else:
assert (
self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
), f"{self.pixel_values.shape[1:]} vs {other.pixel_values.shape[1:]}"
self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])
# args would be stacked along first dim
# usually these are already tensors
stack_args = [
# TODO: merge with image_grid_thws, basically the same thing
"tgt_sizes",
"image_spatial_crop",
]
for arg in stack_args:
if getattr(self, arg, None) is None:
setattr(self, arg, getattr(other, arg, None))
elif getattr(other, arg, None) is not None:
# self and other both not None
setattr(
self,
arg,
torch.cat([getattr(self, arg), getattr(other, arg)], dim=0),
)
if self.image_grid_thws is None:
self.image_grid_thws = other.image_grid_thws
elif other.image_grid_thws is not None:
self.image_grid_thws = torch.concat(
[self.image_grid_thws, other.image_grid_thws]
)
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache. # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
# Please note that if the `input_ids` is later used in the model forward, # Please note that if the `input_ids` is later used in the model forward,
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
# errors in cuda kernels. See also llava.py for example. # errors in cuda kernels. See also llava.py for example.
self.data_hashes += other.data_hashes
self.pad_values = [x % (1 << 30) for x in self.data_hashes]
# args needed to be merged # args needed to be merged
optional_args = [ optional_args = [
"audio_features", "items",
"image_sizes",
"image_offsets", "image_offsets",
"image_pad_len", "image_pad_len",
# "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images # "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
"aspect_ratio_ids",
"aspect_ratio_mask",
"images_emb_mask",
] ]
for arg in optional_args: for arg in optional_args:
self_arg = getattr(self, arg, None) self_arg = getattr(self, arg, None)
......
...@@ -112,7 +112,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache ...@@ -112,7 +112,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
......
import json
import logging import logging
import time
from collections import defaultdict
from http import HTTPStatus from http import HTTPStatus
from typing import Dict, List, Optional, Tuple from typing import Optional
import torch
from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
......
...@@ -355,11 +355,6 @@ class ForwardBatch: ...@@ -355,11 +355,6 @@ class ForwardBatch:
for mm_input in valid_inputs[1:]: for mm_input in valid_inputs[1:]:
merged.merge(mm_input) merged.merge(mm_input)
if isinstance(merged.pixel_values, np.ndarray):
merged.pixel_values = torch.from_numpy(merged.pixel_values)
if isinstance(merged.audio_features, np.ndarray):
merged.audio_features = torch.from_numpy(merged.audio_features)
return merged return merged
def contains_image_inputs(self) -> bool: def contains_image_inputs(self) -> bool:
......
...@@ -251,17 +251,16 @@ class ModelRunner: ...@@ -251,17 +251,16 @@ class ModelRunner:
self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type) self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)
if self.is_multimodal: if self.is_multimodal:
self.mem_fraction_static *= 0.95 self.mem_fraction_static *= 0.90
logger.info( logger.info(
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} " f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
f"because this is a multimodal model." f"because this is a multimodal model."
) )
if self.model_config.hf_config.architectures == [ logger.info(
"MllamaForConditionalGeneration" "Automatically turn off --chunked-prefill-size for multimodal model."
]: )
logger.info("Automatically turn off --chunked-prefill-size for mllama.") server_args.chunked_prefill_size = -1
server_args.chunked_prefill_size = -1
if self.model_config.hf_config.architectures == [ if self.model_config.hf_config.architectures == [
"Qwen2VLForConditionalGeneration" "Qwen2VLForConditionalGeneration"
...@@ -269,18 +268,7 @@ class ModelRunner: ...@@ -269,18 +268,7 @@ class ModelRunner:
"Qwen2_5_VLForConditionalGeneration" "Qwen2_5_VLForConditionalGeneration"
]: ]:
# TODO: qwen2-vl series does not support radix cache now, set disable_radix_cache=True automatically # TODO: qwen2-vl series does not support radix cache now, set disable_radix_cache=True automatically
logger.info( logger.info("Automatically disable radix cache for qwen-vl series.")
"Automatically turn off --chunked-prefill-size and disable radix cache for qwen-vl series."
)
server_args.chunked_prefill_size = -1
server_args.disable_radix_cache = True
if self.model_config.hf_config.architectures == ["DeepseekVL2ForCausalLM"]:
# TODO: deepseek-vl2 does not support radix cache now, set disable_radix_cache=True automatically
logger.info(
"Automatically turn off --chunked-prefill-size and disable radix cache for deepseek-vl2."
)
server_args.chunked_prefill_size = -1
server_args.disable_radix_cache = True server_args.disable_radix_cache = True
if server_args.enable_deepep_moe: if server_args.enable_deepep_moe:
......
...@@ -17,7 +17,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig ...@@ -17,7 +17,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.schedule_batch import MultimodalInputs from sglang.srt.managers.schedule_batch import MultimodalInputs
from sglang.srt.model_executor.model_runner import ForwardBatch from sglang.srt.model_executor.model_runner import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix from sglang.srt.utils import add_prefix, flatten_nested_list
class CLIPVisionEmbeddings(nn.Module): class CLIPVisionEmbeddings(nn.Module):
...@@ -368,7 +368,6 @@ class CLIPVisionTransformer(nn.Module): ...@@ -368,7 +368,6 @@ class CLIPVisionTransformer(nn.Module):
self, self,
pixel_values: torch.Tensor, pixel_values: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embeddings(pixel_values.to(self.device)) hidden_states = self.embeddings(pixel_values.to(self.device))
hidden_states = self.pre_layrnorm(hidden_states) hidden_states = self.pre_layrnorm(hidden_states)
...@@ -456,12 +455,18 @@ class CLIPModel(nn.Module): ...@@ -456,12 +455,18 @@ class CLIPModel(nn.Module):
get_embedding: bool = True, get_embedding: bool = True,
): ):
assert get_embedding, "CLIPEmbeddingModel is only used for embedding" assert get_embedding, "CLIPEmbeddingModel is only used for embedding"
image_inputs = None mm_inputs = []
if forward_batch.mm_inputs is not None: if forward_batch.mm_inputs is not None:
image_inputs = forward_batch.mm_inputs mm_inputs = forward_batch.mm_inputs
pixel_values_list = [
if image_inputs is not None and image_inputs[0] is not None: item.pixel_values
vision_outputs = self.vision_model(image_inputs[0].pixel_values) for item in flatten_nested_list(
[mm_input.mm_items for mm_input in mm_inputs if mm_input is not None]
)
]
if len(pixel_values_list) != 0:
pixel_values = torch.concat(pixel_values_list)
vision_outputs = self.vision_model(pixel_values)
pooled_output = vision_outputs[:, 0, :] pooled_output = vision_outputs[:, 0, :]
image_embeds = self.visual_projection(pooled_output) image_embeds = self.visual_projection(pooled_output)
image_embeds = nn.functional.normalize(image_embeds, p=2, dim=1) image_embeds = nn.functional.normalize(image_embeds, p=2, dim=1)
......
...@@ -51,7 +51,7 @@ from sglang.srt.managers.mm_utils import ( ...@@ -51,7 +51,7 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs, MultiModalityDataPaddingPatternTokenPairs,
general_mm_embed_routine, general_mm_embed_routine,
) )
from sglang.srt.managers.schedule_batch import MultimodalInputs, global_server_args_dict from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.llama import LlamaForCausalLM from sglang.srt.models.llama import LlamaForCausalLM
...@@ -1959,8 +1959,8 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel): ...@@ -1959,8 +1959,8 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
def get_image_feature(self, image_input: MultimodalInputs) -> torch.Tensor: def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
pixel_values = image_input.pixel_values pixel_values = torch.concat([item.pixel_values for item in items], dim=0)
bs, n = pixel_values.shape[0:2] bs, n = pixel_values.shape[0:2]
pixel_values = pixel_values.to( pixel_values = pixel_values.to(
device=self.vision_model.device, dtype=self.vision_model.dtype device=self.vision_model.device, dtype=self.vision_model.dtype
...@@ -1976,7 +1976,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel): ...@@ -1976,7 +1976,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
return images_embeds return images_embeds
def get_input_embeddings(self) -> nn.Embedding: def get_input_embeddings(self) -> nn.Embedding:
return self.language_model.model.embed_tokens return self.language_model.get_input_embeddings()
@torch.no_grad() @torch.no_grad()
def forward( def forward(
...@@ -1984,23 +1984,18 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel): ...@@ -1984,23 +1984,18 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
positions: torch.Tensor, positions: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
get_embedding: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = general_mm_embed_routine(
inputs_embeds = general_mm_embed_routine(
input_ids=input_ids, input_ids=input_ids,
forward_batch=forward_batch, forward_batch=forward_batch,
embed_tokens=self.get_input_embeddings(), image_data_embedding_func=self.get_image_feature,
mm_data_embedding_func=self.get_image_feature, language_model=self.language_model,
)
return self.language_model(
input_ids=None,
positions=positions, positions=positions,
forward_batch=forward_batch,
input_embeds=inputs_embeds,
get_embedding=False,
) )
return hidden_states
def prepare_gen_img_embeds(self, image_ids: torch.LongTensor): def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
return self.gen_aligner(self.gen_embed(image_ids)) return self.gen_aligner(self.gen_embed(image_ids))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment