"docs/vscode:/vscode.git/clone" did not exist on "75d53cc83966b4046e5a329ddf7baa6aa24f52e2"
Unverified Commit 86fc0d79 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Add a watch dog thread (#1816)

parent 1be853ee
...@@ -550,4 +550,4 @@ if __name__ == "__main__": ...@@ -550,4 +550,4 @@ if __name__ == "__main__":
except Exception as e: except Exception as e:
raise e raise e
finally: finally:
kill_child_process(os.getpid(), including_parent=False) kill_child_process()
...@@ -15,7 +15,6 @@ import dataclasses ...@@ -15,7 +15,6 @@ import dataclasses
import itertools import itertools
import json import json
import multiprocessing import multiprocessing
import os
import time import time
from typing import Tuple from typing import Tuple
...@@ -70,7 +69,7 @@ def launch_server_internal(server_args): ...@@ -70,7 +69,7 @@ def launch_server_internal(server_args):
except Exception as e: except Exception as e:
raise e raise e
finally: finally:
kill_child_process(os.getpid(), including_parent=False) kill_child_process()
def launch_server_process(server_args: ServerArgs): def launch_server_process(server_args: ServerArgs):
...@@ -176,7 +175,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs): ...@@ -176,7 +175,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
) )
finally: finally:
if proc: if proc:
kill_child_process(proc.pid) kill_child_process(proc.pid, include_self=True)
print(f"\nResults are saved to {bench_args.result_filename}") print(f"\nResults are saved to {bench_args.result_filename}")
......
...@@ -15,4 +15,4 @@ if __name__ == "__main__": ...@@ -15,4 +15,4 @@ if __name__ == "__main__":
except Exception as e: except Exception as e:
raise e raise e
finally: finally:
kill_child_process(os.getpid(), including_parent=False) kill_child_process()
...@@ -18,6 +18,7 @@ limitations under the License. ...@@ -18,6 +18,7 @@ limitations under the License.
import json import json
import logging import logging
import os import os
import threading
import time import time
import warnings import warnings
from collections import deque from collections import deque
...@@ -222,10 +223,11 @@ class Scheduler: ...@@ -222,10 +223,11 @@ class Scheduler:
self.waiting_queue: List[Req] = [] self.waiting_queue: List[Req] = []
self.running_batch: Optional[ScheduleBatch] = None self.running_batch: Optional[ScheduleBatch] = None
self.cur_batch: Optional[ScheduleBatch] = None self.cur_batch: Optional[ScheduleBatch] = None
self.decode_forward_ct = 0 self.forward_ct = 0
self.stream_interval = server_args.stream_interval self.forward_ct_decode = 0
self.num_generated_tokens = 0 self.num_generated_tokens = 0
self.last_stats_tic = time.time() self.last_stats_tic = time.time()
self.stream_interval = server_args.stream_interval
# Init chunked prefill # Init chunked prefill
self.chunked_prefill_size = server_args.chunked_prefill_size self.chunked_prefill_size = server_args.chunked_prefill_size
...@@ -272,6 +274,11 @@ class Scheduler: ...@@ -272,6 +274,11 @@ class Scheduler:
self.batch_is_full = False self.batch_is_full = False
# Init watchdog thread
self.watchdog_timeout = server_args.watchdog_timeout
t = threading.Thread(target=self.watchdog_thread, daemon=True)
t.start()
# Init profiler # Init profiler
if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "": if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
self.profiler = None self.profiler = None
...@@ -289,6 +296,23 @@ class Scheduler: ...@@ -289,6 +296,23 @@ class Scheduler:
with_stack=True, with_stack=True,
) )
def watchdog_thread(self):
self.watchdog_last_forward_ct = 0
self.watchdog_last_time = time.time()
while True:
if self.cur_batch is not None:
if self.watchdog_last_forward_ct == self.forward_ct:
if time.time() > self.watchdog_last_time + self.watchdog_timeout:
logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
break
else:
self.watchdog_last_forward_ct = self.forward_ct
self.watchdog_last_time = time.time()
time.sleep(self.watchdog_timeout / 2)
kill_parent_process()
@torch.inference_mode() @torch.inference_mode()
def event_loop_normal(self): def event_loop_normal(self):
"""A normal blocking scheduler loop.""" """A normal blocking scheduler loop."""
...@@ -299,6 +323,7 @@ class Scheduler: ...@@ -299,6 +323,7 @@ class Scheduler:
self.process_input_requests(recv_reqs) self.process_input_requests(recv_reqs)
batch = self.get_next_batch_to_run() batch = self.get_next_batch_to_run()
self.cur_batch = batch
if batch: if batch:
result = self.run_batch(batch) result = self.run_batch(batch)
...@@ -746,6 +771,8 @@ class Scheduler: ...@@ -746,6 +771,8 @@ class Scheduler:
def run_batch(self, batch: ScheduleBatch): def run_batch(self, batch: ScheduleBatch):
"""Run a batch.""" """Run a batch."""
self.forward_ct += 1
if self.is_generation: if self.is_generation:
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0: if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
...@@ -778,6 +805,7 @@ class Scheduler: ...@@ -778,6 +805,7 @@ class Scheduler:
self.process_batch_result_prefill(batch, result) self.process_batch_result_prefill(batch, result)
def process_batch_result_prefill(self, batch: ScheduleBatch, result): def process_batch_result_prefill(self, batch: ScheduleBatch, result):
if self.is_generation: if self.is_generation:
logits_output, next_token_ids, bid = result logits_output, next_token_ids, bid = result
...@@ -890,8 +918,8 @@ class Scheduler: ...@@ -890,8 +918,8 @@ class Scheduler:
self.token_to_kv_pool.free_group_end() self.token_to_kv_pool.free_group_end()
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30) self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0: if self.tp_rank == 0 and self.forward_ct_decode % 40 == 0:
self.print_decode_stats() self.print_decode_stats()
def add_logprob_return_values( def add_logprob_return_values(
...@@ -984,7 +1012,7 @@ class Scheduler: ...@@ -984,7 +1012,7 @@ class Scheduler:
else: # embedding or reward model else: # embedding or reward model
output_embeddings = [] output_embeddings = []
is_stream_iter = self.decode_forward_ct % self.stream_interval == 0 is_stream_iter = self.forward_ct_decode % self.stream_interval == 0
for req in reqs: for req in reqs:
if req.finished() or ( if req.finished() or (
......
...@@ -441,7 +441,7 @@ def launch_server( ...@@ -441,7 +441,7 @@ def launch_server(
# Send a warmup request # Send a warmup request
t = threading.Thread( t = threading.Thread(
target=_wait_and_warmup, args=(server_args, pipe_finish_writer, os.getpid()) target=_wait_and_warmup, args=(server_args, pipe_finish_writer)
) )
t.start() t.start()
...@@ -496,7 +496,7 @@ def _set_envs_and_config(server_args: ServerArgs): ...@@ -496,7 +496,7 @@ def _set_envs_and_config(server_args: ServerArgs):
mp.set_start_method("spawn", force=True) mp.set_start_method("spawn", force=True)
def _wait_and_warmup(server_args, pipe_finish_writer, pid): def _wait_and_warmup(server_args, pipe_finish_writer):
headers = {} headers = {}
url = server_args.url() url = server_args.url()
if server_args.api_key: if server_args.api_key:
...@@ -519,7 +519,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid): ...@@ -519,7 +519,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
if pipe_finish_writer is not None: if pipe_finish_writer is not None:
pipe_finish_writer.send(last_traceback) pipe_finish_writer.send(last_traceback)
logger.error(f"Initialization failed. warmup error: {last_traceback}") logger.error(f"Initialization failed. warmup error: {last_traceback}")
kill_child_process(pid, including_parent=False) kill_child_process(include_self=True)
return return
model_info = res.json() model_info = res.json()
...@@ -551,7 +551,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid): ...@@ -551,7 +551,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
if pipe_finish_writer is not None: if pipe_finish_writer is not None:
pipe_finish_writer.send(last_traceback) pipe_finish_writer.send(last_traceback)
logger.error(f"Initialization failed. warmup error: {last_traceback}") logger.error(f"Initialization failed. warmup error: {last_traceback}")
kill_child_process(pid, including_parent=False) kill_child_process(include_self=True)
return return
# logger.info(f"{res.json()=}") # logger.info(f"{res.json()=}")
...@@ -617,7 +617,7 @@ class Runtime: ...@@ -617,7 +617,7 @@ class Runtime:
def shutdown(self): def shutdown(self):
if self.pid is not None: if self.pid is not None:
kill_child_process(self.pid) kill_child_process(self.pid, include_self=True)
self.pid = None self.pid = None
def cache_prefix(self, prefix: str): def cache_prefix(self, prefix: str):
...@@ -834,7 +834,7 @@ class Engine: ...@@ -834,7 +834,7 @@ class Engine:
return ret return ret
def shutdown(self): def shutdown(self):
kill_child_process(os.getpid(), including_parent=False) kill_child_process(include_self=True)
def get_tokenizer(self): def get_tokenizer(self):
global tokenizer_manager global tokenizer_manager
......
...@@ -74,6 +74,7 @@ class ServerArgs: ...@@ -74,6 +74,7 @@ class ServerArgs:
api_key: Optional[str] = None api_key: Optional[str] = None
file_storage_pth: str = "SGLang_storage" file_storage_pth: str = "SGLang_storage"
enable_cache_report: bool = False enable_cache_report: bool = False
watchdog_timeout: float = 600
# Data parallelism # Data parallelism
dp_size: int = 1 dp_size: int = 1
...@@ -429,6 +430,12 @@ class ServerArgs: ...@@ -429,6 +430,12 @@ class ServerArgs:
action="store_true", action="store_true",
help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.", help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
) )
parser.add_argument(
"--watchdog-timeout",
type=float,
default=ServerArgs.watchdog_timeout,
help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
)
# Data parallelism # Data parallelism
parser.add_argument( parser.add_argument(
......
...@@ -398,17 +398,26 @@ def kill_parent_process(): ...@@ -398,17 +398,26 @@ def kill_parent_process():
"""Kill the parent process and all children of the parent process.""" """Kill the parent process and all children of the parent process."""
current_process = psutil.Process() current_process = psutil.Process()
parent_process = current_process.parent() parent_process = current_process.parent()
kill_child_process(parent_process.pid, skip_pid=current_process.pid) kill_child_process(
parent_process.pid, include_self=True, skip_pid=current_process.pid
)
try:
current_process.kill()
except psutil.NoSuchProcess:
pass
def kill_child_process(pid, including_parent=True, skip_pid=None): def kill_child_process(pid=None, include_self=False, skip_pid=None):
"""Kill the process and all its children process.""" """Kill the process and all its children process."""
if pid is None:
pid = os.getpid()
try: try:
parent = psutil.Process(pid) itself = psutil.Process(pid)
except psutil.NoSuchProcess: except psutil.NoSuchProcess:
return return
children = parent.children(recursive=True) children = itself.children(recursive=True)
for child in children: for child in children:
if child.pid == skip_pid: if child.pid == skip_pid:
continue continue
...@@ -417,9 +426,9 @@ def kill_child_process(pid, including_parent=True, skip_pid=None): ...@@ -417,9 +426,9 @@ def kill_child_process(pid, including_parent=True, skip_pid=None):
except psutil.NoSuchProcess: except psutil.NoSuchProcess:
pass pass
if including_parent: if include_self:
try: try:
parent.kill() itself.kill()
except psutil.NoSuchProcess: except psutil.NoSuchProcess:
pass pass
......
...@@ -495,7 +495,7 @@ def run_unittest_files(files: List[str], timeout_per_file: float): ...@@ -495,7 +495,7 @@ def run_unittest_files(files: List[str], timeout_per_file: float):
) )
assert ret_code == 0 assert ret_code == 0
except TimeoutError: except TimeoutError:
kill_child_process(process.pid) kill_child_process(process.pid, include_self=True)
time.sleep(5) time.sleep(5)
print( print(
f"\nTimeout after {timeout_per_file} seconds when running {filename}\n", f"\nTimeout after {timeout_per_file} seconds when running {filename}\n",
...@@ -563,7 +563,7 @@ def run_bench_serving( ...@@ -563,7 +563,7 @@ def run_bench_serving(
try: try:
res = run_benchmark(args) res = run_benchmark(args)
finally: finally:
kill_child_process(process.pid) kill_child_process(process.pid, include_self=True)
assert res["completed"] == num_prompts assert res["completed"] == num_prompts
return res return res
...@@ -596,7 +596,7 @@ def run_bench_latency(model, other_args): ...@@ -596,7 +596,7 @@ def run_bench_latency(model, other_args):
lastline = output.split("\n")[-3] lastline = output.split("\n")[-3]
output_throughput = float(lastline.split(" ")[-2]) output_throughput = float(lastline.split(" ")[-2])
finally: finally:
kill_child_process(process.pid) kill_child_process(process.pid, include_self=True)
return output_throughput return output_throughput
...@@ -707,8 +707,8 @@ def run_mmlu_test( ...@@ -707,8 +707,8 @@ def run_mmlu_test(
pass pass
# Clean up everything # Clean up everything
kill_child_process(process.pid) kill_child_process(process.pid, include_self=True)
kill_child_process(process.pid) kill_child_process(process.pid, include_self=True)
stdout.close() stdout.close()
stderr.close() stderr.close()
if os.path.exists(STDOUT_FILENAME): if os.path.exists(STDOUT_FILENAME):
......
...@@ -31,7 +31,7 @@ class TestBatchPenalizerE2E(unittest.TestCase): ...@@ -31,7 +31,7 @@ class TestBatchPenalizerE2E(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid) kill_child_process(cls.process.pid, include_self=True)
def run_decode( def run_decode(
self, self,
......
...@@ -45,7 +45,7 @@ class TestCacheReport(unittest.TestCase): ...@@ -45,7 +45,7 @@ class TestCacheReport(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid) kill_child_process(cls.process.pid, include_self=True)
def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1): def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1):
response = requests.post( response = requests.post(
......
...@@ -25,7 +25,7 @@ class TestDataParallelism(unittest.TestCase): ...@@ -25,7 +25,7 @@ class TestDataParallelism(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid) kill_child_process(cls.process.pid, include_self=True)
def test_mmlu(self): def test_mmlu(self):
args = SimpleNamespace( args = SimpleNamespace(
......
...@@ -43,7 +43,7 @@ class TestDoubleSparsity(unittest.TestCase): ...@@ -43,7 +43,7 @@ class TestDoubleSparsity(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid) kill_child_process(cls.process.pid, include_self=True)
def test_mmlu(self): def test_mmlu(self):
args = SimpleNamespace( args = SimpleNamespace(
......
...@@ -28,7 +28,7 @@ class TestOpenAIServer(unittest.TestCase): ...@@ -28,7 +28,7 @@ class TestOpenAIServer(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid) kill_child_process(cls.process.pid, include_self=True)
def run_embedding(self, use_list_input, token_input): def run_embedding(self, use_list_input, token_input):
client = openai.Client(api_key=self.api_key, base_url=self.base_url) client = openai.Client(api_key=self.api_key, base_url=self.base_url)
......
...@@ -30,7 +30,7 @@ class TestEvalAccuracyLarge(unittest.TestCase): ...@@ -30,7 +30,7 @@ class TestEvalAccuracyLarge(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid) kill_child_process(cls.process.pid, include_self=True)
def test_mmlu(self): def test_mmlu(self):
args = SimpleNamespace( args = SimpleNamespace(
......
...@@ -25,7 +25,7 @@ class TestEvalAccuracyLargeChunkedPrefill(unittest.TestCase): ...@@ -25,7 +25,7 @@ class TestEvalAccuracyLargeChunkedPrefill(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid) kill_child_process(cls.process.pid, include_self=True)
def test_mmlu(self): def test_mmlu(self):
args = SimpleNamespace( args = SimpleNamespace(
......
...@@ -31,7 +31,7 @@ class TestEvalAccuracyLargeChunkedPrefill(unittest.TestCase): ...@@ -31,7 +31,7 @@ class TestEvalAccuracyLargeChunkedPrefill(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid) kill_child_process(cls.process.pid, include_self=True)
def test_mmlu(self): def test_mmlu(self):
args = SimpleNamespace( args = SimpleNamespace(
......
...@@ -22,7 +22,7 @@ class TestEvalAccuracyMini(unittest.TestCase): ...@@ -22,7 +22,7 @@ class TestEvalAccuracyMini(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid) kill_child_process(cls.process.pid, include_self=True)
def test_mmlu(self): def test_mmlu(self):
args = SimpleNamespace( args = SimpleNamespace(
......
...@@ -41,7 +41,7 @@ class TestJSONConstrained(unittest.TestCase): ...@@ -41,7 +41,7 @@ class TestJSONConstrained(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid) kill_child_process(cls.process.pid, include_self=True)
def run_decode(self, json_schema, return_logprob=False, top_logprobs_num=0, n=1): def run_decode(self, json_schema, return_logprob=False, top_logprobs_num=0, n=1):
response = requests.post( response = requests.post(
......
...@@ -42,7 +42,7 @@ class TestLargeMaxNewTokens(unittest.TestCase): ...@@ -42,7 +42,7 @@ class TestLargeMaxNewTokens(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid) kill_child_process(cls.process.pid, include_self=True)
cls.stdout.close() cls.stdout.close()
cls.stderr.close() cls.stderr.close()
os.remove("stdout.txt") os.remove("stdout.txt")
......
...@@ -32,7 +32,7 @@ class TestMatchedStop(unittest.TestCase): ...@@ -32,7 +32,7 @@ class TestMatchedStop(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid) kill_child_process(cls.process.pid, include_self=True)
def run_completions_generation( def run_completions_generation(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment