"vscode:/vscode.git/clone" did not exist on "8d386f7990194172e40f6da651e00f92312cd35e"
Unverified Commit 4e2af03c authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

[Production] Drain requests before exit when receive SIGTERM (#1838)

parent 3184aa95
...@@ -20,6 +20,8 @@ import dataclasses ...@@ -20,6 +20,8 @@ import dataclasses
import json import json
import logging import logging
import os import os
import signal
import sys
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import fastapi import fastapi
...@@ -58,7 +60,12 @@ from sglang.srt.managers.io_struct import ( ...@@ -58,7 +60,12 @@ from sglang.srt.managers.io_struct import (
) )
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import get_zmq_socket, is_generation_model, is_multimodal_model from sglang.srt.utils import (
get_zmq_socket,
is_generation_model,
is_multimodal_model,
kill_child_process,
)
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
...@@ -142,6 +149,9 @@ class TokenizerManager: ...@@ -142,6 +149,9 @@ class TokenizerManager:
self.model_update_lock = asyncio.Lock() self.model_update_lock = asyncio.Lock()
self.model_update_result = None self.model_update_result = None
# Others
self.gracefully_exit = False
async def generate_request( async def generate_request(
self, self,
obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput], obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
...@@ -629,6 +639,28 @@ class TokenizerManager: ...@@ -629,6 +639,28 @@ class TokenizerManager:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
loop.create_task(self.handle_loop()) loop.create_task(self.handle_loop())
signal_handler = SignalHandler(self)
loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
loop.create_task(self.sigterm_watchdog())
async def sigterm_watchdog(self):
while not self.gracefully_exit:
await asyncio.sleep(60)
# drain requests
while True:
remain_num_req = len(self.rid_to_state)
logger.info(
f"gracefully exiting... remaining number of requests {remain_num_req}"
)
if remain_num_req > 0:
await asyncio.sleep(5)
else:
break
kill_child_process(include_self=True)
sys.exit(-1)
async def handle_loop(self): async def handle_loop(self):
"""The event loop that handles requests""" """The event loop that handles requests"""
...@@ -740,3 +772,14 @@ class TokenizerManager: ...@@ -740,3 +772,14 @@ class TokenizerManager:
token_top_logprobs, decode_to_text token_top_logprobs, decode_to_text
) )
return top_logprobs return top_logprobs
class SignalHandler:
def __init__(self, tokenizer_manager):
self.tokenizer_manager = tokenizer_manager
def signal_handler(self, signum=None, frame=None):
logger.warning(
f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
)
self.tokenizer_manager.gracefully_exit = True
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment