Unverified Commit 48761171 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

[Fix] fix eos trim inconsistency (#1650)

parent c3f2fc5a
......@@ -18,7 +18,7 @@ limitations under the License.
import dataclasses
import logging
from collections import OrderedDict
from typing import List
from typing import List, Union
import zmq
......@@ -29,7 +29,7 @@ from sglang.srt.managers.io_struct import (
BatchTokenIDOut,
UpdateWeightReqOutput,
)
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import configure_logger, kill_parent_process
from sglang.utils import find_printable_text, get_exception_traceback
......@@ -75,6 +75,21 @@ class DetokenizerManager:
self.decode_status = LimitedCapacityDict()
def trim_eos(self, output: Union[str, List[int]], finished_reason, no_eos_trim):
if no_eos_trim:
return output
# Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit
if isinstance(finished_reason, FINISH_MATCHED_STR) and isinstance(output, str):
pos = output.find(finished_reason.matched)
return output[:pos] if pos != -1 else output
if isinstance(finished_reason, FINISH_MATCHED_TOKEN) and isinstance(
output, list
):
assert len(output) > 0
return output[:-1]
return output
def event_loop(self):
"""The event loop that handles requests"""
......@@ -122,7 +137,13 @@ class DetokenizerManager:
s = self.decode_status[rid]
s.decode_ids = recv_obj.decode_ids[i]
read_ids.append(s.decode_ids[s.surr_offset :])
read_ids.append(
self.trim_eos(
s.decode_ids[s.surr_offset :],
recv_obj.finished_reason[i],
recv_obj.no_eos_trim[i],
)
)
surr_ids.append(s.decode_ids[s.surr_offset : s.read_offset])
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
......@@ -152,13 +173,13 @@ class DetokenizerManager:
else:
new_text = find_printable_text(new_text)
output_strs.append(s.decoded_text + new_text)
# Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit
if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
pos = output_strs[i].find(recv_obj.finished_reason[i].matched)
if pos != -1:
output_strs[i] = output_strs[i][:pos]
output_strs.append(
self.trim_eos(
s.decoded_text + new_text,
recv_obj.finished_reason[i],
recv_obj.no_eos_trim[i],
)
)
self.send_to_tokenizer.send_pyobj(
BatchStrOut(
......
......@@ -295,6 +295,7 @@ class BatchTokenIDOut:
spaces_between_special_tokens: List[bool]
meta_info: List[Dict]
finished_reason: List[BaseFinishReason]
no_eos_trim: List[bool]
@dataclass
......
......@@ -883,6 +883,7 @@ class Scheduler:
output_read_offsets = []
output_skip_special_tokens = []
output_spaces_between_special_tokens = []
output_no_eos_trim = []
else: # embedding or reward model
output_embeddings = []
unfinished_indices = []
......@@ -914,6 +915,7 @@ class Scheduler:
output_spaces_between_special_tokens.append(
req.sampling_params.spaces_between_special_tokens
)
output_no_eos_trim.append(req.sampling_params.no_eos_trim)
meta_info = {
"prompt_tokens": len(req.origin_input_ids),
......@@ -961,6 +963,7 @@ class Scheduler:
output_spaces_between_special_tokens,
output_meta_info,
output_finished_reason,
output_no_eos_trim,
)
)
else: # embedding or reward model
......
......@@ -493,23 +493,38 @@ def v1_generate_request(
top_logprobs_nums.append(
request.logprobs if request.logprobs is not None else 0
)
sampling_params_list.append(
{
"temperature": request.temperature,
"max_new_tokens": request.max_tokens,
"min_new_tokens": request.min_tokens,
"stop": request.stop,
"stop_token_ids": request.stop_token_ids,
"top_p": request.top_p,
"presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty,
"repetition_penalty": request.repetition_penalty,
"regex": request.regex,
"json_schema": request.json_schema,
"n": request.n,
"ignore_eos": request.ignore_eos,
}
)
sampling_params = []
if isinstance(request.no_eos_trim, list):
num_reqs = len(request.prompt)
else:
num_reqs = 1
for i in range(num_reqs):
sampling_params.append(
{
"temperature": request.temperature,
"max_new_tokens": request.max_tokens,
"min_new_tokens": request.min_tokens,
"stop": request.stop,
"stop_token_ids": request.stop_token_ids,
"top_p": request.top_p,
"presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty,
"repetition_penalty": request.repetition_penalty,
"regex": request.regex,
"json_schema": request.json_schema,
"n": request.n,
"ignore_eos": request.ignore_eos,
"no_eos_trim": (
request.no_eos_trim
if not isinstance(request.no_eos_trim, list)
else request.no_eos_trim[i]
),
}
)
if num_reqs == 1:
sampling_params_list.append(sampling_params[0])
else:
sampling_params_list.append(sampling_params)
if len(all_requests) == 1:
prompt = prompts[0]
......
......@@ -174,6 +174,7 @@ class CompletionRequest(BaseModel):
min_tokens: int = 0
repetition_penalty: Optional[float] = 1.0
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
no_eos_trim: Union[bool, List[bool]] = False
class CompletionResponseChoice(BaseModel):
......
......@@ -40,6 +40,7 @@ class SamplingParams:
regex: Optional[str] = None,
n: int = 1,
json_schema: Optional[str] = None,
no_eos_trim: bool = False,
) -> None:
self.temperature = temperature
self.top_p = top_p
......@@ -60,6 +61,7 @@ class SamplingParams:
self.regex = regex
self.n = n
self.json_schema = json_schema
self.no_eos_trim = no_eos_trim
# Process some special cases
if self.temperature < _SAMPLING_EPS:
......
......@@ -690,3 +690,10 @@ def pytorch_profile(name, func, *args, data_size=-1):
prof.export_chrome_trace(f"trace/{name}_{step_counter}.json")
step_counter += 1
return result
def first_rank_print(*args, **kwargs):
if torch.cuda.current_device() == 0:
print(*args, **kwargs)
else:
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment