Unverified Commit e1020dc5 authored by Mick's avatar Mick Committed by GitHub
Browse files

refactor: simply MultimodalTokens logic (#7924)

parent 3586b4ce
...@@ -21,7 +21,7 @@ class BaseMultiModalProcessorOutput: ...@@ -21,7 +21,7 @@ class BaseMultiModalProcessorOutput:
# input_text, with each frame of video/image represented with a image_token # input_text, with each frame of video/image represented with a image_token
input_text: str input_text: str
# frames loaded from image and video, in given order # frames loaded from image, in given order
images: Optional[list[Union[Image.Image, dict]]] = None images: Optional[list[Union[Image.Image, dict]]] = None
# videos # videos
...@@ -44,14 +44,26 @@ class BaseMultiModalProcessorOutput: ...@@ -44,14 +44,26 @@ class BaseMultiModalProcessorOutput:
@dataclasses.dataclass @dataclasses.dataclass
class MultimodalSpecialTokens: class MultimodalSpecialTokens:
image_token: Optional[Union[int, str, List[str]]] = None image_token: Optional[Union[str, List[str]]] = None
video_token: Optional[Union[int, str, List[str]]] = None video_token: Optional[Union[str, List[str]]] = None
audio_token: Optional[Union[int, str, List[str]]] = None audio_token: Optional[Union[str, List[str]]] = None
image_token_id: Optional[int] = None
video_token_id: Optional[int] = None
audio_token_id: Optional[int] = None
image_token_regex: Optional[re.Pattern] = None image_token_regex: Optional[re.Pattern] = None
video_token_regex: Optional[re.Pattern] = None video_token_regex: Optional[re.Pattern] = None
audio_token_regex: Optional[re.Pattern] = None audio_token_regex: Optional[re.Pattern] = None
combined_regex: Optional[re.Pattern] = None
def build(self, processor):
self.convert_to_strs(processor)
self.parse_regex()
self.get_combined_regex()
return self
def convert_to_str(self, token: Union[str, int], processor) -> str: def convert_to_str(self, token: Union[str, int], processor) -> str:
if token is None: if token is None:
return token return token
...@@ -60,11 +72,14 @@ class MultimodalSpecialTokens: ...@@ -60,11 +72,14 @@ class MultimodalSpecialTokens:
return processor.tokenizer.convert_ids_to_tokens([token])[0] return processor.tokenizer.convert_ids_to_tokens([token])[0]
def convert_to_strs(self, processor): def convert_to_strs(self, processor):
self.image_token = self.convert_to_str(self.image_token, processor) if not self.image_token:
self.video_token = self.convert_to_str(self.video_token, processor) self.image_token = self.convert_to_str(self.image_token_id, processor)
self.audio_token = self.convert_to_str(self.audio_token, processor) if not self.video_token:
self.video_token = self.convert_to_str(self.video_token_id, processor)
def get_modality_of_token(self, token) -> Optional[Modality]: if not self.audio_token:
self.audio_token = self.convert_to_str(self.audio_token_id, processor)
def get_modality_of_token(self, token: str) -> Optional[Modality]:
""" """
:return: the modality associated with the given token, if the token is a special_token or matches with the multimodal token regex :return: the modality associated with the given token, if the token is a special_token or matches with the multimodal token regex
""" """
...@@ -94,7 +109,12 @@ class MultimodalSpecialTokens: ...@@ -94,7 +109,12 @@ class MultimodalSpecialTokens:
if self.audio_token_regex is None and self.audio_token is not None: if self.audio_token_regex is None and self.audio_token is not None:
self.audio_token_regex = re.compile(re.escape(self.audio_token)) self.audio_token_regex = re.compile(re.escape(self.audio_token))
def combine_regex(self) -> re.Pattern: def get_combined_regex(self) -> re.Pattern:
"""
Builds and returns a regex, used to split input str into tokens (with mm special tokens)
"""
if self.combined_regex:
return self.combined_regex
tokens = [ tokens = [
self.image_token_regex, self.image_token_regex,
self.video_token_regex, self.video_token_regex,
...@@ -107,7 +127,8 @@ class MultimodalSpecialTokens: ...@@ -107,7 +127,8 @@ class MultimodalSpecialTokens:
patterns.append(t.pattern) patterns.append(t.pattern)
flags |= t.flags flags |= t.flags
combined = "(" + "|".join(f"(?:{p})" for p in patterns) + ")" combined = "(" + "|".join(f"(?:{p})" for p in patterns) + ")"
return re.compile(combined, flags) self.combined_regex = re.compile(combined, flags)
return self.combined_regex
class BaseMultimodalProcessor(ABC): class BaseMultimodalProcessor(ABC):
...@@ -341,9 +362,8 @@ class BaseMultimodalProcessor(ABC): ...@@ -341,9 +362,8 @@ class BaseMultimodalProcessor(ABC):
discard_alpha_channel: if True, discards the alpha channel in the returned images discard_alpha_channel: if True, discards the alpha channel in the returned images
""" """
multimodal_tokens.convert_to_strs(self._processor) multimodal_tokens_pattern = multimodal_tokens.get_combined_regex()
multimodal_tokens.parse_regex()
multimodal_tokens_pattern = multimodal_tokens.combine_regex()
if isinstance(prompt, list) and return_text: if isinstance(prompt, list) and return_text:
assert len(prompt) and isinstance(prompt[0], int) assert len(prompt) and isinstance(prompt[0], int)
prompt = self._processor.tokenizer.decode(prompt) prompt = self._processor.tokenizer.decode(prompt)
...@@ -445,7 +465,6 @@ class BaseMultimodalProcessor(ABC): ...@@ -445,7 +465,6 @@ class BaseMultimodalProcessor(ABC):
return result = [(2,4),(6,7)] return result = [(2,4),(6,7)]
""" """
mask = input_ids == mm_token_id mask = input_ids == mm_token_id
start_positions = (mask & ~torch.roll(mask, 1)).nonzero(as_tuple=True)[0] start_positions = (mask & ~torch.roll(mask, 1)).nonzero(as_tuple=True)[0]
end_positions = (mask & ~torch.roll(mask, -1)).nonzero(as_tuple=True)[0] end_positions = (mask & ~torch.roll(mask, -1)).nonzero(as_tuple=True)[0]
...@@ -554,7 +573,9 @@ class BaseMultimodalProcessor(ABC): ...@@ -554,7 +573,9 @@ class BaseMultimodalProcessor(ABC):
return collected_items, input_ids, ret return collected_items, input_ids, ret
def process_and_combine_mm_data( def process_and_combine_mm_data(
self, base_output: BaseMultiModalProcessorOutput self,
base_output: BaseMultiModalProcessorOutput,
mm_tokens: MultimodalSpecialTokens,
) -> Tuple[List[MultimodalDataItem], torch.Tensor, dict]: ) -> Tuple[List[MultimodalDataItem], torch.Tensor, dict]:
""" """
Process multimodal data and return the combined multimodal items and input_ids. Process multimodal data and return the combined multimodal items and input_ids.
...@@ -618,22 +639,14 @@ class BaseMultimodalProcessor(ABC): ...@@ -618,22 +639,14 @@ class BaseMultimodalProcessor(ABC):
# Add offsets to all items # Add offsets to all items
for mm_item in all_collected_items: for mm_item in all_collected_items:
if mm_item.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]: mm_item.offsets = self.get_mm_items_offset(
mm_item.offsets = self.get_mm_items_offset( input_ids=input_ids,
input_ids=input_ids, mm_token_id={
mm_token_id=self.IM_TOKEN_ID, Modality.IMAGE: mm_tokens.image_token_id,
) Modality.MULTI_IMAGES: mm_tokens.image_token_id,
elif mm_item.modality == Modality.AUDIO: Modality.VIDEO: mm_tokens.video_token_id,
mm_item.offsets = self.get_mm_items_offset( Modality.AUDIO: mm_tokens.audio_token_id,
input_ids=input_ids, }.get(mm_item.modality, None),
mm_token_id=self.AUDIO_TOKEN_ID, )
)
elif mm_item.modality == Modality.VIDEO:
mm_item.offsets = self.get_mm_items_offset(
input_ids=input_ids,
mm_token_id=self.VIDEO_TOKEN_ID,
)
else:
raise ValueError(f"Unknown modality: {mm_item.modality}")
return all_collected_items, input_ids, ret return all_collected_items, input_ids, ret
...@@ -33,7 +33,9 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor): ...@@ -33,7 +33,9 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "<image>" self.mm_tokens = MultimodalSpecialTokens(image_token="<image>").build(
_processor
)
async def process_mm_data_async( async def process_mm_data_async(
self, self,
...@@ -47,7 +49,7 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor): ...@@ -47,7 +49,7 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
base_output = self.load_mm_data( base_output = self.load_mm_data(
input_text, input_text,
image_data=image_data, image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(image_token=self.IMAGE_TOKEN), multimodal_tokens=self.mm_tokens,
max_req_input_len=max_req_input_len, max_req_input_len=max_req_input_len,
) )
res = self.process_mm_data( res = self.process_mm_data(
......
...@@ -4,7 +4,6 @@ from typing import Dict, List, Union ...@@ -4,7 +4,6 @@ from typing import Dict, List, Union
from sglang.srt.managers.multimodal_processor import ( from sglang.srt.managers.multimodal_processor import (
BaseMultimodalProcessor as SGLangBaseProcessor, BaseMultimodalProcessor as SGLangBaseProcessor,
) )
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.gemma3_mm import Gemma3ForConditionalGeneration from sglang.srt.models.gemma3_mm import Gemma3ForConditionalGeneration
from sglang.srt.multimodal.processors.base_processor import MultimodalSpecialTokens from sglang.srt.multimodal.processors.base_processor import MultimodalSpecialTokens
...@@ -17,15 +16,17 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor): ...@@ -17,15 +16,17 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor)
# The single, pre-expanded image token.
self.IMAGE_TOKEN = "<start_of_image>"
# The regex that matches expanded image tokens.
self.IMAGE_TOKEN_REGEX = re.compile(
r"<start_of_image>(?:(?:<image_soft_token>)*<end_of_image>)?"
)
self.IM_START_TOKEN_ID = hf_config.boi_token_index self.IM_START_TOKEN_ID = hf_config.boi_token_index
self.IM_END_TOKEN_ID = hf_config.eoi_token_index self.IM_END_TOKEN_ID = hf_config.eoi_token_index
self.IM_TOKEN_ID = hf_config.image_token_index self.mm_tokens = MultimodalSpecialTokens(
# The single, pre-expanded image token.
image_token="<start_of_image>",
image_token_id=hf_config.image_token_index,
# The regex that matches expanded image tokens.
image_token_regex=re.compile(
r"<start_of_image>(?:(?:<image_soft_token>)*<end_of_image>)?"
),
).build(_processor)
async def process_mm_data_async( async def process_mm_data_async(
self, self,
...@@ -39,14 +40,14 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor): ...@@ -39,14 +40,14 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
base_output = self.load_mm_data( base_output = self.load_mm_data(
prompt=input_text, prompt=input_text,
image_data=image_data, image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens( multimodal_tokens=self.mm_tokens,
image_token=self.IMAGE_TOKEN, image_token_regex=self.IMAGE_TOKEN_REGEX
),
max_req_input_len=max_req_input_len, max_req_input_len=max_req_input_len,
discard_alpha_channel=True, discard_alpha_channel=True,
) )
mm_items, input_ids, _ = self.process_and_combine_mm_data(base_output) mm_items, input_ids, _ = self.process_and_combine_mm_data(
base_output, self.mm_tokens
)
return { return {
"input_ids": input_ids.tolist(), "input_ids": input_ids.tolist(),
"mm_items": mm_items, "mm_items": mm_items,
......
...@@ -30,23 +30,23 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor): ...@@ -30,23 +30,23 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor):
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "<image_soft_token>"
self.IMAGE_TOKEN_REGEX = re.compile(
r"<start_of_image>(?:(?:<image_soft_token>)*<end_of_image>)?"
)
self.AUDIO_TOKEN = "<audio_soft_token>"
self.AUDIO_TOKEN_REGEX = re.compile(
r"<start_of_audio>(?:(?:<audio_soft_token>)*<end_of_audio>)?"
)
self.IM_TOKEN_ID = hf_config.image_token_id
self.IM_START_TOKEN_ID = hf_config.boi_token_id self.IM_START_TOKEN_ID = hf_config.boi_token_id
self.IM_END_TOKEN_ID = hf_config.eoi_token_id self.IM_END_TOKEN_ID = hf_config.eoi_token_id
self.AUDIO_TOKEN_ID = hf_config.audio_token_id
self.AUDIO_START_TOKEN_ID = hf_config.boa_token_id self.AUDIO_START_TOKEN_ID = hf_config.boa_token_id
self.AUDIO_END_TOKEN_ID = hf_config.eoa_token_id self.AUDIO_END_TOKEN_ID = hf_config.eoa_token_id
self.mm_tokens = MultimodalSpecialTokens(
image_token="<image_soft_token>",
image_token_id=hf_config.image_token_id,
image_token_regex=re.compile(
r"<start_of_image>(?:(?:<image_soft_token>)*<end_of_image>)?"
),
audio_token="<audio_soft_token>",
audio_token_id=hf_config.audio_token_id,
audio_token_regex=re.compile(
r"<start_of_audio>(?:(?:<audio_soft_token>)*<end_of_audio>)?"
),
).build(_processor)
async def process_mm_data_async( async def process_mm_data_async(
self, self,
...@@ -64,19 +64,17 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor): ...@@ -64,19 +64,17 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor):
image_data=image_data, image_data=image_data,
audio_data=audio_data, audio_data=audio_data,
max_req_input_len=max_req_input_len, max_req_input_len=max_req_input_len,
multimodal_tokens=MultimodalSpecialTokens( multimodal_tokens=self.mm_tokens,
image_token=self.IMAGE_TOKEN,
image_token_regex=self.IMAGE_TOKEN_REGEX,
audio_token=self.AUDIO_TOKEN,
audio_token_regex=self.AUDIO_TOKEN_REGEX,
),
) )
mm_items, input_ids, _ = self.process_and_combine_mm_data(base_output) mm_items, input_ids, _ = self.process_and_combine_mm_data(
base_output, self.mm_tokens
)
return { return {
"input_ids": input_ids.tolist(), "input_ids": input_ids.tolist(),
"mm_items": mm_items, "mm_items": mm_items,
"im_token_id": self.IM_TOKEN_ID, # TODO(mick): could we return MultimodalSpecialTokens directly?
"audio_token_id": self.AUDIO_TOKEN_ID, "im_token_id": self.mm_tokens.image_token_id,
"audio_token_id": self.mm_tokens.audio_token_id,
} }
...@@ -24,7 +24,6 @@ class InternVLImageProcessor(BaseMultimodalProcessor): ...@@ -24,7 +24,6 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
self.IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>" self.IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>"
self.IMG_START_TOKEN = "<img>" self.IMG_START_TOKEN = "<img>"
self.IMG_END_TOKEN = "</img>" self.IMG_END_TOKEN = "</img>"
self.IMG_TOKEN = "<image>"
self.num_image_token = int( self.num_image_token = int(
(image_size // patch_size) ** 2 * (hf_config.downsample_ratio**2) (image_size // patch_size) ** 2 * (hf_config.downsample_ratio**2)
) )
...@@ -32,9 +31,10 @@ class InternVLImageProcessor(BaseMultimodalProcessor): ...@@ -32,9 +31,10 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
tokenizer = self._processor tokenizer = self._processor
self.img_start_token_id = tokenizer.convert_tokens_to_ids(self.IMG_START_TOKEN) self.img_start_token_id = tokenizer.convert_tokens_to_ids(self.IMG_START_TOKEN)
self.img_end_token_id = tokenizer.convert_tokens_to_ids(self.IMG_END_TOKEN) self.img_end_token_id = tokenizer.convert_tokens_to_ids(self.IMG_END_TOKEN)
self.img_context_token_id = tokenizer.convert_tokens_to_ids( self.mm_tokens = MultimodalSpecialTokens(
self.IMG_CONTEXT_TOKEN image_token="<image>",
) image_token_id=tokenizer.convert_tokens_to_ids(self.IMG_CONTEXT_TOKEN),
).build(_image_processor)
@staticmethod @staticmethod
def build_transform(input_size): def build_transform(input_size):
...@@ -175,7 +175,7 @@ class InternVLImageProcessor(BaseMultimodalProcessor): ...@@ -175,7 +175,7 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
base_output = self.load_mm_data( base_output = self.load_mm_data(
prompt=input_text, prompt=input_text,
image_data=image_data, image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(image_token=self.IMG_TOKEN), multimodal_tokens=self.mm_tokens,
max_req_input_len=max_req_input_len, max_req_input_len=max_req_input_len,
discard_alpha_channel=True, discard_alpha_channel=True,
) )
...@@ -219,7 +219,7 @@ class InternVLImageProcessor(BaseMultimodalProcessor): ...@@ -219,7 +219,7 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
input_ids = tokenizer(input_text, return_tensors="pt")["input_ids"].flatten() input_ids = tokenizer(input_text, return_tensors="pt")["input_ids"].flatten()
image_offsets = self.get_mm_items_offset( image_offsets = self.get_mm_items_offset(
input_ids=input_ids, input_ids=input_ids,
mm_token_id=self.img_context_token_id, mm_token_id=self.mm_tokens.image_token_id,
) )
items = [ items = [
MultimodalDataItem( MultimodalDataItem(
...@@ -234,5 +234,5 @@ class InternVLImageProcessor(BaseMultimodalProcessor): ...@@ -234,5 +234,5 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
"mm_items": items, "mm_items": items,
"im_start_id": self.img_start_token_id, "im_start_id": self.img_start_token_id,
"im_end_id": self.img_end_token_id, "im_end_id": self.img_end_token_id,
"im_token_id": self.img_context_token_id, "im_token_id": self.mm_tokens.image_token_id,
} }
...@@ -11,8 +11,12 @@ from sglang.srt.multimodal.processors.base_processor import ( ...@@ -11,8 +11,12 @@ from sglang.srt.multimodal.processors.base_processor import (
class JanusProImageProcessor(BaseMultimodalProcessor): class JanusProImageProcessor(BaseMultimodalProcessor):
models = [MultiModalityCausalLM] models = [MultiModalityCausalLM]
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, processor):
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, processor)
self.mm_tokens = MultimodalSpecialTokens(
image_token=processor.image_token
).build(processor)
async def process_mm_data_async( async def process_mm_data_async(
self, self,
...@@ -27,9 +31,7 @@ class JanusProImageProcessor(BaseMultimodalProcessor): ...@@ -27,9 +31,7 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
base_out = self.load_mm_data( base_out = self.load_mm_data(
prompt=input_text, prompt=input_text,
image_data=image_data, image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens( multimodal_tokens=self.mm_tokens,
image_token=processor.image_token
),
max_req_input_len=max_req_input_len, max_req_input_len=max_req_input_len,
) )
......
import re import re
from typing import Any, Dict, List, Optional, Union from typing import Dict, List, Union
import torch
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.kimi_vl import KimiVLForConditionalGeneration from sglang.srt.models.kimi_vl import KimiVLForConditionalGeneration
from sglang.srt.multimodal.processors.base_processor import ( from sglang.srt.multimodal.processors.base_processor import (
BaseMultimodalProcessor as SGLangBaseProcessor, BaseMultimodalProcessor as SGLangBaseProcessor,
...@@ -17,9 +14,12 @@ class KimiVLImageProcessor(SGLangBaseProcessor): ...@@ -17,9 +14,12 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "<|media_pad|>" self.mm_tokens = MultimodalSpecialTokens(
self.IMAGE_TOKEN_REGEX = re.compile(r"(?:<\|media_pad\|>)+") image_token="<|media_pad|>",
self.IM_TOKEN_ID = _processor.tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN) # TODO: could we convert in MultimodalSpecialTokens?
image_token_id=hf_config.media_placeholder_token_id,
image_token_regex=re.compile(r"(?:<\|media_pad\|>)+"),
).build(_processor)
async def process_mm_data_async( async def process_mm_data_async(
self, self,
...@@ -33,16 +33,16 @@ class KimiVLImageProcessor(SGLangBaseProcessor): ...@@ -33,16 +33,16 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
base_output = self.load_mm_data( base_output = self.load_mm_data(
prompt=input_text, prompt=input_text,
image_data=image_data, image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens( multimodal_tokens=self.mm_tokens,
image_token=self.IMAGE_TOKEN, image_token_regex=self.IMAGE_TOKEN_REGEX
),
max_req_input_len=max_req_input_len, max_req_input_len=max_req_input_len,
) )
mm_items, input_ids, _ = self.process_and_combine_mm_data(base_output) mm_items, input_ids, _ = self.process_and_combine_mm_data(
base_output, self.mm_tokens
)
return { return {
"input_ids": input_ids.tolist(), "input_ids": input_ids.tolist(),
"mm_items": mm_items, "mm_items": mm_items,
"im_token_id": self.IM_TOKEN_ID, "im_token_id": self.mm_tokens.image_token_id,
} }
...@@ -17,9 +17,11 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor): ...@@ -17,9 +17,11 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor)
self.image_token = "(<image>./</image>)" self.mm_tokens = MultimodalSpecialTokens(
self.audio_token = "(<audio>./</audio>)" image_token="(<image>./</image>)",
self.video_token = "(<video>./</video>)" audio_token="(<audio>./</audio>)",
video_token="(<video>./</video>)",
).build(_processor)
async def process_mm_data_async( async def process_mm_data_async(
self, self,
...@@ -35,11 +37,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor): ...@@ -35,11 +37,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
max_req_input_len=max_req_input_len, max_req_input_len=max_req_input_len,
audio_data=audio_data, audio_data=audio_data,
image_data=image_data, image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens( multimodal_tokens=self.mm_tokens,
image_token=self.image_token,
video_token=self.video_token,
audio_token=self.audio_token,
),
) )
if base_output is None: if base_output is None:
return None return None
......
...@@ -26,8 +26,8 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor): ...@@ -26,8 +26,8 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
self.eoi_token_index = hf_config.eoi_token_index self.eoi_token_index = hf_config.eoi_token_index
self.image_token_index = hf_config.image_token_index self.image_token_index = hf_config.image_token_index
self.multimodal_tokens = MultimodalSpecialTokens( self.multimodal_tokens = MultimodalSpecialTokens(
image_token=_processor.image_token image_token=_processor.image_token,
) ).build(_processor)
async def process_mm_data_async( async def process_mm_data_async(
self, self,
......
...@@ -21,7 +21,7 @@ class Phi4MMImageProcessor(BaseMultimodalProcessor): ...@@ -21,7 +21,7 @@ class Phi4MMImageProcessor(BaseMultimodalProcessor):
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor)
self.multimodal_tokens = MultimodalSpecialTokens( self.multimodal_tokens = MultimodalSpecialTokens(
image_token=_IMAGE_SPECIAL_TOKEN, image_token=_IMAGE_SPECIAL_TOKEN,
) ).build(_processor)
async def process_mm_data_async( async def process_mm_data_async(
self, self,
......
...@@ -55,7 +55,7 @@ class PixtralProcessor(BaseMultimodalProcessor): ...@@ -55,7 +55,7 @@ class PixtralProcessor(BaseMultimodalProcessor):
self.patch_size = self.vision_config.patch_size self.patch_size = self.vision_config.patch_size
self.multimodal_tokens = MultimodalSpecialTokens( self.multimodal_tokens = MultimodalSpecialTokens(
image_token=_processor.image_token image_token=_processor.image_token
) ).build(_processor)
_processor.tokenizer.add_special_tokens( _processor.tokenizer.add_special_tokens(
{ {
"pad_token": getattr(hf_config, "pad_token", self.PAD_TOKEN), "pad_token": getattr(hf_config, "pad_token", self.PAD_TOKEN),
......
...@@ -203,16 +203,9 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): ...@@ -203,16 +203,9 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor)
# The single, pre-expanded image token.
self.IMAGE_TOKEN = "<|vision_start|><|image_pad|><|vision_end|>"
# The regex that matches expanded image tokens. # The regex that matches expanded image tokens.
self.IMAGE_TOKEN_REGEX = re.compile(
r"<\|vision_start\|>(?:<\|image_pad\|>)+<\|vision_end\|>"
)
self.IM_START_TOKEN_ID = hf_config.vision_start_token_id self.IM_START_TOKEN_ID = hf_config.vision_start_token_id
self.IM_END_TOKEN_ID = hf_config.vision_end_token_id self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
self.IM_TOKEN_ID = hf_config.image_token_id
self.VIDEO_TOKEN_ID = hf_config.video_token_id
self.vision_start_token_id = hf_config.vision_start_token_id self.vision_start_token_id = hf_config.vision_start_token_id
self.vision_end_token_id = hf_config.vision_end_token_id self.vision_end_token_id = hf_config.vision_end_token_id
self.NUM_TOKEN_PER_FRAME = 770 self.NUM_TOKEN_PER_FRAME = 770
...@@ -220,12 +213,14 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): ...@@ -220,12 +213,14 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
self.MIN_PIXELS = 4 * 28 * 28 self.MIN_PIXELS = 4 * 28 * 28
self.MAX_PIXELS = 16384 * 28 * 28 self.MAX_PIXELS = 16384 * 28 * 28
self.MAX_RATIO = 200 self.MAX_RATIO = 200
# TODO(mick): move all MultimodalSpecialTokens initializations into processor init self.mm_tokens = MultimodalSpecialTokens(
self.mm_special_tokens = MultimodalSpecialTokens( image_token="<|vision_start|><|image_pad|><|vision_end|>",
image_token=self.IMAGE_TOKEN, image_token_id=hf_config.image_token_id,
image_token_regex=self.IMAGE_TOKEN_REGEX, image_token_regex=re.compile(
video_token=self.VIDEO_TOKEN_ID, r"<\|vision_start\|>(?:<\|image_pad\|>)+<\|vision_end\|>"
) ),
video_token_id=hf_config.video_token_id,
).build(_processor)
async def process_mm_data_async( async def process_mm_data_async(
self, self,
...@@ -241,7 +236,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): ...@@ -241,7 +236,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
prompt=input_text, prompt=input_text,
image_data=image_data, image_data=image_data,
video_data=request_obj.video_data, video_data=request_obj.video_data,
multimodal_tokens=self.mm_special_tokens, multimodal_tokens=self.mm_tokens,
max_req_input_len=max_req_input_len, max_req_input_len=max_req_input_len,
) )
...@@ -255,13 +250,15 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): ...@@ -255,13 +250,15 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
await preprocess_video(video) for video in base_output.videos await preprocess_video(video) for video in base_output.videos
] ]
mm_items, input_ids, ret = self.process_and_combine_mm_data(base_output) mm_items, input_ids, ret = self.process_and_combine_mm_data(
base_output, self.mm_tokens
)
input_ids = input_ids.flatten() input_ids = input_ids.flatten()
mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index( mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index(
spatial_merge_size=self.hf_config.vision_config.spatial_merge_size, spatial_merge_size=self.hf_config.vision_config.spatial_merge_size,
image_token_id=self.IM_TOKEN_ID, image_token_id=self.mm_tokens.image_token_id,
video_token_id=self.VIDEO_TOKEN_ID, video_token_id=self.mm_tokens.video_token_id,
vision_start_token_id=self.vision_start_token_id, vision_start_token_id=self.vision_start_token_id,
model_type=self.hf_config.model_type, model_type=self.hf_config.model_type,
tokens_per_second=getattr( tokens_per_second=getattr(
...@@ -279,8 +276,8 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): ...@@ -279,8 +276,8 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
"mm_items": mm_items, "mm_items": mm_items,
"im_start_id": self.IM_START_TOKEN_ID, "im_start_id": self.IM_START_TOKEN_ID,
"im_end_id": self.IM_END_TOKEN_ID, "im_end_id": self.IM_END_TOKEN_ID,
"im_token_id": self.IM_TOKEN_ID, "im_token_id": self.mm_tokens.image_token_id,
"video_token_id": self.VIDEO_TOKEN_ID, "video_token_id": self.mm_tokens.video_token_id,
"mrope_positions": mrope_positions, "mrope_positions": mrope_positions,
"mrope_position_delta": mrope_position_delta, "mrope_position_delta": mrope_position_delta,
} }
from typing import Any, Dict, List, Optional, Type, cast from typing import Any, Dict, List, Optional, Type
import torch.nn as nn import torch.nn as nn
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
...@@ -10,7 +10,6 @@ from sglang.srt.managers.io_struct import ( ...@@ -10,7 +10,6 @@ from sglang.srt.managers.io_struct import (
GenerateReqInput, GenerateReqInput,
ImageDataInputItem, ImageDataInputItem,
) )
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.vila import VILAForConditionalGeneration from sglang.srt.models.vila import VILAForConditionalGeneration
from sglang.srt.multimodal.processors.base_processor import ( from sglang.srt.multimodal.processors.base_processor import (
BaseMultimodalProcessor, BaseMultimodalProcessor,
...@@ -37,8 +36,11 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor): ...@@ -37,8 +36,11 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor):
_processor: VILAProcessor, _processor: VILAProcessor,
) -> None: ) -> None:
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor)
self.IM_TOKEN_ID = hf_config.image_token_id self.mm_tokens = MultimodalSpecialTokens(
self.VIDEO_TOKEN_ID = hf_config.video_token_id image_token=self._processor.tokenizer.image_token,
image_token_id=hf_config.image_token_id,
video_token_id=hf_config.video_token_id,
).build(_processor)
async def process_mm_data_async( async def process_mm_data_async(
self, self,
...@@ -50,18 +52,18 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor): ...@@ -50,18 +52,18 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor):
) -> Optional[Dict[str, Any]]: ) -> Optional[Dict[str, Any]]:
base_output = self.load_mm_data( base_output = self.load_mm_data(
prompt=input_text, prompt=input_text,
multimodal_tokens=MultimodalSpecialTokens( multimodal_tokens=self.mm_tokens,
image_token=self._processor.tokenizer.image_token
),
max_req_input_len=max_req_input_len, max_req_input_len=max_req_input_len,
image_data=image_data, image_data=image_data,
) )
mm_items, input_ids, _ = self.process_and_combine_mm_data(base_output) mm_items, input_ids, _ = self.process_and_combine_mm_data(
base_output, self.mm_tokens
)
return { return {
"input_ids": input_ids.tolist(), "input_ids": input_ids.tolist(),
"mm_items": mm_items, "mm_items": mm_items,
"im_token_id": self.IM_TOKEN_ID, "im_token_id": self.mm_tokens.image_token_id,
"video_token_id": self.VIDEO_TOKEN_ID, "video_token_id": self.mm_tokens.video_token_id,
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment