Unverified Commit 243e78c2 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Benchmark][Bugfix] Fix race condtion when starting server for sweep benchmark (#32927)


Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent aac0b817
...@@ -29,6 +29,7 @@ def run_server( ...@@ -29,6 +29,7 @@ def run_server(
show_stdout: bool, show_stdout: bool,
serve_overrides: ParameterSweepItem, serve_overrides: ParameterSweepItem,
dry_run: bool, dry_run: bool,
server_ready_timeout: int = 300,
): ):
server_cmd = serve_overrides.apply_to_cmd(serve_cmd) server_cmd = serve_overrides.apply_to_cmd(serve_cmd)
...@@ -42,6 +43,7 @@ def run_server( ...@@ -42,6 +43,7 @@ def run_server(
return return
with ServerProcess(server_cmd, after_bench_cmd, show_stdout=show_stdout) as server: with ServerProcess(server_cmd, after_bench_cmd, show_stdout=show_stdout) as server:
server.wait_until_ready(timeout=server_ready_timeout)
yield server yield server
print("[END SERVER]") print("[END SERVER]")
...@@ -212,6 +214,7 @@ def run_combs( ...@@ -212,6 +214,7 @@ def run_combs(
num_runs: int, num_runs: int,
dry_run: bool, dry_run: bool,
links: list[tuple[str, str]], links: list[tuple[str, str]],
server_ready_timeout: int = 300,
): ):
all_data = list[dict[str, object]]() all_data = list[dict[str, object]]()
for serve_comb in serve_params: for serve_comb in serve_params:
...@@ -222,6 +225,7 @@ def run_combs( ...@@ -222,6 +225,7 @@ def run_combs(
show_stdout=show_stdout, show_stdout=show_stdout,
serve_overrides=serve_comb, serve_overrides=serve_comb,
dry_run=dry_run, dry_run=dry_run,
server_ready_timeout=server_ready_timeout,
) )
if _comb_needs_server(serve_comb, bench_params, output_dir) if _comb_needs_server(serve_comb, bench_params, output_dir)
else contextlib.nullcontext() else contextlib.nullcontext()
...@@ -272,6 +276,7 @@ class SweepServeArgs: ...@@ -272,6 +276,7 @@ class SweepServeArgs:
dry_run: bool dry_run: bool
resume: str | None resume: str | None
link_vars: list[tuple[str, str]] | None link_vars: list[tuple[str, str]] | None
server_ready_timeout: int
parser_name: ClassVar[str] = "serve" parser_name: ClassVar[str] = "serve"
parser_help: ClassVar[str] = "Run vLLM server benchmark under multiple settings." parser_help: ClassVar[str] = "Run vLLM server benchmark under multiple settings."
...@@ -312,6 +317,7 @@ class SweepServeArgs: ...@@ -312,6 +317,7 @@ class SweepServeArgs:
dry_run=args.dry_run, dry_run=args.dry_run,
resume=args.resume, resume=args.resume,
link_vars=link_vars, link_vars=link_vars,
server_ready_timeout=args.server_ready_timeout,
) )
@classmethod @classmethod
...@@ -341,6 +347,12 @@ class SweepServeArgs: ...@@ -341,6 +347,12 @@ class SweepServeArgs:
help="If set, logs the standard output of subcommands. " help="If set, logs the standard output of subcommands. "
"Useful for debugging but can be quite spammy.", "Useful for debugging but can be quite spammy.",
) )
parser.add_argument(
"--server-ready-timeout",
type=int,
default=300,
help="Timeout in seconds to wait for the server to become ready.",
)
parser.add_argument( parser.add_argument(
"--serve-params", "--serve-params",
type=str, type=str,
...@@ -431,6 +443,7 @@ def run_main(args: SweepServeArgs): ...@@ -431,6 +443,7 @@ def run_main(args: SweepServeArgs):
num_runs=args.num_runs, num_runs=args.num_runs,
dry_run=args.dry_run, dry_run=args.dry_run,
links=args.link_vars, links=args.link_vars,
server_ready_timeout=args.server_ready_timeout,
) )
except BaseException as exc: except BaseException as exc:
raise RuntimeError( raise RuntimeError(
......
...@@ -4,6 +4,7 @@ import contextlib ...@@ -4,6 +4,7 @@ import contextlib
import os import os
import signal import signal
import subprocess import subprocess
import time
from types import TracebackType from types import TracebackType
import requests import requests
...@@ -88,6 +89,29 @@ class ServerProcess: ...@@ -88,6 +89,29 @@ class ServerProcess:
return f"http://{host}:{port}" return f"http://{host}:{port}"
def is_server_ready(self) -> bool:
server_address = self._get_vllm_server_address()
try:
response = requests.get(f"{server_address}/health")
return response.status_code == 200
except requests.RequestException:
return False
def wait_until_ready(self, timeout: int) -> None:
start_time = time.monotonic()
while not self.is_server_ready():
# Check if server process has crashed
if self._server_process.poll() is not None:
returncode = self._server_process.returncode
raise RuntimeError(
f"Server process crashed with return code {returncode}"
)
if time.monotonic() - start_time > timeout:
raise TimeoutError(
f"Server failed to become ready within {timeout} seconds."
)
time.sleep(1)
def reset_caches(self) -> None: def reset_caches(self) -> None:
server_cmd = self.server_cmd server_cmd = self.server_cmd
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment