"vscode:/vscode.git/clone" did not exist on "a4d71e75a9ec73c7713d061612894efd25f29073"
Unverified Commit 86fc0d79 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Add a watch dog thread (#1816)

parent 1be853ee
......@@ -550,4 +550,4 @@ if __name__ == "__main__":
except Exception as e:
raise e
finally:
kill_child_process(os.getpid(), including_parent=False)
kill_child_process()
......@@ -15,7 +15,6 @@ import dataclasses
import itertools
import json
import multiprocessing
import os
import time
from typing import Tuple
......@@ -70,7 +69,7 @@ def launch_server_internal(server_args):
except Exception as e:
raise e
finally:
kill_child_process(os.getpid(), including_parent=False)
kill_child_process()
def launch_server_process(server_args: ServerArgs):
......@@ -176,7 +175,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
)
finally:
if proc:
kill_child_process(proc.pid)
kill_child_process(proc.pid, include_self=True)
print(f"\nResults are saved to {bench_args.result_filename}")
......
......@@ -15,4 +15,4 @@ if __name__ == "__main__":
except Exception as e:
raise e
finally:
kill_child_process(os.getpid(), including_parent=False)
kill_child_process()
......@@ -18,6 +18,7 @@ limitations under the License.
import json
import logging
import os
import threading
import time
import warnings
from collections import deque
......@@ -222,10 +223,11 @@ class Scheduler:
self.waiting_queue: List[Req] = []
self.running_batch: Optional[ScheduleBatch] = None
self.cur_batch: Optional[ScheduleBatch] = None
self.decode_forward_ct = 0
self.stream_interval = server_args.stream_interval
self.forward_ct = 0
self.forward_ct_decode = 0
self.num_generated_tokens = 0
self.last_stats_tic = time.time()
self.stream_interval = server_args.stream_interval
# Init chunked prefill
self.chunked_prefill_size = server_args.chunked_prefill_size
......@@ -272,6 +274,11 @@ class Scheduler:
self.batch_is_full = False
# Init watchdog thread
self.watchdog_timeout = server_args.watchdog_timeout
t = threading.Thread(target=self.watchdog_thread, daemon=True)
t.start()
# Init profiler
if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
self.profiler = None
......@@ -289,6 +296,23 @@ class Scheduler:
with_stack=True,
)
def watchdog_thread(self):
self.watchdog_last_forward_ct = 0
self.watchdog_last_time = time.time()
while True:
if self.cur_batch is not None:
if self.watchdog_last_forward_ct == self.forward_ct:
if time.time() > self.watchdog_last_time + self.watchdog_timeout:
logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
break
else:
self.watchdog_last_forward_ct = self.forward_ct
self.watchdog_last_time = time.time()
time.sleep(self.watchdog_timeout / 2)
kill_parent_process()
@torch.inference_mode()
def event_loop_normal(self):
"""A normal blocking scheduler loop."""
......@@ -299,6 +323,7 @@ class Scheduler:
self.process_input_requests(recv_reqs)
batch = self.get_next_batch_to_run()
self.cur_batch = batch
if batch:
result = self.run_batch(batch)
......@@ -746,6 +771,8 @@ class Scheduler:
def run_batch(self, batch: ScheduleBatch):
"""Run a batch."""
self.forward_ct += 1
if self.is_generation:
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
model_worker_batch = batch.get_model_worker_batch()
......@@ -778,6 +805,7 @@ class Scheduler:
self.process_batch_result_prefill(batch, result)
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
if self.is_generation:
logits_output, next_token_ids, bid = result
......@@ -890,8 +918,8 @@ class Scheduler:
self.token_to_kv_pool.free_group_end()
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
if self.tp_rank == 0 and self.forward_ct_decode % 40 == 0:
self.print_decode_stats()
def add_logprob_return_values(
......@@ -984,7 +1012,7 @@ class Scheduler:
else: # embedding or reward model
output_embeddings = []
is_stream_iter = self.decode_forward_ct % self.stream_interval == 0
is_stream_iter = self.forward_ct_decode % self.stream_interval == 0
for req in reqs:
if req.finished() or (
......
......@@ -441,7 +441,7 @@ def launch_server(
# Send a warmup request
t = threading.Thread(
target=_wait_and_warmup, args=(server_args, pipe_finish_writer, os.getpid())
target=_wait_and_warmup, args=(server_args, pipe_finish_writer)
)
t.start()
......@@ -496,7 +496,7 @@ def _set_envs_and_config(server_args: ServerArgs):
mp.set_start_method("spawn", force=True)
def _wait_and_warmup(server_args, pipe_finish_writer, pid):
def _wait_and_warmup(server_args, pipe_finish_writer):
headers = {}
url = server_args.url()
if server_args.api_key:
......@@ -519,7 +519,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
if pipe_finish_writer is not None:
pipe_finish_writer.send(last_traceback)
logger.error(f"Initialization failed. warmup error: {last_traceback}")
kill_child_process(pid, including_parent=False)
kill_child_process(include_self=True)
return
model_info = res.json()
......@@ -551,7 +551,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
if pipe_finish_writer is not None:
pipe_finish_writer.send(last_traceback)
logger.error(f"Initialization failed. warmup error: {last_traceback}")
kill_child_process(pid, including_parent=False)
kill_child_process(include_self=True)
return
# logger.info(f"{res.json()=}")
......@@ -617,7 +617,7 @@ class Runtime:
def shutdown(self):
if self.pid is not None:
kill_child_process(self.pid)
kill_child_process(self.pid, include_self=True)
self.pid = None
def cache_prefix(self, prefix: str):
......@@ -834,7 +834,7 @@ class Engine:
return ret
def shutdown(self):
kill_child_process(os.getpid(), including_parent=False)
kill_child_process(include_self=True)
def get_tokenizer(self):
global tokenizer_manager
......
......@@ -74,6 +74,7 @@ class ServerArgs:
api_key: Optional[str] = None
file_storage_pth: str = "SGLang_storage"
enable_cache_report: bool = False
watchdog_timeout: float = 600
# Data parallelism
dp_size: int = 1
......@@ -429,6 +430,12 @@ class ServerArgs:
action="store_true",
help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
)
parser.add_argument(
"--watchdog-timeout",
type=float,
default=ServerArgs.watchdog_timeout,
help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
)
# Data parallelism
parser.add_argument(
......
......@@ -398,17 +398,26 @@ def kill_parent_process():
"""Kill the parent process and all children of the parent process."""
current_process = psutil.Process()
parent_process = current_process.parent()
kill_child_process(parent_process.pid, skip_pid=current_process.pid)
kill_child_process(
parent_process.pid, include_self=True, skip_pid=current_process.pid
)
try:
current_process.kill()
except psutil.NoSuchProcess:
pass
def kill_child_process(pid, including_parent=True, skip_pid=None):
def kill_child_process(pid=None, include_self=False, skip_pid=None):
"""Kill the process and all its children process."""
if pid is None:
pid = os.getpid()
try:
parent = psutil.Process(pid)
itself = psutil.Process(pid)
except psutil.NoSuchProcess:
return
children = parent.children(recursive=True)
children = itself.children(recursive=True)
for child in children:
if child.pid == skip_pid:
continue
......@@ -417,9 +426,9 @@ def kill_child_process(pid, including_parent=True, skip_pid=None):
except psutil.NoSuchProcess:
pass
if including_parent:
if include_self:
try:
parent.kill()
itself.kill()
except psutil.NoSuchProcess:
pass
......
......@@ -495,7 +495,7 @@ def run_unittest_files(files: List[str], timeout_per_file: float):
)
assert ret_code == 0
except TimeoutError:
kill_child_process(process.pid)
kill_child_process(process.pid, include_self=True)
time.sleep(5)
print(
f"\nTimeout after {timeout_per_file} seconds when running {filename}\n",
......@@ -563,7 +563,7 @@ def run_bench_serving(
try:
res = run_benchmark(args)
finally:
kill_child_process(process.pid)
kill_child_process(process.pid, include_self=True)
assert res["completed"] == num_prompts
return res
......@@ -596,7 +596,7 @@ def run_bench_latency(model, other_args):
lastline = output.split("\n")[-3]
output_throughput = float(lastline.split(" ")[-2])
finally:
kill_child_process(process.pid)
kill_child_process(process.pid, include_self=True)
return output_throughput
......@@ -707,8 +707,8 @@ def run_mmlu_test(
pass
# Clean up everything
kill_child_process(process.pid)
kill_child_process(process.pid)
kill_child_process(process.pid, include_self=True)
kill_child_process(process.pid, include_self=True)
stdout.close()
stderr.close()
if os.path.exists(STDOUT_FILENAME):
......
......@@ -31,7 +31,7 @@ class TestBatchPenalizerE2E(unittest.TestCase):
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
kill_child_process(cls.process.pid, include_self=True)
def run_decode(
self,
......
......@@ -45,7 +45,7 @@ class TestCacheReport(unittest.TestCase):
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
kill_child_process(cls.process.pid, include_self=True)
def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1):
response = requests.post(
......
......@@ -25,7 +25,7 @@ class TestDataParallelism(unittest.TestCase):
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
kill_child_process(cls.process.pid, include_self=True)
def test_mmlu(self):
args = SimpleNamespace(
......
......@@ -43,7 +43,7 @@ class TestDoubleSparsity(unittest.TestCase):
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
kill_child_process(cls.process.pid, include_self=True)
def test_mmlu(self):
args = SimpleNamespace(
......
......@@ -28,7 +28,7 @@ class TestOpenAIServer(unittest.TestCase):
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
kill_child_process(cls.process.pid, include_self=True)
def run_embedding(self, use_list_input, token_input):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
......
......@@ -30,7 +30,7 @@ class TestEvalAccuracyLarge(unittest.TestCase):
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
kill_child_process(cls.process.pid, include_self=True)
def test_mmlu(self):
args = SimpleNamespace(
......
......@@ -25,7 +25,7 @@ class TestEvalAccuracyLargeChunkedPrefill(unittest.TestCase):
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
kill_child_process(cls.process.pid, include_self=True)
def test_mmlu(self):
args = SimpleNamespace(
......
......@@ -31,7 +31,7 @@ class TestEvalAccuracyLargeChunkedPrefill(unittest.TestCase):
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
kill_child_process(cls.process.pid, include_self=True)
def test_mmlu(self):
args = SimpleNamespace(
......
......@@ -22,7 +22,7 @@ class TestEvalAccuracyMini(unittest.TestCase):
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
kill_child_process(cls.process.pid, include_self=True)
def test_mmlu(self):
args = SimpleNamespace(
......
......@@ -41,7 +41,7 @@ class TestJSONConstrained(unittest.TestCase):
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
kill_child_process(cls.process.pid, include_self=True)
def run_decode(self, json_schema, return_logprob=False, top_logprobs_num=0, n=1):
response = requests.post(
......
......@@ -42,7 +42,7 @@ class TestLargeMaxNewTokens(unittest.TestCase):
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
kill_child_process(cls.process.pid, include_self=True)
cls.stdout.close()
cls.stderr.close()
os.remove("stdout.txt")
......
......@@ -32,7 +32,7 @@ class TestMatchedStop(unittest.TestCase):
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
kill_child_process(cls.process.pid, include_self=True)
def run_completions_generation(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment