Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
37b42297
Unverified
Commit
37b42297
authored
Feb 09, 2024
by
Liangsheng Yin
Committed by
GitHub
Feb 09, 2024
Browse files
import outlines (#168)
parent
cba50273
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
27 additions
and
1464 deletions
+27
-1464
examples/usage/json_decode.py
examples/usage/json_decode.py
+6
-9
python/pyproject.toml
python/pyproject.toml
+1
-1
python/sglang/srt/constrained/__init__.py
python/sglang/srt/constrained/__init__.py
+16
-0
python/sglang/srt/constrained/disk_cache.py
python/sglang/srt/constrained/disk_cache.py
+0
-70
python/sglang/srt/constrained/fsm.py
python/sglang/srt/constrained/fsm.py
+0
-358
python/sglang/srt/constrained/fsm_cache.py
python/sglang/srt/constrained/fsm_cache.py
+1
-2
python/sglang/srt/constrained/json_schema.py
python/sglang/srt/constrained/json_schema.py
+0
-290
python/sglang/srt/constrained/jump_forward.py
python/sglang/srt/constrained/jump_forward.py
+1
-2
python/sglang/srt/constrained/regex.py
python/sglang/srt/constrained/regex.py
+0
-586
python/sglang/srt/constrained/tokenizer.py
python/sglang/srt/constrained/tokenizer.py
+0
-143
python/sglang/srt/server.py
python/sglang/srt/server.py
+1
-1
test/srt/test_jump_forward.py
test/srt/test_jump_forward.py
+1
-2
No files found.
examples/usage/json_decode.py
View file @
37b42297
...
...
@@ -5,10 +5,9 @@ python json_decode.py
"""
from
enum
import
Enum
from
pydantic
import
BaseModel
,
constr
import
sglang
as
sgl
from
sglang.srt.constrained.json_schema
import
build_regex_from_object
from
pydantic
import
BaseModel
from
sglang.srt.constrained
import
build_regex_from_object
character_regex
=
(
r
"""\{\n"""
...
...
@@ -30,7 +29,10 @@ character_regex = (
@
sgl
.
function
def
character_gen
(
s
,
name
):
s
+=
name
+
" is a character in Harry Potter. Please fill in the following information about this character.
\n
"
s
+=
(
name
+
" is a character in Harry Potter. Please fill in the following information about this character.
\n
"
)
s
+=
sgl
.
gen
(
"json_output"
,
max_tokens
=
256
,
regex
=
character_regex
)
...
...
@@ -65,11 +67,6 @@ def pydantic_wizard_gen(s):
)
def
driver_character_gen
():
state
=
character_gen
.
run
(
name
=
"Hermione Granger"
)
print
(
state
.
text
())
def
driver_pydantic_wizard_gen
():
state
=
pydantic_wizard_gen
.
run
()
print
(
state
.
text
())
...
...
python/pyproject.toml
View file @
37b42297
...
...
@@ -20,7 +20,7 @@ dependencies = [
[project.optional-dependencies]
srt
=
[
"aiohttp"
,
"fastapi"
,
"psutil"
,
"rpyc"
,
"torch"
,
"uvloop"
,
"uvicorn"
,
"zmq"
,
"vllm>=0.2.5"
,
"interegular"
,
"lark"
,
"numba"
,
"pydantic"
,
"referencing"
,
"diskcache"
,
"cloudpickle"
,
"pillow"
]
"pydantic"
,
"referencing"
,
"diskcache"
,
"cloudpickle"
,
"pillow"
,
"outlines>=0.0.27"
]
openai
=
[
"openai>=1.0"
,
"numpy"
]
anthropic
=
[
"anthropic"
,
"numpy"
]
all
=
["sglang[srt]
", "
sglang
[openai]
", "
sglang
[anthropic]"]
...
...
python/sglang/srt/constrained/__init__.py
0 → 100644
View file @
37b42297
from
outlines.caching
import
cache
as
disk_cache
from
outlines.caching
import
disable_cache
from
outlines.fsm.fsm
import
RegexFSM
from
outlines.fsm.json_schema
import
build_regex_from_object
from
outlines.fsm.regex
import
FSMInfo
,
make_deterministic_fsm
from
outlines.models.transformers
import
TransformerTokenizer
__all__
=
[
"RegexFSM"
,
"FSMInfo"
,
"make_deterministic_fsm"
,
"build_regex_from_object"
,
"TransformerTokenizer"
,
"disk_cache"
,
"disable_cache"
,
]
python/sglang/srt/constrained/disk_cache.py
deleted
100644 → 0
View file @
cba50273
# Adapted from:
# https://github.com/outlines-dev/outlines/blob/6c6966cfa24e9c120494ebb317c6126aa2ae94af/outlines/caching.py
import
asyncio
import
hashlib
import
os
from
typing
import
Callable
,
Optional
import
cloudpickle
from
diskcache
import
Cache
home_dir
=
os
.
path
.
expanduser
(
"~"
)
cache_dir
=
os
.
environ
.
get
(
"SGLANG_CACHE_DIR"
,
f
"
{
home_dir
}
/.cache/sglang"
)
memory
=
Cache
(
cache_dir
,
eviction_policy
=
"none"
,
cull_limit
=
0
)
_caching_enabled
=
True
def
hash_arguments
(
*
args
,
**
kwargs
)
->
str
:
"""Create a hash out of the args and kwargs provided"""
result
=
hashlib
.
md5
()
for
item
in
list
(
args
)
+
sorted
(
kwargs
.
items
()):
result
.
update
(
cloudpickle
.
dumps
(
item
))
return
result
.
hexdigest
()
def
disk_cache
(
key_function
:
Optional
[
Callable
]
=
None
):
def
decorator
(
cached_function
:
Callable
):
def
wrapper
(
*
args
,
**
kwargs
):
if
not
_caching_enabled
:
return
cached_function
(
*
args
,
**
kwargs
)
if
key_function
:
key_args
=
key_function
(
*
args
,
**
kwargs
)
cache_key
=
hash_arguments
(
*
key_args
)
else
:
cache_key
=
hash_arguments
(
*
args
,
**
kwargs
)
if
cache_key
in
memory
:
return
memory
[
cache_key
]
result
=
cached_function
(
*
args
,
**
kwargs
)
memory
[
cache_key
]
=
result
return
result
async
def
async_wrapper
(
*
args
,
**
kwargs
):
if
not
_caching_enabled
:
return
await
cached_function
(
*
args
,
**
kwargs
)
if
key_function
:
key_args
=
key_function
(
*
args
,
**
kwargs
)
cache_key
=
hash_arguments
(
*
key_args
)
else
:
cache_key
=
hash_arguments
(
*
args
,
**
kwargs
)
if
cache_key
in
memory
:
return
memory
[
cache_key
]
result
=
await
cached_function
(
*
args
,
**
kwargs
)
memory
[
cache_key
]
=
result
return
result
if
asyncio
.
iscoroutinefunction
(
cached_function
):
return
async_wrapper
else
:
return
wrapper
return
decorator
def
disable_cache
():
global
_caching_enabled
_caching_enabled
=
False
def
clear_cache
():
global
memory
memory
.
clear
()
python/sglang/srt/constrained/fsm.py
deleted
100644 → 0
View file @
cba50273
# Adapted from:
# https://github.com/outlines-dev/outlines/blob/6c6966cfa24e9c120494ebb317c6126aa2ae94af/outlines/fsm/fsm.py
from
typing
import
List
,
NewType
,
Protocol
,
Tuple
import
interegular
from
lark
import
Lark
from
sglang.srt.constrained.disk_cache
import
disk_cache
# from outlines.fsm.parsing import PartialLark
from
sglang.srt.constrained.regex
import
(
create_fsm_index_tokenizer
,
make_deterministic_fsm
,
)
from
sglang.srt.constrained.tokenizer
import
Tokenizer
FSMState
=
NewType
(
"FSMState"
,
int
)
class
FSM
(
Protocol
):
def
allowed_token_ids
(
self
,
state
:
FSMState
)
->
List
[
int
]:
...
def
next_state
(
self
,
state
:
FSMState
,
token_id
:
int
)
->
FSMState
:
...
def
is_final_state
(
self
,
state
:
FSMState
)
->
bool
:
...
def
copy
(
self
)
->
"FSM"
:
...
class
StopAtTokenFSM
(
FSM
):
"""FSM to generate text until a specified token id is generated or
a specified number of tokens has been generated.
Text is usually produced until the EOS token is generated by the
model.
"""
def
__init__
(
self
,
tokenizer
:
"Tokenizer"
,
stop_token_id
:
int
):
self
.
stop_token_id
=
stop_token_id
self
.
vocabulary
=
tokenizer
.
vocabulary
.
values
()
self
.
final_states
=
{
1
}
def
allowed_token_ids
(
self
,
state
:
FSMState
)
->
List
[
int
]:
"""Generate a list of allowed tokens for the next step.
When in the initial state we allow every token to be generated.
In the final state the only allowed token is `stop_token_id`.
Parameters
----------
state
The current state of the FSM.
Returns
-------
A list that contains the tokens to mask.
"""
if
state
==
0
:
return
list
(
self
.
vocabulary
)
else
:
return
[
self
.
stop_token_id
]
def
next_state
(
self
,
state
:
FSMState
,
token_id
:
int
)
->
FSMState
:
"""Update the state of the FSM.
The FSM stays in the initial state `0` unless the specified stop token
has been generated or the maximum number of tokens has been reached. In
which case the FSM moves to the final state `1`.
Parameters
----------
state
The current state of the FSM.
token_id
The id of the token that was just generated.
Returns
-------
The new state of the FSM.
"""
if
token_id
==
self
.
stop_token_id
:
return
FSMState
(
1
)
return
FSMState
(
0
)
def
is_final_state
(
self
,
state
:
FSMState
)
->
bool
:
"""Determine whether the current state of the FSM is a final state."""
return
state
in
self
.
final_states
def
copy
(
self
)
->
"StopAtTokenFSM"
:
"""Create a copy of the FSM."""
return
self
class
RegexFSM
(
FSM
):
"""FSM to generate text that is in the language of a regular expression."""
def
__init__
(
self
,
regex_string
:
str
,
tokenizer
:
"Tokenizer"
,
):
@
disk_cache
()
def
create_states_mapping
(
regex_string
:
str
,
cacheable_vocabulary
:
Tuple
[
Tuple
[
str
,
int
]]
)
->
Tuple
[
dict
,
set
,
set
]:
"""Create the variables related to the mapping between states and tokens
The parameters of the function are used for caching purpose
"""
regex_pattern
=
interegular
.
parse_pattern
(
regex_string
)
regex_fsm
,
_
=
make_deterministic_fsm
(
regex_pattern
.
to_fsm
().
reduce
())
(
states_to_token_maps
,
empty_token_ids
,
)
=
create_fsm_index_tokenizer
(
regex_fsm
,
tokenizer
)
# We make sure that it is possible to generate strings in the language
# of the regular expression with the tokens present in the model's
# vocabulary.
if
not
any
(
regex_fsm
.
finals
.
intersection
(
v
.
values
())
for
v
in
states_to_token_maps
.
values
()
):
raise
ValueError
(
"The vocabulary does not allow us to build a sequence that matches the input regex"
)
final_states
=
regex_fsm
.
finals
|
{
-
1
}
# Include the EOS token in final states
return
states_to_token_maps
,
empty_token_ids
,
final_states
(
self
.
states_to_token_maps
,
self
.
empty_token_ids
,
self
.
final_states
,
)
=
create_states_mapping
(
regex_string
,
tuple
(
sorted
(
tokenizer
.
vocabulary
.
items
()))
)
self
.
num_tokens_generated
=
0
self
.
vocabulary
=
tokenizer
.
vocabulary
.
values
()
self
.
end_token_id
=
tokenizer
.
eos_token_id
def
allowed_token_ids
(
self
,
state
:
FSMState
)
->
List
[
int
]:
"""Generate a list of allowed tokens for the next step.
The initialization of the FSM builds an index which maps FSM states to a
map from authorized tokens to the state in which the FSM needs to move
if said token is generated. Therefore the authorized tokens at the
current state are the keys of the map returned by the value of the index
for current state.
If the current state is not contained in the end this means that we are
in a final state of the FSM. We only authorize EOS tokens in the final
state.
Parameters
----------
state
The current state of the FSM.
Returns
-------
A list that contains the tokens to mask.
"""
next_tokens_to_end_states
=
self
.
states_to_token_maps
.
get
(
state
)
if
next_tokens_to_end_states
is
None
:
return
[
self
.
end_token_id
]
else
:
return
list
(
next_tokens_to_end_states
.
keys
())
def
next_state
(
self
,
state
:
FSMState
,
token_id
:
int
)
->
FSMState
:
"""Update the state of the FSM.
We use the index to determine to which state the FSM should transition
given the token that was just generated.
Parameters
----------
state
The current state of the FSM.
token_id
The id of the token that was just generated.
Returns
-------
The new state of the FSM.
"""
if
token_id
==
self
.
end_token_id
:
return
FSMState
(
-
1
)
last_token_to_end_state
=
self
.
states_to_token_maps
[
state
]
next_state
=
last_token_to_end_state
.
get
(
token_id
)
if
next_state
is
None
:
next_state
=
-
1
return
FSMState
(
next_state
)
def
is_final_state
(
self
,
state
:
FSMState
)
->
bool
:
"""Determine whether the current state of the FSM is a final state."""
return
state
in
self
.
final_states
def
copy
(
self
)
->
"RegexFSM"
:
"""Create a copy of the FSM."""
return
self
class
CFGFSM
(
FSM
):
"""FSM to generate text that is in the language of a context-free grammar."""
def
__init__
(
self
,
cfg_string
:
str
,
tokenizer
:
"Tokenizer"
):
self
.
cfg_string
=
cfg_string
self
.
tokenizer
=
tokenizer
self
.
parser
=
Lark
(
cfg_string
,
parser
=
"lalr"
,
lexer
=
"contextual"
,
propagate_positions
=
False
,
maybe_placeholders
=
False
,
regex
=
True
,
)
self
.
terminal_regexps
=
dict
()
for
terminal
in
self
.
parser
.
terminals
:
if
terminal
.
pattern
is
not
None
:
self
.
terminal_regexps
[
terminal
.
name
]
=
terminal
.
pattern
.
to_regexp
()
self
.
terminal_regexps
[
"$END"
]
=
tokenizer
.
eos_token
self
.
generation
=
""
self
.
reset_state
=
False
self
.
allow_eos
=
False
self
.
done
=
False
self
.
regex_fsm
:
RegexFSM
def
_set_next_regex_fsm
(
self
)
->
None
:
"""Use the CFG incremental parser to set the next regex FSM.
Check what the CFG incremental parser proposes next:
- If the only proposal is the EOS token we set the state to done and
return.
- If there are other proposals, we set a new regex FSM and return.
"""
interactive
=
self
.
parser
.
parse_interactive
(
self
.
generation
)
interactive
.
exhaust_lexer
()
options
=
{
self
.
terminal_regexps
[
x
]
for
x
in
interactive
.
accepts
()}
if
self
.
terminal_regexps
[
"$END"
]
in
options
:
options
.
remove
(
self
.
terminal_regexps
[
"$END"
])
if
len
(
options
)
==
0
:
self
.
done
=
True
return
self
.
allow_eos
=
True
options
.
add
(
""
)
assert
len
(
options
)
>
1
regex_string
=
r
"("
+
r
"|"
.
join
([
r
"("
+
x
+
r
")"
for
x
in
options
])
+
r
")"
self
.
regex_fsm
=
RegexFSM
(
regex_string
,
self
.
tokenizer
)
self
.
reset_state
=
True
def
allowed_token_ids
(
self
,
state
:
FSMState
)
->
List
[
int
]:
"""Generate a list of allowed tokens for the next step.
Upon initialization, the CFG incremental parser is used to determine the
first regex.
This regex is used for proposals until either:
- The regex is exhausted, and its only remaining option is the EOS
token, in which case we always transition to the next regex
- The regex can be exhausted, but the EOS token is not the only
remaining option, in which case we transition to the next regex with
probability P (TODO) or remove the possibility of generating the EOS
token and continue with the current regex
The CFG incremental parser is allowed to propose the EOS token from any final state,
and once it is generated, the FSM will continue to always generate the EOS token.
Parameters
----------
state
The current state of the FSM.
Returns
-------
A list that contains the tokens to mask.
"""
if
self
.
generation
!=
""
:
proposal
=
self
.
regex_fsm
.
allowed_token_ids
(
state
)
if
self
.
tokenizer
.
eos_token_id
not
in
proposal
:
return
proposal
if
set
(
proposal
)
!=
{
self
.
tokenizer
.
eos_token_id
}:
if
False
:
# TODO: THIS NEEDS TO BE SAMPLED
proposal
=
[
x
for
x
in
proposal
if
x
!=
self
.
tokenizer
.
eos_token_id
]
return
proposal
self
.
_set_next_regex_fsm
()
if
self
.
done
:
return
[
self
.
tokenizer
.
eos_token_id
]
if
self
.
reset_state
:
state
=
FSMState
(
0
)
proposal
=
self
.
regex_fsm
.
allowed_token_ids
(
state
)
if
self
.
allow_eos
:
self
.
allow_eos
=
False
else
:
proposal
=
[
x
for
x
in
proposal
if
x
!=
self
.
tokenizer
.
eos_token_id
]
assert
len
(
proposal
)
>
0
return
proposal
def
next_state
(
self
,
state
:
FSMState
,
token_id
:
int
)
->
FSMState
:
"""Update the state of the FSM.
Transitions the underlying regex FSM to its next state.
If at max tokens or EOS token, transition permanently to the final state.
Update stored partial generations for subsequent incremental parsing.
Parameters
----------
state
The current state of the FSM.
token_id
The id of the token that was just generated.
Returns
-------
The new state of the FSM.
"""
if
token_id
==
self
.
tokenizer
.
eos_token_id
:
self
.
done
=
True
return
FSMState
(
-
1
)
if
self
.
reset_state
:
self
.
reset_state
=
False
state
=
FSMState
(
0
)
self
.
generation
+=
self
.
tokenizer
.
decode
([
token_id
])[
0
]
return
self
.
regex_fsm
.
next_state
(
state
,
token_id
)
def
is_final_state
(
self
,
state
:
FSMState
)
->
bool
:
"""Return whether the current state of the FSM is a final state."""
return
self
.
done
def
copy
(
self
)
->
"CFGFSM"
:
"""Create a copy of the FSM."""
return
CFGFSM
(
self
.
cfg_string
,
self
.
tokenizer
)
python/sglang/srt/constrained/fsm_cache.py
View file @
37b42297
from
sglang.srt.constrained
import
RegexFSM
,
TransformerTokenizer
from
sglang.srt.constrained.base_cache
import
BaseCache
from
sglang.srt.constrained.fsm
import
RegexFSM
from
sglang.srt.constrained.tokenizer
import
TransformerTokenizer
class
FSMCache
(
BaseCache
):
...
...
python/sglang/srt/constrained/json_schema.py
deleted
100644 → 0
View file @
cba50273
# Adapted from:
# https://github.com/outlines-dev/outlines/blob/8a0bafc8d82937babc5d586dd4f72ae844407e0e/outlines/fsm/json_schema.py
import
inspect
import
json
import
re
from
typing
import
Callable
,
Union
from
jsonschema.protocols
import
Validator
from
pydantic
import
BaseModel
,
create_model
from
referencing
import
Registry
,
Resource
from
referencing._core
import
Resolver
from
referencing.jsonschema
import
DRAFT202012
STRING_INNER
=
r
'(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)'
STRING
=
f
'"
{
STRING_INNER
}
*"'
INTEGER
=
r
"(0|[1-9][0-9]*)"
NUMBER
=
rf
"(-)?(
{
INTEGER
}
)(\.[0-9]+)?([eE][+-][0-9]+)?"
BOOLEAN
=
r
"(true|false)"
NULL
=
r
"null"
type_to_regex
=
{
"string"
:
STRING
,
"integer"
:
INTEGER
,
"number"
:
NUMBER
,
"boolean"
:
BOOLEAN
,
"null"
:
NULL
,
}
def
build_regex_from_object
(
object
:
Union
[
str
,
Callable
,
BaseModel
]):
"""Turn a JSON schema into a regex that matches any JSON object that follows
this schema.
JSON Schema is a declarative language that allows to annotate JSON documents
with types and descriptions. These schemas can be generated from any Python
datastructure that has type annotation: namedtuples, dataclasses, Pydantic
models. And by ensuring that the generation respects the schema we ensure
that the output can be parsed into these objects.
This function parses the provided schema and builds a generation schedule which
mixes deterministic generation (fixed strings), and sampling with constraints.
Parameters
----------
schema
A string that represents a JSON Schema.
Returns
-------
A generation schedule. A list of strings that represent the JSON
schema's structure and regular expression that define the structure of
the fields.
References
----------
.. [0] JSON Schema. https://json-schema.org/
"""
if
isinstance
(
object
,
type
(
BaseModel
)):
schema
=
object
.
model_json_schema
()
elif
callable
(
object
):
schema
=
get_schema_from_signature
(
object
)
else
:
schema
=
json
.
loads
(
object
)
Validator
.
check_schema
(
schema
)
# Build reference resolver
schema
=
Resource
(
contents
=
schema
,
specification
=
DRAFT202012
)
uri
=
schema
.
id
()
if
schema
.
id
()
is
not
None
else
""
registry
=
Registry
().
with_resource
(
uri
=
uri
,
resource
=
schema
)
resolver
=
registry
.
resolver
()
content
=
schema
.
contents
return
to_regex
(
resolver
,
content
)
def
to_regex
(
resolver
:
Resolver
,
instance
:
dict
):
"""Translate a JSON Schema instance into a regex that validates the schema.
Note
----
Many features of JSON schema are missing:
- Handle `additionalProperties` keyword
- Handle types defined as a list
- Handle constraints on numbers
- Handle special patterns: `date`, `uri`, etc.
This does not support recursive definitions.
Parameters
----------
resolver
An object that resolves references to other instances within a schema
instance
The instance to translate
"""
whitespace
=
r
"[\n ]*"
if
"properties"
in
instance
:
regex
=
""
regex
+=
r
"\{"
properties
=
instance
[
"properties"
]
required_properties
=
instance
.
get
(
"required"
,
[])
is_required
=
[
item
in
required_properties
for
item
in
properties
]
# If at least one property is required, we include the one in the lastest position
# without any comma.
# For each property before it (optional or required), we add with a comma after the property.
# For each property after it (optional), we add with a comma before the property.
if
any
(
is_required
):
last_required_pos
=
max
([
i
for
i
,
value
in
enumerate
(
is_required
)
if
value
])
for
i
,
(
name
,
value
)
in
enumerate
(
properties
.
items
()):
subregex
=
f
'
{
whitespace
}
"
{
name
}
"
{
whitespace
}
:
{
whitespace
}
'
subregex
+=
to_regex
(
resolver
,
value
)
if
i
<
last_required_pos
:
subregex
=
f
"
{
subregex
}{
whitespace
}
,"
elif
i
>
last_required_pos
:
subregex
=
f
"
{
whitespace
}
,
{
subregex
}
"
regex
+=
subregex
if
is_required
[
i
]
else
f
"(
{
subregex
}
)?"
# If no property is required, we have to create a possible pattern for each property in which
# it's the last one necessarilly present. Then, we add the others as optional before and after
# following the same strategy as described above.
# The whole block is made optional to allow the case in which no property is returned.
else
:
property_subregexes
=
[]
for
i
,
(
name
,
value
)
in
enumerate
(
properties
.
items
()):
subregex
=
f
'
{
whitespace
}
"
{
name
}
"
{
whitespace
}
:
{
whitespace
}
'
subregex
+=
to_regex
(
resolver
,
value
)
property_subregexes
.
append
(
subregex
)
possible_patterns
=
[]
for
i
in
range
(
len
(
property_subregexes
)):
pattern
=
""
for
subregex
in
property_subregexes
[:
i
]:
pattern
+=
f
"(
{
subregex
}{
whitespace
}
,)?"
pattern
+=
property_subregexes
[
i
]
for
subregex
in
property_subregexes
[
i
+
1
:]:
pattern
+=
f
"(
{
whitespace
}
,
{
subregex
}
)?"
possible_patterns
.
append
(
pattern
)
regex
+=
f
"(
{
'|'
.
join
(
possible_patterns
)
}
)?"
regex
+=
f
"
{
whitespace
}
"
+
r
"\}"
return
regex
# To validate against allOf, the given data must be valid against all of the
# given subschemas.
elif
"allOf"
in
instance
:
subregexes
=
[
to_regex
(
resolver
,
t
)
for
t
in
instance
[
"allOf"
]]
subregexes_str
=
[
f
"
{
subregex
}
"
for
subregex
in
subregexes
]
return
rf
"(
{
''
.
join
(
subregexes_str
)
}
)"
# To validate against `anyOf`, the given data must be valid against
# any (one or more) of the given subschemas.
elif
"anyOf"
in
instance
:
subregexes
=
[
to_regex
(
resolver
,
t
)
for
t
in
instance
[
"anyOf"
]]
return
rf
"(
{
'|'
.
join
(
subregexes
)
}
)"
# To validate against oneOf, the given data must be valid against exactly
# one of the given subschemas.
elif
"oneOf"
in
instance
:
subregexes
=
[
to_regex
(
resolver
,
t
)
for
t
in
instance
[
"oneOf"
]]
xor_patterns
=
[]
# json schema validation ensured there is no overlapping schemas in oneOf
for
subregex
in
subregexes
:
other_subregexes
=
filter
(
lambda
r
:
r
!=
subregex
,
subregexes
)
other_subregexes_str
=
"|"
.
join
([
f
"
{
s
}
"
for
s
in
other_subregexes
])
negative_lookahead
=
f
"(?!.*(
{
other_subregexes_str
}
))"
xor_patterns
.
append
(
f
"(
{
subregex
}
)
{
negative_lookahead
}
"
)
return
rf
"(
{
'|'
.
join
(
xor_patterns
)
}
)"
# The enum keyword is used to restrict a value to a fixed set of values. It
# must be an array with at least one element, where each element is unique.
elif
"enum"
in
instance
:
choices
=
[]
for
choice
in
instance
[
"enum"
]:
if
type
(
choice
)
in
[
int
,
float
,
bool
,
None
]:
choices
.
append
(
re
.
escape
(
str
(
choice
)))
elif
type
(
choice
)
==
str
:
choices
.
append
(
f
'"
{
re
.
escape
(
choice
)
}
"'
)
return
f
"(
{
'|'
.
join
(
choices
)
}
)"
elif
"$ref"
in
instance
:
path
=
f
"
{
instance
[
'$ref'
]
}
"
instance
=
resolver
.
lookup
(
path
).
contents
return
to_regex
(
resolver
,
instance
)
# The type keyword may either be a string or an array:
# - If it's a string, it is the name of one of the basic types.
# - If it is an array, it must be an array of strings, where each string is
# the name of one of the basic types, and each element is unique. In this
# case, the JSON snippet is valid if it matches any of the given types.
elif
"type"
in
instance
:
instance_type
=
instance
[
"type"
]
if
instance_type
==
"string"
:
if
"maxLength"
in
instance
or
"minLength"
in
instance
:
max_items
=
instance
.
get
(
"maxLength"
,
""
)
min_items
=
instance
.
get
(
"minLength"
,
""
)
try
:
if
int
(
max_items
)
<
int
(
min_items
):
raise
ValueError
(
"maxLength must be greater than or equal to minLength"
)
except
ValueError
:
pass
return
f
'"
{
STRING_INNER
}
{{
{
min_items
}
,
{
max_items
}
}}"'
elif
"pattern"
in
instance
:
pattern
=
instance
[
"pattern"
]
if
pattern
[
0
]
==
"^"
and
pattern
[
-
1
]
==
"$"
:
return
rf
'(^"
{
pattern
[
1
:
-
1
]
}
"$)'
else
:
return
rf
'("
{
pattern
}
")'
else
:
return
type_to_regex
[
"string"
]
elif
instance_type
==
"number"
:
return
type_to_regex
[
"number"
]
elif
instance_type
==
"integer"
:
return
type_to_regex
[
"integer"
]
elif
instance_type
==
"array"
:
min_items
=
instance
.
get
(
"minItems"
,
"0"
)
max_items
=
instance
.
get
(
"maxItems"
,
""
)
if
min_items
==
max_items
:
num_repeats
=
"{"
+
str
(
int
(
min_items
)
-
1
)
+
"}"
else
:
num_repeats
=
"*"
if
"items"
in
instance
:
items_regex
=
to_regex
(
resolver
,
instance
[
"items"
])
return
rf
"\[(
{
items_regex
}
)(,(
{
items_regex
}
))
{
num_repeats
}
\]"
else
:
# Here we need to make the choice to exclude generating list of objects
# if the specification of the object is not given, even though a JSON
# object that contains an object here would be valid under the specification.
types
=
[
{
"type"
:
"boolean"
},
{
"type"
:
"null"
},
{
"type"
:
"number"
},
{
"type"
:
"integer"
},
{
"type"
:
"string"
},
]
regexes
=
[
to_regex
(
resolver
,
t
)
for
t
in
types
]
return
(
rf
"\[(
{
'|'
.
join
(
regexes
)
}
)(,(
{
'|'
.
join
(
regexes
)
}
))
{
num_repeats
}
\]"
)
elif
instance_type
==
"boolean"
:
return
type_to_regex
[
"boolean"
]
elif
instance_type
==
"null"
:
return
type_to_regex
[
"null"
]
elif
isinstance
(
instance_type
,
list
):
# Here we need to make the choice to exclude generating an object
# if the specification of the object is not give, even though a JSON
# object that contains an object here would be valid under the specification.
regexes
=
[
to_regex
(
resolver
,
{
"type"
:
t
})
for
t
in
instance_type
if
t
!=
"object"
]
return
rf
"(
{
'|'
.
join
(
regexes
)
}
)"
raise
NotImplementedError
(
f
"""Could not translate the instance
{
instance
}
to a
regular expression. Make sure it is valid to the JSON Schema specification. If
it is, please open an issue on the Outlines repository"""
)
def
get_schema_from_signature
(
fn
:
Callable
)
->
str
:
"""Turn a function signature into a JSON schema.
Every JSON object valid to the output JSON Schema can be passed
to `fn` using the ** unpacking syntax.
"""
signature
=
inspect
.
signature
(
fn
)
arguments
=
{}
for
name
,
arg
in
signature
.
parameters
.
items
():
if
arg
.
annotation
==
inspect
.
_empty
:
raise
ValueError
(
"Each argument must have a type annotation"
)
else
:
arguments
[
name
]
=
(
arg
.
annotation
,
...)
model
=
create_model
(
"Arguments"
,
**
arguments
)
return
model
.
model_json_schema
()
python/sglang/srt/constrained/jump_forward.py
View file @
37b42297
import
interegular
from
sglang.srt.constrained
import
FSMInfo
,
disk_cache
,
make_deterministic_fsm
from
sglang.srt.constrained.base_cache
import
BaseCache
from
sglang.srt.constrained.disk_cache
import
disk_cache
from
sglang.srt.constrained.regex
import
FSMInfo
,
make_deterministic_fsm
IP_REGEX
=
r
"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
...
...
python/sglang/srt/constrained/regex.py
deleted
100644 → 0
View file @
cba50273
# Adapted from:
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/fsm/regex.py
from
collections
import
namedtuple
from
functools
import
lru_cache
from
typing
import
Dict
,
Generator
,
List
,
Sequence
,
Set
,
Tuple
import
numba
import
numpy
as
np
from
interegular.fsm
import
FSM
,
Alphabet
,
OblivionError
,
anything_else
from
numba.typed.typedobjectutils
import
_nonoptional
from
sglang.srt.constrained.tokenizer
import
Tokenizer
class
BetterAlphabet
(
Alphabet
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
assert
anything_else
in
self
.
_symbol_mapping
self
.
anything_value
=
self
.
_symbol_mapping
[
anything_else
]
def
__getitem__
(
self
,
item
):
return
self
.
_symbol_mapping
.
get
(
item
,
self
.
anything_value
)
def
copy
(
self
):
return
BetterAlphabet
(
self
.
_symbol_mapping
.
copy
())
class
BetterFSM
(
FSM
):
flat_transition_map
:
Dict
[
Tuple
[
int
,
int
],
int
]
trans_key_to_states
:
Dict
[
int
,
List
[
int
]]
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
if
not
isinstance
(
self
.
alphabet
,
BetterAlphabet
):
self
.
__dict__
[
"alphabet"
]
=
BetterAlphabet
(
self
.
alphabet
.
_symbol_mapping
)
flat_transition_map
=
{}
trans_key_to_states
=
{}
for
from_state
,
trans_map
in
self
.
map
.
items
():
for
trans_key
,
to_state
in
trans_map
.
items
():
flat_transition_map
[(
from_state
,
trans_key
)]
=
to_state
trans_key_to_states
.
setdefault
(
trans_key
,
set
()).
add
(
from_state
)
self
.
__dict__
[
"trans_key_to_states"
]
=
trans_key_to_states
self
.
__dict__
[
"flat_transition_map"
]
=
flat_transition_map
self
.
__dict__
[
"_fsm_info"
]
=
None
def
copy
(
self
):
return
BetterFSM
(
alphabet
=
self
.
alphabet
.
copy
(),
states
=
self
.
states
.
copy
(),
initial
=
self
.
initial
,
finals
=
self
.
finals
.
copy
(),
map
=
self
.
map
.
copy
(),
__no_validation__
=
True
,
)
@
property
def
fsm_info
(
self
):
if
self
.
_fsm_info
is
None
:
flat_transition_map_items
=
np
.
fromiter
(
((
a
[
0
],
a
[
1
],
b
)
for
a
,
b
in
self
.
flat_transition_map
.
items
()),
dtype
=
np
.
dtype
(
"i8, i8, i8"
),
)
trans_key_to_states_items
=
np
.
fromiter
(
((
k
,
z
)
for
k
,
v
in
self
.
trans_key_to_states
.
items
()
for
z
in
v
),
dtype
=
np
.
dtype
(
"i8, i8"
),
)
alphabet_symbol_mapping_items
=
np
.
fromiter
(
(
it
for
it
in
self
.
alphabet
.
_symbol_mapping
.
items
()
if
it
[
0
]
!=
anything_else
),
dtype
=
np
.
dtype
(
"U1, i8"
),
)
nb_finals
=
np
.
fromiter
(
self
.
finals
,
dtype
=
np
.
dtype
(
"i8"
))
self
.
__dict__
[
"_fsm_info"
]
=
create_fsm_info
(
self
.
initial
,
nb_finals
,
flat_transition_map_items
,
trans_key_to_states_items
,
self
.
alphabet
.
anything_value
,
alphabet_symbol_mapping_items
,
)
return
self
.
_fsm_info
nb_int_list_type
=
numba
.
types
.
ListType
(
numba
.
int64
)
nb_int_pair_type
=
numba
.
types
.
UniTuple
(
numba
.
int64
,
2
)
nb_unichar_1_type
=
numba
.
types
.
UnicodeCharSeq
(
1
)
@
numba
.
njit
(
cache
=
True
)
def
create_fsm_info
(
py_initial
,
py_finals
,
flat_transition_map_items
,
trans_key_to_states_items
,
py_anything_value
,
alphabet_symbol_mapping_items
,
):
trans_key_to_states
=
numba
.
typed
.
Dict
.
empty
(
numba
.
int64
,
nb_int_list_type
)
for
trans_key_and_state
in
trans_key_to_states_items
:
trans_key_to_states
.
setdefault
(
trans_key_and_state
[
0
],
numba
.
typed
.
List
.
empty_list
(
numba
.
int64
)
).
append
(
trans_key_and_state
[
1
])
flat_transition_map
=
numba
.
typed
.
Dict
.
empty
(
nb_int_pair_type
,
numba
.
int64
)
for
trans_key_and_state
in
flat_transition_map_items
:
flat_transition_map
[
(
trans_key_and_state
[
0
],
trans_key_and_state
[
1
])
]
=
trans_key_and_state
[
2
]
alphabet_symbol_map
=
numba
.
typed
.
Dict
.
empty
(
nb_unichar_1_type
,
numba
.
int64
)
for
symbol_and_trans_key
in
alphabet_symbol_mapping_items
:
alphabet_symbol_map
[
symbol_and_trans_key
[
0
]]
=
symbol_and_trans_key
[
1
]
initial
=
numba
.
int64
(
py_initial
)
finals
=
set
()
for
final
in
py_finals
:
finals
.
add
(
final
)
anything_value
=
numba
.
int64
(
py_anything_value
)
return
FSMInfo
(
initial
,
finals
,
flat_transition_map
,
trans_key_to_states
,
anything_value
,
alphabet_symbol_map
,
)
FSMInfo
=
namedtuple
(
"FSMInfo"
,
[
"initial"
,
"finals"
,
"transitions"
,
"trans_key_to_states"
,
"alphabet_anything_value"
,
"alphabet_symbol_mapping"
,
],
)
def
make_deterministic_fsm
(
fsm
:
FSM
)
->
Tuple
[
BetterFSM
,
Dict
[
int
,
int
]]:
"""Construct an equivalent FSM with deterministic state labels."""
old_to_new_trans_keys
=
{
trans_key
:
i
for
i
,
(
trans_key
,
_
)
in
enumerate
(
sorted
(
fsm
.
alphabet
.
by_transition
.
items
(),
key
=
lambda
x
:
sorted
(
x
[
1
]))
)
}
new_symbol_mapping
=
{
symbol
:
old_to_new_trans_keys
[
trans_key
]
for
symbol
,
trans_key
in
fsm
.
alphabet
.
_symbol_mapping
.
items
()
}
new_alphabet
=
BetterAlphabet
(
new_symbol_mapping
)
new_map
=
{
from_state
:
{
old_to_new_trans_keys
[
trans_key
]:
to_state
for
trans_key
,
to_state
in
trans_map
.
items
()
}
for
from_state
,
trans_map
in
fsm
.
map
.
items
()
}
old_to_new_states
=
{}
old_to_new_states
[
fsm
.
initial
]
=
0
i
=
0
seen
=
{
fsm
.
initial
}
old_state_queue
=
[
fsm
.
initial
]
while
old_state_queue
:
old_state
=
old_state_queue
.
pop
(
-
1
)
transitions
=
new_map
[
old_state
]
sorted_transitions
=
sorted
(
transitions
.
items
(),
key
=
lambda
v
:
v
[
0
])
for
_
,
old_state
in
sorted_transitions
:
if
old_state
not
in
seen
:
old_state_queue
.
append
(
old_state
)
seen
.
add
(
old_state
)
if
old_state
not
in
old_to_new_states
:
i
+=
1
old_to_new_states
[
old_state
]
=
i
new_map
=
dict
(
sorted
(
(
(
old_to_new_states
[
from_state
],
dict
(
sorted
(
(
(
trans_key
,
old_to_new_states
[
to_state
])
for
trans_key
,
to_state
in
trans_map
.
items
()
),
key
=
lambda
v
:
v
[
0
],
)
),
)
for
from_state
,
trans_map
in
new_map
.
items
()
),
key
=
lambda
v
:
v
[
0
],
)
)
new_initial
=
0
new_finals
=
frozenset
(
sorted
(
old_to_new_states
[
old_state
]
for
old_state
in
fsm
.
finals
)
)
new_states
=
frozenset
(
sorted
(
new_map
.
keys
()))
new_fsm
=
BetterFSM
(
new_alphabet
,
new_states
,
new_initial
,
new_finals
,
new_map
)
return
new_fsm
,
old_to_new_states
@
numba
.
njit
(
nogil
=
True
,
cache
=
True
)
def
_walk_fsm
(
fsm_transitions
:
Dict
[
Tuple
[
int
,
int
],
int
],
alphabet_symbol_mapping
:
Dict
[
str
,
int
],
alphabet_anything_value
:
int
,
fsm_initial
:
int
,
fsm_finals
:
Set
[
int
],
input_string
:
str
,
start_state
:
int
,
full_match
:
bool
=
True
,
)
->
List
[
int
]:
state
=
start_state
accepted_states
:
List
[
int
]
=
numba
.
typed
.
List
.
empty_list
(
numba
.
int64
)
last_final_idx
:
int
=
numba
.
uint64
(
0
)
for
i
,
symbol
in
enumerate
(
input_string
):
trans_key
=
alphabet_symbol_mapping
.
get
(
symbol
,
alphabet_anything_value
)
new_state
=
fsm_transitions
.
get
((
state
,
trans_key
))
if
new_state
is
None
:
if
not
full_match
and
last_final_idx
>
0
:
return
accepted_states
[:
last_final_idx
]
return
numba
.
typed
.
List
.
empty_list
(
numba
.
int64
)
state
=
new_state
if
state
in
fsm_finals
:
last_final_idx
=
numba
.
uint64
(
i
+
1
)
accepted_states
.
append
(
_nonoptional
(
state
))
if
full_match
and
last_final_idx
-
1
!=
i
:
return
numba
.
typed
.
List
.
empty_list
(
numba
.
int64
)
return
accepted_states
def
walk_fsm
(
fsm
:
BetterFSM
,
input_string
:
str
,
start_state
:
int
,
full_match
:
bool
=
True
,
)
->
List
[
int
]:
fsm_finals
=
fsm
.
finals
state
=
start_state
accepted_states
:
List
[
int
]
=
[]
last_final_idx
:
int
=
0
alphabet_symbol_mapping
=
fsm
.
alphabet
.
_symbol_mapping
alphabet_anything_value
=
fsm
.
alphabet
.
anything_value
fsm_transitions
=
fsm
.
flat_transition_map
for
i
,
symbol
in
enumerate
(
input_string
):
trans_key
=
alphabet_symbol_mapping
.
get
(
symbol
,
alphabet_anything_value
)
new_state
=
fsm_transitions
.
get
((
state
,
trans_key
))
if
new_state
is
None
:
if
not
full_match
and
last_final_idx
>
0
:
return
accepted_states
[:
last_final_idx
]
return
[]
state
=
new_state
if
state
in
fsm_finals
:
last_final_idx
=
i
+
1
accepted_states
.
append
(
state
)
if
full_match
and
last_final_idx
-
1
!=
i
:
return
[]
return
accepted_states
def
fsm_union
(
fsms
:
Sequence
[
FSM
],
)
->
Tuple
[
FSM
,
Dict
[
int
,
Tuple
[
Set
[
Tuple
[
int
,
int
]],
Set
[
int
],
Dict
[
int
,
Set
[
int
]]]]]:
"""Construct an FSM representing the union of the FSMs in `fsms`.
This is an updated version of `interegular.fsm.FSM.union` made to return an
extra map of component FSMs to the sets of state transitions that
correspond to them in the new FSM.
"""
alphabet
,
new_to_old
=
Alphabet
.
union
(
*
[
fsm
.
alphabet
for
fsm
in
fsms
])
indexed_fsms
=
tuple
(
enumerate
(
fsms
))
initial
=
{
i
:
fsm
.
initial
for
(
i
,
fsm
)
in
indexed_fsms
}
# Dedicated function accepting a "superset" and returning the next
# "superset" obtained by following this transition in the new FSM
def
follow
(
current_state
,
new_transition
:
int
):
next
=
{}
for
i
,
f
in
indexed_fsms
:
old_transition
=
new_to_old
[
i
][
new_transition
]
if
(
i
in
current_state
and
current_state
[
i
]
in
f
.
map
and
old_transition
in
f
.
map
[
current_state
[
i
]]
):
next
[
i
]
=
f
.
map
[
current_state
[
i
]][
old_transition
]
if
not
next
:
raise
OblivionError
return
next
states
=
[
initial
]
finals
:
Set
[
int
]
=
set
()
map
:
Dict
[
int
,
Dict
[
int
,
int
]]
=
{}
# Map component FSMs to their new state-to-state transitions, finals, and a
# map translating component FSM states to aggregate FSM states
fsms_to_trans_finals
:
Dict
[
int
,
Tuple
[
Set
[
Tuple
[
int
,
int
]],
Set
[
int
],
Dict
[
int
,
Set
[
int
]]]
]
=
{}
i
=
0
while
i
<
len
(
states
):
state
=
states
[
i
]
# Add to the finals of the aggregate FSM whenever we hit a final in a
# component FSM
if
any
(
state
.
get
(
j
,
-
1
)
in
fsm
.
finals
for
(
j
,
fsm
)
in
indexed_fsms
):
finals
.
add
(
i
)
# Compute the map for this state
map
[
i
]
=
{}
for
transition
in
alphabet
.
by_transition
:
try
:
next
=
follow
(
state
,
transition
)
except
OblivionError
:
# Reached an oblivion state; don't list it
continue
else
:
try
:
# TODO: Seems like this could--and should--be avoided
j
=
states
.
index
(
next
)
except
ValueError
:
j
=
len
(
states
)
states
.
append
(
next
)
map
[
i
][
transition
]
=
j
for
fsm_id
,
fsm_state
in
next
.
items
():
(
fsm_transitions
,
fsm_finals
,
fsm_old_to_new
,
)
=
fsms_to_trans_finals
.
setdefault
(
fsm_id
,
(
set
(),
set
(),
{}))
old_from
=
state
[
fsm_id
]
old_to
=
fsm_state
fsm_old_to_new
.
setdefault
(
old_from
,
set
()).
add
(
i
)
fsm_old_to_new
.
setdefault
(
old_to
,
set
()).
add
(
j
)
fsm_transitions
.
add
((
i
,
j
))
if
fsm_state
in
fsms
[
fsm_id
].
finals
:
fsm_finals
.
add
(
j
)
i
+=
1
fsm
=
FSM
(
alphabet
=
alphabet
,
states
=
range
(
len
(
states
)),
initial
=
0
,
finals
=
finals
,
map
=
map
,
__no_validation__
=
True
,
)
fsm
,
old_to_new_states
=
make_deterministic_fsm
(
fsm
)
_fsms_to_trans_finals
=
{
fsm_id
:
(
{(
old_to_new_states
[
s1
],
old_to_new_states
[
s2
])
for
s1
,
s2
in
transitions
},
{
old_to_new_states
[
s
]
for
s
in
finals
},
{
old_state
:
{
old_to_new_states
[
new_state
]
for
new_state
in
new_states
}
for
old_state
,
new_states
in
old_to_new
.
items
()
},
)
for
fsm_id
,
(
transitions
,
finals
,
old_to_new
)
in
sorted
(
fsms_to_trans_finals
.
items
(),
key
=
lambda
x
:
x
[
0
]
)
}
return
(
fsm
,
_fsms_to_trans_finals
,
)
def
get_sub_fsms_from_seq
(
state_seq
:
Sequence
[
int
],
fsms_to_trans_finals
:
Dict
[
int
,
Tuple
[
Set
[
Tuple
[
int
,
int
]],
Set
[
int
],
Dict
[
int
,
Set
[
int
]]]
],
)
->
Generator
[
Tuple
[
int
,
bool
,
bool
],
None
,
None
]:
"""Get the indices of the sub-FSMs in `fsm` that could have matched the state sequence `state_seq`.
Parameters
----------
state_seq
A state sequence.
fsms_to_trans_finals
A map from FSM indices to tuples containing sets of their state transitions
and sets of the final/accept states.
Returns
-------
A generator returning tuples containing each sub-FSM index (in the order
they were union-ed to construct `fsm`) and booleans indicating whether or
not there is another valid transition from the last state in the sequence
for the associated sub-FSM (i.e. if the FSM can continue
accepting/matching) and whether or not the sequence ends in a final state
of the sub-FSM.
"""
state_seq_transitions
=
set
(
zip
(
state_seq
[:
-
1
],
state_seq
[
1
:]))
last_fsm_state
=
state_seq
[
-
1
]
yield
from
(
(
# The sub-FMS index
fsm_idx
,
# Is there another possible transition in this sub-FSM?
any
(
last_fsm_state
==
from_s
for
(
from_s
,
to_s
)
in
transitions
),
# Is this sub-FSM in a final state?
state_seq
[
-
1
]
in
finals
,
)
for
fsm_idx
,
(
transitions
,
finals
,
_
)
in
fsms_to_trans_finals
.
items
()
if
state_seq_transitions
.
issubset
(
transitions
)
)
@
numba
.
njit
(
cache
=
True
,
nogil
=
True
)
def
state_scan_tokens
(
fsm_transitions
:
Dict
[
Tuple
[
int
,
int
],
int
],
alphabet_symbol_mapping
:
Dict
[
str
,
int
],
alphabet_anything_value
:
int
,
fsm_initial
:
int
,
fsm_finals
:
Set
[
int
],
vocabulary
:
Dict
[
str
,
List
[
int
]],
start_state
:
int
,
)
->
Set
[
Tuple
[
int
,
int
]]:
res
=
set
()
for
token
,
token_ids
in
vocabulary
.
items
():
state_seq
=
_walk_fsm
(
fsm_transitions
,
alphabet_symbol_mapping
,
alphabet_anything_value
,
fsm_initial
,
fsm_finals
,
token
,
start_state
,
False
,
)
if
state_seq
is
not
None
and
len
(
state_seq
)
<
len
(
token
):
continue
for
token_id
in
token_ids
:
res
.
add
((
token_id
,
state_seq
[
-
1
]))
return
res
def
create_fsm_index_end_to_end
(
fsm_info
:
FSMInfo
,
vocabulary
:
Dict
[
str
,
List
[
int
]],
)
->
Dict
[
int
,
Set
[
Tuple
[
int
,
int
]]]:
"""Create an FSM state-to-vocabulary map/index through end-to-end token parsing."""
# TODO: Consider using a `List` of `Set`s instead; that way we can JIT this
# code, too.
states_to_token_subsets
:
Dict
[
int
,
Set
[
Tuple
[
int
,
int
]]]
=
{}
seen
:
Set
[
int
]
=
set
()
next_states
=
{
fsm_info
.
initial
}
while
next_states
:
start_state
=
next_states
.
pop
()
token_ids_end_states
=
state_scan_tokens
(
fsm_info
.
transitions
,
fsm_info
.
alphabet_symbol_mapping
,
fsm_info
.
alphabet_anything_value
,
fsm_info
.
initial
,
fsm_info
.
finals
,
vocabulary
,
start_state
,
)
for
token_id_and_end_state
in
token_ids_end_states
:
states_to_token_subsets
.
setdefault
(
start_state
,
set
()).
add
(
token_id_and_end_state
)
end_state
=
token_id_and_end_state
[
1
]
if
end_state
not
in
seen
:
next_states
.
add
(
end_state
)
seen
.
add
(
start_state
)
return
states_to_token_subsets
# TODO: Cannot cache typed collections to disk, yet. See
# https://github.com/numba/numba/issues/4698
@
lru_cache
def
reduced_vocabulary
(
tokenizer
:
"Tokenizer"
):
"""Create a map from decoded vocabulary tokens to lists of equivalent token ids."""
vocabulary
=
numba
.
typed
.
Dict
.
empty
(
numba
.
types
.
string
,
numba
.
types
.
ListType
(
numba
.
int64
)
)
empty_token_ids
=
set
()
for
token
,
token_idx
in
tokenizer
.
vocabulary
.
items
():
if
token
in
tokenizer
.
special_tokens
:
continue
token_str
=
tokenizer
.
convert_token_to_string
(
token
)
if
token_str
:
vocabulary
.
setdefault
(
token_str
,
numba
.
typed
.
List
.
empty_list
(
numba
.
int64
),
).
append
(
numba
.
int64
(
token_idx
))
else
:
empty_token_ids
.
add
(
numba
.
int64
(
token_idx
))
return
vocabulary
,
empty_token_ids
def
create_fsm_index_tokenizer
(
fsm
:
BetterFSM
,
tokenizer
:
"Tokenizer"
,
)
->
Tuple
[
Dict
[
int
,
Dict
[
int
,
int
]],
Set
[
int
]]:
"""Construct an FMS index from a tokenizer.
This uses the end-to-end approach of `create_fsm_index_end_to_end`.
.. warning::
`fsm` needs to be deterministically ordered so that future caching makes sense.
"""
vocabulary
,
empty_token_ids
=
reduced_vocabulary
(
tokenizer
)
states_to_token_subsets
=
create_fsm_index_end_to_end
(
fsm
.
fsm_info
,
vocabulary
)
# Allow transitions to EOS from all terminals FSM states that are
# reachable
# TODO: Do we really need this anymore?
for
state
in
fsm
.
fsm_info
.
finals
:
subset
=
states_to_token_subsets
.
get
(
state
)
if
subset
is
not
None
:
subset
.
add
((
tokenizer
.
eos_token_id
,
state
))
# Convert to token-to-end-state maps
states_to_token_subsets
=
{
k
:
dict
(
v
)
for
k
,
v
in
states_to_token_subsets
.
items
()}
return
states_to_token_subsets
,
empty_token_ids
python/sglang/srt/constrained/tokenizer.py
deleted
100644 → 0
View file @
cba50273
# Adapted from:
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/models/tokenizer.py
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/models/transformers.py
from
abc
import
abstractmethod
from
typing
import
Dict
,
Hashable
,
List
,
Protocol
,
Set
,
Tuple
,
Union
import
numpy
as
np
import
torch
from
numpy.typing
import
NDArray
class
Tokenizer
(
Protocol
,
Hashable
):
eos_token
:
str
eos_token_id
:
int
pad_token_id
:
int
vocabulary
:
Dict
[
str
,
int
]
special_tokens
:
Set
[
int
]
@
abstractmethod
def
encode
(
self
,
prompt
:
Union
[
str
,
List
[
str
]]
)
->
Tuple
[
NDArray
[
np
.
int64
],
NDArray
[
np
.
int64
]]:
"""Translate the input prompts into NumPy arrays of token ids and attention mask."""
...
@
abstractmethod
def
decode
(
self
,
token_ids
:
NDArray
[
np
.
int64
])
->
List
[
str
]:
"""Translate an array of token ids to a string or list of strings."""
...
@
abstractmethod
def
convert_token_to_string
(
self
,
token
:
str
)
->
str
:
"""Convert a token to its equivalent string.
This is for instance useful for BPE tokenizers where whitespaces are
represented by the special characted `Ġ`. This prevents matching a raw
token that includes `Ġ` with a string.
"""
...
def
get_llama_tokenizer_types
():
"""Get all the Llama tokenizer types/classes that need work-arounds.
When they can't be imported, a dummy class is created.
"""
try
:
from
transformers.models.llama
import
LlamaTokenizer
except
ImportError
:
class
LlamaTokenizer
:
# type: ignore
pass
try
:
from
transformers.models.llama
import
LlamaTokenizerFast
except
ImportError
:
class
LlamaTokenizerFast
:
# type: ignore
pass
try
:
from
transformers.models.code_llama
import
CodeLlamaTokenizer
except
ImportError
:
class
CodeLlamaTokenizer
:
# type: ignore
pass
try
:
from
transformers.models.code_llama
import
CodeLlamaTokenizerFast
except
ImportError
:
class
CodeLlamaTokenizerFast
:
# type: ignore
pass
return
(
LlamaTokenizer
,
LlamaTokenizerFast
,
CodeLlamaTokenizer
,
CodeLlamaTokenizerFast
,
)
class
TransformerTokenizer
(
Tokenizer
):
"""Represents a tokenizer for models in the `transformers` library."""
def
__init__
(
self
,
model_name
:
str
,
**
kwargs
):
from
transformers
import
AutoTokenizer
kwargs
.
setdefault
(
"padding_side"
,
"left"
)
self
.
model_name
=
model_name
# TODO: Do something to make this hashable?
self
.
kwargs
=
kwargs
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
**
kwargs
)
self
.
eos_token_id
=
self
.
tokenizer
.
eos_token_id
self
.
eos_token
=
self
.
tokenizer
.
eos_token
if
not
self
.
tokenizer
.
pad_token_id
:
self
.
tokenizer
.
pad_token_id
=
self
.
tokenizer
.
eos_token_id
self
.
pad_token_id
=
self
.
eos_token_id
else
:
self
.
pad_token_id
=
self
.
tokenizer
.
pad_token_id
self
.
pad_token
=
self
.
tokenizer
.
pad_token
self
.
special_tokens
=
set
(
self
.
tokenizer
.
all_special_tokens
)
self
.
vocabulary
=
self
.
tokenizer
.
get_vocab
()
self
.
is_llama
=
isinstance
(
self
.
tokenizer
,
get_llama_tokenizer_types
())
def
encode
(
self
,
prompt
:
Union
[
str
,
List
[
str
]],
**
kwargs
)
->
Tuple
[
torch
.
LongTensor
,
torch
.
LongTensor
]:
kwargs
[
"padding"
]
=
True
kwargs
[
"return_tensors"
]
=
"pt"
output
=
self
.
tokenizer
(
prompt
,
**
kwargs
)
return
output
[
"input_ids"
],
output
[
"attention_mask"
]
def
decode
(
self
,
token_ids
:
torch
.
LongTensor
)
->
List
[
str
]:
text
=
self
.
tokenizer
.
batch_decode
(
token_ids
,
skip_special_tokens
=
True
)
return
text
def
convert_token_to_string
(
self
,
token
:
str
)
->
str
:
from
transformers.file_utils
import
SPIECE_UNDERLINE
string
=
self
.
tokenizer
.
convert_tokens_to_string
([
token
])
if
self
.
is_llama
:
# A hack to handle missing spaces to HF's Llama tokenizers
if
token
.
startswith
(
SPIECE_UNDERLINE
)
or
token
==
"<0x20>"
:
return
" "
+
string
return
string
def
__eq__
(
self
,
other
):
if
isinstance
(
other
,
type
(
self
)):
return
other
.
model_name
==
self
.
model_name
and
other
.
kwargs
==
self
.
kwargs
return
NotImplemented
def
__hash__
(
self
):
from
datasets.fingerprint
import
Hasher
return
hash
(
Hasher
.
hash
(
self
.
tokenizer
))
python/sglang/srt/server.py
View file @
37b42297
...
...
@@ -21,7 +21,7 @@ from fastapi import FastAPI, HTTPException, Request
from
fastapi.responses
import
Response
,
StreamingResponse
from
pydantic
import
BaseModel
from
sglang.backend.runtime_endpoint
import
RuntimeEndpoint
from
sglang.srt.constrained
.disk_cache
import
disable_cache
from
sglang.srt.constrained
import
disable_cache
from
sglang.srt.conversation
import
(
Conversation
,
SeparatorStyle
,
...
...
test/srt/test_jump_forward.py
View file @
37b42297
...
...
@@ -2,12 +2,11 @@ import argparse
from
enum
import
Enum
from
pydantic
import
BaseModel
,
constr
from
sglang.srt.constrained
.json_schema
import
build_regex_from_object
from
sglang.srt.constrained
import
build_regex_from_object
from
sglang.test.test_utils
import
(
add_common_sglang_args_and_parse
,
select_sglang_backend
,
)
import
sglang
as
sgl
IP_REGEX
=
r
"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment