"examples/cpp/vscode:/vscode.git/clone" did not exist on "4b12118014056b3ede06b7a5d41ea173dc55a548"
Unverified Commit 37b42297 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

import outlines (#168)

parent cba50273
...@@ -5,10 +5,9 @@ python json_decode.py ...@@ -5,10 +5,9 @@ python json_decode.py
""" """
from enum import Enum from enum import Enum
from pydantic import BaseModel, constr
import sglang as sgl import sglang as sgl
from sglang.srt.constrained.json_schema import build_regex_from_object from pydantic import BaseModel
from sglang.srt.constrained import build_regex_from_object
character_regex = ( character_regex = (
r"""\{\n""" r"""\{\n"""
...@@ -30,7 +29,10 @@ character_regex = ( ...@@ -30,7 +29,10 @@ character_regex = (
@sgl.function @sgl.function
def character_gen(s, name): def character_gen(s, name):
s += name + " is a character in Harry Potter. Please fill in the following information about this character.\n" s += (
name
+ " is a character in Harry Potter. Please fill in the following information about this character.\n"
)
s += sgl.gen("json_output", max_tokens=256, regex=character_regex) s += sgl.gen("json_output", max_tokens=256, regex=character_regex)
...@@ -65,11 +67,6 @@ def pydantic_wizard_gen(s): ...@@ -65,11 +67,6 @@ def pydantic_wizard_gen(s):
) )
def driver_character_gen():
state = character_gen.run(name="Hermione Granger")
print(state.text())
def driver_pydantic_wizard_gen(): def driver_pydantic_wizard_gen():
state = pydantic_wizard_gen.run() state = pydantic_wizard_gen.run()
print(state.text()) print(state.text())
......
...@@ -20,7 +20,7 @@ dependencies = [ ...@@ -20,7 +20,7 @@ dependencies = [
[project.optional-dependencies] [project.optional-dependencies]
srt = ["aiohttp", "fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn", srt = ["aiohttp", "fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn",
"zmq", "vllm>=0.2.5", "interegular", "lark", "numba", "zmq", "vllm>=0.2.5", "interegular", "lark", "numba",
"pydantic", "referencing", "diskcache", "cloudpickle", "pillow"] "pydantic", "referencing", "diskcache", "cloudpickle", "pillow", "outlines>=0.0.27"]
openai = ["openai>=1.0", "numpy"] openai = ["openai>=1.0", "numpy"]
anthropic = ["anthropic", "numpy"] anthropic = ["anthropic", "numpy"]
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"] all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"]
......
from outlines.caching import cache as disk_cache
from outlines.caching import disable_cache
from outlines.fsm.fsm import RegexFSM
from outlines.fsm.json_schema import build_regex_from_object
from outlines.fsm.regex import FSMInfo, make_deterministic_fsm
from outlines.models.transformers import TransformerTokenizer
__all__ = [
"RegexFSM",
"FSMInfo",
"make_deterministic_fsm",
"build_regex_from_object",
"TransformerTokenizer",
"disk_cache",
"disable_cache",
]
# Adapted from:
# https://github.com/outlines-dev/outlines/blob/6c6966cfa24e9c120494ebb317c6126aa2ae94af/outlines/caching.py
import asyncio
import hashlib
import os
from typing import Callable, Optional
import cloudpickle
from diskcache import Cache
home_dir = os.path.expanduser("~")
cache_dir = os.environ.get("SGLANG_CACHE_DIR", f"{home_dir}/.cache/sglang")
memory = Cache(cache_dir, eviction_policy="none", cull_limit=0)
_caching_enabled = True
def hash_arguments(*args, **kwargs) -> str:
"""Create a hash out of the args and kwargs provided"""
result = hashlib.md5()
for item in list(args) + sorted(kwargs.items()):
result.update(cloudpickle.dumps(item))
return result.hexdigest()
def disk_cache(key_function: Optional[Callable] = None):
def decorator(cached_function: Callable):
def wrapper(*args, **kwargs):
if not _caching_enabled:
return cached_function(*args, **kwargs)
if key_function:
key_args = key_function(*args, **kwargs)
cache_key = hash_arguments(*key_args)
else:
cache_key = hash_arguments(*args, **kwargs)
if cache_key in memory:
return memory[cache_key]
result = cached_function(*args, **kwargs)
memory[cache_key] = result
return result
async def async_wrapper(*args, **kwargs):
if not _caching_enabled:
return await cached_function(*args, **kwargs)
if key_function:
key_args = key_function(*args, **kwargs)
cache_key = hash_arguments(*key_args)
else:
cache_key = hash_arguments(*args, **kwargs)
if cache_key in memory:
return memory[cache_key]
result = await cached_function(*args, **kwargs)
memory[cache_key] = result
return result
if asyncio.iscoroutinefunction(cached_function):
return async_wrapper
else:
return wrapper
return decorator
def disable_cache():
global _caching_enabled
_caching_enabled = False
def clear_cache():
global memory
memory.clear()
# Adapted from:
# https://github.com/outlines-dev/outlines/blob/6c6966cfa24e9c120494ebb317c6126aa2ae94af/outlines/fsm/fsm.py
from typing import List, NewType, Protocol, Tuple
import interegular
from lark import Lark
from sglang.srt.constrained.disk_cache import disk_cache
# from outlines.fsm.parsing import PartialLark
from sglang.srt.constrained.regex import (
create_fsm_index_tokenizer,
make_deterministic_fsm,
)
from sglang.srt.constrained.tokenizer import Tokenizer
FSMState = NewType("FSMState", int)
class FSM(Protocol):
def allowed_token_ids(self, state: FSMState) -> List[int]:
...
def next_state(self, state: FSMState, token_id: int) -> FSMState:
...
def is_final_state(self, state: FSMState) -> bool:
...
def copy(self) -> "FSM":
...
class StopAtTokenFSM(FSM):
"""FSM to generate text until a specified token id is generated or
a specified number of tokens has been generated.
Text is usually produced until the EOS token is generated by the
model.
"""
def __init__(self, tokenizer: "Tokenizer", stop_token_id: int):
self.stop_token_id = stop_token_id
self.vocabulary = tokenizer.vocabulary.values()
self.final_states = {1}
def allowed_token_ids(self, state: FSMState) -> List[int]:
"""Generate a list of allowed tokens for the next step.
When in the initial state we allow every token to be generated.
In the final state the only allowed token is `stop_token_id`.
Parameters
----------
state
The current state of the FSM.
Returns
-------
A list that contains the tokens to mask.
"""
if state == 0:
return list(self.vocabulary)
else:
return [self.stop_token_id]
def next_state(self, state: FSMState, token_id: int) -> FSMState:
"""Update the state of the FSM.
The FSM stays in the initial state `0` unless the specified stop token
has been generated or the maximum number of tokens has been reached. In
which case the FSM moves to the final state `1`.
Parameters
----------
state
The current state of the FSM.
token_id
The id of the token that was just generated.
Returns
-------
The new state of the FSM.
"""
if token_id == self.stop_token_id:
return FSMState(1)
return FSMState(0)
def is_final_state(self, state: FSMState) -> bool:
"""Determine whether the current state of the FSM is a final state."""
return state in self.final_states
def copy(self) -> "StopAtTokenFSM":
"""Create a copy of the FSM."""
return self
class RegexFSM(FSM):
"""FSM to generate text that is in the language of a regular expression."""
def __init__(
self,
regex_string: str,
tokenizer: "Tokenizer",
):
@disk_cache()
def create_states_mapping(
regex_string: str, cacheable_vocabulary: Tuple[Tuple[str, int]]
) -> Tuple[dict, set, set]:
"""Create the variables related to the mapping between states and tokens
The parameters of the function are used for caching purpose
"""
regex_pattern = interegular.parse_pattern(regex_string)
regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())
(
states_to_token_maps,
empty_token_ids,
) = create_fsm_index_tokenizer(regex_fsm, tokenizer)
# We make sure that it is possible to generate strings in the language
# of the regular expression with the tokens present in the model's
# vocabulary.
if not any(
regex_fsm.finals.intersection(v.values())
for v in states_to_token_maps.values()
):
raise ValueError(
"The vocabulary does not allow us to build a sequence that matches the input regex"
)
final_states = regex_fsm.finals | {
-1
} # Include the EOS token in final states
return states_to_token_maps, empty_token_ids, final_states
(
self.states_to_token_maps,
self.empty_token_ids,
self.final_states,
) = create_states_mapping(
regex_string, tuple(sorted(tokenizer.vocabulary.items()))
)
self.num_tokens_generated = 0
self.vocabulary = tokenizer.vocabulary.values()
self.end_token_id = tokenizer.eos_token_id
def allowed_token_ids(self, state: FSMState) -> List[int]:
"""Generate a list of allowed tokens for the next step.
The initialization of the FSM builds an index which maps FSM states to a
map from authorized tokens to the state in which the FSM needs to move
if said token is generated. Therefore the authorized tokens at the
current state are the keys of the map returned by the value of the index
for current state.
If the current state is not contained in the end this means that we are
in a final state of the FSM. We only authorize EOS tokens in the final
state.
Parameters
----------
state
The current state of the FSM.
Returns
-------
A list that contains the tokens to mask.
"""
next_tokens_to_end_states = self.states_to_token_maps.get(state)
if next_tokens_to_end_states is None:
return [self.end_token_id]
else:
return list(next_tokens_to_end_states.keys())
def next_state(self, state: FSMState, token_id: int) -> FSMState:
"""Update the state of the FSM.
We use the index to determine to which state the FSM should transition
given the token that was just generated.
Parameters
----------
state
The current state of the FSM.
token_id
The id of the token that was just generated.
Returns
-------
The new state of the FSM.
"""
if token_id == self.end_token_id:
return FSMState(-1)
last_token_to_end_state = self.states_to_token_maps[state]
next_state = last_token_to_end_state.get(token_id)
if next_state is None:
next_state = -1
return FSMState(next_state)
def is_final_state(self, state: FSMState) -> bool:
"""Determine whether the current state of the FSM is a final state."""
return state in self.final_states
def copy(self) -> "RegexFSM":
"""Create a copy of the FSM."""
return self
class CFGFSM(FSM):
"""FSM to generate text that is in the language of a context-free grammar."""
def __init__(self, cfg_string: str, tokenizer: "Tokenizer"):
self.cfg_string = cfg_string
self.tokenizer = tokenizer
self.parser = Lark(
cfg_string,
parser="lalr",
lexer="contextual",
propagate_positions=False,
maybe_placeholders=False,
regex=True,
)
self.terminal_regexps = dict()
for terminal in self.parser.terminals:
if terminal.pattern is not None:
self.terminal_regexps[terminal.name] = terminal.pattern.to_regexp()
self.terminal_regexps["$END"] = tokenizer.eos_token
self.generation = ""
self.reset_state = False
self.allow_eos = False
self.done = False
self.regex_fsm: RegexFSM
def _set_next_regex_fsm(self) -> None:
"""Use the CFG incremental parser to set the next regex FSM.
Check what the CFG incremental parser proposes next:
- If the only proposal is the EOS token we set the state to done and
return.
- If there are other proposals, we set a new regex FSM and return.
"""
interactive = self.parser.parse_interactive(self.generation)
interactive.exhaust_lexer()
options = {self.terminal_regexps[x] for x in interactive.accepts()}
if self.terminal_regexps["$END"] in options:
options.remove(self.terminal_regexps["$END"])
if len(options) == 0:
self.done = True
return
self.allow_eos = True
options.add("")
assert len(options) > 1
regex_string = r"(" + r"|".join([r"(" + x + r")" for x in options]) + r")"
self.regex_fsm = RegexFSM(regex_string, self.tokenizer)
self.reset_state = True
def allowed_token_ids(self, state: FSMState) -> List[int]:
"""Generate a list of allowed tokens for the next step.
Upon initialization, the CFG incremental parser is used to determine the
first regex.
This regex is used for proposals until either:
- The regex is exhausted, and its only remaining option is the EOS
token, in which case we always transition to the next regex
- The regex can be exhausted, but the EOS token is not the only
remaining option, in which case we transition to the next regex with
probability P (TODO) or remove the possibility of generating the EOS
token and continue with the current regex
The CFG incremental parser is allowed to propose the EOS token from any final state,
and once it is generated, the FSM will continue to always generate the EOS token.
Parameters
----------
state
The current state of the FSM.
Returns
-------
A list that contains the tokens to mask.
"""
if self.generation != "":
proposal = self.regex_fsm.allowed_token_ids(state)
if self.tokenizer.eos_token_id not in proposal:
return proposal
if set(proposal) != {self.tokenizer.eos_token_id}:
if False: # TODO: THIS NEEDS TO BE SAMPLED
proposal = [x for x in proposal if x != self.tokenizer.eos_token_id]
return proposal
self._set_next_regex_fsm()
if self.done:
return [self.tokenizer.eos_token_id]
if self.reset_state:
state = FSMState(0)
proposal = self.regex_fsm.allowed_token_ids(state)
if self.allow_eos:
self.allow_eos = False
else:
proposal = [x for x in proposal if x != self.tokenizer.eos_token_id]
assert len(proposal) > 0
return proposal
def next_state(self, state: FSMState, token_id: int) -> FSMState:
"""Update the state of the FSM.
Transitions the underlying regex FSM to its next state.
If at max tokens or EOS token, transition permanently to the final state.
Update stored partial generations for subsequent incremental parsing.
Parameters
----------
state
The current state of the FSM.
token_id
The id of the token that was just generated.
Returns
-------
The new state of the FSM.
"""
if token_id == self.tokenizer.eos_token_id:
self.done = True
return FSMState(-1)
if self.reset_state:
self.reset_state = False
state = FSMState(0)
self.generation += self.tokenizer.decode([token_id])[0]
return self.regex_fsm.next_state(state, token_id)
def is_final_state(self, state: FSMState) -> bool:
"""Return whether the current state of the FSM is a final state."""
return self.done
def copy(self) -> "CFGFSM":
"""Create a copy of the FSM."""
return CFGFSM(self.cfg_string, self.tokenizer)
from sglang.srt.constrained import RegexFSM, TransformerTokenizer
from sglang.srt.constrained.base_cache import BaseCache from sglang.srt.constrained.base_cache import BaseCache
from sglang.srt.constrained.fsm import RegexFSM
from sglang.srt.constrained.tokenizer import TransformerTokenizer
class FSMCache(BaseCache): class FSMCache(BaseCache):
......
# Adapted from:
# https://github.com/outlines-dev/outlines/blob/8a0bafc8d82937babc5d586dd4f72ae844407e0e/outlines/fsm/json_schema.py
import inspect
import json
import re
from typing import Callable, Union
from jsonschema.protocols import Validator
from pydantic import BaseModel, create_model
from referencing import Registry, Resource
from referencing._core import Resolver
from referencing.jsonschema import DRAFT202012
STRING_INNER = r'(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)'
STRING = f'"{STRING_INNER}*"'
INTEGER = r"(0|[1-9][0-9]*)"
NUMBER = rf"(-)?({INTEGER})(\.[0-9]+)?([eE][+-][0-9]+)?"
BOOLEAN = r"(true|false)"
NULL = r"null"
type_to_regex = {
"string": STRING,
"integer": INTEGER,
"number": NUMBER,
"boolean": BOOLEAN,
"null": NULL,
}
def build_regex_from_object(object: Union[str, Callable, BaseModel]):
"""Turn a JSON schema into a regex that matches any JSON object that follows
this schema.
JSON Schema is a declarative language that allows to annotate JSON documents
with types and descriptions. These schemas can be generated from any Python
datastructure that has type annotation: namedtuples, dataclasses, Pydantic
models. And by ensuring that the generation respects the schema we ensure
that the output can be parsed into these objects.
This function parses the provided schema and builds a generation schedule which
mixes deterministic generation (fixed strings), and sampling with constraints.
Parameters
----------
schema
A string that represents a JSON Schema.
Returns
-------
A generation schedule. A list of strings that represent the JSON
schema's structure and regular expression that define the structure of
the fields.
References
----------
.. [0] JSON Schema. https://json-schema.org/
"""
if isinstance(object, type(BaseModel)):
schema = object.model_json_schema()
elif callable(object):
schema = get_schema_from_signature(object)
else:
schema = json.loads(object)
Validator.check_schema(schema)
# Build reference resolver
schema = Resource(contents=schema, specification=DRAFT202012)
uri = schema.id() if schema.id() is not None else ""
registry = Registry().with_resource(uri=uri, resource=schema)
resolver = registry.resolver()
content = schema.contents
return to_regex(resolver, content)
def to_regex(resolver: Resolver, instance: dict):
"""Translate a JSON Schema instance into a regex that validates the schema.
Note
----
Many features of JSON schema are missing:
- Handle `additionalProperties` keyword
- Handle types defined as a list
- Handle constraints on numbers
- Handle special patterns: `date`, `uri`, etc.
This does not support recursive definitions.
Parameters
----------
resolver
An object that resolves references to other instances within a schema
instance
The instance to translate
"""
whitespace = r"[\n ]*"
if "properties" in instance:
regex = ""
regex += r"\{"
properties = instance["properties"]
required_properties = instance.get("required", [])
is_required = [item in required_properties for item in properties]
# If at least one property is required, we include the one in the lastest position
# without any comma.
# For each property before it (optional or required), we add with a comma after the property.
# For each property after it (optional), we add with a comma before the property.
if any(is_required):
last_required_pos = max([i for i, value in enumerate(is_required) if value])
for i, (name, value) in enumerate(properties.items()):
subregex = f'{whitespace}"{name}"{whitespace}:{whitespace}'
subregex += to_regex(resolver, value)
if i < last_required_pos:
subregex = f"{subregex}{whitespace},"
elif i > last_required_pos:
subregex = f"{whitespace},{subregex}"
regex += subregex if is_required[i] else f"({subregex})?"
# If no property is required, we have to create a possible pattern for each property in which
# it's the last one necessarilly present. Then, we add the others as optional before and after
# following the same strategy as described above.
# The whole block is made optional to allow the case in which no property is returned.
else:
property_subregexes = []
for i, (name, value) in enumerate(properties.items()):
subregex = f'{whitespace}"{name}"{whitespace}:{whitespace}'
subregex += to_regex(resolver, value)
property_subregexes.append(subregex)
possible_patterns = []
for i in range(len(property_subregexes)):
pattern = ""
for subregex in property_subregexes[:i]:
pattern += f"({subregex}{whitespace},)?"
pattern += property_subregexes[i]
for subregex in property_subregexes[i + 1 :]:
pattern += f"({whitespace},{subregex})?"
possible_patterns.append(pattern)
regex += f"({'|'.join(possible_patterns)})?"
regex += f"{whitespace}" + r"\}"
return regex
# To validate against allOf, the given data must be valid against all of the
# given subschemas.
elif "allOf" in instance:
subregexes = [to_regex(resolver, t) for t in instance["allOf"]]
subregexes_str = [f"{subregex}" for subregex in subregexes]
return rf"({''.join(subregexes_str)})"
# To validate against `anyOf`, the given data must be valid against
# any (one or more) of the given subschemas.
elif "anyOf" in instance:
subregexes = [to_regex(resolver, t) for t in instance["anyOf"]]
return rf"({'|'.join(subregexes)})"
# To validate against oneOf, the given data must be valid against exactly
# one of the given subschemas.
elif "oneOf" in instance:
subregexes = [to_regex(resolver, t) for t in instance["oneOf"]]
xor_patterns = []
# json schema validation ensured there is no overlapping schemas in oneOf
for subregex in subregexes:
other_subregexes = filter(lambda r: r != subregex, subregexes)
other_subregexes_str = "|".join([f"{s}" for s in other_subregexes])
negative_lookahead = f"(?!.*({other_subregexes_str}))"
xor_patterns.append(f"({subregex}){negative_lookahead}")
return rf"({'|'.join(xor_patterns)})"
# The enum keyword is used to restrict a value to a fixed set of values. It
# must be an array with at least one element, where each element is unique.
elif "enum" in instance:
choices = []
for choice in instance["enum"]:
if type(choice) in [int, float, bool, None]:
choices.append(re.escape(str(choice)))
elif type(choice) == str:
choices.append(f'"{re.escape(choice)}"')
return f"({'|'.join(choices)})"
elif "$ref" in instance:
path = f"{instance['$ref']}"
instance = resolver.lookup(path).contents
return to_regex(resolver, instance)
# The type keyword may either be a string or an array:
# - If it's a string, it is the name of one of the basic types.
# - If it is an array, it must be an array of strings, where each string is
# the name of one of the basic types, and each element is unique. In this
# case, the JSON snippet is valid if it matches any of the given types.
elif "type" in instance:
instance_type = instance["type"]
if instance_type == "string":
if "maxLength" in instance or "minLength" in instance:
max_items = instance.get("maxLength", "")
min_items = instance.get("minLength", "")
try:
if int(max_items) < int(min_items):
raise ValueError(
"maxLength must be greater than or equal to minLength"
)
except ValueError:
pass
return f'"{STRING_INNER}{{{min_items},{max_items}}}"'
elif "pattern" in instance:
pattern = instance["pattern"]
if pattern[0] == "^" and pattern[-1] == "$":
return rf'(^"{pattern[1:-1]}"$)'
else:
return rf'("{pattern}")'
else:
return type_to_regex["string"]
elif instance_type == "number":
return type_to_regex["number"]
elif instance_type == "integer":
return type_to_regex["integer"]
elif instance_type == "array":
min_items = instance.get("minItems", "0")
max_items = instance.get("maxItems", "")
if min_items == max_items:
num_repeats = "{" + str(int(min_items) - 1) + "}"
else:
num_repeats = "*"
if "items" in instance:
items_regex = to_regex(resolver, instance["items"])
return rf"\[({items_regex})(,({items_regex})){num_repeats}\]"
else:
# Here we need to make the choice to exclude generating list of objects
# if the specification of the object is not given, even though a JSON
# object that contains an object here would be valid under the specification.
types = [
{"type": "boolean"},
{"type": "null"},
{"type": "number"},
{"type": "integer"},
{"type": "string"},
]
regexes = [to_regex(resolver, t) for t in types]
return (
rf"\[({'|'.join(regexes)})(,({'|'.join(regexes)})){num_repeats}\]"
)
elif instance_type == "boolean":
return type_to_regex["boolean"]
elif instance_type == "null":
return type_to_regex["null"]
elif isinstance(instance_type, list):
# Here we need to make the choice to exclude generating an object
# if the specification of the object is not give, even though a JSON
# object that contains an object here would be valid under the specification.
regexes = [
to_regex(resolver, {"type": t}) for t in instance_type if t != "object"
]
return rf"({'|'.join(regexes)})"
raise NotImplementedError(
f"""Could not translate the instance {instance} to a
regular expression. Make sure it is valid to the JSON Schema specification. If
it is, please open an issue on the Outlines repository"""
)
def get_schema_from_signature(fn: Callable) -> str:
"""Turn a function signature into a JSON schema.
Every JSON object valid to the output JSON Schema can be passed
to `fn` using the ** unpacking syntax.
"""
signature = inspect.signature(fn)
arguments = {}
for name, arg in signature.parameters.items():
if arg.annotation == inspect._empty:
raise ValueError("Each argument must have a type annotation")
else:
arguments[name] = (arg.annotation, ...)
model = create_model("Arguments", **arguments)
return model.model_json_schema()
import interegular import interegular
from sglang.srt.constrained import FSMInfo, disk_cache, make_deterministic_fsm
from sglang.srt.constrained.base_cache import BaseCache from sglang.srt.constrained.base_cache import BaseCache
from sglang.srt.constrained.disk_cache import disk_cache
from sglang.srt.constrained.regex import FSMInfo, make_deterministic_fsm
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)" IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
......
# Adapted from:
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/fsm/regex.py
from collections import namedtuple
from functools import lru_cache
from typing import Dict, Generator, List, Sequence, Set, Tuple
import numba
import numpy as np
from interegular.fsm import FSM, Alphabet, OblivionError, anything_else
from numba.typed.typedobjectutils import _nonoptional
from sglang.srt.constrained.tokenizer import Tokenizer
class BetterAlphabet(Alphabet):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert anything_else in self._symbol_mapping
self.anything_value = self._symbol_mapping[anything_else]
def __getitem__(self, item):
return self._symbol_mapping.get(item, self.anything_value)
def copy(self):
return BetterAlphabet(self._symbol_mapping.copy())
class BetterFSM(FSM):
flat_transition_map: Dict[Tuple[int, int], int]
trans_key_to_states: Dict[int, List[int]]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not isinstance(self.alphabet, BetterAlphabet):
self.__dict__["alphabet"] = BetterAlphabet(self.alphabet._symbol_mapping)
flat_transition_map = {}
trans_key_to_states = {}
for from_state, trans_map in self.map.items():
for trans_key, to_state in trans_map.items():
flat_transition_map[(from_state, trans_key)] = to_state
trans_key_to_states.setdefault(trans_key, set()).add(from_state)
self.__dict__["trans_key_to_states"] = trans_key_to_states
self.__dict__["flat_transition_map"] = flat_transition_map
self.__dict__["_fsm_info"] = None
def copy(self):
return BetterFSM(
alphabet=self.alphabet.copy(),
states=self.states.copy(),
initial=self.initial,
finals=self.finals.copy(),
map=self.map.copy(),
__no_validation__=True,
)
@property
def fsm_info(self):
if self._fsm_info is None:
flat_transition_map_items = np.fromiter(
((a[0], a[1], b) for a, b in self.flat_transition_map.items()),
dtype=np.dtype("i8, i8, i8"),
)
trans_key_to_states_items = np.fromiter(
((k, z) for k, v in self.trans_key_to_states.items() for z in v),
dtype=np.dtype("i8, i8"),
)
alphabet_symbol_mapping_items = np.fromiter(
(
it
for it in self.alphabet._symbol_mapping.items()
if it[0] != anything_else
),
dtype=np.dtype("U1, i8"),
)
nb_finals = np.fromiter(self.finals, dtype=np.dtype("i8"))
self.__dict__["_fsm_info"] = create_fsm_info(
self.initial,
nb_finals,
flat_transition_map_items,
trans_key_to_states_items,
self.alphabet.anything_value,
alphabet_symbol_mapping_items,
)
return self._fsm_info
nb_int_list_type = numba.types.ListType(numba.int64)
nb_int_pair_type = numba.types.UniTuple(numba.int64, 2)
nb_unichar_1_type = numba.types.UnicodeCharSeq(1)
@numba.njit(cache=True)
def create_fsm_info(
py_initial,
py_finals,
flat_transition_map_items,
trans_key_to_states_items,
py_anything_value,
alphabet_symbol_mapping_items,
):
trans_key_to_states = numba.typed.Dict.empty(numba.int64, nb_int_list_type)
for trans_key_and_state in trans_key_to_states_items:
trans_key_to_states.setdefault(
trans_key_and_state[0], numba.typed.List.empty_list(numba.int64)
).append(trans_key_and_state[1])
flat_transition_map = numba.typed.Dict.empty(nb_int_pair_type, numba.int64)
for trans_key_and_state in flat_transition_map_items:
flat_transition_map[
(trans_key_and_state[0], trans_key_and_state[1])
] = trans_key_and_state[2]
alphabet_symbol_map = numba.typed.Dict.empty(nb_unichar_1_type, numba.int64)
for symbol_and_trans_key in alphabet_symbol_mapping_items:
alphabet_symbol_map[symbol_and_trans_key[0]] = symbol_and_trans_key[1]
initial = numba.int64(py_initial)
finals = set()
for final in py_finals:
finals.add(final)
anything_value = numba.int64(py_anything_value)
return FSMInfo(
initial,
finals,
flat_transition_map,
trans_key_to_states,
anything_value,
alphabet_symbol_map,
)
FSMInfo = namedtuple(
"FSMInfo",
[
"initial",
"finals",
"transitions",
"trans_key_to_states",
"alphabet_anything_value",
"alphabet_symbol_mapping",
],
)
def make_deterministic_fsm(fsm: FSM) -> Tuple[BetterFSM, Dict[int, int]]:
"""Construct an equivalent FSM with deterministic state labels."""
old_to_new_trans_keys = {
trans_key: i
for i, (trans_key, _) in enumerate(
sorted(fsm.alphabet.by_transition.items(), key=lambda x: sorted(x[1]))
)
}
new_symbol_mapping = {
symbol: old_to_new_trans_keys[trans_key]
for symbol, trans_key in fsm.alphabet._symbol_mapping.items()
}
new_alphabet = BetterAlphabet(new_symbol_mapping)
new_map = {
from_state: {
old_to_new_trans_keys[trans_key]: to_state
for trans_key, to_state in trans_map.items()
}
for from_state, trans_map in fsm.map.items()
}
old_to_new_states = {}
old_to_new_states[fsm.initial] = 0
i = 0
seen = {fsm.initial}
old_state_queue = [fsm.initial]
while old_state_queue:
old_state = old_state_queue.pop(-1)
transitions = new_map[old_state]
sorted_transitions = sorted(transitions.items(), key=lambda v: v[0])
for _, old_state in sorted_transitions:
if old_state not in seen:
old_state_queue.append(old_state)
seen.add(old_state)
if old_state not in old_to_new_states:
i += 1
old_to_new_states[old_state] = i
new_map = dict(
sorted(
(
(
old_to_new_states[from_state],
dict(
sorted(
(
(trans_key, old_to_new_states[to_state])
for trans_key, to_state in trans_map.items()
),
key=lambda v: v[0],
)
),
)
for from_state, trans_map in new_map.items()
),
key=lambda v: v[0],
)
)
new_initial = 0
new_finals = frozenset(
sorted(old_to_new_states[old_state] for old_state in fsm.finals)
)
new_states = frozenset(sorted(new_map.keys()))
new_fsm = BetterFSM(new_alphabet, new_states, new_initial, new_finals, new_map)
return new_fsm, old_to_new_states
@numba.njit(nogil=True, cache=True)
def _walk_fsm(
fsm_transitions: Dict[Tuple[int, int], int],
alphabet_symbol_mapping: Dict[str, int],
alphabet_anything_value: int,
fsm_initial: int,
fsm_finals: Set[int],
input_string: str,
start_state: int,
full_match: bool = True,
) -> List[int]:
state = start_state
accepted_states: List[int] = numba.typed.List.empty_list(numba.int64)
last_final_idx: int = numba.uint64(0)
for i, symbol in enumerate(input_string):
trans_key = alphabet_symbol_mapping.get(symbol, alphabet_anything_value)
new_state = fsm_transitions.get((state, trans_key))
if new_state is None:
if not full_match and last_final_idx > 0:
return accepted_states[:last_final_idx]
return numba.typed.List.empty_list(numba.int64)
state = new_state
if state in fsm_finals:
last_final_idx = numba.uint64(i + 1)
accepted_states.append(_nonoptional(state))
if full_match and last_final_idx - 1 != i:
return numba.typed.List.empty_list(numba.int64)
return accepted_states
def walk_fsm(
fsm: BetterFSM,
input_string: str,
start_state: int,
full_match: bool = True,
) -> List[int]:
fsm_finals = fsm.finals
state = start_state
accepted_states: List[int] = []
last_final_idx: int = 0
alphabet_symbol_mapping = fsm.alphabet._symbol_mapping
alphabet_anything_value = fsm.alphabet.anything_value
fsm_transitions = fsm.flat_transition_map
for i, symbol in enumerate(input_string):
trans_key = alphabet_symbol_mapping.get(symbol, alphabet_anything_value)
new_state = fsm_transitions.get((state, trans_key))
if new_state is None:
if not full_match and last_final_idx > 0:
return accepted_states[:last_final_idx]
return []
state = new_state
if state in fsm_finals:
last_final_idx = i + 1
accepted_states.append(state)
if full_match and last_final_idx - 1 != i:
return []
return accepted_states
def fsm_union(
fsms: Sequence[FSM],
) -> Tuple[FSM, Dict[int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]]]:
"""Construct an FSM representing the union of the FSMs in `fsms`.
This is an updated version of `interegular.fsm.FSM.union` made to return an
extra map of component FSMs to the sets of state transitions that
correspond to them in the new FSM.
"""
alphabet, new_to_old = Alphabet.union(*[fsm.alphabet for fsm in fsms])
indexed_fsms = tuple(enumerate(fsms))
initial = {i: fsm.initial for (i, fsm) in indexed_fsms}
# Dedicated function accepting a "superset" and returning the next
# "superset" obtained by following this transition in the new FSM
def follow(current_state, new_transition: int):
next = {}
for i, f in indexed_fsms:
old_transition = new_to_old[i][new_transition]
if (
i in current_state
and current_state[i] in f.map
and old_transition in f.map[current_state[i]]
):
next[i] = f.map[current_state[i]][old_transition]
if not next:
raise OblivionError
return next
states = [initial]
finals: Set[int] = set()
map: Dict[int, Dict[int, int]] = {}
# Map component FSMs to their new state-to-state transitions, finals, and a
# map translating component FSM states to aggregate FSM states
fsms_to_trans_finals: Dict[
int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]
] = {}
i = 0
while i < len(states):
state = states[i]
# Add to the finals of the aggregate FSM whenever we hit a final in a
# component FSM
if any(state.get(j, -1) in fsm.finals for (j, fsm) in indexed_fsms):
finals.add(i)
# Compute the map for this state
map[i] = {}
for transition in alphabet.by_transition:
try:
next = follow(state, transition)
except OblivionError:
# Reached an oblivion state; don't list it
continue
else:
try:
# TODO: Seems like this could--and should--be avoided
j = states.index(next)
except ValueError:
j = len(states)
states.append(next)
map[i][transition] = j
for fsm_id, fsm_state in next.items():
(
fsm_transitions,
fsm_finals,
fsm_old_to_new,
) = fsms_to_trans_finals.setdefault(fsm_id, (set(), set(), {}))
old_from = state[fsm_id]
old_to = fsm_state
fsm_old_to_new.setdefault(old_from, set()).add(i)
fsm_old_to_new.setdefault(old_to, set()).add(j)
fsm_transitions.add((i, j))
if fsm_state in fsms[fsm_id].finals:
fsm_finals.add(j)
i += 1
fsm = FSM(
alphabet=alphabet,
states=range(len(states)),
initial=0,
finals=finals,
map=map,
__no_validation__=True,
)
fsm, old_to_new_states = make_deterministic_fsm(fsm)
_fsms_to_trans_finals = {
fsm_id: (
{(old_to_new_states[s1], old_to_new_states[s2]) for s1, s2 in transitions},
{old_to_new_states[s] for s in finals},
{
old_state: {old_to_new_states[new_state] for new_state in new_states}
for old_state, new_states in old_to_new.items()
},
)
for fsm_id, (transitions, finals, old_to_new) in sorted(
fsms_to_trans_finals.items(), key=lambda x: x[0]
)
}
return (
fsm,
_fsms_to_trans_finals,
)
def get_sub_fsms_from_seq(
state_seq: Sequence[int],
fsms_to_trans_finals: Dict[
int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]
],
) -> Generator[Tuple[int, bool, bool], None, None]:
"""Get the indices of the sub-FSMs in `fsm` that could have matched the state sequence `state_seq`.
Parameters
----------
state_seq
A state sequence.
fsms_to_trans_finals
A map from FSM indices to tuples containing sets of their state transitions
and sets of the final/accept states.
Returns
-------
A generator returning tuples containing each sub-FSM index (in the order
they were union-ed to construct `fsm`) and booleans indicating whether or
not there is another valid transition from the last state in the sequence
for the associated sub-FSM (i.e. if the FSM can continue
accepting/matching) and whether or not the sequence ends in a final state
of the sub-FSM.
"""
state_seq_transitions = set(zip(state_seq[:-1], state_seq[1:]))
last_fsm_state = state_seq[-1]
yield from (
(
# The sub-FMS index
fsm_idx,
# Is there another possible transition in this sub-FSM?
any(last_fsm_state == from_s for (from_s, to_s) in transitions),
# Is this sub-FSM in a final state?
state_seq[-1] in finals,
)
for fsm_idx, (transitions, finals, _) in fsms_to_trans_finals.items()
if state_seq_transitions.issubset(transitions)
)
@numba.njit(cache=True, nogil=True)
def state_scan_tokens(
fsm_transitions: Dict[Tuple[int, int], int],
alphabet_symbol_mapping: Dict[str, int],
alphabet_anything_value: int,
fsm_initial: int,
fsm_finals: Set[int],
vocabulary: Dict[str, List[int]],
start_state: int,
) -> Set[Tuple[int, int]]:
res = set()
for token, token_ids in vocabulary.items():
state_seq = _walk_fsm(
fsm_transitions,
alphabet_symbol_mapping,
alphabet_anything_value,
fsm_initial,
fsm_finals,
token,
start_state,
False,
)
if state_seq is not None and len(state_seq) < len(token):
continue
for token_id in token_ids:
res.add((token_id, state_seq[-1]))
return res
def create_fsm_index_end_to_end(
fsm_info: FSMInfo,
vocabulary: Dict[str, List[int]],
) -> Dict[int, Set[Tuple[int, int]]]:
"""Create an FSM state-to-vocabulary map/index through end-to-end token parsing."""
# TODO: Consider using a `List` of `Set`s instead; that way we can JIT this
# code, too.
states_to_token_subsets: Dict[int, Set[Tuple[int, int]]] = {}
seen: Set[int] = set()
next_states = {fsm_info.initial}
while next_states:
start_state = next_states.pop()
token_ids_end_states = state_scan_tokens(
fsm_info.transitions,
fsm_info.alphabet_symbol_mapping,
fsm_info.alphabet_anything_value,
fsm_info.initial,
fsm_info.finals,
vocabulary,
start_state,
)
for token_id_and_end_state in token_ids_end_states:
states_to_token_subsets.setdefault(start_state, set()).add(
token_id_and_end_state
)
end_state = token_id_and_end_state[1]
if end_state not in seen:
next_states.add(end_state)
seen.add(start_state)
return states_to_token_subsets
# TODO: Cannot cache typed collections to disk, yet. See
# https://github.com/numba/numba/issues/4698
@lru_cache
def reduced_vocabulary(tokenizer: "Tokenizer"):
"""Create a map from decoded vocabulary tokens to lists of equivalent token ids."""
vocabulary = numba.typed.Dict.empty(
numba.types.string, numba.types.ListType(numba.int64)
)
empty_token_ids = set()
for token, token_idx in tokenizer.vocabulary.items():
if token in tokenizer.special_tokens:
continue
token_str = tokenizer.convert_token_to_string(token)
if token_str:
vocabulary.setdefault(
token_str,
numba.typed.List.empty_list(numba.int64),
).append(numba.int64(token_idx))
else:
empty_token_ids.add(numba.int64(token_idx))
return vocabulary, empty_token_ids
def create_fsm_index_tokenizer(
fsm: BetterFSM,
tokenizer: "Tokenizer",
) -> Tuple[Dict[int, Dict[int, int]], Set[int]]:
"""Construct an FMS index from a tokenizer.
This uses the end-to-end approach of `create_fsm_index_end_to_end`.
.. warning::
`fsm` needs to be deterministically ordered so that future caching makes sense.
"""
vocabulary, empty_token_ids = reduced_vocabulary(tokenizer)
states_to_token_subsets = create_fsm_index_end_to_end(fsm.fsm_info, vocabulary)
# Allow transitions to EOS from all terminals FSM states that are
# reachable
# TODO: Do we really need this anymore?
for state in fsm.fsm_info.finals:
subset = states_to_token_subsets.get(state)
if subset is not None:
subset.add((tokenizer.eos_token_id, state))
# Convert to token-to-end-state maps
states_to_token_subsets = {k: dict(v) for k, v in states_to_token_subsets.items()}
return states_to_token_subsets, empty_token_ids
# Adapted from:
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/models/tokenizer.py
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/models/transformers.py
from abc import abstractmethod
from typing import Dict, Hashable, List, Protocol, Set, Tuple, Union
import numpy as np
import torch
from numpy.typing import NDArray
class Tokenizer(Protocol, Hashable):
eos_token: str
eos_token_id: int
pad_token_id: int
vocabulary: Dict[str, int]
special_tokens: Set[int]
@abstractmethod
def encode(
self, prompt: Union[str, List[str]]
) -> Tuple[NDArray[np.int64], NDArray[np.int64]]:
"""Translate the input prompts into NumPy arrays of token ids and attention mask."""
...
@abstractmethod
def decode(self, token_ids: NDArray[np.int64]) -> List[str]:
"""Translate an array of token ids to a string or list of strings."""
...
@abstractmethod
def convert_token_to_string(self, token: str) -> str:
"""Convert a token to its equivalent string.
This is for instance useful for BPE tokenizers where whitespaces are
represented by the special characted `Ġ`. This prevents matching a raw
token that includes `Ġ` with a string.
"""
...
def get_llama_tokenizer_types():
"""Get all the Llama tokenizer types/classes that need work-arounds.
When they can't be imported, a dummy class is created.
"""
try:
from transformers.models.llama import LlamaTokenizer
except ImportError:
class LlamaTokenizer: # type: ignore
pass
try:
from transformers.models.llama import LlamaTokenizerFast
except ImportError:
class LlamaTokenizerFast: # type: ignore
pass
try:
from transformers.models.code_llama import CodeLlamaTokenizer
except ImportError:
class CodeLlamaTokenizer: # type: ignore
pass
try:
from transformers.models.code_llama import CodeLlamaTokenizerFast
except ImportError:
class CodeLlamaTokenizerFast: # type: ignore
pass
return (
LlamaTokenizer,
LlamaTokenizerFast,
CodeLlamaTokenizer,
CodeLlamaTokenizerFast,
)
class TransformerTokenizer(Tokenizer):
"""Represents a tokenizer for models in the `transformers` library."""
def __init__(self, model_name: str, **kwargs):
from transformers import AutoTokenizer
kwargs.setdefault("padding_side", "left")
self.model_name = model_name
# TODO: Do something to make this hashable?
self.kwargs = kwargs
self.tokenizer = AutoTokenizer.from_pretrained(model_name, **kwargs)
self.eos_token_id = self.tokenizer.eos_token_id
self.eos_token = self.tokenizer.eos_token
if not self.tokenizer.pad_token_id:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
self.pad_token_id = self.eos_token_id
else:
self.pad_token_id = self.tokenizer.pad_token_id
self.pad_token = self.tokenizer.pad_token
self.special_tokens = set(self.tokenizer.all_special_tokens)
self.vocabulary = self.tokenizer.get_vocab()
self.is_llama = isinstance(self.tokenizer, get_llama_tokenizer_types())
def encode(
self, prompt: Union[str, List[str]], **kwargs
) -> Tuple[torch.LongTensor, torch.LongTensor]:
kwargs["padding"] = True
kwargs["return_tensors"] = "pt"
output = self.tokenizer(prompt, **kwargs)
return output["input_ids"], output["attention_mask"]
def decode(self, token_ids: torch.LongTensor) -> List[str]:
text = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True)
return text
def convert_token_to_string(self, token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE
string = self.tokenizer.convert_tokens_to_string([token])
if self.is_llama:
# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string
return string
def __eq__(self, other):
if isinstance(other, type(self)):
return other.model_name == self.model_name and other.kwargs == self.kwargs
return NotImplemented
def __hash__(self):
from datasets.fingerprint import Hasher
return hash(Hasher.hash(self.tokenizer))
...@@ -21,7 +21,7 @@ from fastapi import FastAPI, HTTPException, Request ...@@ -21,7 +21,7 @@ from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import Response, StreamingResponse from fastapi.responses import Response, StreamingResponse
from pydantic import BaseModel from pydantic import BaseModel
from sglang.backend.runtime_endpoint import RuntimeEndpoint from sglang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.srt.constrained.disk_cache import disable_cache from sglang.srt.constrained import disable_cache
from sglang.srt.conversation import ( from sglang.srt.conversation import (
Conversation, Conversation,
SeparatorStyle, SeparatorStyle,
......
...@@ -2,12 +2,11 @@ import argparse ...@@ -2,12 +2,11 @@ import argparse
from enum import Enum from enum import Enum
from pydantic import BaseModel, constr from pydantic import BaseModel, constr
from sglang.srt.constrained.json_schema import build_regex_from_object from sglang.srt.constrained import build_regex_from_object
from sglang.test.test_utils import ( from sglang.test.test_utils import (
add_common_sglang_args_and_parse, add_common_sglang_args_and_parse,
select_sglang_backend, select_sglang_backend,
) )
import sglang as sgl import sglang as sgl
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)" IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment