Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c020f9ce
Unverified
Commit
c020f9ce
authored
Aug 01, 2024
by
Liangsheng Yin
Committed by
GitHub
Aug 01, 2024
Browse files
Support chunked prefill when radix cache is disabled (#811)
parent
ca600e8c
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
164 additions
and
27 deletions
+164
-27
python/sglang/srt/constrained/base_tool_cache.py
python/sglang/srt/constrained/base_tool_cache.py
+2
-2
python/sglang/srt/constrained/fsm_cache.py
python/sglang/srt/constrained/fsm_cache.py
+2
-2
python/sglang/srt/constrained/jump_forward.py
python/sglang/srt/constrained/jump_forward.py
+2
-2
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+29
-9
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+21
-6
python/sglang/srt/mem_cache/base_cache.py
python/sglang/srt/mem_cache/base_cache.py
+43
-0
python/sglang/srt/mem_cache/chunk_cache.py
python/sglang/srt/mem_cache/chunk_cache.py
+60
-0
python/sglang/srt/mem_cache/radix_cache.py
python/sglang/srt/mem_cache/radix_cache.py
+5
-2
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+0
-4
No files found.
python/sglang/srt/constrained/base_cache.py
→
python/sglang/srt/constrained/base_
tool_
cache.py
View file @
c020f9ce
...
@@ -13,12 +13,12 @@ See the License for the specific language governing permissions and
...
@@ -13,12 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
"""
"""
"""Base
cache clas
s."""
"""Base
tool cache for constrained decoding tool
s."""
import
time
import
time
class
BaseCache
:
class
Base
Tool
Cache
:
def
__init__
(
self
,
enable
=
True
):
def
__init__
(
self
,
enable
=
True
):
self
.
enable
=
enable
self
.
enable
=
enable
self
.
reset
()
self
.
reset
()
...
...
python/sglang/srt/constrained/fsm_cache.py
View file @
c020f9ce
...
@@ -16,10 +16,10 @@ limitations under the License.
...
@@ -16,10 +16,10 @@ limitations under the License.
"""Cache for the compressed finite state machine."""
"""Cache for the compressed finite state machine."""
from
sglang.srt.constrained
import
RegexGuide
,
TransformerTokenizer
from
sglang.srt.constrained
import
RegexGuide
,
TransformerTokenizer
from
sglang.srt.constrained.base_cache
import
BaseCache
from
sglang.srt.constrained.base_
tool_
cache
import
Base
Tool
Cache
class
FSMCache
(
BaseCache
):
class
FSMCache
(
Base
Tool
Cache
):
def
__init__
(
self
,
tokenizer_path
,
tokenizer_args_dict
,
enable
=
True
):
def
__init__
(
self
,
tokenizer_path
,
tokenizer_args_dict
,
enable
=
True
):
super
().
__init__
(
enable
=
enable
)
super
().
__init__
(
enable
=
enable
)
...
...
python/sglang/srt/constrained/jump_forward.py
View file @
c020f9ce
...
@@ -30,7 +30,7 @@ from sglang.srt.constrained import (
...
@@ -30,7 +30,7 @@ from sglang.srt.constrained import (
make_byte_level_fsm
,
make_byte_level_fsm
,
make_deterministic_fsm
,
make_deterministic_fsm
,
)
)
from
sglang.srt.constrained.base_cache
import
BaseCache
from
sglang.srt.constrained.base_
tool_
cache
import
Base
Tool
Cache
IP_REGEX
=
r
"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
IP_REGEX
=
r
"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
...
@@ -151,7 +151,7 @@ class JumpForwardMap:
...
@@ -151,7 +151,7 @@ class JumpForwardMap:
)
)
class
JumpForwardCache
(
BaseCache
):
class
JumpForwardCache
(
Base
Tool
Cache
):
def
__init__
(
self
):
def
__init__
(
self
):
super
().
__init__
()
super
().
__init__
()
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
c020f9ce
...
@@ -28,6 +28,7 @@ from flashinfer.sampling import top_k_top_p_sampling_from_probs
...
@@ -28,6 +28,7 @@ from flashinfer.sampling import top_k_top_p_sampling_from_probs
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt.constrained
import
RegexGuide
from
sglang.srt.constrained
import
RegexGuide
from
sglang.srt.constrained.jump_forward
import
JumpForwardMap
from
sglang.srt.constrained.jump_forward
import
JumpForwardMap
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
,
TokenToKVPool
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
,
TokenToKVPool
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
...
@@ -486,15 +487,33 @@ class Batch:
...
@@ -486,15 +487,33 @@ class Batch:
req
=
self
.
reqs
[
idx
]
req
=
self
.
reqs
[
idx
]
retracted_reqs
.
append
(
req
)
retracted_reqs
.
append
(
req
)
# TODO: apply more fine-grained retraction
if
isinstance
(
self
.
tree_cache
,
ChunkCache
):
last_uncached_pos
=
len
(
req
.
prefix_indices
)
# ChunkCache does not have eviction
token_indices
=
self
.
req_to_token_pool
.
req_to_token
[
token_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req_pool_indices_cpu
[
idx
]
req_pool_indices_cpu
[
idx
]
][
last_uncached_pos
:
seq_lens_cpu
[
idx
]]
][:
seq_lens_cpu
[
idx
]]
self
.
token_to_kv_pool
.
free
(
token_indices
)
self
.
token_to_kv_pool
.
free
(
token_indices
)
self
.
req_to_token_pool
.
free
(
int
(
req_pool_indices_cpu
[
idx
]))
# release the last node
del
self
.
tree_cache
.
entries
[
req
.
rid
]
self
.
tree_cache
.
dec_lock_ref
(
req
.
last_node
)
else
:
# TODO: apply more fine-grained retraction
last_uncached_pos
=
len
(
req
.
prefix_indices
)
token_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req_pool_indices_cpu
[
idx
]
][
last_uncached_pos
:
seq_lens_cpu
[
idx
]]
self
.
token_to_kv_pool
.
free
(
token_indices
)
self
.
req_to_token_pool
.
free
(
int
(
req_pool_indices_cpu
[
idx
]))
# release the last node
self
.
tree_cache
.
dec_lock_ref
(
req
.
last_node
)
# NOTE(lsyin): we should use the newly evictable memory instantly.
residual_size
=
(
len
(
sorted_indices
)
*
global_config
.
retract_decode_steps
-
self
.
token_to_kv_pool
.
available_size
()
)
residual_size
=
max
(
0
,
residual_size
)
self
.
tree_cache
.
evict
(
residual_size
,
self
.
token_to_kv_pool
.
free
)
req
.
prefix_indices
=
None
req
.
prefix_indices
=
None
req
.
last_node
=
None
req
.
last_node
=
None
...
@@ -575,6 +594,7 @@ class Batch:
...
@@ -575,6 +594,7 @@ class Batch:
if
req_pool_indices_cpu
is
None
:
if
req_pool_indices_cpu
is
None
:
req_pool_indices_cpu
=
self
.
req_pool_indices
.
tolist
()
req_pool_indices_cpu
=
self
.
req_pool_indices
.
tolist
()
self
.
tree_cache
.
cache_req
(
self
.
tree_cache
.
cache_req
(
rid
=
req
.
rid
,
token_ids
=
cur_all_ids
,
token_ids
=
cur_all_ids
,
last_uncached_pos
=
len
(
req
.
prefix_indices
),
last_uncached_pos
=
len
(
req
.
prefix_indices
),
req_pool_idx
=
req_pool_indices_cpu
[
i
],
req_pool_idx
=
req_pool_indices_cpu
[
i
],
...
...
python/sglang/srt/managers/tp_worker.py
View file @
c020f9ce
...
@@ -43,6 +43,7 @@ from sglang.srt.managers.schedule_batch import (
...
@@ -43,6 +43,7 @@ from sglang.srt.managers.schedule_batch import (
ForwardMode
,
ForwardMode
,
Req
,
Req
,
)
)
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
from
sglang.srt.model_config
import
ModelConfig
from
sglang.srt.model_config
import
ModelConfig
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
...
@@ -144,11 +145,20 @@ class ModelTpServer:
...
@@ -144,11 +145,20 @@ class ModelTpServer:
)
)
# Init cache
# Init cache
self
.
tree_cache
=
RadixCache
(
if
(
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
,
server_args
.
chunked_prefill_size
is
not
None
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
,
and
server_args
.
disable_radix_cache
disable
=
server_args
.
disable_radix_cache
,
):
)
self
.
tree_cache
=
ChunkCache
(
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
,
)
else
:
self
.
tree_cache
=
RadixCache
(
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
,
disable
=
server_args
.
disable_radix_cache
,
)
self
.
tree_cache_metrics
=
{
"total"
:
0
,
"hit"
:
0
}
self
.
tree_cache_metrics
=
{
"total"
:
0
,
"hit"
:
0
}
self
.
scheduler
=
PolicyScheduler
(
self
.
scheduler
=
PolicyScheduler
(
self
.
schedule_policy
,
self
.
schedule_policy
,
...
@@ -354,7 +364,10 @@ class ModelTpServer:
...
@@ -354,7 +364,10 @@ class ModelTpServer:
# Compute matched prefix length
# Compute matched prefix length
for
req
in
self
.
waiting_queue
:
for
req
in
self
.
waiting_queue
:
req
.
input_ids
=
req
.
origin_input_ids
+
req
.
output_ids
req
.
input_ids
=
req
.
origin_input_ids
+
req
.
output_ids
prefix_indices
,
last_node
=
self
.
tree_cache
.
match_prefix
(
req
.
input_ids
)
prefix_indices
,
last_node
=
self
.
tree_cache
.
match_prefix
(
rid
=
req
.
rid
,
key
=
req
.
input_ids
,
)
if
req
.
return_logprob
:
if
req
.
return_logprob
:
prefix_indices
=
prefix_indices
[:
req
.
logprob_start_len
]
prefix_indices
=
prefix_indices
[:
req
.
logprob_start_len
]
req
.
extend_input_len
=
len
(
req
.
input_ids
)
-
len
(
prefix_indices
)
req
.
extend_input_len
=
len
(
req
.
input_ids
)
-
len
(
prefix_indices
)
...
@@ -614,6 +627,7 @@ class ModelTpServer:
...
@@ -614,6 +627,7 @@ class ModelTpServer:
req_pool_indices_cpu
=
batch
.
req_pool_indices
.
cpu
().
numpy
()
req_pool_indices_cpu
=
batch
.
req_pool_indices
.
cpu
().
numpy
()
for
i
,
req
in
enumerate
(
batch
.
reqs
):
for
i
,
req
in
enumerate
(
batch
.
reqs
):
new_prefix_indices
,
new_last_node
=
self
.
tree_cache
.
cache_req
(
new_prefix_indices
,
new_last_node
=
self
.
tree_cache
.
cache_req
(
rid
=
req
.
rid
,
token_ids
=
tuple
(
req
.
input_ids
),
token_ids
=
tuple
(
req
.
input_ids
),
last_uncached_pos
=
len
(
req
.
prefix_indices
),
last_uncached_pos
=
len
(
req
.
prefix_indices
),
req_pool_idx
=
req_pool_indices_cpu
[
i
],
req_pool_idx
=
req_pool_indices_cpu
[
i
],
...
@@ -771,6 +785,7 @@ class ModelTpServer:
...
@@ -771,6 +785,7 @@ class ModelTpServer:
for
i
in
finished_indices
:
for
i
in
finished_indices
:
req
=
batch
.
reqs
[
i
]
req
=
batch
.
reqs
[
i
]
self
.
tree_cache
.
cache_req
(
self
.
tree_cache
.
cache_req
(
rid
=
req
.
rid
,
token_ids
=
tuple
(
req
.
origin_input_ids
+
req
.
output_ids
)[:
-
1
],
token_ids
=
tuple
(
req
.
origin_input_ids
+
req
.
output_ids
)[:
-
1
],
last_uncached_pos
=
len
(
req
.
prefix_indices
),
last_uncached_pos
=
len
(
req
.
prefix_indices
),
req_pool_idx
=
req_pool_indices_cpu
[
i
],
req_pool_idx
=
req_pool_indices_cpu
[
i
],
...
...
python/sglang/srt/mem_cache/base_cache.py
0 → 100644
View file @
c020f9ce
from
abc
import
ABC
,
abstractmethod
class
BasePrefixCache
(
ABC
):
"""Cache can be indexed by either rid or key."""
@
abstractmethod
def
reset
(
self
):
pass
@
abstractmethod
def
match_prefix
(
self
,
**
kwargs
):
pass
@
abstractmethod
def
insert
(
self
,
**
kwargs
):
pass
@
abstractmethod
def
cache_req
(
self
,
**
kwargs
):
pass
@
abstractmethod
def
evict
(
self
,
num_tokens
,
evict_callback
):
pass
@
abstractmethod
def
inc_lock_ref
(
self
,
node
):
pass
@
abstractmethod
def
dec_lock_ref
(
self
,
node
):
pass
@
abstractmethod
def
evictable_size
(
self
):
pass
def
total_size
(
self
):
raise
NotImplementedError
def
pretty_print
(
self
):
raise
NotImplementedError
python/sglang/srt/mem_cache/chunk_cache.py
0 → 100644
View file @
c020f9ce
"""Cache for chunked prefill, used when RadixCache is disabled."""
from
sglang.srt.mem_cache.base_cache
import
BasePrefixCache
class
ChunkCacheEntry
:
def
__init__
(
self
,
rid
,
value
):
self
.
rid
=
rid
self
.
value
=
value
class
ChunkCache
(
BasePrefixCache
):
def
__init__
(
self
,
req_to_token_pool
,
token_to_kv_pool
):
self
.
disable
=
True
self
.
req_to_token_pool
=
req_to_token_pool
self
.
token_to_kv_pool
=
token_to_kv_pool
self
.
reset
()
def
reset
(
self
):
self
.
entries
=
{}
def
match_prefix
(
self
,
rid
,
**
kwargs
):
if
rid
not
in
self
.
entries
:
return
[],
None
entry
=
self
.
entries
[
rid
]
return
entry
.
value
,
entry
def
cache_req
(
self
,
rid
,
token_ids
,
req_pool_idx
,
del_in_memory_pool
=
True
,
**
kwargs
):
indices
=
self
.
req_to_token_pool
.
req_to_token
[
req_pool_idx
,
:
len
(
token_ids
)]
if
del_in_memory_pool
:
assert
rid
in
self
.
entries
self
.
req_to_token_pool
.
free
(
req_pool_idx
)
self
.
token_to_kv_pool
.
free
(
indices
)
return
if
rid
not
in
self
.
entries
:
self
.
entries
[
rid
]
=
ChunkCacheEntry
(
rid
,
indices
)
entry
=
self
.
entries
[
rid
]
entry
.
value
=
indices
return
indices
,
entry
def
insert
(
self
):
raise
NotImplementedError
def
evict
(
self
,
num_tokens
,
evict_callback
):
pass
def
inc_lock_ref
(
self
,
node
):
return
0
def
dec_lock_ref
(
self
,
node
):
return
0
def
evictable_size
(
self
):
return
0
python/sglang/srt/mem_cache/radix_cache.py
View file @
c020f9ce
...
@@ -23,6 +23,8 @@ from collections import defaultdict
...
@@ -23,6 +23,8 @@ from collections import defaultdict
import
torch
import
torch
from
sglang.srt.mem_cache.base_cache
import
BasePrefixCache
class
TreeNode
:
class
TreeNode
:
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -46,7 +48,7 @@ def _key_match(key0, key1):
...
@@ -46,7 +48,7 @@ def _key_match(key0, key1):
return
i
return
i
class
RadixCache
:
class
RadixCache
(
BasePrefixCache
)
:
def
__init__
(
self
,
req_to_token_pool
,
token_to_kv_pool
,
disable
:
bool
=
False
):
def
__init__
(
self
,
req_to_token_pool
,
token_to_kv_pool
,
disable
:
bool
=
False
):
self
.
req_to_token_pool
=
req_to_token_pool
self
.
req_to_token_pool
=
req_to_token_pool
self
.
token_to_kv_pool
=
token_to_kv_pool
self
.
token_to_kv_pool
=
token_to_kv_pool
...
@@ -62,7 +64,7 @@ class RadixCache:
...
@@ -62,7 +64,7 @@ class RadixCache:
self
.
root_node
.
lock_ref
=
1
self
.
root_node
.
lock_ref
=
1
self
.
evictable_size_
=
0
self
.
evictable_size_
=
0
def
match_prefix
(
self
,
key
):
def
match_prefix
(
self
,
key
,
**
kwargs
):
if
self
.
disable
:
if
self
.
disable
:
return
[],
self
.
root_node
return
[],
self
.
root_node
...
@@ -90,6 +92,7 @@ class RadixCache:
...
@@ -90,6 +92,7 @@ class RadixCache:
req_pool_idx
,
req_pool_idx
,
del_in_memory_pool
=
True
,
del_in_memory_pool
=
True
,
old_last_node
=
None
,
old_last_node
=
None
,
**
kwargs
,
):
):
# Insert the request into radix cache
# Insert the request into radix cache
indices
=
self
.
req_to_token_pool
.
req_to_token
[
req_pool_idx
,
:
len
(
token_ids
)]
indices
=
self
.
req_to_token_pool
.
req_to_token
[
req_pool_idx
,
:
len
(
token_ids
)]
...
...
python/sglang/srt/server_args.py
View file @
c020f9ce
...
@@ -419,10 +419,6 @@ class ServerArgs:
...
@@ -419,10 +419,6 @@ class ServerArgs:
self
.
dp_size
>
1
and
self
.
node_rank
is
not
None
self
.
dp_size
>
1
and
self
.
node_rank
is
not
None
),
"multi-node data parallel is not supported"
),
"multi-node data parallel is not supported"
assert
not
(
self
.
chunked_prefill_size
is
not
None
and
self
.
disable_radix_cache
),
"chunked prefill is not supported with radix cache disabled currently"
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
PortArgs
:
class
PortArgs
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment