Unverified Commit ec0a72c2 authored by Kebe's avatar Kebe Committed by GitHub
Browse files

Fix bench_serving not recognizing OPENAI_API_KEY (#3870)


Signed-off-by: default avatarKebe <mail@kebe7jun.com>
parent 1c96fa86
...@@ -71,6 +71,14 @@ def remove_prefix(text: str, prefix: str) -> str: ...@@ -71,6 +71,14 @@ def remove_prefix(text: str, prefix: str) -> str:
return text[len(prefix) :] if text.startswith(prefix) else text return text[len(prefix) :] if text.startswith(prefix) else text
def get_auth_headers() -> Dict[str, str]:
api_key = os.environ.get("OPENAI_API_KEY")
if api_key:
return {"Authorization": f"Bearer {api_key}"}
else:
return {}
# trt llm not support ignore_eos # trt llm not support ignore_eos
# https://github.com/triton-inference-server/tensorrtllm_backend/issues/505 # https://github.com/triton-inference-server/tensorrtllm_backend/issues/505
async def async_request_trt_llm( async def async_request_trt_llm(
...@@ -165,7 +173,7 @@ async def async_request_openai_completions( ...@@ -165,7 +173,7 @@ async def async_request_openai_completions(
"ignore_eos": not args.disable_ignore_eos, "ignore_eos": not args.disable_ignore_eos,
**request_func_input.extra_request_body, **request_func_input.extra_request_body,
} }
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} headers = get_auth_headers()
output = RequestFuncOutput() output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
...@@ -244,7 +252,7 @@ async def async_request_truss( ...@@ -244,7 +252,7 @@ async def async_request_truss(
"ignore_eos": not args.disable_ignore_eos, "ignore_eos": not args.disable_ignore_eos,
**request_func_input.extra_request_body, **request_func_input.extra_request_body,
} }
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} headers = get_auth_headers()
output = RequestFuncOutput() output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
...@@ -325,7 +333,7 @@ async def async_request_sglang_generate( ...@@ -325,7 +333,7 @@ async def async_request_sglang_generate(
"logprob_start_len": -1, "logprob_start_len": -1,
**request_func_input.extra_request_body, **request_func_input.extra_request_body,
} }
headers = {} headers = get_auth_headers()
output = RequestFuncOutput() output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
...@@ -1238,7 +1246,7 @@ def run_benchmark(args_: argparse.Namespace): ...@@ -1238,7 +1246,7 @@ def run_benchmark(args_: argparse.Namespace):
) )
sys.exit(1) sys.exit(1)
try: try:
response = requests.get(model_url) response = requests.get(model_url, headers=get_auth_headers())
model_list = response.json().get("data", []) model_list = response.json().get("data", [])
args.model = model_list[0]["id"] if model_list else None args.model = model_list[0]["id"] if model_list else None
except Exception as e: except Exception as e:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment