Unverified Commit fd7926e4 authored by yichuan~'s avatar yichuan~ Committed by GitHub
Browse files

Fix prompt len in parallel sampling (#928)

parent 399cad91
...@@ -500,7 +500,9 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False): ...@@ -500,7 +500,9 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
responses.append(response) responses.append(response)
return responses return responses
else: else:
prompt_tokens = sum(item["meta_info"]["prompt_tokens"] for item in ret) prompt_tokens = sum(
ret[i]["meta_info"]["prompt_tokens"] for i in range(0, len(ret), request.n)
)
completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret) completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
response = CompletionResponse( response = CompletionResponse(
id=ret[0]["meta_info"]["id"], id=ret[0]["meta_info"]["id"],
...@@ -707,8 +709,6 @@ def v1_chat_generate_request(all_requests, tokenizer_manager): ...@@ -707,8 +709,6 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
def v1_chat_generate_response(request, ret, to_file=False): def v1_chat_generate_response(request, ret, to_file=False):
choices = [] choices = []
total_prompt_tokens = 0
total_completion_tokens = 0
for idx, ret_item in enumerate(ret): for idx, ret_item in enumerate(ret):
logprobs = False logprobs = False
...@@ -747,8 +747,6 @@ def v1_chat_generate_response(request, ret, to_file=False): ...@@ -747,8 +747,6 @@ def v1_chat_generate_response(request, ret, to_file=False):
choice_logprobs = ChoiceLogprobs(content=token_logprobs) choice_logprobs = ChoiceLogprobs(content=token_logprobs)
else: else:
choice_logprobs = None choice_logprobs = None
prompt_tokens = ret_item["meta_info"]["prompt_tokens"]
completion_tokens = ret_item["meta_info"]["completion_tokens"]
if to_file: if to_file:
# to make the choice data json serializable # to make the choice data json serializable
...@@ -767,8 +765,7 @@ def v1_chat_generate_response(request, ret, to_file=False): ...@@ -767,8 +765,7 @@ def v1_chat_generate_response(request, ret, to_file=False):
) )
choices.append(choice_data) choices.append(choice_data)
total_prompt_tokens += prompt_tokens
total_completion_tokens += completion_tokens
if to_file: if to_file:
responses = [] responses = []
...@@ -795,14 +792,18 @@ def v1_chat_generate_response(request, ret, to_file=False): ...@@ -795,14 +792,18 @@ def v1_chat_generate_response(request, ret, to_file=False):
responses.append(response) responses.append(response)
return responses return responses
else: else:
prompt_tokens = sum(
ret[i]["meta_info"]["prompt_tokens"] for i in range(0, len(ret), request.n)
)
completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
response = ChatCompletionResponse( response = ChatCompletionResponse(
id=ret[0]["meta_info"]["id"], id=ret[0]["meta_info"]["id"],
model=request.model, model=request.model,
choices=choices, choices=choices,
usage=UsageInfo( usage=UsageInfo(
prompt_tokens=total_prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=total_completion_tokens, completion_tokens=completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens, total_tokens=prompt_tokens + completion_tokens,
), ),
) )
return response return response
......
...@@ -45,11 +45,6 @@ class TestOpenAIServer(unittest.TestCase): ...@@ -45,11 +45,6 @@ class TestOpenAIServer(unittest.TestCase):
prompt_arg = prompt_input prompt_arg = prompt_input
num_choices = 1 num_choices = 1
if parallel_sample_num:
# FIXME: This is wrong. We should not count the prompt tokens multiple times for
# parallel sampling.
num_prompt_tokens *= parallel_sample_num
response = client.completions.create( response = client.completions.create(
model=self.model, model=self.model,
prompt=prompt_arg, prompt=prompt_arg,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment