Unverified Commit c3845d82 authored by Robert Caulk's avatar Robert Caulk Committed by GitHub
Browse files

Allow user to define whitespace pattern for outlines (#4305)

parent a822eb34
...@@ -57,7 +57,9 @@ def test_guided_logits_processors(): ...@@ -57,7 +57,9 @@ def test_guided_logits_processors():
"""Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor.""" """Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta') tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
regex_LP = RegexLogitsProcessor(TEST_REGEX, tokenizer) regex_LP = RegexLogitsProcessor(TEST_REGEX, tokenizer)
json_LP = JSONLogitsProcessor(TEST_SCHEMA, tokenizer) json_LP = JSONLogitsProcessor(TEST_SCHEMA,
tokenizer,
whitespace_pattern=None)
regex_LP.init_state() regex_LP.init_state()
token_ids = tokenizer.encode( token_ids = tokenizer.encode(
......
...@@ -146,6 +146,11 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -146,6 +146,11 @@ class ChatCompletionRequest(OpenAIBaseModel):
"If specified, will override the default guided decoding backend " "If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be either " "of the server for this specific request. If set, must be either "
"'outlines' / 'lm-format-enforcer'")) "'outlines' / 'lm-format-enforcer'"))
guided_whitespace_pattern: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default whitespace pattern "
"for guided json decoding."))
# doc: end-chat-completion-extra-params # doc: end-chat-completion-extra-params
...@@ -285,6 +290,11 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -285,6 +290,11 @@ class CompletionRequest(OpenAIBaseModel):
"If specified, will override the default guided decoding backend " "If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be one of " "of the server for this specific request. If set, must be one of "
"'outlines' / 'lm-format-enforcer'")) "'outlines' / 'lm-format-enforcer'"))
guided_whitespace_pattern: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default whitespace pattern "
"for guided json decoding."))
# doc: end-completion-extra-params # doc: end-completion-extra-params
......
...@@ -74,7 +74,8 @@ async def get_outlines_guided_decoding_logits_processor( ...@@ -74,7 +74,8 @@ async def get_outlines_guided_decoding_logits_processor(
result = await loop.run_in_executor(global_thread_pool, result = await loop.run_in_executor(global_thread_pool,
_get_cached_logits_processor, guide, _get_cached_logits_processor, guide,
tokenizer, mode) tokenizer, mode,
request.guided_whitespace_pattern)
logits_processor = copy(result) logits_processor = copy(result)
# reset logits processor's internal state # reset logits processor's internal state
...@@ -117,9 +118,10 @@ def _get_guide_and_mode( ...@@ -117,9 +118,10 @@ def _get_guide_and_mode(
@lru_cache(maxsize=32) @lru_cache(maxsize=32)
def _get_cached_logits_processor(guide: str, def _get_cached_logits_processor(guide: str,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
mode: GuidedDecodingMode): mode: GuidedDecodingMode,
whitespace_pattern: Union[str, None]):
if mode == GuidedDecodingMode.JSON: if mode == GuidedDecodingMode.JSON:
return JSONLogitsProcessor(guide, tokenizer) return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern)
elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE: elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
return RegexLogitsProcessor(guide, tokenizer) return RegexLogitsProcessor(guide, tokenizer)
elif mode == GuidedDecodingMode.GRAMMAR: elif mode == GuidedDecodingMode.GRAMMAR:
......
...@@ -18,7 +18,7 @@ import json ...@@ -18,7 +18,7 @@ import json
import math import math
from collections import defaultdict from collections import defaultdict
from functools import lru_cache from functools import lru_cache
from typing import Callable, DefaultDict, Dict, List, Optional, Union from typing import Callable, DefaultDict, Dict, List, Union
import torch import torch
from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM
...@@ -80,10 +80,9 @@ class RegexLogitsProcessor(BaseLogitsProcessor): ...@@ -80,10 +80,9 @@ class RegexLogitsProcessor(BaseLogitsProcessor):
class JSONLogitsProcessor(RegexLogitsProcessor): class JSONLogitsProcessor(RegexLogitsProcessor):
def __init__(self, def __init__(self, schema: Union[str, Dict, BaseModel],
schema: Union[str, Dict, BaseModel],
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
whitespace_pattern: Optional[str] = None): whitespace_pattern: Union[str, None]):
"""Compile the FSM that drives the JSON-guided generation. """Compile the FSM that drives the JSON-guided generation.
Parameters Parameters
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment