Unverified Commit ca13f3b8 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Disk FSM cache and adjust code. (#63)

parent 0b2efc2a
from sglang import function, gen, set_default_backend, Runtime
IP_ADDR_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
@function
def regex_gen(s):
s += "Q: What is the IP address of the Google DNS servers?\n"
s += "A: " + gen(
"answer",
temperature=0,
regex=r"((25[0-5]|2[0-4]\d|[01]?\d\d?).){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)",
regex=IP_ADDR_REGEX,
)
......
......@@ -19,7 +19,7 @@ dependencies = [
[project.optional-dependencies]
srt = ["fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn", "zmq", "vllm>=0.2.5",
"interegular", "lark", "numba", "pydantic"]
"interegular", "lark", "numba", "pydantic", "diskcache", "cloudpickle"]
openai = ["openai>=1.0"]
anthropic = ["anthropic"]
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"]
......
# Adapted from:
# https://github.com/outlines-dev/outlines/blob/6c6966cfa24e9c120494ebb317c6126aa2ae94af/outlines/caching.py
import asyncio
import hashlib
import os
from typing import Callable, Optional
import cloudpickle
from diskcache import Cache
home_dir = os.path.expanduser("~")
cache_dir = os.environ.get("SGLANG_CACHE_DIR", f"{home_dir}/.cache/sglang")
memory = Cache(cache_dir, eviction_policy="none", cull_limit=0)
_caching_enabled = True
def hash_arguments(*args, **kwargs) -> str:
"""Create a hash out of the args and kwargs provided"""
result = hashlib.md5()
for item in list(args) + sorted(kwargs.items()):
result.update(cloudpickle.dumps(item))
return result.hexdigest()
def disk_cache(key_function: Optional[Callable] = None):
def decorator(cached_function: Callable):
def wrapper(*args, **kwargs):
if not _caching_enabled:
return cached_function(*args, **kwargs)
if key_function:
key_args = key_function(*args, **kwargs)
cache_key = hash_arguments(*key_args)
else:
cache_key = hash_arguments(*args, **kwargs)
if cache_key in memory:
return memory[cache_key]
result = cached_function(*args, **kwargs)
memory[cache_key] = result
return result
async def async_wrapper(*args, **kwargs):
if not _caching_enabled:
return await cached_function(*args, **kwargs)
if key_function:
key_args = key_function(*args, **kwargs)
cache_key = hash_arguments(*key_args)
else:
cache_key = hash_arguments(*args, **kwargs)
if cache_key in memory:
return memory[cache_key]
result = await cached_function(*args, **kwargs)
memory[cache_key] = result
return result
if asyncio.iscoroutinefunction(cached_function):
return async_wrapper
else:
return wrapper
return decorator
def disable_cache():
global _caching_enabled
_caching_enabled = False
def clear_cache():
global memory
memory.clear()
# Adapted from:
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/fsm/fsm.py
from typing import List, NewType, Protocol
# https://github.com/outlines-dev/outlines/blob/6c6966cfa24e9c120494ebb317c6126aa2ae94af/outlines/fsm/fsm.py
from typing import List, NewType, Protocol, Tuple
import interegular
from lark import Lark
from sglang.srt.constrained.disk_cache import disk_cache
# from outlines.fsm.parsing import PartialLark
from sglang.srt.constrained.regex import (
......@@ -16,16 +17,16 @@ FSMState = NewType("FSMState", int)
class FSM(Protocol):
def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]:
def allowed_token_ids(self, state: FSMState) -> List[int]:
...
def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState:
def next_state(self, state: FSMState, token_id: int) -> FSMState:
...
def is_final_state(self, state: FSMState, idx: int = 0) -> bool:
def is_final_state(self, state: FSMState) -> bool:
...
def reset(self) -> None:
def copy(self) -> "FSM":
...
......@@ -38,17 +39,12 @@ class StopAtTokenFSM(FSM):
"""
def __init__(
self,
tokenizer: "Tokenizer",
stop_token_id: int,
):
def __init__(self, tokenizer: "Tokenizer", stop_token_id: int):
self.stop_token_id = stop_token_id
self.num_tokens_generated = 0
self.vocabulary = tokenizer.vocabulary.values()
self.final_states = {1}
def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]:
def allowed_token_ids(self, state: FSMState) -> List[int]:
"""Generate a list of allowed tokens for the next step.
When in the initial state we allow every token to be generated.
......@@ -58,8 +54,6 @@ class StopAtTokenFSM(FSM):
----------
state
The current state of the FSM.
idx
The index of the current input in the batch.
Returns
-------
......@@ -71,7 +65,7 @@ class StopAtTokenFSM(FSM):
else:
return [self.stop_token_id]
def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState:
def next_state(self, state: FSMState, token_id: int) -> FSMState:
"""Update the state of the FSM.
The FSM stays in the initial state `0` unless the specified stop token
......@@ -84,29 +78,24 @@ class StopAtTokenFSM(FSM):
The current state of the FSM.
token_id
The id of the token that was just generated.
idx
The index of the current input in the batch.
Returns
-------
The new state of the FSM.
"""
if idx == 0:
self.num_tokens_generated += 1
if token_id == self.stop_token_id:
return FSMState(1)
return FSMState(0)
def is_final_state(self, state: FSMState, idx: int = 0) -> bool:
def is_final_state(self, state: FSMState) -> bool:
"""Determine whether the current state of the FSM is a final state."""
return state in self.final_states
def reset(self) -> None:
"""Reset the FSM to its initial state. Here this only resets the token counter."""
self.num_tokens_generated = 0
def copy(self) -> "StopAtTokenFSM":
"""Create a copy of the FSM."""
return self
class RegexFSM(FSM):
......@@ -117,32 +106,48 @@ class RegexFSM(FSM):
regex_string: str,
tokenizer: "Tokenizer",
):
regex_pattern = interegular.parse_pattern(regex_string)
regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())
@disk_cache()
def create_states_mapping(
regex_string: str, cacheable_vocabulary: Tuple[Tuple[str, int]]
) -> Tuple[dict, set, set]:
"""Create the variables related to the mapping between states and tokens
The parameters of the function are used for caching purpose
"""
regex_pattern = interegular.parse_pattern(regex_string)
regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())
(
states_to_token_maps,
empty_token_ids,
) = create_fsm_index_tokenizer(regex_fsm, tokenizer)
# We make sure that it is possible to generate strings in the language
# of the regular expression with the tokens present in the model's
# vocabulary.
if not any(
regex_fsm.finals.intersection(v.values())
for v in states_to_token_maps.values()
):
raise ValueError(
"The vocabulary does not allow us to build a sequence that matches the input regex"
)
final_states = regex_fsm.finals | {
-1
} # Include the EOS token in final states
return states_to_token_maps, empty_token_ids, final_states
(
self.states_to_token_maps,
self.empty_token_ids,
) = create_fsm_index_tokenizer(regex_fsm, tokenizer)
# We make sure that it is possible to generate strings in the language
# of the regular expression with the tokens present in the model's
# vocabulary.
if not any(
regex_fsm.finals.intersection(v.values())
for v in self.states_to_token_maps.values()
):
raise ValueError(
"The vocabulary does not allow us to build a sequence that matches the input regex"
)
self.final_states = regex_fsm.finals | {
-1
} # Include the EOS token in final states
self.final_states,
) = create_states_mapping(
regex_string, tuple(sorted(tokenizer.vocabulary.items()))
)
self.num_tokens_generated = 0
self.vocabulary = tokenizer.vocabulary.values()
self.end_token_id = tokenizer.eos_token_id
def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]:
def allowed_token_ids(self, state: FSMState) -> List[int]:
"""Generate a list of allowed tokens for the next step.
The initialization of the FSM builds an index which maps FSM states to a
......@@ -159,8 +164,6 @@ class RegexFSM(FSM):
----------
state
The current state of the FSM.
idx
The index of the current input in the batch.
Returns
-------
......@@ -174,7 +177,7 @@ class RegexFSM(FSM):
else:
return list(next_tokens_to_end_states.keys())
def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState:
def next_state(self, state: FSMState, token_id: int) -> FSMState:
"""Update the state of the FSM.
We use the index to determine to which state the FSM should transition
......@@ -186,17 +189,12 @@ class RegexFSM(FSM):
The current state of the FSM.
token_id
The id of the token that was just generated.
idx
The index of the current input in the batch.
Returns
-------
The new state of the FSM.
"""
if idx == 0:
self.num_tokens_generated += 1
if token_id == self.end_token_id:
return FSMState(-1)
......@@ -207,24 +205,22 @@ class RegexFSM(FSM):
return FSMState(next_state)
def is_final_state(self, state: FSMState, idx: int = 0) -> bool:
def is_final_state(self, state: FSMState) -> bool:
"""Determine whether the current state of the FSM is a final state."""
return state in self.final_states
def reset(self) -> None:
"""Reset the FSM to its initial state. Here this only resets the token counter."""
self.num_tokens_generated = 0
def copy(self) -> "RegexFSM":
"""Create a copy of the FSM."""
return self
class CFGFSM(FSM):
"""FSM to generate text that is in the language of a context-free grammar."""
def __init__(
self,
cfg_string: str,
tokenizer: "Tokenizer",
):
# self.parser = PartialLark(cfg_string, parser="lalr")
def __init__(self, cfg_string: str, tokenizer: "Tokenizer"):
self.cfg_string = cfg_string
self.tokenizer = tokenizer
self.parser = Lark(
cfg_string,
parser="lalr",
......@@ -239,59 +235,52 @@ class CFGFSM(FSM):
self.terminal_regexps[terminal.name] = terminal.pattern.to_regexp()
self.terminal_regexps["$END"] = tokenizer.eos_token
self.tokenizer = tokenizer
self.num_tokens_generated = 0
self.generations: List[str] = []
self.regex_fsms: List[RegexFSM] = []
self.reset_state: List[bool] = []
self.allow_eos: List[bool] = []
self.done: List[bool] = []
self.generation = ""
self.reset_state = False
self.allow_eos = False
self.done = False
self.regex_fsm: RegexFSM
def _set_next_regex_fsm(self, idx: int = 0) -> None:
def _set_next_regex_fsm(self) -> None:
"""Use the CFG incremental parser to set the next regex FSM.
Check what the CFG incremental parser proposes next.
If the only proposal is the EOS token,
we set the state to done and return.
If there are other proposals,
we set a new regex FSM and return.
Check what the CFG incremental parser proposes next:
- If the only proposal is the EOS token we set the state to done and
return.
- If there are other proposals, we set a new regex FSM and return.
"""
interactive = self.parser.parse_interactive(self.generations[idx])
interactive = self.parser.parse_interactive(self.generation)
interactive.exhaust_lexer()
options = {self.terminal_regexps[x] for x in interactive.accepts()}
if self.terminal_regexps["$END"] in options:
options.remove(self.terminal_regexps["$END"])
if len(options) == 0:
self.done[idx] = True
self.done = True
return
self.allow_eos[idx] = True
self.allow_eos = True
options.add("")
assert len(options) > 1
regex_string = r"(" + r"|".join([r"(" + x + r")" for x in options]) + r")"
args = (
regex_string,
self.tokenizer,
)
if len(self.regex_fsms) <= idx:
self.regex_fsms.append(RegexFSM(*args))
else:
self.regex_fsms[idx] = RegexFSM(*args)
self.reset_state[idx] = True
self.regex_fsm = RegexFSM(regex_string, self.tokenizer)
self.reset_state = True
def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]:
def allowed_token_ids(self, state: FSMState) -> List[int]:
"""Generate a list of allowed tokens for the next step.
Upon initialization, the CFG incremental parser is used to determine the first regex.
Upon initialization, the CFG incremental parser is used to determine the
first regex.
This regex is used for proposals until either:
- the regex is exhausted, and its only remaining option is the EOS token,
in which case we always transition to the next regex
- the regex can be exhausted, but the EOS token is not the only remaining option,
in which case we transition to the next regex with probability P (TODO)
or remove the possibility of generating the EOS token and continue with the current regex
- The regex is exhausted, and its only remaining option is the EOS
token, in which case we always transition to the next regex
- The regex can be exhausted, but the EOS token is not the only
remaining option, in which case we transition to the next regex with
probability P (TODO) or remove the possibility of generating the EOS
token and continue with the current regex
The CFG incremental parser is allowed to propose the EOS token from any final state,
and once it is generated, the FSM will continue to always generate the EOS token.
......@@ -300,22 +289,14 @@ class CFGFSM(FSM):
----------
state
The current state of the FSM.
idx
The index of the current input in the batch.
Returns
-------
A list that contains the tokens to mask.
"""
if len(self.generations) <= idx:
self.generations.append("")
self.reset_state.append(False)
self.allow_eos.append(False)
self.done.append(False)
if len(self.regex_fsms) > idx:
proposal = self.regex_fsms[idx].allowed_token_ids(state)
if self.generation != "":
proposal = self.regex_fsm.allowed_token_ids(state)
if self.tokenizer.eos_token_id not in proposal:
return proposal
if set(proposal) != {self.tokenizer.eos_token_id}:
......@@ -323,23 +304,23 @@ class CFGFSM(FSM):
proposal = [x for x in proposal if x != self.tokenizer.eos_token_id]
return proposal
self._set_next_regex_fsm(idx)
self._set_next_regex_fsm()
if self.done[idx]:
if self.done:
return [self.tokenizer.eos_token_id]
if self.reset_state[idx]:
if self.reset_state:
state = FSMState(0)
proposal = self.regex_fsms[idx].allowed_token_ids(state)
if self.allow_eos[idx]:
self.allow_eos[idx] = False
proposal = self.regex_fsm.allowed_token_ids(state)
if self.allow_eos:
self.allow_eos = False
else:
proposal = [x for x in proposal if x != self.tokenizer.eos_token_id]
assert len(proposal) > 0
return proposal
def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState:
def next_state(self, state: FSMState, token_id: int) -> FSMState:
"""Update the state of the FSM.
Transitions the underlying regex FSM to its next state.
......@@ -352,34 +333,26 @@ class CFGFSM(FSM):
The current state of the FSM.
token_id
The id of the token that was just generated.
idx
The index of the current input in the batch.
Returns
-------
The new state of the FSM.
"""
if idx == 0:
self.num_tokens_generated += 1
if token_id == self.tokenizer.eos_token_id:
self.done[idx] = True
self.done = True
return FSMState(-1)
if self.reset_state[idx]:
self.reset_state[idx] = False
if self.reset_state:
self.reset_state = False
state = FSMState(0)
self.generations[idx] += self.tokenizer.decode([token_id])[0]
self.generation += self.tokenizer.decode([token_id])[0]
return self.regex_fsms[idx].next_state(state, token_id, idx)
return self.regex_fsm.next_state(state, token_id)
def is_final_state(self, state: FSMState, idx: int = 0) -> bool:
def is_final_state(self, state: FSMState) -> bool:
"""Return whether the current state of the FSM is a final state."""
return self.done[idx]
return self.done
def reset(self) -> None:
"""Reset the FSM to its initial state, so it can be called on a fresh batch on inputs."""
self.num_tokens_generated = 0
self.generations = []
self.regex_fsms = []
self.reset_state = []
self.done = []
def copy(self) -> "CFGFSM":
"""Create a copy of the FSM."""
return CFGFSM(self.cfg_string, self.tokenizer)
import threading
from sglang.srt.constrained.fsm import RegexFSM
from sglang.srt.constrained.tokenizer import TransformerTokenizer
def get_fsm(regex, tokenizer, fsm_cache_entry):
outlines_tokenizer = TransformerTokenizer(tokenizer)
fsm = RegexFSM(regex, outlines_tokenizer)
fsm_cache_entry.fsm = fsm
fsm_cache_entry.event.set()
class FSMCacheEntry:
def __init__(self):
self.fsm = None
self.event = threading.Event()
class FSMCache:
def __init__(self, tokenizer):
def __init__(self, tokenizer_path, tokenizer_args_dict):
self.cache = {}
self.tokenizer = tokenizer
self.outlines_tokenizer = TransformerTokenizer(
tokenizer_path, **tokenizer_args_dict
)
def init_fsm_in_background(self, regex):
def init_fsm(self, regex):
if regex not in self.cache:
self.cache[regex] = FSMCacheEntry()
threading.Thread(
target=get_fsm,
args=(
regex,
self.tokenizer,
self.cache[regex],
),
).start()
fsm = RegexFSM(regex, self.outlines_tokenizer)
self.cache[regex] = fsm
def get_fsm(self, regex):
self.init_fsm_in_background(regex)
entry = self.cache[regex]
entry.event.wait()
return entry.fsm
return self.cache[regex]
......@@ -2,17 +2,7 @@
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/models/tokenizer.py
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/models/transformers.py
from abc import abstractmethod
from typing import (
TYPE_CHECKING,
Dict,
Hashable,
List,
Optional,
Protocol,
Set,
Tuple,
Union,
)
from typing import Dict, Hashable, List, Protocol, Set, Tuple, Union
import numpy as np
import torch
......@@ -50,15 +40,6 @@ class Tokenizer(Protocol, Hashable):
...
if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer
__all__ = ["transformers"]
KVCacheType = Tuple[Tuple[torch.DoubleTensor, torch.DoubleTensor], ...]
def get_llama_tokenizer_types():
"""Get all the Llama tokenizer types/classes that need work-arounds.
......@@ -101,76 +82,17 @@ def get_llama_tokenizer_types():
)
class Transformer:
"""Represents a `transformers` model."""
def __init__(
self,
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
):
self.device = model.device
self.model = model
self.tokenizer = tokenizer
@torch.inference_mode
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: torch.LongTensor,
past_key_values: Optional[Tuple] = None,
) -> Tuple[torch.FloatTensor, Optional[KVCacheType]]:
"""Compute a forward pass through the transformer model.
Parameters
----------
input_ids
The input token ids. Must be one or two dimensional.
attention_mask
The attention mask. Must be one or two dimensional.
past_key_values
A tuple of tuples containing the cached key and value tensors for each
attention head.
Returns
-------
The computed logits and the new cached key and value tensors.
"""
assert 0 < input_ids.ndim < 3
if past_key_values:
input_ids = input_ids[..., -1].unsqueeze(-1)
output = self.model(
input_ids,
attention_mask=attention_mask,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
past_key_values=past_key_values,
)
return output.logits, output.past_key_values
def __call__(
self,
input_ids: torch.LongTensor,
attention_mask: torch.LongTensor,
past_key_values: Optional[Tuple] = None,
) -> torch.FloatTensor:
logits, kv_cache = self.forward(input_ids, attention_mask, past_key_values)
next_token_logits = logits[..., -1, :]
return next_token_logits, kv_cache
class TransformerTokenizer(Tokenizer):
"""Represents a tokenizer for models in the `transformers` library."""
def __init__(self, tokenizer):
def __init__(self, model_name: str, **kwargs):
from transformers import AutoTokenizer
kwargs.setdefault("padding_side", "left")
self.model_name = model_name
# TODO: Do something to make this hashable?
self.tokenizer = tokenizer
self.kwargs = kwargs
self.tokenizer = AutoTokenizer.from_pretrained(model_name, **kwargs)
self.eos_token_id = self.tokenizer.eos_token_id
self.eos_token = self.tokenizer.eos_token
......@@ -212,55 +134,10 @@ class TransformerTokenizer(Tokenizer):
def __eq__(self, other):
if isinstance(other, type(self)):
return False
# TODO(lsyin): the lru_cache for the TransoformerTokenizer is useless ?
# return other.model_name == self.model_name and other.kwargs == self.kwargs
return other.model_name == self.model_name and other.kwargs == self.kwargs
return NotImplemented
def __hash__(self):
from datasets.fingerprint import Hasher
return hash(Hasher.hash(self.tokenizer))
def transformers(
model_name: str,
device: Optional[str] = None,
model_kwargs: dict = {},
tokenizer_kwargs: dict = {},
):
"""Instantiate a model from the `transformers` library and its tokenizer.
Parameters
----------
model_name
The name of the model as listed on Hugging Face's model page.
device
The device(s) on which the model should be loaded. This overrides
the `device_map` entry in `model_kwargs` when provided.
model_kwargs
A dictionary that contains the keyword arguments to pass to the
`from_pretrained` method when loading the model.
tokenizer_kwargs
A dictionary that contains the keyword arguments to pass to the
`from_pretrained` method when loading the tokenizer.
Returns
-------
A `TransformersModel` model instance.
"""
try:
from transformers import AutoModelForCausalLM
except ImportError:
raise ImportError(
"The `transformers` library needs to be installed in order to use `transformers` models."
)
if device is not None:
model_kwargs["device_map"] = device
model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
tokenizer = TransformerTokenizer(model_name, **tokenizer_kwargs)
return Transformer(model, tokenizer)
......@@ -45,7 +45,7 @@ class Req:
# for constrained decoding
self.regex_fsm = None
self.regex_fsm_state = None
self.regex_fsm_state = 0
def max_new_tokens(self):
return self.sampling_params.max_new_tokens
......
......@@ -111,7 +111,13 @@ class ModelRpcServer(rpyc.Service):
self.stream_interval = server_args.stream_interval
# Init the FSM cache for constrained generation
self.regex_fsm_cache = FSMCache(self.tokenizer)
self.regex_fsm_cache = FSMCache(
server_args.tokenizer_path,
{
"tokenizer_mode": server_args.tokenizer_mode,
"trust_remote_code": server_args.trust_remote_code,
},
)
# Init new token estimation
self.new_token_ratio = min(0.4 * server_args.schedule_conservativeness, 1.0)
......@@ -213,6 +219,10 @@ class ModelRpcServer(rpyc.Service):
req.stream = recv_req.stream
req.tokenizer = self.tokenizer
# Init regex fsm
if req.sampling_params.regex is not None:
req.regex_fsm = self.regex_fsm_cache.init_fsm(req.sampling_params.regex)
# Truncate long prompts
req.input_ids = req.input_ids[: self.model_config.context_len - 1]
req.sampling_params.max_new_tokens = min(
......@@ -322,11 +332,10 @@ class ModelRpcServer(rpyc.Service):
self.model_config.vocab_size, self.int_token_logit_bias
)
# init the regex fsm before first sampling
# Reset regex fsm state before first sampling due to retractions
for req in batch.reqs:
if req.sampling_params.regex is not None:
req.regex_fsm_state = 0
req.regex_fsm = self.regex_fsm_cache.get_fsm(req.sampling_params.regex)
if batch.extend_num_tokens != 0:
# Forward
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment