Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a9ef49c1
Unverified
Commit
a9ef49c1
authored
Jul 18, 2024
by
Liangsheng Yin
Committed by
GitHub
Jul 18, 2024
Browse files
Detokenize incrementally when streaming (#653)
parent
21ba3a88
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
101 additions
and
33 deletions
+101
-33
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+30
-4
python/sglang/srt/managers/controller/infer_batch.py
python/sglang/srt/managers/controller/infer_batch.py
+15
-15
python/sglang/srt/managers/controller/tp_worker.py
python/sglang/srt/managers/controller/tp_worker.py
+7
-7
python/sglang/srt/managers/detokenizer_manager.py
python/sglang/srt/managers/detokenizer_manager.py
+47
-5
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+2
-2
No files found.
python/sglang/srt/layers/radix_attention.py
View file @
a9ef49c1
...
@@ -136,7 +136,33 @@ class RadixAttention(nn.Module):
...
@@ -136,7 +136,33 @@ class RadixAttention(nn.Module):
return
self
.
decode_forward
(
q
,
k
,
v
,
input_metadata
)
return
self
.
decode_forward
(
q
,
k
,
v
,
input_metadata
)
def
store_kv_cache
(
self
,
cache_k
,
cache_v
,
input_metadata
:
InputMetadata
):
def
store_kv_cache
(
self
,
cache_k
,
cache_v
,
input_metadata
:
InputMetadata
):
key_buffer
=
input_metadata
.
token_to_kv_pool
.
get_key_buffer
(
self
.
layer_id
)
kv_cache
=
input_metadata
.
token_to_kv_pool
.
kv_data
[
self
.
layer_id
]
key_buffer
[
input_metadata
.
out_cache_loc
]
=
cache_k
_store_kv_cache
(
cache_k
,
cache_v
,
kv_cache
,
input_metadata
.
out_cache_loc
)
value_buffer
=
input_metadata
.
token_to_kv_pool
.
get_value_buffer
(
self
.
layer_id
)
value_buffer
[
input_metadata
.
out_cache_loc
]
=
cache_v
try
:
@
torch
.
library
.
custom_op
(
"mylib::store_kv_cache"
,
mutates_args
=
{
"kv_cache"
})
def
_store_kv_cache
(
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
cache_loc
:
torch
.
Tensor
,
)
->
None
:
kv_cache
[
cache_loc
,
0
]
=
k
kv_cache
[
cache_loc
,
1
]
=
v
@
_store_kv_cache
.
register_fake
def
_
(
k
,
v
,
kv_cache
,
cache_loc
):
pass
except
:
def
_store_kv_cache
(
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
cache_loc
:
torch
.
Tensor
,
)
->
None
:
kv_cache
[
cache_loc
,
0
]
=
k
kv_cache
[
cache_loc
,
1
]
=
v
python/sglang/srt/managers/controller/infer_batch.py
View file @
a9ef49c1
...
@@ -82,6 +82,14 @@ class Req:
...
@@ -82,6 +82,14 @@ class Req:
self
.
input_ids
=
None
# input_ids = origin_input_ids + output_ids
self
.
input_ids
=
None
# input_ids = origin_input_ids + output_ids
# For incremental decoding
# For incremental decoding
# ----- | --------- read_ids -------|
# ----- | surr_ids |
# xxxxx | xxxxxxxxxxx | xxxxxxxxxxx |
# ----- ^ ----------- ^ ----------- ^
# ----- 1 ----------- 2 ----------- 3
# 1: surr_offset
# 2: read_offset
# 3: last token
self
.
decoded_text
=
""
self
.
decoded_text
=
""
self
.
surr_offset
=
None
# Surrounding offset to defeat the cleanup algorithm
self
.
surr_offset
=
None
# Surrounding offset to defeat the cleanup algorithm
self
.
read_offset
=
None
self
.
read_offset
=
None
...
@@ -132,7 +140,7 @@ class Req:
...
@@ -132,7 +140,7 @@ class Req:
return
self
.
finished_reason
is
not
None
return
self
.
finished_reason
is
not
None
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
def
init_
detokenize_incrementally
(
self
):
def
init_
incremental_detokenize
(
self
):
first_iter
=
self
.
surr_offset
is
None
or
self
.
read_offset
is
None
first_iter
=
self
.
surr_offset
is
None
or
self
.
read_offset
is
None
if
first_iter
:
if
first_iter
:
...
@@ -142,13 +150,11 @@ class Req:
...
@@ -142,13 +150,11 @@ class Req:
)
)
all_ids
=
self
.
origin_input_ids_unpadded
+
self
.
output_ids
all_ids
=
self
.
origin_input_ids_unpadded
+
self
.
output_ids
surr_ids
=
all_ids
[
self
.
surr_offset
:
self
.
read_offset
]
return
all_ids
[
self
.
surr_offset
:],
self
.
read_offset
-
self
.
surr_offset
read_ids
=
all_ids
[
self
.
surr_offset
:]
return
surr_ids
,
read_ids
,
len
(
all_ids
)
def
get_next_inc_detokenization
(
self
):
read_ids
,
read_offset
=
self
.
init_incremental_detokenize
()
def
detokenize_incrementally
(
self
,
inplace
:
bool
=
True
):
surr_ids
=
read_ids
[:
read_offset
]
surr_ids
,
read_ids
,
num_all_tokens
=
self
.
init_detokenize_incrementally
()
surr_text
=
self
.
tokenizer
.
decode
(
surr_text
=
self
.
tokenizer
.
decode
(
surr_ids
,
surr_ids
,
...
@@ -162,13 +168,7 @@ class Req:
...
@@ -162,13 +168,7 @@ class Req:
)
)
if
len
(
new_text
)
>
len
(
surr_text
)
and
not
new_text
.
endswith
(
"�"
):
if
len
(
new_text
)
>
len
(
surr_text
)
and
not
new_text
.
endswith
(
"�"
):
new_text
=
new_text
[
len
(
surr_text
)
:]
return
True
,
new_text
[
len
(
surr_text
)
:]
if
inplace
:
self
.
decoded_text
+=
new_text
self
.
surr_offset
=
self
.
read_offset
self
.
read_offset
=
num_all_tokens
return
True
,
new_text
return
False
,
""
return
False
,
""
...
@@ -501,7 +501,7 @@ class Batch:
...
@@ -501,7 +501,7 @@ class Batch:
cur_output_ids
=
req
.
output_ids
cur_output_ids
=
req
.
output_ids
req
.
output_ids
.
extend
(
suffix_ids
)
req
.
output_ids
.
extend
(
suffix_ids
)
decode_res
,
new_text
=
req
.
d
et
okenize_incrementally
(
inplace
=
False
)
decode_res
,
new_text
=
req
.
g
et
_next_inc_detokenization
(
)
if
not
decode_res
:
if
not
decode_res
:
req
.
output_ids
=
cur_output_ids
req
.
output_ids
=
cur_output_ids
continue
continue
...
...
python/sglang/srt/managers/controller/tp_worker.py
View file @
a9ef49c1
...
@@ -590,8 +590,8 @@ class ModelTpServer:
...
@@ -590,8 +590,8 @@ class ModelTpServer:
def
handle_finished_requests
(
self
,
batch
:
Batch
):
def
handle_finished_requests
(
self
,
batch
:
Batch
):
output_rids
=
[]
output_rids
=
[]
decoded_texts
=
[]
decoded_texts
=
[]
surr_
output_ids
=
[]
output_
read_
ids
=
[]
read_
output_
id
s
=
[]
output_
read_offset
s
=
[]
output_skip_special_tokens
=
[]
output_skip_special_tokens
=
[]
output_spaces_between_special_tokens
=
[]
output_spaces_between_special_tokens
=
[]
output_meta_info
=
[]
output_meta_info
=
[]
...
@@ -615,9 +615,9 @@ class ModelTpServer:
...
@@ -615,9 +615,9 @@ class ModelTpServer:
):
):
output_rids
.
append
(
req
.
rid
)
output_rids
.
append
(
req
.
rid
)
decoded_texts
.
append
(
req
.
decoded_text
)
decoded_texts
.
append
(
req
.
decoded_text
)
surr
_ids
,
read_
ids
,
_
=
req
.
init_
detokenize_incrementally
()
read
_ids
,
read_
offset
=
req
.
init_
incremental_detokenize
()
surr_
output_ids
.
append
(
surr
_ids
)
output_
read_
ids
.
append
(
read
_ids
)
read_
output_
id
s
.
append
(
read_
ids
)
output_
read_offset
s
.
append
(
read_
offset
)
output_skip_special_tokens
.
append
(
output_skip_special_tokens
.
append
(
req
.
sampling_params
.
skip_special_tokens
req
.
sampling_params
.
skip_special_tokens
)
)
...
@@ -654,8 +654,8 @@ class ModelTpServer:
...
@@ -654,8 +654,8 @@ class ModelTpServer:
BatchTokenIDOut
(
BatchTokenIDOut
(
output_rids
,
output_rids
,
decoded_texts
,
decoded_texts
,
surr_
output_ids
,
output_
read_
ids
,
read_
output_
id
s
,
output_
read_offset
s
,
output_skip_special_tokens
,
output_skip_special_tokens
,
output_spaces_between_special_tokens
,
output_spaces_between_special_tokens
,
output_meta_info
,
output_meta_info
,
...
...
python/sglang/srt/managers/detokenizer_manager.py
View file @
a9ef49c1
"""DetokenizerManager is a process that detokenizes the token ids."""
"""DetokenizerManager is a process that detokenizes the token ids."""
import
asyncio
import
asyncio
import
dataclasses
import
inspect
import
inspect
from
typing
import
List
import
uvloop
import
uvloop
import
zmq
import
zmq
...
@@ -16,6 +18,14 @@ from sglang.utils import find_printable_text, get_exception_traceback, graceful_
...
@@ -16,6 +18,14 @@ from sglang.utils import find_printable_text, get_exception_traceback, graceful_
asyncio
.
set_event_loop_policy
(
uvloop
.
EventLoopPolicy
())
asyncio
.
set_event_loop_policy
(
uvloop
.
EventLoopPolicy
())
@
dataclasses
.
dataclass
class
DecodeStatus
:
decoded_text
:
str
decode_ids
:
List
[
int
]
surr_offset
:
int
read_offset
:
int
class
DetokenizerManager
:
class
DetokenizerManager
:
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -35,19 +45,42 @@ class DetokenizerManager:
...
@@ -35,19 +45,42 @@ class DetokenizerManager:
trust_remote_code
=
server_args
.
trust_remote_code
,
trust_remote_code
=
server_args
.
trust_remote_code
,
)
)
self
.
decode_status
=
{}
async
def
handle_loop
(
self
):
async
def
handle_loop
(
self
):
while
True
:
while
True
:
recv_obj
:
BatchTokenIDOut
=
await
self
.
recv_from_router
.
recv_pyobj
()
recv_obj
:
BatchTokenIDOut
=
await
self
.
recv_from_router
.
recv_pyobj
()
assert
isinstance
(
recv_obj
,
BatchTokenIDOut
)
assert
isinstance
(
recv_obj
,
BatchTokenIDOut
)
bs
=
len
(
recv_obj
.
rids
)
# FIXME: incremental detokenize is not compatible with jump forward
# Initialize decode status
read_ids
,
surr_ids
=
[],
[]
for
i
in
range
(
bs
):
rid
=
recv_obj
.
rids
[
i
]
if
rid
not
in
self
.
decode_status
:
s
=
DecodeStatus
(
decoded_text
=
recv_obj
.
decoded_texts
[
i
],
decode_ids
=
recv_obj
.
decode_ids
[
i
],
surr_offset
=
0
,
read_offset
=
recv_obj
.
read_offsets
[
i
],
)
self
.
decode_status
[
rid
]
=
s
else
:
s
=
self
.
decode_status
[
rid
]
s
.
decode_ids
=
recv_obj
.
decode_ids
[
i
]
read_ids
.
append
(
s
.
decode_ids
[
s
.
surr_offset
:])
surr_ids
.
append
(
s
.
decode_ids
[
s
.
surr_offset
:
s
.
read_offset
])
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
surr_texts
=
self
.
tokenizer
.
batch_decode
(
surr_texts
=
self
.
tokenizer
.
batch_decode
(
recv_obj
.
surr_output
_ids
,
surr
_ids
,
skip_special_tokens
=
recv_obj
.
skip_special_tokens
[
0
],
skip_special_tokens
=
recv_obj
.
skip_special_tokens
[
0
],
spaces_between_special_tokens
=
recv_obj
.
spaces_between_special_tokens
[
0
],
spaces_between_special_tokens
=
recv_obj
.
spaces_between_special_tokens
[
0
],
)
)
read_texts
=
self
.
tokenizer
.
batch_decode
(
read_texts
=
self
.
tokenizer
.
batch_decode
(
re
cv_obj
.
read_output
_ids
,
re
ad
_ids
,
skip_special_tokens
=
recv_obj
.
skip_special_tokens
[
0
],
skip_special_tokens
=
recv_obj
.
skip_special_tokens
[
0
],
spaces_between_special_tokens
=
recv_obj
.
spaces_between_special_tokens
[
0
],
spaces_between_special_tokens
=
recv_obj
.
spaces_between_special_tokens
[
0
],
)
)
...
@@ -55,11 +88,20 @@ class DetokenizerManager:
...
@@ -55,11 +88,20 @@ class DetokenizerManager:
# Trim stop str
# Trim stop str
# TODO(lmzheng): handle the case where multiple stop strs are hit
# TODO(lmzheng): handle the case where multiple stop strs are hit
output_strs
=
[]
output_strs
=
[]
for
i
in
range
(
len
(
recv_obj
.
rids
)):
for
i
in
range
(
bs
):
s
=
self
.
decode_status
[
recv_obj
.
rids
[
i
]]
new_text
=
read_texts
[
i
][
len
(
surr_texts
[
i
])
:]
new_text
=
read_texts
[
i
][
len
(
surr_texts
[
i
])
:]
if
recv_obj
.
finished_reason
[
i
]
is
None
:
if
recv_obj
.
finished_reason
[
i
]
is
None
:
new_text
=
find_printable_text
(
new_text
)
# Streaming chunk: update the decode status
output_strs
.
append
(
recv_obj
.
decoded_texts
[
i
]
+
new_text
)
if
len
(
new_text
)
>
0
and
not
new_text
.
endswith
(
"�"
):
s
.
decoded_text
=
s
.
decoded_text
+
new_text
s
.
surr_offset
=
s
.
read_offset
s
.
read_offset
=
len
(
s
.
decode_ids
)
new_text
=
""
else
:
new_text
=
find_printable_text
(
new_text
)
output_strs
.
append
(
s
.
decoded_text
+
new_text
)
if
isinstance
(
recv_obj
.
finished_reason
[
i
],
FINISH_MATCHED_STR
):
if
isinstance
(
recv_obj
.
finished_reason
[
i
],
FINISH_MATCHED_STR
):
pos
=
output_strs
[
i
].
find
(
recv_obj
.
finished_reason
[
i
].
matched
)
pos
=
output_strs
[
i
].
find
(
recv_obj
.
finished_reason
[
i
].
matched
)
...
...
python/sglang/srt/managers/io_struct.py
View file @
a9ef49c1
...
@@ -111,8 +111,8 @@ class TokenizedGenerateReqInput:
...
@@ -111,8 +111,8 @@ class TokenizedGenerateReqInput:
class
BatchTokenIDOut
:
class
BatchTokenIDOut
:
rids
:
List
[
str
]
rids
:
List
[
str
]
decoded_texts
:
List
[
str
]
decoded_texts
:
List
[
str
]
surr_output
_ids
:
List
[
List
[
int
]
]
decode
_ids
:
List
[
int
]
read_o
utput_id
s
:
List
[
List
[
int
]
]
read_o
ffset
s
:
List
[
int
]
skip_special_tokens
:
List
[
bool
]
skip_special_tokens
:
List
[
bool
]
spaces_between_special_tokens
:
List
[
bool
]
spaces_between_special_tokens
:
List
[
bool
]
meta_info
:
List
[
Dict
]
meta_info
:
List
[
Dict
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment