Unverified Commit 91f93f14 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Crash the server when error or OOM happens (#514)

parent f70f7258
......@@ -12,6 +12,7 @@ from sglang.global_config import global_config
from sglang.srt.managers.controller.tp_worker import ModelTpClient
from sglang.srt.managers.io_struct import BatchTokenIDOut
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import kill_parent_process
from sglang.utils import get_exception_traceback
logger = logging.getLogger("srt.controller")
......@@ -58,6 +59,10 @@ class DataParallelWorkerThread(threading.Thread):
f"{get_exception_traceback()}"
)
self.liveness = False
# Crash the whole server when there are any errors.
# TODO(lianmin): make this an option.
kill_parent_process()
return
for obj in out_pyobjs:
self.send_to_detokenizer.send_pyobj(obj)
......
"""A controller that manages a group of tensor parallel workers."""
import asyncio
import logging
import time
import uvloop
import zmq
......@@ -9,10 +10,13 @@ import zmq.asyncio
from sglang.global_config import global_config
from sglang.srt.managers.controller.tp_worker import ModelTpClient
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import kill_parent_process
from sglang.utils import get_exception_traceback
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
logger = logging.getLogger("srt.controller")
class ControllerSingle:
def __init__(self, model_client: ModelTpClient, port_args: PortArgs):
......@@ -85,4 +89,9 @@ def start_controller_process(
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.create_task(controller.loop_for_recv_requests())
loop.run_until_complete(controller.loop_for_forward())
\ No newline at end of file
try:
loop.run_until_complete(controller.loop_for_forward())
except Exception:
logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
finally:
kill_parent_process()
\ No newline at end of file
......@@ -18,7 +18,7 @@ from vllm.model_executor.models import ModelRegistry
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model
from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model, monkey_patch_vllm_p2p_access_check
logger = logging.getLogger("srt.model_runner")
......@@ -240,10 +240,12 @@ class ModelRunner:
logger.info(f"[gpu_id={self.gpu_id}] Set cuda device.")
torch.cuda.set_device(self.gpu_id)
logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.")
monkey_patch_vllm_p2p_access_check()
init_distributed_environment(
backend="nccl",
world_size=self.tp_size,
rank=self.tp_rank,
local_rank=self.gpu_id,
distributed_init_method=f"tcp://127.0.0.1:{self.nccl_port}",
)
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
......@@ -265,7 +267,7 @@ class ModelRunner:
def load_model(self):
logger.info(
f"[gpu_id={self.gpu_id}] Load weight begin. "
f"Avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
)
device_config = DeviceConfig()
......@@ -295,8 +297,8 @@ class ModelRunner:
)
logger.info(
f"[gpu_id={self.gpu_id}] Load weight end. "
f"Type={type(self.model).__name__}. "
f"Avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
f"type={type(self.model).__name__}, "
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
)
def profile_max_num_token(self, total_gpu_memory):
......@@ -333,7 +335,7 @@ class ModelRunner:
)
logger.info(
f"[gpu_id={self.gpu_id}] Memory pool end. "
f"Avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
)
@torch.inference_mode()
......
......@@ -34,7 +34,7 @@ from sglang.srt.utils import (
)
from sglang.utils import get_exception_traceback
logger = logging.getLogger("srt.model_tp")
logger = logging.getLogger("srt.tp_worker")
class ModelTpServer:
......@@ -187,7 +187,8 @@ class ModelTpServer:
# Forward
self.forward_step()
except Exception:
logger.error("Exception in ModelTpClient:\n" + get_exception_traceback())
logger.error("Exception in ModelTpServer:\n" + get_exception_traceback())
raise
# Return results
ret = self.out_pyobjs
......
......@@ -87,7 +87,7 @@ def start_detokenizer_process(
try:
manager = DetokenizerManager(server_args, port_args)
except Exception as e:
except Exception:
pipe_writer.send(get_exception_traceback())
raise
pipe_writer.send("init ok")
......
......@@ -228,20 +228,21 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
# Send a warmup request
try:
res = requests.post(
url + "/generate",
json={
"text": "The capital city of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 16,
for _ in range(server_args.dp_size):
res = requests.post(
url + "/generate",
json={
"text": "The capital city of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 16,
},
},
},
headers=headers,
timeout=600,
)
assert res.status_code == 200
except Exception as e:
headers=headers,
timeout=600,
)
assert res.status_code == 200
except Exception:
if pipe_finish_writer is not None:
pipe_finish_writer.send(get_exception_traceback())
print(f"Initialization failed. warmup error: {e}")
......
......@@ -12,6 +12,7 @@ from io import BytesIO
from typing import List, Optional
import numpy as np
import psutil
import requests
import rpyc
import torch
......@@ -441,6 +442,27 @@ def assert_pkg_version(pkg: str, min_version: str):
)
def kill_parent_process():
"""Kill the parent process and all children of the parent process."""
current_process = psutil.Process()
parent_process = current_process.parent()
children = current_process.children(recursive=True)
for child in children:
if child.pid != current_process.pid:
os.kill(child.pid, 9)
os.kill(parent_process.pid, 9)
def monkey_patch_vllm_p2p_access_check():
"""
Monkey patch the slow p2p access check in vllm.
NOTE: We assume the p2p access is always allowed, which can be wrong for some setups.
"""
import vllm.distributed.device_communicators.custom_all_reduce_utils as tgt
setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True)
API_KEY_HEADER_NAME = "X-API-Key"
......@@ -459,3 +481,4 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
)
response = await call_next(request)
return response
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment