Unverified Commit 3212c2ad authored by Mick's avatar Mick Committed by GitHub
Browse files

vlm: optimize tensor transport (#6003)


Co-authored-by: default avatarXinyuan Tong <115166877+JustinTong0323@users.noreply.github.com>
parent 53475674
......@@ -3,8 +3,9 @@ Multi-modality utils
"""
import hashlib
import pickle
from abc import abstractmethod
from typing import Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple
import numpy as np
import torch
......@@ -27,6 +28,130 @@ from sglang.utils import logger
# propagation that can cause some log messages (like 'server is fired up') to not appear
# in the console when multimodal support is enabled.
# TODO(mick): nccl
# cuda_ipc: for intranode tensor sharing
TensorTransportMode = Literal["cuda_ipc", "auto", "default"]
class TransportProxyTensor(torch.Tensor):
"""
A convenient torch.Tensor subclass that carries extra metadata and supports
efficient inter-process communications
"""
@staticmethod
def __new__(
cls,
data: torch.Tensor,
name: Optional[str] = None,
fields: Optional[Dict[str, Any]] = None,
transport_mode: TensorTransportMode = "default",
*args,
**kwargs,
):
if not isinstance(data, torch.Tensor):
raise TypeError(
f"Input 'data' must be a torch.Tensor, but got {type(data)}"
)
instance = data.as_subclass(cls)
instance._metadata = {
"name": name,
"fields": fields if fields is not None else {},
"transport_mode": transport_mode,
}
return instance
def __getstate__(self):
"""
Called during pickling. Implements the serialization logic.
"""
# acquire all serialize metadata from _metadata
state = {
"metadata": self._metadata,
"tensor_data": None,
"ipc_extra": None,
}
transport_mode = self._metadata.get("transport_mode", "default")
if transport_mode == "cuda_ipc" and self.is_cuda:
try:
storage = self.untyped_storage()
handle = storage._share_cuda_()
state["ipc_extra"] = {
"handle": handle,
"shape": self.shape,
"dtype": self.dtype,
"stride": self.stride(),
"device_index": self.device.index,
}
state["tensor_data"] = None
except Exception as e:
print_warning_once(
f"Warning: Failed to get CUDA IPC handle ({e}). Falling back to default transport."
)
state["metadata"]["transport_mode"] = "default"
state["tensor_data"] = self.as_subclass(torch.Tensor)
else:
state["metadata"]["transport_mode"] = "default"
state["tensor_data"] = self.as_subclass(torch.Tensor)
return state
def __setstate__(self, state: Dict[str, Any]):
"""
Called during unpickling. Implements the deserialization logic.
"""
self._metadata = state["metadata"]
transport_mode = self._metadata.get("transport_mode", "default")
if transport_mode == "cuda_ipc" and state["ipc_extra"] is not None:
ipc_extra = state["ipc_extra"]
handle, shape, dtype, stride, source_device_index = (
ipc_extra["handle"],
ipc_extra["shape"],
ipc_extra["dtype"],
ipc_extra["stride"],
ipc_extra["device_index"],
)
try:
target_device = torch.device(f"cuda:{source_device_index}")
with torch.cuda.device(target_device):
storage = torch.UntypedStorage._new_shared_cuda(*handle)
reconstructed_tensor = torch.empty(
0, dtype=dtype, device=target_device
).set_(storage, storage_offset=0, size=shape, stride=stride)
self.set_(reconstructed_tensor)
except Exception as e:
print(f"Error: Failed to deserialize from CUDA IPC handle ({e}).")
raise e
elif state["tensor_data"] is not None:
self.set_(state["tensor_data"])
else:
raise pickle.UnpicklingError(
"Invalid state for TransportProxyTensor: no tensor data found."
)
@property
def name(self) -> Optional[str]:
return self._metadata.get("name")
@property
def fields(self) -> Dict[str, Any]:
return self._metadata.get("fields", {})
@property
def transport_mode(self) -> TensorTransportMode:
return self._metadata.get("transport_mode", "default")
class MultiModalityDataPaddingPattern:
"""
......
......@@ -12,18 +12,6 @@ logger = logging.getLogger(__name__)
PROCESSOR_MAPPING = {}
class DummyMultimodalProcessor(BaseMultimodalProcessor):
def __init__(self):
pass
async def process_mm_data_async(self, *args, **kwargs):
return None
def get_dummy_processor():
return DummyMultimodalProcessor()
def import_processors():
package_name = "sglang.srt.multimodal.processors"
package = importlib.import_module(package_name)
......@@ -49,11 +37,12 @@ def import_processors():
def get_mm_processor(
hf_config, server_args: ServerArgs, processor
hf_config, server_args: ServerArgs, processor, transport_mode
) -> BaseMultimodalProcessor:
for model_cls, processor_cls in PROCESSOR_MAPPING.items():
if model_cls.__name__ in hf_config.architectures:
return processor_cls(hf_config, server_args, processor)
return processor_cls(hf_config, server_args, processor, transport_mode)
raise ValueError(
f"No processor registered for architecture: {hf_config.architectures}.\n"
f"Registered architectures: {[model_cls.__name__ for model_cls in PROCESSOR_MAPPING.keys()]}"
......
......@@ -209,10 +209,11 @@ class MultimodalDataItem:
hash: int = None
pad_value: int = None
offsets: Optional[list] = None
# the raw features returned by processor, e.g. pixel_values or audio_features
feature: Union[torch.Tensor, np.ndarray] = None
# the precomputed embeddings for the modality, e.g. image_emb for image, audio_emb for audio
# the precomputed embeddings, passed as final encoder embeddings
# One and only one of the feature and precomputed_embeddings will be empty
precomputed_embeddings: Optional[Union[torch.Tensor, np.ndarray]] = None
# Model-specific data stored in a dictionary
......
......@@ -112,6 +112,7 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromTensorReqInput,
UpdateWeightsFromTensorReqOutput,
)
from sglang.srt.managers.mm_utils import TensorTransportMode
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
from sglang.srt.metrics.collector import TokenizerMetricsCollector
from sglang.srt.sampling.sampling_params import SamplingParams
......@@ -166,6 +167,16 @@ class ReqState:
output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
is_cross_node = server_args.dist_init_addr
if is_cross_node:
# Fallback to default CPU transport for multi-node
return "default"
else:
return "cuda_ipc"
class TokenizerManager:
"""TokenizerManager is a process that tokenizes the text."""
......@@ -216,12 +227,13 @@ class TokenizerManager:
revision=server_args.revision,
use_fast=not server_args.disable_fast_image_processor,
)
transport_mode = _determine_tensor_transport_mode(self.server_args)
# We want to parallelize the image pre-processing so we create an executor for it
# We create mm_processor for any skip_tokenizer_init to make sure we still encode
# images even with skip_tokenizer_init=False.
self.mm_processor = get_mm_processor(
self.model_config.hf_config, server_args, _processor
self.model_config.hf_config, server_args, _processor, transport_mode
)
if server_args.skip_tokenizer_init:
......
......@@ -12,6 +12,7 @@ import torch
from PIL import Image
from transformers import BaseImageProcessorFast
from sglang.srt.managers.mm_utils import TransportProxyTensor
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.utils import load_audio, load_image, load_video, logger
......@@ -142,11 +143,14 @@ class MultimodalSpecialTokens:
class BaseMultimodalProcessor(ABC):
models = []
def __init__(self, hf_config, server_args, _processor):
def __init__(
self, hf_config, server_args, _processor, transport_mode, *args, **kwargs
):
self.hf_config = hf_config
self._processor = _processor
self.arch = hf_config.architectures[0]
self.server_args = server_args
self.transport_mode = transport_mode
# FIXME: not accurate, model and image specific
self.NUM_TOKEN_PER_FRAME = 330
......@@ -217,10 +221,6 @@ class BaseMultimodalProcessor(ABC):
return_tensors="pt",
**kwargs,
)
if "pixel_values" in result and isinstance(
result["pixel_values"], torch.Tensor
):
result["pixel_values"] = result["pixel_values"].to("cpu")
return result
@abstractmethod
......@@ -500,7 +500,6 @@ class BaseMultimodalProcessor(ABC):
) -> List[MultimodalDataItem]:
"""Create mm_items directly from processor output."""
items: dict[Modality, MultimodalDataItem] = {}
for attr_name, value in data_dict.items():
if attr_name == "input_ids":
continue
......@@ -624,4 +623,19 @@ class BaseMultimodalProcessor(ABC):
mm_token_id=mm_token_id,
)
# post-process
for item in all_collected_items:
# replace the feature tensor with a proxy
if isinstance(item.feature, torch.Tensor) and item.feature.is_cuda:
item.feature = TransportProxyTensor(
transport_mode=self.transport_mode, data=item.feature
)
elif (
isinstance(item.precomputed_embeddings, torch.Tensor)
and item.precomputed_embeddings.is_cuda
):
item.precomputed_embeddings = TransportProxyTensor(
transport_mode=self.transport_mode, data=item.precomputed_embeddings
)
return all_collected_items, input_ids, ret
......@@ -10,8 +10,8 @@ from sglang.srt.multimodal.processors.base_processor import (
class ClipImageProcessor(BaseMultimodalProcessor):
models = [CLIPModel]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
self.mm_tokens = MultimodalSpecialTokens(image_token="<image>").build(
_processor
)
......
......@@ -31,8 +31,8 @@ from sglang.srt.multimodal.processors.base_processor import (
class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
models = [DeepseekVL2ForCausalLM]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
self.mm_tokens = MultimodalSpecialTokens(
image_token="<image>", image_token_id=self._processor.image_token_id
).build(_processor)
......
......@@ -14,8 +14,8 @@ from sglang.srt.multimodal.processors.base_processor import MultimodalSpecialTok
class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
models = [Gemma3ForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
self.IM_START_TOKEN_ID = hf_config.boi_token_index
self.IM_END_TOKEN_ID = hf_config.eoi_token_index
self.mm_tokens = MultimodalSpecialTokens(
......
......@@ -27,8 +27,8 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor):
models = [Gemma3nForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
self.IM_START_TOKEN_ID = hf_config.boi_token_id
self.IM_END_TOKEN_ID = hf_config.eoi_token_id
......
......@@ -16,8 +16,8 @@ from sglang.srt.multimodal.processors.base_processor import (
class InternVLImageProcessor(BaseMultimodalProcessor):
models = [InternVLChatModel]
def __init__(self, hf_config, server_args, _image_processor):
super().__init__(hf_config, server_args, _image_processor)
def __init__(self, hf_config, server_args, _image_processor, *args, **kwargs):
super().__init__(hf_config, server_args, _image_processor, *args, **kwargs)
image_size = hf_config.force_image_size or hf_config.vision_config.image_size
patch_size = hf_config.vision_config.patch_size
......
......@@ -11,8 +11,8 @@ from sglang.srt.multimodal.processors.base_processor import (
class JanusProImageProcessor(BaseMultimodalProcessor):
models = [MultiModalityCausalLM]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
self.mm_tokens = MultimodalSpecialTokens(
image_token=_processor.image_token,
......
......@@ -12,8 +12,8 @@ from sglang.srt.multimodal.processors.base_processor import MultimodalSpecialTok
class KimiVLImageProcessor(SGLangBaseProcessor):
models = [KimiVLForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
self.mm_tokens = MultimodalSpecialTokens(
image_token="<|media_pad|>",
# TODO: could we convert in MultimodalSpecialTokens?
......
......@@ -30,8 +30,8 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
LlavaMistralForCausalLM,
]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
@staticmethod
def _process_single_image_task(
......@@ -187,7 +187,7 @@ class LlavaMultimodalProcessor(BaseMultimodalProcessor):
f"Cannot find corresponding multimodal processor registered in sglang for model type `{model_type}`"
)
def __init__(self, hf_config, server_args, _processor):
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
assert hasattr(hf_config, "vision_config")
assert hasattr(hf_config, "text_config")
self.vision_config = hf_config.vision_config
......@@ -196,7 +196,7 @@ class LlavaMultimodalProcessor(BaseMultimodalProcessor):
if vision_type := getattr(self.vision_config, "model_type"):
self.inner = self._get_sgl_processor_cls(vision_type)(
hf_config, server_args, _processor
hf_config, server_args, _processor, *args, **kwargs
)
else:
raise ValueError(
......
......@@ -15,8 +15,8 @@ from sglang.srt.multimodal.processors.base_processor import (
class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
models = [MiniCPMV, MiniCPMO]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
# Collect special token ids
tokenizer = self._processor.tokenizer
self.slice_start_id = getattr(tokenizer, "slice_start_id", None)
......@@ -26,7 +26,6 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
self.im_start_id = getattr(tokenizer, "im_start_id", None)
self.im_end_id = getattr(tokenizer, "im_end_id", None)
self.im_token_id = getattr(tokenizer, "unk_id", None)
self.mm_tokens = MultimodalSpecialTokens(
image_token="(<image>./</image>)",
audio_token="(<audio>./</audio>)",
......
......@@ -10,8 +10,8 @@ from sglang.srt.multimodal.processors.base_processor import (
class MllamaImageProcessor(BaseMultimodalProcessor):
models = [MllamaForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
self.mm_tokens = MultimodalSpecialTokens(
image_token=self._processor.image_token,
image_token_id=self._processor.image_token_id,
......
......@@ -18,8 +18,8 @@ from sglang.srt.multimodal.processors.base_processor import (
class Mllama4ImageProcessor(BaseMultimodalProcessor):
models = [Llama4ForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
self.vision_config = hf_config.vision_config
self.text_config = hf_config.text_config
self.boi_token_index = hf_config.boi_token_index
......
......@@ -47,9 +47,9 @@ class Phi4MMProcessorAdapter(ProcessorMixin):
class Phi4MMMultimodalProcessor(BaseMultimodalProcessor):
models = [Phi4MMForCausalLM]
def __init__(self, hf_config, server_args, _processor):
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
self.processor = Phi4MMProcessorAdapter(_processor)
super().__init__(hf_config, server_args, self.processor)
super().__init__(hf_config, server_args, self.processor, *args, **kwargs)
# the following CONSTANTS come from hugging-face microsoft/Phi-4-multimodal-instruct's processing_phi4mm.py file
# ref: https://huggingface.co/microsoft/Phi-4-multimodal-instruct/blob/main/processing_phi4mm.py
......
......@@ -42,8 +42,8 @@ class PixtralProcessor(BaseMultimodalProcessor):
return ncols, nrows
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
self.IM_TOKEN_ID = getattr(
hf_config, "image_token_index", PixtralVisionModel.DEFAULT_IMAGE_TOKEN_ID
)
......
......@@ -11,8 +11,8 @@ from sglang.srt.multimodal.processors.base_processor import (
class Qwen2AudioMultimodalProcessor(BaseMultimodalProcessor):
models = [Qwen2AudioForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
self.AUDIO_TOKEN = "<|audio_bos|><|AUDIO|><|audio_eos|>"
self.AUDIO_TOKEN_REGEX = re.compile(
r"<\|audio_bos\|>(?:<\|AUDIO\|>)+<\|audio_eos\|>"
......
......@@ -201,8 +201,8 @@ async def preprocess_video(
class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
models = [Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
# The regex that matches expanded image tokens.
self.IM_START_TOKEN_ID = hf_config.vision_start_token_id
self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment