Unverified Commit d53dcf9c authored by yichuan~'s avatar yichuan~ Committed by GitHub
Browse files

Support more OpenAI API test (#916)

parent bb66cc4c
...@@ -92,7 +92,7 @@ class GenerateReqInput: ...@@ -92,7 +92,7 @@ class GenerateReqInput:
for element in parallel_sample_num_list for element in parallel_sample_num_list
) )
if parallel_sample_num > 1 and (not all_equal): if parallel_sample_num > 1 and (not all_equal):
## TODO cope with the case that the parallel_sample_num is different for different samples # TODO cope with the case that the parallel_sample_num is different for different samples
raise ValueError( raise ValueError(
"The parallel_sample_num should be the same for all samples in sample params." "The parallel_sample_num should be the same for all samples in sample params."
) )
...@@ -103,14 +103,19 @@ class GenerateReqInput: ...@@ -103,14 +103,19 @@ class GenerateReqInput:
if parallel_sample_num != 1: if parallel_sample_num != 1:
# parallel sampling +1 represents the original prefill stage # parallel sampling +1 represents the original prefill stage
num = parallel_sample_num + 1 num = parallel_sample_num + 1
if isinstance(self.text, List): if isinstance(self.text, list):
## suppot batch operation # suppot batch operation
self.batch_size = len(self.text) self.batch_size = len(self.text)
num = num * len(self.text) num = num * len(self.text)
elif isinstance(self.input_ids, list) and isinstance(
self.input_ids[0], list
):
self.batch_size = len(self.input_ids)
num = num * len(self.input_ids)
else: else:
self.batch_size = 1 self.batch_size = 1
else: else:
## support select operation # support select operation
num = len(self.text) if self.text is not None else len(self.input_ids) num = len(self.text) if self.text is not None else len(self.input_ids)
self.batch_size = num self.batch_size = num
......
...@@ -153,8 +153,9 @@ class TokenizerManager: ...@@ -153,8 +153,9 @@ class TokenizerManager:
async def _handle_single_request( async def _handle_single_request(
self, obj, request, index=None, is_cache_for_prefill=False self, obj, request, index=None, is_cache_for_prefill=False
): ):
if not is_cache_for_prefill: if not is_cache_for_prefill: # The normal case with a single prompt
not_use_index = not (index is not None) not_use_index = index is None
rid = obj.rid if not_use_index else obj.rid[index] rid = obj.rid if not_use_index else obj.rid[index]
input_text = obj.text if not_use_index else obj.text[index] input_text = obj.text if not_use_index else obj.text[index]
input_ids = ( input_ids = (
...@@ -182,14 +183,27 @@ class TokenizerManager: ...@@ -182,14 +183,27 @@ class TokenizerManager:
top_logprobs_num = ( top_logprobs_num = (
obj.top_logprobs_num if not_use_index else obj.top_logprobs_num[index] obj.top_logprobs_num if not_use_index else obj.top_logprobs_num[index]
) )
else: else: # A prefill request to cache the common prompt for parallel sampling
if isinstance(obj.text, list): if obj.text is not None:
input_text = obj.text[index] if isinstance(obj.text, list):
rid = obj.rid[index] input_text = obj.text[index]
rid = obj.rid[index]
else:
input_text = obj.text
rid = obj.rid[0]
input_ids = self.tokenizer.encode(input_text)
else: else:
input_text = obj.text input_text = None
rid = obj.rid[0] if isinstance(obj.input_ids, list) and isinstance(
input_ids = self.tokenizer.encode(input_text) obj.input_ids[0], list
):
# when obj["input_ids"] is List[List[int]]
input_ids = obj.input_ids[index]
rid = obj.rid[index]
else:
input_ids = obj.input_ids
rid = obj.rid[0]
sampling_params = SamplingParams(**obj.sampling_params[0]) sampling_params = SamplingParams(**obj.sampling_params[0])
sampling_params.max_new_tokens = 0 sampling_params.max_new_tokens = 0
pixel_values, image_hash, image_size = await self._get_pixel_values( pixel_values, image_hash, image_size = await self._get_pixel_values(
...@@ -240,11 +254,11 @@ class TokenizerManager: ...@@ -240,11 +254,11 @@ class TokenizerManager:
): ):
if input_id_result is not None: if input_id_result is not None:
input_id_result.append(input_id) input_id_result.append(input_id)
pass if input_id_result is not None and len(input_id_result) > 1:
if len(input_id_result) > 1 and input_id_result is not None:
obj.input_ids = input_id_result obj.input_ids = input_id_result
elif input_id_result is not None: elif input_id_result is not None:
obj.input_ids = input_id_result[0] obj.input_ids = input_id_result[0]
# First send out all requests # First send out all requests
for i in range(batch_size): for i in range(batch_size):
for j in range(parallel_sample_num): for j in range(parallel_sample_num):
...@@ -264,11 +278,12 @@ class TokenizerManager: ...@@ -264,11 +278,12 @@ class TokenizerManager:
input_text = None input_text = None
input_ids = obj.input_ids[i] input_ids = obj.input_ids[i]
else: else:
assert obj.input_ids is not None
if batch_size == 1: if batch_size == 1:
input_text = obj.text input_text = None
input_ids = obj.input_ids input_ids = obj.input_ids
else: else:
input_text = obj.text[i] input_text = None
input_ids = obj.input_ids[i] input_ids = obj.input_ids[i]
sampling_params = self._get_sampling_params(obj.sampling_params[index]) sampling_params = self._get_sampling_params(obj.sampling_params[index])
pixel_values, image_hash, image_size = await self._get_pixel_values( pixel_values, image_hash, image_size = await self._get_pixel_values(
......
...@@ -251,7 +251,9 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe ...@@ -251,7 +251,9 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
if end_point == "/v1/chat/completions": if end_point == "/v1/chat/completions":
responses = v1_chat_generate_response(request, ret, to_file=True) responses = v1_chat_generate_response(request, ret, to_file=True)
else: else:
responses = v1_generate_response(request, ret, to_file=True) responses = v1_generate_response(
request, ret, tokenizer_manager, to_file=True
)
except Exception as e: except Exception as e:
error_json = { error_json = {
...@@ -339,6 +341,7 @@ def v1_generate_request(all_requests): ...@@ -339,6 +341,7 @@ def v1_generate_request(all_requests):
return_logprobs = [] return_logprobs = []
top_logprobs_nums = [] top_logprobs_nums = []
first_prompt_type = type(all_requests[0].prompt) first_prompt_type = type(all_requests[0].prompt)
for request in all_requests: for request in all_requests:
prompt = request.prompt prompt = request.prompt
assert ( assert (
...@@ -364,7 +367,7 @@ def v1_generate_request(all_requests): ...@@ -364,7 +367,7 @@ def v1_generate_request(all_requests):
) )
if len(all_requests) > 1 and request.n > 1: if len(all_requests) > 1 and request.n > 1:
raise ValueError( raise ValueError(
"Batch operation is not supported for completions from files" "Parallel sampling is not supported for completions from files"
) )
if len(all_requests) == 1: if len(all_requests) == 1:
...@@ -377,10 +380,11 @@ def v1_generate_request(all_requests): ...@@ -377,10 +380,11 @@ def v1_generate_request(all_requests):
else: else:
prompt_kwargs = {"input_ids": prompt} prompt_kwargs = {"input_ids": prompt}
else: else:
if isinstance(prompts[0], str): if isinstance(prompts[0], str) or isinstance(propmt[0][0], str):
prompt_kwargs = {"text": prompts} prompt_kwargs = {"text": prompts}
else: else:
prompt_kwargs = {"input_ids": prompts} prompt_kwargs = {"input_ids": prompts}
adapted_request = GenerateReqInput( adapted_request = GenerateReqInput(
**prompt_kwargs, **prompt_kwargs,
sampling_params=sampling_params_list, sampling_params=sampling_params_list,
...@@ -389,35 +393,52 @@ def v1_generate_request(all_requests): ...@@ -389,35 +393,52 @@ def v1_generate_request(all_requests):
return_text_in_logprobs=True, return_text_in_logprobs=True,
stream=all_requests[0].stream, stream=all_requests[0].stream,
) )
if len(all_requests) == 1: if len(all_requests) == 1:
return adapted_request, all_requests[0] return adapted_request, all_requests[0]
return adapted_request, all_requests return adapted_request, all_requests
def v1_generate_response(request, ret, to_file=False): def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
choices = [] choices = []
echo = False echo = False
if (not isinstance(request, List)) and request.echo: if (not isinstance(request, list)) and request.echo:
# TODO: handle the case propmt is token ids # TODO: handle the case propmt is token ids
if isinstance(request.prompt, list): if isinstance(request.prompt, list) and isinstance(request.prompt[0], str):
# for the case of multiple str prompts
prompts = request.prompt prompts = request.prompt
elif isinstance(request.prompt, list) and isinstance(request.prompt[0], list):
# for the case of multiple token ids prompts
prompts = [
tokenizer_manager.tokenizer.decode(prompt, skip_special_tokens=True)
for prompt in request.prompt
]
elif isinstance(request.prompt, list) and isinstance(request.prompt[0], int):
# for the case of single token ids prompt
prompts = [
tokenizer_manager.tokenizer.decode(
request.prompt, skip_special_tokens=True
)
]
else: else:
# for the case of single str prompt
prompts = [request.prompt] prompts = [request.prompt]
echo = True echo = True
for idx, ret_item in enumerate(ret): for idx, ret_item in enumerate(ret):
text = ret_item["text"] text = ret_item["text"]
if isinstance(request, List) and request[idx].echo: if isinstance(request, list) and request[idx].echo:
echo = True echo = True
text = request[idx].prompt + text text = request[idx].prompt + text
if (not isinstance(request, List)) and echo: if (not isinstance(request, list)) and echo:
text = prompts[idx] + text prompt_index = idx // request.n
text = prompts[prompt_index] + text
logprobs = False logprobs = False
if isinstance(request, List) and request[idx].logprobs: if isinstance(request, list) and request[idx].logprobs:
logprobs = True logprobs = True
elif (not isinstance(request, List)) and request.logprobs: elif (not isinstance(request, list)) and request.logprobs:
logprobs = True logprobs = True
if logprobs: if logprobs:
if echo: if echo:
...@@ -479,15 +500,16 @@ def v1_generate_response(request, ret, to_file=False): ...@@ -479,15 +500,16 @@ def v1_generate_response(request, ret, to_file=False):
responses.append(response) responses.append(response)
return responses return responses
else: else:
prompt_tokens = sum(item["meta_info"]["prompt_tokens"] for item in ret)
completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret) completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
response = CompletionResponse( response = CompletionResponse(
id=ret[0]["meta_info"]["id"], id=ret[0]["meta_info"]["id"],
model=request.model, model=request.model,
choices=choices, choices=choices,
usage=UsageInfo( usage=UsageInfo(
prompt_tokens=ret[0]["meta_info"]["prompt_tokens"], prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=ret[0]["meta_info"]["prompt_tokens"] + completion_tokens, total_tokens=prompt_tokens + completion_tokens,
), ),
) )
return response return response
...@@ -513,8 +535,18 @@ async def v1_completions(tokenizer_manager, raw_request: Request): ...@@ -513,8 +535,18 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
if not stream_buffer: # The first chunk if not stream_buffer: # The first chunk
if request.echo: if request.echo:
if isinstance(request.prompt, str):
# for the case of single str prompts
prompts = request.prompt
elif isinstance(request.prompt, list) and isinstance(
request.prompt[0], int
):
prompts = tokenizer_manager.tokenizer.decode(
request.prompt, skip_special_tokens=True
)
# Prepend prompt in response text. # Prepend prompt in response text.
text = request.prompt + text text = prompts + text
if request.logprobs: if request.logprobs:
# The first chunk and echo is enabled. # The first chunk and echo is enabled.
...@@ -539,7 +571,6 @@ async def v1_completions(tokenizer_manager, raw_request: Request): ...@@ -539,7 +571,6 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
"output_top_logprobs" "output_top_logprobs"
][n_prev_token:], ][n_prev_token:],
) )
n_prev_token = len( n_prev_token = len(
content["meta_info"]["output_token_logprobs"] content["meta_info"]["output_token_logprobs"]
) )
...@@ -588,7 +619,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request): ...@@ -588,7 +619,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
if not isinstance(ret, list): if not isinstance(ret, list):
ret = [ret] ret = [ret]
response = v1_generate_response(request, ret) response = v1_generate_response(request, ret, tokenizer_manager)
return response return response
...@@ -626,7 +657,7 @@ def v1_chat_generate_request(all_requests, tokenizer_manager): ...@@ -626,7 +657,7 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
prompt_ids = tokenizer_manager.tokenizer.encode(prompt) prompt_ids = tokenizer_manager.tokenizer.encode(prompt)
else: else:
# Use the raw prompt and stop strings if the messages is already a string. # Use the raw prompt and stop strings if the messages is already a string.
prompt = request.messages prompt_ids = request.messages
stop = request.stop stop = request.stop
image_data = None image_data = None
input_ids.append(prompt_ids) input_ids.append(prompt_ids)
...@@ -647,12 +678,21 @@ def v1_chat_generate_request(all_requests, tokenizer_manager): ...@@ -647,12 +678,21 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
image_data_list.append(image_data) image_data_list.append(image_data)
if len(all_requests) == 1: if len(all_requests) == 1:
input_ids = input_ids[0] input_ids = input_ids[0]
if isinstance(input_ids, str):
prompt_kwargs = {"text": input_ids}
else:
prompt_kwargs = {"input_ids": input_ids}
sampling_params_list = sampling_params_list[0] sampling_params_list = sampling_params_list[0]
image_data = image_data_list[0] image_data = image_data_list[0]
return_logprobs = return_logprobs[0] return_logprobs = return_logprobs[0]
top_logprobs_nums = top_logprobs_nums[0] top_logprobs_nums = top_logprobs_nums[0]
else:
if isinstance(input_ids[0], str):
prompt_kwargs = {"text": input_ids}
else:
prompt_kwargs = {"input_ids": input_ids}
adapted_request = GenerateReqInput( adapted_request = GenerateReqInput(
input_ids=input_ids, **prompt_kwargs,
image_data=image_data, image_data=image_data,
sampling_params=sampling_params_list, sampling_params=sampling_params_list,
return_logprob=return_logprobs, return_logprob=return_logprobs,
...@@ -672,9 +712,9 @@ def v1_chat_generate_response(request, ret, to_file=False): ...@@ -672,9 +712,9 @@ def v1_chat_generate_response(request, ret, to_file=False):
for idx, ret_item in enumerate(ret): for idx, ret_item in enumerate(ret):
logprobs = False logprobs = False
if isinstance(request, List) and request[idx].logprobs: if isinstance(request, list) and request[idx].logprobs:
logprobs = True logprobs = True
elif (not isinstance(request, List)) and request.logprobs: elif (not isinstance(request, list)) and request.logprobs:
logprobs = True logprobs = True
if logprobs: if logprobs:
logprobs = to_openai_style_logprobs( logprobs = to_openai_style_logprobs(
...@@ -779,10 +819,58 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): ...@@ -779,10 +819,58 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
is_first = True is_first = True
stream_buffer = "" stream_buffer = ""
n_prev_token = 0
try: try:
async for content in tokenizer_manager.generate_request( async for content in tokenizer_manager.generate_request(
adapted_request, raw_request adapted_request, raw_request
): ):
prompt_tokens = content["meta_info"]["prompt_tokens"]
completion_tokens = content["meta_info"]["completion_tokens"]
if request.logprobs:
logprobs = to_openai_style_logprobs(
output_token_logprobs=content["meta_info"][
"output_token_logprobs"
][n_prev_token:],
output_top_logprobs=content["meta_info"][
"output_top_logprobs"
][n_prev_token:],
)
n_prev_token = len(
content["meta_info"]["output_token_logprobs"]
)
token_logprobs = []
for token, logprob in zip(
logprobs.tokens, logprobs.token_logprobs
):
token_bytes = list(token.encode("utf-8"))
top_logprobs = []
if logprobs.top_logprobs:
for top_token, top_logprob in logprobs.top_logprobs[
0
].items():
top_token_bytes = list(top_token.encode("utf-8"))
top_logprobs.append(
TopLogprob(
token=top_token,
bytes=top_token_bytes,
logprob=top_logprob,
)
)
token_logprobs.append(
ChatCompletionTokenLogprob(
token=token,
bytes=token_bytes,
logprob=logprob,
top_logprobs=top_logprobs,
)
)
choice_logprobs = ChoiceLogprobs(content=token_logprobs)
else:
choice_logprobs = None
if is_first: if is_first:
# First chunk with role # First chunk with role
is_first = False is_first = False
...@@ -790,11 +878,17 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): ...@@ -790,11 +878,17 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
index=0, index=0,
delta=DeltaMessage(role="assistant"), delta=DeltaMessage(role="assistant"),
finish_reason=content["meta_info"]["finish_reason"], finish_reason=content["meta_info"]["finish_reason"],
logprobs=choice_logprobs,
) )
chunk = ChatCompletionStreamResponse( chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"], id=content["meta_info"]["id"],
choices=[choice_data], choices=[choice_data],
model=request.model, model=request.model,
usage=UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
) )
yield f"data: {chunk.model_dump_json()}\n\n" yield f"data: {chunk.model_dump_json()}\n\n"
...@@ -805,11 +899,17 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): ...@@ -805,11 +899,17 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
index=0, index=0,
delta=DeltaMessage(content=delta), delta=DeltaMessage(content=delta),
finish_reason=content["meta_info"]["finish_reason"], finish_reason=content["meta_info"]["finish_reason"],
logprobs=choice_logprobs,
) )
chunk = ChatCompletionStreamResponse( chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"], id=content["meta_info"]["id"],
choices=[choice_data], choices=[choice_data],
model=request.model, model=request.model,
usage=UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
) )
yield f"data: {chunk.model_dump_json()}\n\n" yield f"data: {chunk.model_dump_json()}\n\n"
except ValueError as e: except ValueError as e:
......
...@@ -278,7 +278,7 @@ class DeltaMessage(BaseModel): ...@@ -278,7 +278,7 @@ class DeltaMessage(BaseModel):
class ChatCompletionResponseStreamChoice(BaseModel): class ChatCompletionResponseStreamChoice(BaseModel):
index: int index: int
delta: DeltaMessage delta: DeltaMessage
logprobs: Optional[LogProbs] = None logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
finish_reason: Optional[str] = None finish_reason: Optional[str] = None
......
...@@ -3,6 +3,7 @@ import unittest ...@@ -3,6 +3,7 @@ import unittest
import openai import openai
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_child_process
from sglang.test.test_utils import MODEL_NAME_FOR_TEST, popen_launch_server from sglang.test.test_utils import MODEL_NAME_FOR_TEST, popen_launch_server
...@@ -18,60 +19,85 @@ class TestOpenAIServer(unittest.TestCase): ...@@ -18,60 +19,85 @@ class TestOpenAIServer(unittest.TestCase):
cls.model, cls.base_url, timeout=300, api_key=cls.api_key cls.model, cls.base_url, timeout=300, api_key=cls.api_key
) )
cls.base_url += "/v1" cls.base_url += "/v1"
cls.tokenizer = get_tokenizer(MODEL_NAME_FOR_TEST)
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_child_process(cls.process.pid) kill_child_process(cls.process.pid)
def run_completion(self, echo, logprobs, use_list_input): def run_completion(
self, echo, logprobs, use_list_input, parallel_sample_num, token_input
):
client = openai.Client(api_key=self.api_key, base_url=self.base_url) client = openai.Client(api_key=self.api_key, base_url=self.base_url)
prompt = "The capital of France is" prompt = "The capital of France is"
if token_input:
prompt_input = self.tokenizer.encode(prompt)
num_prompt_tokens = len(prompt_input)
else:
prompt_input = prompt
num_prompt_tokens = len(self.tokenizer.encode(prompt))
if use_list_input: if use_list_input:
prompt_arg = [prompt, prompt] prompt_arg = [prompt_input, prompt_input]
num_choices = len(prompt_arg) num_choices = len(prompt_arg)
num_prompt_tokens *= 2
else: else:
prompt_arg = prompt prompt_arg = prompt_input
num_choices = 1 num_choices = 1
if parallel_sample_num:
# FIXME: This is wrong. We should not count the prompt tokens multiple times for
# parallel sampling.
num_prompt_tokens *= parallel_sample_num
response = client.completions.create( response = client.completions.create(
model=self.model, model=self.model,
prompt=prompt_arg, prompt=prompt_arg,
temperature=0.1, temperature=0,
max_tokens=32, max_tokens=32,
echo=echo, echo=echo,
logprobs=logprobs, logprobs=logprobs,
n=parallel_sample_num,
) )
assert len(response.choices) == num_choices assert len(response.choices) == num_choices * parallel_sample_num
if echo: if echo:
text = response.choices[0].text text = response.choices[0].text
assert text.startswith(prompt) assert text.startswith(prompt)
if logprobs: if logprobs:
assert response.choices[0].logprobs assert response.choices[0].logprobs
assert isinstance(response.choices[0].logprobs.tokens[0], str) assert isinstance(response.choices[0].logprobs.tokens[0], str)
assert isinstance(response.choices[0].logprobs.top_logprobs[1], dict) assert isinstance(response.choices[0].logprobs.top_logprobs[1], dict)
ret_num_top_logprobs = len(response.choices[0].logprobs.top_logprobs[1]) ret_num_top_logprobs = len(response.choices[0].logprobs.top_logprobs[1])
# FIXME: Fix this bug. Sometimes, some top_logprobs are missing in the return value. # FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some out_put id maps to the same output token and duplicate in the map
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}" # assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
assert ret_num_top_logprobs > 0
if echo: if echo:
assert response.choices[0].logprobs.token_logprobs[0] == None assert response.choices[0].logprobs.token_logprobs[0] == None
else: else:
assert response.choices[0].logprobs.token_logprobs[0] != None assert response.choices[0].logprobs.token_logprobs[0] != None
assert response.id assert response.id
assert response.created assert response.created
assert response.usage.prompt_tokens > 0 assert (
response.usage.prompt_tokens == num_prompt_tokens
), f"{response.usage.prompt_tokens} vs {num_prompt_tokens}"
assert response.usage.completion_tokens > 0 assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0 assert response.usage.total_tokens > 0
def run_completion_stream(self, echo, logprobs): def run_completion_stream(self, echo, logprobs, token_input):
client = openai.Client(api_key=self.api_key, base_url=self.base_url) client = openai.Client(api_key=self.api_key, base_url=self.base_url)
prompt = "The capital of France is" prompt = "The capital of France is"
if token_input:
prompt_arg = self.tokenizer.encode(prompt)
else:
prompt_arg = prompt
generator = client.completions.create( generator = client.completions.create(
model=self.model, model=self.model,
prompt=prompt, prompt=prompt_arg,
temperature=0.1, temperature=0,
max_tokens=32, max_tokens=32,
echo=echo, echo=echo,
logprobs=logprobs, logprobs=logprobs,
...@@ -90,12 +116,15 @@ class TestOpenAIServer(unittest.TestCase): ...@@ -90,12 +116,15 @@ class TestOpenAIServer(unittest.TestCase):
ret_num_top_logprobs = len( ret_num_top_logprobs = len(
response.choices[0].logprobs.top_logprobs[0] response.choices[0].logprobs.top_logprobs[0]
) )
# FIXME: Fix this bug. Sometimes, some top_logprobs are missing in the return value. # FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some out_put id maps to the same output token and duplicate in the map
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}" # assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
assert ret_num_top_logprobs > 0
if first: if first:
if echo: if echo:
assert response.choices[0].text.startswith(prompt) assert response.choices[0].text.startswith(
prompt
), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {first}"
first = False first = False
assert response.id assert response.id
...@@ -104,7 +133,7 @@ class TestOpenAIServer(unittest.TestCase): ...@@ -104,7 +133,7 @@ class TestOpenAIServer(unittest.TestCase):
assert response.usage.completion_tokens > 0 assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0 assert response.usage.total_tokens > 0
def run_chat_completion(self, logprobs): def run_chat_completion(self, logprobs, parallel_sample_num):
client = openai.Client(api_key=self.api_key, base_url=self.base_url) client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.chat.completions.create( response = client.chat.completions.create(
model=self.model, model=self.model,
...@@ -116,6 +145,7 @@ class TestOpenAIServer(unittest.TestCase): ...@@ -116,6 +145,7 @@ class TestOpenAIServer(unittest.TestCase):
max_tokens=32, max_tokens=32,
logprobs=logprobs is not None and logprobs > 0, logprobs=logprobs is not None and logprobs > 0,
top_logprobs=logprobs, top_logprobs=logprobs,
n=parallel_sample_num,
) )
if logprobs: if logprobs:
assert isinstance( assert isinstance(
...@@ -128,7 +158,7 @@ class TestOpenAIServer(unittest.TestCase): ...@@ -128,7 +158,7 @@ class TestOpenAIServer(unittest.TestCase):
assert ( assert (
ret_num_top_logprobs == logprobs ret_num_top_logprobs == logprobs
), f"{ret_num_top_logprobs} vs {logprobs}" ), f"{ret_num_top_logprobs} vs {logprobs}"
assert len(response.choices) == parallel_sample_num
assert response.choices[0].message.role == "assistant" assert response.choices[0].message.role == "assistant"
assert isinstance(response.choices[0].message.content, str) assert isinstance(response.choices[0].message.content, str)
assert response.id assert response.id
...@@ -161,11 +191,21 @@ class TestOpenAIServer(unittest.TestCase): ...@@ -161,11 +191,21 @@ class TestOpenAIServer(unittest.TestCase):
continue continue
if logprobs: if logprobs:
# FIXME: Fix this bug. Return top_logprobs in the streaming mode. assert response.choices[0].logprobs
pass assert isinstance(
response.choices[0].logprobs.content[0].top_logprobs[0].token, str
)
assert isinstance(
response.choices[0].logprobs.content[0].top_logprobs, list
)
ret_num_top_logprobs = len(
response.choices[0].logprobs.content[0].top_logprobs
)
assert (
ret_num_top_logprobs == logprobs
), f"{ret_num_top_logprobs} vs {logprobs}"
assert isinstance(data.content, str) assert isinstance(data.content, str)
assert response.id assert response.id
assert response.created assert response.created
...@@ -173,16 +213,27 @@ class TestOpenAIServer(unittest.TestCase): ...@@ -173,16 +213,27 @@ class TestOpenAIServer(unittest.TestCase):
for echo in [False, True]: for echo in [False, True]:
for logprobs in [None, 5]: for logprobs in [None, 5]:
for use_list_input in [True, False]: for use_list_input in [True, False]:
self.run_completion(echo, logprobs, use_list_input) for parallel_sample_num in [1, 2]:
for token_input in [False, True]:
self.run_completion(
echo,
logprobs,
use_list_input,
parallel_sample_num,
token_input,
)
def test_completion_stream(self): def test_completion_stream(self):
# parallel sampling adn list input are not supported in streaming mode
for echo in [False, True]: for echo in [False, True]:
for logprobs in [None, 5]: for logprobs in [None, 5]:
self.run_completion_stream(echo, logprobs) for token_input in [False, True]:
self.run_completion_stream(echo, logprobs, token_input)
def test_chat_completion(self): def test_chat_completion(self):
for logprobs in [None, 5]: for logprobs in [None, 5]:
self.run_chat_completion(logprobs) for parallel_sample_num in [1, 2]:
self.run_chat_completion(logprobs, parallel_sample_num)
def test_chat_completion_stream(self): def test_chat_completion_stream(self):
for logprobs in [None, 5]: for logprobs in [None, 5]:
...@@ -224,5 +275,5 @@ if __name__ == "__main__": ...@@ -224,5 +275,5 @@ if __name__ == "__main__":
# t = TestOpenAIServer() # t = TestOpenAIServer()
# t.setUpClass() # t.setUpClass()
# t.test_chat_completion_stream() # t.test_completion()
# t.tearDownClass() # t.tearDownClass()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment