"examples/vscode:/vscode.git/clone" did not exist on "3efb5d8ecf7d748655e2199d120a40888ece2282"
Unverified Commit c45e49d8 authored by Xinyuan Tong's avatar Xinyuan Tong Committed by GitHub
Browse files

oai: Adds support for OpenAI chat completions API in bench_serving (#7036)


Signed-off-by: default avatarXinyuan Tong <justinning0323@outlook.com>
Co-authored-by: default avataryhyang201 <47235274+yhyang201@users.noreply.github.com>
Co-authored-by: default avatarMick <mickjagger19@icloud.com>
parent d8053929
...@@ -265,6 +265,138 @@ async def async_request_openai_completions( ...@@ -265,6 +265,138 @@ async def async_request_openai_completions(
return output return output
async def async_request_openai_chat_completions(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
"""Makes a request to the OpenAI Chat Completions API.
Handles both streaming and non-streaming responses, including support
for image data in messages. Calculates and returns various performance
metrics.
Args:
request_func_input: Input parameters for the request.
pbar: Optional tqdm progress bar to update.
Returns:
RequestFuncOutput: Output of the request, including generated text,
latency, TTFT, ITL, and success status.
"""
api_url = request_func_input.api_url
assert api_url.endswith(
"chat/completions"
), "OpenAI Chat Completions API URL must end with 'chat/completions'."
if request_func_input.image_data:
messages = [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": request_func_input.image_data},
},
{"type": "text", "text": request_func_input.prompt},
],
},
]
else:
messages = [{"role": "user", "content": request_func_input.prompt}]
async with _create_bench_client_session() as session:
payload = {
"model": request_func_input.model,
"messages": messages,
"temperature": 0.0,
"max_tokens": request_func_input.output_len,
"stream": not args.disable_stream,
**request_func_input.extra_request_body,
}
headers = get_auth_headers()
output = RequestFuncOutput.init_new(request_func_input)
generated_text = ""
output_len = request_func_input.output_len
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(
url=api_url, json=payload, headers=headers
) as response:
if response.status == 200:
if args.disable_stream:
# Non-streaming response
response_json = await response.json()
output.generated_text = response_json["choices"][0]["message"][
"content"
]
output.success = True
output.latency = time.perf_counter() - st
output.ttft = (
output.latency
) # For non-streaming, TTFT = total latency
output.output_len = response_json.get("usage", {}).get(
"completion_tokens", output_len
)
else:
# Streaming response
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
latency = time.perf_counter() - st
if chunk == "[DONE]":
pass
else:
data = json.loads(chunk)
# Check if this chunk contains content
delta = data.get("choices", [{}])[0].get("delta", {})
content = delta.get("content", "")
if content:
timestamp = time.perf_counter()
# First token
if ttft == 0.0:
ttft = timestamp - st
output.ttft = ttft
# Decoding phase
else:
output.itl.append(
timestamp - most_recent_timestamp
)
most_recent_timestamp = timestamp
generated_text += content
# Check for usage info in final chunk
output_len = (data.get("usage") or {}).get(
"completion_tokens", output_len
)
output.generated_text = generated_text
output.success = True
output.latency = latency
output.output_len = output_len
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
if pbar:
pbar.update(1)
return output
async def async_request_truss( async def async_request_truss(
request_func_input: RequestFuncInput, request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None, pbar: Optional[tqdm] = None,
...@@ -544,6 +676,7 @@ def get_dataset(args, tokenizer): ...@@ -544,6 +676,7 @@ def get_dataset(args, tokenizer):
num_requests=args.num_prompts, num_requests=args.num_prompts,
tokenizer=tokenizer, tokenizer=tokenizer,
fixed_output_len=args.random_output_len, fixed_output_len=args.random_output_len,
apply_chat_template=args.apply_chat_template,
random_sample=True, random_sample=True,
) )
else: else:
...@@ -555,8 +688,11 @@ ASYNC_REQUEST_FUNCS = { ...@@ -555,8 +688,11 @@ ASYNC_REQUEST_FUNCS = {
"sglang": async_request_sglang_generate, "sglang": async_request_sglang_generate,
"sglang-native": async_request_sglang_generate, "sglang-native": async_request_sglang_generate,
"sglang-oai": async_request_openai_completions, "sglang-oai": async_request_openai_completions,
"sglang-oai-chat": async_request_openai_chat_completions,
"vllm": async_request_openai_completions, "vllm": async_request_openai_completions,
"vllm-chat": async_request_openai_chat_completions,
"lmdeploy": async_request_openai_completions, "lmdeploy": async_request_openai_completions,
"lmdeploy-chat": async_request_openai_chat_completions,
"trt": async_request_trt_llm, "trt": async_request_trt_llm,
"gserver": async_request_gserver, "gserver": async_request_gserver,
"truss": async_request_truss, "truss": async_request_truss,
...@@ -661,6 +797,7 @@ def sample_mmmu_requests( ...@@ -661,6 +797,7 @@ def sample_mmmu_requests(
num_requests: int, num_requests: int,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int] = None, fixed_output_len: Optional[int] = None,
apply_chat_template: bool = True,
random_sample: bool = True, random_sample: bool = True,
) -> List[DatasetRow]: ) -> List[DatasetRow]:
""" """
...@@ -670,6 +807,7 @@ def sample_mmmu_requests( ...@@ -670,6 +807,7 @@ def sample_mmmu_requests(
num_requests: Number of requests to sample. num_requests: Number of requests to sample.
tokenizer: Tokenizer to use for token counting. tokenizer: Tokenizer to use for token counting.
fixed_output_len: If provided, use this fixed output length for all requests. fixed_output_len: If provided, use this fixed output length for all requests.
apply_chat_template: Whether to apply the chat template to the prompt.
random_sample: Whether to randomly sample or take the first N. random_sample: Whether to randomly sample or take the first N.
Returns: Returns:
...@@ -739,28 +877,30 @@ def sample_mmmu_requests( ...@@ -739,28 +877,30 @@ def sample_mmmu_requests(
# Construct the prompt # Construct the prompt
prompt = f"Question: {question}\n\nAnswer: " prompt = f"Question: {question}\n\nAnswer: "
if apply_chat_template:
try: try:
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
[ [
{ {
"role": "user", "role": "user",
"content": [ "content": [
{ {
"type": "image_url", "type": "image_url",
"image_url": {"url": image_data}, "image_url": {"url": image_data},
}, },
{"type": "text", "text": prompt}, {"type": "text", "text": prompt},
], ],
} }
], ],
add_generation_prompt=True, add_generation_prompt=True,
tokenize=False, tokenize=False,
) )
except Exception as e: except Exception as e:
# Note (Xinyuan): This is a workaround for an issue where some tokenizers do not support content as a list. (e.g. InternVL) # Note (Xinyuan): This is a workaround for an issue where some tokenizers do not support content as a list. (e.g. InternVL)
print(f"Error applying chat template: {e}, fallback to <image> tag") print(
prompt = f"<image>{prompt}" f"Error applying chat template: {e}, fallback to <image> tag"
)
prompt = f"<image>{prompt}"
# Calculate token lengths for text only (without image data) # Calculate token lengths for text only (without image data)
prompt_token_ids = tokenizer.encode(prompt) prompt_token_ids = tokenizer.encode(prompt)
...@@ -1538,12 +1678,19 @@ def run_benchmark(args_: argparse.Namespace): ...@@ -1538,12 +1678,19 @@ def run_benchmark(args_: argparse.Namespace):
if args.base_url if args.base_url
else f"http://{args.host}:{args.port}/generate" else f"http://{args.host}:{args.port}/generate"
) )
args.apply_chat_template = True
elif args.backend in ["sglang-oai", "vllm", "lmdeploy"]: elif args.backend in ["sglang-oai", "vllm", "lmdeploy"]:
api_url = ( api_url = (
f"{args.base_url}/v1/completions" f"{args.base_url}/v1/completions"
if args.base_url if args.base_url
else f"http://{args.host}:{args.port}/v1/completions" else f"http://{args.host}:{args.port}/v1/completions"
) )
elif args.backend in ["sglang-oai-chat", "vllm-chat", "lmdeploy-chat"]:
api_url = (
f"{args.base_url}/v1/chat/completions"
if args.base_url
else f"http://{args.host}:{args.port}/v1/chat/completions"
)
elif args.backend == "trt": elif args.backend == "trt":
api_url = ( api_url = (
f"{args.base_url}/v2/models/ensemble/generate_stream" f"{args.base_url}/v2/models/ensemble/generate_stream"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment