"""
Multi-modality utils
"""
import hashlib
import pickle
from abc import abstractmethod
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple
import numpy as np
import torch
from torch import nn
from sglang.srt.layers.multimodal import gpu_tensor_hash
from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
global_server_args_dict,
)
from sglang.srt.mem_cache.multimodal_cache import MultiModalCache
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import flatten_nested_list, is_npu, print_warning_once
from sglang.utils import logger
_is_npu = is_npu()
# NOTE: Using the shared logger from sglang.utils instead of creating a module-specific logger
# to ensure consistent logging behavior across the codebase. This prevents issues with log
# propagation that can cause some log messages (like 'server is fired up') to not appear
# in the console when multimodal support is enabled.
# TODO(mick): nccl
# cuda_ipc: for intranode tensor sharing
TensorTransportMode = Literal["cuda_ipc", "auto", "default"]
class TransportProxyTensor(torch.Tensor):
"""
A convenient torch.Tensor subclass that carries extra metadata and supports
efficient inter-process communications
"""
@staticmethod
def __new__(
cls,
data: torch.Tensor,
name: Optional[str] = None,
fields: Optional[Dict[str, Any]] = None,
transport_mode: TensorTransportMode = "default",
*args,
**kwargs,
):
if not isinstance(data, torch.Tensor):
raise TypeError(
f"Input 'data' must be a torch.Tensor, but got {type(data)}"
)
instance = data.as_subclass(cls)
instance._metadata = {
"name": name,
"fields": fields if fields is not None else {},
"transport_mode": transport_mode,
}
return instance
def __getstate__(self):
"""
Called during pickling. Implements the serialization logic.
"""
# acquire all serialize metadata from _metadata
state = {
"metadata": self._metadata,
"tensor_data": None,
"ipc_extra": None,
}
transport_mode = self._metadata.get("transport_mode", "default")
if transport_mode == "cuda_ipc" and self.is_cuda:
try:
storage = self.untyped_storage()
handle = storage._share_cuda_()
state["ipc_extra"] = {
"handle": handle,
"shape": self.shape,
"dtype": self.dtype,
"stride": self.stride(),
"device_index": self.device.index,
}
state["tensor_data"] = None
except Exception as e:
# Failed to get CUDA IPC handle (possibly tp). Falling back to default transport.
state["metadata"]["transport_mode"] = "default"
state["tensor_data"] = self.as_subclass(torch.Tensor)
else:
state["metadata"]["transport_mode"] = "default"
state["tensor_data"] = self.as_subclass(torch.Tensor)
return state
def __setstate__(self, state: Dict[str, Any]):
"""
Called during unpickling. Implements the deserialization logic.
"""
self._metadata = state["metadata"]
transport_mode = self._metadata.get("transport_mode", "default")
if transport_mode == "cuda_ipc" and state["ipc_extra"] is not None:
ipc_extra = state["ipc_extra"]
handle, shape, dtype, stride, source_device_index = (
ipc_extra["handle"],
ipc_extra["shape"],
ipc_extra["dtype"],
ipc_extra["stride"],
ipc_extra["device_index"],
)
try:
target_device = torch.device(f"cuda:{source_device_index}")
with torch.cuda.device(target_device):
storage = torch.UntypedStorage._new_shared_cuda(*handle)
reconstructed_tensor = torch.empty(
0, dtype=dtype, device=target_device
).set_(storage, storage_offset=0, size=shape, stride=stride)
self.set_(reconstructed_tensor)
except Exception as e:
print(f"Error: Failed to deserialize from CUDA IPC handle ({e}).")
raise e
elif state["tensor_data"] is not None:
self.set_(state["tensor_data"])
else:
raise pickle.UnpicklingError(
"Invalid state for TransportProxyTensor: no tensor data found."
)
@property
def name(self) -> Optional[str]:
return self._metadata.get("name")
@property
def fields(self) -> Dict[str, Any]:
return self._metadata.get("fields", {})
@property
def transport_mode(self) -> TensorTransportMode:
return self._metadata.get("transport_mode", "default")
class MultiModalityDataPaddingPattern:
"""
Data tokens (like image tokens) often need special handling during padding
to maintain model compatibility. This class provides the interface for
implementing different padding strategies for data tokens
"""
@abstractmethod
def pad_input_tokens(
self, input_ids: List[int], mm_inputs: MultimodalInputs
) -> List[int]:
"""
Pad the input ids sequence containing data tokens, and replace them with pad_values
"""
pass
class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern):
"""In this pattern, data tokens should be enclosed by special token pairs (e.g. ..., data_token_pairs)
The padded value in a region enclosed by a token pair with be the same one, as the MultimodalDataItem's pad value
This strategy should be applied when data content is marked by start/end token pairs in the input sequence.
"""
def __init__(
self,
data_token_pairs: Optional[List[Tuple[int, int]]],
data_start_token_ids: Optional[List[int]] = None,
) -> None:
"""
Args:
data_start_token_ids marks the start of a single multimodal data
See Minicpmo's slice_start_id for example
"""
self.data_token_id_pairs = data_token_pairs
self.data_start_token_ids = data_start_token_ids or [
s for s, _e in data_token_pairs
]
def pad_input_tokens(
self, input_ids: List[int], mm_inputs: MultimodalInputs
) -> List[int]:
"""
This function will replace the data-tokens in between with pad_values accordingly
"""
pad_values = [item.pad_value for item in mm_inputs.mm_items]
data_token_pairs = self.data_token_id_pairs
mm_inputs.data_offsets = []
if data_token_pairs is None:
data_token_pairs = [mm_inputs.im_start_id, mm_inputs.im_end_id]
if data_token_pairs is None:
print_warning_once(
"No data_token_pairs provided, RadixAttention might be influenced."
)
return input_ids
start_token_ids = {s for s, _e in data_token_pairs}
end_tokens_ids = {e for _s, e in data_token_pairs}
padded_ids = []
last_idx = 0
data_idx = -1
start_indices = [i for i, x in enumerate(input_ids) if x in start_token_ids]
end_indices = [i for i, x in enumerate(input_ids) if x in end_tokens_ids]
if len(start_indices) != len(end_indices):
return input_ids
for start_idx, end_idx in zip(start_indices, end_indices):
padded_ids.extend(input_ids[last_idx : start_idx + 1])
if input_ids[start_idx] in self.data_start_token_ids:
data_idx += 1
mm_inputs.data_offsets += [start_idx]
if data_idx >= len(pad_values):
data_idx = len(pad_values) - 1
num_tokens = end_idx - start_idx - 1
pad_value = pad_values[data_idx]
padded_ids.extend([pad_value] * num_tokens)
last_idx = end_idx
padded_ids.extend(input_ids[last_idx:])
assert len(input_ids) == len(padded_ids), "Length validation fails"
return padded_ids
class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPattern):
"""In this pattern, data tokens should be represented as repetitions of a single token
e.g. ...., or