Unverified Commit 071a1f51 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Minor] clean up multimodal processor and tokenizer manager (#7624)

parent 7c0db3a6
......@@ -353,8 +353,7 @@ async def generate_from_file_request(file: UploadFile, request: Request):
obj = GenerateReqInput(
input_embeds=input_embeds,
sampling_params={
"repetition_penalty": 1.2,
"temperature": 0.2,
"temperature": 0.0,
"max_new_tokens": 512,
},
)
......@@ -393,16 +392,6 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
return _create_error_response(e)
@app.api_route(
"/v1/rerank", methods=["POST", "PUT"], dependencies=[Depends(validate_json_request)]
)
async def v1_rerank_request(request: V1RerankReqInput, raw_request: Request):
"""Endpoint for reranking documents based on query relevance."""
return await raw_request.app.state.openai_serving_rerank.handle_request(
request, raw_request
)
@app.api_route("/flush_cache", methods=["GET", "POST"])
async def flush_cache():
"""Flush the radix cache."""
......@@ -841,6 +830,16 @@ async def v1_score_request(request: ScoringRequest, raw_request: Request):
)
@app.api_route(
"/v1/rerank", methods=["POST", "PUT"], dependencies=[Depends(validate_json_request)]
)
async def v1_rerank_request(request: V1RerankReqInput, raw_request: Request):
"""Endpoint for reranking documents based on query relevance."""
return await raw_request.app.state.openai_serving_rerank.handle_request(
request, raw_request
)
def _create_error_response(e):
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
......
......@@ -22,17 +22,16 @@ from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union
from sglang.srt.managers.schedule_batch import BaseFinishReason
from sglang.srt.multimodal.mm_utils import has_valid_data
from sglang.srt.sampling.sampling_params import SamplingParams
# handle serialization of Image for pydantic
# Handle serialization of Image for pydantic
if TYPE_CHECKING:
from PIL.Image import Image
else:
Image = Any
from sglang.srt.managers.schedule_batch import BaseFinishReason
from sglang.srt.sampling.sampling_params import SamplingParams
@dataclass
class SessionParams:
......@@ -182,6 +181,7 @@ class GenerateReqInput:
# Determine parallel sample count
if self.sampling_params is None:
self.parallel_sample_num = 1
return
elif isinstance(self.sampling_params, dict):
self.parallel_sample_num = self.sampling_params.get("n", 1)
else: # isinstance(self.sampling_params, list):
......
......@@ -25,7 +25,6 @@ def get_dummy_processor():
return DummyMultimodalProcessor()
@lru_cache()
def import_processors():
package_name = "sglang.srt.multimodal.processors"
package = importlib.import_module(package_name)
......
......@@ -180,46 +180,48 @@ class Modality(Enum):
@dataclasses.dataclass
class MultimodalDataItem:
"""
A single multimodal data, from a single image/video/audio or others
A single multimodal data, from a single image/video/audio or others.
We put the common fields first and the model-specific fields last.
"""
modality: Modality
hash: int = None
pad_value: int = None
aspect_ratio_id: Optional[List[torch.Tensor]] = None
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
image_sizes: Tuple[int, int] = None
image_offsets: Optional[list] = None
# the real data, pixel_values or audio_features
# data: Union[List[torch.Tensor], List[np.ndarray]]
pixel_values: Union[torch.Tensor, np.ndarray] = None
audio_features: Union[torch.Tensor, np.ndarray] = None
audio_feature_lens: Optional[List[torch.Tensor]] = None
audio_offsets: Optional[List[Tuple[int, int]]] = None
precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
# For qwen-vl
image_grid_thw: Union[torch.Tensor, np.ndarray] = None
video_grid_thws: Union[torch.Tensor, np.ndarray] = None
second_per_grid_ts: Optional[List[torch.Tensor]] = None
# For deepseek-vl
image_emb_mask: Optional[torch.Tensor] = None
image_spatial_crop: Optional[torch.Tensor] = None
second_per_grid_ts: Optional[List[torch.Tensor]] = None
# For minicpmv
# [num_images, (n, w, h)]
tgt_size: Tuple[int, int] = None
# kimi-vl related
image_grid_hws: Optional[List[torch.Tensor]] = None
# For mllama
aspect_ratio_id: Optional[List[torch.Tensor]] = None
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
audio_features: Union[torch.Tensor, np.ndarray] = None
audio_feature_lens: Optional[List[torch.Tensor]] = None
audio_offsets: Optional[List[Tuple[int, int]]] = None
# For kimi-vl
image_grid_hws: Optional[List[torch.Tensor]] = None
# gemma3n related
# For gemma3n
input_features: Optional[torch.Tensor] = None
input_features_mask: Optional[torch.Tensor] = None
precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
@staticmethod
def is_empty_list(l):
if l is None:
......@@ -339,10 +341,6 @@ class MultimodalInputs:
image_pad_len: Optional[list] = None
num_image_tokens: Optional[int] = None
# QWen2-VL related
mrope_positions: Optional[torch.Tensor] = None
mrope_position_delta: Optional[torch.Tensor] = None
# image
im_token_id: Optional[int] = None
im_start_id: Optional[int] = None
......@@ -358,6 +356,10 @@ class MultimodalInputs:
audio_start_id: Optional[int] = None
audio_end_id: Optional[int] = None
# QWen2-VL related
mrope_positions: Optional[torch.Tensor] = None
mrope_position_delta: Optional[torch.Tensor] = None
@staticmethod
def from_dict(obj: dict):
ret = MultimodalInputs(
......
......@@ -150,7 +150,9 @@ class ReqState:
# For streaming output
last_output_offset: int = 0
# For incremental state update.
# TODO(lianmin): do not initialize some lists if not needed.
text: str = ""
output_ids: List[int] = dataclasses.field(default_factory=list)
input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
......@@ -199,7 +201,6 @@ class TokenizerManager:
self.model_path = server_args.model_path
self.served_model_name = server_args.served_model_name
self.model_config = ModelConfig.from_server_args(server_args)
self.is_generation = self.model_config.is_generation
self.is_image_gen = self.model_config.is_image_gen
self.context_len = self.model_config.context_len
......@@ -251,19 +252,36 @@ class TokenizerManager:
self.dump_requests_threshold = 1000
self.dump_request_list: List[Tuple] = []
self.log_request_metadata = self.get_log_request_metadata()
self.asyncio_tasks = set()
self.session_futures = {} # session_id -> asyncio event
self.max_req_input_len = None
# The event to notify the weight sync is finished.
self.model_update_lock = RWLock()
self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
None
)
self.asyncio_tasks = set()
# For session info
self.session_futures = {} # session_id -> asyncio event
# For pd disaggregtion
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
)
self.transfer_backend = TransferBackend(
self.server_args.disaggregation_transfer_backend
)
# Start kv boostrap server on prefill
if self.disaggregation_mode == DisaggregationMode.PREFILL:
# only start bootstrap server on prefill tm
kv_bootstrap_server_class = get_kv_class(
self.transfer_backend, KVClassType.BOOTSTRAP_SERVER
)
self.bootstrap_server = kv_bootstrap_server_class(
self.server_args.disaggregation_bootstrap_port
)
# Set after scheduler is initialized
self.max_req_input_len = None
# For load balancing
self.current_load = 0
self.current_load_lock = asyncio.Lock()
# Metrics
if self.enable_metrics:
......@@ -393,34 +411,14 @@ class TokenizerManager:
]
)
# For pd disaggregtion
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
)
self.transfer_backend = TransferBackend(
self.server_args.disaggregation_transfer_backend
)
# Start kv boostrap server on prefill
if self.disaggregation_mode == DisaggregationMode.PREFILL:
# only start bootstrap server on prefill tm
kv_bootstrap_server_class = get_kv_class(
self.transfer_backend, KVClassType.BOOTSTRAP_SERVER
)
self.bootstrap_server = kv_bootstrap_server_class(
self.server_args.disaggregation_bootstrap_port
)
self.current_load = 0
self.current_load_lock = asyncio.Lock()
async def generate_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
request: Optional[fastapi.Request] = None,
):
created_time = time.time()
self.auto_create_handle_loop()
obj.normalize_batch_and_arguments()
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
raise ValueError(
......@@ -428,22 +426,6 @@ class TokenizerManager:
"Please add `--is-embedding` when launching the server or try another model."
)
obj.normalize_batch_and_arguments()
if isinstance(obj, GenerateReqInput):
return_hidden_states = obj.return_hidden_states
has_return_hidden_states = return_hidden_states == True or (
isinstance(return_hidden_states, list) and any(return_hidden_states)
)
if (
not self.server_args.enable_return_hidden_states
and has_return_hidden_states
):
raise ValueError(
"return_hidden_states=True requires the server to be started "
"with --enable-return-hidden-states (ServerArgs.enable_return_hidden_states)."
)
if self.log_requests:
max_length, skip_names, _ = self.log_request_metadata
logger.info(
......@@ -451,8 +433,7 @@ class TokenizerManager:
)
async with self.model_update_lock.reader_lock:
is_single = obj.is_single
if is_single:
if obj.is_single:
tokenized_obj = await self._tokenize_one_request(obj)
state = self._send_one_request(obj, tokenized_obj, created_time)
async for response in self._wait_one_response(obj, state, request):
......@@ -514,12 +495,12 @@ class TokenizerManager:
else:
image_inputs: Optional[Dict] = None
self._validate_token_len(obj, input_ids)
self._validate_one_request(obj, input_ids)
return self._create_tokenized_object(
obj, input_text, input_ids, input_embeds, image_inputs, token_type_ids
)
def _validate_token_len(
def _validate_one_request(
self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int]
) -> None:
"""Validates that the input token count and the requested token count doesn't exceed the model's context length."""
......@@ -548,6 +529,24 @@ class TokenizerManager:
)
raise ValueError(error_msg)
if isinstance(obj, GenerateReqInput):
if (
obj.return_hidden_states
and not self.server_args.enable_return_hidden_states
):
raise ValueError(
"The server is not configured to return the hidden states. "
"Please set `--enable-return-hidden-states` to enable this feature."
)
if (
obj.custom_logit_processor
and not self.server_args.enable_custom_logit_processor
):
raise ValueError(
"The server is not configured to enable custom logit processor. "
"Please set `--enable-custom-logits-processor` to enable this feature."
)
def _create_tokenized_object(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
......@@ -558,24 +557,6 @@ class TokenizerManager:
token_type_ids: Optional[List[int]] = None,
) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]:
"""Create a tokenized request object from common parameters."""
if self.is_generation:
return_logprob = obj.return_logprob
logprob_start_len = obj.logprob_start_len
top_logprobs_num = obj.top_logprobs_num
token_ids_logprob = obj.token_ids_logprob
session_params = (
SessionParams(**obj.session_params) if obj.session_params else None
)
if (
obj.custom_logit_processor
and not self.server_args.enable_custom_logit_processor
):
raise ValueError(
"The server is not configured to enable custom logit processor. "
"Please set `--enable-custom-logits-processor` to enable this feature."
)
# Parse sampling parameters
# Note: if there are preferred sampling params, we use them if they are not
# explicitly passed in sampling_params
......@@ -589,16 +570,20 @@ class TokenizerManager:
# Build return object
if isinstance(obj, GenerateReqInput):
session_params = (
SessionParams(**obj.session_params) if obj.session_params else None
)
tokenized_obj = TokenizedGenerateReqInput(
obj.rid,
input_text,
input_ids,
image_inputs,
sampling_params,
return_logprob,
logprob_start_len,
top_logprobs_num,
token_ids_logprob,
obj.return_logprob,
obj.logprob_start_len,
obj.top_logprobs_num,
obj.token_ids_logprob,
obj.stream,
bootstrap_host=obj.bootstrap_host,
bootstrap_port=obj.bootstrap_port,
......
......@@ -98,6 +98,7 @@ class BaseMultimodalProcessor(ABC):
self._processor = _processor
self.arch = hf_config.architectures[0]
self.server_args = server_args
# FIXME: not accurate, model and image specific
self.NUM_TOKEN_PER_FRAME = 330
......
......@@ -10,7 +10,6 @@ import torch
import sglang.srt.sampling.penaltylib as penaltylib
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
from sglang.srt.sampling.sampling_params import TOP_K_ALL
from sglang.srt.utils import merge_bias_tensor
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import ScheduleBatch
......@@ -345,3 +344,42 @@ class SamplingBatchInfo:
self.logit_bias = merge_bias_tensor(
self.logit_bias, other.logit_bias, len(self), len(other), self.device, 0.0
)
def merge_bias_tensor(
lhs: Optional[torch.Tensor],
rhs: Optional[torch.Tensor],
bs1: int,
bs2: int,
device: str,
default: float,
):
"""Merge two bias tensors for batch merging.
Args:
lhs: Left-hand side tensor
rhs: Right-hand side tensor
bs1: Batch size of left-hand side tensor
bs2: Batch size of right-hand side tensor
device: Device to place the merged tensor on
default: Default value for missing tensor elements
Returns:
Merged tensor or None if both inputs are None
"""
if lhs is None and rhs is None:
return None
if lhs is not None and rhs is not None:
return torch.cat([lhs, rhs])
else:
if lhs is not None:
shape, dtype = lhs.shape[1:], lhs.dtype
else:
shape, dtype = rhs.shape[1:], rhs.dtype
if lhs is None:
lhs = torch.empty((bs1, *shape), device=device, dtype=dtype).fill_(default)
if rhs is None:
rhs = torch.empty((bs2, *shape), device=device, dtype=dtype).fill_(default)
return torch.cat([lhs, rhs])
......@@ -4,7 +4,7 @@ from typing import List
import torch
from sglang.srt.utils import is_cuda, is_hip, rank0_print
from sglang.srt.utils import is_cuda, is_hip, rank0_log
if is_cuda() or is_hip():
from sgl_kernel import (
......@@ -344,13 +344,13 @@ def test_build_tree_kernel_efficient():
num_verify_tokens=num_draft_token,
)
rank0_print("=========== build tree kernel efficient ==========")
# rank0_print(f"{tree_mask=}", flush=True)
rank0_print(f"{position=}", flush=True)
rank0_print(f"{retrive_index=}", flush=True)
rank0_print(f"{retrive_next_token=}", flush=True)
rank0_print(f"{retrive_next_sibling=}", flush=True)
rank0_print(f"{draft_tokens=}", flush=True)
rank0_log("=========== build tree kernel efficient ==========")
# rank0_log(f"{tree_mask=}")
rank0_log(f"{position=}")
rank0_log(f"{retrive_index=}")
rank0_log(f"{retrive_next_token=}")
rank0_log(f"{retrive_next_sibling=}")
rank0_log(f"{draft_tokens=}")
assert position.tolist() == [5, 6, 6, 7, 7, 8, 8, 9, 10, 11, 12, 12, 12, 12, 13, 14]
assert retrive_index.tolist() == [
[0, 1, 2, 3, 4, 5, 6, 7],
......
......@@ -1917,14 +1917,11 @@ def configure_ipv6(dist_init_addr):
return port, host
def rank0_print(msg: str):
def rank0_log(msg: str):
from sglang.srt.distributed import get_tensor_model_parallel_rank
if get_tensor_model_parallel_rank() == 0:
print(msg, flush=True)
rank0_log = rank0_print
logger.info(msg)
def get_cuda_version():
......@@ -2344,45 +2341,6 @@ def require_mlp_sync(server_args):
return server_args.enable_dp_attention or require_gathered_buffer(server_args)
def merge_bias_tensor(
lhs: Optional[torch.Tensor],
rhs: Optional[torch.Tensor],
bs1: int,
bs2: int,
device: str,
default: float,
):
"""Merge two bias tensors for batch merging.
Args:
lhs: Left-hand side tensor
rhs: Right-hand side tensor
bs1: Batch size of left-hand side tensor
bs2: Batch size of right-hand side tensor
device: Device to place the merged tensor on
default: Default value for missing tensor elements
Returns:
Merged tensor or None if both inputs are None
"""
if lhs is None and rhs is None:
return None
if lhs is not None and rhs is not None:
return torch.cat([lhs, rhs])
else:
if lhs is not None:
shape, dtype = lhs.shape[1:], lhs.dtype
else:
shape, dtype = rhs.shape[1:], rhs.dtype
if lhs is None:
lhs = torch.empty((bs1, *shape), device=device, dtype=dtype).fill_(default)
if rhs is None:
rhs = torch.empty((bs2, *shape), device=device, dtype=dtype).fill_(default)
return torch.cat([lhs, rhs])
def find_local_repo_dir(repo_id: str, revision: Optional[str] = None) -> Optional[str]:
import huggingface_hub as hf
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment