Unverified Commit 48761171 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

[Fix] fix eos trim inconsistency (#1650)

parent c3f2fc5a
...@@ -18,7 +18,7 @@ limitations under the License. ...@@ -18,7 +18,7 @@ limitations under the License.
import dataclasses import dataclasses
import logging import logging
from collections import OrderedDict from collections import OrderedDict
from typing import List from typing import List, Union
import zmq import zmq
...@@ -29,7 +29,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -29,7 +29,7 @@ from sglang.srt.managers.io_struct import (
BatchTokenIDOut, BatchTokenIDOut,
UpdateWeightReqOutput, UpdateWeightReqOutput,
) )
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import configure_logger, kill_parent_process from sglang.srt.utils import configure_logger, kill_parent_process
from sglang.utils import find_printable_text, get_exception_traceback from sglang.utils import find_printable_text, get_exception_traceback
...@@ -75,6 +75,21 @@ class DetokenizerManager: ...@@ -75,6 +75,21 @@ class DetokenizerManager:
self.decode_status = LimitedCapacityDict() self.decode_status = LimitedCapacityDict()
def trim_eos(self, output: Union[str, List[int]], finished_reason, no_eos_trim):
if no_eos_trim:
return output
# Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit
if isinstance(finished_reason, FINISH_MATCHED_STR) and isinstance(output, str):
pos = output.find(finished_reason.matched)
return output[:pos] if pos != -1 else output
if isinstance(finished_reason, FINISH_MATCHED_TOKEN) and isinstance(
output, list
):
assert len(output) > 0
return output[:-1]
return output
def event_loop(self): def event_loop(self):
"""The event loop that handles requests""" """The event loop that handles requests"""
...@@ -122,7 +137,13 @@ class DetokenizerManager: ...@@ -122,7 +137,13 @@ class DetokenizerManager:
s = self.decode_status[rid] s = self.decode_status[rid]
s.decode_ids = recv_obj.decode_ids[i] s.decode_ids = recv_obj.decode_ids[i]
read_ids.append(s.decode_ids[s.surr_offset :]) read_ids.append(
self.trim_eos(
s.decode_ids[s.surr_offset :],
recv_obj.finished_reason[i],
recv_obj.no_eos_trim[i],
)
)
surr_ids.append(s.decode_ids[s.surr_offset : s.read_offset]) surr_ids.append(s.decode_ids[s.surr_offset : s.read_offset])
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request # TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
...@@ -152,13 +173,13 @@ class DetokenizerManager: ...@@ -152,13 +173,13 @@ class DetokenizerManager:
else: else:
new_text = find_printable_text(new_text) new_text = find_printable_text(new_text)
output_strs.append(s.decoded_text + new_text) output_strs.append(
self.trim_eos(
# Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit s.decoded_text + new_text,
if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR): recv_obj.finished_reason[i],
pos = output_strs[i].find(recv_obj.finished_reason[i].matched) recv_obj.no_eos_trim[i],
if pos != -1: )
output_strs[i] = output_strs[i][:pos] )
self.send_to_tokenizer.send_pyobj( self.send_to_tokenizer.send_pyobj(
BatchStrOut( BatchStrOut(
......
...@@ -295,6 +295,7 @@ class BatchTokenIDOut: ...@@ -295,6 +295,7 @@ class BatchTokenIDOut:
spaces_between_special_tokens: List[bool] spaces_between_special_tokens: List[bool]
meta_info: List[Dict] meta_info: List[Dict]
finished_reason: List[BaseFinishReason] finished_reason: List[BaseFinishReason]
no_eos_trim: List[bool]
@dataclass @dataclass
......
...@@ -883,6 +883,7 @@ class Scheduler: ...@@ -883,6 +883,7 @@ class Scheduler:
output_read_offsets = [] output_read_offsets = []
output_skip_special_tokens = [] output_skip_special_tokens = []
output_spaces_between_special_tokens = [] output_spaces_between_special_tokens = []
output_no_eos_trim = []
else: # embedding or reward model else: # embedding or reward model
output_embeddings = [] output_embeddings = []
unfinished_indices = [] unfinished_indices = []
...@@ -914,6 +915,7 @@ class Scheduler: ...@@ -914,6 +915,7 @@ class Scheduler:
output_spaces_between_special_tokens.append( output_spaces_between_special_tokens.append(
req.sampling_params.spaces_between_special_tokens req.sampling_params.spaces_between_special_tokens
) )
output_no_eos_trim.append(req.sampling_params.no_eos_trim)
meta_info = { meta_info = {
"prompt_tokens": len(req.origin_input_ids), "prompt_tokens": len(req.origin_input_ids),
...@@ -961,6 +963,7 @@ class Scheduler: ...@@ -961,6 +963,7 @@ class Scheduler:
output_spaces_between_special_tokens, output_spaces_between_special_tokens,
output_meta_info, output_meta_info,
output_finished_reason, output_finished_reason,
output_no_eos_trim,
) )
) )
else: # embedding or reward model else: # embedding or reward model
......
...@@ -493,23 +493,38 @@ def v1_generate_request( ...@@ -493,23 +493,38 @@ def v1_generate_request(
top_logprobs_nums.append( top_logprobs_nums.append(
request.logprobs if request.logprobs is not None else 0 request.logprobs if request.logprobs is not None else 0
) )
sampling_params_list.append( sampling_params = []
{ if isinstance(request.no_eos_trim, list):
"temperature": request.temperature, num_reqs = len(request.prompt)
"max_new_tokens": request.max_tokens, else:
"min_new_tokens": request.min_tokens, num_reqs = 1
"stop": request.stop, for i in range(num_reqs):
"stop_token_ids": request.stop_token_ids, sampling_params.append(
"top_p": request.top_p, {
"presence_penalty": request.presence_penalty, "temperature": request.temperature,
"frequency_penalty": request.frequency_penalty, "max_new_tokens": request.max_tokens,
"repetition_penalty": request.repetition_penalty, "min_new_tokens": request.min_tokens,
"regex": request.regex, "stop": request.stop,
"json_schema": request.json_schema, "stop_token_ids": request.stop_token_ids,
"n": request.n, "top_p": request.top_p,
"ignore_eos": request.ignore_eos, "presence_penalty": request.presence_penalty,
} "frequency_penalty": request.frequency_penalty,
) "repetition_penalty": request.repetition_penalty,
"regex": request.regex,
"json_schema": request.json_schema,
"n": request.n,
"ignore_eos": request.ignore_eos,
"no_eos_trim": (
request.no_eos_trim
if not isinstance(request.no_eos_trim, list)
else request.no_eos_trim[i]
),
}
)
if num_reqs == 1:
sampling_params_list.append(sampling_params[0])
else:
sampling_params_list.append(sampling_params)
if len(all_requests) == 1: if len(all_requests) == 1:
prompt = prompts[0] prompt = prompts[0]
......
...@@ -174,6 +174,7 @@ class CompletionRequest(BaseModel): ...@@ -174,6 +174,7 @@ class CompletionRequest(BaseModel):
min_tokens: int = 0 min_tokens: int = 0
repetition_penalty: Optional[float] = 1.0 repetition_penalty: Optional[float] = 1.0
stop_token_ids: Optional[List[int]] = Field(default_factory=list) stop_token_ids: Optional[List[int]] = Field(default_factory=list)
no_eos_trim: Union[bool, List[bool]] = False
class CompletionResponseChoice(BaseModel): class CompletionResponseChoice(BaseModel):
......
...@@ -40,6 +40,7 @@ class SamplingParams: ...@@ -40,6 +40,7 @@ class SamplingParams:
regex: Optional[str] = None, regex: Optional[str] = None,
n: int = 1, n: int = 1,
json_schema: Optional[str] = None, json_schema: Optional[str] = None,
no_eos_trim: bool = False,
) -> None: ) -> None:
self.temperature = temperature self.temperature = temperature
self.top_p = top_p self.top_p = top_p
...@@ -60,6 +61,7 @@ class SamplingParams: ...@@ -60,6 +61,7 @@ class SamplingParams:
self.regex = regex self.regex = regex
self.n = n self.n = n
self.json_schema = json_schema self.json_schema = json_schema
self.no_eos_trim = no_eos_trim
# Process some special cases # Process some special cases
if self.temperature < _SAMPLING_EPS: if self.temperature < _SAMPLING_EPS:
......
...@@ -690,3 +690,10 @@ def pytorch_profile(name, func, *args, data_size=-1): ...@@ -690,3 +690,10 @@ def pytorch_profile(name, func, *args, data_size=-1):
prof.export_chrome_trace(f"trace/{name}_{step_counter}.json") prof.export_chrome_trace(f"trace/{name}_{step_counter}.json")
step_counter += 1 step_counter += 1
return result return result
def first_rank_print(*args, **kwargs):
if torch.cuda.current_device() == 0:
print(*args, **kwargs)
else:
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment