Unverified Commit 183df472 authored by ZhouXingg's avatar ZhouXingg Committed by GitHub
Browse files

SamplingParams add "spaces_between_special_tokens" argument (#392)

parent 5c5aba59
...@@ -107,6 +107,7 @@ class RuntimeEndpoint(BaseBackend): ...@@ -107,6 +107,7 @@ class RuntimeEndpoint(BaseBackend):
"text": s.text_, "text": s.text_,
"sampling_params": { "sampling_params": {
"skip_special_tokens": global_config.skip_special_tokens_in_output, "skip_special_tokens": global_config.skip_special_tokens_in_output,
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
**sampling_params.to_srt_kwargs(), **sampling_params.to_srt_kwargs(),
}, },
} }
...@@ -115,6 +116,7 @@ class RuntimeEndpoint(BaseBackend): ...@@ -115,6 +116,7 @@ class RuntimeEndpoint(BaseBackend):
"text": s.text_, "text": s.text_,
"sampling_params": { "sampling_params": {
"skip_special_tokens": global_config.skip_special_tokens_in_output, "skip_special_tokens": global_config.skip_special_tokens_in_output,
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
"dtype": "int", "dtype": "int",
**sampling_params.to_srt_kwargs(), **sampling_params.to_srt_kwargs(),
}, },
...@@ -145,6 +147,7 @@ class RuntimeEndpoint(BaseBackend): ...@@ -145,6 +147,7 @@ class RuntimeEndpoint(BaseBackend):
"text": s.text_, "text": s.text_,
"sampling_params": { "sampling_params": {
"skip_special_tokens": global_config.skip_special_tokens_in_output, "skip_special_tokens": global_config.skip_special_tokens_in_output,
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
**sampling_params.to_srt_kwargs(), **sampling_params.to_srt_kwargs(),
}, },
} }
...@@ -153,6 +156,7 @@ class RuntimeEndpoint(BaseBackend): ...@@ -153,6 +156,7 @@ class RuntimeEndpoint(BaseBackend):
"text": s.text_, "text": s.text_,
"sampling_params": { "sampling_params": {
"skip_special_tokens": global_config.skip_special_tokens_in_output, "skip_special_tokens": global_config.skip_special_tokens_in_output,
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
"dtype": "int", "dtype": "int",
**sampling_params.to_srt_kwargs(), **sampling_params.to_srt_kwargs(),
}, },
......
...@@ -12,6 +12,7 @@ class GlobalConfig: ...@@ -12,6 +12,7 @@ class GlobalConfig:
# Output configs # Output configs
self.skip_special_tokens_in_output = True self.skip_special_tokens_in_output = True
self.spaces_between_special_tokens_in_out = True
# Optimization configs # Optimization configs
self.eager_fill_image = False self.eager_fill_image = False
......
...@@ -38,10 +38,11 @@ class DetokenizerManager: ...@@ -38,10 +38,11 @@ class DetokenizerManager:
if isinstance(recv_obj, BatchTokenIDOut): if isinstance(recv_obj, BatchTokenIDOut):
output_tokens = recv_obj.output_tokens output_tokens = recv_obj.output_tokens
# TODO(lmzheng): handle skip_special_tokens per request # TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
output_strs = self.tokenizer.batch_decode( output_strs = self.tokenizer.batch_decode(
output_tokens, output_tokens,
skip_special_tokens=recv_obj.skip_special_tokens[0], skip_special_tokens=recv_obj.skip_special_tokens[0],
spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
) )
# Trim stop str # Trim stop str
......
...@@ -97,6 +97,7 @@ class BatchTokenIDOut: ...@@ -97,6 +97,7 @@ class BatchTokenIDOut:
output_and_jump_forward_strs: List[str] output_and_jump_forward_strs: List[str]
hit_stop_str: List[Optional[str]] hit_stop_str: List[Optional[str]]
skip_special_tokens: List[bool] skip_special_tokens: List[bool]
spaces_between_special_tokens: List[bool]
meta_info: List[Dict] meta_info: List[Dict]
finished: List[bool] finished: List[bool]
......
...@@ -549,6 +549,7 @@ class ModelRpcServer: ...@@ -549,6 +549,7 @@ class ModelRpcServer:
output_and_jump_forward_strs = [] output_and_jump_forward_strs = []
output_hit_stop_str = [] output_hit_stop_str = []
output_skip_special_tokens = [] output_skip_special_tokens = []
output_spaces_between_special_tokens = []
output_meta_info = [] output_meta_info = []
output_finished = [] output_finished = []
finished_indices = [] finished_indices = []
...@@ -575,6 +576,9 @@ class ModelRpcServer: ...@@ -575,6 +576,9 @@ class ModelRpcServer:
output_skip_special_tokens.append( output_skip_special_tokens.append(
req.sampling_params.skip_special_tokens req.sampling_params.skip_special_tokens
) )
output_spaces_between_special_tokens.append(
req.sampling_params.spaces_between_special_tokens
)
meta_info = { meta_info = {
"prompt_tokens": req.prompt_tokens, "prompt_tokens": req.prompt_tokens,
...@@ -609,6 +613,7 @@ class ModelRpcServer: ...@@ -609,6 +613,7 @@ class ModelRpcServer:
output_and_jump_forward_strs, output_and_jump_forward_strs,
output_hit_stop_str, output_hit_stop_str,
output_skip_special_tokens, output_skip_special_tokens,
output_spaces_between_special_tokens,
output_meta_info, output_meta_info,
output_finished, output_finished,
) )
......
...@@ -17,6 +17,7 @@ class SamplingParams: ...@@ -17,6 +17,7 @@ class SamplingParams:
presence_penalty: float = 0.0, presence_penalty: float = 0.0,
ignore_eos: bool = False, ignore_eos: bool = False,
skip_special_tokens: bool = True, skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True,
dtype: Optional[str] = None, dtype: Optional[str] = None,
regex: Optional[str] = None, regex: Optional[str] = None,
) -> None: ) -> None:
...@@ -29,6 +30,7 @@ class SamplingParams: ...@@ -29,6 +30,7 @@ class SamplingParams:
self.max_new_tokens = max_new_tokens self.max_new_tokens = max_new_tokens
self.ignore_eos = ignore_eos self.ignore_eos = ignore_eos
self.skip_special_tokens = skip_special_tokens self.skip_special_tokens = skip_special_tokens
self.spaces_between_special_tokens = spaces_between_special_tokens
self.dtype = dtype self.dtype = dtype
self.regex = regex self.regex = regex
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment