Unverified Commit fbd6b94d authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Fix the double BOS problem in the HF chat template (#888)

parent 4c8093c8
...@@ -594,7 +594,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request): ...@@ -594,7 +594,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
def v1_chat_generate_request(all_requests, tokenizer_manager): def v1_chat_generate_request(all_requests, tokenizer_manager):
texts = [] input_ids = []
sampling_params_list = [] sampling_params_list = []
image_data_list = [] image_data_list = []
return_logprobs = [] return_logprobs = []
...@@ -608,8 +608,8 @@ def v1_chat_generate_request(all_requests, tokenizer_manager): ...@@ -608,8 +608,8 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
if not isinstance(request.messages, str): if not isinstance(request.messages, str):
# Apply chat template and its stop strings. # Apply chat template and its stop strings.
if chat_template_name is None: if chat_template_name is None:
prompt = tokenizer_manager.tokenizer.apply_chat_template( prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
request.messages, tokenize=False, add_generation_prompt=True request.messages, tokenize=True, add_generation_prompt=True
) )
stop = request.stop stop = request.stop
image_data = None image_data = None
...@@ -623,12 +623,13 @@ def v1_chat_generate_request(all_requests, tokenizer_manager): ...@@ -623,12 +623,13 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
stop.append(request.stop) stop.append(request.stop)
else: else:
stop.extend(request.stop) stop.extend(request.stop)
prompt_ids = tokenizer_manager.tokenizer.encode(prompt)
else: else:
# Use the raw prompt and stop strings if the messages is already a string. # Use the raw prompt and stop strings if the messages is already a string.
prompt = request.messages prompt = request.messages
stop = request.stop stop = request.stop
image_data = None image_data = None
texts.append(prompt) input_ids.append(prompt_ids)
return_logprobs.append(request.logprobs) return_logprobs.append(request.logprobs)
top_logprobs_nums.append(request.top_logprobs) top_logprobs_nums.append(request.top_logprobs)
sampling_params_list.append( sampling_params_list.append(
...@@ -645,13 +646,13 @@ def v1_chat_generate_request(all_requests, tokenizer_manager): ...@@ -645,13 +646,13 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
) )
image_data_list.append(image_data) image_data_list.append(image_data)
if len(all_requests) == 1: if len(all_requests) == 1:
texts = texts[0] input_ids = input_ids[0]
sampling_params_list = sampling_params_list[0] sampling_params_list = sampling_params_list[0]
image_data = image_data_list[0] image_data = image_data_list[0]
return_logprobs = return_logprobs[0] return_logprobs = return_logprobs[0]
top_logprobs_nums = top_logprobs_nums[0] top_logprobs_nums = top_logprobs_nums[0]
adapted_request = GenerateReqInput( adapted_request = GenerateReqInput(
text=texts, input_ids=input_ids,
image_data=image_data, image_data=image_data,
sampling_params=sampling_params_list, sampling_params=sampling_params_list,
return_logprob=return_logprobs, return_logprob=return_logprobs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment