"src/vscode:/vscode.git/clone" did not exist on "d769d8a13b74684aa1725761ce25bf26ebc4fb5a"
Unverified Commit d53dcf9c authored by yichuan~'s avatar yichuan~ Committed by GitHub
Browse files

Support more OpenAI API test (#916)

parent bb66cc4c
......@@ -92,7 +92,7 @@ class GenerateReqInput:
for element in parallel_sample_num_list
)
if parallel_sample_num > 1 and (not all_equal):
## TODO cope with the case that the parallel_sample_num is different for different samples
# TODO cope with the case that the parallel_sample_num is different for different samples
raise ValueError(
"The parallel_sample_num should be the same for all samples in sample params."
)
......@@ -103,14 +103,19 @@ class GenerateReqInput:
if parallel_sample_num != 1:
# parallel sampling +1 represents the original prefill stage
num = parallel_sample_num + 1
if isinstance(self.text, List):
## suppot batch operation
if isinstance(self.text, list):
# suppot batch operation
self.batch_size = len(self.text)
num = num * len(self.text)
elif isinstance(self.input_ids, list) and isinstance(
self.input_ids[0], list
):
self.batch_size = len(self.input_ids)
num = num * len(self.input_ids)
else:
self.batch_size = 1
else:
## support select operation
# support select operation
num = len(self.text) if self.text is not None else len(self.input_ids)
self.batch_size = num
......
......@@ -153,8 +153,9 @@ class TokenizerManager:
async def _handle_single_request(
self, obj, request, index=None, is_cache_for_prefill=False
):
if not is_cache_for_prefill:
not_use_index = not (index is not None)
if not is_cache_for_prefill: # The normal case with a single prompt
not_use_index = index is None
rid = obj.rid if not_use_index else obj.rid[index]
input_text = obj.text if not_use_index else obj.text[index]
input_ids = (
......@@ -182,14 +183,27 @@ class TokenizerManager:
top_logprobs_num = (
obj.top_logprobs_num if not_use_index else obj.top_logprobs_num[index]
)
else:
if isinstance(obj.text, list):
input_text = obj.text[index]
rid = obj.rid[index]
else: # A prefill request to cache the common prompt for parallel sampling
if obj.text is not None:
if isinstance(obj.text, list):
input_text = obj.text[index]
rid = obj.rid[index]
else:
input_text = obj.text
rid = obj.rid[0]
input_ids = self.tokenizer.encode(input_text)
else:
input_text = obj.text
rid = obj.rid[0]
input_ids = self.tokenizer.encode(input_text)
input_text = None
if isinstance(obj.input_ids, list) and isinstance(
obj.input_ids[0], list
):
# when obj["input_ids"] is List[List[int]]
input_ids = obj.input_ids[index]
rid = obj.rid[index]
else:
input_ids = obj.input_ids
rid = obj.rid[0]
sampling_params = SamplingParams(**obj.sampling_params[0])
sampling_params.max_new_tokens = 0
pixel_values, image_hash, image_size = await self._get_pixel_values(
......@@ -240,11 +254,11 @@ class TokenizerManager:
):
if input_id_result is not None:
input_id_result.append(input_id)
pass
if len(input_id_result) > 1 and input_id_result is not None:
if input_id_result is not None and len(input_id_result) > 1:
obj.input_ids = input_id_result
elif input_id_result is not None:
obj.input_ids = input_id_result[0]
# First send out all requests
for i in range(batch_size):
for j in range(parallel_sample_num):
......@@ -264,11 +278,12 @@ class TokenizerManager:
input_text = None
input_ids = obj.input_ids[i]
else:
assert obj.input_ids is not None
if batch_size == 1:
input_text = obj.text
input_text = None
input_ids = obj.input_ids
else:
input_text = obj.text[i]
input_text = None
input_ids = obj.input_ids[i]
sampling_params = self._get_sampling_params(obj.sampling_params[index])
pixel_values, image_hash, image_size = await self._get_pixel_values(
......
......@@ -251,7 +251,9 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
if end_point == "/v1/chat/completions":
responses = v1_chat_generate_response(request, ret, to_file=True)
else:
responses = v1_generate_response(request, ret, to_file=True)
responses = v1_generate_response(
request, ret, tokenizer_manager, to_file=True
)
except Exception as e:
error_json = {
......@@ -339,6 +341,7 @@ def v1_generate_request(all_requests):
return_logprobs = []
top_logprobs_nums = []
first_prompt_type = type(all_requests[0].prompt)
for request in all_requests:
prompt = request.prompt
assert (
......@@ -364,7 +367,7 @@ def v1_generate_request(all_requests):
)
if len(all_requests) > 1 and request.n > 1:
raise ValueError(
"Batch operation is not supported for completions from files"
"Parallel sampling is not supported for completions from files"
)
if len(all_requests) == 1:
......@@ -377,10 +380,11 @@ def v1_generate_request(all_requests):
else:
prompt_kwargs = {"input_ids": prompt}
else:
if isinstance(prompts[0], str):
if isinstance(prompts[0], str) or isinstance(propmt[0][0], str):
prompt_kwargs = {"text": prompts}
else:
prompt_kwargs = {"input_ids": prompts}
adapted_request = GenerateReqInput(
**prompt_kwargs,
sampling_params=sampling_params_list,
......@@ -389,35 +393,52 @@ def v1_generate_request(all_requests):
return_text_in_logprobs=True,
stream=all_requests[0].stream,
)
if len(all_requests) == 1:
return adapted_request, all_requests[0]
return adapted_request, all_requests
def v1_generate_response(request, ret, to_file=False):
def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
choices = []
echo = False
if (not isinstance(request, List)) and request.echo:
if (not isinstance(request, list)) and request.echo:
# TODO: handle the case propmt is token ids
if isinstance(request.prompt, list):
if isinstance(request.prompt, list) and isinstance(request.prompt[0], str):
# for the case of multiple str prompts
prompts = request.prompt
elif isinstance(request.prompt, list) and isinstance(request.prompt[0], list):
# for the case of multiple token ids prompts
prompts = [
tokenizer_manager.tokenizer.decode(prompt, skip_special_tokens=True)
for prompt in request.prompt
]
elif isinstance(request.prompt, list) and isinstance(request.prompt[0], int):
# for the case of single token ids prompt
prompts = [
tokenizer_manager.tokenizer.decode(
request.prompt, skip_special_tokens=True
)
]
else:
# for the case of single str prompt
prompts = [request.prompt]
echo = True
for idx, ret_item in enumerate(ret):
text = ret_item["text"]
if isinstance(request, List) and request[idx].echo:
if isinstance(request, list) and request[idx].echo:
echo = True
text = request[idx].prompt + text
if (not isinstance(request, List)) and echo:
text = prompts[idx] + text
if (not isinstance(request, list)) and echo:
prompt_index = idx // request.n
text = prompts[prompt_index] + text
logprobs = False
if isinstance(request, List) and request[idx].logprobs:
if isinstance(request, list) and request[idx].logprobs:
logprobs = True
elif (not isinstance(request, List)) and request.logprobs:
elif (not isinstance(request, list)) and request.logprobs:
logprobs = True
if logprobs:
if echo:
......@@ -479,15 +500,16 @@ def v1_generate_response(request, ret, to_file=False):
responses.append(response)
return responses
else:
prompt_tokens = sum(item["meta_info"]["prompt_tokens"] for item in ret)
completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
response = CompletionResponse(
id=ret[0]["meta_info"]["id"],
model=request.model,
choices=choices,
usage=UsageInfo(
prompt_tokens=ret[0]["meta_info"]["prompt_tokens"],
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=ret[0]["meta_info"]["prompt_tokens"] + completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
return response
......@@ -513,8 +535,18 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
if not stream_buffer: # The first chunk
if request.echo:
if isinstance(request.prompt, str):
# for the case of single str prompts
prompts = request.prompt
elif isinstance(request.prompt, list) and isinstance(
request.prompt[0], int
):
prompts = tokenizer_manager.tokenizer.decode(
request.prompt, skip_special_tokens=True
)
# Prepend prompt in response text.
text = request.prompt + text
text = prompts + text
if request.logprobs:
# The first chunk and echo is enabled.
......@@ -539,7 +571,6 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
"output_top_logprobs"
][n_prev_token:],
)
n_prev_token = len(
content["meta_info"]["output_token_logprobs"]
)
......@@ -588,7 +619,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
if not isinstance(ret, list):
ret = [ret]
response = v1_generate_response(request, ret)
response = v1_generate_response(request, ret, tokenizer_manager)
return response
......@@ -626,7 +657,7 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
prompt_ids = tokenizer_manager.tokenizer.encode(prompt)
else:
# Use the raw prompt and stop strings if the messages is already a string.
prompt = request.messages
prompt_ids = request.messages
stop = request.stop
image_data = None
input_ids.append(prompt_ids)
......@@ -647,12 +678,21 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
image_data_list.append(image_data)
if len(all_requests) == 1:
input_ids = input_ids[0]
if isinstance(input_ids, str):
prompt_kwargs = {"text": input_ids}
else:
prompt_kwargs = {"input_ids": input_ids}
sampling_params_list = sampling_params_list[0]
image_data = image_data_list[0]
return_logprobs = return_logprobs[0]
top_logprobs_nums = top_logprobs_nums[0]
else:
if isinstance(input_ids[0], str):
prompt_kwargs = {"text": input_ids}
else:
prompt_kwargs = {"input_ids": input_ids}
adapted_request = GenerateReqInput(
input_ids=input_ids,
**prompt_kwargs,
image_data=image_data,
sampling_params=sampling_params_list,
return_logprob=return_logprobs,
......@@ -672,9 +712,9 @@ def v1_chat_generate_response(request, ret, to_file=False):
for idx, ret_item in enumerate(ret):
logprobs = False
if isinstance(request, List) and request[idx].logprobs:
if isinstance(request, list) and request[idx].logprobs:
logprobs = True
elif (not isinstance(request, List)) and request.logprobs:
elif (not isinstance(request, list)) and request.logprobs:
logprobs = True
if logprobs:
logprobs = to_openai_style_logprobs(
......@@ -779,10 +819,58 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
is_first = True
stream_buffer = ""
n_prev_token = 0
try:
async for content in tokenizer_manager.generate_request(
adapted_request, raw_request
):
prompt_tokens = content["meta_info"]["prompt_tokens"]
completion_tokens = content["meta_info"]["completion_tokens"]
if request.logprobs:
logprobs = to_openai_style_logprobs(
output_token_logprobs=content["meta_info"][
"output_token_logprobs"
][n_prev_token:],
output_top_logprobs=content["meta_info"][
"output_top_logprobs"
][n_prev_token:],
)
n_prev_token = len(
content["meta_info"]["output_token_logprobs"]
)
token_logprobs = []
for token, logprob in zip(
logprobs.tokens, logprobs.token_logprobs
):
token_bytes = list(token.encode("utf-8"))
top_logprobs = []
if logprobs.top_logprobs:
for top_token, top_logprob in logprobs.top_logprobs[
0
].items():
top_token_bytes = list(top_token.encode("utf-8"))
top_logprobs.append(
TopLogprob(
token=top_token,
bytes=top_token_bytes,
logprob=top_logprob,
)
)
token_logprobs.append(
ChatCompletionTokenLogprob(
token=token,
bytes=token_bytes,
logprob=logprob,
top_logprobs=top_logprobs,
)
)
choice_logprobs = ChoiceLogprobs(content=token_logprobs)
else:
choice_logprobs = None
if is_first:
# First chunk with role
is_first = False
......@@ -790,11 +878,17 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
index=0,
delta=DeltaMessage(role="assistant"),
finish_reason=content["meta_info"]["finish_reason"],
logprobs=choice_logprobs,
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
choices=[choice_data],
model=request.model,
usage=UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
yield f"data: {chunk.model_dump_json()}\n\n"
......@@ -805,11 +899,17 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
index=0,
delta=DeltaMessage(content=delta),
finish_reason=content["meta_info"]["finish_reason"],
logprobs=choice_logprobs,
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
choices=[choice_data],
model=request.model,
usage=UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
yield f"data: {chunk.model_dump_json()}\n\n"
except ValueError as e:
......
......@@ -278,7 +278,7 @@ class DeltaMessage(BaseModel):
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
logprobs: Optional[LogProbs] = None
logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
finish_reason: Optional[str] = None
......
......@@ -3,6 +3,7 @@ import unittest
import openai
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.utils import kill_child_process
from sglang.test.test_utils import MODEL_NAME_FOR_TEST, popen_launch_server
......@@ -18,60 +19,85 @@ class TestOpenAIServer(unittest.TestCase):
cls.model, cls.base_url, timeout=300, api_key=cls.api_key
)
cls.base_url += "/v1"
cls.tokenizer = get_tokenizer(MODEL_NAME_FOR_TEST)
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
def run_completion(self, echo, logprobs, use_list_input):
def run_completion(
self, echo, logprobs, use_list_input, parallel_sample_num, token_input
):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
prompt = "The capital of France is"
if token_input:
prompt_input = self.tokenizer.encode(prompt)
num_prompt_tokens = len(prompt_input)
else:
prompt_input = prompt
num_prompt_tokens = len(self.tokenizer.encode(prompt))
if use_list_input:
prompt_arg = [prompt, prompt]
prompt_arg = [prompt_input, prompt_input]
num_choices = len(prompt_arg)
num_prompt_tokens *= 2
else:
prompt_arg = prompt
prompt_arg = prompt_input
num_choices = 1
if parallel_sample_num:
# FIXME: This is wrong. We should not count the prompt tokens multiple times for
# parallel sampling.
num_prompt_tokens *= parallel_sample_num
response = client.completions.create(
model=self.model,
prompt=prompt_arg,
temperature=0.1,
temperature=0,
max_tokens=32,
echo=echo,
logprobs=logprobs,
n=parallel_sample_num,
)
assert len(response.choices) == num_choices
assert len(response.choices) == num_choices * parallel_sample_num
if echo:
text = response.choices[0].text
assert text.startswith(prompt)
if logprobs:
assert response.choices[0].logprobs
assert isinstance(response.choices[0].logprobs.tokens[0], str)
assert isinstance(response.choices[0].logprobs.top_logprobs[1], dict)
ret_num_top_logprobs = len(response.choices[0].logprobs.top_logprobs[1])
# FIXME: Fix this bug. Sometimes, some top_logprobs are missing in the return value.
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some out_put id maps to the same output token and duplicate in the map
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
assert ret_num_top_logprobs > 0
if echo:
assert response.choices[0].logprobs.token_logprobs[0] == None
else:
assert response.choices[0].logprobs.token_logprobs[0] != None
assert response.id
assert response.created
assert response.usage.prompt_tokens > 0
assert (
response.usage.prompt_tokens == num_prompt_tokens
), f"{response.usage.prompt_tokens} vs {num_prompt_tokens}"
assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0
def run_completion_stream(self, echo, logprobs):
def run_completion_stream(self, echo, logprobs, token_input):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
prompt = "The capital of France is"
if token_input:
prompt_arg = self.tokenizer.encode(prompt)
else:
prompt_arg = prompt
generator = client.completions.create(
model=self.model,
prompt=prompt,
temperature=0.1,
prompt=prompt_arg,
temperature=0,
max_tokens=32,
echo=echo,
logprobs=logprobs,
......@@ -90,12 +116,15 @@ class TestOpenAIServer(unittest.TestCase):
ret_num_top_logprobs = len(
response.choices[0].logprobs.top_logprobs[0]
)
# FIXME: Fix this bug. Sometimes, some top_logprobs are missing in the return value.
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some out_put id maps to the same output token and duplicate in the map
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
assert ret_num_top_logprobs > 0
if first:
if echo:
assert response.choices[0].text.startswith(prompt)
assert response.choices[0].text.startswith(
prompt
), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {first}"
first = False
assert response.id
......@@ -104,7 +133,7 @@ class TestOpenAIServer(unittest.TestCase):
assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0
def run_chat_completion(self, logprobs):
def run_chat_completion(self, logprobs, parallel_sample_num):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.chat.completions.create(
model=self.model,
......@@ -116,6 +145,7 @@ class TestOpenAIServer(unittest.TestCase):
max_tokens=32,
logprobs=logprobs is not None and logprobs > 0,
top_logprobs=logprobs,
n=parallel_sample_num,
)
if logprobs:
assert isinstance(
......@@ -128,7 +158,7 @@ class TestOpenAIServer(unittest.TestCase):
assert (
ret_num_top_logprobs == logprobs
), f"{ret_num_top_logprobs} vs {logprobs}"
assert len(response.choices) == parallel_sample_num
assert response.choices[0].message.role == "assistant"
assert isinstance(response.choices[0].message.content, str)
assert response.id
......@@ -161,11 +191,21 @@ class TestOpenAIServer(unittest.TestCase):
continue
if logprobs:
# FIXME: Fix this bug. Return top_logprobs in the streaming mode.
pass
assert response.choices[0].logprobs
assert isinstance(
response.choices[0].logprobs.content[0].top_logprobs[0].token, str
)
assert isinstance(
response.choices[0].logprobs.content[0].top_logprobs, list
)
ret_num_top_logprobs = len(
response.choices[0].logprobs.content[0].top_logprobs
)
assert (
ret_num_top_logprobs == logprobs
), f"{ret_num_top_logprobs} vs {logprobs}"
assert isinstance(data.content, str)
assert response.id
assert response.created
......@@ -173,16 +213,27 @@ class TestOpenAIServer(unittest.TestCase):
for echo in [False, True]:
for logprobs in [None, 5]:
for use_list_input in [True, False]:
self.run_completion(echo, logprobs, use_list_input)
for parallel_sample_num in [1, 2]:
for token_input in [False, True]:
self.run_completion(
echo,
logprobs,
use_list_input,
parallel_sample_num,
token_input,
)
def test_completion_stream(self):
# parallel sampling adn list input are not supported in streaming mode
for echo in [False, True]:
for logprobs in [None, 5]:
self.run_completion_stream(echo, logprobs)
for token_input in [False, True]:
self.run_completion_stream(echo, logprobs, token_input)
def test_chat_completion(self):
for logprobs in [None, 5]:
self.run_chat_completion(logprobs)
for parallel_sample_num in [1, 2]:
self.run_chat_completion(logprobs, parallel_sample_num)
def test_chat_completion_stream(self):
for logprobs in [None, 5]:
......@@ -224,5 +275,5 @@ if __name__ == "__main__":
# t = TestOpenAIServer()
# t.setUpClass()
# t.test_chat_completion_stream()
# t.test_completion()
# t.tearDownClass()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment