Unverified Commit c020f9ce authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Support chunked prefill when radix cache is disabled (#811)

parent ca600e8c
...@@ -13,12 +13,12 @@ See the License for the specific language governing permissions and ...@@ -13,12 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
"""Base cache class.""" """Base tool cache for constrained decoding tools."""
import time import time
class BaseCache: class BaseToolCache:
def __init__(self, enable=True): def __init__(self, enable=True):
self.enable = enable self.enable = enable
self.reset() self.reset()
......
...@@ -16,10 +16,10 @@ limitations under the License. ...@@ -16,10 +16,10 @@ limitations under the License.
"""Cache for the compressed finite state machine.""" """Cache for the compressed finite state machine."""
from sglang.srt.constrained import RegexGuide, TransformerTokenizer from sglang.srt.constrained import RegexGuide, TransformerTokenizer
from sglang.srt.constrained.base_cache import BaseCache from sglang.srt.constrained.base_tool_cache import BaseToolCache
class FSMCache(BaseCache): class FSMCache(BaseToolCache):
def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True): def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True):
super().__init__(enable=enable) super().__init__(enable=enable)
......
...@@ -30,7 +30,7 @@ from sglang.srt.constrained import ( ...@@ -30,7 +30,7 @@ from sglang.srt.constrained import (
make_byte_level_fsm, make_byte_level_fsm,
make_deterministic_fsm, make_deterministic_fsm,
) )
from sglang.srt.constrained.base_cache import BaseCache from sglang.srt.constrained.base_tool_cache import BaseToolCache
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)" IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
...@@ -151,7 +151,7 @@ class JumpForwardMap: ...@@ -151,7 +151,7 @@ class JumpForwardMap:
) )
class JumpForwardCache(BaseCache): class JumpForwardCache(BaseToolCache):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
......
...@@ -28,6 +28,7 @@ from flashinfer.sampling import top_k_top_p_sampling_from_probs ...@@ -28,6 +28,7 @@ from flashinfer.sampling import top_k_top_p_sampling_from_probs
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.constrained import RegexGuide from sglang.srt.constrained import RegexGuide
from sglang.srt.constrained.jump_forward import JumpForwardMap from sglang.srt.constrained.jump_forward import JumpForwardMap
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.mem_cache.radix_cache import RadixCache
...@@ -486,16 +487,34 @@ class Batch: ...@@ -486,16 +487,34 @@ class Batch:
req = self.reqs[idx] req = self.reqs[idx]
retracted_reqs.append(req) retracted_reqs.append(req)
if isinstance(self.tree_cache, ChunkCache):
# ChunkCache does not have eviction
token_indices = self.req_to_token_pool.req_to_token[
req_pool_indices_cpu[idx]
][: seq_lens_cpu[idx]]
self.token_to_kv_pool.free(token_indices)
self.req_to_token_pool.free(int(req_pool_indices_cpu[idx]))
del self.tree_cache.entries[req.rid]
else:
# TODO: apply more fine-grained retraction # TODO: apply more fine-grained retraction
last_uncached_pos = len(req.prefix_indices) last_uncached_pos = len(req.prefix_indices)
token_indices = self.req_to_token_pool.req_to_token[ token_indices = self.req_to_token_pool.req_to_token[
req_pool_indices_cpu[idx] req_pool_indices_cpu[idx]
][last_uncached_pos : seq_lens_cpu[idx]] ][last_uncached_pos : seq_lens_cpu[idx]]
self.token_to_kv_pool.free(token_indices) self.token_to_kv_pool.free(token_indices)
self.req_to_token_pool.free(int(req_pool_indices_cpu[idx]))
# release the last node # release the last node
self.tree_cache.dec_lock_ref(req.last_node) self.tree_cache.dec_lock_ref(req.last_node)
# NOTE(lsyin): we should use the newly evictable memory instantly.
residual_size = (
len(sorted_indices) * global_config.retract_decode_steps
- self.token_to_kv_pool.available_size()
)
residual_size = max(0, residual_size)
self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
req.prefix_indices = None req.prefix_indices = None
req.last_node = None req.last_node = None
req.extend_input_len = 0 req.extend_input_len = 0
...@@ -575,6 +594,7 @@ class Batch: ...@@ -575,6 +594,7 @@ class Batch:
if req_pool_indices_cpu is None: if req_pool_indices_cpu is None:
req_pool_indices_cpu = self.req_pool_indices.tolist() req_pool_indices_cpu = self.req_pool_indices.tolist()
self.tree_cache.cache_req( self.tree_cache.cache_req(
rid=req.rid,
token_ids=cur_all_ids, token_ids=cur_all_ids,
last_uncached_pos=len(req.prefix_indices), last_uncached_pos=len(req.prefix_indices),
req_pool_idx=req_pool_indices_cpu[i], req_pool_idx=req_pool_indices_cpu[i],
......
...@@ -43,6 +43,7 @@ from sglang.srt.managers.schedule_batch import ( ...@@ -43,6 +43,7 @@ from sglang.srt.managers.schedule_batch import (
ForwardMode, ForwardMode,
Req, Req,
) )
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.model_config import ModelConfig from sglang.srt.model_config import ModelConfig
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
...@@ -144,6 +145,15 @@ class ModelTpServer: ...@@ -144,6 +145,15 @@ class ModelTpServer:
) )
# Init cache # Init cache
if (
server_args.chunked_prefill_size is not None
and server_args.disable_radix_cache
):
self.tree_cache = ChunkCache(
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool,
)
else:
self.tree_cache = RadixCache( self.tree_cache = RadixCache(
req_to_token_pool=self.model_runner.req_to_token_pool, req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool, token_to_kv_pool=self.model_runner.token_to_kv_pool,
...@@ -354,7 +364,10 @@ class ModelTpServer: ...@@ -354,7 +364,10 @@ class ModelTpServer:
# Compute matched prefix length # Compute matched prefix length
for req in self.waiting_queue: for req in self.waiting_queue:
req.input_ids = req.origin_input_ids + req.output_ids req.input_ids = req.origin_input_ids + req.output_ids
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids) prefix_indices, last_node = self.tree_cache.match_prefix(
rid=req.rid,
key=req.input_ids,
)
if req.return_logprob: if req.return_logprob:
prefix_indices = prefix_indices[: req.logprob_start_len] prefix_indices = prefix_indices[: req.logprob_start_len]
req.extend_input_len = len(req.input_ids) - len(prefix_indices) req.extend_input_len = len(req.input_ids) - len(prefix_indices)
...@@ -614,6 +627,7 @@ class ModelTpServer: ...@@ -614,6 +627,7 @@ class ModelTpServer:
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy() req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
for i, req in enumerate(batch.reqs): for i, req in enumerate(batch.reqs):
new_prefix_indices, new_last_node = self.tree_cache.cache_req( new_prefix_indices, new_last_node = self.tree_cache.cache_req(
rid=req.rid,
token_ids=tuple(req.input_ids), token_ids=tuple(req.input_ids),
last_uncached_pos=len(req.prefix_indices), last_uncached_pos=len(req.prefix_indices),
req_pool_idx=req_pool_indices_cpu[i], req_pool_idx=req_pool_indices_cpu[i],
...@@ -771,6 +785,7 @@ class ModelTpServer: ...@@ -771,6 +785,7 @@ class ModelTpServer:
for i in finished_indices: for i in finished_indices:
req = batch.reqs[i] req = batch.reqs[i]
self.tree_cache.cache_req( self.tree_cache.cache_req(
rid=req.rid,
token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1], token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
last_uncached_pos=len(req.prefix_indices), last_uncached_pos=len(req.prefix_indices),
req_pool_idx=req_pool_indices_cpu[i], req_pool_idx=req_pool_indices_cpu[i],
......
from abc import ABC, abstractmethod
class BasePrefixCache(ABC):
"""Cache can be indexed by either rid or key."""
@abstractmethod
def reset(self):
pass
@abstractmethod
def match_prefix(self, **kwargs):
pass
@abstractmethod
def insert(self, **kwargs):
pass
@abstractmethod
def cache_req(self, **kwargs):
pass
@abstractmethod
def evict(self, num_tokens, evict_callback):
pass
@abstractmethod
def inc_lock_ref(self, node):
pass
@abstractmethod
def dec_lock_ref(self, node):
pass
@abstractmethod
def evictable_size(self):
pass
def total_size(self):
raise NotImplementedError
def pretty_print(self):
raise NotImplementedError
"""Cache for chunked prefill, used when RadixCache is disabled."""
from sglang.srt.mem_cache.base_cache import BasePrefixCache
class ChunkCacheEntry:
def __init__(self, rid, value):
self.rid = rid
self.value = value
class ChunkCache(BasePrefixCache):
def __init__(self, req_to_token_pool, token_to_kv_pool):
self.disable = True
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool = token_to_kv_pool
self.reset()
def reset(self):
self.entries = {}
def match_prefix(self, rid, **kwargs):
if rid not in self.entries:
return [], None
entry = self.entries[rid]
return entry.value, entry
def cache_req(
self, rid, token_ids, req_pool_idx, del_in_memory_pool=True, **kwargs
):
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
if del_in_memory_pool:
assert rid in self.entries
self.req_to_token_pool.free(req_pool_idx)
self.token_to_kv_pool.free(indices)
return
if rid not in self.entries:
self.entries[rid] = ChunkCacheEntry(rid, indices)
entry = self.entries[rid]
entry.value = indices
return indices, entry
def insert(self):
raise NotImplementedError
def evict(self, num_tokens, evict_callback):
pass
def inc_lock_ref(self, node):
return 0
def dec_lock_ref(self, node):
return 0
def evictable_size(self):
return 0
...@@ -23,6 +23,8 @@ from collections import defaultdict ...@@ -23,6 +23,8 @@ from collections import defaultdict
import torch import torch
from sglang.srt.mem_cache.base_cache import BasePrefixCache
class TreeNode: class TreeNode:
def __init__(self): def __init__(self):
...@@ -46,7 +48,7 @@ def _key_match(key0, key1): ...@@ -46,7 +48,7 @@ def _key_match(key0, key1):
return i return i
class RadixCache: class RadixCache(BasePrefixCache):
def __init__(self, req_to_token_pool, token_to_kv_pool, disable: bool = False): def __init__(self, req_to_token_pool, token_to_kv_pool, disable: bool = False):
self.req_to_token_pool = req_to_token_pool self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool = token_to_kv_pool self.token_to_kv_pool = token_to_kv_pool
...@@ -62,7 +64,7 @@ class RadixCache: ...@@ -62,7 +64,7 @@ class RadixCache:
self.root_node.lock_ref = 1 self.root_node.lock_ref = 1
self.evictable_size_ = 0 self.evictable_size_ = 0
def match_prefix(self, key): def match_prefix(self, key, **kwargs):
if self.disable: if self.disable:
return [], self.root_node return [], self.root_node
...@@ -90,6 +92,7 @@ class RadixCache: ...@@ -90,6 +92,7 @@ class RadixCache:
req_pool_idx, req_pool_idx,
del_in_memory_pool=True, del_in_memory_pool=True,
old_last_node=None, old_last_node=None,
**kwargs,
): ):
# Insert the request into radix cache # Insert the request into radix cache
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)] indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
......
...@@ -419,10 +419,6 @@ class ServerArgs: ...@@ -419,10 +419,6 @@ class ServerArgs:
self.dp_size > 1 and self.node_rank is not None self.dp_size > 1 and self.node_rank is not None
), "multi-node data parallel is not supported" ), "multi-node data parallel is not supported"
assert not (
self.chunked_prefill_size is not None and self.disable_radix_cache
), "chunked prefill is not supported with radix cache disabled currently"
@dataclasses.dataclass @dataclasses.dataclass
class PortArgs: class PortArgs:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment