Unverified Commit 24bc3fb0 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

EAGLE cache fix for SWARadixCache (#11231)


Co-authored-by: default avatarHanming Lu <69857889+hanming-lu@users.noreply.github.com>
parent 8a8a608a
......@@ -777,6 +777,7 @@ class Scheduler(
sliding_window_size=self.sliding_window_size,
page_size=self.page_size,
disable=server_args.disable_radix_cache,
is_eagle=self.spec_algorithm.is_eagle(),
)
elif server_args.enable_lmcache:
from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import (
......
......@@ -274,10 +274,15 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
self.full_to_swa_index_mapping[free_index] = 0
def backup_state(self):
raise NotImplementedError
return [
self.full_attn_allocator.backup_state(),
self.swa_attn_allocator.backup_state(),
]
def restore_state(self, state):
raise NotImplementedError
assert len(state) == 2
self.full_attn_allocator.restore_state(state[0])
self.swa_attn_allocator.restore_state(state[1])
def clear(self):
self.swa_attn_allocator.clear()
......
......@@ -749,6 +749,7 @@ class SWAKVPool(KVCache):
self,
size: int,
size_swa: int,
dtype: torch.dtype,
swa_attention_layer_ids: List[int],
full_attention_layer_ids: List[int],
enable_kvcache_transpose: bool,
......@@ -757,6 +758,7 @@ class SWAKVPool(KVCache):
):
self.size = size
self.size_swa = size_swa
self.dtype = dtype
self.swa_layer_nums = len(swa_attention_layer_ids)
self.full_layer_nums = len(full_attention_layer_ids)
kwargs["page_size"] = 1
......@@ -766,11 +768,13 @@ class SWAKVPool(KVCache):
self.swa_kv_pool = token_to_kv_pool_class(
size=size_swa,
dtype=dtype,
layer_num=self.swa_layer_nums,
**kwargs,
)
self.full_kv_pool = token_to_kv_pool_class(
size=size,
dtype=dtype,
layer_num=self.full_layer_nums,
**kwargs,
)
......
......@@ -326,6 +326,8 @@ class RadixCache(BasePrefixCache):
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
all_token_len = len(token_ids)
# For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
# So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, :all_token_len
......@@ -349,7 +351,8 @@ class RadixCache(BasePrefixCache):
old_prefix_len = len(req.prefix_indices)
if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
# prefix_indices attached partial part (for page_size > 1) and one unmatched token (for EAGLE)
# In EAGLE chunked prefill case, the prefix_indices included one unmatched token (kv_indices[actual_kv_len:])
# Here we -1 to make sure the kv of the unmatched token can be freed correctly to avoid memory leak
old_prefix_len -= 1
# Radix Cache takes one ref in memory pool
......@@ -370,7 +373,8 @@ class RadixCache(BasePrefixCache):
token_ids = req.fill_ids
all_token_len = len(token_ids)
# The actual kv len for EAGLE is len(token_ids), since EAGLE uses bigram key
# For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
# So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, :all_token_len
......@@ -393,7 +397,8 @@ class RadixCache(BasePrefixCache):
old_prefix_len = len(req.prefix_indices)
if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
# prefix_indices attached partial part (for page_size > 1) and one unmatched token (for EAGLE)
# In EAGLE chunked prefill case, the prefix_indices included one unmatched token (kv_indices[actual_kv_len:])
# Here we -1 to make sure the kv of the unmatched token can be freed correctly to avoid memory leak
old_prefix_len -= 1
# Radix Cache takes one ref in memory pool
......
......@@ -32,6 +32,7 @@ from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.mem_cache.radix_cache import (
RadixKey,
_convert_to_bigram_key,
_key_match_page_size1,
_key_match_paged,
get_child_key,
......@@ -327,12 +328,14 @@ class SWARadixCache(BasePrefixCache):
sliding_window_size: int,
page_size: int,
disable: bool = False,
is_eagle: bool = False,
):
assert isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator)
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.page_size = page_size
self.disable = disable
self.is_eagle = is_eagle
if self.token_to_kv_pool_allocator:
self.device = self.token_to_kv_pool_allocator.device
......@@ -346,6 +349,11 @@ class SWARadixCache(BasePrefixCache):
self.key_match_fn = partial(_key_match_paged, page_size=page_size)
self.get_child_key_fn = partial(get_child_key, page_size=page_size)
if is_eagle:
self.key_convert_fn = _convert_to_bigram_key
else:
self.key_convert_fn = lambda key: key
self.sliding_window_size = sliding_window_size
self.reset()
......@@ -376,6 +384,8 @@ class SWARadixCache(BasePrefixCache):
The last node create a new child if the prefix is shorter
than the last node's value.
"""
key.token_ids = self.key_convert_fn(key.token_ids)
if self.disable or len(key) == 0:
return MatchResult(
device_indices=torch.empty(
......@@ -406,8 +416,15 @@ class SWARadixCache(BasePrefixCache):
if self.disable:
return 0
key.token_ids = self.key_convert_fn(key.token_ids)
if value is None:
value = torch.tensor([x for x in key.token_ids], dtype=torch.int64)
if self.is_eagle:
# Make sure the value len equal to the EAGLE bigram key len
value = value[: len(key)]
return self._insert_helper(self.root_node, key, value, prev_prefix_len)
def cache_finished_req(self, req: Req) -> None:
......@@ -422,25 +439,41 @@ class SWARadixCache(BasePrefixCache):
return
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
all_token_len = len(token_ids)
# For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
# So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids)
req.req_pool_idx, :all_token_len
]
if self.page_size != 1:
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
page_aligned_len = actual_kv_len // self.page_size * self.page_size
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
else:
page_aligned_len = len(kv_indices)
page_aligned_len = actual_kv_len
page_aligned_kv_indices = kv_indices.clone()
if self.is_eagle:
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
page_aligned_token_len = (
page_aligned_len + 1 if self.is_eagle else page_aligned_len
)
old_prefix_len = len(req.prefix_indices)
if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
# In EAGLE chunked prefill case, the prefix_indices included one unmatched token (kv_indices[actual_kv_len:])
# Here we -1 to make sure the kv of the unmatched token can be freed correctly to avoid memory leak
old_prefix_len -= 1
# Radix Cache takes one ref in memory pool
# insert the token_ids and kv_indices into the radix tree
# Note: the insert function already frees the overlapped kv_indices
new_prefix_len = self.insert(
RadixKey(token_ids[:page_aligned_len], req.extra_key),
RadixKey(token_ids[:page_aligned_token_len], req.extra_key),
page_aligned_kv_indices,
len(req.prefix_indices),
old_prefix_len,
)
# Remove req slot release the cache lock
......@@ -459,39 +492,56 @@ class SWARadixCache(BasePrefixCache):
return
token_ids = req.fill_ids
all_token_len = len(token_ids)
# For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
# So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids)
req.req_pool_idx, :all_token_len
]
if self.page_size != 1:
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
page_aligned_len = actual_kv_len // self.page_size * self.page_size
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
else:
page_aligned_len = len(kv_indices)
page_aligned_len = actual_kv_len
page_aligned_kv_indices = kv_indices.clone()
page_aligned_token_ids = token_ids[:page_aligned_len]
# For EAGLE, the page_aligned_len is for the bigram key, the normal key len should +1
page_aligned_token_len = (
page_aligned_len + 1 if self.is_eagle else page_aligned_len
)
page_aligned_token_ids = token_ids[:page_aligned_token_len]
old_prefix_len = len(req.prefix_indices)
if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
# In EAGLE chunked prefill case, the prefix_indices included one unmatched token (kv_indices[actual_kv_len:])
# Here we -1 to make sure the kv of the unmatched token can be freed correctly to avoid memory leak
old_prefix_len -= 1
# Radix Cache takes one ref in memory pool
# Note: the insert function already frees the overlapped kv_indices
new_prefix_len = self.insert(
RadixKey(page_aligned_token_ids, req.extra_key),
page_aligned_kv_indices,
len(req.prefix_indices),
old_prefix_len,
)
# The prefix indices could be updated, reuse it
new_indices, new_last_node, _, _ = self.match_prefix(
RadixKey(page_aligned_token_ids, req.extra_key)
)
assert len(req.prefix_indices) <= len(
assert old_prefix_len <= len(
new_indices
), f"{req.prefix_indices=}, {new_indices=}"
assert new_prefix_len <= len(new_indices), f"{new_prefix_len=}, {new_indices=}"
self.req_to_token_pool.write(
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
new_indices[len(req.prefix_indices) :],
(req.req_pool_idx, slice(old_prefix_len, len(new_indices))),
new_indices[old_prefix_len:],
)
req.last_matched_prefix_len = len(new_indices)
self.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
swa_uuid_for_lock = self.inc_lock_ref(new_last_node)
......@@ -501,7 +551,13 @@ class SWARadixCache(BasePrefixCache):
[new_indices, kv_indices[len(new_indices) :]]
)
else:
req.prefix_indices = new_indices
if self.is_eagle:
# Attach the kv index of the last token for EAGLE, it can be used in chunked prefill
req.prefix_indices = torch.cat(
[new_indices, kv_indices[actual_kv_len:]]
)
else:
req.prefix_indices = new_indices
req.last_node = new_last_node
req.swa_uuid_for_lock = swa_uuid_for_lock
......
......@@ -27,7 +27,11 @@ if _is_cuda:
def enable_fused_set_kv_buffer(forward_batch: ForwardBatch):
"""Enable fused set_kv_buffer only on CUDA with bfloat16 KV cache."""
return _is_cuda and forward_batch.token_to_kv_pool.dtype == torch.bfloat16
return (
_is_cuda
and hasattr(forward_batch.token_to_kv_pool, "dtype")
and forward_batch.token_to_kv_pool.dtype == torch.bfloat16
)
def create_fused_set_kv_buffer_arg(
......
......@@ -113,6 +113,7 @@ suites = {
TestFile("test_srt_engine.py", 261),
TestFile("test_srt_endpoint.py", 130),
TestFile("test_start_profile.py", 60),
TestFile("test_swa_unittest.py", 1),
TestFile("test_torch_compile.py", 76),
TestFile("test_torch_compile_moe.py", 172),
TestFile("test_torch_native_attention_backend.py", 123),
......
......@@ -50,9 +50,14 @@ class TestSWA(unittest.TestCase):
kvcache=pool,
need_sort=False,
)
assert alloc.available_size() == size + size_swa
self.assertEqual(
alloc.full_available_size() + alloc.swa_available_size(), size + size_swa
)
index = alloc.alloc(1)
assert alloc.available_size() == size_swa + size_swa - 2
self.assertEqual(
alloc.full_available_size() + alloc.swa_available_size(),
size_swa + size_swa - 2,
)
alloc.free_swa(index)
result = alloc.translate_loc_from_full_to_swa(index)
print(result)
......@@ -117,38 +122,174 @@ class TestSWA(unittest.TestCase):
f"[Start] allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}"
)
req1_token_ids, req1_kv_indices = [1, 2, 3], allocator.alloc(3)
assert len(req1_token_ids) == len(req1_kv_indices)
self.assertEqual(len(req1_token_ids), len(req1_kv_indices))
print(
f"req1: inserting, req1_token_ids: {req1_token_ids}, req1_kv_indices: {req1_kv_indices}"
)
prefix_len = tree.insert(RadixKey(req1_token_ids), req1_kv_indices)
print(
f"req1: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}"
)
req2_token_ids, req2_kv_indices = [1, 2, 3, 4, 5, 6, 7], allocator.alloc(7)
self.assertEqual(len(req2_token_ids), len(req2_kv_indices))
print(
f"req2: inserting, req2_token_ids: {req2_token_ids}, req2_kv_indices: {req2_kv_indices}"
)
prefix_len = tree.insert(RadixKey(req2_token_ids), req2_kv_indices)
print(
f"req2: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}"
)
req3_token_ids, req3_kv_indices = [10, 11, 12], allocator.alloc(3)
self.assertEqual(len(req3_token_ids), len(req3_kv_indices))
print(
f"req3: inserting, req3_token_ids: {req3_token_ids}, req3_kv_indices: {req3_kv_indices}"
)
prefix_len = tree.insert(RadixKey(req3_token_ids), req3_kv_indices)
print(
f"req3: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}"
)
req4_token_ids, req4_kv_indices = [1, 2, 3, 4, 5, 60, 70], allocator.alloc(7)
self.assertEqual(len(req4_token_ids), len(req4_kv_indices))
print(
f"req4: inserting, req4_token_ids: {req4_token_ids}, req4_kv_indices: {req4_kv_indices}"
)
prefix_len = tree.insert(RadixKey(req4_token_ids), req4_kv_indices)
print(
f"req4: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}"
)
tree.pretty_print()
full_num_tokens, swa_num_tokens = 1, 0
print(f"evicting {full_num_tokens} full token and {swa_num_tokens} swa token")
tree.evict(full_num_tokens=full_num_tokens, swa_num_tokens=swa_num_tokens)
tree.pretty_print()
full_num_tokens, swa_num_tokens = 0, 1
print(f"evicting {full_num_tokens} full token and {swa_num_tokens} swa token")
tree.evict(full_num_tokens=full_num_tokens, swa_num_tokens=swa_num_tokens)
tree.pretty_print()
full_num_tokens, swa_num_tokens = 1, 2
print(f"evicting {full_num_tokens} full token and {swa_num_tokens} swa token")
tree.evict(full_num_tokens=full_num_tokens, swa_num_tokens=swa_num_tokens)
tree.pretty_print()
req5_token_ids = [1, 2, 3, 4, 5]
result = tree.match_prefix(RadixKey(req5_token_ids))
kv_indices, last_node = result.device_indices, result.last_device_node
print(
f"req5: token_ids: {req5_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}"
)
self.assertEqual(len(kv_indices), 0)
req6_token_ids = [1, 2, 3, 4, 5, 60, 70]
result = tree.match_prefix(RadixKey(req6_token_ids))
kv_indices, last_node = result.device_indices, result.last_device_node
print(
f"req6: token_ids: {req6_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}"
)
self.assertEqual(len(kv_indices), 7)
self.assertEqual(len(last_node.key), 2)
self.assertEqual(last_node.key.token_ids[0], 60)
self.assertEqual(last_node.key.token_ids[1], 70)
def test_swa_radix_cache_eagle(self):
# args
req_size = 10
max_context_len = 128
kv_size = 128
kv_size_swa = 64
sliding_window_size = 4
head_num = 8
head_dim = 128
num_layers = 48
global_interval = 4
dtype = torch.bfloat16
device = "cuda"
full_attention_layer_ids = [i for i in range(0, num_layers, global_interval)]
full_attention_layer_ids_set = set(full_attention_layer_ids)
swa_attention_layer_ids = [
i for i in range(num_layers) if i not in full_attention_layer_ids_set
]
# setup req to token pool
req_to_token_pool = ReqToTokenPool(
size=req_size,
max_context_len=max_context_len,
device=device,
enable_memory_saver=False,
)
# setup kv pool
kv_pool = SWAKVPool(
size=kv_size,
size_swa=kv_size_swa,
dtype=dtype,
head_num=head_num,
head_dim=head_dim,
swa_attention_layer_ids=swa_attention_layer_ids,
full_attention_layer_ids=full_attention_layer_ids,
enable_kvcache_transpose=False,
device=device,
)
# setup token to kv pool allocator
allocator = SWATokenToKVPoolAllocator(
size=kv_size,
size_swa=kv_size_swa,
dtype=dtype,
device=device,
kvcache=kv_pool,
need_sort=False,
)
# setup radix cache
tree = SWARadixCache(
req_to_token_pool=req_to_token_pool,
token_to_kv_pool_allocator=allocator,
sliding_window_size=sliding_window_size,
page_size=1,
disable=False,
is_eagle=True,
)
# test
print(
f"[Start] allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}"
)
req1_token_ids, req1_kv_indices = [1, 2, 3], allocator.alloc(3)
self.assertEqual(len(req1_token_ids), len(req1_kv_indices))
print(
f"req1: inserting, req1_token_ids: {req1_token_ids}, req1_kv_indices: {req1_kv_indices}"
)
prefix_len = tree.insert(RadixKey(req1_token_ids), req1_kv_indices)
self.assertEqual(prefix_len, 0)
print(
f"req1: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}"
)
req2_token_ids, req2_kv_indices = [1, 2, 3, 4, 5, 6, 7], allocator.alloc(7)
assert len(req2_token_ids) == len(req2_kv_indices)
self.assertEqual(len(req2_token_ids), len(req2_kv_indices))
print(
f"req2: inserting, req2_token_ids: {req2_token_ids}, req2_kv_indices: {req2_kv_indices}"
)
prefix_len = tree.insert(RadixKey(req2_token_ids), req2_kv_indices)
self.assertEqual(prefix_len, 2)
print(
f"req2: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}"
)
req3_token_ids, req3_kv_indices = [10, 11, 12], allocator.alloc(3)
assert len(req3_token_ids) == len(req3_kv_indices)
self.assertEqual(len(req3_token_ids), len(req3_kv_indices))
print(
f"req3: inserting, req3_token_ids: {req3_token_ids}, req3_kv_indices: {req3_kv_indices}"
)
prefix_len = tree.insert(RadixKey(req3_token_ids), req3_kv_indices)
self.assertEqual(prefix_len, 0)
print(
f"req3: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}"
)
req4_token_ids, req4_kv_indices = [1, 2, 3, 4, 5, 60, 70], allocator.alloc(7)
assert len(req4_token_ids) == len(req4_kv_indices)
self.assertEqual(len(req4_token_ids), len(req4_kv_indices))
print(
f"req4: inserting, req4_token_ids: {req4_token_ids}, req4_kv_indices: {req4_kv_indices}"
)
prefix_len = tree.insert(RadixKey(req4_token_ids), req4_kv_indices)
self.assertEqual(prefix_len, 4)
print(
f"req4: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}"
)
......@@ -175,7 +316,7 @@ class TestSWA(unittest.TestCase):
print(
f"req5: token_ids: {req5_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}"
)
assert len(kv_indices) == 0
self.assertEqual(len(kv_indices), 0) # no swa prefix matched
req6_token_ids = [1, 2, 3, 4, 5, 60, 70]
result = tree.match_prefix(RadixKey(req6_token_ids))
......@@ -183,10 +324,10 @@ class TestSWA(unittest.TestCase):
print(
f"req6: token_ids: {req6_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}"
)
assert len(kv_indices) == 7
assert len(last_node.key) == 2
assert last_node.key.token_ids[0] == 60
assert last_node.key.token_ids[1] == 70
self.assertEqual(len(kv_indices), 6)
self.assertEqual(len(last_node.key), 2)
self.assertEqual(last_node.key.token_ids[0], (5, 60))
self.assertEqual(last_node.key.token_ids[1], (60, 70))
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment