Unverified Commit 0d7db16a authored by Abirdcfly's avatar Abirdcfly Committed by GitHub
Browse files

[PD] add test for chat completions endpoint (#21925)


Signed-off-by: default avatarAbirdcfly <fp544037857@gmail.com>
parent 845420ac
...@@ -51,20 +51,31 @@ def check_vllm_server(url: str, timeout=5, retries=3) -> bool: ...@@ -51,20 +51,31 @@ def check_vllm_server(url: str, timeout=5, retries=3) -> bool:
return False return False
def run_simple_prompt(base_url: str, model_name: str, def run_simple_prompt(base_url: str, model_name: str, input_prompt: str,
input_prompt: str) -> str: use_chat_endpoint: bool) -> str:
client = openai.OpenAI(api_key="EMPTY", base_url=base_url) client = openai.OpenAI(api_key="EMPTY", base_url=base_url)
completion = client.completions.create(model=model_name, if use_chat_endpoint:
prompt=input_prompt, completion = client.chat.completions.create(
max_tokens=MAX_OUTPUT_LEN, model=model_name,
temperature=0.0, messages=[{
seed=42) "role": "user",
"content": [{
"type": "text",
"text": input_prompt
}]
}],
max_completion_tokens=MAX_OUTPUT_LEN,
temperature=0.0,
seed=42)
return completion.choices[0].message.content
else:
completion = client.completions.create(model=model_name,
prompt=input_prompt,
max_tokens=MAX_OUTPUT_LEN,
temperature=0.0,
seed=42)
# print("-" * 50) return completion.choices[0].text
# print(f"Completion results for {model_name}:")
# print(completion)
# print("-" * 50)
return completion.choices[0].text
def main(): def main():
...@@ -125,10 +136,12 @@ def main(): ...@@ -125,10 +136,12 @@ def main():
f"vllm server: {args.service_url} is not ready yet!") f"vllm server: {args.service_url} is not ready yet!")
output_strs = dict() output_strs = dict()
for prompt in SAMPLE_PROMPTS: for i, prompt in enumerate(SAMPLE_PROMPTS):
use_chat_endpoint = (i % 2 == 1)
output_str = run_simple_prompt(base_url=service_url, output_str = run_simple_prompt(base_url=service_url,
model_name=args.model_name, model_name=args.model_name,
input_prompt=prompt) input_prompt=prompt,
use_chat_endpoint=use_chat_endpoint)
print(f"Prompt: {prompt}, output: {output_str}") print(f"Prompt: {prompt}, output: {output_str}")
output_strs[prompt] = output_str output_strs[prompt] = output_str
......
...@@ -162,6 +162,8 @@ async def send_request_to_service(client_info: dict, endpoint: str, ...@@ -162,6 +162,8 @@ async def send_request_to_service(client_info: dict, endpoint: str,
} }
req_data["stream"] = False req_data["stream"] = False
req_data["max_tokens"] = 1 req_data["max_tokens"] = 1
if "max_completion_tokens" in req_data:
req_data["max_completion_tokens"] = 1
if "stream_options" in req_data: if "stream_options" in req_data:
del req_data["stream_options"] del req_data["stream_options"]
headers = { headers = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment