Unverified Commit 93dffd69 authored by zifeitong's avatar zifeitong Committed by GitHub
Browse files

Add constrained_json_whitespace_pattern to ServerArgs (#1438)

parent 2abe4f1c
...@@ -29,6 +29,7 @@ class FSMCache(BaseToolCache): ...@@ -29,6 +29,7 @@ class FSMCache(BaseToolCache):
tokenizer_args_dict, tokenizer_args_dict,
enable=True, enable=True,
skip_tokenizer_init=False, skip_tokenizer_init=False,
constrained_json_whitespace_pattern=None,
): ):
super().__init__(enable=enable) super().__init__(enable=enable)
...@@ -63,11 +64,14 @@ class FSMCache(BaseToolCache): ...@@ -63,11 +64,14 @@ class FSMCache(BaseToolCache):
self.outlines_tokenizer.vocabulary = ( self.outlines_tokenizer.vocabulary = (
self.outlines_tokenizer.tokenizer.get_vocab() self.outlines_tokenizer.tokenizer.get_vocab()
) )
self.constrained_json_whitespace_pattern = constrained_json_whitespace_pattern
def init_value(self, key): def init_value(self, key):
key_type, key_string = key key_type, key_string = key
if key_type == "json": if key_type == "json":
regex = build_regex_from_schema(key_string, whitespace_pattern=r"[\n\t ]*") regex = build_regex_from_schema(
key_string, whitespace_pattern=self.constrained_json_whitespace_pattern
)
elif key_type == "regex": elif key_type == "regex":
regex = key_string regex = key_string
else: else:
......
...@@ -198,6 +198,7 @@ class ModelTpServer: ...@@ -198,6 +198,7 @@ class ModelTpServer:
"trust_remote_code": server_args.trust_remote_code, "trust_remote_code": server_args.trust_remote_code,
}, },
skip_tokenizer_init=server_args.skip_tokenizer_init, skip_tokenizer_init=server_args.skip_tokenizer_init,
constrained_json_whitespace_pattern=server_args.constrained_json_whitespace_pattern,
) )
self.jump_forward_cache = JumpForwardCache() self.jump_forward_cache = JumpForwardCache()
...@@ -807,13 +808,11 @@ class ModelTpServer: ...@@ -807,13 +808,11 @@ class ModelTpServer:
unfinished_indices.append(i) unfinished_indices.append(i)
if req.finished() or ( if req.finished() or (
(
req.stream req.stream
and ( and (
self.decode_forward_ct % self.stream_interval == 0 self.decode_forward_ct % self.stream_interval == 0
or len(req.output_ids) == 1 or len(req.output_ids) == 1
) )
)
): ):
output_rids.append(req.rid) output_rids.append(req.rid)
output_finished_reason.append(req.finished_reason) output_finished_reason.append(req.finished_reason)
......
...@@ -70,6 +70,7 @@ class ServerArgs: ...@@ -70,6 +70,7 @@ class ServerArgs:
tp_size: int = 1 tp_size: int = 1
stream_interval: int = 1 stream_interval: int = 1
random_seed: Optional[int] = None random_seed: Optional[int] = None
constrained_json_whitespace_pattern: Optional[str] = None
# Logging # Logging
log_level: str = "info" log_level: str = "info"
...@@ -370,6 +371,12 @@ class ServerArgs: ...@@ -370,6 +371,12 @@ class ServerArgs:
default=ServerArgs.random_seed, default=ServerArgs.random_seed,
help="The random seed.", help="The random seed.",
) )
parser.add_argument(
"--constrained-json-whitespace-pattern",
type=str,
default=ServerArgs.constrained_json_whitespace_pattern,
help=r"Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*",
)
parser.add_argument( parser.add_argument(
"--log-level", "--log-level",
type=str, type=str,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment