Unverified Commit d4fc1a70 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Crash the server correctly during error (#2231)

parent db674e3d
...@@ -47,6 +47,7 @@ import itertools ...@@ -47,6 +47,7 @@ import itertools
import json import json
import logging import logging
import multiprocessing import multiprocessing
import os
import time import time
from typing import Tuple from typing import Tuple
...@@ -62,11 +63,7 @@ from sglang.srt.model_executor.model_runner import ModelRunner ...@@ -62,11 +63,7 @@ from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server import _set_envs_and_config from sglang.srt.server import _set_envs_and_config
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import configure_logger, kill_process_tree, suppress_other_loggers
configure_logger,
kill_child_process,
suppress_other_loggers,
)
@dataclasses.dataclass @dataclasses.dataclass
...@@ -468,4 +465,4 @@ if __name__ == "__main__": ...@@ -468,4 +465,4 @@ if __name__ == "__main__":
main(server_args, bench_args) main(server_args, bench_args)
finally: finally:
if server_args.tp_size != 1: if server_args.tp_size != 1:
kill_child_process() kill_process_tree(os.getpid(), include_parent=False)
...@@ -15,6 +15,7 @@ import dataclasses ...@@ -15,6 +15,7 @@ import dataclasses
import itertools import itertools
import json import json
import multiprocessing import multiprocessing
import os
import time import time
from typing import Tuple from typing import Tuple
...@@ -23,7 +24,7 @@ import requests ...@@ -23,7 +24,7 @@ import requests
from sglang.srt.server import launch_server from sglang.srt.server import launch_server
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_process_tree
@dataclasses.dataclass @dataclasses.dataclass
...@@ -69,7 +70,7 @@ def launch_server_internal(server_args): ...@@ -69,7 +70,7 @@ def launch_server_internal(server_args):
except Exception as e: except Exception as e:
raise e raise e
finally: finally:
kill_child_process() kill_process_tree(os.getpid(), include_parent=False)
def launch_server_process(server_args: ServerArgs): def launch_server_process(server_args: ServerArgs):
...@@ -175,7 +176,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs): ...@@ -175,7 +176,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
) )
finally: finally:
if proc: if proc:
kill_child_process(proc.pid, include_self=True) kill_process_tree(proc.pid)
print(f"\nResults are saved to {bench_args.result_filename}") print(f"\nResults are saved to {bench_args.result_filename}")
......
...@@ -4,7 +4,7 @@ import sys ...@@ -4,7 +4,7 @@ import sys
from sglang.srt.server import launch_server from sglang.srt.server import launch_server
from sglang.srt.server_args import prepare_server_args from sglang.srt.server_args import prepare_server_args
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_process_tree
if __name__ == "__main__": if __name__ == "__main__":
server_args = prepare_server_args(sys.argv[1:]) server_args = prepare_server_args(sys.argv[1:])
...@@ -12,4 +12,4 @@ if __name__ == "__main__": ...@@ -12,4 +12,4 @@ if __name__ == "__main__":
try: try:
launch_server(server_args) launch_server(server_args)
finally: finally:
kill_child_process() kill_process_tree(os.getpid(), include_parent=False)
...@@ -15,9 +15,11 @@ ...@@ -15,9 +15,11 @@
import logging import logging
import multiprocessing as mp import multiprocessing as mp
import signal
import threading import threading
from enum import Enum, auto from enum import Enum, auto
import psutil
import zmq import zmq
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
...@@ -26,13 +28,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -26,13 +28,7 @@ from sglang.srt.managers.io_struct import (
) )
from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import bind_port, configure_logger, get_zmq_socket
bind_port,
configure_logger,
get_zmq_socket,
kill_parent_process,
suppress_other_loggers,
)
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -235,7 +231,7 @@ def run_data_parallel_controller_process( ...@@ -235,7 +231,7 @@ def run_data_parallel_controller_process(
pipe_writer, pipe_writer,
): ):
configure_logger(server_args) configure_logger(server_args)
suppress_other_loggers() parent_process = psutil.Process().parent()
try: try:
controller = DataParallelController(server_args, port_args) controller = DataParallelController(server_args, port_args)
...@@ -244,6 +240,6 @@ def run_data_parallel_controller_process( ...@@ -244,6 +240,6 @@ def run_data_parallel_controller_process(
) )
controller.event_loop() controller.event_loop()
except Exception: except Exception:
msg = get_exception_traceback() traceback = get_exception_traceback()
logger.error(msg) logger.error(f"DataParallelController hit an exception: {traceback}")
kill_parent_process() parent_process.send_signal(signal.SIGQUIT)
...@@ -15,9 +15,11 @@ ...@@ -15,9 +15,11 @@
import dataclasses import dataclasses
import logging import logging
import signal
from collections import OrderedDict from collections import OrderedDict
from typing import List, Union from typing import List, Union
import psutil
import zmq import zmq
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
...@@ -28,7 +30,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -28,7 +30,7 @@ from sglang.srt.managers.io_struct import (
) )
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import configure_logger, get_zmq_socket, kill_parent_process from sglang.srt.utils import configure_logger, get_zmq_socket
from sglang.utils import find_printable_text, get_exception_traceback from sglang.utils import find_printable_text, get_exception_traceback
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -193,11 +195,12 @@ def run_detokenizer_process( ...@@ -193,11 +195,12 @@ def run_detokenizer_process(
port_args: PortArgs, port_args: PortArgs,
): ):
configure_logger(server_args) configure_logger(server_args)
parent_process = psutil.Process().parent()
try: try:
manager = DetokenizerManager(server_args, port_args) manager = DetokenizerManager(server_args, port_args)
manager.event_loop() manager.event_loop()
except Exception: except Exception:
msg = get_exception_traceback() traceback = get_exception_traceback()
logger.error(msg) logger.error(f"DetokenizerManager hit an exception: {traceback}")
kill_parent_process() parent_process.send_signal(signal.SIGQUIT)
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import logging import logging
import os import os
import signal
import threading import threading
import time import time
import warnings import warnings
...@@ -23,6 +24,7 @@ from concurrent import futures ...@@ -23,6 +24,7 @@ from concurrent import futures
from types import SimpleNamespace from types import SimpleNamespace
from typing import List, Optional from typing import List, Optional
import psutil
import torch import torch
import zmq import zmq
...@@ -73,7 +75,6 @@ from sglang.srt.utils import ( ...@@ -73,7 +75,6 @@ from sglang.srt.utils import (
crash_on_warnings, crash_on_warnings,
get_bool_env_var, get_bool_env_var,
get_zmq_socket, get_zmq_socket,
kill_parent_process,
set_gpu_proc_affinity, set_gpu_proc_affinity,
set_random_seed, set_random_seed,
suppress_other_loggers, suppress_other_loggers,
...@@ -316,6 +317,7 @@ class Scheduler: ...@@ -316,6 +317,7 @@ class Scheduler:
self.watchdog_timeout = server_args.watchdog_timeout self.watchdog_timeout = server_args.watchdog_timeout
t = threading.Thread(target=self.watchdog_thread, daemon=True) t = threading.Thread(target=self.watchdog_thread, daemon=True)
t.start() t.start()
self.parent_process = psutil.Process().parent()
# Init profiler # Init profiler
if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "": if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
...@@ -359,7 +361,7 @@ class Scheduler: ...@@ -359,7 +361,7 @@ class Scheduler:
self.watchdog_last_time = time.time() self.watchdog_last_time = time.time()
time.sleep(self.watchdog_timeout / 2) time.sleep(self.watchdog_timeout / 2)
kill_parent_process() self.parent_process.send_signal(signal.SIGQUIT)
@torch.no_grad() @torch.no_grad()
def event_loop_normal(self): def event_loop_normal(self):
...@@ -1423,6 +1425,7 @@ def run_scheduler_process( ...@@ -1423,6 +1425,7 @@ def run_scheduler_process(
configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}") configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")
suppress_other_loggers() suppress_other_loggers()
parent_process = psutil.Process().parent()
try: try:
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank) scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
...@@ -1434,6 +1437,6 @@ def run_scheduler_process( ...@@ -1434,6 +1437,6 @@ def run_scheduler_process(
else: else:
scheduler.event_loop_normal() scheduler.event_loop_normal()
except Exception: except Exception:
msg = get_exception_traceback() traceback = get_exception_traceback()
logger.error(msg) logger.error(f"Scheduler hit an exception: {traceback}")
kill_parent_process() parent_process.send_signal(signal.SIGQUIT)
...@@ -58,7 +58,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -58,7 +58,7 @@ from sglang.srt.managers.io_struct import (
from sglang.srt.metrics.collector import TokenizerMetricsCollector from sglang.srt.metrics.collector import TokenizerMetricsCollector
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import get_zmq_socket, kill_child_process from sglang.srt.utils import get_zmq_socket, kill_process_tree
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
...@@ -532,7 +532,7 @@ class TokenizerManager: ...@@ -532,7 +532,7 @@ class TokenizerManager:
else: else:
break break
kill_child_process(include_self=True) kill_process_tree(os.getpid(), include_parent=True)
sys.exit(0) sys.exit(0)
async def handle_loop(self): async def handle_loop(self):
......
...@@ -15,16 +15,19 @@ ...@@ -15,16 +15,19 @@
import dataclasses import dataclasses
import logging import logging
import signal
import threading import threading
from queue import Queue from queue import Queue
from typing import Optional from typing import Optional
import psutil
import torch import torch
from sglang.srt.managers.io_struct import UpdateWeightReqInput from sglang.srt.managers.io_struct import UpdateWeightReqInput
from sglang.srt.managers.schedule_batch import ModelWorkerBatch from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -70,6 +73,7 @@ class TpModelWorkerClient: ...@@ -70,6 +73,7 @@ class TpModelWorkerClient:
target=self.forward_thread_func, target=self.forward_thread_func,
) )
self.forward_thread.start() self.forward_thread.start()
self.parent_process = psutil.Process().parent()
def get_worker_info(self): def get_worker_info(self):
return self.worker.get_worker_info() return self.worker.get_worker_info()
...@@ -87,8 +91,13 @@ class TpModelWorkerClient: ...@@ -87,8 +91,13 @@ class TpModelWorkerClient:
) )
def forward_thread_func(self): def forward_thread_func(self):
with torch.cuda.stream(self.forward_stream): try:
self.forward_thread_func_() with torch.cuda.stream(self.forward_stream):
self.forward_thread_func_()
except Exception:
traceback = get_exception_traceback()
logger.error(f"TpModelWorkerClient hit an exception: {traceback}")
self.parent_process.send_signal(signal.SIGQUIT)
@torch.no_grad() @torch.no_grad()
def forward_thread_func_(self): def forward_thread_func_(self):
......
...@@ -23,6 +23,8 @@ import json ...@@ -23,6 +23,8 @@ import json
import logging import logging
import multiprocessing as mp import multiprocessing as mp
import os import os
import signal
import sys
import threading import threading
import time import time
from http import HTTPStatus from http import HTTPStatus
...@@ -79,7 +81,7 @@ from sglang.srt.utils import ( ...@@ -79,7 +81,7 @@ from sglang.srt.utils import (
configure_logger, configure_logger,
delete_directory, delete_directory,
is_port_available, is_port_available,
kill_child_process, kill_process_tree,
maybe_set_triton_cache_manager, maybe_set_triton_cache_manager,
prepare_model_and_tokenizer, prepare_model_and_tokenizer,
set_prometheus_multiproc_dir, set_prometheus_multiproc_dir,
...@@ -572,6 +574,15 @@ def _set_envs_and_config(server_args: ServerArgs): ...@@ -572,6 +574,15 @@ def _set_envs_and_config(server_args: ServerArgs):
"at https://docs.flashinfer.ai/installation.html.", "at https://docs.flashinfer.ai/installation.html.",
) )
# Register the signal handler.
# The child processes will send SIGQUIT to this process when any error happens
# This process then clean up the whole process tree
def sigquit_handler(signum, frame):
kill_process_tree(os.getpid())
signal.signal(signal.SIGQUIT, sigquit_handler)
# Set mp start method
mp.set_start_method("spawn", force=True) mp.set_start_method("spawn", force=True)
...@@ -598,7 +609,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer): ...@@ -598,7 +609,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
if pipe_finish_writer is not None: if pipe_finish_writer is not None:
pipe_finish_writer.send(last_traceback) pipe_finish_writer.send(last_traceback)
logger.error(f"Initialization failed. warmup error: {last_traceback}") logger.error(f"Initialization failed. warmup error: {last_traceback}")
kill_child_process(include_self=True) kill_process_tree(os.getpid())
return return
model_info = res.json() model_info = res.json()
...@@ -631,7 +642,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer): ...@@ -631,7 +642,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
if pipe_finish_writer is not None: if pipe_finish_writer is not None:
pipe_finish_writer.send(last_traceback) pipe_finish_writer.send(last_traceback)
logger.error(f"Initialization failed. warmup error: {last_traceback}") logger.error(f"Initialization failed. warmup error: {last_traceback}")
kill_child_process(include_self=True) kill_process_tree(os.getpid())
return return
# logger.info(f"{res.json()=}") # logger.info(f"{res.json()=}")
...@@ -700,7 +711,7 @@ class Runtime: ...@@ -700,7 +711,7 @@ class Runtime:
def shutdown(self): def shutdown(self):
if self.pid is not None: if self.pid is not None:
kill_child_process(self.pid, include_self=True) kill_process_tree(self.pid)
self.pid = None self.pid = None
def cache_prefix(self, prefix: str): def cache_prefix(self, prefix: str):
...@@ -924,7 +935,7 @@ class Engine: ...@@ -924,7 +935,7 @@ class Engine:
return ret return ret
def shutdown(self): def shutdown(self):
kill_child_process() kill_process_tree(os.getpid(), include_parent=False)
def get_tokenizer(self): def get_tokenizer(self):
global tokenizer_manager global tokenizer_manager
......
...@@ -443,26 +443,14 @@ def assert_pkg_version(pkg: str, min_version: str, message: str): ...@@ -443,26 +443,14 @@ def assert_pkg_version(pkg: str, min_version: str, message: str):
) )
def kill_parent_process(): def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = None):
"""Kill the parent process and all children of the parent process.""" """Kill the process and all its child processes."""
current_process = psutil.Process() if parent_pid is None:
parent_process = current_process.parent() parent_pid = os.getpid()
kill_child_process( include_parent = False
parent_process.pid, include_self=True, skip_pid=current_process.pid
)
try:
current_process.kill()
except psutil.NoSuchProcess:
pass
def kill_child_process(pid=None, include_self=False, skip_pid=None):
"""Kill the process and all its children process."""
if pid is None:
pid = os.getpid()
try: try:
itself = psutil.Process(pid) itself = psutil.Process(parent_pid)
except psutil.NoSuchProcess: except psutil.NoSuchProcess:
return return
...@@ -475,13 +463,13 @@ def kill_child_process(pid=None, include_self=False, skip_pid=None): ...@@ -475,13 +463,13 @@ def kill_child_process(pid=None, include_self=False, skip_pid=None):
except psutil.NoSuchProcess: except psutil.NoSuchProcess:
pass pass
if include_self: if include_parent:
try: try:
itself.kill() itself.kill()
# Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes), # Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes),
# so we send an additional signal to kill them. # so we send an additional signal to kill them.
itself.send_signal(signal.SIGINT) itself.send_signal(signal.SIGQUIT)
except psutil.NoSuchProcess: except psutil.NoSuchProcess:
pass pass
......
...@@ -22,7 +22,7 @@ from sglang.bench_serving import run_benchmark ...@@ -22,7 +22,7 @@ from sglang.bench_serving import run_benchmark
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.lang.backend.openai import OpenAI from sglang.lang.backend.openai import OpenAI
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.srt.utils import get_bool_env_var, kill_child_process from sglang.srt.utils import get_bool_env_var, kill_process_tree
from sglang.test.run_eval import run_eval from sglang.test.run_eval import run_eval
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
...@@ -504,7 +504,7 @@ def run_unittest_files(files: List[str], timeout_per_file: float): ...@@ -504,7 +504,7 @@ def run_unittest_files(files: List[str], timeout_per_file: float):
) )
assert ret_code == 0 assert ret_code == 0
except TimeoutError: except TimeoutError:
kill_child_process(process.pid, include_self=True) kill_process_tree(process.pid)
time.sleep(5) time.sleep(5)
print( print(
f"\nTimeout after {timeout_per_file} seconds when running {filename}\n", f"\nTimeout after {timeout_per_file} seconds when running {filename}\n",
...@@ -578,7 +578,7 @@ def run_bench_serving( ...@@ -578,7 +578,7 @@ def run_bench_serving(
run_benchmark(warmup_args) run_benchmark(warmup_args)
res = run_benchmark(args) res = run_benchmark(args)
finally: finally:
kill_child_process(process.pid, include_self=True) kill_process_tree(process.pid)
assert res["completed"] == num_prompts assert res["completed"] == num_prompts
return res return res
...@@ -611,7 +611,7 @@ def run_bench_one_batch(model, other_args): ...@@ -611,7 +611,7 @@ def run_bench_one_batch(model, other_args):
lastline = output.split("\n")[-3] lastline = output.split("\n")[-3]
output_throughput = float(lastline.split(" ")[-2]) output_throughput = float(lastline.split(" ")[-2])
finally: finally:
kill_child_process(process.pid, include_self=True) kill_process_tree(process.pid)
return output_throughput return output_throughput
...@@ -710,8 +710,8 @@ def run_and_check_memory_leak( ...@@ -710,8 +710,8 @@ def run_and_check_memory_leak(
workload_func(base_url, model) workload_func(base_url, model)
# Clean up everything # Clean up everything
kill_child_process(process.pid, include_self=True) kill_process_tree(process.pid)
kill_child_process(process.pid, include_self=True) kill_process_tree(process.pid)
stdout.close() stdout.close()
stderr.close() stderr.close()
if os.path.exists(STDOUT_FILENAME): if os.path.exists(STDOUT_FILENAME):
......
...@@ -348,9 +348,9 @@ def wait_for_server(base_url: str, timeout: int = None) -> None: ...@@ -348,9 +348,9 @@ def wait_for_server(base_url: str, timeout: int = None) -> None:
def terminate_process(process): def terminate_process(process):
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_process_tree
kill_child_process(process.pid, include_self=True) kill_process_tree(process.pid)
def print_highlight(html_content: str): def print_highlight(html_content: str):
......
...@@ -5,7 +5,7 @@ from types import SimpleNamespace ...@@ -5,7 +5,7 @@ from types import SimpleNamespace
import requests import requests
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval from sglang.test.run_eval import run_eval
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST,
...@@ -79,7 +79,7 @@ class TestEvalAccuracyMini(unittest.TestCase): ...@@ -79,7 +79,7 @@ class TestEvalAccuracyMini(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True) kill_process_tree(cls.process.pid)
def test_mmlu(self): def test_mmlu(self):
args = SimpleNamespace( args = SimpleNamespace(
......
...@@ -4,7 +4,7 @@ from multiprocessing import Process ...@@ -4,7 +4,7 @@ from multiprocessing import Process
import requests import requests
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
...@@ -31,7 +31,7 @@ class TestBatchPenalizerE2E(unittest.TestCase): ...@@ -31,7 +31,7 @@ class TestBatchPenalizerE2E(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True) kill_process_tree(cls.process.pid)
def run_decode( def run_decode(
self, self,
......
...@@ -4,7 +4,7 @@ import unittest ...@@ -4,7 +4,7 @@ import unittest
import openai import openai
import requests import requests
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
...@@ -44,7 +44,7 @@ class TestCacheReport(unittest.TestCase): ...@@ -44,7 +44,7 @@ class TestCacheReport(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True) kill_process_tree(cls.process.pid)
def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1): def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1):
response = requests.post( response = requests.post(
......
...@@ -4,7 +4,7 @@ from types import SimpleNamespace ...@@ -4,7 +4,7 @@ from types import SimpleNamespace
import requests import requests
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval from sglang.test.run_eval import run_eval
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST,
...@@ -28,7 +28,7 @@ class TestDataParallelism(unittest.TestCase): ...@@ -28,7 +28,7 @@ class TestDataParallelism(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True) kill_process_tree(cls.process.pid)
def test_mmlu(self): def test_mmlu(self):
args = SimpleNamespace( args = SimpleNamespace(
......
...@@ -2,7 +2,7 @@ import os ...@@ -2,7 +2,7 @@ import os
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval from sglang.test.run_eval import run_eval
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST,
...@@ -45,7 +45,7 @@ class TestDoubleSparsity(unittest.TestCase): ...@@ -45,7 +45,7 @@ class TestDoubleSparsity(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True) kill_process_tree(cls.process.pid)
def test_mmlu(self): def test_mmlu(self):
args = SimpleNamespace( args = SimpleNamespace(
......
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval from sglang.test.run_eval import run_eval
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MLA_MODEL_NAME_FOR_TEST, DEFAULT_MLA_MODEL_NAME_FOR_TEST,
...@@ -30,7 +30,7 @@ class TestDPAttention(unittest.TestCase): ...@@ -30,7 +30,7 @@ class TestDPAttention(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True) kill_process_tree(cls.process.pid)
def test_mmlu(self): def test_mmlu(self):
args = SimpleNamespace( args = SimpleNamespace(
......
...@@ -3,7 +3,7 @@ import unittest ...@@ -3,7 +3,7 @@ import unittest
import openai import openai
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
...@@ -28,7 +28,7 @@ class TestOpenAIServer(unittest.TestCase): ...@@ -28,7 +28,7 @@ class TestOpenAIServer(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True) kill_process_tree(cls.process.pid)
def run_embedding(self, use_list_input, token_input): def run_embedding(self, use_list_input, token_input):
client = openai.Client(api_key=self.api_key, base_url=self.base_url) client = openai.Client(api_key=self.api_key, base_url=self.base_url)
......
...@@ -6,7 +6,7 @@ python -m unittest test_eval_accuracy_large.TestEvalAccuracyLarge.test_mmlu ...@@ -6,7 +6,7 @@ python -m unittest test_eval_accuracy_large.TestEvalAccuracyLarge.test_mmlu
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval from sglang.test.run_eval import run_eval
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST,
...@@ -30,7 +30,7 @@ class TestEvalAccuracyLarge(unittest.TestCase): ...@@ -30,7 +30,7 @@ class TestEvalAccuracyLarge(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True) kill_process_tree(cls.process.pid)
def test_mmlu(self): def test_mmlu(self):
args = SimpleNamespace( args = SimpleNamespace(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment