Unverified Commit dc965db0 authored by Alex Chi Z's avatar Alex Chi Z Committed by GitHub
Browse files

make radix cache deterministic (#10721)


Signed-off-by: default avatarAlex Chi Z <iskyzh@gmail.com>
parent 817e46f4
...@@ -163,6 +163,7 @@ from sglang.srt.tracing.trace import ( ...@@ -163,6 +163,7 @@ from sglang.srt.tracing.trace import (
) )
from sglang.srt.two_batch_overlap import TboDPAttentionPreparer from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
from sglang.srt.utils import ( from sglang.srt.utils import (
DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE_CONFIG,
DynamicGradMode, DynamicGradMode,
broadcast_pyobj, broadcast_pyobj,
configure_gc_logger, configure_gc_logger,
...@@ -705,11 +706,7 @@ class Scheduler( ...@@ -705,11 +706,7 @@ class Scheduler(
self.truncation_align_size = None self.truncation_align_size = None
return return
backend_sizes = { env_var, default_size = DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE_CONFIG.get(
"flashinfer": ("SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE", 4096),
"triton": ("SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE", 4096),
}
env_var, default_size = backend_sizes.get(
self.server_args.attention_backend, (None, None) self.server_args.attention_backend, (None, None)
) )
self.truncation_align_size = ( self.truncation_align_size = (
...@@ -849,6 +846,7 @@ class Scheduler( ...@@ -849,6 +846,7 @@ class Scheduler(
disable=server_args.disable_radix_cache, disable=server_args.disable_radix_cache,
enable_kv_cache_events=self.enable_kv_cache_events, enable_kv_cache_events=self.enable_kv_cache_events,
eviction_policy=server_args.radix_eviction_policy, eviction_policy=server_args.radix_eviction_policy,
enable_deterministic_inference=server_args.enable_deterministic_inference,
is_eagle=self.spec_algorithm.is_eagle(), is_eagle=self.spec_algorithm.is_eagle(),
) )
......
from __future__ import annotations from __future__ import annotations
from sglang.srt.utils import DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE
""" """
Copyright 2023-2024 SGLang Team Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
...@@ -185,6 +187,7 @@ class RadixCache(BasePrefixCache): ...@@ -185,6 +187,7 @@ class RadixCache(BasePrefixCache):
disable: bool = False, disable: bool = False,
enable_kv_cache_events: bool = False, enable_kv_cache_events: bool = False,
eviction_policy: str = "lru", eviction_policy: str = "lru",
enable_deterministic_inference: bool = False,
is_eagle: bool = False, is_eagle: bool = False,
): ):
self.req_to_token_pool = req_to_token_pool self.req_to_token_pool = req_to_token_pool
...@@ -193,6 +196,8 @@ class RadixCache(BasePrefixCache): ...@@ -193,6 +196,8 @@ class RadixCache(BasePrefixCache):
self.disable = disable self.disable = disable
self.enable_kv_cache_events = enable_kv_cache_events self.enable_kv_cache_events = enable_kv_cache_events
self.kv_event_queue = [] self.kv_event_queue = []
self.enable_deterministic_inference = enable_deterministic_inference
self.split_size = DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE
self.is_eagle = is_eagle self.is_eagle = is_eagle
if self.token_to_kv_pool_allocator: if self.token_to_kv_pool_allocator:
...@@ -234,7 +239,9 @@ class RadixCache(BasePrefixCache): ...@@ -234,7 +239,9 @@ class RadixCache(BasePrefixCache):
self.protected_size_ = 0 self.protected_size_ = 0
self._record_all_cleared_event() self._record_all_cleared_event()
def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult: def match_prefix(
self, key: RadixKey, is_cache_unfinished: bool = False, **kwargs
) -> MatchResult:
"""Find the longest cached prefix of ``key`` in the radix tree. """Find the longest cached prefix of ``key`` in the radix tree.
The logical namespace for prefix matching is determined by both the The logical namespace for prefix matching is determined by both the
...@@ -295,7 +302,9 @@ class RadixCache(BasePrefixCache): ...@@ -295,7 +302,9 @@ class RadixCache(BasePrefixCache):
if len(key) == 0: if len(key) == 0:
return empty_match_result() return empty_match_result()
value, last_node = self._match_prefix_helper(self.root_node, key) value, last_node = self._match_prefix_helper(
self.root_node, key, is_cache_unfinished=is_cache_unfinished
)
if value: if value:
value = torch.cat(value) value = torch.cat(value)
else: else:
...@@ -418,7 +427,8 @@ class RadixCache(BasePrefixCache): ...@@ -418,7 +427,8 @@ class RadixCache(BasePrefixCache):
# The prefix indices could be updated, reuse it # The prefix indices could be updated, reuse it
new_indices, new_last_node, _, _ = self.match_prefix( new_indices, new_last_node, _, _ = self.match_prefix(
RadixKey(token_ids=page_aligned_token_ids, extra_key=req.extra_key) RadixKey(token_ids=page_aligned_token_ids, extra_key=req.extra_key),
is_cache_unfinished=True,
) )
self.req_to_token_pool.write( self.req_to_token_pool.write(
(req.req_pool_idx, slice(old_prefix_len, len(new_indices))), (req.req_pool_idx, slice(old_prefix_len, len(new_indices))),
...@@ -534,16 +544,58 @@ class RadixCache(BasePrefixCache): ...@@ -534,16 +544,58 @@ class RadixCache(BasePrefixCache):
##### Internal Helper Functions ##### ##### Internal Helper Functions #####
def _match_prefix_helper(self, node: TreeNode, key: RadixKey): def _match_prefix_helper(
self, node: TreeNode, key: RadixKey, is_cache_unfinished: bool
):
node.last_access_time = time.monotonic() node.last_access_time = time.monotonic()
child_key = self.get_child_key_fn(key) child_key = self.get_child_key_fn(key)
value = [] value = []
align_split_size = (
not is_cache_unfinished and self.enable_deterministic_inference
)
match_history = [node] if align_split_size else None
if align_split_size and len(key) < self.split_size:
# fast path: directly return the root node if the split point is 0
return value, node
# use the access history to first find a split point at split_size and then return the value and node at that point.
def reconstruct_at_split_point(match_history, value_len):
# reverse the search process to find the last node right above the split_size, split here
split_point = value_len // self.split_size * self.split_size
# rebuild value form history
value = []
current_value_len = 0
node = match_history[0] # this is the root node
for idx, node in enumerate(match_history):
match_len = len(node.value)
if current_value_len + match_len > split_point:
# split the node at the desired split point
node = self._split_node(
node.key, node, split_point - current_value_len
)
value.append(node.value)
return value, node
elif current_value_len + match_len == split_point:
if idx != 0:
value.append(node.value)
return value, node
current_value_len += match_len
if idx != 0:
# the root node always has empty value, skip
value.append(node.value)
# return the root node as the corresponding node doesn't exist yet
# and the previously computed node is not at the split boundary
return [], match_history[0]
while len(key) > 0 and child_key in node.children.keys(): while len(key) > 0 and child_key in node.children.keys():
child = node.children[child_key] child = node.children[child_key]
child.last_access_time = time.monotonic() child.last_access_time = time.monotonic()
prefix_len = self.key_match_fn(child.key, key) prefix_len = self.key_match_fn(child.key, key)
if align_split_size:
match_history.append(child)
if prefix_len < len(child.key): if prefix_len < len(child.key):
new_node = self._split_node(child.key, child, prefix_len) new_node = self._split_node(child.key, child, prefix_len)
value.append(new_node.value) value.append(new_node.value)
...@@ -557,6 +609,13 @@ class RadixCache(BasePrefixCache): ...@@ -557,6 +609,13 @@ class RadixCache(BasePrefixCache):
if len(key): if len(key):
child_key = self.get_child_key_fn(key) child_key = self.get_child_key_fn(key)
if align_split_size:
value_len = sum(map(len, value))
value, node = reconstruct_at_split_point(match_history, value_len)
assert (
sum(map(len, value)) % self.split_size == 0
), "The value length is not aligned with the split size"
return value, node return value, node
def _split_node(self, key: RadixKey, child: TreeNode, split_len: int): def _split_node(self, key: RadixKey, child: TreeNode, split_len: int):
......
...@@ -1381,13 +1381,6 @@ class ServerArgs: ...@@ -1381,13 +1381,6 @@ class ServerArgs:
f"Currently only {DETERMINISTIC_ATTENTION_BACKEND_CHOICES} attention backends are supported for deterministic inference." f"Currently only {DETERMINISTIC_ATTENTION_BACKEND_CHOICES} attention backends are supported for deterministic inference."
) )
# Currently, only FA3 supports radix cache. Support for other backends is in progress
if self.attention_backend != "fa3":
self.disable_radix_cache = True
logger.warning(
f"Currently radix cache is not compatible with {self.attention_backend} attention backend for deterministic inference. It will be supported in the future."
)
# Check TP size # Check TP size
if self.tp_size > 1: if self.tp_size > 1:
os.environ["NCCL_ALGO"] = "allreduce:tree" os.environ["NCCL_ALGO"] = "allreduce:tree"
......
...@@ -3441,3 +3441,16 @@ def cached_triton_kernel(key_fn=None): ...@@ -3441,3 +3441,16 @@ def cached_triton_kernel(key_fn=None):
return CachedKernel(fn, key_fn) return CachedKernel(fn, key_fn)
return decorator return decorator
DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE = 4096
DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE_CONFIG = {
"flashinfer": (
"SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE",
DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE,
),
"triton": (
"SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE",
DEFAULT_DETERMINISTIC_INFERENCE_BACKEND_SIZE,
),
}
...@@ -277,9 +277,10 @@ def test_deterministic(args): ...@@ -277,9 +277,10 @@ def test_deterministic(args):
elif args.test_mode == "prefix": elif args.test_mode == "prefix":
# In prefix mode, we create prompts from the same long prompt, with different lengths of common prefix. # In prefix mode, we create prompts from the same long prompt, with different lengths of common prefix.
len_prefix = [1, 511, 2048, 4097] len_prefix = [1, 8000, 10000, 12500]
num_prompts = len(len_prefix) num_prompts = len(len_prefix)
outputs = {i: [] for i in range(4)} outputs = {i: [] for i in range(4)}
assert all(i <= len(LONG_PROMPT) for i in len_prefix)
prompts = [LONG_PROMPT[: len_prefix[i]] for i in range(4)] prompts = [LONG_PROMPT[: len_prefix[i]] for i in range(4)]
for i in range(args.n_start, args.n_start + args.n_trials): for i in range(args.n_start, args.n_start + args.n_trials):
batch_size = i batch_size = i
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment