Unverified Commit 93dffd69 authored by zifeitong's avatar zifeitong Committed by GitHub
Browse files

Add constrained_json_whitespace_pattern to ServerArgs (#1438)

parent 2abe4f1c
......@@ -29,6 +29,7 @@ class FSMCache(BaseToolCache):
tokenizer_args_dict,
enable=True,
skip_tokenizer_init=False,
constrained_json_whitespace_pattern=None,
):
super().__init__(enable=enable)
......@@ -63,11 +64,14 @@ class FSMCache(BaseToolCache):
self.outlines_tokenizer.vocabulary = (
self.outlines_tokenizer.tokenizer.get_vocab()
)
self.constrained_json_whitespace_pattern = constrained_json_whitespace_pattern
def init_value(self, key):
key_type, key_string = key
if key_type == "json":
regex = build_regex_from_schema(key_string, whitespace_pattern=r"[\n\t ]*")
regex = build_regex_from_schema(
key_string, whitespace_pattern=self.constrained_json_whitespace_pattern
)
elif key_type == "regex":
regex = key_string
else:
......
......@@ -198,6 +198,7 @@ class ModelTpServer:
"trust_remote_code": server_args.trust_remote_code,
},
skip_tokenizer_init=server_args.skip_tokenizer_init,
constrained_json_whitespace_pattern=server_args.constrained_json_whitespace_pattern,
)
self.jump_forward_cache = JumpForwardCache()
......@@ -807,12 +808,10 @@ class ModelTpServer:
unfinished_indices.append(i)
if req.finished() or (
(
req.stream
and (
self.decode_forward_ct % self.stream_interval == 0
or len(req.output_ids) == 1
)
req.stream
and (
self.decode_forward_ct % self.stream_interval == 0
or len(req.output_ids) == 1
)
):
output_rids.append(req.rid)
......
......@@ -70,6 +70,7 @@ class ServerArgs:
tp_size: int = 1
stream_interval: int = 1
random_seed: Optional[int] = None
constrained_json_whitespace_pattern: Optional[str] = None
# Logging
log_level: str = "info"
......@@ -370,6 +371,12 @@ class ServerArgs:
default=ServerArgs.random_seed,
help="The random seed.",
)
parser.add_argument(
"--constrained-json-whitespace-pattern",
type=str,
default=ServerArgs.constrained_json_whitespace_pattern,
help=r"Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*",
)
parser.add_argument(
"--log-level",
type=str,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment