Unverified Commit 2725f8da authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

[Minor] Rename no_eos_trim to no_stop_trim (#1661)

parent da1ffed6
...@@ -75,8 +75,8 @@ class DetokenizerManager: ...@@ -75,8 +75,8 @@ class DetokenizerManager:
self.decode_status = LimitedCapacityDict() self.decode_status = LimitedCapacityDict()
def trim_eos(self, output: Union[str, List[int]], finished_reason, no_eos_trim): def trim_eos(self, output: Union[str, List[int]], finished_reason, no_stop_trim):
if no_eos_trim: if no_stop_trim:
return output return output
# Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit # Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit
...@@ -141,7 +141,7 @@ class DetokenizerManager: ...@@ -141,7 +141,7 @@ class DetokenizerManager:
self.trim_eos( self.trim_eos(
s.decode_ids[s.surr_offset :], s.decode_ids[s.surr_offset :],
recv_obj.finished_reason[i], recv_obj.finished_reason[i],
recv_obj.no_eos_trim[i], recv_obj.no_stop_trim[i],
) )
) )
surr_ids.append(s.decode_ids[s.surr_offset : s.read_offset]) surr_ids.append(s.decode_ids[s.surr_offset : s.read_offset])
...@@ -177,7 +177,7 @@ class DetokenizerManager: ...@@ -177,7 +177,7 @@ class DetokenizerManager:
self.trim_eos( self.trim_eos(
s.decoded_text + new_text, s.decoded_text + new_text,
recv_obj.finished_reason[i], recv_obj.finished_reason[i],
recv_obj.no_eos_trim[i], recv_obj.no_stop_trim[i],
) )
) )
......
...@@ -295,7 +295,7 @@ class BatchTokenIDOut: ...@@ -295,7 +295,7 @@ class BatchTokenIDOut:
spaces_between_special_tokens: List[bool] spaces_between_special_tokens: List[bool]
meta_info: List[Dict] meta_info: List[Dict]
finished_reason: List[BaseFinishReason] finished_reason: List[BaseFinishReason]
no_eos_trim: List[bool] no_stop_trim: List[bool]
@dataclass @dataclass
......
...@@ -885,7 +885,7 @@ class Scheduler: ...@@ -885,7 +885,7 @@ class Scheduler:
output_read_offsets = [] output_read_offsets = []
output_skip_special_tokens = [] output_skip_special_tokens = []
output_spaces_between_special_tokens = [] output_spaces_between_special_tokens = []
output_no_eos_trim = [] output_no_stop_trim = []
else: # embedding or reward model else: # embedding or reward model
output_embeddings = [] output_embeddings = []
unfinished_indices = [] unfinished_indices = []
...@@ -917,7 +917,7 @@ class Scheduler: ...@@ -917,7 +917,7 @@ class Scheduler:
output_spaces_between_special_tokens.append( output_spaces_between_special_tokens.append(
req.sampling_params.spaces_between_special_tokens req.sampling_params.spaces_between_special_tokens
) )
output_no_eos_trim.append(req.sampling_params.no_eos_trim) output_no_stop_trim.append(req.sampling_params.no_stop_trim)
meta_info = { meta_info = {
"prompt_tokens": len(req.origin_input_ids), "prompt_tokens": len(req.origin_input_ids),
...@@ -968,7 +968,7 @@ class Scheduler: ...@@ -968,7 +968,7 @@ class Scheduler:
output_spaces_between_special_tokens, output_spaces_between_special_tokens,
output_meta_info, output_meta_info,
output_finished_reason, output_finished_reason,
output_no_eos_trim, output_no_stop_trim,
) )
) )
else: # embedding or reward model else: # embedding or reward model
......
...@@ -494,7 +494,7 @@ def v1_generate_request( ...@@ -494,7 +494,7 @@ def v1_generate_request(
request.logprobs if request.logprobs is not None else 0 request.logprobs if request.logprobs is not None else 0
) )
sampling_params = [] sampling_params = []
if isinstance(request.no_eos_trim, list): if isinstance(request.no_stop_trim, list):
num_reqs = len(request.prompt) num_reqs = len(request.prompt)
else: else:
num_reqs = 1 num_reqs = 1
...@@ -514,10 +514,10 @@ def v1_generate_request( ...@@ -514,10 +514,10 @@ def v1_generate_request(
"json_schema": request.json_schema, "json_schema": request.json_schema,
"n": request.n, "n": request.n,
"ignore_eos": request.ignore_eos, "ignore_eos": request.ignore_eos,
"no_eos_trim": ( "no_stop_trim": (
request.no_eos_trim request.no_stop_trim
if not isinstance(request.no_eos_trim, list) if not isinstance(request.no_stop_trim, list)
else request.no_eos_trim[i] else request.no_stop_trim[i]
), ),
} }
) )
......
...@@ -174,7 +174,7 @@ class CompletionRequest(BaseModel): ...@@ -174,7 +174,7 @@ class CompletionRequest(BaseModel):
min_tokens: int = 0 min_tokens: int = 0
repetition_penalty: Optional[float] = 1.0 repetition_penalty: Optional[float] = 1.0
stop_token_ids: Optional[List[int]] = Field(default_factory=list) stop_token_ids: Optional[List[int]] = Field(default_factory=list)
no_eos_trim: Union[bool, List[bool]] = False no_stop_trim: Union[bool, List[bool]] = False
class CompletionResponseChoice(BaseModel): class CompletionResponseChoice(BaseModel):
......
...@@ -40,7 +40,7 @@ class SamplingParams: ...@@ -40,7 +40,7 @@ class SamplingParams:
regex: Optional[str] = None, regex: Optional[str] = None,
n: int = 1, n: int = 1,
json_schema: Optional[str] = None, json_schema: Optional[str] = None,
no_eos_trim: bool = False, no_stop_trim: bool = False,
) -> None: ) -> None:
self.temperature = temperature self.temperature = temperature
self.top_p = top_p self.top_p = top_p
...@@ -61,7 +61,7 @@ class SamplingParams: ...@@ -61,7 +61,7 @@ class SamplingParams:
self.regex = regex self.regex = regex
self.n = n self.n = n
self.json_schema = json_schema self.json_schema = json_schema
self.no_eos_trim = no_eos_trim self.no_stop_trim = no_stop_trim
# Process some special cases # Process some special cases
if self.temperature < _SAMPLING_EPS: if self.temperature < _SAMPLING_EPS:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment