Unverified Commit ca13f3b8 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Disk FSM cache and adjust code. (#63)

parent 0b2efc2a
from sglang import function, gen, set_default_backend, Runtime from sglang import function, gen, set_default_backend, Runtime
IP_ADDR_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
@function @function
def regex_gen(s): def regex_gen(s):
s += "Q: What is the IP address of the Google DNS servers?\n" s += "Q: What is the IP address of the Google DNS servers?\n"
s += "A: " + gen( s += "A: " + gen(
"answer", "answer",
temperature=0, temperature=0,
regex=r"((25[0-5]|2[0-4]\d|[01]?\d\d?).){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)", regex=IP_ADDR_REGEX,
) )
......
...@@ -19,7 +19,7 @@ dependencies = [ ...@@ -19,7 +19,7 @@ dependencies = [
[project.optional-dependencies] [project.optional-dependencies]
srt = ["fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn", "zmq", "vllm>=0.2.5", srt = ["fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn", "zmq", "vllm>=0.2.5",
"interegular", "lark", "numba", "pydantic"] "interegular", "lark", "numba", "pydantic", "diskcache", "cloudpickle"]
openai = ["openai>=1.0"] openai = ["openai>=1.0"]
anthropic = ["anthropic"] anthropic = ["anthropic"]
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"] all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"]
......
# Adapted from:
# https://github.com/outlines-dev/outlines/blob/6c6966cfa24e9c120494ebb317c6126aa2ae94af/outlines/caching.py
import asyncio
import hashlib
import os
from typing import Callable, Optional
import cloudpickle
from diskcache import Cache
home_dir = os.path.expanduser("~")
cache_dir = os.environ.get("SGLANG_CACHE_DIR", f"{home_dir}/.cache/sglang")
memory = Cache(cache_dir, eviction_policy="none", cull_limit=0)
_caching_enabled = True
def hash_arguments(*args, **kwargs) -> str:
"""Create a hash out of the args and kwargs provided"""
result = hashlib.md5()
for item in list(args) + sorted(kwargs.items()):
result.update(cloudpickle.dumps(item))
return result.hexdigest()
def disk_cache(key_function: Optional[Callable] = None):
def decorator(cached_function: Callable):
def wrapper(*args, **kwargs):
if not _caching_enabled:
return cached_function(*args, **kwargs)
if key_function:
key_args = key_function(*args, **kwargs)
cache_key = hash_arguments(*key_args)
else:
cache_key = hash_arguments(*args, **kwargs)
if cache_key in memory:
return memory[cache_key]
result = cached_function(*args, **kwargs)
memory[cache_key] = result
return result
async def async_wrapper(*args, **kwargs):
if not _caching_enabled:
return await cached_function(*args, **kwargs)
if key_function:
key_args = key_function(*args, **kwargs)
cache_key = hash_arguments(*key_args)
else:
cache_key = hash_arguments(*args, **kwargs)
if cache_key in memory:
return memory[cache_key]
result = await cached_function(*args, **kwargs)
memory[cache_key] = result
return result
if asyncio.iscoroutinefunction(cached_function):
return async_wrapper
else:
return wrapper
return decorator
def disable_cache():
global _caching_enabled
_caching_enabled = False
def clear_cache():
global memory
memory.clear()
# Adapted from: # Adapted from:
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/fsm/fsm.py # https://github.com/outlines-dev/outlines/blob/6c6966cfa24e9c120494ebb317c6126aa2ae94af/outlines/fsm/fsm.py
from typing import List, NewType, Protocol from typing import List, NewType, Protocol, Tuple
import interegular import interegular
from lark import Lark from lark import Lark
from sglang.srt.constrained.disk_cache import disk_cache
# from outlines.fsm.parsing import PartialLark # from outlines.fsm.parsing import PartialLark
from sglang.srt.constrained.regex import ( from sglang.srt.constrained.regex import (
...@@ -16,16 +17,16 @@ FSMState = NewType("FSMState", int) ...@@ -16,16 +17,16 @@ FSMState = NewType("FSMState", int)
class FSM(Protocol): class FSM(Protocol):
def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]: def allowed_token_ids(self, state: FSMState) -> List[int]:
... ...
def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState: def next_state(self, state: FSMState, token_id: int) -> FSMState:
... ...
def is_final_state(self, state: FSMState, idx: int = 0) -> bool: def is_final_state(self, state: FSMState) -> bool:
... ...
def reset(self) -> None: def copy(self) -> "FSM":
... ...
...@@ -38,17 +39,12 @@ class StopAtTokenFSM(FSM): ...@@ -38,17 +39,12 @@ class StopAtTokenFSM(FSM):
""" """
def __init__( def __init__(self, tokenizer: "Tokenizer", stop_token_id: int):
self,
tokenizer: "Tokenizer",
stop_token_id: int,
):
self.stop_token_id = stop_token_id self.stop_token_id = stop_token_id
self.num_tokens_generated = 0
self.vocabulary = tokenizer.vocabulary.values() self.vocabulary = tokenizer.vocabulary.values()
self.final_states = {1} self.final_states = {1}
def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]: def allowed_token_ids(self, state: FSMState) -> List[int]:
"""Generate a list of allowed tokens for the next step. """Generate a list of allowed tokens for the next step.
When in the initial state we allow every token to be generated. When in the initial state we allow every token to be generated.
...@@ -58,8 +54,6 @@ class StopAtTokenFSM(FSM): ...@@ -58,8 +54,6 @@ class StopAtTokenFSM(FSM):
---------- ----------
state state
The current state of the FSM. The current state of the FSM.
idx
The index of the current input in the batch.
Returns Returns
------- -------
...@@ -71,7 +65,7 @@ class StopAtTokenFSM(FSM): ...@@ -71,7 +65,7 @@ class StopAtTokenFSM(FSM):
else: else:
return [self.stop_token_id] return [self.stop_token_id]
def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState: def next_state(self, state: FSMState, token_id: int) -> FSMState:
"""Update the state of the FSM. """Update the state of the FSM.
The FSM stays in the initial state `0` unless the specified stop token The FSM stays in the initial state `0` unless the specified stop token
...@@ -84,29 +78,24 @@ class StopAtTokenFSM(FSM): ...@@ -84,29 +78,24 @@ class StopAtTokenFSM(FSM):
The current state of the FSM. The current state of the FSM.
token_id token_id
The id of the token that was just generated. The id of the token that was just generated.
idx
The index of the current input in the batch.
Returns Returns
------- -------
The new state of the FSM. The new state of the FSM.
""" """
if idx == 0:
self.num_tokens_generated += 1
if token_id == self.stop_token_id: if token_id == self.stop_token_id:
return FSMState(1) return FSMState(1)
return FSMState(0) return FSMState(0)
def is_final_state(self, state: FSMState, idx: int = 0) -> bool: def is_final_state(self, state: FSMState) -> bool:
"""Determine whether the current state of the FSM is a final state.""" """Determine whether the current state of the FSM is a final state."""
return state in self.final_states return state in self.final_states
def reset(self) -> None: def copy(self) -> "StopAtTokenFSM":
"""Reset the FSM to its initial state. Here this only resets the token counter.""" """Create a copy of the FSM."""
self.num_tokens_generated = 0 return self
class RegexFSM(FSM): class RegexFSM(FSM):
...@@ -117,32 +106,48 @@ class RegexFSM(FSM): ...@@ -117,32 +106,48 @@ class RegexFSM(FSM):
regex_string: str, regex_string: str,
tokenizer: "Tokenizer", tokenizer: "Tokenizer",
): ):
regex_pattern = interegular.parse_pattern(regex_string) @disk_cache()
regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) def create_states_mapping(
regex_string: str, cacheable_vocabulary: Tuple[Tuple[str, int]]
) -> Tuple[dict, set, set]:
"""Create the variables related to the mapping between states and tokens
The parameters of the function are used for caching purpose
"""
regex_pattern = interegular.parse_pattern(regex_string)
regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())
(
states_to_token_maps,
empty_token_ids,
) = create_fsm_index_tokenizer(regex_fsm, tokenizer)
# We make sure that it is possible to generate strings in the language
# of the regular expression with the tokens present in the model's
# vocabulary.
if not any(
regex_fsm.finals.intersection(v.values())
for v in states_to_token_maps.values()
):
raise ValueError(
"The vocabulary does not allow us to build a sequence that matches the input regex"
)
final_states = regex_fsm.finals | {
-1
} # Include the EOS token in final states
return states_to_token_maps, empty_token_ids, final_states
( (
self.states_to_token_maps, self.states_to_token_maps,
self.empty_token_ids, self.empty_token_ids,
) = create_fsm_index_tokenizer(regex_fsm, tokenizer) self.final_states,
) = create_states_mapping(
# We make sure that it is possible to generate strings in the language regex_string, tuple(sorted(tokenizer.vocabulary.items()))
# of the regular expression with the tokens present in the model's )
# vocabulary.
if not any(
regex_fsm.finals.intersection(v.values())
for v in self.states_to_token_maps.values()
):
raise ValueError(
"The vocabulary does not allow us to build a sequence that matches the input regex"
)
self.final_states = regex_fsm.finals | {
-1
} # Include the EOS token in final states
self.num_tokens_generated = 0 self.num_tokens_generated = 0
self.vocabulary = tokenizer.vocabulary.values() self.vocabulary = tokenizer.vocabulary.values()
self.end_token_id = tokenizer.eos_token_id self.end_token_id = tokenizer.eos_token_id
def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]: def allowed_token_ids(self, state: FSMState) -> List[int]:
"""Generate a list of allowed tokens for the next step. """Generate a list of allowed tokens for the next step.
The initialization of the FSM builds an index which maps FSM states to a The initialization of the FSM builds an index which maps FSM states to a
...@@ -159,8 +164,6 @@ class RegexFSM(FSM): ...@@ -159,8 +164,6 @@ class RegexFSM(FSM):
---------- ----------
state state
The current state of the FSM. The current state of the FSM.
idx
The index of the current input in the batch.
Returns Returns
------- -------
...@@ -174,7 +177,7 @@ class RegexFSM(FSM): ...@@ -174,7 +177,7 @@ class RegexFSM(FSM):
else: else:
return list(next_tokens_to_end_states.keys()) return list(next_tokens_to_end_states.keys())
def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState: def next_state(self, state: FSMState, token_id: int) -> FSMState:
"""Update the state of the FSM. """Update the state of the FSM.
We use the index to determine to which state the FSM should transition We use the index to determine to which state the FSM should transition
...@@ -186,17 +189,12 @@ class RegexFSM(FSM): ...@@ -186,17 +189,12 @@ class RegexFSM(FSM):
The current state of the FSM. The current state of the FSM.
token_id token_id
The id of the token that was just generated. The id of the token that was just generated.
idx
The index of the current input in the batch.
Returns Returns
------- -------
The new state of the FSM. The new state of the FSM.
""" """
if idx == 0:
self.num_tokens_generated += 1
if token_id == self.end_token_id: if token_id == self.end_token_id:
return FSMState(-1) return FSMState(-1)
...@@ -207,24 +205,22 @@ class RegexFSM(FSM): ...@@ -207,24 +205,22 @@ class RegexFSM(FSM):
return FSMState(next_state) return FSMState(next_state)
def is_final_state(self, state: FSMState, idx: int = 0) -> bool: def is_final_state(self, state: FSMState) -> bool:
"""Determine whether the current state of the FSM is a final state.""" """Determine whether the current state of the FSM is a final state."""
return state in self.final_states return state in self.final_states
def reset(self) -> None: def copy(self) -> "RegexFSM":
"""Reset the FSM to its initial state. Here this only resets the token counter.""" """Create a copy of the FSM."""
self.num_tokens_generated = 0 return self
class CFGFSM(FSM): class CFGFSM(FSM):
"""FSM to generate text that is in the language of a context-free grammar.""" """FSM to generate text that is in the language of a context-free grammar."""
def __init__( def __init__(self, cfg_string: str, tokenizer: "Tokenizer"):
self, self.cfg_string = cfg_string
cfg_string: str, self.tokenizer = tokenizer
tokenizer: "Tokenizer",
):
# self.parser = PartialLark(cfg_string, parser="lalr")
self.parser = Lark( self.parser = Lark(
cfg_string, cfg_string,
parser="lalr", parser="lalr",
...@@ -239,59 +235,52 @@ class CFGFSM(FSM): ...@@ -239,59 +235,52 @@ class CFGFSM(FSM):
self.terminal_regexps[terminal.name] = terminal.pattern.to_regexp() self.terminal_regexps[terminal.name] = terminal.pattern.to_regexp()
self.terminal_regexps["$END"] = tokenizer.eos_token self.terminal_regexps["$END"] = tokenizer.eos_token
self.tokenizer = tokenizer self.generation = ""
self.num_tokens_generated = 0 self.reset_state = False
self.generations: List[str] = [] self.allow_eos = False
self.regex_fsms: List[RegexFSM] = [] self.done = False
self.reset_state: List[bool] = [] self.regex_fsm: RegexFSM
self.allow_eos: List[bool] = []
self.done: List[bool] = []
def _set_next_regex_fsm(self, idx: int = 0) -> None: def _set_next_regex_fsm(self) -> None:
"""Use the CFG incremental parser to set the next regex FSM. """Use the CFG incremental parser to set the next regex FSM.
Check what the CFG incremental parser proposes next. Check what the CFG incremental parser proposes next:
If the only proposal is the EOS token, - If the only proposal is the EOS token we set the state to done and
we set the state to done and return. return.
If there are other proposals, - If there are other proposals, we set a new regex FSM and return.
we set a new regex FSM and return.
""" """
interactive = self.parser.parse_interactive(self.generations[idx]) interactive = self.parser.parse_interactive(self.generation)
interactive.exhaust_lexer() interactive.exhaust_lexer()
options = {self.terminal_regexps[x] for x in interactive.accepts()} options = {self.terminal_regexps[x] for x in interactive.accepts()}
if self.terminal_regexps["$END"] in options: if self.terminal_regexps["$END"] in options:
options.remove(self.terminal_regexps["$END"]) options.remove(self.terminal_regexps["$END"])
if len(options) == 0: if len(options) == 0:
self.done[idx] = True self.done = True
return return
self.allow_eos[idx] = True self.allow_eos = True
options.add("") options.add("")
assert len(options) > 1 assert len(options) > 1
regex_string = r"(" + r"|".join([r"(" + x + r")" for x in options]) + r")" regex_string = r"(" + r"|".join([r"(" + x + r")" for x in options]) + r")"
args = ( self.regex_fsm = RegexFSM(regex_string, self.tokenizer)
regex_string, self.reset_state = True
self.tokenizer,
)
if len(self.regex_fsms) <= idx:
self.regex_fsms.append(RegexFSM(*args))
else:
self.regex_fsms[idx] = RegexFSM(*args)
self.reset_state[idx] = True
def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]: def allowed_token_ids(self, state: FSMState) -> List[int]:
"""Generate a list of allowed tokens for the next step. """Generate a list of allowed tokens for the next step.
Upon initialization, the CFG incremental parser is used to determine the first regex. Upon initialization, the CFG incremental parser is used to determine the
first regex.
This regex is used for proposals until either: This regex is used for proposals until either:
- the regex is exhausted, and its only remaining option is the EOS token,
in which case we always transition to the next regex - The regex is exhausted, and its only remaining option is the EOS
- the regex can be exhausted, but the EOS token is not the only remaining option, token, in which case we always transition to the next regex
in which case we transition to the next regex with probability P (TODO) - The regex can be exhausted, but the EOS token is not the only
or remove the possibility of generating the EOS token and continue with the current regex remaining option, in which case we transition to the next regex with
probability P (TODO) or remove the possibility of generating the EOS
token and continue with the current regex
The CFG incremental parser is allowed to propose the EOS token from any final state, The CFG incremental parser is allowed to propose the EOS token from any final state,
and once it is generated, the FSM will continue to always generate the EOS token. and once it is generated, the FSM will continue to always generate the EOS token.
...@@ -300,22 +289,14 @@ class CFGFSM(FSM): ...@@ -300,22 +289,14 @@ class CFGFSM(FSM):
---------- ----------
state state
The current state of the FSM. The current state of the FSM.
idx
The index of the current input in the batch.
Returns Returns
------- -------
A list that contains the tokens to mask. A list that contains the tokens to mask.
""" """
if len(self.generations) <= idx: if self.generation != "":
self.generations.append("") proposal = self.regex_fsm.allowed_token_ids(state)
self.reset_state.append(False)
self.allow_eos.append(False)
self.done.append(False)
if len(self.regex_fsms) > idx:
proposal = self.regex_fsms[idx].allowed_token_ids(state)
if self.tokenizer.eos_token_id not in proposal: if self.tokenizer.eos_token_id not in proposal:
return proposal return proposal
if set(proposal) != {self.tokenizer.eos_token_id}: if set(proposal) != {self.tokenizer.eos_token_id}:
...@@ -323,23 +304,23 @@ class CFGFSM(FSM): ...@@ -323,23 +304,23 @@ class CFGFSM(FSM):
proposal = [x for x in proposal if x != self.tokenizer.eos_token_id] proposal = [x for x in proposal if x != self.tokenizer.eos_token_id]
return proposal return proposal
self._set_next_regex_fsm(idx) self._set_next_regex_fsm()
if self.done[idx]: if self.done:
return [self.tokenizer.eos_token_id] return [self.tokenizer.eos_token_id]
if self.reset_state[idx]: if self.reset_state:
state = FSMState(0) state = FSMState(0)
proposal = self.regex_fsms[idx].allowed_token_ids(state) proposal = self.regex_fsm.allowed_token_ids(state)
if self.allow_eos[idx]: if self.allow_eos:
self.allow_eos[idx] = False self.allow_eos = False
else: else:
proposal = [x for x in proposal if x != self.tokenizer.eos_token_id] proposal = [x for x in proposal if x != self.tokenizer.eos_token_id]
assert len(proposal) > 0 assert len(proposal) > 0
return proposal return proposal
def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState: def next_state(self, state: FSMState, token_id: int) -> FSMState:
"""Update the state of the FSM. """Update the state of the FSM.
Transitions the underlying regex FSM to its next state. Transitions the underlying regex FSM to its next state.
...@@ -352,34 +333,26 @@ class CFGFSM(FSM): ...@@ -352,34 +333,26 @@ class CFGFSM(FSM):
The current state of the FSM. The current state of the FSM.
token_id token_id
The id of the token that was just generated. The id of the token that was just generated.
idx
The index of the current input in the batch.
Returns Returns
------- -------
The new state of the FSM. The new state of the FSM.
""" """
if idx == 0:
self.num_tokens_generated += 1
if token_id == self.tokenizer.eos_token_id: if token_id == self.tokenizer.eos_token_id:
self.done[idx] = True self.done = True
return FSMState(-1) return FSMState(-1)
if self.reset_state[idx]: if self.reset_state:
self.reset_state[idx] = False self.reset_state = False
state = FSMState(0) state = FSMState(0)
self.generations[idx] += self.tokenizer.decode([token_id])[0] self.generation += self.tokenizer.decode([token_id])[0]
return self.regex_fsms[idx].next_state(state, token_id, idx) return self.regex_fsm.next_state(state, token_id)
def is_final_state(self, state: FSMState, idx: int = 0) -> bool: def is_final_state(self, state: FSMState) -> bool:
"""Return whether the current state of the FSM is a final state.""" """Return whether the current state of the FSM is a final state."""
return self.done[idx] return self.done
def reset(self) -> None: def copy(self) -> "CFGFSM":
"""Reset the FSM to its initial state, so it can be called on a fresh batch on inputs.""" """Create a copy of the FSM."""
self.num_tokens_generated = 0 return CFGFSM(self.cfg_string, self.tokenizer)
self.generations = []
self.regex_fsms = []
self.reset_state = []
self.done = []
import threading
from sglang.srt.constrained.fsm import RegexFSM from sglang.srt.constrained.fsm import RegexFSM
from sglang.srt.constrained.tokenizer import TransformerTokenizer from sglang.srt.constrained.tokenizer import TransformerTokenizer
def get_fsm(regex, tokenizer, fsm_cache_entry):
outlines_tokenizer = TransformerTokenizer(tokenizer)
fsm = RegexFSM(regex, outlines_tokenizer)
fsm_cache_entry.fsm = fsm
fsm_cache_entry.event.set()
class FSMCacheEntry:
def __init__(self):
self.fsm = None
self.event = threading.Event()
class FSMCache: class FSMCache:
def __init__(self, tokenizer): def __init__(self, tokenizer_path, tokenizer_args_dict):
self.cache = {} self.cache = {}
self.tokenizer = tokenizer self.outlines_tokenizer = TransformerTokenizer(
tokenizer_path, **tokenizer_args_dict
)
def init_fsm_in_background(self, regex): def init_fsm(self, regex):
if regex not in self.cache: if regex not in self.cache:
self.cache[regex] = FSMCacheEntry() fsm = RegexFSM(regex, self.outlines_tokenizer)
threading.Thread( self.cache[regex] = fsm
target=get_fsm,
args=(
regex,
self.tokenizer,
self.cache[regex],
),
).start()
def get_fsm(self, regex): return self.cache[regex]
self.init_fsm_in_background(regex)
entry = self.cache[regex]
entry.event.wait()
return entry.fsm
...@@ -2,17 +2,7 @@ ...@@ -2,17 +2,7 @@
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/models/tokenizer.py # https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/models/tokenizer.py
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/models/transformers.py # https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/models/transformers.py
from abc import abstractmethod from abc import abstractmethod
from typing import ( from typing import Dict, Hashable, List, Protocol, Set, Tuple, Union
TYPE_CHECKING,
Dict,
Hashable,
List,
Optional,
Protocol,
Set,
Tuple,
Union,
)
import numpy as np import numpy as np
import torch import torch
...@@ -50,15 +40,6 @@ class Tokenizer(Protocol, Hashable): ...@@ -50,15 +40,6 @@ class Tokenizer(Protocol, Hashable):
... ...
if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer
__all__ = ["transformers"]
KVCacheType = Tuple[Tuple[torch.DoubleTensor, torch.DoubleTensor], ...]
def get_llama_tokenizer_types(): def get_llama_tokenizer_types():
"""Get all the Llama tokenizer types/classes that need work-arounds. """Get all the Llama tokenizer types/classes that need work-arounds.
...@@ -101,76 +82,17 @@ def get_llama_tokenizer_types(): ...@@ -101,76 +82,17 @@ def get_llama_tokenizer_types():
) )
class Transformer:
"""Represents a `transformers` model."""
def __init__(
self,
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
):
self.device = model.device
self.model = model
self.tokenizer = tokenizer
@torch.inference_mode
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: torch.LongTensor,
past_key_values: Optional[Tuple] = None,
) -> Tuple[torch.FloatTensor, Optional[KVCacheType]]:
"""Compute a forward pass through the transformer model.
Parameters
----------
input_ids
The input token ids. Must be one or two dimensional.
attention_mask
The attention mask. Must be one or two dimensional.
past_key_values
A tuple of tuples containing the cached key and value tensors for each
attention head.
Returns
-------
The computed logits and the new cached key and value tensors.
"""
assert 0 < input_ids.ndim < 3
if past_key_values:
input_ids = input_ids[..., -1].unsqueeze(-1)
output = self.model(
input_ids,
attention_mask=attention_mask,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
past_key_values=past_key_values,
)
return output.logits, output.past_key_values
def __call__(
self,
input_ids: torch.LongTensor,
attention_mask: torch.LongTensor,
past_key_values: Optional[Tuple] = None,
) -> torch.FloatTensor:
logits, kv_cache = self.forward(input_ids, attention_mask, past_key_values)
next_token_logits = logits[..., -1, :]
return next_token_logits, kv_cache
class TransformerTokenizer(Tokenizer): class TransformerTokenizer(Tokenizer):
"""Represents a tokenizer for models in the `transformers` library.""" """Represents a tokenizer for models in the `transformers` library."""
def __init__(self, tokenizer): def __init__(self, model_name: str, **kwargs):
from transformers import AutoTokenizer
kwargs.setdefault("padding_side", "left")
self.model_name = model_name
# TODO: Do something to make this hashable? # TODO: Do something to make this hashable?
self.tokenizer = tokenizer self.kwargs = kwargs
self.tokenizer = AutoTokenizer.from_pretrained(model_name, **kwargs)
self.eos_token_id = self.tokenizer.eos_token_id self.eos_token_id = self.tokenizer.eos_token_id
self.eos_token = self.tokenizer.eos_token self.eos_token = self.tokenizer.eos_token
...@@ -212,55 +134,10 @@ class TransformerTokenizer(Tokenizer): ...@@ -212,55 +134,10 @@ class TransformerTokenizer(Tokenizer):
def __eq__(self, other): def __eq__(self, other):
if isinstance(other, type(self)): if isinstance(other, type(self)):
return False return other.model_name == self.model_name and other.kwargs == self.kwargs
# TODO(lsyin): the lru_cache for the TransoformerTokenizer is useless ?
# return other.model_name == self.model_name and other.kwargs == self.kwargs
return NotImplemented return NotImplemented
def __hash__(self): def __hash__(self):
from datasets.fingerprint import Hasher from datasets.fingerprint import Hasher
return hash(Hasher.hash(self.tokenizer)) return hash(Hasher.hash(self.tokenizer))
def transformers(
model_name: str,
device: Optional[str] = None,
model_kwargs: dict = {},
tokenizer_kwargs: dict = {},
):
"""Instantiate a model from the `transformers` library and its tokenizer.
Parameters
----------
model_name
The name of the model as listed on Hugging Face's model page.
device
The device(s) on which the model should be loaded. This overrides
the `device_map` entry in `model_kwargs` when provided.
model_kwargs
A dictionary that contains the keyword arguments to pass to the
`from_pretrained` method when loading the model.
tokenizer_kwargs
A dictionary that contains the keyword arguments to pass to the
`from_pretrained` method when loading the tokenizer.
Returns
-------
A `TransformersModel` model instance.
"""
try:
from transformers import AutoModelForCausalLM
except ImportError:
raise ImportError(
"The `transformers` library needs to be installed in order to use `transformers` models."
)
if device is not None:
model_kwargs["device_map"] = device
model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
tokenizer = TransformerTokenizer(model_name, **tokenizer_kwargs)
return Transformer(model, tokenizer)
...@@ -45,7 +45,7 @@ class Req: ...@@ -45,7 +45,7 @@ class Req:
# for constrained decoding # for constrained decoding
self.regex_fsm = None self.regex_fsm = None
self.regex_fsm_state = None self.regex_fsm_state = 0
def max_new_tokens(self): def max_new_tokens(self):
return self.sampling_params.max_new_tokens return self.sampling_params.max_new_tokens
......
...@@ -111,7 +111,13 @@ class ModelRpcServer(rpyc.Service): ...@@ -111,7 +111,13 @@ class ModelRpcServer(rpyc.Service):
self.stream_interval = server_args.stream_interval self.stream_interval = server_args.stream_interval
# Init the FSM cache for constrained generation # Init the FSM cache for constrained generation
self.regex_fsm_cache = FSMCache(self.tokenizer) self.regex_fsm_cache = FSMCache(
server_args.tokenizer_path,
{
"tokenizer_mode": server_args.tokenizer_mode,
"trust_remote_code": server_args.trust_remote_code,
},
)
# Init new token estimation # Init new token estimation
self.new_token_ratio = min(0.4 * server_args.schedule_conservativeness, 1.0) self.new_token_ratio = min(0.4 * server_args.schedule_conservativeness, 1.0)
...@@ -213,6 +219,10 @@ class ModelRpcServer(rpyc.Service): ...@@ -213,6 +219,10 @@ class ModelRpcServer(rpyc.Service):
req.stream = recv_req.stream req.stream = recv_req.stream
req.tokenizer = self.tokenizer req.tokenizer = self.tokenizer
# Init regex fsm
if req.sampling_params.regex is not None:
req.regex_fsm = self.regex_fsm_cache.init_fsm(req.sampling_params.regex)
# Truncate long prompts # Truncate long prompts
req.input_ids = req.input_ids[: self.model_config.context_len - 1] req.input_ids = req.input_ids[: self.model_config.context_len - 1]
req.sampling_params.max_new_tokens = min( req.sampling_params.max_new_tokens = min(
...@@ -322,11 +332,10 @@ class ModelRpcServer(rpyc.Service): ...@@ -322,11 +332,10 @@ class ModelRpcServer(rpyc.Service):
self.model_config.vocab_size, self.int_token_logit_bias self.model_config.vocab_size, self.int_token_logit_bias
) )
# init the regex fsm before first sampling # Reset regex fsm state before first sampling due to retractions
for req in batch.reqs: for req in batch.reqs:
if req.sampling_params.regex is not None: if req.sampling_params.regex is not None:
req.regex_fsm_state = 0 req.regex_fsm_state = 0
req.regex_fsm = self.regex_fsm_cache.get_fsm(req.sampling_params.regex)
if batch.extend_num_tokens != 0: if batch.extend_num_tokens != 0:
# Forward # Forward
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment