Unverified Commit fd7926e4 authored by yichuan~'s avatar yichuan~ Committed by GitHub
Browse files

Fix prompt len in parallel sampling (#928)

parent 399cad91
......@@ -500,7 +500,9 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
responses.append(response)
return responses
else:
prompt_tokens = sum(item["meta_info"]["prompt_tokens"] for item in ret)
prompt_tokens = sum(
ret[i]["meta_info"]["prompt_tokens"] for i in range(0, len(ret), request.n)
)
completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
response = CompletionResponse(
id=ret[0]["meta_info"]["id"],
......@@ -707,8 +709,6 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
def v1_chat_generate_response(request, ret, to_file=False):
choices = []
total_prompt_tokens = 0
total_completion_tokens = 0
for idx, ret_item in enumerate(ret):
logprobs = False
......@@ -747,8 +747,6 @@ def v1_chat_generate_response(request, ret, to_file=False):
choice_logprobs = ChoiceLogprobs(content=token_logprobs)
else:
choice_logprobs = None
prompt_tokens = ret_item["meta_info"]["prompt_tokens"]
completion_tokens = ret_item["meta_info"]["completion_tokens"]
if to_file:
# to make the choice data json serializable
......@@ -767,8 +765,7 @@ def v1_chat_generate_response(request, ret, to_file=False):
)
choices.append(choice_data)
total_prompt_tokens += prompt_tokens
total_completion_tokens += completion_tokens
if to_file:
responses = []
......@@ -795,14 +792,18 @@ def v1_chat_generate_response(request, ret, to_file=False):
responses.append(response)
return responses
else:
prompt_tokens = sum(
ret[i]["meta_info"]["prompt_tokens"] for i in range(0, len(ret), request.n)
)
completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
response = ChatCompletionResponse(
id=ret[0]["meta_info"]["id"],
model=request.model,
choices=choices,
usage=UsageInfo(
prompt_tokens=total_prompt_tokens,
completion_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
return response
......
......@@ -45,11 +45,6 @@ class TestOpenAIServer(unittest.TestCase):
prompt_arg = prompt_input
num_choices = 1
if parallel_sample_num:
# FIXME: This is wrong. We should not count the prompt tokens multiple times for
# parallel sampling.
num_prompt_tokens *= parallel_sample_num
response = client.completions.create(
model=self.model,
prompt=prompt_arg,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment