Unverified Commit 01ee0fbc authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

fast regex decode

Auto-detect constant str path in regex FSM, then extend instead.
parent 711d3435
## Run benchmark
### Dependencies
```
llama_cpp_python 0.2.32
guidance 0.1.10
vllm 0.2.7
outlines 0.0.24
```
### Benchmark sglang
Run Llama-7B
```
python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
```
Benchmark
```
python3 bench_sglang.py
```
### Benchmark vllm
Run Llama-7B
```
python3 -m outlines.serve.serve --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
```
Benchmark
```
python3 bench_other.py --backend vllm
```
### Benchmark guidance (seems not supported)
Run Llama-7B and benchmark
```
python3 bench_other.py --backend guidance --parallel 1
```
import argparse
import json
import time
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import guidance
from sglang.test.test_utils import (
add_common_other_args_and_parse,
call_generate_outlines,
)
from sglang.utils import dump_state_text
from tqdm import tqdm
# there are some FSM bugs with json regex converted from pydantic model
# here use a string regex instead
# regex_string = build_regex_from_object(HarryPoterRole)
character_regex = (
r"""\{\n"""
+ r""" "name": "[\w\d\s]{1,16}",\n"""
+ r""" "house": "(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)",\n"""
+ r""" "blood status": "(Pure-blood|Half-blood|Muggle-born)",\n"""
+ r""" "occupation": "(student|teacher|auror|ministry of magic|death eater|order of the phoenix)",\n"""
+ r""" "wand": \{\n"""
+ r""" "wood": "[\w\d\s]{1,16}",\n"""
+ r""" "core": "[\w\d\s]{1,16}",\n"""
+ r""" "length": [0-9]{1,2}\.[0-9]{0,2}\n"""
+ r""" \},\n"""
+ r""" "alive": "(Alive|Deceased)",\n"""
+ r""" "patronus": "[\w\d\s]{1,16}",\n"""
+ r""" "bogart": "[\w\d\s]{1,16}"\n"""
+ r"""\}"""
)
# fmt: off
def character_gen(name, generate):
s = name+ " is a character in Harry Potter. Please fill in the following information about him/her.\n"
s += generate(s, max_tokens=256, regex=character_regex)
return s
# fmt: on
@guidance
def character_maker(lm, name):
regex_str_no_quote = r"[\w\d\s]+"
regex_float = r"[0-9]+\.[0-9]+"
lm += f"""\
{name} is a character in Harry Potter. Please fill in the following information about him/her.
{{
"name": "{guidance.gen("name", max_tokens=16, regex=regex_str_no_quote)}",
"house": "{guidance.select(options=['Gryffindor', 'Slytherin', 'Ravenclaw', 'Hufflepuff'], name='house')}",
"blood status": "{guidance.select(options=['Pure-blood', 'Half-blood', 'Muggle-born'], name='blood status')}",
"occupation": "{guidance.select(options=['student', 'teacher', 'auror', 'ministry of magic', 'death eater', 'order of the phoenix'], name='occupation')}",
"wand": {{
"wood": "{guidance.gen("wood", max_tokens=16, regex=regex_str_no_quote)}",
"core": "{guidance.gen('core', max_tokens=16, regex=regex_str_no_quote)}",
"length": {guidance.gen('length', max_tokens=10, regex=regex_float)}
}},
"alive": "{guidance.select(options=['Alive', 'Deceased'], name='alive')}",
"patronus": "{guidance.gen('patronus', max_tokens=16, regex=regex_str_no_quote)}",
"bogart": "{guidance.gen('bogart', max_tokens=16, regex=regex_str_no_quote)}"
}}
"""
return lm
def main(args):
arguments = []
with open(args.data_path, "r") as f:
for line in f:
arguments.append({"name": line.strip()})
arguments = arguments[: args.num_jsons]
states = [None] * len(arguments)
# Select backend
if args.backend == "vllm":
url = f"{args.host}:{args.port}/generate"
generate = partial(call_generate_outlines, url=url, temperature=0)
def func(i):
states[i] = character_gen(**arguments[i], generate=generate)
get_one_answer = func
elif args.backend == "guidance":
model = guidance.models.LlamaCpp(
"/home/ubuntu/model_weights/Llama-2-7b-chat-hf/ggml-model-f16.gguf",
n_gpu_layers=-1,
n_ctx=4096,
)
def func(i):
lm = model + character_maker(**arguments[i])
states[i] = lm
get_one_answer = func
else:
raise ValueError(f"Invalid backend: {args.backend}")
tic = time.time()
if args.parallel == 1:
for i in tqdm(range(len(arguments))):
get_one_answer(i)
else:
with ThreadPoolExecutor(args.parallel) as executor:
rets = executor.map(get_one_answer, list(range(len(arguments))))
for _ in rets:
pass
latency = time.time() - tic
# Compute accuracy
print(f"Latency: {latency:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
value = {
"task": "json_fast_forward",
"backend": args.backend,
"latency": round(latency, 3),
"num_jsons": args.num_jsons,
"parallel": args.parallel,
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="dataset.txt")
parser.add_argument("--num-jsons", type=int, default=50)
args = add_common_other_args_and_parse(parser)
main(args)
import argparse
import json
import time
import sglang as sgl
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import dump_state_text
# there are some FSM bugs with json regex converted from pydantic model
# here use a string regex instead
# regex_string = build_regex_from_object(HarryPoterRole)
character_regex = (
r"""\{\n"""
+ r""" "name": "[\w\d\s]{1,16}",\n"""
+ r""" "house": "(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)",\n"""
+ r""" "blood status": "(Pure-blood|Half-blood|Muggle-born)",\n"""
+ r""" "occupation": "(student|teacher|auror|ministry of magic|death eater|order of the phoenix)",\n"""
+ r""" "wand": \{\n"""
+ r""" "wood": "[\w\d\s]{1,16}",\n"""
+ r""" "core": "[\w\d\s]{1,16}",\n"""
+ r""" "length": [0-9]{1,2}\.[0-9]{0,2}\n"""
+ r""" \},\n"""
+ r""" "alive": "(Alive|Deceased)",\n"""
+ r""" "patronus": "[\w\d\s]{1,16}",\n"""
+ r""" "bogart": "[\w\d\s]{1,16}"\n"""
+ r"""\}"""
)
# fmt: off
@sgl.function
def character_gen(s, name):
s += name+ " is a character in Harry Potter. Please fill in the following information about him/her.\n"
s += sgl.gen("json_output", max_tokens=256, regex=character_regex)
# fmt: on
def bench_character(args):
arguments = []
with open(args.data_path, "r") as f:
for line in f:
arguments.append({"name": line.strip()})
arguments = arguments[: args.num_jsons]
# Select backend
backend = select_sglang_backend(args)
sgl.set_default_backend(backend)
# Run requests
tic = time.time()
states = character_gen.run_batch(
arguments,
temperature=0,
num_threads=args.parallel,
progress_bar=(args.parallel == 1),
)
latency = time.time() - tic
return states, latency
def main(args):
states, latency = bench_character(args)
# Compute accuracy
print(f"Latency: {latency:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(f"{args.backend}.json", "w") as fout:
for state in states:
fout.write(state["json_output"] + "\n")
with open(args.result_file, "a") as fout:
value = {
"task": "json_fast_forward",
"backend": args.backend,
"latency": round(latency, 3),
"num_jsons": args.num_jsons,
"parallel": args.parallel,
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="dataset.txt")
parser.add_argument("--num-jsons", type=int, default=50)
args = add_common_sglang_args_and_parse(parser)
main(args)
Harry Potter
Hermione Granger
Ron Weasley
Albus Dumbledore
Severus Snape
Rubeus Hagrid
Draco Malfoy
Ginny Weasley
Fred Weasley
George Weasley
Percy Weasley
Sirius Black
Remus Lupin
Neville Longbottom
Luna Lovegood
Cedric Diggory
Cho Chang
Lord Voldemort
Minerva McGonagall
Filius Flitwick
Dolores Umbridge
Bellatrix Lestrange
Lucius Malfoy
Molly Weasley
Arthur Weasley
Nymphadora Tonks
Dobby
Moaning Myrtle
Peter Pettigrew
Alastor 'Mad-Eye' Moody
Horace Slughorn
Vernon Dursley
Petunia Dursley
Dudley Dursley
Argus Filch
Sybill Trelawney
Gilderoy Lockhart
Fleur Delacour
Viktor Krum
Bill Weasley
Oliver Wood
Cornelius Fudge
Barty Crouch Sr.
Barty Crouch Jr.
Kingsley Shacklebolt
Quirinus Quirrell
Nearly Headless Nick
Aunt Marge
Griphook
Ludo Bagman
\ No newline at end of file
......@@ -91,12 +91,32 @@ def run_program_batch(
if num_threads == 1:
rets = []
for arguments in batch_arguments:
rets.append(
run_program(
program, backend, (), arguments, default_sampling_para, False, True
if progress_bar:
for arguments in tqdm.tqdm(batch_arguments):
rets.append(
run_program(
program,
backend,
(),
arguments,
default_sampling_para,
False,
True,
)
)
else:
for arguments in batch_arguments:
rets.append(
run_program(
program,
backend,
(),
arguments,
default_sampling_para,
False,
True,
)
)
)
else:
if progress_bar:
pbar = tqdm.tqdm(total=len(batch_arguments))
......
import interegular
from sglang.srt.constrained.disk_cache import disk_cache
from sglang.srt.constrained.regex import FSMInfo, make_deterministic_fsm
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
class FastForwardMap:
def __init__(self, regex_string):
@disk_cache()
def _init_state_to_fast_forward(regex_string):
regex_pattern = interegular.parse_pattern(regex_string)
regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())
fsm_info: FSMInfo = regex_fsm.fsm_info
symbol_to_id = fsm_info.alphabet_symbol_mapping
id_to_symbol = {}
for symbol, id_ in symbol_to_id.items():
id_to_symbol.setdefault(id_, []).append(symbol)
transitions = fsm_info.transitions
dirty_states = set()
state_to_fast_forward = {}
for (state, id_), next_state in transitions.items():
if state in dirty_states:
continue
if state in state_to_fast_forward:
dirty_states.add(state)
del state_to_fast_forward[state]
continue
if len(id_to_symbol[id_]) > 1:
dirty_states.add(state)
continue
state_to_fast_forward[state] = (id_to_symbol[id_][0], next_state)
return state_to_fast_forward
self.state_to_fast_forward = _init_state_to_fast_forward(regex_string)
def valid_states(self):
return self.state_to_fast_forward.keys()
def fast_forward(self, state):
if state not in self.state_to_fast_forward:
return None
fast_forward_str = ""
next_state = None
while state in self.state_to_fast_forward:
symbol, next_state = self.state_to_fast_forward[state]
fast_forward_str += symbol
state = next_state
return fast_forward_str, next_state
class FastForwardCache:
def __init__(self):
self.cache = {}
def init_fast_forward_map(self, regex_string):
if regex_string not in self.cache:
fast_forward_map = FastForwardMap(regex_string)
self.cache[regex_string] = fast_forward_map
return self.cache[regex_string]
def test_main():
regex_string = r"The google's DNS sever address is " + IP_REGEX
fast_forward_map = FastForwardMap(regex_string)
for state in fast_forward_map.valid_states():
print(state, f'"{fast_forward_map.fast_forward(state)}"')
if __name__ == "__main__":
test_main()
from sglang.srt.constrained.fsm import RegexFSM
from sglang.srt.constrained.tokenizer import TransformerTokenizer
_enable_memory_cache = True
class FSMCache:
def __init__(self, tokenizer_path, tokenizer_args_dict):
......@@ -10,8 +12,10 @@ class FSMCache:
)
def init_fsm(self, regex):
if regex not in self.cache:
fsm = RegexFSM(regex, self.outlines_tokenizer)
self.cache[regex] = fsm
if _enable_memory_cache:
if regex not in self.cache:
fsm = RegexFSM(regex, self.outlines_tokenizer)
self.cache[regex] = fsm
return self.cache[regex]
return self.cache[regex]
return RegexFSM(regex, self.outlines_tokenizer)
# Adapted from:
# https://github.com/outlines-dev/outlines/blob/8a0bafc8d82937babc5d586dd4f72ae844407e0e/outlines/fsm/json_schema.py
import inspect
import json
import re
from typing import Callable, Union
from jsonschema.protocols import Validator
from pydantic import BaseModel, create_model
from referencing import Registry, Resource
from referencing._core import Resolver
from referencing.jsonschema import DRAFT202012
STRING_INNER = r'(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)'
STRING = f'"{STRING_INNER}*"'
INTEGER = r"(0|[1-9][0-9]*)"
NUMBER = rf"(-)?({INTEGER})(\.[0-9]+)?([eE][+-][0-9]+)?"
BOOLEAN = r"(true|false)"
NULL = r"null"
type_to_regex = {
"string": STRING,
"integer": INTEGER,
"number": NUMBER,
"boolean": BOOLEAN,
"null": NULL,
}
def build_regex_from_object(object: Union[str, Callable, BaseModel]):
"""Turn a JSON schema into a regex that matches any JSON object that follows
this schema.
JSON Schema is a declarative language that allows to annotate JSON documents
with types and descriptions. These schemas can be generated from any Python
datastructure that has type annotation: namedtuples, dataclasses, Pydantic
models. And by ensuring that the generation respects the schema we ensure
that the output can be parsed into these objects.
This function parses the provided schema and builds a generation schedule which
mixes deterministic generation (fixed strings), and sampling with constraints.
Parameters
----------
schema
A string that represents a JSON Schema.
Returns
-------
A generation schedule. A list of strings that represent the JSON
schema's structure and regular expression that define the structure of
the fields.
References
----------
.. [0] JSON Schema. https://json-schema.org/
"""
if isinstance(object, type(BaseModel)):
schema = object.model_json_schema()
elif callable(object):
schema = get_schema_from_signature(object)
else:
schema = json.loads(object)
Validator.check_schema(schema)
# Build reference resolver
schema = Resource(contents=schema, specification=DRAFT202012)
uri = schema.id() if schema.id() is not None else ""
registry = Registry().with_resource(uri=uri, resource=schema)
resolver = registry.resolver()
content = schema.contents
return to_regex(resolver, content)
def to_regex(resolver: Resolver, instance: dict):
"""Translate a JSON Schema instance into a regex that validates the schema.
Note
----
Many features of JSON schema are missing:
- Handle `additionalProperties` keyword
- Handle types defined as a list
- Handle constraints on numbers
- Handle special patterns: `date`, `uri`, etc.
This does not support recursive definitions.
Parameters
----------
resolver
An object that resolves references to other instances within a schema
instance
The instance to translate
"""
whitespace = r"[\n ]*"
if "properties" in instance:
regex = ""
regex += r"\{"
properties = instance["properties"]
required_properties = instance.get("required", [])
is_required = [item in required_properties for item in properties]
# If at least one property is required, we include the one in the lastest position
# without any comma.
# For each property before it (optional or required), we add with a comma after the property.
# For each property after it (optional), we add with a comma before the property.
if any(is_required):
last_required_pos = max([i for i, value in enumerate(is_required) if value])
for i, (name, value) in enumerate(properties.items()):
subregex = f'{whitespace}"{name}"{whitespace}:{whitespace}'
subregex += to_regex(resolver, value)
if i < last_required_pos:
subregex = f"{subregex}{whitespace},"
elif i > last_required_pos:
subregex = f"{whitespace},{subregex}"
regex += subregex if is_required[i] else f"({subregex})?"
# If no property is required, we have to create a possible pattern for each property in which
# it's the last one necessarilly present. Then, we add the others as optional before and after
# following the same strategy as described above.
# The whole block is made optional to allow the case in which no property is returned.
else:
property_subregexes = []
for i, (name, value) in enumerate(properties.items()):
subregex = f'{whitespace}"{name}"{whitespace}:{whitespace}'
subregex += to_regex(resolver, value)
property_subregexes.append(subregex)
possible_patterns = []
for i in range(len(property_subregexes)):
pattern = ""
for subregex in property_subregexes[:i]:
pattern += f"({subregex}{whitespace},)?"
pattern += property_subregexes[i]
for subregex in property_subregexes[i + 1 :]:
pattern += f"({whitespace},{subregex})?"
possible_patterns.append(pattern)
regex += f"({'|'.join(possible_patterns)})?"
regex += f"{whitespace}" + r"\}"
return regex
# To validate against allOf, the given data must be valid against all of the
# given subschemas.
elif "allOf" in instance:
subregexes = [to_regex(resolver, t) for t in instance["allOf"]]
subregexes_str = [f"{subregex}" for subregex in subregexes]
return rf"({''.join(subregexes_str)})"
# To validate against `anyOf`, the given data must be valid against
# any (one or more) of the given subschemas.
elif "anyOf" in instance:
subregexes = [to_regex(resolver, t) for t in instance["anyOf"]]
return rf"({'|'.join(subregexes)})"
# To validate against oneOf, the given data must be valid against exactly
# one of the given subschemas.
elif "oneOf" in instance:
subregexes = [to_regex(resolver, t) for t in instance["oneOf"]]
xor_patterns = []
# json schema validation ensured there is no overlapping schemas in oneOf
for subregex in subregexes:
other_subregexes = filter(lambda r: r != subregex, subregexes)
other_subregexes_str = "|".join([f"{s}" for s in other_subregexes])
negative_lookahead = f"(?!.*({other_subregexes_str}))"
xor_patterns.append(f"({subregex}){negative_lookahead}")
return rf"({'|'.join(xor_patterns)})"
# The enum keyword is used to restrict a value to a fixed set of values. It
# must be an array with at least one element, where each element is unique.
elif "enum" in instance:
choices = []
for choice in instance["enum"]:
if type(choice) in [int, float, bool, None]:
choices.append(re.escape(str(choice)))
elif type(choice) == str:
choices.append(f'"{re.escape(choice)}"')
return f"({'|'.join(choices)})"
elif "$ref" in instance:
path = f"{instance['$ref']}"
instance = resolver.lookup(path).contents
return to_regex(resolver, instance)
# The type keyword may either be a string or an array:
# - If it's a string, it is the name of one of the basic types.
# - If it is an array, it must be an array of strings, where each string is
# the name of one of the basic types, and each element is unique. In this
# case, the JSON snippet is valid if it matches any of the given types.
elif "type" in instance:
instance_type = instance["type"]
if instance_type == "string":
if "maxLength" in instance or "minLength" in instance:
max_items = instance.get("maxLength", "")
min_items = instance.get("minLength", "")
try:
if int(max_items) < int(min_items):
raise ValueError(
"maxLength must be greater than or equal to minLength"
)
except ValueError:
pass
return f'"{STRING_INNER}{{{min_items},{max_items}}}"'
elif "pattern" in instance:
pattern = instance["pattern"]
if pattern[0] == "^" and pattern[-1] == "$":
return rf'(^"{pattern[1:-1]}"$)'
else:
return rf'("{pattern}")'
else:
return type_to_regex["string"]
elif instance_type == "number":
return type_to_regex["number"]
elif instance_type == "integer":
return type_to_regex["integer"]
elif instance_type == "array":
min_items = instance.get("minItems", "0")
max_items = instance.get("maxItems", "")
if min_items == max_items:
num_repeats = "{" + str(int(min_items) - 1) + "}"
else:
num_repeats = "*"
if "items" in instance:
items_regex = to_regex(resolver, instance["items"])
return rf"\[({items_regex})(,({items_regex})){num_repeats}\]"
else:
# Here we need to make the choice to exclude generating list of objects
# if the specification of the object is not given, even though a JSON
# object that contains an object here would be valid under the specification.
types = [
{"type": "boolean"},
{"type": "null"},
{"type": "number"},
{"type": "integer"},
{"type": "string"},
]
regexes = [to_regex(resolver, t) for t in types]
return (
rf"\[({'|'.join(regexes)})(,({'|'.join(regexes)})){num_repeats}\]"
)
elif instance_type == "boolean":
return type_to_regex["boolean"]
elif instance_type == "null":
return type_to_regex["null"]
elif isinstance(instance_type, list):
# Here we need to make the choice to exclude generating an object
# if the specification of the object is not give, even though a JSON
# object that contains an object here would be valid under the specification.
regexes = [
to_regex(resolver, {"type": t}) for t in instance_type if t != "object"
]
return rf"({'|'.join(regexes)})"
raise NotImplementedError(
f"""Could not translate the instance {instance} to a
regular expression. Make sure it is valid to the JSON Schema specification. If
it is, please open an issue on the Outlines repository"""
)
def get_schema_from_signature(fn: Callable) -> str:
"""Turn a function signature into a JSON schema.
Every JSON object valid to the output JSON Schema can be passed
to `fn` using the ** unpacking syntax.
"""
signature = inspect.signature(fn)
arguments = {}
for name, arg in signature.parameters.items():
if arg.annotation == inspect._empty:
raise ValueError("Each argument must have a type annotation")
else:
arguments[name] = (arg.annotation, ...)
model = create_model("Arguments", **arguments)
return model.model_json_schema()
......@@ -60,6 +60,8 @@ class DetokenizerManager:
if first_token.startswith("▁"):
output_strs[i] = " " + output_strs[i]
output_strs[i] = recv_obj.output_and_fast_forward_strs[i] + output_strs[i]
self.send_to_tokenizer.send_pyobj(
BatchStrOut(
recv_obj.rids,
......
......@@ -59,6 +59,7 @@ class GenerateReqInput:
@dataclass
class TokenizedGenerateReqInput:
rid: str
input_text: str
input_ids: List[int]
pixel_values: List[float]
image_hash: int
......@@ -73,6 +74,7 @@ class TokenizedGenerateReqInput:
class BatchTokenIDOut:
rids: List[str]
output_tokens: List[List[int]]
output_and_fast_forward_strs: List[str]
hit_stop_str: List[Optional[str]]
skip_special_tokens: List[bool]
meta_info: List[Dict]
......
......@@ -23,6 +23,7 @@ class FinishReason(Enum):
class Req:
def __init__(self, rid):
self.rid = rid
self.input_text = None
self.input_ids = []
self.output_ids = []
self.pixel_values = None
......@@ -48,10 +49,44 @@ class Req:
# for constrained decoding
self.regex_fsm = None
self.regex_fsm_state = 0
self.fast_forward_map = None
self.output_and_fast_forward_str = ""
def max_new_tokens(self):
return self.sampling_params.max_new_tokens
def tokenize_fast_forward(self, fast_forward_str, next_state):
old_output_str = self.tokenizer.decode(self.output_ids)
if self.tokenizer.convert_ids_to_tokens(self.output_ids[0]).startswith("▁"):
old_output_str = " " + old_output_str
new_input_string = (
self.input_text
+ self.output_and_fast_forward_str
+ old_output_str
+ fast_forward_str
)
new_input_ids = self.tokenizer.encode(new_input_string)
fast_forward_tokens_len = (
len(new_input_ids) - len(self.input_ids) - len(self.output_ids)
)
# print("=" * 100)
# print(f"Catch fast forward:\n{fast_forward_str}")
# print(self.tokenizer.convert_ids_to_tokens(self.input_ids))
# print(self.tokenizer.convert_ids_to_tokens(new_input_ids))
self.input_ids = new_input_ids
self.output_ids = []
self.sampling_params.max_new_tokens = max(
self.sampling_params.max_new_tokens - fast_forward_tokens_len, 0
)
self.regex_fsm_state = next_state
self.output_and_fast_forward_str = (
self.output_and_fast_forward_str + old_output_str + fast_forward_str
)
# print(f"Output and fast forward str:\n{self.output_and_fast_forward_str}")
# print("*" * 100)
def check_finished(self):
if self.finished:
return
......@@ -263,6 +298,8 @@ class Batch:
req.last_node = None
req.extend_input_len = 0
req.output_ids = []
req.regex_fsm_state = 0
# TODO: apply more fine-grained retraction
token_indices = self.req_to_token_pool.req_to_token[
......@@ -274,6 +311,46 @@ class Batch:
return retracted_reqs
def check_for_fast_forward(self):
fast_forward_reqs = []
filter_indices = [i for i in range(len(self.reqs))]
req_pool_indices_cpu = None
for i, req in enumerate(self.reqs):
if req.fast_forward_map is not None:
res = req.fast_forward_map.fast_forward(req.regex_fsm_state)
if res is not None:
fast_forward_str, next_state = res
if len(fast_forward_str) <= 1:
continue
# insert the old request into tree_cache
token_ids_in_memory = tuple(req.input_ids + req.output_ids)[:-1]
if req_pool_indices_cpu is None:
req_pool_indices_cpu = self.req_pool_indices.cpu().tolist()
req_pool_idx = req_pool_indices_cpu[i]
indices = self.req_to_token_pool.req_to_token[
req_pool_idx, : len(token_ids_in_memory)
]
prefix_len = self.tree_cache.insert(
token_ids_in_memory, indices.clone()
)
self.token_to_kv_pool.free(indices[:prefix_len])
self.req_to_token_pool.free(req_pool_idx)
self.tree_cache.dec_ref_counter(req.last_node)
# fast forward
req.tokenize_fast_forward(fast_forward_str, next_state)
fast_forward_reqs.append(req)
filter_indices.remove(i)
if len(filter_indices) < len(self.reqs):
self.filter_batch(filter_indices)
return fast_forward_reqs
def prepare_for_decode(self, input_ids=None):
if input_ids is None:
input_ids = [
......
......@@ -21,6 +21,7 @@ from sglang.srt.managers.router.radix_cache import RadixCache
from sglang.srt.managers.router.scheduler import Scheduler
from sglang.srt.model_config import ModelConfig
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.constrained.fast_forward import FastForwardCache
from sglang.srt.utils import (
get_exception_traceback,
get_int_token_logit_bias,
......@@ -45,6 +46,7 @@ class ModelRpcServer(rpyc.Service):
self.tp_rank = tp_rank
self.tp_size = server_args.tp_size
self.schedule_heuristic = server_args.schedule_heuristic
self.no_regex_fast_forward = server_args.no_regex_fast_forward
# Init model and tokenizer
self.model_config = ModelConfig(
......@@ -118,6 +120,7 @@ class ModelRpcServer(rpyc.Service):
"trust_remote_code": server_args.trust_remote_code,
},
)
self.fast_forward_cache = FastForwardCache()
# Init new token estimation
self.new_token_ratio = min(0.4 * server_args.schedule_conservativeness, 1.0)
......@@ -201,6 +204,7 @@ class ModelRpcServer(rpyc.Service):
recv_req: TokenizedGenerateReqInput,
):
req = Req(recv_req.rid)
req.input_text = recv_req.input_text
req.input_ids = recv_req.input_ids
req.pixel_values = recv_req.pixel_values
req.image_size = recv_req.image_size
......@@ -223,6 +227,10 @@ class ModelRpcServer(rpyc.Service):
# Init regex fsm
if req.sampling_params.regex is not None:
req.regex_fsm = self.regex_fsm_cache.init_fsm(req.sampling_params.regex)
if not self.no_regex_fast_forward:
req.fast_forward_map = self.fast_forward_cache.init_fast_forward_map(
req.sampling_params.regex
)
# Truncate long prompts
req.input_ids = req.input_ids[: self.model_config.context_len - 1]
......@@ -334,11 +342,6 @@ class ModelRpcServer(rpyc.Service):
self.model_config.vocab_size, self.int_token_logit_bias
)
# Reset regex fsm state before first sampling due to retractions
for req in batch.reqs:
if req.sampling_params.regex is not None:
req.regex_fsm_state = 0
if batch.extend_num_tokens != 0:
# Forward
logits, (logprobs, normalized_logprobs) = self.model_runner.forward(
......@@ -388,6 +391,13 @@ class ModelRpcServer(rpyc.Service):
self.min_new_token_ratio,
)
if not self.no_regex_fast_forward:
# check for fast forward
fast_forward_reqs = batch.check_for_fast_forward()
self.forward_queue.extend(fast_forward_reqs)
if batch.is_empty():
return
# Update batch tensors
self.decode_forward_ct += 1
batch.prepare_for_decode()
......@@ -408,6 +418,7 @@ class ModelRpcServer(rpyc.Service):
def handle_finished_requests(self, batch: Batch):
output_rids = []
output_tokens = []
output_and_fast_forward_strs = []
output_hit_stop_str = []
output_skip_special_tokens = []
output_meta_info = []
......@@ -425,6 +436,7 @@ class ModelRpcServer(rpyc.Service):
):
output_rids.append(req.rid)
output_tokens.append(req.output_ids)
output_and_fast_forward_strs.append(req.output_and_fast_forward_str)
output_hit_stop_str.append(req.hit_stop_str)
output_skip_special_tokens.append(
req.sampling_params.skip_special_tokens
......@@ -445,6 +457,7 @@ class ModelRpcServer(rpyc.Service):
BatchTokenIDOut(
output_rids,
output_tokens,
output_and_fast_forward_strs,
output_hit_stop_str,
output_skip_special_tokens,
output_meta_info,
......
......@@ -157,6 +157,7 @@ class TokenizerManager:
)
tokenized_obj = TokenizedGenerateReqInput(
rid=rid,
input_text=obj.text,
input_ids=input_ids,
pixel_values=pixel_values,
image_hash=image_hash,
......
......@@ -23,6 +23,7 @@ class ServerArgs:
disable_log_stats: bool = False
log_stats_interval: int = 10
log_level: str = "info"
no_regex_fast_forward: bool = False
def __post_init__(self):
if self.tokenizer_path is None:
......@@ -150,6 +151,11 @@ class ServerArgs:
default=ServerArgs.log_stats_interval,
help="Log stats interval in second.",
)
parser.add_argument(
"--no-regex-fast-forward",
action="store_true",
help="Disable regex fast forward",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
......
import argparse
from enum import Enum
import sglang as sgl
from pydantic import BaseModel, constr
from sglang.srt.constrained.json_schema import build_regex_from_object
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
ip_fast_forward = (
r"The google's DNS sever address is "
+ IP_REGEX
+ r" and "
+ IP_REGEX
+ r". "
+ r"The google's website domain name is "
+ r"www\.(\w)+\.(\w)+"
+ r"."
)
# fmt: off
@sgl.function
def regex_gen(s):
s += "Q: What is the IP address of the Google DNS servers?\n"
s += "A: " + sgl.gen(
"answer",
max_tokens=128,
temperature=0,
regex=ip_fast_forward,
)
# fmt: on
json_fast_forward = (
r"""The information about Hogwarts is in the following JSON format\.\n"""
+ r"""\n\{\n"""
+ r""" "name": "[\w\d\s]*",\n"""
+ r""" "country": "[\w\d\s]*",\n"""
+ r""" "latitude": [-+]?[0-9]*\.?[0-9]+,\n"""
+ r""" "population": [-+]?[0-9]+,\n"""
+ r""" "top 3 landmarks": \["[\w\d\s]*", "[\w\d\s]*", "[\w\d\s]*"\],\n"""
+ r"""\}\n"""
)
# fmt: off
@sgl.function
def json_gen(s):
s += sgl.gen(
"json",
max_tokens=128,
temperature=0,
regex=json_fast_forward,
)
# fmt: on
class Weapon(str, Enum):
sword = "sword"
axe = "axe"
mace = "mace"
spear = "spear"
bow = "bow"
crossbow = "crossbow"
class Armor(str, Enum):
leather = "leather"
chainmail = "chainmail"
plate = "plate"
class Character(BaseModel):
name: constr(max_length=10)
age: int
armor: Armor
weapon: Weapon
strength: int
@sgl.function
def character_gen(s):
s += "Give me a character description who is a wizard.\n"
s += sgl.gen(
"character",
max_tokens=128,
temperature=0,
regex=build_regex_from_object(Character),
)
def main(args):
# Select backend
backend = select_sglang_backend(args)
sgl.set_default_backend(backend)
state = regex_gen.run(temperature=0)
print("=" * 20, "IP TEST", "=" * 20)
print(state.text())
state = json_gen.run(temperature=0)
print("=" * 20, "JSON TEST", "=" * 20)
print(state.text())
state = character_gen.run(temperature=0)
print("=" * 20, "CHARACTER TEST", "=" * 20)
print(state.text())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
args = add_common_sglang_args_and_parse(parser)
main(args)
# ==================== IP TEST ====================
# Q: What is the IP address of the Google DNS servers?
# A: The google's DNS sever address is 8.8.8.8 and 8.8.4.4. The google's website domain name is www.google.com.
# ==================== JSON TEST ====================
# The information about Hogwarts is in the following JSON format.
# {
# "name": "Hogwarts School of Witchcraft and Wizardry",
# "country": "Scotland",
# "latitude": 55.566667,
# "population": 1000,
# "top 3 landmarks": ["Hogwarts Castle", "The Great Hall", "The Forbidden Forest"],
# }
# ==================== CHARACTER TEST ====================
# Give me a character description who is a wizard.
# { "name" : "Merlin", "age" : 500, "armor" : "chainmail" , "weapon" : "sword" , "strength" : 10 }
......@@ -2,14 +2,13 @@ import argparse
import random
import string
import sglang as sgl
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from vllm.transformers_utils.tokenizer import get_tokenizer
import sglang as sgl
TOKENIZER = None
RANDOM_PREFILL_LEN = None
RANDOM_DECODE_LEN = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment