"sgl-kernel/python/sgl_kernel/testing/rotary_embedding.py" did not exist on "fb11a4398158ce1838813b729d4ca188865b69f1"
Unverified Commit 4e2af03c authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

[Production] Drain requests before exit when receive SIGTERM (#1838)

parent 3184aa95
......@@ -20,6 +20,8 @@ import dataclasses
import json
import logging
import os
import signal
import sys
from typing import Dict, List, Optional, Tuple, Union
import fastapi
......@@ -58,7 +60,12 @@ from sglang.srt.managers.io_struct import (
)
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import get_zmq_socket, is_generation_model, is_multimodal_model
from sglang.srt.utils import (
get_zmq_socket,
is_generation_model,
is_multimodal_model,
kill_child_process,
)
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
......@@ -142,6 +149,9 @@ class TokenizerManager:
self.model_update_lock = asyncio.Lock()
self.model_update_result = None
# Others
self.gracefully_exit = False
async def generate_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
......@@ -629,6 +639,28 @@ class TokenizerManager:
loop = asyncio.get_event_loop()
loop.create_task(self.handle_loop())
signal_handler = SignalHandler(self)
loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
loop.create_task(self.sigterm_watchdog())
async def sigterm_watchdog(self):
while not self.gracefully_exit:
await asyncio.sleep(60)
# drain requests
while True:
remain_num_req = len(self.rid_to_state)
logger.info(
f"gracefully exiting... remaining number of requests {remain_num_req}"
)
if remain_num_req > 0:
await asyncio.sleep(5)
else:
break
kill_child_process(include_self=True)
sys.exit(-1)
async def handle_loop(self):
"""The event loop that handles requests"""
......@@ -740,3 +772,14 @@ class TokenizerManager:
token_top_logprobs, decode_to_text
)
return top_logprobs
class SignalHandler:
def __init__(self, tokenizer_manager):
self.tokenizer_manager = tokenizer_manager
def signal_handler(self, signum=None, frame=None):
logger.warning(
f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
)
self.tokenizer_manager.gracefully_exit = True
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment