Unverified Commit 91f93f14 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Crash the server when error or OOM happens (#514)

parent f70f7258
...@@ -12,6 +12,7 @@ from sglang.global_config import global_config ...@@ -12,6 +12,7 @@ from sglang.global_config import global_config
from sglang.srt.managers.controller.tp_worker import ModelTpClient from sglang.srt.managers.controller.tp_worker import ModelTpClient
from sglang.srt.managers.io_struct import BatchTokenIDOut from sglang.srt.managers.io_struct import BatchTokenIDOut
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import kill_parent_process
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
logger = logging.getLogger("srt.controller") logger = logging.getLogger("srt.controller")
...@@ -58,6 +59,10 @@ class DataParallelWorkerThread(threading.Thread): ...@@ -58,6 +59,10 @@ class DataParallelWorkerThread(threading.Thread):
f"{get_exception_traceback()}" f"{get_exception_traceback()}"
) )
self.liveness = False self.liveness = False
# Crash the whole server when there are any errors.
# TODO(lianmin): make this an option.
kill_parent_process()
return
for obj in out_pyobjs: for obj in out_pyobjs:
self.send_to_detokenizer.send_pyobj(obj) self.send_to_detokenizer.send_pyobj(obj)
......
"""A controller that manages a group of tensor parallel workers.""" """A controller that manages a group of tensor parallel workers."""
import asyncio import asyncio
import logging import logging
import time
import uvloop import uvloop
import zmq import zmq
...@@ -9,10 +10,13 @@ import zmq.asyncio ...@@ -9,10 +10,13 @@ import zmq.asyncio
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.managers.controller.tp_worker import ModelTpClient from sglang.srt.managers.controller.tp_worker import ModelTpClient
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import kill_parent_process
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
logger = logging.getLogger("srt.controller")
class ControllerSingle: class ControllerSingle:
def __init__(self, model_client: ModelTpClient, port_args: PortArgs): def __init__(self, model_client: ModelTpClient, port_args: PortArgs):
...@@ -85,4 +89,9 @@ def start_controller_process( ...@@ -85,4 +89,9 @@ def start_controller_process(
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
loop.create_task(controller.loop_for_recv_requests()) loop.create_task(controller.loop_for_recv_requests())
loop.run_until_complete(controller.loop_for_forward()) try:
\ No newline at end of file loop.run_until_complete(controller.loop_for_forward())
except Exception:
logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
finally:
kill_parent_process()
\ No newline at end of file
...@@ -18,7 +18,7 @@ from vllm.model_executor.models import ModelRegistry ...@@ -18,7 +18,7 @@ from vllm.model_executor.models import ModelRegistry
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model, monkey_patch_vllm_p2p_access_check
logger = logging.getLogger("srt.model_runner") logger = logging.getLogger("srt.model_runner")
...@@ -240,10 +240,12 @@ class ModelRunner: ...@@ -240,10 +240,12 @@ class ModelRunner:
logger.info(f"[gpu_id={self.gpu_id}] Set cuda device.") logger.info(f"[gpu_id={self.gpu_id}] Set cuda device.")
torch.cuda.set_device(self.gpu_id) torch.cuda.set_device(self.gpu_id)
logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.") logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.")
monkey_patch_vllm_p2p_access_check()
init_distributed_environment( init_distributed_environment(
backend="nccl", backend="nccl",
world_size=self.tp_size, world_size=self.tp_size,
rank=self.tp_rank, rank=self.tp_rank,
local_rank=self.gpu_id,
distributed_init_method=f"tcp://127.0.0.1:{self.nccl_port}", distributed_init_method=f"tcp://127.0.0.1:{self.nccl_port}",
) )
initialize_model_parallel(tensor_model_parallel_size=self.tp_size) initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
...@@ -265,7 +267,7 @@ class ModelRunner: ...@@ -265,7 +267,7 @@ class ModelRunner:
def load_model(self): def load_model(self):
logger.info( logger.info(
f"[gpu_id={self.gpu_id}] Load weight begin. " f"[gpu_id={self.gpu_id}] Load weight begin. "
f"Avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
) )
device_config = DeviceConfig() device_config = DeviceConfig()
...@@ -295,8 +297,8 @@ class ModelRunner: ...@@ -295,8 +297,8 @@ class ModelRunner:
) )
logger.info( logger.info(
f"[gpu_id={self.gpu_id}] Load weight end. " f"[gpu_id={self.gpu_id}] Load weight end. "
f"Type={type(self.model).__name__}. " f"type={type(self.model).__name__}, "
f"Avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
) )
def profile_max_num_token(self, total_gpu_memory): def profile_max_num_token(self, total_gpu_memory):
...@@ -333,7 +335,7 @@ class ModelRunner: ...@@ -333,7 +335,7 @@ class ModelRunner:
) )
logger.info( logger.info(
f"[gpu_id={self.gpu_id}] Memory pool end. " f"[gpu_id={self.gpu_id}] Memory pool end. "
f"Avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
) )
@torch.inference_mode() @torch.inference_mode()
......
...@@ -34,7 +34,7 @@ from sglang.srt.utils import ( ...@@ -34,7 +34,7 @@ from sglang.srt.utils import (
) )
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
logger = logging.getLogger("srt.model_tp") logger = logging.getLogger("srt.tp_worker")
class ModelTpServer: class ModelTpServer:
...@@ -187,7 +187,8 @@ class ModelTpServer: ...@@ -187,7 +187,8 @@ class ModelTpServer:
# Forward # Forward
self.forward_step() self.forward_step()
except Exception: except Exception:
logger.error("Exception in ModelTpClient:\n" + get_exception_traceback()) logger.error("Exception in ModelTpServer:\n" + get_exception_traceback())
raise
# Return results # Return results
ret = self.out_pyobjs ret = self.out_pyobjs
......
...@@ -87,7 +87,7 @@ def start_detokenizer_process( ...@@ -87,7 +87,7 @@ def start_detokenizer_process(
try: try:
manager = DetokenizerManager(server_args, port_args) manager = DetokenizerManager(server_args, port_args)
except Exception as e: except Exception:
pipe_writer.send(get_exception_traceback()) pipe_writer.send(get_exception_traceback())
raise raise
pipe_writer.send("init ok") pipe_writer.send("init ok")
......
...@@ -228,20 +228,21 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg ...@@ -228,20 +228,21 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
# Send a warmup request # Send a warmup request
try: try:
res = requests.post( for _ in range(server_args.dp_size):
url + "/generate", res = requests.post(
json={ url + "/generate",
"text": "The capital city of France is", json={
"sampling_params": { "text": "The capital city of France is",
"temperature": 0, "sampling_params": {
"max_new_tokens": 16, "temperature": 0,
"max_new_tokens": 16,
},
}, },
}, headers=headers,
headers=headers, timeout=600,
timeout=600, )
) assert res.status_code == 200
assert res.status_code == 200 except Exception:
except Exception as e:
if pipe_finish_writer is not None: if pipe_finish_writer is not None:
pipe_finish_writer.send(get_exception_traceback()) pipe_finish_writer.send(get_exception_traceback())
print(f"Initialization failed. warmup error: {e}") print(f"Initialization failed. warmup error: {e}")
......
...@@ -12,6 +12,7 @@ from io import BytesIO ...@@ -12,6 +12,7 @@ from io import BytesIO
from typing import List, Optional from typing import List, Optional
import numpy as np import numpy as np
import psutil
import requests import requests
import rpyc import rpyc
import torch import torch
...@@ -441,6 +442,27 @@ def assert_pkg_version(pkg: str, min_version: str): ...@@ -441,6 +442,27 @@ def assert_pkg_version(pkg: str, min_version: str):
) )
def kill_parent_process():
"""Kill the parent process and all children of the parent process."""
current_process = psutil.Process()
parent_process = current_process.parent()
children = current_process.children(recursive=True)
for child in children:
if child.pid != current_process.pid:
os.kill(child.pid, 9)
os.kill(parent_process.pid, 9)
def monkey_patch_vllm_p2p_access_check():
"""
Monkey patch the slow p2p access check in vllm.
NOTE: We assume the p2p access is always allowed, which can be wrong for some setups.
"""
import vllm.distributed.device_communicators.custom_all_reduce_utils as tgt
setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True)
API_KEY_HEADER_NAME = "X-API-Key" API_KEY_HEADER_NAME = "X-API-Key"
...@@ -459,3 +481,4 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware): ...@@ -459,3 +481,4 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
) )
response = await call_next(request) response = await call_next(request)
return response return response
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment