Unverified Commit c020f9ce authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Support chunked prefill when radix cache is disabled (#811)

parent ca600e8c
......@@ -13,12 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
"""Base cache class."""
"""Base tool cache for constrained decoding tools."""
import time
class BaseCache:
class BaseToolCache:
def __init__(self, enable=True):
self.enable = enable
self.reset()
......
......@@ -16,10 +16,10 @@ limitations under the License.
"""Cache for the compressed finite state machine."""
from sglang.srt.constrained import RegexGuide, TransformerTokenizer
from sglang.srt.constrained.base_cache import BaseCache
from sglang.srt.constrained.base_tool_cache import BaseToolCache
class FSMCache(BaseCache):
class FSMCache(BaseToolCache):
def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True):
super().__init__(enable=enable)
......
......@@ -30,7 +30,7 @@ from sglang.srt.constrained import (
make_byte_level_fsm,
make_deterministic_fsm,
)
from sglang.srt.constrained.base_cache import BaseCache
from sglang.srt.constrained.base_tool_cache import BaseToolCache
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
......@@ -151,7 +151,7 @@ class JumpForwardMap:
)
class JumpForwardCache(BaseCache):
class JumpForwardCache(BaseToolCache):
def __init__(self):
super().__init__()
......
......@@ -28,6 +28,7 @@ from flashinfer.sampling import top_k_top_p_sampling_from_probs
from sglang.global_config import global_config
from sglang.srt.constrained import RegexGuide
from sglang.srt.constrained.jump_forward import JumpForwardMap
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.mem_cache.radix_cache import RadixCache
......@@ -486,15 +487,33 @@ class Batch:
req = self.reqs[idx]
retracted_reqs.append(req)
# TODO: apply more fine-grained retraction
last_uncached_pos = len(req.prefix_indices)
token_indices = self.req_to_token_pool.req_to_token[
req_pool_indices_cpu[idx]
][last_uncached_pos : seq_lens_cpu[idx]]
self.token_to_kv_pool.free(token_indices)
# release the last node
self.tree_cache.dec_lock_ref(req.last_node)
if isinstance(self.tree_cache, ChunkCache):
# ChunkCache does not have eviction
token_indices = self.req_to_token_pool.req_to_token[
req_pool_indices_cpu[idx]
][: seq_lens_cpu[idx]]
self.token_to_kv_pool.free(token_indices)
self.req_to_token_pool.free(int(req_pool_indices_cpu[idx]))
del self.tree_cache.entries[req.rid]
else:
# TODO: apply more fine-grained retraction
last_uncached_pos = len(req.prefix_indices)
token_indices = self.req_to_token_pool.req_to_token[
req_pool_indices_cpu[idx]
][last_uncached_pos : seq_lens_cpu[idx]]
self.token_to_kv_pool.free(token_indices)
self.req_to_token_pool.free(int(req_pool_indices_cpu[idx]))
# release the last node
self.tree_cache.dec_lock_ref(req.last_node)
# NOTE(lsyin): we should use the newly evictable memory instantly.
residual_size = (
len(sorted_indices) * global_config.retract_decode_steps
- self.token_to_kv_pool.available_size()
)
residual_size = max(0, residual_size)
self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
req.prefix_indices = None
req.last_node = None
......@@ -575,6 +594,7 @@ class Batch:
if req_pool_indices_cpu is None:
req_pool_indices_cpu = self.req_pool_indices.tolist()
self.tree_cache.cache_req(
rid=req.rid,
token_ids=cur_all_ids,
last_uncached_pos=len(req.prefix_indices),
req_pool_idx=req_pool_indices_cpu[i],
......
......@@ -43,6 +43,7 @@ from sglang.srt.managers.schedule_batch import (
ForwardMode,
Req,
)
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.model_config import ModelConfig
from sglang.srt.model_executor.model_runner import ModelRunner
......@@ -144,11 +145,20 @@ class ModelTpServer:
)
# Init cache
self.tree_cache = RadixCache(
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool,
disable=server_args.disable_radix_cache,
)
if (
server_args.chunked_prefill_size is not None
and server_args.disable_radix_cache
):
self.tree_cache = ChunkCache(
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool,
)
else:
self.tree_cache = RadixCache(
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool,
disable=server_args.disable_radix_cache,
)
self.tree_cache_metrics = {"total": 0, "hit": 0}
self.scheduler = PolicyScheduler(
self.schedule_policy,
......@@ -354,7 +364,10 @@ class ModelTpServer:
# Compute matched prefix length
for req in self.waiting_queue:
req.input_ids = req.origin_input_ids + req.output_ids
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
prefix_indices, last_node = self.tree_cache.match_prefix(
rid=req.rid,
key=req.input_ids,
)
if req.return_logprob:
prefix_indices = prefix_indices[: req.logprob_start_len]
req.extend_input_len = len(req.input_ids) - len(prefix_indices)
......@@ -614,6 +627,7 @@ class ModelTpServer:
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
for i, req in enumerate(batch.reqs):
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
rid=req.rid,
token_ids=tuple(req.input_ids),
last_uncached_pos=len(req.prefix_indices),
req_pool_idx=req_pool_indices_cpu[i],
......@@ -771,6 +785,7 @@ class ModelTpServer:
for i in finished_indices:
req = batch.reqs[i]
self.tree_cache.cache_req(
rid=req.rid,
token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
last_uncached_pos=len(req.prefix_indices),
req_pool_idx=req_pool_indices_cpu[i],
......
from abc import ABC, abstractmethod
class BasePrefixCache(ABC):
"""Cache can be indexed by either rid or key."""
@abstractmethod
def reset(self):
pass
@abstractmethod
def match_prefix(self, **kwargs):
pass
@abstractmethod
def insert(self, **kwargs):
pass
@abstractmethod
def cache_req(self, **kwargs):
pass
@abstractmethod
def evict(self, num_tokens, evict_callback):
pass
@abstractmethod
def inc_lock_ref(self, node):
pass
@abstractmethod
def dec_lock_ref(self, node):
pass
@abstractmethod
def evictable_size(self):
pass
def total_size(self):
raise NotImplementedError
def pretty_print(self):
raise NotImplementedError
"""Cache for chunked prefill, used when RadixCache is disabled."""
from sglang.srt.mem_cache.base_cache import BasePrefixCache
class ChunkCacheEntry:
def __init__(self, rid, value):
self.rid = rid
self.value = value
class ChunkCache(BasePrefixCache):
def __init__(self, req_to_token_pool, token_to_kv_pool):
self.disable = True
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool = token_to_kv_pool
self.reset()
def reset(self):
self.entries = {}
def match_prefix(self, rid, **kwargs):
if rid not in self.entries:
return [], None
entry = self.entries[rid]
return entry.value, entry
def cache_req(
self, rid, token_ids, req_pool_idx, del_in_memory_pool=True, **kwargs
):
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
if del_in_memory_pool:
assert rid in self.entries
self.req_to_token_pool.free(req_pool_idx)
self.token_to_kv_pool.free(indices)
return
if rid not in self.entries:
self.entries[rid] = ChunkCacheEntry(rid, indices)
entry = self.entries[rid]
entry.value = indices
return indices, entry
def insert(self):
raise NotImplementedError
def evict(self, num_tokens, evict_callback):
pass
def inc_lock_ref(self, node):
return 0
def dec_lock_ref(self, node):
return 0
def evictable_size(self):
return 0
......@@ -23,6 +23,8 @@ from collections import defaultdict
import torch
from sglang.srt.mem_cache.base_cache import BasePrefixCache
class TreeNode:
def __init__(self):
......@@ -46,7 +48,7 @@ def _key_match(key0, key1):
return i
class RadixCache:
class RadixCache(BasePrefixCache):
def __init__(self, req_to_token_pool, token_to_kv_pool, disable: bool = False):
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool = token_to_kv_pool
......@@ -62,7 +64,7 @@ class RadixCache:
self.root_node.lock_ref = 1
self.evictable_size_ = 0
def match_prefix(self, key):
def match_prefix(self, key, **kwargs):
if self.disable:
return [], self.root_node
......@@ -90,6 +92,7 @@ class RadixCache:
req_pool_idx,
del_in_memory_pool=True,
old_last_node=None,
**kwargs,
):
# Insert the request into radix cache
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
......
......@@ -419,10 +419,6 @@ class ServerArgs:
self.dp_size > 1 and self.node_rank is not None
), "multi-node data parallel is not supported"
assert not (
self.chunked_prefill_size is not None and self.disable_radix_cache
), "chunked prefill is not supported with radix cache disabled currently"
@dataclasses.dataclass
class PortArgs:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment