Unverified Commit 3212c2ad authored by Mick's avatar Mick Committed by GitHub
Browse files

vlm: optimize tensor transport (#6003)


Co-authored-by: default avatarXinyuan Tong <115166877+JustinTong0323@users.noreply.github.com>
parent 53475674
...@@ -3,8 +3,9 @@ Multi-modality utils ...@@ -3,8 +3,9 @@ Multi-modality utils
""" """
import hashlib import hashlib
import pickle
from abc import abstractmethod from abc import abstractmethod
from typing import Callable, Dict, List, Optional, Tuple from typing import Any, Callable, Dict, List, Literal, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
...@@ -27,6 +28,130 @@ from sglang.utils import logger ...@@ -27,6 +28,130 @@ from sglang.utils import logger
# propagation that can cause some log messages (like 'server is fired up') to not appear # propagation that can cause some log messages (like 'server is fired up') to not appear
# in the console when multimodal support is enabled. # in the console when multimodal support is enabled.
# TODO(mick): nccl
# cuda_ipc: for intranode tensor sharing
TensorTransportMode = Literal["cuda_ipc", "auto", "default"]
class TransportProxyTensor(torch.Tensor):
"""
A convenient torch.Tensor subclass that carries extra metadata and supports
efficient inter-process communications
"""
@staticmethod
def __new__(
cls,
data: torch.Tensor,
name: Optional[str] = None,
fields: Optional[Dict[str, Any]] = None,
transport_mode: TensorTransportMode = "default",
*args,
**kwargs,
):
if not isinstance(data, torch.Tensor):
raise TypeError(
f"Input 'data' must be a torch.Tensor, but got {type(data)}"
)
instance = data.as_subclass(cls)
instance._metadata = {
"name": name,
"fields": fields if fields is not None else {},
"transport_mode": transport_mode,
}
return instance
def __getstate__(self):
"""
Called during pickling. Implements the serialization logic.
"""
# acquire all serialize metadata from _metadata
state = {
"metadata": self._metadata,
"tensor_data": None,
"ipc_extra": None,
}
transport_mode = self._metadata.get("transport_mode", "default")
if transport_mode == "cuda_ipc" and self.is_cuda:
try:
storage = self.untyped_storage()
handle = storage._share_cuda_()
state["ipc_extra"] = {
"handle": handle,
"shape": self.shape,
"dtype": self.dtype,
"stride": self.stride(),
"device_index": self.device.index,
}
state["tensor_data"] = None
except Exception as e:
print_warning_once(
f"Warning: Failed to get CUDA IPC handle ({e}). Falling back to default transport."
)
state["metadata"]["transport_mode"] = "default"
state["tensor_data"] = self.as_subclass(torch.Tensor)
else:
state["metadata"]["transport_mode"] = "default"
state["tensor_data"] = self.as_subclass(torch.Tensor)
return state
def __setstate__(self, state: Dict[str, Any]):
"""
Called during unpickling. Implements the deserialization logic.
"""
self._metadata = state["metadata"]
transport_mode = self._metadata.get("transport_mode", "default")
if transport_mode == "cuda_ipc" and state["ipc_extra"] is not None:
ipc_extra = state["ipc_extra"]
handle, shape, dtype, stride, source_device_index = (
ipc_extra["handle"],
ipc_extra["shape"],
ipc_extra["dtype"],
ipc_extra["stride"],
ipc_extra["device_index"],
)
try:
target_device = torch.device(f"cuda:{source_device_index}")
with torch.cuda.device(target_device):
storage = torch.UntypedStorage._new_shared_cuda(*handle)
reconstructed_tensor = torch.empty(
0, dtype=dtype, device=target_device
).set_(storage, storage_offset=0, size=shape, stride=stride)
self.set_(reconstructed_tensor)
except Exception as e:
print(f"Error: Failed to deserialize from CUDA IPC handle ({e}).")
raise e
elif state["tensor_data"] is not None:
self.set_(state["tensor_data"])
else:
raise pickle.UnpicklingError(
"Invalid state for TransportProxyTensor: no tensor data found."
)
@property
def name(self) -> Optional[str]:
return self._metadata.get("name")
@property
def fields(self) -> Dict[str, Any]:
return self._metadata.get("fields", {})
@property
def transport_mode(self) -> TensorTransportMode:
return self._metadata.get("transport_mode", "default")
class MultiModalityDataPaddingPattern: class MultiModalityDataPaddingPattern:
""" """
......
...@@ -12,18 +12,6 @@ logger = logging.getLogger(__name__) ...@@ -12,18 +12,6 @@ logger = logging.getLogger(__name__)
PROCESSOR_MAPPING = {} PROCESSOR_MAPPING = {}
class DummyMultimodalProcessor(BaseMultimodalProcessor):
def __init__(self):
pass
async def process_mm_data_async(self, *args, **kwargs):
return None
def get_dummy_processor():
return DummyMultimodalProcessor()
def import_processors(): def import_processors():
package_name = "sglang.srt.multimodal.processors" package_name = "sglang.srt.multimodal.processors"
package = importlib.import_module(package_name) package = importlib.import_module(package_name)
...@@ -49,11 +37,12 @@ def import_processors(): ...@@ -49,11 +37,12 @@ def import_processors():
def get_mm_processor( def get_mm_processor(
hf_config, server_args: ServerArgs, processor hf_config, server_args: ServerArgs, processor, transport_mode
) -> BaseMultimodalProcessor: ) -> BaseMultimodalProcessor:
for model_cls, processor_cls in PROCESSOR_MAPPING.items(): for model_cls, processor_cls in PROCESSOR_MAPPING.items():
if model_cls.__name__ in hf_config.architectures: if model_cls.__name__ in hf_config.architectures:
return processor_cls(hf_config, server_args, processor) return processor_cls(hf_config, server_args, processor, transport_mode)
raise ValueError( raise ValueError(
f"No processor registered for architecture: {hf_config.architectures}.\n" f"No processor registered for architecture: {hf_config.architectures}.\n"
f"Registered architectures: {[model_cls.__name__ for model_cls in PROCESSOR_MAPPING.keys()]}" f"Registered architectures: {[model_cls.__name__ for model_cls in PROCESSOR_MAPPING.keys()]}"
......
...@@ -209,10 +209,11 @@ class MultimodalDataItem: ...@@ -209,10 +209,11 @@ class MultimodalDataItem:
hash: int = None hash: int = None
pad_value: int = None pad_value: int = None
offsets: Optional[list] = None offsets: Optional[list] = None
# the raw features returned by processor, e.g. pixel_values or audio_features # the raw features returned by processor, e.g. pixel_values or audio_features
feature: Union[torch.Tensor, np.ndarray] = None feature: Union[torch.Tensor, np.ndarray] = None
# the precomputed embeddings, passed as final encoder embeddings
# the precomputed embeddings for the modality, e.g. image_emb for image, audio_emb for audio # One and only one of the feature and precomputed_embeddings will be empty
precomputed_embeddings: Optional[Union[torch.Tensor, np.ndarray]] = None precomputed_embeddings: Optional[Union[torch.Tensor, np.ndarray]] = None
# Model-specific data stored in a dictionary # Model-specific data stored in a dictionary
......
...@@ -112,6 +112,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -112,6 +112,7 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromTensorReqInput, UpdateWeightsFromTensorReqInput,
UpdateWeightsFromTensorReqOutput, UpdateWeightsFromTensorReqOutput,
) )
from sglang.srt.managers.mm_utils import TensorTransportMode
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
from sglang.srt.metrics.collector import TokenizerMetricsCollector from sglang.srt.metrics.collector import TokenizerMetricsCollector
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
...@@ -166,6 +167,16 @@ class ReqState: ...@@ -166,6 +167,16 @@ class ReqState:
output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list) output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
is_cross_node = server_args.dist_init_addr
if is_cross_node:
# Fallback to default CPU transport for multi-node
return "default"
else:
return "cuda_ipc"
class TokenizerManager: class TokenizerManager:
"""TokenizerManager is a process that tokenizes the text.""" """TokenizerManager is a process that tokenizes the text."""
...@@ -216,12 +227,13 @@ class TokenizerManager: ...@@ -216,12 +227,13 @@ class TokenizerManager:
revision=server_args.revision, revision=server_args.revision,
use_fast=not server_args.disable_fast_image_processor, use_fast=not server_args.disable_fast_image_processor,
) )
transport_mode = _determine_tensor_transport_mode(self.server_args)
# We want to parallelize the image pre-processing so we create an executor for it # We want to parallelize the image pre-processing so we create an executor for it
# We create mm_processor for any skip_tokenizer_init to make sure we still encode # We create mm_processor for any skip_tokenizer_init to make sure we still encode
# images even with skip_tokenizer_init=False. # images even with skip_tokenizer_init=False.
self.mm_processor = get_mm_processor( self.mm_processor = get_mm_processor(
self.model_config.hf_config, server_args, _processor self.model_config.hf_config, server_args, _processor, transport_mode
) )
if server_args.skip_tokenizer_init: if server_args.skip_tokenizer_init:
......
...@@ -12,6 +12,7 @@ import torch ...@@ -12,6 +12,7 @@ import torch
from PIL import Image from PIL import Image
from transformers import BaseImageProcessorFast from transformers import BaseImageProcessorFast
from sglang.srt.managers.mm_utils import TransportProxyTensor
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.utils import load_audio, load_image, load_video, logger from sglang.srt.utils import load_audio, load_image, load_video, logger
...@@ -142,11 +143,14 @@ class MultimodalSpecialTokens: ...@@ -142,11 +143,14 @@ class MultimodalSpecialTokens:
class BaseMultimodalProcessor(ABC): class BaseMultimodalProcessor(ABC):
models = [] models = []
def __init__(self, hf_config, server_args, _processor): def __init__(
self, hf_config, server_args, _processor, transport_mode, *args, **kwargs
):
self.hf_config = hf_config self.hf_config = hf_config
self._processor = _processor self._processor = _processor
self.arch = hf_config.architectures[0] self.arch = hf_config.architectures[0]
self.server_args = server_args self.server_args = server_args
self.transport_mode = transport_mode
# FIXME: not accurate, model and image specific # FIXME: not accurate, model and image specific
self.NUM_TOKEN_PER_FRAME = 330 self.NUM_TOKEN_PER_FRAME = 330
...@@ -217,10 +221,6 @@ class BaseMultimodalProcessor(ABC): ...@@ -217,10 +221,6 @@ class BaseMultimodalProcessor(ABC):
return_tensors="pt", return_tensors="pt",
**kwargs, **kwargs,
) )
if "pixel_values" in result and isinstance(
result["pixel_values"], torch.Tensor
):
result["pixel_values"] = result["pixel_values"].to("cpu")
return result return result
@abstractmethod @abstractmethod
...@@ -500,7 +500,6 @@ class BaseMultimodalProcessor(ABC): ...@@ -500,7 +500,6 @@ class BaseMultimodalProcessor(ABC):
) -> List[MultimodalDataItem]: ) -> List[MultimodalDataItem]:
"""Create mm_items directly from processor output.""" """Create mm_items directly from processor output."""
items: dict[Modality, MultimodalDataItem] = {} items: dict[Modality, MultimodalDataItem] = {}
for attr_name, value in data_dict.items(): for attr_name, value in data_dict.items():
if attr_name == "input_ids": if attr_name == "input_ids":
continue continue
...@@ -624,4 +623,19 @@ class BaseMultimodalProcessor(ABC): ...@@ -624,4 +623,19 @@ class BaseMultimodalProcessor(ABC):
mm_token_id=mm_token_id, mm_token_id=mm_token_id,
) )
# post-process
for item in all_collected_items:
# replace the feature tensor with a proxy
if isinstance(item.feature, torch.Tensor) and item.feature.is_cuda:
item.feature = TransportProxyTensor(
transport_mode=self.transport_mode, data=item.feature
)
elif (
isinstance(item.precomputed_embeddings, torch.Tensor)
and item.precomputed_embeddings.is_cuda
):
item.precomputed_embeddings = TransportProxyTensor(
transport_mode=self.transport_mode, data=item.precomputed_embeddings
)
return all_collected_items, input_ids, ret return all_collected_items, input_ids, ret
...@@ -10,8 +10,8 @@ from sglang.srt.multimodal.processors.base_processor import ( ...@@ -10,8 +10,8 @@ from sglang.srt.multimodal.processors.base_processor import (
class ClipImageProcessor(BaseMultimodalProcessor): class ClipImageProcessor(BaseMultimodalProcessor):
models = [CLIPModel] models = [CLIPModel]
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor, *args, **kwargs)
self.mm_tokens = MultimodalSpecialTokens(image_token="<image>").build( self.mm_tokens = MultimodalSpecialTokens(image_token="<image>").build(
_processor _processor
) )
......
...@@ -31,8 +31,8 @@ from sglang.srt.multimodal.processors.base_processor import ( ...@@ -31,8 +31,8 @@ from sglang.srt.multimodal.processors.base_processor import (
class DeepseekVL2ImageProcessor(BaseMultimodalProcessor): class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
models = [DeepseekVL2ForCausalLM] models = [DeepseekVL2ForCausalLM]
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor, *args, **kwargs)
self.mm_tokens = MultimodalSpecialTokens( self.mm_tokens = MultimodalSpecialTokens(
image_token="<image>", image_token_id=self._processor.image_token_id image_token="<image>", image_token_id=self._processor.image_token_id
).build(_processor) ).build(_processor)
......
...@@ -14,8 +14,8 @@ from sglang.srt.multimodal.processors.base_processor import MultimodalSpecialTok ...@@ -14,8 +14,8 @@ from sglang.srt.multimodal.processors.base_processor import MultimodalSpecialTok
class Gemma3SGLangImageProcessor(SGLangBaseProcessor): class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
models = [Gemma3ForConditionalGeneration] models = [Gemma3ForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor, *args, **kwargs)
self.IM_START_TOKEN_ID = hf_config.boi_token_index self.IM_START_TOKEN_ID = hf_config.boi_token_index
self.IM_END_TOKEN_ID = hf_config.eoi_token_index self.IM_END_TOKEN_ID = hf_config.eoi_token_index
self.mm_tokens = MultimodalSpecialTokens( self.mm_tokens = MultimodalSpecialTokens(
......
...@@ -27,8 +27,8 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor): ...@@ -27,8 +27,8 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor):
models = [Gemma3nForConditionalGeneration] models = [Gemma3nForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor, *args, **kwargs)
self.IM_START_TOKEN_ID = hf_config.boi_token_id self.IM_START_TOKEN_ID = hf_config.boi_token_id
self.IM_END_TOKEN_ID = hf_config.eoi_token_id self.IM_END_TOKEN_ID = hf_config.eoi_token_id
......
...@@ -16,8 +16,8 @@ from sglang.srt.multimodal.processors.base_processor import ( ...@@ -16,8 +16,8 @@ from sglang.srt.multimodal.processors.base_processor import (
class InternVLImageProcessor(BaseMultimodalProcessor): class InternVLImageProcessor(BaseMultimodalProcessor):
models = [InternVLChatModel] models = [InternVLChatModel]
def __init__(self, hf_config, server_args, _image_processor): def __init__(self, hf_config, server_args, _image_processor, *args, **kwargs):
super().__init__(hf_config, server_args, _image_processor) super().__init__(hf_config, server_args, _image_processor, *args, **kwargs)
image_size = hf_config.force_image_size or hf_config.vision_config.image_size image_size = hf_config.force_image_size or hf_config.vision_config.image_size
patch_size = hf_config.vision_config.patch_size patch_size = hf_config.vision_config.patch_size
......
...@@ -11,8 +11,8 @@ from sglang.srt.multimodal.processors.base_processor import ( ...@@ -11,8 +11,8 @@ from sglang.srt.multimodal.processors.base_processor import (
class JanusProImageProcessor(BaseMultimodalProcessor): class JanusProImageProcessor(BaseMultimodalProcessor):
models = [MultiModalityCausalLM] models = [MultiModalityCausalLM]
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor, *args, **kwargs)
self.mm_tokens = MultimodalSpecialTokens( self.mm_tokens = MultimodalSpecialTokens(
image_token=_processor.image_token, image_token=_processor.image_token,
......
...@@ -12,8 +12,8 @@ from sglang.srt.multimodal.processors.base_processor import MultimodalSpecialTok ...@@ -12,8 +12,8 @@ from sglang.srt.multimodal.processors.base_processor import MultimodalSpecialTok
class KimiVLImageProcessor(SGLangBaseProcessor): class KimiVLImageProcessor(SGLangBaseProcessor):
models = [KimiVLForConditionalGeneration] models = [KimiVLForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor, *args, **kwargs)
self.mm_tokens = MultimodalSpecialTokens( self.mm_tokens = MultimodalSpecialTokens(
image_token="<|media_pad|>", image_token="<|media_pad|>",
# TODO: could we convert in MultimodalSpecialTokens? # TODO: could we convert in MultimodalSpecialTokens?
......
...@@ -30,8 +30,8 @@ class LlavaImageProcessor(BaseMultimodalProcessor): ...@@ -30,8 +30,8 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
LlavaMistralForCausalLM, LlavaMistralForCausalLM,
] ]
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor, *args, **kwargs)
@staticmethod @staticmethod
def _process_single_image_task( def _process_single_image_task(
...@@ -187,7 +187,7 @@ class LlavaMultimodalProcessor(BaseMultimodalProcessor): ...@@ -187,7 +187,7 @@ class LlavaMultimodalProcessor(BaseMultimodalProcessor):
f"Cannot find corresponding multimodal processor registered in sglang for model type `{model_type}`" f"Cannot find corresponding multimodal processor registered in sglang for model type `{model_type}`"
) )
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
assert hasattr(hf_config, "vision_config") assert hasattr(hf_config, "vision_config")
assert hasattr(hf_config, "text_config") assert hasattr(hf_config, "text_config")
self.vision_config = hf_config.vision_config self.vision_config = hf_config.vision_config
...@@ -196,7 +196,7 @@ class LlavaMultimodalProcessor(BaseMultimodalProcessor): ...@@ -196,7 +196,7 @@ class LlavaMultimodalProcessor(BaseMultimodalProcessor):
if vision_type := getattr(self.vision_config, "model_type"): if vision_type := getattr(self.vision_config, "model_type"):
self.inner = self._get_sgl_processor_cls(vision_type)( self.inner = self._get_sgl_processor_cls(vision_type)(
hf_config, server_args, _processor hf_config, server_args, _processor, *args, **kwargs
) )
else: else:
raise ValueError( raise ValueError(
......
...@@ -15,8 +15,8 @@ from sglang.srt.multimodal.processors.base_processor import ( ...@@ -15,8 +15,8 @@ from sglang.srt.multimodal.processors.base_processor import (
class MiniCPMMultimodalProcessor(BaseMultimodalProcessor): class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
models = [MiniCPMV, MiniCPMO] models = [MiniCPMV, MiniCPMO]
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor, *args, **kwargs)
# Collect special token ids # Collect special token ids
tokenizer = self._processor.tokenizer tokenizer = self._processor.tokenizer
self.slice_start_id = getattr(tokenizer, "slice_start_id", None) self.slice_start_id = getattr(tokenizer, "slice_start_id", None)
...@@ -26,7 +26,6 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor): ...@@ -26,7 +26,6 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
self.im_start_id = getattr(tokenizer, "im_start_id", None) self.im_start_id = getattr(tokenizer, "im_start_id", None)
self.im_end_id = getattr(tokenizer, "im_end_id", None) self.im_end_id = getattr(tokenizer, "im_end_id", None)
self.im_token_id = getattr(tokenizer, "unk_id", None) self.im_token_id = getattr(tokenizer, "unk_id", None)
self.mm_tokens = MultimodalSpecialTokens( self.mm_tokens = MultimodalSpecialTokens(
image_token="(<image>./</image>)", image_token="(<image>./</image>)",
audio_token="(<audio>./</audio>)", audio_token="(<audio>./</audio>)",
......
...@@ -10,8 +10,8 @@ from sglang.srt.multimodal.processors.base_processor import ( ...@@ -10,8 +10,8 @@ from sglang.srt.multimodal.processors.base_processor import (
class MllamaImageProcessor(BaseMultimodalProcessor): class MllamaImageProcessor(BaseMultimodalProcessor):
models = [MllamaForConditionalGeneration] models = [MllamaForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor, *args, **kwargs)
self.mm_tokens = MultimodalSpecialTokens( self.mm_tokens = MultimodalSpecialTokens(
image_token=self._processor.image_token, image_token=self._processor.image_token,
image_token_id=self._processor.image_token_id, image_token_id=self._processor.image_token_id,
......
...@@ -18,8 +18,8 @@ from sglang.srt.multimodal.processors.base_processor import ( ...@@ -18,8 +18,8 @@ from sglang.srt.multimodal.processors.base_processor import (
class Mllama4ImageProcessor(BaseMultimodalProcessor): class Mllama4ImageProcessor(BaseMultimodalProcessor):
models = [Llama4ForConditionalGeneration] models = [Llama4ForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor, *args, **kwargs)
self.vision_config = hf_config.vision_config self.vision_config = hf_config.vision_config
self.text_config = hf_config.text_config self.text_config = hf_config.text_config
self.boi_token_index = hf_config.boi_token_index self.boi_token_index = hf_config.boi_token_index
......
...@@ -47,9 +47,9 @@ class Phi4MMProcessorAdapter(ProcessorMixin): ...@@ -47,9 +47,9 @@ class Phi4MMProcessorAdapter(ProcessorMixin):
class Phi4MMMultimodalProcessor(BaseMultimodalProcessor): class Phi4MMMultimodalProcessor(BaseMultimodalProcessor):
models = [Phi4MMForCausalLM] models = [Phi4MMForCausalLM]
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
self.processor = Phi4MMProcessorAdapter(_processor) self.processor = Phi4MMProcessorAdapter(_processor)
super().__init__(hf_config, server_args, self.processor) super().__init__(hf_config, server_args, self.processor, *args, **kwargs)
# the following CONSTANTS come from hugging-face microsoft/Phi-4-multimodal-instruct's processing_phi4mm.py file # the following CONSTANTS come from hugging-face microsoft/Phi-4-multimodal-instruct's processing_phi4mm.py file
# ref: https://huggingface.co/microsoft/Phi-4-multimodal-instruct/blob/main/processing_phi4mm.py # ref: https://huggingface.co/microsoft/Phi-4-multimodal-instruct/blob/main/processing_phi4mm.py
......
...@@ -42,8 +42,8 @@ class PixtralProcessor(BaseMultimodalProcessor): ...@@ -42,8 +42,8 @@ class PixtralProcessor(BaseMultimodalProcessor):
return ncols, nrows return ncols, nrows
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor, *args, **kwargs)
self.IM_TOKEN_ID = getattr( self.IM_TOKEN_ID = getattr(
hf_config, "image_token_index", PixtralVisionModel.DEFAULT_IMAGE_TOKEN_ID hf_config, "image_token_index", PixtralVisionModel.DEFAULT_IMAGE_TOKEN_ID
) )
......
...@@ -11,8 +11,8 @@ from sglang.srt.multimodal.processors.base_processor import ( ...@@ -11,8 +11,8 @@ from sglang.srt.multimodal.processors.base_processor import (
class Qwen2AudioMultimodalProcessor(BaseMultimodalProcessor): class Qwen2AudioMultimodalProcessor(BaseMultimodalProcessor):
models = [Qwen2AudioForConditionalGeneration] models = [Qwen2AudioForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor, *args, **kwargs)
self.AUDIO_TOKEN = "<|audio_bos|><|AUDIO|><|audio_eos|>" self.AUDIO_TOKEN = "<|audio_bos|><|AUDIO|><|audio_eos|>"
self.AUDIO_TOKEN_REGEX = re.compile( self.AUDIO_TOKEN_REGEX = re.compile(
r"<\|audio_bos\|>(?:<\|AUDIO\|>)+<\|audio_eos\|>" r"<\|audio_bos\|>(?:<\|AUDIO\|>)+<\|audio_eos\|>"
......
...@@ -201,8 +201,8 @@ async def preprocess_video( ...@@ -201,8 +201,8 @@ async def preprocess_video(
class Qwen2_5VLImageProcessor(SGLangBaseProcessor): class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
models = [Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration] models = [Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor, *args, **kwargs)
# The regex that matches expanded image tokens. # The regex that matches expanded image tokens.
self.IM_START_TOKEN_ID = hf_config.vision_start_token_id self.IM_START_TOKEN_ID = hf_config.vision_start_token_id
self.IM_END_TOKEN_ID = hf_config.vision_end_token_id self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment