Unverified Commit fda6bb78 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

update bench_serving (#4958)

parent 23c764b1
...@@ -44,6 +44,12 @@ ASSISTANT_SUFFIX = "Assistant:" ...@@ -44,6 +44,12 @@ ASSISTANT_SUFFIX = "Assistant:"
global args global args
# don't want to import sglang package here
def _get_bool_env_var(name: str, default: str = "false") -> bool:
value = os.getenv(name, default)
return value.lower() in ("true", "1")
@dataclass @dataclass
class RequestFuncInput: class RequestFuncInput:
prompt: str prompt: str
...@@ -969,6 +975,7 @@ async def benchmark( ...@@ -969,6 +975,7 @@ async def benchmark(
extra_request_body: Dict[str, Any], extra_request_body: Dict[str, Any],
profile: bool, profile: bool,
pd_seperated: bool = False, pd_seperated: bool = False,
flush_cache: bool = False,
): ):
if backend in ASYNC_REQUEST_FUNCS: if backend in ASYNC_REQUEST_FUNCS:
request_func = ASYNC_REQUEST_FUNCS[backend] request_func = ASYNC_REQUEST_FUNCS[backend]
...@@ -1012,7 +1019,7 @@ async def benchmark( ...@@ -1012,7 +1019,7 @@ async def benchmark(
print("Initial test run completed. Starting main benchmark run...") print("Initial test run completed. Starting main benchmark run...")
# Flush cache # Flush cache
if "sglang" in backend: if ("sglang" in backend and _get_bool_env_var("SGLANG_IS_IN_CI")) or flush_cache:
requests.post(base_url + "/flush_cache", headers=get_auth_headers()) requests.post(base_url + "/flush_cache", headers=get_auth_headers())
time.sleep(1.0) time.sleep(1.0)
...@@ -1347,6 +1354,10 @@ def run_benchmark(args_: argparse.Namespace): ...@@ -1347,6 +1354,10 @@ def run_benchmark(args_: argparse.Namespace):
tokenizer = get_tokenizer(tokenizer_id) tokenizer = get_tokenizer(tokenizer_id)
input_requests = get_dataset(args, tokenizer) input_requests = get_dataset(args, tokenizer)
# compatible with SimpleNamespace
if not hasattr(args, "flush_cache"):
args.flush_cache = False
return asyncio.run( return asyncio.run(
benchmark( benchmark(
backend=backend, backend=backend,
...@@ -1362,6 +1373,7 @@ def run_benchmark(args_: argparse.Namespace): ...@@ -1362,6 +1373,7 @@ def run_benchmark(args_: argparse.Namespace):
extra_request_body=extra_request_body, extra_request_body=extra_request_body,
profile=args.profile, profile=args.profile,
pd_seperated=args.pd_seperated, pd_seperated=args.pd_seperated,
flush_cache=args.flush_cache,
) )
) )
...@@ -1543,6 +1555,11 @@ if __name__ == "__main__": ...@@ -1543,6 +1555,11 @@ if __name__ == "__main__":
action="store_true", action="store_true",
help="Benchmark PD disaggregation server", help="Benchmark PD disaggregation server",
) )
parser.add_argument(
"--flush-cache",
action="store_true",
help="Flush the cache before running the benchmark",
)
group = parser.add_argument_group("generated-shared-prefix dataset arguments") group = parser.add_argument_group("generated-shared-prefix dataset arguments")
group.add_argument( group.add_argument(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment