""" Multi-modality utils """ import hashlib import pickle from abc import abstractmethod from typing import Any, Callable, Dict, List, Literal, Optional, Tuple import numpy as np import torch from torch import nn from sglang.srt.layers.multimodal import gpu_tensor_hash from sglang.srt.managers.schedule_batch import ( Modality, MultimodalDataItem, MultimodalInputs, global_server_args_dict, ) from sglang.srt.mem_cache.multimodal_cache import MultiModalCache from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.utils import flatten_nested_list, is_npu, print_warning_once from sglang.utils import logger _is_npu = is_npu() # NOTE: Using the shared logger from sglang.utils instead of creating a module-specific logger # to ensure consistent logging behavior across the codebase. This prevents issues with log # propagation that can cause some log messages (like 'server is fired up') to not appear # in the console when multimodal support is enabled. # TODO(mick): nccl # cuda_ipc: for intranode tensor sharing TensorTransportMode = Literal["cuda_ipc", "auto", "default"] class TransportProxyTensor(torch.Tensor): """ A convenient torch.Tensor subclass that carries extra metadata and supports efficient inter-process communications """ @staticmethod def __new__( cls, data: torch.Tensor, name: Optional[str] = None, fields: Optional[Dict[str, Any]] = None, transport_mode: TensorTransportMode = "default", *args, **kwargs, ): if not isinstance(data, torch.Tensor): raise TypeError( f"Input 'data' must be a torch.Tensor, but got {type(data)}" ) instance = data.as_subclass(cls) instance._metadata = { "name": name, "fields": fields if fields is not None else {}, "transport_mode": transport_mode, } return instance def __getstate__(self): """ Called during pickling. Implements the serialization logic. """ # acquire all serialize metadata from _metadata state = { "metadata": self._metadata, "tensor_data": None, "ipc_extra": None, } transport_mode = self._metadata.get("transport_mode", "default") if transport_mode == "cuda_ipc" and self.is_cuda: try: storage = self.untyped_storage() handle = storage._share_cuda_() state["ipc_extra"] = { "handle": handle, "shape": self.shape, "dtype": self.dtype, "stride": self.stride(), "device_index": self.device.index, } state["tensor_data"] = None except Exception as e: # Failed to get CUDA IPC handle (possibly tp). Falling back to default transport. state["metadata"]["transport_mode"] = "default" state["tensor_data"] = self.as_subclass(torch.Tensor) else: state["metadata"]["transport_mode"] = "default" state["tensor_data"] = self.as_subclass(torch.Tensor) return state def __setstate__(self, state: Dict[str, Any]): """ Called during unpickling. Implements the deserialization logic. """ self._metadata = state["metadata"] transport_mode = self._metadata.get("transport_mode", "default") if transport_mode == "cuda_ipc" and state["ipc_extra"] is not None: ipc_extra = state["ipc_extra"] handle, shape, dtype, stride, source_device_index = ( ipc_extra["handle"], ipc_extra["shape"], ipc_extra["dtype"], ipc_extra["stride"], ipc_extra["device_index"], ) try: target_device = torch.device(f"cuda:{source_device_index}") with torch.cuda.device(target_device): storage = torch.UntypedStorage._new_shared_cuda(*handle) reconstructed_tensor = torch.empty( 0, dtype=dtype, device=target_device ).set_(storage, storage_offset=0, size=shape, stride=stride) self.set_(reconstructed_tensor) except Exception as e: print(f"Error: Failed to deserialize from CUDA IPC handle ({e}).") raise e elif state["tensor_data"] is not None: self.set_(state["tensor_data"]) else: raise pickle.UnpicklingError( "Invalid state for TransportProxyTensor: no tensor data found." ) @property def name(self) -> Optional[str]: return self._metadata.get("name") @property def fields(self) -> Dict[str, Any]: return self._metadata.get("fields", {}) @property def transport_mode(self) -> TensorTransportMode: return self._metadata.get("transport_mode", "default") class MultiModalityDataPaddingPattern: """ Data tokens (like image tokens) often need special handling during padding to maintain model compatibility. This class provides the interface for implementing different padding strategies for data tokens """ @abstractmethod def pad_input_tokens( self, input_ids: List[int], mm_inputs: MultimodalInputs ) -> List[int]: """ Pad the input ids sequence containing data tokens, and replace them with pad_values """ pass class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern): """In this pattern, data tokens should be enclosed by special token pairs (e.g. ..., data_token_pairs) The padded value in a region enclosed by a token pair with be the same one, as the MultimodalDataItem's pad value This strategy should be applied when data content is marked by start/end token pairs in the input sequence. """ def __init__( self, data_token_pairs: Optional[List[Tuple[int, int]]], data_start_token_ids: Optional[List[int]] = None, ) -> None: """ Args: data_start_token_ids marks the start of a single multimodal data See Minicpmo's slice_start_id for example """ self.data_token_id_pairs = data_token_pairs self.data_start_token_ids = data_start_token_ids or [ s for s, _e in data_token_pairs ] def pad_input_tokens( self, input_ids: List[int], mm_inputs: MultimodalInputs ) -> List[int]: """ This function will replace the data-tokens in between with pad_values accordingly """ pad_values = [item.pad_value for item in mm_inputs.mm_items] data_token_pairs = self.data_token_id_pairs mm_inputs.data_offsets = [] if data_token_pairs is None: data_token_pairs = [mm_inputs.im_start_id, mm_inputs.im_end_id] if data_token_pairs is None: print_warning_once( "No data_token_pairs provided, RadixAttention might be influenced." ) return input_ids start_token_ids = {s for s, _e in data_token_pairs} end_tokens_ids = {e for _s, e in data_token_pairs} padded_ids = [] last_idx = 0 data_idx = -1 start_indices = [i for i, x in enumerate(input_ids) if x in start_token_ids] end_indices = [i for i, x in enumerate(input_ids) if x in end_tokens_ids] if len(start_indices) != len(end_indices): return input_ids for start_idx, end_idx in zip(start_indices, end_indices): padded_ids.extend(input_ids[last_idx : start_idx + 1]) if input_ids[start_idx] in self.data_start_token_ids: data_idx += 1 mm_inputs.data_offsets += [start_idx] if data_idx >= len(pad_values): data_idx = len(pad_values) - 1 num_tokens = end_idx - start_idx - 1 pad_value = pad_values[data_idx] padded_ids.extend([pad_value] * num_tokens) last_idx = end_idx padded_ids.extend(input_ids[last_idx:]) assert len(input_ids) == len(padded_ids), "Length validation fails" return padded_ids class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPattern): """In this pattern, data tokens should be represented as repetitions of a single token e.g. ...., or