Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a9ef49c1
Unverified
Commit
a9ef49c1
authored
Jul 18, 2024
by
Liangsheng Yin
Committed by
GitHub
Jul 18, 2024
Browse files
Detokenize incrementally when streaming (#653)
parent
21ba3a88
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
101 additions
and
33 deletions
+101
-33
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+30
-4
python/sglang/srt/managers/controller/infer_batch.py
python/sglang/srt/managers/controller/infer_batch.py
+15
-15
python/sglang/srt/managers/controller/tp_worker.py
python/sglang/srt/managers/controller/tp_worker.py
+7
-7
python/sglang/srt/managers/detokenizer_manager.py
python/sglang/srt/managers/detokenizer_manager.py
+47
-5
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+2
-2
No files found.
python/sglang/srt/layers/radix_attention.py
View file @
a9ef49c1
...
...
@@ -136,7 +136,33 @@ class RadixAttention(nn.Module):
return
self
.
decode_forward
(
q
,
k
,
v
,
input_metadata
)
def
store_kv_cache
(
self
,
cache_k
,
cache_v
,
input_metadata
:
InputMetadata
):
key_buffer
=
input_metadata
.
token_to_kv_pool
.
get_key_buffer
(
self
.
layer_id
)
key_buffer
[
input_metadata
.
out_cache_loc
]
=
cache_k
value_buffer
=
input_metadata
.
token_to_kv_pool
.
get_value_buffer
(
self
.
layer_id
)
value_buffer
[
input_metadata
.
out_cache_loc
]
=
cache_v
kv_cache
=
input_metadata
.
token_to_kv_pool
.
kv_data
[
self
.
layer_id
]
_store_kv_cache
(
cache_k
,
cache_v
,
kv_cache
,
input_metadata
.
out_cache_loc
)
try
:
@
torch
.
library
.
custom_op
(
"mylib::store_kv_cache"
,
mutates_args
=
{
"kv_cache"
})
def
_store_kv_cache
(
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
cache_loc
:
torch
.
Tensor
,
)
->
None
:
kv_cache
[
cache_loc
,
0
]
=
k
kv_cache
[
cache_loc
,
1
]
=
v
@
_store_kv_cache
.
register_fake
def
_
(
k
,
v
,
kv_cache
,
cache_loc
):
pass
except
:
def
_store_kv_cache
(
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
cache_loc
:
torch
.
Tensor
,
)
->
None
:
kv_cache
[
cache_loc
,
0
]
=
k
kv_cache
[
cache_loc
,
1
]
=
v
python/sglang/srt/managers/controller/infer_batch.py
View file @
a9ef49c1
...
...
@@ -82,6 +82,14 @@ class Req:
self
.
input_ids
=
None
# input_ids = origin_input_ids + output_ids
# For incremental decoding
# ----- | --------- read_ids -------|
# ----- | surr_ids |
# xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
# ----- ^ ----------- ^ ----------- ^
# ----- 1 ----------- 2 ----------- 3
# 1: surr_offset
# 2: read_offset
# 3: last token
self
.
decoded_text
=
""
self
.
surr_offset
=
None
# Surrounding offset to defeat the cleanup algorithm
self
.
read_offset
=
None
...
...
@@ -132,7 +140,7 @@ class Req:
return
self
.
finished_reason
is
not
None
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
def
init_
detokenize_incrementally
(
self
):
def
init_
incremental_detokenize
(
self
):
first_iter
=
self
.
surr_offset
is
None
or
self
.
read_offset
is
None
if
first_iter
:
...
...
@@ -142,13 +150,11 @@ class Req:
)
all_ids
=
self
.
origin_input_ids_unpadded
+
self
.
output_ids
surr_ids
=
all_ids
[
self
.
surr_offset
:
self
.
read_offset
]
read_ids
=
all_ids
[
self
.
surr_offset
:]
return
all_ids
[
self
.
surr_offset
:],
self
.
read_offset
-
self
.
surr_offset
return
surr_ids
,
read_ids
,
len
(
all_ids
)
def
detokenize_incrementally
(
self
,
inplace
:
bool
=
True
):
surr_ids
,
read_ids
,
num_all_tokens
=
self
.
init_detokenize_incrementally
()
def
get_next_inc_detokenization
(
self
):
read_ids
,
read_offset
=
self
.
init_incremental_detokenize
()
surr_ids
=
read_ids
[:
read_offset
]
surr_text
=
self
.
tokenizer
.
decode
(
surr_ids
,
...
...
@@ -162,13 +168,7 @@ class Req:
)
if
len
(
new_text
)
>
len
(
surr_text
)
and
not
new_text
.
endswith
(
"�"
):
new_text
=
new_text
[
len
(
surr_text
)
:]
if
inplace
:
self
.
decoded_text
+=
new_text
self
.
surr_offset
=
self
.
read_offset
self
.
read_offset
=
num_all_tokens
return
True
,
new_text
return
True
,
new_text
[
len
(
surr_text
)
:]
return
False
,
""
...
...
@@ -501,7 +501,7 @@ class Batch:
cur_output_ids
=
req
.
output_ids
req
.
output_ids
.
extend
(
suffix_ids
)
decode_res
,
new_text
=
req
.
d
et
okenize_incrementally
(
inplace
=
False
)
decode_res
,
new_text
=
req
.
g
et
_next_inc_detokenization
(
)
if
not
decode_res
:
req
.
output_ids
=
cur_output_ids
continue
...
...
python/sglang/srt/managers/controller/tp_worker.py
View file @
a9ef49c1
...
...
@@ -590,8 +590,8 @@ class ModelTpServer:
def
handle_finished_requests
(
self
,
batch
:
Batch
):
output_rids
=
[]
decoded_texts
=
[]
surr_
output_ids
=
[]
read_
output_
id
s
=
[]
output_
read_
ids
=
[]
output_
read_offset
s
=
[]
output_skip_special_tokens
=
[]
output_spaces_between_special_tokens
=
[]
output_meta_info
=
[]
...
...
@@ -615,9 +615,9 @@ class ModelTpServer:
):
output_rids
.
append
(
req
.
rid
)
decoded_texts
.
append
(
req
.
decoded_text
)
surr
_ids
,
read_
ids
,
_
=
req
.
init_
detokenize_incrementally
()
surr_
output_ids
.
append
(
surr
_ids
)
read_
output_
id
s
.
append
(
read_
ids
)
read
_ids
,
read_
offset
=
req
.
init_
incremental_detokenize
()
output_
read_
ids
.
append
(
read
_ids
)
output_
read_offset
s
.
append
(
read_
offset
)
output_skip_special_tokens
.
append
(
req
.
sampling_params
.
skip_special_tokens
)
...
...
@@ -654,8 +654,8 @@ class ModelTpServer:
BatchTokenIDOut
(
output_rids
,
decoded_texts
,
surr_
output_ids
,
read_
output_
id
s
,
output_
read_
ids
,
output_
read_offset
s
,
output_skip_special_tokens
,
output_spaces_between_special_tokens
,
output_meta_info
,
...
...
python/sglang/srt/managers/detokenizer_manager.py
View file @
a9ef49c1
"""DetokenizerManager is a process that detokenizes the token ids."""
import
asyncio
import
dataclasses
import
inspect
from
typing
import
List
import
uvloop
import
zmq
...
...
@@ -16,6 +18,14 @@ from sglang.utils import find_printable_text, get_exception_traceback, graceful_
asyncio
.
set_event_loop_policy
(
uvloop
.
EventLoopPolicy
())
@
dataclasses
.
dataclass
class
DecodeStatus
:
decoded_text
:
str
decode_ids
:
List
[
int
]
surr_offset
:
int
read_offset
:
int
class
DetokenizerManager
:
def
__init__
(
self
,
...
...
@@ -35,19 +45,42 @@ class DetokenizerManager:
trust_remote_code
=
server_args
.
trust_remote_code
,
)
self
.
decode_status
=
{}
async
def
handle_loop
(
self
):
while
True
:
recv_obj
:
BatchTokenIDOut
=
await
self
.
recv_from_router
.
recv_pyobj
()
assert
isinstance
(
recv_obj
,
BatchTokenIDOut
)
bs
=
len
(
recv_obj
.
rids
)
# FIXME: incremental detokenize is not compatible with jump forward
# Initialize decode status
read_ids
,
surr_ids
=
[],
[]
for
i
in
range
(
bs
):
rid
=
recv_obj
.
rids
[
i
]
if
rid
not
in
self
.
decode_status
:
s
=
DecodeStatus
(
decoded_text
=
recv_obj
.
decoded_texts
[
i
],
decode_ids
=
recv_obj
.
decode_ids
[
i
],
surr_offset
=
0
,
read_offset
=
recv_obj
.
read_offsets
[
i
],
)
self
.
decode_status
[
rid
]
=
s
else
:
s
=
self
.
decode_status
[
rid
]
s
.
decode_ids
=
recv_obj
.
decode_ids
[
i
]
read_ids
.
append
(
s
.
decode_ids
[
s
.
surr_offset
:])
surr_ids
.
append
(
s
.
decode_ids
[
s
.
surr_offset
:
s
.
read_offset
])
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
surr_texts
=
self
.
tokenizer
.
batch_decode
(
recv_obj
.
surr_output
_ids
,
surr
_ids
,
skip_special_tokens
=
recv_obj
.
skip_special_tokens
[
0
],
spaces_between_special_tokens
=
recv_obj
.
spaces_between_special_tokens
[
0
],
)
read_texts
=
self
.
tokenizer
.
batch_decode
(
re
cv_obj
.
read_output
_ids
,
re
ad
_ids
,
skip_special_tokens
=
recv_obj
.
skip_special_tokens
[
0
],
spaces_between_special_tokens
=
recv_obj
.
spaces_between_special_tokens
[
0
],
)
...
...
@@ -55,11 +88,20 @@ class DetokenizerManager:
# Trim stop str
# TODO(lmzheng): handle the case where multiple stop strs are hit
output_strs
=
[]
for
i
in
range
(
len
(
recv_obj
.
rids
)):
for
i
in
range
(
bs
):
s
=
self
.
decode_status
[
recv_obj
.
rids
[
i
]]
new_text
=
read_texts
[
i
][
len
(
surr_texts
[
i
])
:]
if
recv_obj
.
finished_reason
[
i
]
is
None
:
new_text
=
find_printable_text
(
new_text
)
output_strs
.
append
(
recv_obj
.
decoded_texts
[
i
]
+
new_text
)
# Streaming chunk: update the decode status
if
len
(
new_text
)
>
0
and
not
new_text
.
endswith
(
"�"
):
s
.
decoded_text
=
s
.
decoded_text
+
new_text
s
.
surr_offset
=
s
.
read_offset
s
.
read_offset
=
len
(
s
.
decode_ids
)
new_text
=
""
else
:
new_text
=
find_printable_text
(
new_text
)
output_strs
.
append
(
s
.
decoded_text
+
new_text
)
if
isinstance
(
recv_obj
.
finished_reason
[
i
],
FINISH_MATCHED_STR
):
pos
=
output_strs
[
i
].
find
(
recv_obj
.
finished_reason
[
i
].
matched
)
...
...
python/sglang/srt/managers/io_struct.py
View file @
a9ef49c1
...
...
@@ -111,8 +111,8 @@ class TokenizedGenerateReqInput:
class
BatchTokenIDOut
:
rids
:
List
[
str
]
decoded_texts
:
List
[
str
]
surr_output
_ids
:
List
[
List
[
int
]
]
read_o
utput_id
s
:
List
[
List
[
int
]
]
decode
_ids
:
List
[
int
]
read_o
ffset
s
:
List
[
int
]
skip_special_tokens
:
List
[
bool
]
spaces_between_special_tokens
:
List
[
bool
]
meta_info
:
List
[
Dict
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment