Unverified Commit e07a6977 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Minor improvements of TokenizerManager / health check (#6327)

parent cd8d4b9d
...@@ -4,14 +4,16 @@ on: ...@@ -4,14 +4,16 @@ on:
push: push:
branches: [ main ] branches: [ main ]
paths: paths:
- "python/sglang/**" - "python/**"
- "scripts/**"
- "test/**" - "test/**"
- "sgl-kernel/**" - "sgl-kernel/**"
- ".github/workflows/pr-test-amd.yml" - ".github/workflows/pr-test-amd.yml"
pull_request: pull_request:
branches: [ main ] branches: [ main ]
paths: paths:
- "python/sglang/**" - "python/**"
- "scripts/**"
- "test/**" - "test/**"
- "sgl-kernel/**" - "sgl-kernel/**"
- ".github/workflows/pr-test-amd.yml" - ".github/workflows/pr-test-amd.yml"
......
...@@ -96,12 +96,14 @@ anthropic = ["anthropic>=0.20.0"] ...@@ -96,12 +96,14 @@ anthropic = ["anthropic>=0.20.0"]
litellm = ["litellm>=1.0.0"] litellm = ["litellm>=1.0.0"]
torch_memory_saver = ["torch_memory_saver>=0.0.4"] torch_memory_saver = ["torch_memory_saver>=0.0.4"]
test = [ test = [
"accelerate",
"torchaudio",
"jsonlines", "jsonlines",
"matplotlib", "matplotlib",
"pandas", "pandas",
"sentence_transformers",
"accelerate",
"peft", "peft",
"timm",
"sentence_transformers",
] ]
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]", "sglang[torch_memory_saver]"] all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]", "sglang[torch_memory_saver]"]
all_hip = ["sglang[srt_hip]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"] all_hip = ["sglang[srt_hip]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
......
...@@ -13,6 +13,8 @@ import torch.distributed as dist ...@@ -13,6 +13,8 @@ import torch.distributed as dist
from sglang.srt.utils import get_ip from sglang.srt.utils import get_ip
FakeBootstrapHost = "2.2.2.2"
class DisaggregationMode(Enum): class DisaggregationMode(Enum):
NULL = "null" NULL = "null"
...@@ -20,9 +22,6 @@ class DisaggregationMode(Enum): ...@@ -20,9 +22,6 @@ class DisaggregationMode(Enum):
DECODE = "decode" DECODE = "decode"
FakeBootstrapHost = "2.2.2.2"
def poll_and_all_reduce(pollers, gloo_group): def poll_and_all_reduce(pollers, gloo_group):
polls = [int(poller.poll()) for poller in pollers] polls = [int(poller.poll()) for poller in pollers]
tensor_to_reduce = torch.tensor(polls, dtype=torch.uint8, device="cpu") tensor_to_reduce = torch.tensor(polls, dtype=torch.uint8, device="cpu")
......
...@@ -189,6 +189,7 @@ async def health_generate(request: Request) -> Response: ...@@ -189,6 +189,7 @@ async def health_generate(request: Request) -> Response:
if _global_state.tokenizer_manager.last_receive_tstamp > tic: if _global_state.tokenizer_manager.last_receive_tstamp > tic:
task.cancel() task.cancel()
_global_state.tokenizer_manager.rid_to_state.pop(rid, None) _global_state.tokenizer_manager.rid_to_state.pop(rid, None)
_global_state.tokenizer_manager.health_check_failed = False
return Response(status_code=200) return Response(status_code=200)
task.cancel() task.cancel()
...@@ -202,6 +203,7 @@ async def health_generate(request: Request) -> Response: ...@@ -202,6 +203,7 @@ async def health_generate(request: Request) -> Response:
f"last_heartbeat time: {last_receive_time}" f"last_heartbeat time: {last_receive_time}"
) )
_global_state.tokenizer_manager.rid_to_state.pop(rid, None) _global_state.tokenizer_manager.rid_to_state.pop(rid, None)
_global_state.tokenizer_manager.health_check_failed = True
return Response(status_code=503) return Response(status_code=503)
......
...@@ -19,7 +19,6 @@ import warnings ...@@ -19,7 +19,6 @@ import warnings
from pathlib import Path from pathlib import Path
from typing import Dict, Optional, Type, Union from typing import Dict, Optional, Type, Union
import transformers
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
......
...@@ -129,7 +129,6 @@ from sglang.srt.utils import ( ...@@ -129,7 +129,6 @@ from sglang.srt.utils import (
DynamicGradMode, DynamicGradMode,
broadcast_pyobj, broadcast_pyobj,
configure_logger, configure_logger,
crash_on_warnings,
disable_request_logging, disable_request_logging,
get_bool_env_var, get_bool_env_var,
get_zmq_socket, get_zmq_socket,
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import asyncio import asyncio
import copy import copy
import dataclasses import dataclasses
import json
import logging import logging
import os import os
import pickle import pickle
...@@ -90,6 +91,8 @@ from sglang.srt.managers.io_struct import ( ...@@ -90,6 +91,8 @@ from sglang.srt.managers.io_struct import (
ResumeMemoryOccupationReqInput, ResumeMemoryOccupationReqInput,
ResumeMemoryOccupationReqOutput, ResumeMemoryOccupationReqOutput,
SessionParams, SessionParams,
SetInternalStateReq,
SetInternalStateReqOutput,
SlowDownReqInput, SlowDownReqInput,
SlowDownReqOutput, SlowDownReqOutput,
TokenizedEmbeddingReqInput, TokenizedEmbeddingReqInput,
...@@ -169,6 +172,11 @@ class TokenizerManager: ...@@ -169,6 +172,11 @@ class TokenizerManager:
self.enable_metrics = server_args.enable_metrics self.enable_metrics = server_args.enable_metrics
self.log_requests = server_args.log_requests self.log_requests = server_args.log_requests
self.log_requests_level = server_args.log_requests_level self.log_requests_level = server_args.log_requests_level
self.preferred_sampling_params = (
json.loads(server_args.preferred_sampling_params)
if server_args.preferred_sampling_params
else None
)
# Init inter-process communication # Init inter-process communication
context = zmq.asyncio.Context(2) context = zmq.asyncio.Context(2)
...@@ -228,6 +236,7 @@ class TokenizerManager: ...@@ -228,6 +236,7 @@ class TokenizerManager:
# Store states # Store states
self.no_create_loop = False self.no_create_loop = False
self.rid_to_state: Dict[str, ReqState] = {} self.rid_to_state: Dict[str, ReqState] = {}
self.health_check_failed = False
self.gracefully_exit = False self.gracefully_exit = False
self.last_receive_tstamp = 0 self.last_receive_tstamp = 0
self.dump_requests_folder = "" # By default do not dump self.dump_requests_folder = "" # By default do not dump
...@@ -255,6 +264,10 @@ class TokenizerManager: ...@@ -255,6 +264,10 @@ class TokenizerManager:
"model_name": self.server_args.served_model_name, "model_name": self.server_args.served_model_name,
# TODO: Add lora name/path in the future, # TODO: Add lora name/path in the future,
}, },
bucket_time_to_first_token=self.server_args.bucket_time_to_first_token,
bucket_e2e_request_latency=self.server_args.bucket_e2e_request_latency,
bucket_inter_token_latency=self.server_args.bucket_inter_token_latency,
collect_tokens_histogram=self.server_args.collect_tokens_histogram,
) )
# Communicators # Communicators
...@@ -285,9 +298,13 @@ class TokenizerManager: ...@@ -285,9 +298,13 @@ class TokenizerManager:
self.start_profile_communicator = _Communicator( self.start_profile_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size self.send_to_scheduler, server_args.dp_size
) )
self.health_check_communitcator = _Communicator(self.send_to_scheduler, 1)
self.get_internal_state_communicator = _Communicator( self.get_internal_state_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size self.send_to_scheduler, server_args.dp_size
) )
self.set_internal_state_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.expert_distribution_communicator = _Communicator( self.expert_distribution_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size self.send_to_scheduler, server_args.dp_size
) )
...@@ -349,6 +366,10 @@ class TokenizerManager: ...@@ -349,6 +366,10 @@ class TokenizerManager:
GetInternalStateReqOutput, GetInternalStateReqOutput,
self.get_internal_state_communicator.handle_recv, self.get_internal_state_communicator.handle_recv,
), ),
(
SetInternalStateReqOutput,
self.set_internal_state_communicator.handle_recv,
),
( (
ExpertDistributionReqOutput, ExpertDistributionReqOutput,
self.expert_distribution_communicator.handle_recv, self.expert_distribution_communicator.handle_recv,
...@@ -508,7 +529,14 @@ class TokenizerManager: ...@@ -508,7 +529,14 @@ class TokenizerManager:
"Please set `--enable-custom-logits-processor` to enable this feature." "Please set `--enable-custom-logits-processor` to enable this feature."
) )
sampling_params = SamplingParams(**obj.sampling_params) # Parse sampling parameters
# Note: if there are preferred sampling params, we use them if they are not
# explicitly passed in sampling_params
if self.preferred_sampling_params:
sampling_kwargs = {**self.preferred_sampling_params, **obj.sampling_params}
else:
sampling_kwargs = obj.sampling_params
sampling_params = SamplingParams(**sampling_kwargs)
sampling_params.normalize(self.tokenizer) sampling_params.normalize(self.tokenizer)
sampling_params.verify() sampling_params.verify()
...@@ -667,7 +695,6 @@ class TokenizerManager: ...@@ -667,7 +695,6 @@ class TokenizerManager:
generators = [] generators = []
rids = [] rids = []
if getattr(obj, "parallel_sample_num", 1) == 1: if getattr(obj, "parallel_sample_num", 1) == 1:
if self.server_args.enable_tokenizer_batch_encode: if self.server_args.enable_tokenizer_batch_encode:
# Validate batch tokenization constraints # Validate batch tokenization constraints
...@@ -857,7 +884,7 @@ class TokenizerManager: ...@@ -857,7 +884,7 @@ class TokenizerManager:
self.auto_create_handle_loop() self.auto_create_handle_loop()
assert ( assert (
self.server_args.dp_size == 1 self.server_args.dp_size == 1
), "dp_size must be for update weights from distributed" ), "dp_size must be 1 for update weights from distributed"
# This means that weight sync # This means that weight sync
# cannot run while requests are in progress. # cannot run while requests are in progress.
...@@ -946,6 +973,14 @@ class TokenizerManager: ...@@ -946,6 +973,14 @@ class TokenizerManager:
# Many DP ranks # Many DP ranks
return [res.internal_state for res in responses] return [res.internal_state for res in responses]
async def set_internal_state(
self, obj: SetInternalStateReq
) -> SetInternalStateReqOutput:
responses: List[SetInternalStateReqOutput] = (
await self.set_internal_state_communicator(obj)
)
return [res.internal_state for res in responses]
def get_log_request_metadata(self): def get_log_request_metadata(self):
max_length = None max_length = None
skip_names = None skip_names = None
...@@ -1015,11 +1050,17 @@ class TokenizerManager: ...@@ -1015,11 +1050,17 @@ class TokenizerManager:
loop.create_task(print_exception_wrapper(self.handle_loop)) loop.create_task(print_exception_wrapper(self.handle_loop))
) )
self.event_loop = loop
# We cannot add signal handler when the tokenizer manager is not in # We cannot add signal handler when the tokenizer manager is not in
# the main thread due to the CPython limitation. # the main thread due to the CPython limitation.
if threading.current_thread() is threading.main_thread(): if threading.current_thread() is threading.main_thread():
signal_handler = SignalHandler(self) signal_handler = SignalHandler(self)
loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler) loop.add_signal_handler(signal.SIGTERM, signal_handler.sigterm_handler)
# Update the signal handler for the process. It overrides the sigquit handler in the launch phase.
loop.add_signal_handler(
signal.SIGQUIT, signal_handler.running_phase_sigquit_handler
)
else: else:
logger.warning( logger.warning(
"Signal handler is not added because the tokenizer manager is " "Signal handler is not added because the tokenizer manager is "
...@@ -1037,6 +1078,15 @@ class TokenizerManager: ...@@ -1037,6 +1078,15 @@ class TokenizerManager:
# Drain requests # Drain requests
while True: while True:
remain_num_req = len(self.rid_to_state) remain_num_req = len(self.rid_to_state)
if self.health_check_failed:
# if health check failed, we should exit immediately
logger.error(
"Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
remain_num_req,
)
break
logger.info( logger.info(
f"Gracefully exiting... remaining number of requests {remain_num_req}" f"Gracefully exiting... remaining number of requests {remain_num_req}"
) )
...@@ -1120,7 +1170,16 @@ class TokenizerManager: ...@@ -1120,7 +1170,16 @@ class TokenizerManager:
"meta_info": meta_info, "meta_info": meta_info,
} }
elif isinstance(recv_obj, BatchMultimodalOut): elif isinstance(recv_obj, BatchMultimodalOut):
raise NotImplementedError() if isinstance(recv_obj.outputs[i], str):
out_dict = {
"text": recv_obj.outputs[i],
"meta_info": meta_info,
}
else:
out_dict = {
"outputs": json.dumps(recv_obj.outputs[i]),
"meta_info": meta_info,
}
else: else:
assert isinstance(recv_obj, BatchEmbeddingOut) assert isinstance(recv_obj, BatchEmbeddingOut)
out_dict = { out_dict = {
...@@ -1366,12 +1425,18 @@ class SignalHandler: ...@@ -1366,12 +1425,18 @@ class SignalHandler:
def __init__(self, tokenizer_manager: TokenizerManager): def __init__(self, tokenizer_manager: TokenizerManager):
self.tokenizer_manager = tokenizer_manager self.tokenizer_manager = tokenizer_manager
def signal_handler(self, signum=None, frame=None): def sigterm_handler(self, signum=None, frame=None):
logger.warning( logger.warning(
f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..." f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
) )
self.tokenizer_manager.gracefully_exit = True self.tokenizer_manager.gracefully_exit = True
def running_phase_sigquit_handler(self, signum=None, frame=None):
logger.error(
"Received sigquit from a child process. It usually means the child failed."
)
kill_process_tree(os.getpid())
T = TypeVar("T") T = TypeVar("T")
......
...@@ -46,7 +46,6 @@ class ServerArgs: ...@@ -46,7 +46,6 @@ class ServerArgs:
tokenizer_path: Optional[str] = None tokenizer_path: Optional[str] = None
tokenizer_mode: str = "auto" tokenizer_mode: str = "auto"
skip_tokenizer_init: bool = False skip_tokenizer_init: bool = False
enable_tokenizer_batch_encode: bool = False
load_format: str = "auto" load_format: str = "auto"
trust_remote_code: bool = False trust_remote_code: bool = False
dtype: str = "auto" dtype: str = "auto"
...@@ -59,6 +58,7 @@ class ServerArgs: ...@@ -59,6 +58,7 @@ class ServerArgs:
chat_template: Optional[str] = None chat_template: Optional[str] = None
completion_template: Optional[str] = None completion_template: Optional[str] = None
is_embedding: bool = False is_embedding: bool = False
enable_multimodal: Optional[bool] = None
revision: Optional[str] = None revision: Optional[str] = None
# Port for the HTTP server # Port for the HTTP server
...@@ -97,6 +97,10 @@ class ServerArgs: ...@@ -97,6 +97,10 @@ class ServerArgs:
log_requests_level: int = 0 log_requests_level: int = 0
show_time_cost: bool = False show_time_cost: bool = False
enable_metrics: bool = False enable_metrics: bool = False
bucket_time_to_first_token: Optional[List[float]] = None
bucket_e2e_request_latency: Optional[List[float]] = None
bucket_inter_token_latency: Optional[List[float]] = None
collect_tokens_histogram: bool = False
decode_log_interval: int = 40 decode_log_interval: int = 40
enable_request_time_stats_logging: bool = False enable_request_time_stats_logging: bool = False
...@@ -120,6 +124,7 @@ class ServerArgs: ...@@ -120,6 +124,7 @@ class ServerArgs:
# Model override args in JSON # Model override args in JSON
json_model_override_args: str = "{}" json_model_override_args: str = "{}"
preferred_sampling_params: Optional[str] = None
# LoRA # LoRA
lora_paths: Optional[List[str]] = None lora_paths: Optional[List[str]] = None
...@@ -154,9 +159,9 @@ class ServerArgs: ...@@ -154,9 +159,9 @@ class ServerArgs:
disable_cuda_graph: bool = False disable_cuda_graph: bool = False
disable_cuda_graph_padding: bool = False disable_cuda_graph_padding: bool = False
enable_nccl_nvls: bool = False enable_nccl_nvls: bool = False
enable_tokenizer_batch_encode: bool = False
disable_outlines_disk_cache: bool = False disable_outlines_disk_cache: bool = False
disable_custom_all_reduce: bool = False disable_custom_all_reduce: bool = False
enable_multimodal: Optional[bool] = None
disable_overlap_schedule: bool = False disable_overlap_schedule: bool = False
enable_mixed_chunk: bool = False enable_mixed_chunk: bool = False
enable_dp_attention: bool = False enable_dp_attention: bool = False
...@@ -474,11 +479,6 @@ class ServerArgs: ...@@ -474,11 +479,6 @@ class ServerArgs:
action="store_true", action="store_true",
help="If set, skip init tokenizer and pass input_ids in generate request.", help="If set, skip init tokenizer and pass input_ids in generate request.",
) )
parser.add_argument(
"--enable-tokenizer-batch-encode",
action="store_true",
help="Enable batch tokenization for improved performance when processing multiple text inputs. Do not use with image inputs, pre-tokenized input_ids, or input_embeds.",
)
parser.add_argument( parser.add_argument(
"--load-format", "--load-format",
type=str, type=str,
...@@ -603,6 +603,12 @@ class ServerArgs: ...@@ -603,6 +603,12 @@ class ServerArgs:
action="store_true", action="store_true",
help="Whether to use a CausalLM as an embedding model.", help="Whether to use a CausalLM as an embedding model.",
) )
parser.add_argument(
"--enable-multimodal",
default=ServerArgs.enable_multimodal,
action="store_true",
help="Enable the multimodal functionality for the served model. If the model being served is not multimodal, nothing will happen",
)
parser.add_argument( parser.add_argument(
"--revision", "--revision",
type=str, type=str,
...@@ -780,6 +786,33 @@ class ServerArgs: ...@@ -780,6 +786,33 @@ class ServerArgs:
action="store_true", action="store_true",
help="Enable log prometheus metrics.", help="Enable log prometheus metrics.",
) )
parser.add_argument(
"--bucket-time-to-first-token",
type=float,
nargs="+",
default=ServerArgs.bucket_time_to_first_token,
help="The buckets of time to first token, specified as a list of floats.",
)
parser.add_argument(
"--bucket-inter-token-latency",
type=float,
nargs="+",
default=ServerArgs.bucket_inter_token_latency,
help="The buckets of inter-token latency, specified as a list of floats.",
)
parser.add_argument(
"--bucket-e2e-request-latency",
type=float,
nargs="+",
default=ServerArgs.bucket_e2e_request_latency,
help="The buckets of end-to-end request latency, specified as a list of floats.",
)
parser.add_argument(
"--collect-tokens-histogram",
action="store_true",
default=ServerArgs.collect_tokens_histogram,
help="Collect prompt/generation tokens histogram.",
)
parser.add_argument( parser.add_argument(
"--decode-log-interval", "--decode-log-interval",
type=int, type=int,
...@@ -868,6 +901,11 @@ class ServerArgs: ...@@ -868,6 +901,11 @@ class ServerArgs:
help="A dictionary in JSON string format used to override default model configurations.", help="A dictionary in JSON string format used to override default model configurations.",
default=ServerArgs.json_model_override_args, default=ServerArgs.json_model_override_args,
) )
parser.add_argument(
"--preferred-sampling-params",
type=str,
help="json-formatted sampling settings that will be returned in /get_model_info",
)
# LoRA # LoRA
parser.add_argument( parser.add_argument(
...@@ -1043,6 +1081,11 @@ class ServerArgs: ...@@ -1043,6 +1081,11 @@ class ServerArgs:
action="store_true", action="store_true",
help="Enable NCCL NVLS for prefill heavy requests when available.", help="Enable NCCL NVLS for prefill heavy requests when available.",
) )
parser.add_argument(
"--enable-tokenizer-batch-encode",
action="store_true",
help="Enable batch tokenization for improved performance when processing multiple text inputs. Do not use with image inputs, pre-tokenized input_ids, or input_embeds.",
)
parser.add_argument( parser.add_argument(
"--disable-outlines-disk-cache", "--disable-outlines-disk-cache",
action="store_true", action="store_true",
...@@ -1053,12 +1096,6 @@ class ServerArgs: ...@@ -1053,12 +1096,6 @@ class ServerArgs:
action="store_true", action="store_true",
help="Disable the custom all-reduce kernel and fall back to NCCL.", help="Disable the custom all-reduce kernel and fall back to NCCL.",
) )
parser.add_argument(
"--enable-multimodal",
default=ServerArgs.enable_multimodal,
action="store_true",
help="Enable the multimodal functionality for the served model. If the model being served is not multimodal, nothing will happen",
)
parser.add_argument( parser.add_argument(
"--disable-overlap-schedule", "--disable-overlap-schedule",
action="store_true", action="store_true",
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# Install the dependency in CI. # Install the dependency in CI.
set -euxo pipefail set -euxo pipefail
# Kill existing processes
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
bash "${SCRIPT_DIR}/killall_sglang.sh" bash "${SCRIPT_DIR}/killall_sglang.sh"
...@@ -16,13 +17,10 @@ rm -rf /usr/local/lib/python3.10/dist-packages/flashinfer* ...@@ -16,13 +17,10 @@ rm -rf /usr/local/lib/python3.10/dist-packages/flashinfer*
rm -rf /usr/local/lib/python3.10/dist-packages/sgl_kernel* rm -rf /usr/local/lib/python3.10/dist-packages/sgl_kernel*
# Install the main package # Install the main package
pip install -e "python[all]" pip install -e "python[dev]"
# Install additional dependencies # Install additional dependencies
pip install transformers==4.51.0 timm torchaudio==2.6.0 sentence_transformers accelerate peft pandas datasets mooncake-transfer-engine==0.3.0 pip install mooncake-transfer-engine==0.3.0 nvidia-cuda-nvrtc-cu12
# For compiling xgrammar kernels
pip install cuda-python nvidia-cuda-nvrtc-cu12
# For lmms_evals evaluating MMMU # For lmms_evals evaluating MMMU
git clone --branch v0.3.3 --depth 1 https://github.com/EvolvingLMMs-Lab/lmms-eval.git git clone --branch v0.3.3 --depth 1 https://github.com/EvolvingLMMs-Lab/lmms-eval.git
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment