Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
01ee0fbc
Unverified
Commit
01ee0fbc
authored
Jan 25, 2024
by
Liangsheng Yin
Committed by
GitHub
Jan 25, 2024
Browse files
fast regex decode
Auto-detect constant str path in regex FSM, then extend instead.
parent
711d3435
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
968 additions
and
16 deletions
+968
-16
benchmark/json_fast_forward/README.md
benchmark/json_fast_forward/README.md
+46
-0
benchmark/json_fast_forward/bench_other.py
benchmark/json_fast_forward/bench_other.py
+135
-0
benchmark/json_fast_forward/bench_sglang.py
benchmark/json_fast_forward/bench_sglang.py
+92
-0
benchmark/json_fast_forward/dataset.txt
benchmark/json_fast_forward/dataset.txt
+50
-0
python/sglang/lang/interpreter.py
python/sglang/lang/interpreter.py
+25
-5
python/sglang/srt/constrained/fast_forward.py
python/sglang/srt/constrained/fast_forward.py
+78
-0
python/sglang/srt/constrained/fsm_cache.py
python/sglang/srt/constrained/fsm_cache.py
+8
-4
python/sglang/srt/constrained/json_schema.py
python/sglang/srt/constrained/json_schema.py
+290
-0
python/sglang/srt/managers/detokenizer_manager.py
python/sglang/srt/managers/detokenizer_manager.py
+2
-0
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+2
-0
python/sglang/srt/managers/router/infer_batch.py
python/sglang/srt/managers/router/infer_batch.py
+77
-0
python/sglang/srt/managers/router/model_rpc.py
python/sglang/srt/managers/router/model_rpc.py
+18
-5
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+1
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+6
-0
test/srt/test_fast_forward.py
test/srt/test_fast_forward.py
+137
-0
test/srt/test_robust.py
test/srt/test_robust.py
+1
-2
No files found.
benchmark/json_fast_forward/README.md
0 → 100644
View file @
01ee0fbc
## Run benchmark
### Dependencies
```
llama_cpp_python 0.2.32
guidance 0.1.10
vllm 0.2.7
outlines 0.0.24
```
### Benchmark sglang
Run Llama-7B
```
python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
```
Benchmark
```
python3 bench_sglang.py
```
### Benchmark vllm
Run Llama-7B
```
python3 -m outlines.serve.serve --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
```
Benchmark
```
python3 bench_other.py --backend vllm
```
### Benchmark guidance (seems not supported)
Run Llama-7B and benchmark
```
python3 bench_other.py --backend guidance --parallel 1
```
benchmark/json_fast_forward/bench_other.py
0 → 100644
View file @
01ee0fbc
import
argparse
import
json
import
time
from
concurrent.futures
import
ThreadPoolExecutor
from
functools
import
partial
import
guidance
from
sglang.test.test_utils
import
(
add_common_other_args_and_parse
,
call_generate_outlines
,
)
from
sglang.utils
import
dump_state_text
from
tqdm
import
tqdm
# there are some FSM bugs with json regex converted from pydantic model
# here use a string regex instead
# regex_string = build_regex_from_object(HarryPoterRole)
character_regex
=
(
r
"""\{\n"""
+
r
""" "name": "[\w\d\s]{1,16}",\n"""
+
r
""" "house": "(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)",\n"""
+
r
""" "blood status": "(Pure-blood|Half-blood|Muggle-born)",\n"""
+
r
""" "occupation": "(student|teacher|auror|ministry of magic|death eater|order of the phoenix)",\n"""
+
r
""" "wand": \{\n"""
+
r
""" "wood": "[\w\d\s]{1,16}",\n"""
+
r
""" "core": "[\w\d\s]{1,16}",\n"""
+
r
""" "length": [0-9]{1,2}\.[0-9]{0,2}\n"""
+
r
""" \},\n"""
+
r
""" "alive": "(Alive|Deceased)",\n"""
+
r
""" "patronus": "[\w\d\s]{1,16}",\n"""
+
r
""" "bogart": "[\w\d\s]{1,16}"\n"""
+
r
"""\}"""
)
# fmt: off
def
character_gen
(
name
,
generate
):
s
=
name
+
" is a character in Harry Potter. Please fill in the following information about him/her.
\n
"
s
+=
generate
(
s
,
max_tokens
=
256
,
regex
=
character_regex
)
return
s
# fmt: on
@
guidance
def
character_maker
(
lm
,
name
):
regex_str_no_quote
=
r
"[\w\d\s]+"
regex_float
=
r
"[0-9]+\.[0-9]+"
lm
+=
f
"""
\
{
name
}
is a character in Harry Potter. Please fill in the following information about him/her.
{{
"name": "
{
guidance
.
gen
(
"name"
,
max_tokens
=
16
,
regex
=
regex_str_no_quote
)
}
",
"house": "
{
guidance
.
select
(
options
=
[
'Gryffindor'
,
'Slytherin'
,
'Ravenclaw'
,
'Hufflepuff'
],
name
=
'house'
)
}
",
"blood status": "
{
guidance
.
select
(
options
=
[
'Pure-blood'
,
'Half-blood'
,
'Muggle-born'
],
name
=
'blood status'
)
}
",
"occupation": "
{
guidance
.
select
(
options
=
[
'student'
,
'teacher'
,
'auror'
,
'ministry of magic'
,
'death eater'
,
'order of the phoenix'
],
name
=
'occupation'
)
}
",
"wand": {{
"wood": "
{
guidance
.
gen
(
"wood"
,
max_tokens
=
16
,
regex
=
regex_str_no_quote
)
}
",
"core": "
{
guidance
.
gen
(
'core'
,
max_tokens
=
16
,
regex
=
regex_str_no_quote
)
}
",
"length":
{
guidance
.
gen
(
'length'
,
max_tokens
=
10
,
regex
=
regex_float
)
}
}},
"alive": "
{
guidance
.
select
(
options
=
[
'Alive'
,
'Deceased'
],
name
=
'alive'
)
}
",
"patronus": "
{
guidance
.
gen
(
'patronus'
,
max_tokens
=
16
,
regex
=
regex_str_no_quote
)
}
",
"bogart": "
{
guidance
.
gen
(
'bogart'
,
max_tokens
=
16
,
regex
=
regex_str_no_quote
)
}
"
}}
"""
return
lm
def
main
(
args
):
arguments
=
[]
with
open
(
args
.
data_path
,
"r"
)
as
f
:
for
line
in
f
:
arguments
.
append
({
"name"
:
line
.
strip
()})
arguments
=
arguments
[:
args
.
num_jsons
]
states
=
[
None
]
*
len
(
arguments
)
# Select backend
if
args
.
backend
==
"vllm"
:
url
=
f
"
{
args
.
host
}
:
{
args
.
port
}
/generate"
generate
=
partial
(
call_generate_outlines
,
url
=
url
,
temperature
=
0
)
def
func
(
i
):
states
[
i
]
=
character_gen
(
**
arguments
[
i
],
generate
=
generate
)
get_one_answer
=
func
elif
args
.
backend
==
"guidance"
:
model
=
guidance
.
models
.
LlamaCpp
(
"/home/ubuntu/model_weights/Llama-2-7b-chat-hf/ggml-model-f16.gguf"
,
n_gpu_layers
=-
1
,
n_ctx
=
4096
,
)
def
func
(
i
):
lm
=
model
+
character_maker
(
**
arguments
[
i
])
states
[
i
]
=
lm
get_one_answer
=
func
else
:
raise
ValueError
(
f
"Invalid backend:
{
args
.
backend
}
"
)
tic
=
time
.
time
()
if
args
.
parallel
==
1
:
for
i
in
tqdm
(
range
(
len
(
arguments
))):
get_one_answer
(
i
)
else
:
with
ThreadPoolExecutor
(
args
.
parallel
)
as
executor
:
rets
=
executor
.
map
(
get_one_answer
,
list
(
range
(
len
(
arguments
))))
for
_
in
rets
:
pass
latency
=
time
.
time
()
-
tic
# Compute accuracy
print
(
f
"Latency:
{
latency
:.
3
f
}
"
)
# Write results
dump_state_text
(
f
"tmp_output_
{
args
.
backend
}
.txt"
,
states
)
with
open
(
args
.
result_file
,
"a"
)
as
fout
:
value
=
{
"task"
:
"json_fast_forward"
,
"backend"
:
args
.
backend
,
"latency"
:
round
(
latency
,
3
),
"num_jsons"
:
args
.
num_jsons
,
"parallel"
:
args
.
parallel
,
}
fout
.
write
(
json
.
dumps
(
value
)
+
"
\n
"
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--data-path"
,
type
=
str
,
default
=
"dataset.txt"
)
parser
.
add_argument
(
"--num-jsons"
,
type
=
int
,
default
=
50
)
args
=
add_common_other_args_and_parse
(
parser
)
main
(
args
)
benchmark/json_fast_forward/bench_sglang.py
0 → 100644
View file @
01ee0fbc
import
argparse
import
json
import
time
import
sglang
as
sgl
from
sglang.test.test_utils
import
(
add_common_sglang_args_and_parse
,
select_sglang_backend
,
)
from
sglang.utils
import
dump_state_text
# there are some FSM bugs with json regex converted from pydantic model
# here use a string regex instead
# regex_string = build_regex_from_object(HarryPoterRole)
character_regex
=
(
r
"""\{\n"""
+
r
""" "name": "[\w\d\s]{1,16}",\n"""
+
r
""" "house": "(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)",\n"""
+
r
""" "blood status": "(Pure-blood|Half-blood|Muggle-born)",\n"""
+
r
""" "occupation": "(student|teacher|auror|ministry of magic|death eater|order of the phoenix)",\n"""
+
r
""" "wand": \{\n"""
+
r
""" "wood": "[\w\d\s]{1,16}",\n"""
+
r
""" "core": "[\w\d\s]{1,16}",\n"""
+
r
""" "length": [0-9]{1,2}\.[0-9]{0,2}\n"""
+
r
""" \},\n"""
+
r
""" "alive": "(Alive|Deceased)",\n"""
+
r
""" "patronus": "[\w\d\s]{1,16}",\n"""
+
r
""" "bogart": "[\w\d\s]{1,16}"\n"""
+
r
"""\}"""
)
# fmt: off
@
sgl
.
function
def
character_gen
(
s
,
name
):
s
+=
name
+
" is a character in Harry Potter. Please fill in the following information about him/her.
\n
"
s
+=
sgl
.
gen
(
"json_output"
,
max_tokens
=
256
,
regex
=
character_regex
)
# fmt: on
def
bench_character
(
args
):
arguments
=
[]
with
open
(
args
.
data_path
,
"r"
)
as
f
:
for
line
in
f
:
arguments
.
append
({
"name"
:
line
.
strip
()})
arguments
=
arguments
[:
args
.
num_jsons
]
# Select backend
backend
=
select_sglang_backend
(
args
)
sgl
.
set_default_backend
(
backend
)
# Run requests
tic
=
time
.
time
()
states
=
character_gen
.
run_batch
(
arguments
,
temperature
=
0
,
num_threads
=
args
.
parallel
,
progress_bar
=
(
args
.
parallel
==
1
),
)
latency
=
time
.
time
()
-
tic
return
states
,
latency
def
main
(
args
):
states
,
latency
=
bench_character
(
args
)
# Compute accuracy
print
(
f
"Latency:
{
latency
:.
3
f
}
"
)
# Write results
dump_state_text
(
f
"tmp_output_
{
args
.
backend
}
.txt"
,
states
)
with
open
(
f
"
{
args
.
backend
}
.json"
,
"w"
)
as
fout
:
for
state
in
states
:
fout
.
write
(
state
[
"json_output"
]
+
"
\n
"
)
with
open
(
args
.
result_file
,
"a"
)
as
fout
:
value
=
{
"task"
:
"json_fast_forward"
,
"backend"
:
args
.
backend
,
"latency"
:
round
(
latency
,
3
),
"num_jsons"
:
args
.
num_jsons
,
"parallel"
:
args
.
parallel
,
}
fout
.
write
(
json
.
dumps
(
value
)
+
"
\n
"
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--data-path"
,
type
=
str
,
default
=
"dataset.txt"
)
parser
.
add_argument
(
"--num-jsons"
,
type
=
int
,
default
=
50
)
args
=
add_common_sglang_args_and_parse
(
parser
)
main
(
args
)
benchmark/json_fast_forward/dataset.txt
0 → 100644
View file @
01ee0fbc
Harry Potter
Hermione Granger
Ron Weasley
Albus Dumbledore
Severus Snape
Rubeus Hagrid
Draco Malfoy
Ginny Weasley
Fred Weasley
George Weasley
Percy Weasley
Sirius Black
Remus Lupin
Neville Longbottom
Luna Lovegood
Cedric Diggory
Cho Chang
Lord Voldemort
Minerva McGonagall
Filius Flitwick
Dolores Umbridge
Bellatrix Lestrange
Lucius Malfoy
Molly Weasley
Arthur Weasley
Nymphadora Tonks
Dobby
Moaning Myrtle
Peter Pettigrew
Alastor 'Mad-Eye' Moody
Horace Slughorn
Vernon Dursley
Petunia Dursley
Dudley Dursley
Argus Filch
Sybill Trelawney
Gilderoy Lockhart
Fleur Delacour
Viktor Krum
Bill Weasley
Oliver Wood
Cornelius Fudge
Barty Crouch Sr.
Barty Crouch Jr.
Kingsley Shacklebolt
Quirinus Quirrell
Nearly Headless Nick
Aunt Marge
Griphook
Ludo Bagman
\ No newline at end of file
python/sglang/lang/interpreter.py
View file @
01ee0fbc
...
...
@@ -91,12 +91,32 @@ def run_program_batch(
if
num_threads
==
1
:
rets
=
[]
for
arguments
in
batch_arguments
:
rets
.
append
(
run_program
(
program
,
backend
,
(),
arguments
,
default_sampling_para
,
False
,
True
if
progress_bar
:
for
arguments
in
tqdm
.
tqdm
(
batch_arguments
):
rets
.
append
(
run_program
(
program
,
backend
,
(),
arguments
,
default_sampling_para
,
False
,
True
,
)
)
else
:
for
arguments
in
batch_arguments
:
rets
.
append
(
run_program
(
program
,
backend
,
(),
arguments
,
default_sampling_para
,
False
,
True
,
)
)
)
else
:
if
progress_bar
:
pbar
=
tqdm
.
tqdm
(
total
=
len
(
batch_arguments
))
...
...
python/sglang/srt/constrained/fast_forward.py
0 → 100644
View file @
01ee0fbc
import
interegular
from
sglang.srt.constrained.disk_cache
import
disk_cache
from
sglang.srt.constrained.regex
import
FSMInfo
,
make_deterministic_fsm
IP_REGEX
=
r
"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
class
FastForwardMap
:
def
__init__
(
self
,
regex_string
):
@
disk_cache
()
def
_init_state_to_fast_forward
(
regex_string
):
regex_pattern
=
interegular
.
parse_pattern
(
regex_string
)
regex_fsm
,
_
=
make_deterministic_fsm
(
regex_pattern
.
to_fsm
().
reduce
())
fsm_info
:
FSMInfo
=
regex_fsm
.
fsm_info
symbol_to_id
=
fsm_info
.
alphabet_symbol_mapping
id_to_symbol
=
{}
for
symbol
,
id_
in
symbol_to_id
.
items
():
id_to_symbol
.
setdefault
(
id_
,
[]).
append
(
symbol
)
transitions
=
fsm_info
.
transitions
dirty_states
=
set
()
state_to_fast_forward
=
{}
for
(
state
,
id_
),
next_state
in
transitions
.
items
():
if
state
in
dirty_states
:
continue
if
state
in
state_to_fast_forward
:
dirty_states
.
add
(
state
)
del
state_to_fast_forward
[
state
]
continue
if
len
(
id_to_symbol
[
id_
])
>
1
:
dirty_states
.
add
(
state
)
continue
state_to_fast_forward
[
state
]
=
(
id_to_symbol
[
id_
][
0
],
next_state
)
return
state_to_fast_forward
self
.
state_to_fast_forward
=
_init_state_to_fast_forward
(
regex_string
)
def
valid_states
(
self
):
return
self
.
state_to_fast_forward
.
keys
()
def
fast_forward
(
self
,
state
):
if
state
not
in
self
.
state_to_fast_forward
:
return
None
fast_forward_str
=
""
next_state
=
None
while
state
in
self
.
state_to_fast_forward
:
symbol
,
next_state
=
self
.
state_to_fast_forward
[
state
]
fast_forward_str
+=
symbol
state
=
next_state
return
fast_forward_str
,
next_state
class
FastForwardCache
:
def
__init__
(
self
):
self
.
cache
=
{}
def
init_fast_forward_map
(
self
,
regex_string
):
if
regex_string
not
in
self
.
cache
:
fast_forward_map
=
FastForwardMap
(
regex_string
)
self
.
cache
[
regex_string
]
=
fast_forward_map
return
self
.
cache
[
regex_string
]
def
test_main
():
regex_string
=
r
"The google's DNS sever address is "
+
IP_REGEX
fast_forward_map
=
FastForwardMap
(
regex_string
)
for
state
in
fast_forward_map
.
valid_states
():
print
(
state
,
f
'"
{
fast_forward_map
.
fast_forward
(
state
)
}
"'
)
if
__name__
==
"__main__"
:
test_main
()
python/sglang/srt/constrained/fsm_cache.py
View file @
01ee0fbc
from
sglang.srt.constrained.fsm
import
RegexFSM
from
sglang.srt.constrained.tokenizer
import
TransformerTokenizer
_enable_memory_cache
=
True
class
FSMCache
:
def
__init__
(
self
,
tokenizer_path
,
tokenizer_args_dict
):
...
...
@@ -10,8 +12,10 @@ class FSMCache:
)
def
init_fsm
(
self
,
regex
):
if
regex
not
in
self
.
cache
:
fsm
=
RegexFSM
(
regex
,
self
.
outlines_tokenizer
)
self
.
cache
[
regex
]
=
fsm
if
_enable_memory_cache
:
if
regex
not
in
self
.
cache
:
fsm
=
RegexFSM
(
regex
,
self
.
outlines_tokenizer
)
self
.
cache
[
regex
]
=
fsm
return
self
.
cache
[
regex
]
return
self
.
cache
[
regex
]
return
RegexFSM
(
regex
,
self
.
outlines_tokenizer
)
python/sglang/srt/constrained/json_schema.py
0 → 100644
View file @
01ee0fbc
# Adapted from:
# https://github.com/outlines-dev/outlines/blob/8a0bafc8d82937babc5d586dd4f72ae844407e0e/outlines/fsm/json_schema.py
import
inspect
import
json
import
re
from
typing
import
Callable
,
Union
from
jsonschema.protocols
import
Validator
from
pydantic
import
BaseModel
,
create_model
from
referencing
import
Registry
,
Resource
from
referencing._core
import
Resolver
from
referencing.jsonschema
import
DRAFT202012
STRING_INNER
=
r
'(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)'
STRING
=
f
'"
{
STRING_INNER
}
*"'
INTEGER
=
r
"(0|[1-9][0-9]*)"
NUMBER
=
rf
"(-)?(
{
INTEGER
}
)(\.[0-9]+)?([eE][+-][0-9]+)?"
BOOLEAN
=
r
"(true|false)"
NULL
=
r
"null"
type_to_regex
=
{
"string"
:
STRING
,
"integer"
:
INTEGER
,
"number"
:
NUMBER
,
"boolean"
:
BOOLEAN
,
"null"
:
NULL
,
}
def
build_regex_from_object
(
object
:
Union
[
str
,
Callable
,
BaseModel
]):
"""Turn a JSON schema into a regex that matches any JSON object that follows
this schema.
JSON Schema is a declarative language that allows to annotate JSON documents
with types and descriptions. These schemas can be generated from any Python
datastructure that has type annotation: namedtuples, dataclasses, Pydantic
models. And by ensuring that the generation respects the schema we ensure
that the output can be parsed into these objects.
This function parses the provided schema and builds a generation schedule which
mixes deterministic generation (fixed strings), and sampling with constraints.
Parameters
----------
schema
A string that represents a JSON Schema.
Returns
-------
A generation schedule. A list of strings that represent the JSON
schema's structure and regular expression that define the structure of
the fields.
References
----------
.. [0] JSON Schema. https://json-schema.org/
"""
if
isinstance
(
object
,
type
(
BaseModel
)):
schema
=
object
.
model_json_schema
()
elif
callable
(
object
):
schema
=
get_schema_from_signature
(
object
)
else
:
schema
=
json
.
loads
(
object
)
Validator
.
check_schema
(
schema
)
# Build reference resolver
schema
=
Resource
(
contents
=
schema
,
specification
=
DRAFT202012
)
uri
=
schema
.
id
()
if
schema
.
id
()
is
not
None
else
""
registry
=
Registry
().
with_resource
(
uri
=
uri
,
resource
=
schema
)
resolver
=
registry
.
resolver
()
content
=
schema
.
contents
return
to_regex
(
resolver
,
content
)
def
to_regex
(
resolver
:
Resolver
,
instance
:
dict
):
"""Translate a JSON Schema instance into a regex that validates the schema.
Note
----
Many features of JSON schema are missing:
- Handle `additionalProperties` keyword
- Handle types defined as a list
- Handle constraints on numbers
- Handle special patterns: `date`, `uri`, etc.
This does not support recursive definitions.
Parameters
----------
resolver
An object that resolves references to other instances within a schema
instance
The instance to translate
"""
whitespace
=
r
"[\n ]*"
if
"properties"
in
instance
:
regex
=
""
regex
+=
r
"\{"
properties
=
instance
[
"properties"
]
required_properties
=
instance
.
get
(
"required"
,
[])
is_required
=
[
item
in
required_properties
for
item
in
properties
]
# If at least one property is required, we include the one in the lastest position
# without any comma.
# For each property before it (optional or required), we add with a comma after the property.
# For each property after it (optional), we add with a comma before the property.
if
any
(
is_required
):
last_required_pos
=
max
([
i
for
i
,
value
in
enumerate
(
is_required
)
if
value
])
for
i
,
(
name
,
value
)
in
enumerate
(
properties
.
items
()):
subregex
=
f
'
{
whitespace
}
"
{
name
}
"
{
whitespace
}
:
{
whitespace
}
'
subregex
+=
to_regex
(
resolver
,
value
)
if
i
<
last_required_pos
:
subregex
=
f
"
{
subregex
}{
whitespace
}
,"
elif
i
>
last_required_pos
:
subregex
=
f
"
{
whitespace
}
,
{
subregex
}
"
regex
+=
subregex
if
is_required
[
i
]
else
f
"(
{
subregex
}
)?"
# If no property is required, we have to create a possible pattern for each property in which
# it's the last one necessarilly present. Then, we add the others as optional before and after
# following the same strategy as described above.
# The whole block is made optional to allow the case in which no property is returned.
else
:
property_subregexes
=
[]
for
i
,
(
name
,
value
)
in
enumerate
(
properties
.
items
()):
subregex
=
f
'
{
whitespace
}
"
{
name
}
"
{
whitespace
}
:
{
whitespace
}
'
subregex
+=
to_regex
(
resolver
,
value
)
property_subregexes
.
append
(
subregex
)
possible_patterns
=
[]
for
i
in
range
(
len
(
property_subregexes
)):
pattern
=
""
for
subregex
in
property_subregexes
[:
i
]:
pattern
+=
f
"(
{
subregex
}{
whitespace
}
,)?"
pattern
+=
property_subregexes
[
i
]
for
subregex
in
property_subregexes
[
i
+
1
:]:
pattern
+=
f
"(
{
whitespace
}
,
{
subregex
}
)?"
possible_patterns
.
append
(
pattern
)
regex
+=
f
"(
{
'|'
.
join
(
possible_patterns
)
}
)?"
regex
+=
f
"
{
whitespace
}
"
+
r
"\}"
return
regex
# To validate against allOf, the given data must be valid against all of the
# given subschemas.
elif
"allOf"
in
instance
:
subregexes
=
[
to_regex
(
resolver
,
t
)
for
t
in
instance
[
"allOf"
]]
subregexes_str
=
[
f
"
{
subregex
}
"
for
subregex
in
subregexes
]
return
rf
"(
{
''
.
join
(
subregexes_str
)
}
)"
# To validate against `anyOf`, the given data must be valid against
# any (one or more) of the given subschemas.
elif
"anyOf"
in
instance
:
subregexes
=
[
to_regex
(
resolver
,
t
)
for
t
in
instance
[
"anyOf"
]]
return
rf
"(
{
'|'
.
join
(
subregexes
)
}
)"
# To validate against oneOf, the given data must be valid against exactly
# one of the given subschemas.
elif
"oneOf"
in
instance
:
subregexes
=
[
to_regex
(
resolver
,
t
)
for
t
in
instance
[
"oneOf"
]]
xor_patterns
=
[]
# json schema validation ensured there is no overlapping schemas in oneOf
for
subregex
in
subregexes
:
other_subregexes
=
filter
(
lambda
r
:
r
!=
subregex
,
subregexes
)
other_subregexes_str
=
"|"
.
join
([
f
"
{
s
}
"
for
s
in
other_subregexes
])
negative_lookahead
=
f
"(?!.*(
{
other_subregexes_str
}
))"
xor_patterns
.
append
(
f
"(
{
subregex
}
)
{
negative_lookahead
}
"
)
return
rf
"(
{
'|'
.
join
(
xor_patterns
)
}
)"
# The enum keyword is used to restrict a value to a fixed set of values. It
# must be an array with at least one element, where each element is unique.
elif
"enum"
in
instance
:
choices
=
[]
for
choice
in
instance
[
"enum"
]:
if
type
(
choice
)
in
[
int
,
float
,
bool
,
None
]:
choices
.
append
(
re
.
escape
(
str
(
choice
)))
elif
type
(
choice
)
==
str
:
choices
.
append
(
f
'"
{
re
.
escape
(
choice
)
}
"'
)
return
f
"(
{
'|'
.
join
(
choices
)
}
)"
elif
"$ref"
in
instance
:
path
=
f
"
{
instance
[
'$ref'
]
}
"
instance
=
resolver
.
lookup
(
path
).
contents
return
to_regex
(
resolver
,
instance
)
# The type keyword may either be a string or an array:
# - If it's a string, it is the name of one of the basic types.
# - If it is an array, it must be an array of strings, where each string is
# the name of one of the basic types, and each element is unique. In this
# case, the JSON snippet is valid if it matches any of the given types.
elif
"type"
in
instance
:
instance_type
=
instance
[
"type"
]
if
instance_type
==
"string"
:
if
"maxLength"
in
instance
or
"minLength"
in
instance
:
max_items
=
instance
.
get
(
"maxLength"
,
""
)
min_items
=
instance
.
get
(
"minLength"
,
""
)
try
:
if
int
(
max_items
)
<
int
(
min_items
):
raise
ValueError
(
"maxLength must be greater than or equal to minLength"
)
except
ValueError
:
pass
return
f
'"
{
STRING_INNER
}
{{
{
min_items
}
,
{
max_items
}
}}"'
elif
"pattern"
in
instance
:
pattern
=
instance
[
"pattern"
]
if
pattern
[
0
]
==
"^"
and
pattern
[
-
1
]
==
"$"
:
return
rf
'(^"
{
pattern
[
1
:
-
1
]
}
"$)'
else
:
return
rf
'("
{
pattern
}
")'
else
:
return
type_to_regex
[
"string"
]
elif
instance_type
==
"number"
:
return
type_to_regex
[
"number"
]
elif
instance_type
==
"integer"
:
return
type_to_regex
[
"integer"
]
elif
instance_type
==
"array"
:
min_items
=
instance
.
get
(
"minItems"
,
"0"
)
max_items
=
instance
.
get
(
"maxItems"
,
""
)
if
min_items
==
max_items
:
num_repeats
=
"{"
+
str
(
int
(
min_items
)
-
1
)
+
"}"
else
:
num_repeats
=
"*"
if
"items"
in
instance
:
items_regex
=
to_regex
(
resolver
,
instance
[
"items"
])
return
rf
"\[(
{
items_regex
}
)(,(
{
items_regex
}
))
{
num_repeats
}
\]"
else
:
# Here we need to make the choice to exclude generating list of objects
# if the specification of the object is not given, even though a JSON
# object that contains an object here would be valid under the specification.
types
=
[
{
"type"
:
"boolean"
},
{
"type"
:
"null"
},
{
"type"
:
"number"
},
{
"type"
:
"integer"
},
{
"type"
:
"string"
},
]
regexes
=
[
to_regex
(
resolver
,
t
)
for
t
in
types
]
return
(
rf
"\[(
{
'|'
.
join
(
regexes
)
}
)(,(
{
'|'
.
join
(
regexes
)
}
))
{
num_repeats
}
\]"
)
elif
instance_type
==
"boolean"
:
return
type_to_regex
[
"boolean"
]
elif
instance_type
==
"null"
:
return
type_to_regex
[
"null"
]
elif
isinstance
(
instance_type
,
list
):
# Here we need to make the choice to exclude generating an object
# if the specification of the object is not give, even though a JSON
# object that contains an object here would be valid under the specification.
regexes
=
[
to_regex
(
resolver
,
{
"type"
:
t
})
for
t
in
instance_type
if
t
!=
"object"
]
return
rf
"(
{
'|'
.
join
(
regexes
)
}
)"
raise
NotImplementedError
(
f
"""Could not translate the instance
{
instance
}
to a
regular expression. Make sure it is valid to the JSON Schema specification. If
it is, please open an issue on the Outlines repository"""
)
def
get_schema_from_signature
(
fn
:
Callable
)
->
str
:
"""Turn a function signature into a JSON schema.
Every JSON object valid to the output JSON Schema can be passed
to `fn` using the ** unpacking syntax.
"""
signature
=
inspect
.
signature
(
fn
)
arguments
=
{}
for
name
,
arg
in
signature
.
parameters
.
items
():
if
arg
.
annotation
==
inspect
.
_empty
:
raise
ValueError
(
"Each argument must have a type annotation"
)
else
:
arguments
[
name
]
=
(
arg
.
annotation
,
...)
model
=
create_model
(
"Arguments"
,
**
arguments
)
return
model
.
model_json_schema
()
python/sglang/srt/managers/detokenizer_manager.py
View file @
01ee0fbc
...
...
@@ -60,6 +60,8 @@ class DetokenizerManager:
if
first_token
.
startswith
(
"▁"
):
output_strs
[
i
]
=
" "
+
output_strs
[
i
]
output_strs
[
i
]
=
recv_obj
.
output_and_fast_forward_strs
[
i
]
+
output_strs
[
i
]
self
.
send_to_tokenizer
.
send_pyobj
(
BatchStrOut
(
recv_obj
.
rids
,
...
...
python/sglang/srt/managers/io_struct.py
View file @
01ee0fbc
...
...
@@ -59,6 +59,7 @@ class GenerateReqInput:
@
dataclass
class
TokenizedGenerateReqInput
:
rid
:
str
input_text
:
str
input_ids
:
List
[
int
]
pixel_values
:
List
[
float
]
image_hash
:
int
...
...
@@ -73,6 +74,7 @@ class TokenizedGenerateReqInput:
class
BatchTokenIDOut
:
rids
:
List
[
str
]
output_tokens
:
List
[
List
[
int
]]
output_and_fast_forward_strs
:
List
[
str
]
hit_stop_str
:
List
[
Optional
[
str
]]
skip_special_tokens
:
List
[
bool
]
meta_info
:
List
[
Dict
]
...
...
python/sglang/srt/managers/router/infer_batch.py
View file @
01ee0fbc
...
...
@@ -23,6 +23,7 @@ class FinishReason(Enum):
class
Req
:
def
__init__
(
self
,
rid
):
self
.
rid
=
rid
self
.
input_text
=
None
self
.
input_ids
=
[]
self
.
output_ids
=
[]
self
.
pixel_values
=
None
...
...
@@ -48,10 +49,44 @@ class Req:
# for constrained decoding
self
.
regex_fsm
=
None
self
.
regex_fsm_state
=
0
self
.
fast_forward_map
=
None
self
.
output_and_fast_forward_str
=
""
def
max_new_tokens
(
self
):
return
self
.
sampling_params
.
max_new_tokens
def
tokenize_fast_forward
(
self
,
fast_forward_str
,
next_state
):
old_output_str
=
self
.
tokenizer
.
decode
(
self
.
output_ids
)
if
self
.
tokenizer
.
convert_ids_to_tokens
(
self
.
output_ids
[
0
]).
startswith
(
"▁"
):
old_output_str
=
" "
+
old_output_str
new_input_string
=
(
self
.
input_text
+
self
.
output_and_fast_forward_str
+
old_output_str
+
fast_forward_str
)
new_input_ids
=
self
.
tokenizer
.
encode
(
new_input_string
)
fast_forward_tokens_len
=
(
len
(
new_input_ids
)
-
len
(
self
.
input_ids
)
-
len
(
self
.
output_ids
)
)
# print("=" * 100)
# print(f"Catch fast forward:\n{fast_forward_str}")
# print(self.tokenizer.convert_ids_to_tokens(self.input_ids))
# print(self.tokenizer.convert_ids_to_tokens(new_input_ids))
self
.
input_ids
=
new_input_ids
self
.
output_ids
=
[]
self
.
sampling_params
.
max_new_tokens
=
max
(
self
.
sampling_params
.
max_new_tokens
-
fast_forward_tokens_len
,
0
)
self
.
regex_fsm_state
=
next_state
self
.
output_and_fast_forward_str
=
(
self
.
output_and_fast_forward_str
+
old_output_str
+
fast_forward_str
)
# print(f"Output and fast forward str:\n{self.output_and_fast_forward_str}")
# print("*" * 100)
def
check_finished
(
self
):
if
self
.
finished
:
return
...
...
@@ -263,6 +298,8 @@ class Batch:
req
.
last_node
=
None
req
.
extend_input_len
=
0
req
.
output_ids
=
[]
req
.
regex_fsm_state
=
0
# TODO: apply more fine-grained retraction
token_indices
=
self
.
req_to_token_pool
.
req_to_token
[
...
...
@@ -274,6 +311,46 @@ class Batch:
return
retracted_reqs
def
check_for_fast_forward
(
self
):
fast_forward_reqs
=
[]
filter_indices
=
[
i
for
i
in
range
(
len
(
self
.
reqs
))]
req_pool_indices_cpu
=
None
for
i
,
req
in
enumerate
(
self
.
reqs
):
if
req
.
fast_forward_map
is
not
None
:
res
=
req
.
fast_forward_map
.
fast_forward
(
req
.
regex_fsm_state
)
if
res
is
not
None
:
fast_forward_str
,
next_state
=
res
if
len
(
fast_forward_str
)
<=
1
:
continue
# insert the old request into tree_cache
token_ids_in_memory
=
tuple
(
req
.
input_ids
+
req
.
output_ids
)[:
-
1
]
if
req_pool_indices_cpu
is
None
:
req_pool_indices_cpu
=
self
.
req_pool_indices
.
cpu
().
tolist
()
req_pool_idx
=
req_pool_indices_cpu
[
i
]
indices
=
self
.
req_to_token_pool
.
req_to_token
[
req_pool_idx
,
:
len
(
token_ids_in_memory
)
]
prefix_len
=
self
.
tree_cache
.
insert
(
token_ids_in_memory
,
indices
.
clone
()
)
self
.
token_to_kv_pool
.
free
(
indices
[:
prefix_len
])
self
.
req_to_token_pool
.
free
(
req_pool_idx
)
self
.
tree_cache
.
dec_ref_counter
(
req
.
last_node
)
# fast forward
req
.
tokenize_fast_forward
(
fast_forward_str
,
next_state
)
fast_forward_reqs
.
append
(
req
)
filter_indices
.
remove
(
i
)
if
len
(
filter_indices
)
<
len
(
self
.
reqs
):
self
.
filter_batch
(
filter_indices
)
return
fast_forward_reqs
def
prepare_for_decode
(
self
,
input_ids
=
None
):
if
input_ids
is
None
:
input_ids
=
[
...
...
python/sglang/srt/managers/router/model_rpc.py
View file @
01ee0fbc
...
...
@@ -21,6 +21,7 @@ from sglang.srt.managers.router.radix_cache import RadixCache
from
sglang.srt.managers.router.scheduler
import
Scheduler
from
sglang.srt.model_config
import
ModelConfig
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.constrained.fast_forward
import
FastForwardCache
from
sglang.srt.utils
import
(
get_exception_traceback
,
get_int_token_logit_bias
,
...
...
@@ -45,6 +46,7 @@ class ModelRpcServer(rpyc.Service):
self
.
tp_rank
=
tp_rank
self
.
tp_size
=
server_args
.
tp_size
self
.
schedule_heuristic
=
server_args
.
schedule_heuristic
self
.
no_regex_fast_forward
=
server_args
.
no_regex_fast_forward
# Init model and tokenizer
self
.
model_config
=
ModelConfig
(
...
...
@@ -118,6 +120,7 @@ class ModelRpcServer(rpyc.Service):
"trust_remote_code"
:
server_args
.
trust_remote_code
,
},
)
self
.
fast_forward_cache
=
FastForwardCache
()
# Init new token estimation
self
.
new_token_ratio
=
min
(
0.4
*
server_args
.
schedule_conservativeness
,
1.0
)
...
...
@@ -201,6 +204,7 @@ class ModelRpcServer(rpyc.Service):
recv_req
:
TokenizedGenerateReqInput
,
):
req
=
Req
(
recv_req
.
rid
)
req
.
input_text
=
recv_req
.
input_text
req
.
input_ids
=
recv_req
.
input_ids
req
.
pixel_values
=
recv_req
.
pixel_values
req
.
image_size
=
recv_req
.
image_size
...
...
@@ -223,6 +227,10 @@ class ModelRpcServer(rpyc.Service):
# Init regex fsm
if
req
.
sampling_params
.
regex
is
not
None
:
req
.
regex_fsm
=
self
.
regex_fsm_cache
.
init_fsm
(
req
.
sampling_params
.
regex
)
if
not
self
.
no_regex_fast_forward
:
req
.
fast_forward_map
=
self
.
fast_forward_cache
.
init_fast_forward_map
(
req
.
sampling_params
.
regex
)
# Truncate long prompts
req
.
input_ids
=
req
.
input_ids
[:
self
.
model_config
.
context_len
-
1
]
...
...
@@ -334,11 +342,6 @@ class ModelRpcServer(rpyc.Service):
self
.
model_config
.
vocab_size
,
self
.
int_token_logit_bias
)
# Reset regex fsm state before first sampling due to retractions
for
req
in
batch
.
reqs
:
if
req
.
sampling_params
.
regex
is
not
None
:
req
.
regex_fsm_state
=
0
if
batch
.
extend_num_tokens
!=
0
:
# Forward
logits
,
(
logprobs
,
normalized_logprobs
)
=
self
.
model_runner
.
forward
(
...
...
@@ -388,6 +391,13 @@ class ModelRpcServer(rpyc.Service):
self
.
min_new_token_ratio
,
)
if
not
self
.
no_regex_fast_forward
:
# check for fast forward
fast_forward_reqs
=
batch
.
check_for_fast_forward
()
self
.
forward_queue
.
extend
(
fast_forward_reqs
)
if
batch
.
is_empty
():
return
# Update batch tensors
self
.
decode_forward_ct
+=
1
batch
.
prepare_for_decode
()
...
...
@@ -408,6 +418,7 @@ class ModelRpcServer(rpyc.Service):
def
handle_finished_requests
(
self
,
batch
:
Batch
):
output_rids
=
[]
output_tokens
=
[]
output_and_fast_forward_strs
=
[]
output_hit_stop_str
=
[]
output_skip_special_tokens
=
[]
output_meta_info
=
[]
...
...
@@ -425,6 +436,7 @@ class ModelRpcServer(rpyc.Service):
):
output_rids
.
append
(
req
.
rid
)
output_tokens
.
append
(
req
.
output_ids
)
output_and_fast_forward_strs
.
append
(
req
.
output_and_fast_forward_str
)
output_hit_stop_str
.
append
(
req
.
hit_stop_str
)
output_skip_special_tokens
.
append
(
req
.
sampling_params
.
skip_special_tokens
...
...
@@ -445,6 +457,7 @@ class ModelRpcServer(rpyc.Service):
BatchTokenIDOut
(
output_rids
,
output_tokens
,
output_and_fast_forward_strs
,
output_hit_stop_str
,
output_skip_special_tokens
,
output_meta_info
,
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
01ee0fbc
...
...
@@ -157,6 +157,7 @@ class TokenizerManager:
)
tokenized_obj
=
TokenizedGenerateReqInput
(
rid
=
rid
,
input_text
=
obj
.
text
,
input_ids
=
input_ids
,
pixel_values
=
pixel_values
,
image_hash
=
image_hash
,
...
...
python/sglang/srt/server_args.py
View file @
01ee0fbc
...
...
@@ -23,6 +23,7 @@ class ServerArgs:
disable_log_stats
:
bool
=
False
log_stats_interval
:
int
=
10
log_level
:
str
=
"info"
no_regex_fast_forward
:
bool
=
False
def
__post_init__
(
self
):
if
self
.
tokenizer_path
is
None
:
...
...
@@ -150,6 +151,11 @@ class ServerArgs:
default
=
ServerArgs
.
log_stats_interval
,
help
=
"Log stats interval in second."
,
)
parser
.
add_argument
(
"--no-regex-fast-forward"
,
action
=
"store_true"
,
help
=
"Disable regex fast forward"
,
)
@
classmethod
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
...
...
test/srt/test_fast_forward.py
0 → 100644
View file @
01ee0fbc
import
argparse
from
enum
import
Enum
import
sglang
as
sgl
from
pydantic
import
BaseModel
,
constr
from
sglang.srt.constrained.json_schema
import
build_regex_from_object
from
sglang.test.test_utils
import
(
add_common_sglang_args_and_parse
,
select_sglang_backend
,
)
IP_REGEX
=
r
"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
ip_fast_forward
=
(
r
"The google's DNS sever address is "
+
IP_REGEX
+
r
" and "
+
IP_REGEX
+
r
". "
+
r
"The google's website domain name is "
+
r
"www\.(\w)+\.(\w)+"
+
r
"."
)
# fmt: off
@
sgl
.
function
def
regex_gen
(
s
):
s
+=
"Q: What is the IP address of the Google DNS servers?
\n
"
s
+=
"A: "
+
sgl
.
gen
(
"answer"
,
max_tokens
=
128
,
temperature
=
0
,
regex
=
ip_fast_forward
,
)
# fmt: on
json_fast_forward
=
(
r
"""The information about Hogwarts is in the following JSON format\.\n"""
+
r
"""\n\{\n"""
+
r
""" "name": "[\w\d\s]*",\n"""
+
r
""" "country": "[\w\d\s]*",\n"""
+
r
""" "latitude": [-+]?[0-9]*\.?[0-9]+,\n"""
+
r
""" "population": [-+]?[0-9]+,\n"""
+
r
""" "top 3 landmarks": \["[\w\d\s]*", "[\w\d\s]*", "[\w\d\s]*"\],\n"""
+
r
"""\}\n"""
)
# fmt: off
@
sgl
.
function
def
json_gen
(
s
):
s
+=
sgl
.
gen
(
"json"
,
max_tokens
=
128
,
temperature
=
0
,
regex
=
json_fast_forward
,
)
# fmt: on
class
Weapon
(
str
,
Enum
):
sword
=
"sword"
axe
=
"axe"
mace
=
"mace"
spear
=
"spear"
bow
=
"bow"
crossbow
=
"crossbow"
class
Armor
(
str
,
Enum
):
leather
=
"leather"
chainmail
=
"chainmail"
plate
=
"plate"
class
Character
(
BaseModel
):
name
:
constr
(
max_length
=
10
)
age
:
int
armor
:
Armor
weapon
:
Weapon
strength
:
int
@
sgl
.
function
def
character_gen
(
s
):
s
+=
"Give me a character description who is a wizard.
\n
"
s
+=
sgl
.
gen
(
"character"
,
max_tokens
=
128
,
temperature
=
0
,
regex
=
build_regex_from_object
(
Character
),
)
def
main
(
args
):
# Select backend
backend
=
select_sglang_backend
(
args
)
sgl
.
set_default_backend
(
backend
)
state
=
regex_gen
.
run
(
temperature
=
0
)
print
(
"="
*
20
,
"IP TEST"
,
"="
*
20
)
print
(
state
.
text
())
state
=
json_gen
.
run
(
temperature
=
0
)
print
(
"="
*
20
,
"JSON TEST"
,
"="
*
20
)
print
(
state
.
text
())
state
=
character_gen
.
run
(
temperature
=
0
)
print
(
"="
*
20
,
"CHARACTER TEST"
,
"="
*
20
)
print
(
state
.
text
())
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
args
=
add_common_sglang_args_and_parse
(
parser
)
main
(
args
)
# ==================== IP TEST ====================
# Q: What is the IP address of the Google DNS servers?
# A: The google's DNS sever address is 8.8.8.8 and 8.8.4.4. The google's website domain name is www.google.com.
# ==================== JSON TEST ====================
# The information about Hogwarts is in the following JSON format.
# {
# "name": "Hogwarts School of Witchcraft and Wizardry",
# "country": "Scotland",
# "latitude": 55.566667,
# "population": 1000,
# "top 3 landmarks": ["Hogwarts Castle", "The Great Hall", "The Forbidden Forest"],
# }
# ==================== CHARACTER TEST ====================
# Give me a character description who is a wizard.
# { "name" : "Merlin", "age" : 500, "armor" : "chainmail" , "weapon" : "sword" , "strength" : 10 }
test/srt/test_robust.py
View file @
01ee0fbc
...
...
@@ -2,14 +2,13 @@ import argparse
import
random
import
string
import
sglang
as
sgl
from
sglang.test.test_utils
import
(
add_common_sglang_args_and_parse
,
select_sglang_backend
,
)
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
import
sglang
as
sgl
TOKENIZER
=
None
RANDOM_PREFILL_LEN
=
None
RANDOM_DECODE_LEN
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment