Unverified Commit fda6bb78 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

update bench_serving (#4958)

parent 23c764b1
......@@ -44,6 +44,12 @@ ASSISTANT_SUFFIX = "Assistant:"
global args
# don't want to import sglang package here
def _get_bool_env_var(name: str, default: str = "false") -> bool:
value = os.getenv(name, default)
return value.lower() in ("true", "1")
@dataclass
class RequestFuncInput:
prompt: str
......@@ -969,6 +975,7 @@ async def benchmark(
extra_request_body: Dict[str, Any],
profile: bool,
pd_seperated: bool = False,
flush_cache: bool = False,
):
if backend in ASYNC_REQUEST_FUNCS:
request_func = ASYNC_REQUEST_FUNCS[backend]
......@@ -1012,7 +1019,7 @@ async def benchmark(
print("Initial test run completed. Starting main benchmark run...")
# Flush cache
if "sglang" in backend:
if ("sglang" in backend and _get_bool_env_var("SGLANG_IS_IN_CI")) or flush_cache:
requests.post(base_url + "/flush_cache", headers=get_auth_headers())
time.sleep(1.0)
......@@ -1347,6 +1354,10 @@ def run_benchmark(args_: argparse.Namespace):
tokenizer = get_tokenizer(tokenizer_id)
input_requests = get_dataset(args, tokenizer)
# compatible with SimpleNamespace
if not hasattr(args, "flush_cache"):
args.flush_cache = False
return asyncio.run(
benchmark(
backend=backend,
......@@ -1362,6 +1373,7 @@ def run_benchmark(args_: argparse.Namespace):
extra_request_body=extra_request_body,
profile=args.profile,
pd_seperated=args.pd_seperated,
flush_cache=args.flush_cache,
)
)
......@@ -1543,6 +1555,11 @@ if __name__ == "__main__":
action="store_true",
help="Benchmark PD disaggregation server",
)
parser.add_argument(
"--flush-cache",
action="store_true",
help="Flush the cache before running the benchmark",
)
group = parser.add_argument_group("generated-shared-prefix dataset arguments")
group.add_argument(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment