Unverified Commit bc636f21 authored by Clayton Coleman's avatar Clayton Coleman Committed by GitHub
Browse files

[Benchmark] Allow arbitrary headers to be passed to benchmarked endpoints (#23937)


Signed-off-by: default avatarClayton Coleman <smarterclayton@gmail.com>
parent 017354c0
......@@ -68,6 +68,7 @@ class RequestFuncInput:
model: str
model_name: Optional[str] = None
logprobs: Optional[int] = None
extra_headers: Optional[dict] = None
extra_body: Optional[dict] = None
multi_modal_content: Optional[Union[dict, list[dict]]] = None
ignore_eos: bool = False
......@@ -129,6 +130,8 @@ async def async_request_openai_completions(
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
}
if request_func_input.extra_headers:
headers |= request_func_input.extra_headers
if request_func_input.request_id:
headers["x-request-id"] = request_func_input.request_id
......@@ -258,6 +261,8 @@ async def async_request_openai_chat_completions(
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
}
if request_func_input.extra_headers:
headers |= request_func_input.extra_headers
if request_func_input.request_id:
headers["x-request-id"] = request_func_input.request_id
......@@ -364,6 +369,8 @@ async def async_request_openai_audio(
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
}
if request_func_input.extra_headers:
headers |= request_func_input.extra_headers
if request_func_input.request_id:
headers["x-request-id"] = request_func_input.request_id
......
......@@ -389,6 +389,7 @@ async def benchmark(
goodput_config_dict: dict[str, float],
max_concurrency: Optional[int],
lora_modules: Optional[Iterable[str]],
extra_headers: Optional[dict],
extra_body: Optional[dict],
ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None,
ramp_up_start_rps: Optional[int] = None,
......@@ -452,6 +453,7 @@ async def benchmark(
logprobs=logprobs,
multi_modal_content=test_mm_content,
ignore_eos=ignore_eos,
extra_headers=extra_headers,
extra_body=extra_body,
)
......@@ -484,6 +486,7 @@ async def benchmark(
logprobs=logprobs,
multi_modal_content=test_mm_content,
ignore_eos=ignore_eos,
extra_headers=extra_headers,
extra_body=extra_body)
profile_output = await request_func(
request_func_input=profile_input, session=session)
......@@ -568,6 +571,7 @@ async def benchmark(
logprobs=logprobs,
multi_modal_content=mm_content,
ignore_eos=ignore_eos,
extra_headers=extra_headers,
extra_body=extra_body,
request_id=request_id,)
tasks.append(
......@@ -815,6 +819,15 @@ def add_cli_args(parser: argparse.ArgumentParser):
default="/v1/completions",
help="API endpoint.",
)
parser.add_argument(
"--header",
metavar="KEY=VALUE",
nargs="*",
help="Key-value pairs (e.g, --header x-additional-info=0.3.3) "
"for headers to be passed with each request. These headers override " \
"per backend constants and values set via environment variable, and " \
"will be overriden by other arguments (such as request ids)."
)
parser.add_argument(
"--max-concurrency",
type=int,
......@@ -1104,6 +1117,19 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
base_url = f"http://{args.host}:{args.port}"
# Headers
headers = None
if args.header:
headers = {}
for item in args.header:
if "=" in item:
kvstring = item.split("=", 1)
headers[kvstring[0].strip()] = kvstring[1].strip()
else:
raise ValueError(
"Invalid header format. Please use KEY=VALUE format."
)
tokenizer = get_tokenizer(tokenizer_id,
tokenizer_mode=tokenizer_mode,
trust_remote_code=args.trust_remote_code)
......@@ -1161,6 +1187,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
goodput_config_dict=goodput_config_dict,
max_concurrency=args.max_concurrency,
lora_modules=args.lora_modules,
extra_headers=headers,
extra_body=sampling_params,
ramp_up_strategy=args.ramp_up_strategy,
ramp_up_start_rps=args.ramp_up_start_rps,
......@@ -1184,7 +1211,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
if args.metadata:
for item in args.metadata:
if "=" in item:
kvstring = item.split("=")
kvstring = item.split("=", 1)
result_json[kvstring[0].strip()] = kvstring[1].strip()
else:
raise ValueError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment