Unverified Commit dc965db0 authored by Alex Chi Z's avatar Alex Chi Z Committed by GitHub
Browse files

make radix cache deterministic (#10721)


Signed-off-by: default avatarAlex Chi Z <iskyzh@gmail.com>
parent 817e46f4
......@@ -163,6 +163,7 @@ from sglang.srt.tracing.trace import (
)
from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
from sglang.srt.utils import (
DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE_CONFIG,
DynamicGradMode,
broadcast_pyobj,
configure_gc_logger,
......@@ -705,11 +706,7 @@ class Scheduler(
self.truncation_align_size = None
return
backend_sizes = {
"flashinfer": ("SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE", 4096),
"triton": ("SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE", 4096),
}
env_var, default_size = backend_sizes.get(
env_var, default_size = DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE_CONFIG.get(
self.server_args.attention_backend, (None, None)
)
self.truncation_align_size = (
......@@ -849,6 +846,7 @@ class Scheduler(
disable=server_args.disable_radix_cache,
enable_kv_cache_events=self.enable_kv_cache_events,
eviction_policy=server_args.radix_eviction_policy,
enable_deterministic_inference=server_args.enable_deterministic_inference,
is_eagle=self.spec_algorithm.is_eagle(),
)
......
from __future__ import annotations
from sglang.srt.utils import DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
......@@ -185,6 +187,7 @@ class RadixCache(BasePrefixCache):
disable: bool = False,
enable_kv_cache_events: bool = False,
eviction_policy: str = "lru",
enable_deterministic_inference: bool = False,
is_eagle: bool = False,
):
self.req_to_token_pool = req_to_token_pool
......@@ -193,6 +196,8 @@ class RadixCache(BasePrefixCache):
self.disable = disable
self.enable_kv_cache_events = enable_kv_cache_events
self.kv_event_queue = []
self.enable_deterministic_inference = enable_deterministic_inference
self.split_size = DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE
self.is_eagle = is_eagle
if self.token_to_kv_pool_allocator:
......@@ -234,7 +239,9 @@ class RadixCache(BasePrefixCache):
self.protected_size_ = 0
self._record_all_cleared_event()
def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult:
def match_prefix(
self, key: RadixKey, is_cache_unfinished: bool = False, **kwargs
) -> MatchResult:
"""Find the longest cached prefix of ``key`` in the radix tree.
The logical namespace for prefix matching is determined by both the
......@@ -295,7 +302,9 @@ class RadixCache(BasePrefixCache):
if len(key) == 0:
return empty_match_result()
value, last_node = self._match_prefix_helper(self.root_node, key)
value, last_node = self._match_prefix_helper(
self.root_node, key, is_cache_unfinished=is_cache_unfinished
)
if value:
value = torch.cat(value)
else:
......@@ -418,7 +427,8 @@ class RadixCache(BasePrefixCache):
# The prefix indices could be updated, reuse it
new_indices, new_last_node, _, _ = self.match_prefix(
RadixKey(token_ids=page_aligned_token_ids, extra_key=req.extra_key)
RadixKey(token_ids=page_aligned_token_ids, extra_key=req.extra_key),
is_cache_unfinished=True,
)
self.req_to_token_pool.write(
(req.req_pool_idx, slice(old_prefix_len, len(new_indices))),
......@@ -534,16 +544,58 @@ class RadixCache(BasePrefixCache):
##### Internal Helper Functions #####
def _match_prefix_helper(self, node: TreeNode, key: RadixKey):
def _match_prefix_helper(
self, node: TreeNode, key: RadixKey, is_cache_unfinished: bool
):
node.last_access_time = time.monotonic()
child_key = self.get_child_key_fn(key)
value = []
align_split_size = (
not is_cache_unfinished and self.enable_deterministic_inference
)
match_history = [node] if align_split_size else None
if align_split_size and len(key) < self.split_size:
# fast path: directly return the root node if the split point is 0
return value, node
# use the access history to first find a split point at split_size and then return the value and node at that point.
def reconstruct_at_split_point(match_history, value_len):
# reverse the search process to find the last node right above the split_size, split here
split_point = value_len // self.split_size * self.split_size
# rebuild value form history
value = []
current_value_len = 0
node = match_history[0] # this is the root node
for idx, node in enumerate(match_history):
match_len = len(node.value)
if current_value_len + match_len > split_point:
# split the node at the desired split point
node = self._split_node(
node.key, node, split_point - current_value_len
)
value.append(node.value)
return value, node
elif current_value_len + match_len == split_point:
if idx != 0:
value.append(node.value)
return value, node
current_value_len += match_len
if idx != 0:
# the root node always has empty value, skip
value.append(node.value)
# return the root node as the corresponding node doesn't exist yet
# and the previously computed node is not at the split boundary
return [], match_history[0]
while len(key) > 0 and child_key in node.children.keys():
child = node.children[child_key]
child.last_access_time = time.monotonic()
prefix_len = self.key_match_fn(child.key, key)
if align_split_size:
match_history.append(child)
if prefix_len < len(child.key):
new_node = self._split_node(child.key, child, prefix_len)
value.append(new_node.value)
......@@ -557,6 +609,13 @@ class RadixCache(BasePrefixCache):
if len(key):
child_key = self.get_child_key_fn(key)
if align_split_size:
value_len = sum(map(len, value))
value, node = reconstruct_at_split_point(match_history, value_len)
assert (
sum(map(len, value)) % self.split_size == 0
), "The value length is not aligned with the split size"
return value, node
def _split_node(self, key: RadixKey, child: TreeNode, split_len: int):
......
......@@ -1381,13 +1381,6 @@ class ServerArgs:
f"Currently only {DETERMINISTIC_ATTENTION_BACKEND_CHOICES} attention backends are supported for deterministic inference."
)
# Currently, only FA3 supports radix cache. Support for other backends is in progress
if self.attention_backend != "fa3":
self.disable_radix_cache = True
logger.warning(
f"Currently radix cache is not compatible with {self.attention_backend} attention backend for deterministic inference. It will be supported in the future."
)
# Check TP size
if self.tp_size > 1:
os.environ["NCCL_ALGO"] = "allreduce:tree"
......
......@@ -3441,3 +3441,16 @@ def cached_triton_kernel(key_fn=None):
return CachedKernel(fn, key_fn)
return decorator
DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE = 4096
DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE_CONFIG = {
"flashinfer": (
"SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE",
DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE,
),
"triton": (
"SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE",
DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE,
),
}
......@@ -277,9 +277,10 @@ def test_deterministic(args):
elif args.test_mode == "prefix":
# In prefix mode, we create prompts from the same long prompt, with different lengths of common prefix.
len_prefix = [1, 511, 2048, 4097]
len_prefix = [1, 8000, 10000, 12500]
num_prompts = len(len_prefix)
outputs = {i: [] for i in range(4)}
assert all(i <= len(LONG_PROMPT) for i in len_prefix)
prompts = [LONG_PROMPT[: len_prefix[i]] for i in range(4)]
for i in range(args.n_start, args.n_start + args.n_trials):
batch_size = i
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment