Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9c902b19
Unverified
Commit
9c902b19
authored
Jun 12, 2024
by
Liangsheng Yin
Committed by
GitHub
Jun 11, 2024
Browse files
Decode Incrementally (#517)
parent
111991fe
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
345 additions
and
135 deletions
+345
-135
examples/usage/chinese_regex.py
examples/usage/chinese_regex.py
+53
-0
python/sglang/srt/constrained/__init__.py
python/sglang/srt/constrained/__init__.py
+4
-3
python/sglang/srt/constrained/fsm_cache.py
python/sglang/srt/constrained/fsm_cache.py
+2
-2
python/sglang/srt/constrained/jump_forward.py
python/sglang/srt/constrained/jump_forward.py
+106
-25
python/sglang/srt/managers/controller/infer_batch.py
python/sglang/srt/managers/controller/infer_batch.py
+136
-67
python/sglang/srt/managers/controller/tp_worker.py
python/sglang/srt/managers/controller/tp_worker.py
+28
-18
python/sglang/srt/managers/detokenizer_manager.py
python/sglang/srt/managers/detokenizer_manager.py
+12
-18
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+4
-2
No files found.
examples/usage/chinese_regex.py
0 → 100644
View file @
9c902b19
import
sglang
as
sgl
character_regex
=
(
r
"""\{\n"""
+
r
""" "姓名": "[^"]{1,32}",\n"""
+
r
""" "学院": "(格兰芬多|赫奇帕奇|拉文克劳|斯莱特林)",\n"""
+
r
""" "血型": "(纯血|混血|麻瓜)",\n"""
+
r
""" "职业": "(学生|教师|傲罗|魔法部|食死徒|凤凰社成员)",\n"""
+
r
""" "魔杖": \{\n"""
+
r
""" "材质": "[^"]{1,32}",\n"""
+
r
""" "杖芯": "[^"]{1,32}",\n"""
+
r
""" "长度": [0-9]{1,2}\.[0-9]{0,2}\n"""
+
r
""" \},\n"""
+
r
""" "存活": "(存活|死亡)",\n"""
+
r
""" "守护神": "[^"]{1,32}",\n"""
+
r
""" "博格特": "[^"]{1,32}"\n"""
+
r
"""\}"""
)
@
sgl
.
function
def
character_gen
(
s
,
name
):
s
+=
name
+
" 是一名哈利波特系列小说中的角色。请填写以下关于这个角色的信息。"
s
+=
"""
\
这是一个例子
{
"姓名": "哈利波特",
"学院": "格兰芬多",
"血型": "混血",
"职业": "学生",
"魔杖": {
"材质": "冬青木",
"杖芯": "凤凰尾羽",
"长度": 11.0
},
"存活": "存活",
"守护神": "麋鹿",
"博格特": "摄魂怪"
}
"""
s
+=
f
"现在请你填写
{
name
}
的信息:
\n
"
s
+=
sgl
.
gen
(
"json_output"
,
max_tokens
=
256
,
regex
=
character_regex
)
def
main
():
backend
=
sgl
.
RuntimeEndpoint
(
"http://localhost:30000"
)
sgl
.
set_default_backend
(
backend
)
ret
=
character_gen
.
run
(
name
=
"赫敏格兰杰"
,
temperature
=
0
)
print
(
ret
.
text
())
if
__name__
==
"__main__"
:
main
()
python/sglang/srt/constrained/__init__.py
View file @
9c902b19
...
...
@@ -3,8 +3,8 @@ from typing import Dict, Optional, Union
from
outlines.caching
import
cache
as
disk_cache
from
outlines.caching
import
disable_cache
from
outlines.fsm.
fsm
import
Regex
FSM
from
outlines.fsm.regex
import
FSMInfo
,
make_deterministic_fsm
from
outlines.fsm.
guide
import
Regex
Guide
from
outlines.fsm.regex
import
FSMInfo
,
make_deterministic_fsm
,
make_byte_level_fsm
from
outlines.models.transformers
import
TransformerTokenizer
from
pydantic
import
BaseModel
...
...
@@ -28,11 +28,12 @@ except ImportError:
__all__
=
[
"Regex
FSM
"
,
"Regex
Guide
"
,
"FSMInfo"
,
"make_deterministic_fsm"
,
"build_regex_from_object"
,
"TransformerTokenizer"
,
"disk_cache"
,
"disable_cache"
,
"make_byte_level_fsm"
,
]
python/sglang/srt/constrained/fsm_cache.py
View file @
9c902b19
"""Cache for the compressed finite state machine."""
from
sglang.srt.constrained
import
Regex
FSM
,
TransformerTokenizer
from
sglang.srt.constrained
import
Regex
Guide
,
TransformerTokenizer
from
sglang.srt.constrained.base_cache
import
BaseCache
...
...
@@ -26,4 +26,4 @@ class FSMCache(BaseCache):
)
def
init_value
(
self
,
regex
):
return
Regex
FSM
(
regex
,
self
.
outlines_tokenizer
)
return
Regex
Guide
(
regex
,
self
.
outlines_tokenizer
)
python/sglang/srt/constrained/jump_forward.py
View file @
9c902b19
...
...
@@ -2,20 +2,41 @@
Faster constrained decoding.
Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/
"""
import
interegular
from
sglang.srt.constrained
import
FSMInfo
,
disk_cache
,
make_deterministic_fsm
import
interegular
import
dataclasses
from
collections
import
defaultdict
import
outlines.caching
from
sglang.srt.constrained
import
(
FSMInfo
,
disk_cache
,
make_deterministic_fsm
,
make_byte_level_fsm
,
)
from
sglang.srt.constrained.base_cache
import
BaseCache
IP_REGEX
=
r
"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
@
dataclasses
.
dataclass
class
JumpEdge
:
symbol
:
str
=
None
symbol_next_state
:
int
=
None
byte
:
int
=
None
byte_next_state
:
int
=
None
class
JumpForwardMap
:
def
__init__
(
self
,
regex_string
):
@
disk_cache
()
def
_init_state_to_jump_forward
(
regex_string
):
regex_pattern
=
interegular
.
parse_pattern
(
regex_string
)
regex_fsm
,
_
=
make_deterministic_fsm
(
regex_pattern
.
to_fsm
().
reduce
())
byte_fsm
=
make_byte_level_fsm
(
regex_pattern
.
to_fsm
().
reduce
(),
keep_utf8
=
True
)
regex_fsm
,
_
=
make_deterministic_fsm
(
byte_fsm
)
fsm_info
:
FSMInfo
=
regex_fsm
.
fsm_info
...
...
@@ -25,40 +46,91 @@ class JumpForwardMap:
id_to_symbol
.
setdefault
(
id_
,
[]).
append
(
symbol
)
transitions
=
fsm_info
.
transitions
dirty_states
=
set
(
)
outgoings_ct
=
defaultdict
(
int
)
state_to_jump_forward
=
{}
for
(
state
,
id_
),
next_state
in
transitions
.
items
():
if
state
in
dirty_states
:
if
id_
==
fsm_info
.
alphabet_anything_value
:
continue
if
state
in
state_to_jump_forward
:
dirty_states
.
add
(
state
)
del
state_to_jump_forward
[
state
]
continue
if
len
(
id_to_symbol
[
id_
])
>
1
:
dirty_states
.
add
(
state
)
symbols
=
id_to_symbol
[
id_
]
for
c
in
symbols
:
if
len
(
c
)
>
1
:
# Skip byte level transitions
continue
outgoings_ct
[
state
]
+=
1
if
outgoings_ct
[
state
]
>
1
:
if
state
in
state_to_jump_forward
:
del
state_to_jump_forward
[
state
]
break
state_to_jump_forward
[
state
]
=
JumpEdge
(
symbol
=
c
,
symbol_next_state
=
next_state
,
)
# Process the byte level jump forward
outgoings_ct
=
defaultdict
(
int
)
for
(
state
,
id_
),
next_state
in
transitions
.
items
():
if
id_
==
fsm_info
.
alphabet_anything_value
:
continue
state_to_jump_forward
[
state
]
=
(
id_to_symbol
[
id_
][
0
],
next_state
)
symbols
=
id_to_symbol
[
id_
]
for
c
in
symbols
:
byte_
=
None
if
len
(
c
)
==
1
and
ord
(
c
)
<
0x80
:
# ASCII character
byte_
=
ord
(
c
)
elif
len
(
c
)
==
2
:
byte_
=
int
(
symbols
[
0
],
16
)
if
byte_
is
not
None
:
outgoings_ct
[
state
]
+=
1
if
outgoings_ct
[
state
]
>
1
:
if
state
in
state_to_jump_forward
:
del
state_to_jump_forward
[
state
]
break
e
=
state_to_jump_forward
.
get
(
state
,
JumpEdge
())
e
.
byte
=
byte_
e
.
byte_next_state
=
next_state
state_to_jump_forward
[
state
]
=
e
return
state_to_jump_forward
self
.
state_to_jump_forward
=
_init_state_to_jump_forward
(
regex_string
)
def
valid_states
(
self
):
return
self
.
state_to_jump_forward
.
keys
()
def
jump_forward_symbol
(
self
,
state
):
jump_forward_str
=
""
next_state
=
state
while
state
in
self
.
state_to_jump_forward
:
e
=
self
.
state_to_jump_forward
[
state
]
if
e
.
symbol
is
None
:
break
jump_forward_str
+=
e
.
symbol
next_state
=
e
.
symbol_next_state
state
=
next_state
return
jump_forward_str
,
next_state
def
jump_forward
(
self
,
state
):
def
jump_forward
_byte
(
self
,
state
):
if
state
not
in
self
.
state_to_jump_forward
:
return
None
jump_forward_
str
=
""
jump_forward_
bytes
=
[]
next_state
=
None
while
state
in
self
.
state_to_jump_forward
:
symbol
,
next_state
=
self
.
state_to_jump_forward
[
state
]
jump_forward_str
+=
symbol
e
=
self
.
state_to_jump_forward
[
state
]
assert
e
.
byte
is
not
None
and
e
.
byte_next_state
is
not
None
jump_forward_bytes
.
append
((
e
.
byte
,
e
.
byte_next_state
))
next_state
=
e
.
byte_next_state
state
=
next_state
return
jump_forward_str
,
next_state
return
jump_forward_bytes
def
is_jump_forward_symbol_state
(
self
,
state
):
return
(
state
in
self
.
state_to_jump_forward
and
self
.
state_to_jump_forward
[
state
].
symbol
is
not
None
)
class
JumpForwardCache
(
BaseCache
):
...
...
@@ -69,12 +141,21 @@ class JumpForwardCache(BaseCache):
return
JumpForwardMap
(
regex
)
def
test_main
():
regex_string
=
r
"The google's DNS sever address is "
+
IP_REGEX
def
test_main
(
regex_string
):
jump_forward_map
=
JumpForwardMap
(
regex_string
)
for
state
in
jump_forward_map
.
valid_states
():
print
(
state
,
f
'"
{
jump_forward_map
.
jump_forward
(
state
)
}
"'
)
for
state
,
e
in
jump_forward_map
.
state_to_jump_forward
.
items
():
if
e
.
symbol
is
not
None
:
jump_forward_str
,
next_state
=
jump_forward_map
.
jump_forward_symbol
(
state
)
print
(
f
"
{
state
}
->
{
next_state
}
"
,
jump_forward_str
)
bytes_
=
jump_forward_map
.
jump_forward_byte
(
state
)
print
(
f
"
{
state
}
->
{
bytes_
[
-
1
][
1
]
}
"
,
[
hex
(
b
)
for
b
,
_
in
bytes_
])
if
__name__
==
"__main__"
:
test_main
()
import
outlines
outlines
.
caching
.
clear_cache
()
test_main
(
r
"The google's DNS sever address is "
+
IP_REGEX
)
test_main
(
r
"霍格沃茨特快列车|霍比特人比尔博"
)
# 霍格: \xe9\x9c\x8d \xe6\xa0\xbc ...
# 霍比: \xe9\x9c\x8d \xe6\xaf\x94 ...
python/sglang/srt/managers/controller/infer_batch.py
View file @
9c902b19
...
...
@@ -3,12 +3,17 @@
from
dataclasses
import
dataclass
from
enum
import
IntEnum
,
auto
from
typing
import
List
import
warnings
import
numpy
as
np
import
torch
from
sglang.srt.managers.controller.radix_cache
import
RadixCache
from
sglang.srt.memory_pool
import
ReqToTokenPool
,
TokenToKVPool
from
sglang.srt.constrained.jump_forward
import
JumpForwardMap
from
sglang.srt.constrained
import
RegexGuide
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
class
ForwardMode
(
IntEnum
):
...
...
@@ -64,12 +69,15 @@ class Req:
def
__init__
(
self
,
rid
,
origin_input_text
,
origin_input_ids
):
self
.
rid
=
rid
self
.
origin_input_text
=
origin_input_text
self
.
origin_input_ids_unpadded
=
origin_input_ids
# Before image padding
self
.
origin_input_ids
=
origin_input_ids
self
.
origin_input_ids_unpadded
=
origin_input_ids
# before image padding
self
.
prev_output_str
=
""
self
.
prev_output_ids
=
[]
self
.
output_ids
=
[]
self
.
input_ids
=
None
# input_ids = origin_input_ids + prev_output_ids
self
.
output_ids
=
[]
# Each decode stage's output ids
self
.
input_ids
=
None
# input_ids = origin_input_ids + output_ids
# For incremental decode
self
.
decoded_text
=
""
self
.
surr_offset
=
None
# Surrounding offset to defeat the cleanup algorithm
self
.
read_offset
=
None
# The number of decoded tokens for token usage report. Note that
# this does not include the jump forward tokens.
...
...
@@ -109,20 +117,54 @@ class Req:
self
.
last_update_decode_tokens
=
0
# Constrained decoding
self
.
regex_fsm
=
None
self
.
regex_fsm_state
=
0
self
.
jump_forward_map
=
None
self
.
regex_fsm
:
RegexGuide
=
None
self
.
regex_fsm_state
:
int
=
0
self
.
jump_forward_map
:
JumpForwardMap
=
None
# whether request reached finished condition
def
finished
(
self
)
->
bool
:
return
self
.
finished_reason
is
not
None
def
partial_decode
(
self
,
ids
):
first_token
=
self
.
tokenizer
.
convert_ids_to_tokens
(
ids
[
0
])
first_token
=
(
first_token
.
decode
()
if
isinstance
(
first_token
,
bytes
)
else
first_token
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
def
init_detokenize_incrementally
(
self
):
first_iter
=
self
.
surr_offset
is
None
or
self
.
read_offset
is
None
if
first_iter
:
self
.
read_offset
=
len
(
self
.
origin_input_ids_unpadded
)
self
.
surr_offset
=
max
(
self
.
read_offset
-
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
,
0
)
all_ids
=
self
.
origin_input_ids_unpadded
+
self
.
output_ids
surr_ids
=
all_ids
[
self
.
surr_offset
:
self
.
read_offset
]
read_ids
=
all_ids
[
self
.
surr_offset
:]
return
surr_ids
,
read_ids
,
len
(
all_ids
)
def
detokenize_incrementally
(
self
,
inplace
:
bool
=
True
):
surr_ids
,
read_ids
,
num_all_tokens
=
self
.
init_detokenize_incrementally
()
surr_text
=
self
.
tokenizer
.
decode
(
surr_ids
,
skip_special_tokens
=
self
.
sampling_params
.
skip_special_tokens
,
spaces_between_special_tokens
=
self
.
sampling_params
.
spaces_between_special_tokens
,
)
return
(
" "
if
first_token
.
startswith
(
"▁"
)
else
""
)
+
self
.
tokenizer
.
decode
(
ids
)
new_text
=
self
.
tokenizer
.
decode
(
read_ids
,
skip_special_tokens
=
self
.
sampling_params
.
skip_special_tokens
,
spaces_between_special_tokens
=
self
.
sampling_params
.
spaces_between_special_tokens
,
)
if
len
(
new_text
)
>
len
(
surr_text
)
and
not
new_text
.
endswith
(
"�"
):
new_text
=
new_text
[
len
(
surr_text
)
:]
if
inplace
:
self
.
decoded_text
+=
new_text
self
.
surr_offset
=
self
.
read_offset
self
.
read_offset
=
num_all_tokens
return
True
,
new_text
return
False
,
""
def
max_new_tokens
(
self
):
return
self
.
sampling_params
.
max_new_tokens
...
...
@@ -131,18 +173,17 @@ class Req:
if
self
.
finished
():
return
if
(
len
(
self
.
prev_output_ids
)
+
len
(
self
.
output_ids
)
>=
self
.
sampling_params
.
max_new_tokens
):
self
.
finished_reason
=
FINISH_LENGTH
(
len
(
self
.
prev_output_ids
)
+
len
(
self
.
output_ids
))
if
len
(
self
.
output_ids
)
>=
self
.
sampling_params
.
max_new_tokens
:
self
.
finished_reason
=
FINISH_LENGTH
(
len
(
self
.
output_ids
))
return
if
(
self
.
output_ids
[
-
1
]
==
self
.
tokenizer
.
eos_token_id
and
not
self
.
sampling_params
.
ignore_eos
):
self
.
finished_reason
=
FINISH_MATCHED_TOKEN
(
matched
=
self
.
tokenizer
.
eos_token_id
)
self
.
finished_reason
=
FINISH_MATCHED_TOKEN
(
matched
=
self
.
tokenizer
.
eos_token_id
)
return
if
len
(
self
.
sampling_params
.
stop_strs
)
>
0
:
...
...
@@ -151,61 +192,59 @@ class Req:
)
for
stop_str
in
self
.
sampling_params
.
stop_strs
:
# FIXME: (minor) try incremental match in prev_output_str
if
stop_str
in
tail_str
or
stop_str
in
self
.
prev_output_str
:
if
stop_str
in
tail_str
or
stop_str
in
self
.
decoded_text
:
self
.
finished_reason
=
FINISH_MATCHED_STR
(
matched
=
stop_str
)
return
def
jump_forward_and_retokenize
(
self
,
jump_forward_str
,
next_state
):
# FIXME: This logic does not really solve the problem of determining whether
# there should be a leading space.
cur_output_str
=
self
.
partial_decode
(
self
.
output_ids
)
# TODO(lsyin): apply re-tokenize only for decode tokens so that we do not need origin_input_text anymore
if
self
.
origin_input_text
is
None
:
# Recovering text can only use unpadded ids
self
.
origin_input_text
=
self
.
tokenizer
.
decode
(
self
.
origin_input_ids_unpadded
)
all_text
=
(
self
.
origin_input_text
+
self
.
prev_output_str
+
cur_output_str
+
jump_forward_str
)
all_text
=
self
.
origin_input_text
+
self
.
decoded_text
+
jump_forward_str
all_ids
=
self
.
tokenizer
.
encode
(
all_text
)
prompt_tokens
=
len
(
self
.
origin_input_ids_unpadded
)
self
.
origin_input_ids
=
all_ids
[:
prompt_tokens
]
self
.
origin_input_ids_unpadded
=
self
.
origin_input_ids
# NOTE: the output ids may not strictly correspond to the output text
old_prev_output_ids
=
self
.
prev_output_ids
self
.
prev_output_ids
=
all_ids
[
prompt_tokens
:]
self
.
prev_output_str
=
self
.
prev_output_str
+
cur_output_str
+
jump_forward_str
self
.
output_ids
=
[]
if
all_ids
[
prompt_tokens
-
1
]
!=
self
.
origin_input_ids_unpadded
[
-
1
]:
# TODO(lsyin): fix token fusion
warnings
.
warn
(
"Token fusion between input and output, try to avoid this by removing the space at the end of the input."
)
return
False
old_output_ids
=
self
.
output_ids
self
.
output_ids
=
all_ids
[
prompt_tokens
:]
self
.
decoded_text
=
self
.
decoded_text
+
jump_forward_str
self
.
surr_offset
=
prompt_tokens
self
.
read_offset
=
len
(
all_ids
)
# NOTE: A trick to reduce the surrouding tokens decoding overhead
for
i
in
range
(
0
,
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
):
surr_text_
=
self
.
tokenizer
.
decode
(
all_ids
[
self
.
read_offset
-
i
:
self
.
read_offset
]
)
if
not
surr_text_
.
endswith
(
"�"
):
self
.
surr_offset
=
self
.
read_offset
-
i
break
self
.
regex_fsm_state
=
next_state
if
self
.
return_logprob
:
# For fast-forward part's logprobs
k
=
0
for
i
,
old_id
in
enumerate
(
old_
prev_
output_ids
):
if
old_id
==
self
.
prev_
output_ids
[
i
]:
for
i
,
old_id
in
enumerate
(
old_output_ids
):
if
old_id
==
self
.
output_ids
[
i
]:
k
=
k
+
1
else
:
break
self
.
decode_token_logprobs
=
self
.
decode_token_logprobs
[:
k
]
self
.
decode_top_logprobs
=
self
.
decode_top_logprobs
[:
k
]
self
.
logprob_start_len
=
prompt_tokens
+
k
self
.
last_update_decode_tokens
=
len
(
self
.
prev_
output_ids
)
-
k
self
.
last_update_decode_tokens
=
len
(
self
.
output_ids
)
-
k
# print("=" * 100)
# print(f"Catch jump forward:\n{jump_forward_str}")
# print(self.tokenizer.convert_ids_to_tokens(self.input_ids))
# print(self.tokenizer.convert_ids_to_tokens(new_input_ids))
# print(f"Output and jump forward str:\n{self.output_and_jump_forward_str}")
# print("*" * 100)
return
True
def
__repr__
(
self
):
return
f
"rid(n=
{
self
.
rid
}
, "
f
"input_ids=
{
self
.
origin_input_ids
}
, "
...
...
@@ -381,7 +420,10 @@ class Batch:
sorted_indices
=
[
i
for
i
in
range
(
len
(
self
.
reqs
))]
# TODO(lsyin): improve the priority of retraction
sorted_indices
.
sort
(
key
=
lambda
i
:
(
len
(
self
.
reqs
[
i
].
output_ids
),
-
len
(
self
.
reqs
[
i
].
input_ids
)),
key
=
lambda
i
:
(
len
(
self
.
reqs
[
i
].
output_ids
),
-
len
(
self
.
reqs
[
i
].
origin_input_ids
),
),
reverse
=
True
,
)
...
...
@@ -403,14 +445,9 @@ class Batch:
# release the last node
self
.
tree_cache
.
dec_lock_ref
(
req
.
last_node
)
cur_output_str
=
req
.
partial_decode
(
req
.
output_ids
)
req
.
prev_output_str
=
req
.
prev_output_str
+
cur_output_str
req
.
prev_output_ids
.
extend
(
req
.
output_ids
)
req
.
prefix_indices
=
None
req
.
last_node
=
None
req
.
extend_input_len
=
0
req
.
output_ids
=
[]
# For incremental logprobs
req
.
last_update_decode_tokens
=
0
...
...
@@ -428,18 +465,53 @@ class Batch:
for
i
,
req
in
enumerate
(
self
.
reqs
):
if
req
.
jump_forward_map
is
not
None
:
res
=
req
.
jump_forward_map
.
jump_forward
(
req
.
regex_fsm_state
)
if
res
is
not
None
:
jump_forward_str
,
next_state
=
res
if
len
(
jump_forward_str
)
<=
1
:
jump_forward_bytes
=
req
.
jump_forward_map
.
jump_forward_byte
(
req
.
regex_fsm_state
)
if
jump_forward_bytes
is
not
None
and
len
(
jump_forward_bytes
)
>
1
:
suffix_bytes
=
[]
continuation_range
=
range
(
0x80
,
0xC0
)
cur_state
=
req
.
regex_fsm_state
while
(
len
(
jump_forward_bytes
)
and
jump_forward_bytes
[
0
][
0
]
in
continuation_range
):
# continuation bytes
byte_edge
=
jump_forward_bytes
.
pop
(
0
)
suffix_bytes
.
append
(
byte_edge
[
0
])
cur_state
=
byte_edge
[
1
]
suffix_tokens
=
[
f
"<0x
{
hex
(
b
)[
2
:].
upper
()
}
>"
for
b
in
suffix_bytes
]
suffix_ids
=
req
.
tokenizer
.
convert_tokens_to_ids
(
suffix_tokens
)
# Current ids, for cache and revert
cur_all_ids
=
tuple
(
req
.
origin_input_ids
+
req
.
output_ids
)[:
-
1
]
cur_output_ids
=
req
.
output_ids
req
.
output_ids
.
extend
(
suffix_ids
)
decode_res
,
new_text
=
req
.
detokenize_incrementally
(
inplace
=
False
)
if
not
decode_res
:
req
.
output_ids
=
cur_output_ids
continue
if
req_pool_indices_cpu
is
None
:
req_pool_indices_cpu
=
self
.
req_pool_indices
.
tolist
()
jump_forward_str
,
next_state
=
(
req
.
jump_forward_map
.
jump_forward_symbol
(
cur_state
)
)
# Make the incrementally decoded text part of jump_forward_str
# so that the UTF-8 will not corrupt
jump_forward_str
=
new_text
+
jump_forward_str
if
not
req
.
jump_forward_and_retokenize
(
jump_forward_str
,
next_state
):
req
.
output_ids
=
cur_output_ids
continue
# insert the old request into tree_cache
if
req_pool_indices_cpu
is
None
:
req_pool_indices_cpu
=
self
.
req_pool_indices
.
tolist
()
self
.
tree_cache
.
cache_req
(
token_ids
=
tuple
(
req
.
input_ids
+
req
.
output_ids
)[:
-
1
]
,
token_ids
=
cur_all_ids
,
last_uncached_pos
=
len
(
req
.
prefix_indices
),
req_pool_idx
=
req_pool_indices_cpu
[
i
],
)
...
...
@@ -447,9 +519,6 @@ class Batch:
# unlock the last node
self
.
tree_cache
.
dec_lock_ref
(
req
.
last_node
)
# jump-forward
req
.
jump_forward_and_retokenize
(
jump_forward_str
,
next_state
)
# re-applying image padding
if
req
.
pixel_values
is
not
None
:
(
...
...
@@ -583,7 +652,7 @@ class Batch:
if
req
.
regex_fsm
is
not
None
:
allowed_mask
.
zero_
()
allowed_mask
[
req
.
regex_fsm
.
allowed_token_ids
(
req
.
regex_fsm_state
)
req
.
regex_fsm
.
get_next_instruction
(
req
.
regex_fsm_state
)
.
tokens
]
=
1
logits
[
i
].
masked_fill_
(
~
allowed_mask
,
float
(
"-inf"
))
...
...
@@ -602,7 +671,7 @@ class Batch:
batch_next_token_ids_cpu
=
batch_next_token_ids
.
cpu
().
numpy
()
for
i
,
req
in
enumerate
(
self
.
reqs
):
if
req
.
regex_fsm
is
not
None
:
req
.
regex_fsm_state
=
req
.
regex_fsm
.
next_state
(
req
.
regex_fsm_state
=
req
.
regex_fsm
.
get_
next_state
(
req
.
regex_fsm_state
,
batch_next_token_ids_cpu
[
i
]
)
...
...
python/sglang/srt/managers/controller/tp_worker.py
View file @
9c902b19
...
...
@@ -21,7 +21,13 @@ from sglang.srt.managers.io_struct import (
FlushCacheReq
,
TokenizedGenerateReqInput
,
)
from
sglang.srt.managers.controller.infer_batch
import
BaseFinishReason
,
Batch
,
FINISH_ABORT
,
ForwardMode
,
Req
from
sglang.srt.managers.controller.infer_batch
import
(
BaseFinishReason
,
Batch
,
FINISH_ABORT
,
ForwardMode
,
Req
,
)
from
sglang.srt.managers.controller.model_runner
import
ModelRunner
from
sglang.srt.managers.controller.radix_cache
import
RadixCache
from
sglang.srt.managers.controller.schedule_heuristic
import
ScheduleHeuristic
...
...
@@ -98,8 +104,11 @@ class ModelTpServer:
else
server_args
.
max_prefill_tokens
),
)
self
.
max_running_requests
=
(
self
.
max_total_num_tokens
//
2
if
server_args
.
max_running_requests
is
None
else
server_args
.
max_running_requests
)
self
.
max_running_requests
=
(
self
.
max_total_num_tokens
//
2
if
server_args
.
max_running_requests
is
None
else
server_args
.
max_running_requests
)
self
.
int_token_logit_bias
=
torch
.
tensor
(
get_int_token_logit_bias
(
self
.
tokenizer
,
self
.
model_config
.
vocab_size
)
)
...
...
@@ -314,10 +323,7 @@ class ModelTpServer:
# Compute matched prefix length
for
req
in
self
.
forward_queue
:
assert
(
len
(
req
.
output_ids
)
==
0
),
"The output ids should be empty when prefilling"
req
.
input_ids
=
req
.
origin_input_ids
+
req
.
prev_output_ids
req
.
input_ids
=
req
.
origin_input_ids
+
req
.
output_ids
prefix_indices
,
last_node
=
self
.
tree_cache
.
match_prefix
(
req
.
input_ids
)
if
req
.
return_logprob
:
prefix_indices
=
prefix_indices
[:
req
.
logprob_start_len
]
...
...
@@ -464,7 +470,7 @@ class ModelTpServer:
pt
=
0
for
i
,
req
in
enumerate
(
batch
.
reqs
):
req
.
completion_tokens_wo_jump_forward
+=
1
req
.
output_ids
=
[
next_token_ids
[
i
]
]
req
.
output_ids
.
append
(
next_token_ids
[
i
]
)
req
.
check_finished
()
if
req
.
return_logprob
:
...
...
@@ -524,7 +530,7 @@ class ModelTpServer:
req_pool_indices_cpu
=
batch
.
req_pool_indices
.
cpu
().
numpy
()
for
i
,
req
in
enumerate
(
batch
.
reqs
):
new_prefix_indices
,
new_last_node
=
self
.
tree_cache
.
cache_req
(
token_ids
=
tuple
(
req
.
input_ids
+
req
.
output_ids
)[:
-
1
],
token_ids
=
tuple
(
req
.
origin_
input_ids
+
req
.
output_ids
)[:
-
1
],
last_uncached_pos
=
len
(
req
.
prefix_indices
),
req_pool_idx
=
req_pool_indices_cpu
[
i
],
del_in_memory_pool
=
False
,
...
...
@@ -596,8 +602,9 @@ class ModelTpServer:
def
handle_finished_requests
(
self
,
batch
:
Batch
):
output_rids
=
[]
prev_output_strs
=
[]
output_tokens
=
[]
decoded_texts
=
[]
surr_output_ids
=
[]
read_output_ids
=
[]
output_skip_special_tokens
=
[]
output_spaces_between_special_tokens
=
[]
output_meta_info
=
[]
...
...
@@ -620,8 +627,10 @@ class ModelTpServer:
)
):
output_rids
.
append
(
req
.
rid
)
prev_output_strs
.
append
(
req
.
prev_output_str
)
output_tokens
.
append
(
req
.
output_ids
)
decoded_texts
.
append
(
req
.
decoded_text
)
surr_ids
,
read_ids
,
_
=
req
.
init_detokenize_incrementally
()
surr_output_ids
.
append
(
surr_ids
)
read_output_ids
.
append
(
read_ids
)
output_skip_special_tokens
.
append
(
req
.
sampling_params
.
skip_special_tokens
)
...
...
@@ -631,7 +640,7 @@ class ModelTpServer:
meta_info
=
{
"prompt_tokens"
:
len
(
req
.
origin_input_ids
),
"completion_tokens"
:
len
(
req
.
prev_output_ids
)
+
len
(
req
.
output_ids
),
"completion_tokens"
:
len
(
req
.
output_ids
),
"completion_tokens_wo_jump_forward"
:
req
.
completion_tokens_wo_jump_forward
,
"finish_reason"
:
str
(
req
.
finished_reason
),
}
...
...
@@ -657,8 +666,9 @@ class ModelTpServer:
self
.
out_pyobjs
.
append
(
BatchTokenIDOut
(
output_rids
,
prev_output_strs
,
output_tokens
,
decoded_texts
,
surr_output_ids
,
read_output_ids
,
output_skip_special_tokens
,
output_spaces_between_special_tokens
,
output_meta_info
,
...
...
@@ -673,7 +683,7 @@ class ModelTpServer:
for
i
in
finished_indices
:
req
=
batch
.
reqs
[
i
]
self
.
tree_cache
.
cache_req
(
token_ids
=
tuple
(
req
.
input_ids
+
req
.
output_ids
)[:
-
1
],
token_ids
=
tuple
(
req
.
origin_
input_ids
+
req
.
output_ids
)[:
-
1
],
last_uncached_pos
=
len
(
req
.
prefix_indices
),
req_pool_idx
=
req_pool_indices_cpu
[
i
],
)
...
...
@@ -790,4 +800,4 @@ class ModelTpClient:
return
_func
self
.
step
=
async_wrap
(
"step"
)
\ No newline at end of file
self
.
step
=
async_wrap
(
"step"
)
python/sglang/srt/managers/detokenizer_manager.py
View file @
9c902b19
...
...
@@ -39,30 +39,24 @@ class DetokenizerManager:
recv_obj
:
BatchTokenIDOut
=
await
self
.
recv_from_router
.
recv_pyobj
()
assert
isinstance
(
recv_obj
,
BatchTokenIDOut
)
output_tokens
=
recv_obj
.
output_tokens
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
output_strs
=
self
.
tokenizer
.
batch_decode
(
output_tokens
,
surr_texts
=
self
.
tokenizer
.
batch_decode
(
recv_obj
.
surr_output_ids
,
skip_special_tokens
=
recv_obj
.
skip_special_tokens
[
0
],
spaces_between_special_tokens
=
recv_obj
.
spaces_between_special_tokens
[
0
],
)
read_texts
=
self
.
tokenizer
.
batch_decode
(
recv_obj
.
read_output_ids
,
skip_special_tokens
=
recv_obj
.
skip_special_tokens
[
0
],
spaces_between_special_tokens
=
recv_obj
.
spaces_between_special_tokens
[
0
],
spaces_between_special_tokens
=
recv_obj
.
spaces_between_special_tokens
[
0
],
)
# Trim stop str
# TODO(lmzheng): handle the case where multiple stop strs are hit
for
i
in
range
(
len
(
output_strs
)):
if
len
(
output_tokens
[
i
])
>
0
:
first_token
=
self
.
tokenizer
.
convert_ids_to_tokens
(
int
(
output_tokens
[
i
][
0
])
)
if
not
isinstance
(
first_token
,
str
):
first_token
=
first_token
.
decode
(
"utf-8"
,
errors
=
"ignore"
)
if
first_token
.
startswith
(
"▁"
):
output_strs
[
i
]
=
" "
+
output_strs
[
i
]
output_strs
[
i
]
=
recv_obj
.
prev_output_strs
[
i
]
+
output_strs
[
i
]
output_strs
=
[]
for
i
in
range
(
len
(
recv_obj
.
rids
)):
new_text
=
read_texts
[
i
][
len
(
surr_texts
[
i
])
:]
output_strs
.
append
(
recv_obj
.
decoded_texts
[
i
]
+
new_text
)
if
isinstance
(
recv_obj
.
finished_reason
[
i
],
FINISH_MATCHED_STR
):
pos
=
output_strs
[
i
].
find
(
recv_obj
.
finished_reason
[
i
].
matched
)
...
...
python/sglang/srt/managers/io_struct.py
View file @
9c902b19
...
...
@@ -111,13 +111,15 @@ class TokenizedGenerateReqInput:
@
dataclass
class
BatchTokenIDOut
:
rids
:
List
[
str
]
prev_output_strs
:
List
[
str
]
output_tokens
:
List
[
List
[
int
]]
decoded_texts
:
List
[
str
]
surr_output_ids
:
List
[
List
[
int
]]
read_output_ids
:
List
[
List
[
int
]]
skip_special_tokens
:
List
[
bool
]
spaces_between_special_tokens
:
List
[
bool
]
meta_info
:
List
[
Dict
]
finished_reason
:
List
[
BaseFinishReason
]
@
dataclass
class
BatchStrOut
:
rids
:
List
[
str
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment