Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ca13f3b8
Unverified
Commit
ca13f3b8
authored
Jan 21, 2024
by
Liangsheng Yin
Committed by
GitHub
Jan 20, 2024
Browse files
Disk FSM cache and adjust code. (#63)
parent
0b2efc2a
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
207 additions
and
299 deletions
+207
-299
examples/quick_start/srt_example_regex.py
examples/quick_start/srt_example_regex.py
+4
-1
python/pyproject.toml
python/pyproject.toml
+1
-1
python/sglang/srt/constrained/disk_cache.py
python/sglang/srt/constrained/disk_cache.py
+70
-0
python/sglang/srt/constrained/fsm.py
python/sglang/srt/constrained/fsm.py
+102
-129
python/sglang/srt/constrained/fsm_cache.py
python/sglang/srt/constrained/fsm_cache.py
+8
-32
python/sglang/srt/constrained/tokenizer.py
python/sglang/srt/constrained/tokenizer.py
+9
-132
python/sglang/srt/managers/router/infer_batch.py
python/sglang/srt/managers/router/infer_batch.py
+1
-1
python/sglang/srt/managers/router/model_rpc.py
python/sglang/srt/managers/router/model_rpc.py
+12
-3
No files found.
examples/quick_start/srt_example_regex.py
View file @
ca13f3b8
from
sglang
import
function
,
gen
,
set_default_backend
,
Runtime
from
sglang
import
function
,
gen
,
set_default_backend
,
Runtime
IP_ADDR_REGEX
=
r
"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
@
function
@
function
def
regex_gen
(
s
):
def
regex_gen
(
s
):
s
+=
"Q: What is the IP address of the Google DNS servers?
\n
"
s
+=
"Q: What is the IP address of the Google DNS servers?
\n
"
s
+=
"A: "
+
gen
(
s
+=
"A: "
+
gen
(
"answer"
,
"answer"
,
temperature
=
0
,
temperature
=
0
,
regex
=
r
"((25[0-5]|2[0-4]\d|[01]?\d\d?).){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
,
regex
=
IP_ADDR_REGEX
,
)
)
...
...
python/pyproject.toml
View file @
ca13f3b8
...
@@ -19,7 +19,7 @@ dependencies = [
...
@@ -19,7 +19,7 @@ dependencies = [
[project.optional-dependencies]
[project.optional-dependencies]
srt
=
[
"fastapi"
,
"psutil"
,
"rpyc"
,
"torch"
,
"uvloop"
,
"uvicorn"
,
"zmq"
,
"vllm>=0.2.5"
,
srt
=
[
"fastapi"
,
"psutil"
,
"rpyc"
,
"torch"
,
"uvloop"
,
"uvicorn"
,
"zmq"
,
"vllm>=0.2.5"
,
"interegular"
,
"lark"
,
"numba"
,
"pydantic"
]
"interegular"
,
"lark"
,
"numba"
,
"pydantic"
,
"diskcache"
,
"cloudpickle"
]
openai
=
["openai>=1.0"]
openai
=
["openai>=1.0"]
anthropic
=
["anthropic"]
anthropic
=
["anthropic"]
all
=
["sglang[srt]
", "
sglang
[openai]
", "
sglang
[anthropic]"]
all
=
["sglang[srt]
", "
sglang
[openai]
", "
sglang
[anthropic]"]
...
...
python/sglang/srt/constrained/disk_cache.py
0 → 100644
View file @
ca13f3b8
# Adapted from:
# https://github.com/outlines-dev/outlines/blob/6c6966cfa24e9c120494ebb317c6126aa2ae94af/outlines/caching.py
import
asyncio
import
hashlib
import
os
from
typing
import
Callable
,
Optional
import
cloudpickle
from
diskcache
import
Cache
home_dir
=
os
.
path
.
expanduser
(
"~"
)
cache_dir
=
os
.
environ
.
get
(
"SGLANG_CACHE_DIR"
,
f
"
{
home_dir
}
/.cache/sglang"
)
memory
=
Cache
(
cache_dir
,
eviction_policy
=
"none"
,
cull_limit
=
0
)
_caching_enabled
=
True
def
hash_arguments
(
*
args
,
**
kwargs
)
->
str
:
"""Create a hash out of the args and kwargs provided"""
result
=
hashlib
.
md5
()
for
item
in
list
(
args
)
+
sorted
(
kwargs
.
items
()):
result
.
update
(
cloudpickle
.
dumps
(
item
))
return
result
.
hexdigest
()
def
disk_cache
(
key_function
:
Optional
[
Callable
]
=
None
):
def
decorator
(
cached_function
:
Callable
):
def
wrapper
(
*
args
,
**
kwargs
):
if
not
_caching_enabled
:
return
cached_function
(
*
args
,
**
kwargs
)
if
key_function
:
key_args
=
key_function
(
*
args
,
**
kwargs
)
cache_key
=
hash_arguments
(
*
key_args
)
else
:
cache_key
=
hash_arguments
(
*
args
,
**
kwargs
)
if
cache_key
in
memory
:
return
memory
[
cache_key
]
result
=
cached_function
(
*
args
,
**
kwargs
)
memory
[
cache_key
]
=
result
return
result
async
def
async_wrapper
(
*
args
,
**
kwargs
):
if
not
_caching_enabled
:
return
await
cached_function
(
*
args
,
**
kwargs
)
if
key_function
:
key_args
=
key_function
(
*
args
,
**
kwargs
)
cache_key
=
hash_arguments
(
*
key_args
)
else
:
cache_key
=
hash_arguments
(
*
args
,
**
kwargs
)
if
cache_key
in
memory
:
return
memory
[
cache_key
]
result
=
await
cached_function
(
*
args
,
**
kwargs
)
memory
[
cache_key
]
=
result
return
result
if
asyncio
.
iscoroutinefunction
(
cached_function
):
return
async_wrapper
else
:
return
wrapper
return
decorator
def
disable_cache
():
global
_caching_enabled
_caching_enabled
=
False
def
clear_cache
():
global
memory
memory
.
clear
()
python/sglang/srt/constrained/fsm.py
View file @
ca13f3b8
# Adapted from:
# Adapted from:
# https://github.com/outlines-dev/outlines/blob/
0355ab4272a5d7e4d94c4a53a52593f885b81a61
/outlines/fsm/fsm.py
# https://github.com/outlines-dev/outlines/blob/
6c6966cfa24e9c120494ebb317c6126aa2ae94af
/outlines/fsm/fsm.py
from
typing
import
List
,
NewType
,
Protocol
from
typing
import
List
,
NewType
,
Protocol
,
Tuple
import
interegular
import
interegular
from
lark
import
Lark
from
lark
import
Lark
from
sglang.srt.constrained.disk_cache
import
disk_cache
# from outlines.fsm.parsing import PartialLark
# from outlines.fsm.parsing import PartialLark
from
sglang.srt.constrained.regex
import
(
from
sglang.srt.constrained.regex
import
(
...
@@ -16,16 +17,16 @@ FSMState = NewType("FSMState", int)
...
@@ -16,16 +17,16 @@ FSMState = NewType("FSMState", int)
class
FSM
(
Protocol
):
class
FSM
(
Protocol
):
def
allowed_token_ids
(
self
,
state
:
FSMState
,
idx
:
int
=
0
)
->
List
[
int
]:
def
allowed_token_ids
(
self
,
state
:
FSMState
)
->
List
[
int
]:
...
...
def
next_state
(
self
,
state
:
FSMState
,
token_id
:
int
,
idx
:
int
=
0
)
->
FSMState
:
def
next_state
(
self
,
state
:
FSMState
,
token_id
:
int
)
->
FSMState
:
...
...
def
is_final_state
(
self
,
state
:
FSMState
,
idx
:
int
=
0
)
->
bool
:
def
is_final_state
(
self
,
state
:
FSMState
)
->
bool
:
...
...
def
reset
(
self
)
->
None
:
def
copy
(
self
)
->
"FSM"
:
...
...
...
@@ -38,17 +39,12 @@ class StopAtTokenFSM(FSM):
...
@@ -38,17 +39,12 @@ class StopAtTokenFSM(FSM):
"""
"""
def
__init__
(
def
__init__
(
self
,
tokenizer
:
"Tokenizer"
,
stop_token_id
:
int
):
self
,
tokenizer
:
"Tokenizer"
,
stop_token_id
:
int
,
):
self
.
stop_token_id
=
stop_token_id
self
.
stop_token_id
=
stop_token_id
self
.
num_tokens_generated
=
0
self
.
vocabulary
=
tokenizer
.
vocabulary
.
values
()
self
.
vocabulary
=
tokenizer
.
vocabulary
.
values
()
self
.
final_states
=
{
1
}
self
.
final_states
=
{
1
}
def
allowed_token_ids
(
self
,
state
:
FSMState
,
idx
:
int
=
0
)
->
List
[
int
]:
def
allowed_token_ids
(
self
,
state
:
FSMState
)
->
List
[
int
]:
"""Generate a list of allowed tokens for the next step.
"""Generate a list of allowed tokens for the next step.
When in the initial state we allow every token to be generated.
When in the initial state we allow every token to be generated.
...
@@ -58,8 +54,6 @@ class StopAtTokenFSM(FSM):
...
@@ -58,8 +54,6 @@ class StopAtTokenFSM(FSM):
----------
----------
state
state
The current state of the FSM.
The current state of the FSM.
idx
The index of the current input in the batch.
Returns
Returns
-------
-------
...
@@ -71,7 +65,7 @@ class StopAtTokenFSM(FSM):
...
@@ -71,7 +65,7 @@ class StopAtTokenFSM(FSM):
else
:
else
:
return
[
self
.
stop_token_id
]
return
[
self
.
stop_token_id
]
def
next_state
(
self
,
state
:
FSMState
,
token_id
:
int
,
idx
:
int
=
0
)
->
FSMState
:
def
next_state
(
self
,
state
:
FSMState
,
token_id
:
int
)
->
FSMState
:
"""Update the state of the FSM.
"""Update the state of the FSM.
The FSM stays in the initial state `0` unless the specified stop token
The FSM stays in the initial state `0` unless the specified stop token
...
@@ -84,29 +78,24 @@ class StopAtTokenFSM(FSM):
...
@@ -84,29 +78,24 @@ class StopAtTokenFSM(FSM):
The current state of the FSM.
The current state of the FSM.
token_id
token_id
The id of the token that was just generated.
The id of the token that was just generated.
idx
The index of the current input in the batch.
Returns
Returns
-------
-------
The new state of the FSM.
The new state of the FSM.
"""
"""
if
idx
==
0
:
self
.
num_tokens_generated
+=
1
if
token_id
==
self
.
stop_token_id
:
if
token_id
==
self
.
stop_token_id
:
return
FSMState
(
1
)
return
FSMState
(
1
)
return
FSMState
(
0
)
return
FSMState
(
0
)
def
is_final_state
(
self
,
state
:
FSMState
,
idx
:
int
=
0
)
->
bool
:
def
is_final_state
(
self
,
state
:
FSMState
)
->
bool
:
"""Determine whether the current state of the FSM is a final state."""
"""Determine whether the current state of the FSM is a final state."""
return
state
in
self
.
final_states
return
state
in
self
.
final_states
def
reset
(
self
)
->
None
:
def
copy
(
self
)
->
"StopAtTokenFSM"
:
"""
Reset the FSM to its initial state. Here this only resets the token counter
."""
"""
Create a copy of the FSM
."""
self
.
num_tokens_generated
=
0
return
self
class
RegexFSM
(
FSM
):
class
RegexFSM
(
FSM
):
...
@@ -117,32 +106,48 @@ class RegexFSM(FSM):
...
@@ -117,32 +106,48 @@ class RegexFSM(FSM):
regex_string
:
str
,
regex_string
:
str
,
tokenizer
:
"Tokenizer"
,
tokenizer
:
"Tokenizer"
,
):
):
regex_pattern
=
interegular
.
parse_pattern
(
regex_string
)
@
disk_cache
()
regex_fsm
,
_
=
make_deterministic_fsm
(
regex_pattern
.
to_fsm
().
reduce
())
def
create_states_mapping
(
regex_string
:
str
,
cacheable_vocabulary
:
Tuple
[
Tuple
[
str
,
int
]]
)
->
Tuple
[
dict
,
set
,
set
]:
"""Create the variables related to the mapping between states and tokens
The parameters of the function are used for caching purpose
"""
regex_pattern
=
interegular
.
parse_pattern
(
regex_string
)
regex_fsm
,
_
=
make_deterministic_fsm
(
regex_pattern
.
to_fsm
().
reduce
())
(
states_to_token_maps
,
empty_token_ids
,
)
=
create_fsm_index_tokenizer
(
regex_fsm
,
tokenizer
)
# We make sure that it is possible to generate strings in the language
# of the regular expression with the tokens present in the model's
# vocabulary.
if
not
any
(
regex_fsm
.
finals
.
intersection
(
v
.
values
())
for
v
in
states_to_token_maps
.
values
()
):
raise
ValueError
(
"The vocabulary does not allow us to build a sequence that matches the input regex"
)
final_states
=
regex_fsm
.
finals
|
{
-
1
}
# Include the EOS token in final states
return
states_to_token_maps
,
empty_token_ids
,
final_states
(
(
self
.
states_to_token_maps
,
self
.
states_to_token_maps
,
self
.
empty_token_ids
,
self
.
empty_token_ids
,
)
=
create_fsm_index_tokenizer
(
regex_fsm
,
tokenizer
)
self
.
final_states
,
)
=
create_states_mapping
(
# We make sure that it is possible to generate strings in the language
regex_string
,
tuple
(
sorted
(
tokenizer
.
vocabulary
.
items
()))
# of the regular expression with the tokens present in the model's
)
# vocabulary.
if
not
any
(
regex_fsm
.
finals
.
intersection
(
v
.
values
())
for
v
in
self
.
states_to_token_maps
.
values
()
):
raise
ValueError
(
"The vocabulary does not allow us to build a sequence that matches the input regex"
)
self
.
final_states
=
regex_fsm
.
finals
|
{
-
1
}
# Include the EOS token in final states
self
.
num_tokens_generated
=
0
self
.
num_tokens_generated
=
0
self
.
vocabulary
=
tokenizer
.
vocabulary
.
values
()
self
.
vocabulary
=
tokenizer
.
vocabulary
.
values
()
self
.
end_token_id
=
tokenizer
.
eos_token_id
self
.
end_token_id
=
tokenizer
.
eos_token_id
def
allowed_token_ids
(
self
,
state
:
FSMState
,
idx
:
int
=
0
)
->
List
[
int
]:
def
allowed_token_ids
(
self
,
state
:
FSMState
)
->
List
[
int
]:
"""Generate a list of allowed tokens for the next step.
"""Generate a list of allowed tokens for the next step.
The initialization of the FSM builds an index which maps FSM states to a
The initialization of the FSM builds an index which maps FSM states to a
...
@@ -159,8 +164,6 @@ class RegexFSM(FSM):
...
@@ -159,8 +164,6 @@ class RegexFSM(FSM):
----------
----------
state
state
The current state of the FSM.
The current state of the FSM.
idx
The index of the current input in the batch.
Returns
Returns
-------
-------
...
@@ -174,7 +177,7 @@ class RegexFSM(FSM):
...
@@ -174,7 +177,7 @@ class RegexFSM(FSM):
else
:
else
:
return
list
(
next_tokens_to_end_states
.
keys
())
return
list
(
next_tokens_to_end_states
.
keys
())
def
next_state
(
self
,
state
:
FSMState
,
token_id
:
int
,
idx
:
int
=
0
)
->
FSMState
:
def
next_state
(
self
,
state
:
FSMState
,
token_id
:
int
)
->
FSMState
:
"""Update the state of the FSM.
"""Update the state of the FSM.
We use the index to determine to which state the FSM should transition
We use the index to determine to which state the FSM should transition
...
@@ -186,17 +189,12 @@ class RegexFSM(FSM):
...
@@ -186,17 +189,12 @@ class RegexFSM(FSM):
The current state of the FSM.
The current state of the FSM.
token_id
token_id
The id of the token that was just generated.
The id of the token that was just generated.
idx
The index of the current input in the batch.
Returns
Returns
-------
-------
The new state of the FSM.
The new state of the FSM.
"""
"""
if
idx
==
0
:
self
.
num_tokens_generated
+=
1
if
token_id
==
self
.
end_token_id
:
if
token_id
==
self
.
end_token_id
:
return
FSMState
(
-
1
)
return
FSMState
(
-
1
)
...
@@ -207,24 +205,22 @@ class RegexFSM(FSM):
...
@@ -207,24 +205,22 @@ class RegexFSM(FSM):
return
FSMState
(
next_state
)
return
FSMState
(
next_state
)
def
is_final_state
(
self
,
state
:
FSMState
,
idx
:
int
=
0
)
->
bool
:
def
is_final_state
(
self
,
state
:
FSMState
)
->
bool
:
"""Determine whether the current state of the FSM is a final state."""
"""Determine whether the current state of the FSM is a final state."""
return
state
in
self
.
final_states
return
state
in
self
.
final_states
def
reset
(
self
)
->
None
:
def
copy
(
self
)
->
"RegexFSM"
:
"""
Reset the FSM to its initial state. Here this only resets the token counter
."""
"""
Create a copy of the FSM
."""
self
.
num_tokens_generated
=
0
return
self
class
CFGFSM
(
FSM
):
class
CFGFSM
(
FSM
):
"""FSM to generate text that is in the language of a context-free grammar."""
"""FSM to generate text that is in the language of a context-free grammar."""
def
__init__
(
def
__init__
(
self
,
cfg_string
:
str
,
tokenizer
:
"Tokenizer"
):
self
,
self
.
cfg_string
=
cfg_string
cfg_string
:
str
,
self
.
tokenizer
=
tokenizer
tokenizer
:
"Tokenizer"
,
):
# self.parser = PartialLark(cfg_string, parser="lalr")
self
.
parser
=
Lark
(
self
.
parser
=
Lark
(
cfg_string
,
cfg_string
,
parser
=
"lalr"
,
parser
=
"lalr"
,
...
@@ -239,59 +235,52 @@ class CFGFSM(FSM):
...
@@ -239,59 +235,52 @@ class CFGFSM(FSM):
self
.
terminal_regexps
[
terminal
.
name
]
=
terminal
.
pattern
.
to_regexp
()
self
.
terminal_regexps
[
terminal
.
name
]
=
terminal
.
pattern
.
to_regexp
()
self
.
terminal_regexps
[
"$END"
]
=
tokenizer
.
eos_token
self
.
terminal_regexps
[
"$END"
]
=
tokenizer
.
eos_token
self
.
tokenizer
=
tokenizer
self
.
generation
=
""
self
.
num_tokens_generated
=
0
self
.
reset_state
=
False
self
.
generations
:
List
[
str
]
=
[]
self
.
allow_eos
=
False
self
.
regex_fsms
:
List
[
RegexFSM
]
=
[]
self
.
done
=
False
self
.
reset_state
:
List
[
bool
]
=
[]
self
.
regex_fsm
:
RegexFSM
self
.
allow_eos
:
List
[
bool
]
=
[]
self
.
done
:
List
[
bool
]
=
[]
def
_set_next_regex_fsm
(
self
,
idx
:
int
=
0
)
->
None
:
def
_set_next_regex_fsm
(
self
)
->
None
:
"""Use the CFG incremental parser to set the next regex FSM.
"""Use the CFG incremental parser to set the next regex FSM.
Check what the CFG incremental parser proposes next.
Check what the CFG incremental parser proposes next:
If the only proposal is the EOS token,
- If the only proposal is the EOS token we set the state to done and
we set the state to done and return.
return.
If there are other proposals,
- If there are other proposals, we set a new regex FSM and return.
we set a new regex FSM and return.
"""
"""
interactive
=
self
.
parser
.
parse_interactive
(
self
.
generation
s
[
idx
]
)
interactive
=
self
.
parser
.
parse_interactive
(
self
.
generation
)
interactive
.
exhaust_lexer
()
interactive
.
exhaust_lexer
()
options
=
{
self
.
terminal_regexps
[
x
]
for
x
in
interactive
.
accepts
()}
options
=
{
self
.
terminal_regexps
[
x
]
for
x
in
interactive
.
accepts
()}
if
self
.
terminal_regexps
[
"$END"
]
in
options
:
if
self
.
terminal_regexps
[
"$END"
]
in
options
:
options
.
remove
(
self
.
terminal_regexps
[
"$END"
])
options
.
remove
(
self
.
terminal_regexps
[
"$END"
])
if
len
(
options
)
==
0
:
if
len
(
options
)
==
0
:
self
.
done
[
idx
]
=
True
self
.
done
=
True
return
return
self
.
allow_eos
[
idx
]
=
True
self
.
allow_eos
=
True
options
.
add
(
""
)
options
.
add
(
""
)
assert
len
(
options
)
>
1
assert
len
(
options
)
>
1
regex_string
=
r
"("
+
r
"|"
.
join
([
r
"("
+
x
+
r
")"
for
x
in
options
])
+
r
")"
regex_string
=
r
"("
+
r
"|"
.
join
([
r
"("
+
x
+
r
")"
for
x
in
options
])
+
r
")"
args
=
(
self
.
regex_fsm
=
RegexFSM
(
regex_string
,
self
.
tokenizer
)
regex_string
,
self
.
reset_state
=
True
self
.
tokenizer
,
)
if
len
(
self
.
regex_fsms
)
<=
idx
:
self
.
regex_fsms
.
append
(
RegexFSM
(
*
args
))
else
:
self
.
regex_fsms
[
idx
]
=
RegexFSM
(
*
args
)
self
.
reset_state
[
idx
]
=
True
def
allowed_token_ids
(
self
,
state
:
FSMState
,
idx
:
int
=
0
)
->
List
[
int
]:
def
allowed_token_ids
(
self
,
state
:
FSMState
)
->
List
[
int
]:
"""Generate a list of allowed tokens for the next step.
"""Generate a list of allowed tokens for the next step.
Upon initialization, the CFG incremental parser is used to determine the first regex.
Upon initialization, the CFG incremental parser is used to determine the
first regex.
This regex is used for proposals until either:
This regex is used for proposals until either:
- the regex is exhausted, and its only remaining option is the EOS token,
in which case we always transition to the next regex
- The regex is exhausted, and its only remaining option is the EOS
- the regex can be exhausted, but the EOS token is not the only remaining option,
token, in which case we always transition to the next regex
in which case we transition to the next regex with probability P (TODO)
- The regex can be exhausted, but the EOS token is not the only
or remove the possibility of generating the EOS token and continue with the current regex
remaining option, in which case we transition to the next regex with
probability P (TODO) or remove the possibility of generating the EOS
token and continue with the current regex
The CFG incremental parser is allowed to propose the EOS token from any final state,
The CFG incremental parser is allowed to propose the EOS token from any final state,
and once it is generated, the FSM will continue to always generate the EOS token.
and once it is generated, the FSM will continue to always generate the EOS token.
...
@@ -300,22 +289,14 @@ class CFGFSM(FSM):
...
@@ -300,22 +289,14 @@ class CFGFSM(FSM):
----------
----------
state
state
The current state of the FSM.
The current state of the FSM.
idx
The index of the current input in the batch.
Returns
Returns
-------
-------
A list that contains the tokens to mask.
A list that contains the tokens to mask.
"""
"""
if
len
(
self
.
generations
)
<=
idx
:
if
self
.
generation
!=
""
:
self
.
generations
.
append
(
""
)
proposal
=
self
.
regex_fsm
.
allowed_token_ids
(
state
)
self
.
reset_state
.
append
(
False
)
self
.
allow_eos
.
append
(
False
)
self
.
done
.
append
(
False
)
if
len
(
self
.
regex_fsms
)
>
idx
:
proposal
=
self
.
regex_fsms
[
idx
].
allowed_token_ids
(
state
)
if
self
.
tokenizer
.
eos_token_id
not
in
proposal
:
if
self
.
tokenizer
.
eos_token_id
not
in
proposal
:
return
proposal
return
proposal
if
set
(
proposal
)
!=
{
self
.
tokenizer
.
eos_token_id
}:
if
set
(
proposal
)
!=
{
self
.
tokenizer
.
eos_token_id
}:
...
@@ -323,23 +304,23 @@ class CFGFSM(FSM):
...
@@ -323,23 +304,23 @@ class CFGFSM(FSM):
proposal
=
[
x
for
x
in
proposal
if
x
!=
self
.
tokenizer
.
eos_token_id
]
proposal
=
[
x
for
x
in
proposal
if
x
!=
self
.
tokenizer
.
eos_token_id
]
return
proposal
return
proposal
self
.
_set_next_regex_fsm
(
idx
)
self
.
_set_next_regex_fsm
()
if
self
.
done
[
idx
]
:
if
self
.
done
:
return
[
self
.
tokenizer
.
eos_token_id
]
return
[
self
.
tokenizer
.
eos_token_id
]
if
self
.
reset_state
[
idx
]
:
if
self
.
reset_state
:
state
=
FSMState
(
0
)
state
=
FSMState
(
0
)
proposal
=
self
.
regex_fsm
s
[
idx
]
.
allowed_token_ids
(
state
)
proposal
=
self
.
regex_fsm
.
allowed_token_ids
(
state
)
if
self
.
allow_eos
[
idx
]
:
if
self
.
allow_eos
:
self
.
allow_eos
[
idx
]
=
False
self
.
allow_eos
=
False
else
:
else
:
proposal
=
[
x
for
x
in
proposal
if
x
!=
self
.
tokenizer
.
eos_token_id
]
proposal
=
[
x
for
x
in
proposal
if
x
!=
self
.
tokenizer
.
eos_token_id
]
assert
len
(
proposal
)
>
0
assert
len
(
proposal
)
>
0
return
proposal
return
proposal
def
next_state
(
self
,
state
:
FSMState
,
token_id
:
int
,
idx
:
int
=
0
)
->
FSMState
:
def
next_state
(
self
,
state
:
FSMState
,
token_id
:
int
)
->
FSMState
:
"""Update the state of the FSM.
"""Update the state of the FSM.
Transitions the underlying regex FSM to its next state.
Transitions the underlying regex FSM to its next state.
...
@@ -352,34 +333,26 @@ class CFGFSM(FSM):
...
@@ -352,34 +333,26 @@ class CFGFSM(FSM):
The current state of the FSM.
The current state of the FSM.
token_id
token_id
The id of the token that was just generated.
The id of the token that was just generated.
idx
The index of the current input in the batch.
Returns
Returns
-------
-------
The new state of the FSM.
The new state of the FSM.
"""
"""
if
idx
==
0
:
self
.
num_tokens_generated
+=
1
if
token_id
==
self
.
tokenizer
.
eos_token_id
:
if
token_id
==
self
.
tokenizer
.
eos_token_id
:
self
.
done
[
idx
]
=
True
self
.
done
=
True
return
FSMState
(
-
1
)
return
FSMState
(
-
1
)
if
self
.
reset_state
[
idx
]
:
if
self
.
reset_state
:
self
.
reset_state
[
idx
]
=
False
self
.
reset_state
=
False
state
=
FSMState
(
0
)
state
=
FSMState
(
0
)
self
.
generation
s
[
idx
]
+=
self
.
tokenizer
.
decode
([
token_id
])[
0
]
self
.
generation
+=
self
.
tokenizer
.
decode
([
token_id
])[
0
]
return
self
.
regex_fsm
s
[
idx
]
.
next_state
(
state
,
token_id
,
idx
)
return
self
.
regex_fsm
.
next_state
(
state
,
token_id
)
def
is_final_state
(
self
,
state
:
FSMState
,
idx
:
int
=
0
)
->
bool
:
def
is_final_state
(
self
,
state
:
FSMState
)
->
bool
:
"""Return whether the current state of the FSM is a final state."""
"""Return whether the current state of the FSM is a final state."""
return
self
.
done
[
idx
]
return
self
.
done
def
reset
(
self
)
->
None
:
def
copy
(
self
)
->
"CFGFSM"
:
"""Reset the FSM to its initial state, so it can be called on a fresh batch on inputs."""
"""Create a copy of the FSM."""
self
.
num_tokens_generated
=
0
return
CFGFSM
(
self
.
cfg_string
,
self
.
tokenizer
)
self
.
generations
=
[]
self
.
regex_fsms
=
[]
self
.
reset_state
=
[]
self
.
done
=
[]
python/sglang/srt/constrained/fsm_cache.py
View file @
ca13f3b8
import
threading
from
sglang.srt.constrained.fsm
import
RegexFSM
from
sglang.srt.constrained.fsm
import
RegexFSM
from
sglang.srt.constrained.tokenizer
import
TransformerTokenizer
from
sglang.srt.constrained.tokenizer
import
TransformerTokenizer
def
get_fsm
(
regex
,
tokenizer
,
fsm_cache_entry
):
outlines_tokenizer
=
TransformerTokenizer
(
tokenizer
)
fsm
=
RegexFSM
(
regex
,
outlines_tokenizer
)
fsm_cache_entry
.
fsm
=
fsm
fsm_cache_entry
.
event
.
set
()
class
FSMCacheEntry
:
def
__init__
(
self
):
self
.
fsm
=
None
self
.
event
=
threading
.
Event
()
class
FSMCache
:
class
FSMCache
:
def
__init__
(
self
,
tokenizer
):
def
__init__
(
self
,
tokenizer
_path
,
tokenizer_args_dict
):
self
.
cache
=
{}
self
.
cache
=
{}
self
.
tokenizer
=
tokenizer
self
.
outlines_tokenizer
=
TransformerTokenizer
(
tokenizer_path
,
**
tokenizer_args_dict
)
def
init_fsm
_in_background
(
self
,
regex
):
def
init_fsm
(
self
,
regex
):
if
regex
not
in
self
.
cache
:
if
regex
not
in
self
.
cache
:
self
.
cache
[
regex
]
=
FSMCacheEntry
()
fsm
=
RegexFSM
(
regex
,
self
.
outlines_tokenizer
)
threading
.
Thread
(
self
.
cache
[
regex
]
=
fsm
target
=
get_fsm
,
args
=
(
regex
,
self
.
tokenizer
,
self
.
cache
[
regex
],
),
).
start
()
def
get_fsm
(
self
,
regex
):
return
self
.
cache
[
regex
]
self
.
init_fsm_in_background
(
regex
)
entry
=
self
.
cache
[
regex
]
entry
.
event
.
wait
()
return
entry
.
fsm
python/sglang/srt/constrained/tokenizer.py
View file @
ca13f3b8
...
@@ -2,17 +2,7 @@
...
@@ -2,17 +2,7 @@
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/models/tokenizer.py
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/models/tokenizer.py
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/models/transformers.py
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/models/transformers.py
from
abc
import
abstractmethod
from
abc
import
abstractmethod
from
typing
import
(
from
typing
import
Dict
,
Hashable
,
List
,
Protocol
,
Set
,
Tuple
,
Union
TYPE_CHECKING
,
Dict
,
Hashable
,
List
,
Optional
,
Protocol
,
Set
,
Tuple
,
Union
,
)
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -50,15 +40,6 @@ class Tokenizer(Protocol, Hashable):
...
@@ -50,15 +40,6 @@ class Tokenizer(Protocol, Hashable):
...
...
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedModel
,
PreTrainedTokenizer
__all__
=
[
"transformers"
]
KVCacheType
=
Tuple
[
Tuple
[
torch
.
DoubleTensor
,
torch
.
DoubleTensor
],
...]
def
get_llama_tokenizer_types
():
def
get_llama_tokenizer_types
():
"""Get all the Llama tokenizer types/classes that need work-arounds.
"""Get all the Llama tokenizer types/classes that need work-arounds.
...
@@ -101,76 +82,17 @@ def get_llama_tokenizer_types():
...
@@ -101,76 +82,17 @@ def get_llama_tokenizer_types():
)
)
class
Transformer
:
"""Represents a `transformers` model."""
def
__init__
(
self
,
model
:
"PreTrainedModel"
,
tokenizer
:
"PreTrainedTokenizer"
,
):
self
.
device
=
model
.
device
self
.
model
=
model
self
.
tokenizer
=
tokenizer
@
torch
.
inference_mode
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
,
attention_mask
:
torch
.
LongTensor
,
past_key_values
:
Optional
[
Tuple
]
=
None
,
)
->
Tuple
[
torch
.
FloatTensor
,
Optional
[
KVCacheType
]]:
"""Compute a forward pass through the transformer model.
Parameters
----------
input_ids
The input token ids. Must be one or two dimensional.
attention_mask
The attention mask. Must be one or two dimensional.
past_key_values
A tuple of tuples containing the cached key and value tensors for each
attention head.
Returns
-------
The computed logits and the new cached key and value tensors.
"""
assert
0
<
input_ids
.
ndim
<
3
if
past_key_values
:
input_ids
=
input_ids
[...,
-
1
].
unsqueeze
(
-
1
)
output
=
self
.
model
(
input_ids
,
attention_mask
=
attention_mask
,
return_dict
=
True
,
output_attentions
=
False
,
output_hidden_states
=
False
,
past_key_values
=
past_key_values
,
)
return
output
.
logits
,
output
.
past_key_values
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
attention_mask
:
torch
.
LongTensor
,
past_key_values
:
Optional
[
Tuple
]
=
None
,
)
->
torch
.
FloatTensor
:
logits
,
kv_cache
=
self
.
forward
(
input_ids
,
attention_mask
,
past_key_values
)
next_token_logits
=
logits
[...,
-
1
,
:]
return
next_token_logits
,
kv_cache
class
TransformerTokenizer
(
Tokenizer
):
class
TransformerTokenizer
(
Tokenizer
):
"""Represents a tokenizer for models in the `transformers` library."""
"""Represents a tokenizer for models in the `transformers` library."""
def
__init__
(
self
,
tokenizer
):
def
__init__
(
self
,
model_name
:
str
,
**
kwargs
):
from
transformers
import
AutoTokenizer
kwargs
.
setdefault
(
"padding_side"
,
"left"
)
self
.
model_name
=
model_name
# TODO: Do something to make this hashable?
# TODO: Do something to make this hashable?
self
.
tokenizer
=
tokenizer
self
.
kwargs
=
kwargs
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
**
kwargs
)
self
.
eos_token_id
=
self
.
tokenizer
.
eos_token_id
self
.
eos_token_id
=
self
.
tokenizer
.
eos_token_id
self
.
eos_token
=
self
.
tokenizer
.
eos_token
self
.
eos_token
=
self
.
tokenizer
.
eos_token
...
@@ -212,55 +134,10 @@ class TransformerTokenizer(Tokenizer):
...
@@ -212,55 +134,10 @@ class TransformerTokenizer(Tokenizer):
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
if
isinstance
(
other
,
type
(
self
)):
if
isinstance
(
other
,
type
(
self
)):
return
False
return
other
.
model_name
==
self
.
model_name
and
other
.
kwargs
==
self
.
kwargs
# TODO(lsyin): the lru_cache for the TransoformerTokenizer is useless ?
# return other.model_name == self.model_name and other.kwargs == self.kwargs
return
NotImplemented
return
NotImplemented
def
__hash__
(
self
):
def
__hash__
(
self
):
from
datasets.fingerprint
import
Hasher
from
datasets.fingerprint
import
Hasher
return
hash
(
Hasher
.
hash
(
self
.
tokenizer
))
return
hash
(
Hasher
.
hash
(
self
.
tokenizer
))
def
transformers
(
model_name
:
str
,
device
:
Optional
[
str
]
=
None
,
model_kwargs
:
dict
=
{},
tokenizer_kwargs
:
dict
=
{},
):
"""Instantiate a model from the `transformers` library and its tokenizer.
Parameters
----------
model_name
The name of the model as listed on Hugging Face's model page.
device
The device(s) on which the model should be loaded. This overrides
the `device_map` entry in `model_kwargs` when provided.
model_kwargs
A dictionary that contains the keyword arguments to pass to the
`from_pretrained` method when loading the model.
tokenizer_kwargs
A dictionary that contains the keyword arguments to pass to the
`from_pretrained` method when loading the tokenizer.
Returns
-------
A `TransformersModel` model instance.
"""
try
:
from
transformers
import
AutoModelForCausalLM
except
ImportError
:
raise
ImportError
(
"The `transformers` library needs to be installed in order to use `transformers` models."
)
if
device
is
not
None
:
model_kwargs
[
"device_map"
]
=
device
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_name
,
**
model_kwargs
)
tokenizer
=
TransformerTokenizer
(
model_name
,
**
tokenizer_kwargs
)
return
Transformer
(
model
,
tokenizer
)
python/sglang/srt/managers/router/infer_batch.py
View file @
ca13f3b8
...
@@ -45,7 +45,7 @@ class Req:
...
@@ -45,7 +45,7 @@ class Req:
# for constrained decoding
# for constrained decoding
self
.
regex_fsm
=
None
self
.
regex_fsm
=
None
self
.
regex_fsm_state
=
None
self
.
regex_fsm_state
=
0
def
max_new_tokens
(
self
):
def
max_new_tokens
(
self
):
return
self
.
sampling_params
.
max_new_tokens
return
self
.
sampling_params
.
max_new_tokens
...
...
python/sglang/srt/managers/router/model_rpc.py
View file @
ca13f3b8
...
@@ -111,7 +111,13 @@ class ModelRpcServer(rpyc.Service):
...
@@ -111,7 +111,13 @@ class ModelRpcServer(rpyc.Service):
self
.
stream_interval
=
server_args
.
stream_interval
self
.
stream_interval
=
server_args
.
stream_interval
# Init the FSM cache for constrained generation
# Init the FSM cache for constrained generation
self
.
regex_fsm_cache
=
FSMCache
(
self
.
tokenizer
)
self
.
regex_fsm_cache
=
FSMCache
(
server_args
.
tokenizer_path
,
{
"tokenizer_mode"
:
server_args
.
tokenizer_mode
,
"trust_remote_code"
:
server_args
.
trust_remote_code
,
},
)
# Init new token estimation
# Init new token estimation
self
.
new_token_ratio
=
min
(
0.4
*
server_args
.
schedule_conservativeness
,
1.0
)
self
.
new_token_ratio
=
min
(
0.4
*
server_args
.
schedule_conservativeness
,
1.0
)
...
@@ -213,6 +219,10 @@ class ModelRpcServer(rpyc.Service):
...
@@ -213,6 +219,10 @@ class ModelRpcServer(rpyc.Service):
req
.
stream
=
recv_req
.
stream
req
.
stream
=
recv_req
.
stream
req
.
tokenizer
=
self
.
tokenizer
req
.
tokenizer
=
self
.
tokenizer
# Init regex fsm
if
req
.
sampling_params
.
regex
is
not
None
:
req
.
regex_fsm
=
self
.
regex_fsm_cache
.
init_fsm
(
req
.
sampling_params
.
regex
)
# Truncate long prompts
# Truncate long prompts
req
.
input_ids
=
req
.
input_ids
[:
self
.
model_config
.
context_len
-
1
]
req
.
input_ids
=
req
.
input_ids
[:
self
.
model_config
.
context_len
-
1
]
req
.
sampling_params
.
max_new_tokens
=
min
(
req
.
sampling_params
.
max_new_tokens
=
min
(
...
@@ -322,11 +332,10 @@ class ModelRpcServer(rpyc.Service):
...
@@ -322,11 +332,10 @@ class ModelRpcServer(rpyc.Service):
self
.
model_config
.
vocab_size
,
self
.
int_token_logit_bias
self
.
model_config
.
vocab_size
,
self
.
int_token_logit_bias
)
)
#
init the
regex fsm before first sampling
#
Reset
regex fsm
state
before first sampling
due to retractions
for
req
in
batch
.
reqs
:
for
req
in
batch
.
reqs
:
if
req
.
sampling_params
.
regex
is
not
None
:
if
req
.
sampling_params
.
regex
is
not
None
:
req
.
regex_fsm_state
=
0
req
.
regex_fsm_state
=
0
req
.
regex_fsm
=
self
.
regex_fsm_cache
.
get_fsm
(
req
.
sampling_params
.
regex
)
if
batch
.
extend_num_tokens
!=
0
:
if
batch
.
extend_num_tokens
!=
0
:
# Forward
# Forward
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment