Unverified Commit 127d4b0d authored by Chanh Nguyen's avatar Chanh Nguyen Committed by GitHub
Browse files

Support GC Freezing to improve latency & throughput (#9241)


Co-authored-by: default avatarChanh Nguyen <cnguyen@linkedin.com>
Co-authored-by: default avatarLiangsheng Yin <hnyls2002@gmail.com>
parent 7e880286
...@@ -536,6 +536,22 @@ class Engine(EngineBase): ...@@ -536,6 +536,22 @@ class Engine(EngineBase):
self.tokenizer_manager.resume_memory_occupation(obj, None) self.tokenizer_manager.resume_memory_occupation(obj, None)
) )
def freeze_gc(self):
"""
To maintain a high performance server with low latency, we want to reduce the
stalls caused by the garbage collector scanning through a large number of objects.
It is usually helpful to start the server and warm it up with real requests to
initialize many of the long-lived objects that do not need to be garbage collected.
After sufficient warmup, we can call this function to freeze the garbage collector
so that all objects created before this point are considered out of scope for garbage
collection.
"""
loop = asyncio.get_event_loop()
loop.run_until_complete(self.tokenizer_manager.freeze_gc())
""" """
Execute an RPC call on all scheduler processes. Execute an RPC call on all scheduler processes.
""" """
......
...@@ -511,6 +511,18 @@ async def stop_profile_async(): ...@@ -511,6 +511,18 @@ async def stop_profile_async():
) )
@app.api_route("/freeze_gc", methods=["GET", "POST"])
async def freeze_gc_async():
"""
See engine.freeze_gc for more details.
"""
await _global_state.tokenizer_manager.freeze_gc()
return Response(
content="Garbage collection frozen.\n",
status_code=200,
)
@app.api_route("/start_expert_distribution_record", methods=["GET", "POST"]) @app.api_route("/start_expert_distribution_record", methods=["GET", "POST"])
async def start_expert_distribution_record_async(): async def start_expert_distribution_record_async():
"""Start recording the expert distribution. Clear the previous record if any.""" """Start recording the expert distribution. Clear the previous record if any."""
......
...@@ -31,10 +31,12 @@ from sglang.srt.managers.io_struct import ( ...@@ -31,10 +31,12 @@ from sglang.srt.managers.io_struct import (
BatchMultimodalOut, BatchMultimodalOut,
BatchStrOut, BatchStrOut,
BatchTokenIDOut, BatchTokenIDOut,
FreezeGCReq,
) )
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
configure_logger, configure_logger,
freeze_gc,
get_zmq_socket, get_zmq_socket,
kill_itself_when_parent_died, kill_itself_when_parent_died,
) )
...@@ -100,6 +102,7 @@ class DetokenizerManager: ...@@ -100,6 +102,7 @@ class DetokenizerManager:
(BatchEmbeddingOut, self.handle_batch_embedding_out), (BatchEmbeddingOut, self.handle_batch_embedding_out),
(BatchTokenIDOut, self.handle_batch_token_id_out), (BatchTokenIDOut, self.handle_batch_token_id_out),
(BatchMultimodalDecodeReq, self.handle_multimodal_decode_req), (BatchMultimodalDecodeReq, self.handle_multimodal_decode_req),
(FreezeGCReq, self.handle_freeze_gc_req),
] ]
) )
...@@ -108,7 +111,8 @@ class DetokenizerManager: ...@@ -108,7 +111,8 @@ class DetokenizerManager:
while True: while True:
recv_obj = self.recv_from_scheduler.recv_pyobj() recv_obj = self.recv_from_scheduler.recv_pyobj()
output = self._request_dispatcher(recv_obj) output = self._request_dispatcher(recv_obj)
self.send_to_tokenizer.send_pyobj(output) if output is not None:
self.send_to_tokenizer.send_pyobj(output)
def trim_matched_stop( def trim_matched_stop(
self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
...@@ -247,6 +251,10 @@ class DetokenizerManager: ...@@ -247,6 +251,10 @@ class DetokenizerManager:
cached_tokens=recv_obj.cached_tokens, cached_tokens=recv_obj.cached_tokens,
) )
def handle_freeze_gc_req(self, recv_req: FreezeGCReq):
freeze_gc("Detokenizer Manager")
return None
class LimitedCapacityDict(OrderedDict): class LimitedCapacityDict(OrderedDict):
def __init__(self, capacity: int, *args, **kwargs): def __init__(self, capacity: int, *args, **kwargs):
......
...@@ -1005,6 +1005,11 @@ class ProfileReqOutput: ...@@ -1005,6 +1005,11 @@ class ProfileReqOutput:
message: str message: str
@dataclass
class FreezeGCReq:
pass
@dataclass @dataclass
class ConfigureLoggingReq: class ConfigureLoggingReq:
log_requests: Optional[bool] = None log_requests: Optional[bool] = None
......
...@@ -72,6 +72,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -72,6 +72,7 @@ from sglang.srt.managers.io_struct import (
ExpertDistributionReqOutput, ExpertDistributionReqOutput,
FlushCacheReqInput, FlushCacheReqInput,
FlushCacheReqOutput, FlushCacheReqOutput,
FreezeGCReq,
GetInternalStateReq, GetInternalStateReq,
GetInternalStateReqOutput, GetInternalStateReqOutput,
GetWeightsByNameReqInput, GetWeightsByNameReqInput,
...@@ -145,6 +146,7 @@ from sglang.srt.utils import ( ...@@ -145,6 +146,7 @@ from sglang.srt.utils import (
configure_gc_logger, configure_gc_logger,
configure_logger, configure_logger,
disable_request_logging, disable_request_logging,
freeze_gc,
get_available_gpu_memory, get_available_gpu_memory,
get_bool_env_var, get_bool_env_var,
get_zmq_socket, get_zmq_socket,
...@@ -524,6 +526,7 @@ class Scheduler( ...@@ -524,6 +526,7 @@ class Scheduler(
(ResumeMemoryOccupationReqInput, self.resume_memory_occupation), (ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
(SlowDownReqInput, self.slow_down), (SlowDownReqInput, self.slow_down),
(ProfileReq, self.profile), (ProfileReq, self.profile),
(FreezeGCReq, self.handle_freeze_gc),
(GetInternalStateReq, self.get_internal_state), (GetInternalStateReq, self.get_internal_state),
(SetInternalStateReq, self.set_internal_state), (SetInternalStateReq, self.set_internal_state),
(RpcReqInput, self.handle_rpc_request), (RpcReqInput, self.handle_rpc_request),
...@@ -2469,6 +2472,12 @@ class Scheduler( ...@@ -2469,6 +2472,12 @@ class Scheduler(
if self.idle_sleeper is not None: if self.idle_sleeper is not None:
self.idle_sleeper.maybe_sleep() self.idle_sleeper.maybe_sleep()
def handle_freeze_gc(self, recv_req: FreezeGCReq):
"""Handle freeze_gc request: freeze scheduler's GC and forward to detokenizer."""
freeze_gc("Scheduler")
self.send_to_detokenizer.send_pyobj(recv_req)
return None
class IdleSleeper: class IdleSleeper:
""" """
......
...@@ -78,6 +78,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -78,6 +78,7 @@ from sglang.srt.managers.io_struct import (
ExpertDistributionReqOutput, ExpertDistributionReqOutput,
FlushCacheReqInput, FlushCacheReqInput,
FlushCacheReqOutput, FlushCacheReqOutput,
FreezeGCReq,
GenerateReqInput, GenerateReqInput,
GetInternalStateReq, GetInternalStateReq,
GetInternalStateReqOutput, GetInternalStateReqOutput,
...@@ -122,7 +123,9 @@ from sglang.srt.metrics.collector import TokenizerMetricsCollector ...@@ -122,7 +123,9 @@ from sglang.srt.metrics.collector import TokenizerMetricsCollector
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
configure_gc_warning,
dataclass_to_string_truncated, dataclass_to_string_truncated,
freeze_gc,
get_bool_env_var, get_bool_env_var,
get_zmq_socket, get_zmq_socket,
kill_process_tree, kill_process_tree,
...@@ -352,6 +355,10 @@ class TokenizerManager: ...@@ -352,6 +355,10 @@ class TokenizerManager:
collect_tokens_histogram=self.server_args.collect_tokens_histogram, collect_tokens_histogram=self.server_args.collect_tokens_histogram,
) )
# Configure GC warning
if self.server_args.gc_warning_threshold_secs > 0.0:
configure_gc_warning(self.server_args.gc_warning_threshold_secs)
# Communicators # Communicators
self.init_weights_update_group_communicator = _Communicator( self.init_weights_update_group_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size self.send_to_scheduler, server_args.dp_size
...@@ -446,6 +453,10 @@ class TokenizerManager: ...@@ -446,6 +453,10 @@ class TokenizerManager:
ProfileReqOutput, ProfileReqOutput,
self.profile_communicator.handle_recv, self.profile_communicator.handle_recv,
), ),
(
FreezeGCReq,
lambda x: None,
), # For handling case when scheduler skips detokenizer and forwards back to the tokenizer manager, we ignore it.
( (
GetInternalStateReqOutput, GetInternalStateReqOutput,
self.get_internal_state_communicator.handle_recv, self.get_internal_state_communicator.handle_recv,
...@@ -1359,6 +1370,12 @@ class TokenizerManager: ...@@ -1359,6 +1370,12 @@ class TokenizerManager:
logging.info(f"Config logging: {obj=}") logging.info(f"Config logging: {obj=}")
self.log_request_metadata = self.get_log_request_metadata() self.log_request_metadata = self.get_log_request_metadata()
async def freeze_gc(self):
"""Send a freeze_gc message to the scheduler first, then freeze locally."""
self.send_to_scheduler.send_pyobj(FreezeGCReq())
freeze_gc("Tokenizer Manager")
return None
def create_abort_task(self, obj: GenerateReqInput): def create_abort_task(self, obj: GenerateReqInput):
# Abort the request if the client is disconnected. # Abort the request if the client is disconnected.
async def abort_request(): async def abort_request():
......
...@@ -123,6 +123,7 @@ class ServerArgs: ...@@ -123,6 +123,7 @@ class ServerArgs:
decode_log_interval: int = 40 decode_log_interval: int = 40
enable_request_time_stats_logging: bool = False enable_request_time_stats_logging: bool = False
kv_events_config: Optional[str] = None kv_events_config: Optional[str] = None
gc_warning_threshold_secs: float = 0.0
# API related # API related
api_key: Optional[str] = None api_key: Optional[str] = None
...@@ -1172,6 +1173,12 @@ class ServerArgs: ...@@ -1172,6 +1173,12 @@ class ServerArgs:
default=ServerArgs.collect_tokens_histogram, default=ServerArgs.collect_tokens_histogram,
help="Collect prompt/generation tokens histogram.", help="Collect prompt/generation tokens histogram.",
) )
parser.add_argument(
"--gc-warning-threshold-secs",
type=float,
default=ServerArgs.gc_warning_threshold_secs,
help="The threshold for long GC warning. If a GC takes longer than this, a warning will be logged. Set to 0 to disable.",
)
parser.add_argument( parser.add_argument(
"--decode-log-interval", "--decode-log-interval",
type=int, type=int,
......
...@@ -2541,6 +2541,50 @@ def dynamic_import(func_path: str): ...@@ -2541,6 +2541,50 @@ def dynamic_import(func_path: str):
return func return func
def gc_object_counts():
import gc
g0 = len(gc.get_objects(0))
g1 = len(gc.get_objects(1))
g2 = len(gc.get_objects(2))
return g0, g1, g2
def configure_gc_warning(warn_threshold_secs):
import gc
gc_start_time = {}
def gc_callback(phase, info):
gen = info.get("generation", "?")
if phase == "start":
gc_start_time[gen] = time.time()
elif phase == "stop":
duration = time.time() - gc_start_time.get(gen, time.time())
if duration > warn_threshold_secs:
g0, g1, g2 = gc_object_counts()
logger.warn(
f"LONG GARBAGE COLLECTION DETECTED | Generation {gen} | Duration: {duration:.4f}s | # Objects: gen0={g0}, gen1={g1}, gen2={g2} | "
f"This may cause latency jitter. Consider calling the freeze_gc API after sending a few warmup requests."
)
gc.callbacks.append(gc_callback)
def freeze_gc(context: str):
import gc
g0_before, g1_before, g2_before = gc_object_counts()
gc.freeze()
g0_after, g1_after, g2_after = gc_object_counts()
logger.info(
f"Freezing GC in {context} process. "
f"gen0: {g0_before}->{g0_after}, "
f"gen1: {g1_before}->{g1_after}, "
f"gen2: {g2_before}->{g2_after}"
)
def configure_gc_logger(): def configure_gc_logger():
logger.info("Enable GC Logger") logger.info("Enable GC Logger")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment