Unverified Commit 183df472 authored by ZhouXingg's avatar ZhouXingg Committed by GitHub
Browse files

SamplingParams add "spaces_between_special_tokens" argument (#392)

parent 5c5aba59
......@@ -107,6 +107,7 @@ class RuntimeEndpoint(BaseBackend):
"text": s.text_,
"sampling_params": {
"skip_special_tokens": global_config.skip_special_tokens_in_output,
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
**sampling_params.to_srt_kwargs(),
},
}
......@@ -115,6 +116,7 @@ class RuntimeEndpoint(BaseBackend):
"text": s.text_,
"sampling_params": {
"skip_special_tokens": global_config.skip_special_tokens_in_output,
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
"dtype": "int",
**sampling_params.to_srt_kwargs(),
},
......@@ -145,6 +147,7 @@ class RuntimeEndpoint(BaseBackend):
"text": s.text_,
"sampling_params": {
"skip_special_tokens": global_config.skip_special_tokens_in_output,
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
**sampling_params.to_srt_kwargs(),
},
}
......@@ -153,6 +156,7 @@ class RuntimeEndpoint(BaseBackend):
"text": s.text_,
"sampling_params": {
"skip_special_tokens": global_config.skip_special_tokens_in_output,
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
"dtype": "int",
**sampling_params.to_srt_kwargs(),
},
......
......@@ -12,6 +12,7 @@ class GlobalConfig:
# Output configs
self.skip_special_tokens_in_output = True
self.spaces_between_special_tokens_in_out = True
# Optimization configs
self.eager_fill_image = False
......
......@@ -38,10 +38,11 @@ class DetokenizerManager:
if isinstance(recv_obj, BatchTokenIDOut):
output_tokens = recv_obj.output_tokens
# TODO(lmzheng): handle skip_special_tokens per request
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
output_strs = self.tokenizer.batch_decode(
output_tokens,
skip_special_tokens=recv_obj.skip_special_tokens[0],
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
)
# Trim stop str
......
......@@ -97,6 +97,7 @@ class BatchTokenIDOut:
output_and_jump_forward_strs: List[str]
hit_stop_str: List[Optional[str]]
skip_special_tokens: List[bool]
spaces_between_special_tokens: List[bool]
meta_info: List[Dict]
finished: List[bool]
......
......@@ -549,6 +549,7 @@ class ModelRpcServer:
output_and_jump_forward_strs = []
output_hit_stop_str = []
output_skip_special_tokens = []
output_spaces_between_special_tokens = []
output_meta_info = []
output_finished = []
finished_indices = []
......@@ -575,6 +576,9 @@ class ModelRpcServer:
output_skip_special_tokens.append(
req.sampling_params.skip_special_tokens
)
output_spaces_between_special_tokens.append(
req.sampling_params.spaces_between_special_tokens
)
meta_info = {
"prompt_tokens": req.prompt_tokens,
......@@ -609,6 +613,7 @@ class ModelRpcServer:
output_and_jump_forward_strs,
output_hit_stop_str,
output_skip_special_tokens,
output_spaces_between_special_tokens,
output_meta_info,
output_finished,
)
......
......@@ -17,6 +17,7 @@ class SamplingParams:
presence_penalty: float = 0.0,
ignore_eos: bool = False,
skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True,
dtype: Optional[str] = None,
regex: Optional[str] = None,
) -> None:
......@@ -29,6 +30,7 @@ class SamplingParams:
self.max_new_tokens = max_new_tokens
self.ignore_eos = ignore_eos
self.skip_special_tokens = skip_special_tokens
self.spaces_between_special_tokens = spaces_between_special_tokens
self.dtype = dtype
self.regex = regex
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment