Unverified Commit 91847e38 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Fix eagle radix cache (#10846)

parent 5a290a56
......@@ -547,6 +547,8 @@ class Req:
self.host_hit_length = 0
# The node to lock until for swa radix tree lock ref
self.swa_uuid_for_lock: Optional[int] = None
# The prefix length of the last prefix matching
self.last_matched_prefix_len: int = 0
# Whether or not if it is chunked. It increments whenever
# it is chunked, and decrement whenever chunked request is
......@@ -701,6 +703,7 @@ class Req:
token_ids=self.adjust_max_prefix_ids(), extra_key=self.extra_key
),
)
self.last_matched_prefix_len = len(self.prefix_indices)
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
def adjust_max_prefix_ids(self):
......
......@@ -756,6 +756,7 @@ class Scheduler(
disable=server_args.disable_radix_cache,
enable_kv_cache_events=self.enable_kv_cache_events,
eviction_policy=server_args.radix_eviction_policy,
is_eagle=self.spec_algorithm.is_eagle(),
)
if (
......
......@@ -23,7 +23,7 @@ import heapq
import time
from collections import defaultdict
from functools import partial
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple, Union
import torch
......@@ -159,6 +159,16 @@ def get_child_key(key: RadixKey, page_size: int = 1):
return (key.extra_key, plain_key)
def _convert_to_bigram_key(tokens: List[int]) -> List[Tuple[int, int]]:
# EAGLE uses bigram keys in the radix tree since draft sequence is the one-token-shifted version of target
# [1, 2, 3, 4] -> [(1,2), (2,3), (3,4)]
if len(tokens) < 2:
return []
if isinstance(tokens[0], tuple):
return tokens
return [(tokens[i], tokens[i + 1]) for i in range(len(tokens) - 1)]
class RadixCache(BasePrefixCache):
def __init__(
self,
......@@ -168,6 +178,7 @@ class RadixCache(BasePrefixCache):
disable: bool = False,
enable_kv_cache_events: bool = False,
eviction_policy: str = "lru",
is_eagle: bool = False,
):
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
......@@ -175,6 +186,7 @@ class RadixCache(BasePrefixCache):
self.disable = disable
self.enable_kv_cache_events = enable_kv_cache_events
self.kv_event_queue = []
self.is_eagle = is_eagle
if self.token_to_kv_pool_allocator:
self.device = self.token_to_kv_pool_allocator.device
......@@ -188,6 +200,11 @@ class RadixCache(BasePrefixCache):
self.key_match_fn = partial(_key_match_paged, page_size=page_size)
self.get_child_key_fn = partial(get_child_key, page_size=page_size)
if is_eagle:
self.key_convert_fn = _convert_to_bigram_key
else:
self.key_convert_fn = lambda key: key
if eviction_policy.lower() == "lru":
self.eviction_strategy: EvictionStrategy = LRUStrategy()
elif eviction_policy.lower() == "lfu":
......@@ -248,6 +265,8 @@ class RadixCache(BasePrefixCache):
to expose a precise boundary; this structural refinement improves
subsequent match efficiency and does not duplicate data.
"""
key.token_ids = self.key_convert_fn(key.token_ids)
if self.disable or len(key) == 0:
return MatchResult(
device_indices=torch.empty(
......@@ -278,8 +297,15 @@ class RadixCache(BasePrefixCache):
if self.disable:
return 0
key.token_ids = self.key_convert_fn(key.token_ids)
if value is None:
value = torch.tensor(key.token_ids, dtype=torch.int64)
if self.is_eagle:
# Make sure the value len equal to the EAGLE bigram key len
value = value[: len(key)]
return self._insert_helper(self.root_node, key, value)
def cache_finished_req(self, req: Req):
......@@ -293,28 +319,39 @@ class RadixCache(BasePrefixCache):
return
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
all_token_len = len(token_ids)
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids)
req.req_pool_idx, :all_token_len
]
if self.page_size != 1:
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
page_aligned_len = actual_kv_len // self.page_size * self.page_size
page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
dtype=torch.int64, copy=True
)
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
else:
page_aligned_len = len(kv_indices)
page_aligned_len = actual_kv_len
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
if self.is_eagle:
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
page_aligned_token_len = (
page_aligned_len + 1 if self.is_eagle else page_aligned_len
)
old_prefix_len = len(req.prefix_indices)
if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
# prefix_indices attached partial part (for page_size > 1) and one unmatched token (for EAGLE)
old_prefix_len -= 1
# Radix Cache takes one ref in memory pool
new_prefix_len = self.insert(
RadixKey(token_ids[:page_aligned_len], req.extra_key),
RadixKey(token_ids[:page_aligned_token_len], req.extra_key),
page_aligned_kv_indices,
)
self.token_to_kv_pool_allocator.free(
kv_indices[len(req.prefix_indices) : new_prefix_len]
)
self.token_to_kv_pool_allocator.free(kv_indices[old_prefix_len:new_prefix_len])
# Remove req slot release the cache lock
self.req_to_token_pool.free(req.req_pool_idx)
......@@ -326,19 +363,32 @@ class RadixCache(BasePrefixCache):
return
token_ids = req.fill_ids
all_token_len = len(token_ids)
# The actual kv len for EAGLE is len(token_ids), since EAGLE uses bigram key
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids)
req.req_pool_idx, :all_token_len
]
if self.page_size != 1:
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
page_aligned_len = actual_kv_len // self.page_size * self.page_size
page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
dtype=torch.int64, copy=True
)
else:
page_aligned_len = len(kv_indices)
page_aligned_len = actual_kv_len
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
page_aligned_token_ids = token_ids[:page_aligned_len]
# For EAGLE, the page_aligned_len is for the bigram key, the normal key len should +1
page_aligned_token_len = (
page_aligned_len + 1 if self.is_eagle else page_aligned_len
)
page_aligned_token_ids = token_ids[:page_aligned_token_len]
old_prefix_len = len(req.prefix_indices)
if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
# prefix_indices attached partial part (for page_size > 1) and one unmatched token (for EAGLE)
old_prefix_len -= 1
# Radix Cache takes one ref in memory pool
new_prefix_len = self.insert(
......@@ -346,27 +396,38 @@ class RadixCache(BasePrefixCache):
page_aligned_kv_indices,
chunked=chunked,
)
self.token_to_kv_pool_allocator.free(
kv_indices[len(req.prefix_indices) : new_prefix_len]
)
self.token_to_kv_pool_allocator.free(kv_indices[old_prefix_len:new_prefix_len])
# The prefix indices could be updated, reuse it
new_indices, new_last_node, _, _ = self.match_prefix(
RadixKey(token_ids=page_aligned_token_ids, extra_key=req.extra_key)
)
self.req_to_token_pool.write(
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
new_indices[len(req.prefix_indices) :],
(req.req_pool_idx, slice(old_prefix_len, len(new_indices))),
new_indices[old_prefix_len:],
)
# The last_matched_prefix_len is not always equal to len(req.prefix_indices)
# since for page_size > 1, the partial part is added to req.prefix_indices, but that part of kv indices is not added to the tree.
# It should be freed in the next cache_unfinished_req and final cache_finished_req to avoid memory leak.
# So we introduce this `last_matched_prefix_len` field to make sure the partial part can be freed correctly.
req.last_matched_prefix_len = len(new_indices)
self.dec_lock_ref(req.last_node)
self.inc_lock_ref(new_last_node)
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
if self.page_size != 1:
# Handle partial page, the partial part should be freed in the next cache_unfinished_req and final cache_finished_req.
req.prefix_indices = torch.cat(
[new_indices, kv_indices[len(new_indices) :]]
)
else:
if self.is_eagle:
# Attach the kv index of the last token for EAGLE, it can be used in chunked prefill
req.prefix_indices = torch.cat(
[new_indices, kv_indices[actual_kv_len:]]
)
else:
req.prefix_indices = new_indices
req.last_node = new_last_node
......
......@@ -77,7 +77,8 @@ DEFAULT_MODEL_NAME_FOR_TEST_W8A8_WITH_MOE = "nytopop/Qwen3-30B-A3B.w8a8"
# EAGLE
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST = "meta-llama/Llama-2-7b-chat-hf"
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST = "lmsys/sglang-EAGLE-llama2-chat-7B"
DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3 = "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B"
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST_EAGLE3 = "meta-llama/Llama-3.1-8B-Instruct"
DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3 = "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B"
DEFAULT_STANDALONE_SPECULATIVE_TARGET_MODEL_FOR_TEST = (
"meta-llama/Llama-3.1-8B-Instruct"
)
......
......@@ -9,6 +9,8 @@ from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST_EAGLE3,
DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3,
DEFAULT_MODEL_NAME_FOR_TEST_MLA,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
......@@ -35,6 +37,11 @@ class TestEAGLEEngine(CustomTestCase):
}
NUM_CONFIGS = 2
THRESHOLDS = {
"batch_avg_accept_len": 1.9,
"accept_len": 3.6,
}
def setUp(self):
self.prompt = "Today is a sunny day and I like"
self.sampling_params = {"temperature": 0, "max_new_tokens": 8}
......@@ -63,6 +70,7 @@ class TestEAGLEEngine(CustomTestCase):
self._test_eos_token(engine)
self._test_acc_length(engine)
finally:
engine.flush_cache() # check engine alive
engine.shutdown()
print("=" * 100)
......@@ -92,7 +100,9 @@ class TestEAGLEEngine(CustomTestCase):
"avg_spec_accept_length"
]
print(f"{avg_spec_accept_length=}")
self.assertGreater(avg_spec_accept_length, 1.9)
self.assertGreater(
avg_spec_accept_length, self.THRESHOLDS["batch_avg_accept_len"]
)
def _test_eos_token(self, engine):
prompt = "[INST] <<SYS>>\nYou are a helpful assistant.\n<</SYS>>\nToday is a sunny day and I like [/INST]"
......@@ -131,10 +141,7 @@ class TestEAGLEEngine(CustomTestCase):
)
print(f"{acc_length=:.4f}, {speed=}")
if engine.server_args.model_path == DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST:
self.assertGreater(acc_length, 3.6)
else:
self.assertGreater(acc_length, 2.5)
self.assertGreater(acc_length, self.THRESHOLDS["accept_len"])
class TestEAGLEEngineTokenMap(TestEAGLEEngine):
......@@ -151,12 +158,16 @@ class TestEAGLEEngineTokenMap(TestEAGLEEngine):
"dtype": "float16",
}
NUM_CONFIGS = 1
THRESHOLDS = {
"batch_avg_accept_len": 1.9,
"accept_len": 2.5,
}
class TestEAGLE3Engine(TestEAGLEEngine):
BASE_CONFIG = {
"model_path": "meta-llama/Llama-3.1-8B-Instruct",
"speculative_draft_model_path": "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B",
"model_path": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST_EAGLE3,
"speculative_draft_model_path": DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3,
"speculative_algorithm": "EAGLE3",
"speculative_num_steps": 5,
"speculative_eagle_topk": 16,
......@@ -166,6 +177,72 @@ class TestEAGLE3Engine(TestEAGLEEngine):
"dtype": "float16",
}
NUM_CONFIGS = 1
THRESHOLDS = {
"batch_avg_accept_len": 1.75,
"accept_len": 3.1,
}
class TestEAGLERadixCache(CustomTestCase):
BASE_CONFIG = {
"model_path": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST_EAGLE3,
"speculative_draft_model_path": DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3,
"speculative_algorithm": "EAGLE3",
"speculative_num_steps": 2,
"speculative_eagle_topk": 1,
"speculative_num_draft_tokens": 3,
"mem_fraction_static": 0.7,
"cuda_graph_max_bs": 5,
"dtype": "float16",
}
def test_correctness(self):
configs = [
# Basic config
self.BASE_CONFIG,
# Chunked prefill
{**self.BASE_CONFIG, "chunked_prefill_size": 64},
# Chunked prefill & Page Size > 1
{**self.BASE_CONFIG, "chunked_prefill_size": 64, "page_size": 4},
]
for i, config in enumerate(configs):
with self.subTest(i=i):
print(f"{config=}")
engine = sgl.Engine(**config, log_level="info", decode_log_interval=10)
try:
self._test_acc_length(engine)
finally:
engine.shutdown()
print("=" * 100)
def _test_acc_length(self, engine):
warmup_prompt = [
"Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:",
]
sampling_params = {"temperature": 0, "max_new_tokens": 512}
output = engine.generate(warmup_prompt, sampling_params)
test_prompt = [
"<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGive me a fully functional FastAPI server. Show the python code.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
]
output = engine.generate(test_prompt, sampling_params)
output = output[0]
if "spec_verify_ct" in output["meta_info"]:
acc_length = (
output["meta_info"]["completion_tokens"]
/ output["meta_info"]["spec_verify_ct"]
)
else:
acc_length = 1.0
speed = (
output["meta_info"]["completion_tokens"]
/ output["meta_info"]["e2e_latency"]
)
print(f"{acc_length=:.4f}, {speed=}")
self.assertGreater(acc_length, 2.5)
@unittest.skipIf(is_in_ci(), "To reduce the CI execution time.")
......
......@@ -307,6 +307,72 @@ class TestRadixCache(unittest.TestCase):
result.device_indices, torch.tensor([10, 20], dtype=torch.int64)
)
def test_insert_and_match_eagle(self):
"""Test insert and match operations for EAGLE."""
cache = RadixCache(
req_to_token_pool=None,
token_to_kv_pool_allocator=None,
page_size=1,
disable=False,
is_eagle=True,
)
key = RadixKey([1, 2, 3, 4])
value = torch.tensor([10, 20, 30, 40], dtype=torch.int64)
prefix_len = cache.insert(key, value)
self.assertEqual(prefix_len, 0) # No existing prefix
self.assertEqual(
cache.total_size(), 3
) # The last token is ignored in bigram key
self.assertEqual(cache.evictable_size(), 3)
# Test match_prefix
result = cache.match_prefix(RadixKey([1, 2, 3, 4]))
self.assertEqual(len(result.device_indices), 3)
torch.testing.assert_close(
result.device_indices, torch.tensor([10, 20, 30], dtype=torch.int64)
)
# Test partial match
result = cache.match_prefix(RadixKey([1, 2]))
self.assertEqual(len(result.device_indices), 1)
torch.testing.assert_close(
result.device_indices, torch.tensor([10], dtype=torch.int64)
)
def test_insert_and_match_eagle_page_size(self):
"""Test insert and match operations for EAGLE and page_size > 1."""
cache = RadixCache(
req_to_token_pool=None,
token_to_kv_pool_allocator=None,
page_size=2,
disable=False,
is_eagle=True,
)
key = RadixKey([1, 2, 3])
value = torch.tensor([10, 20, 30], dtype=torch.int64)
prefix_len = cache.insert(key, value)
self.assertEqual(prefix_len, 0) # No existing prefix
self.assertEqual(cache.total_size(), 2) # only one page is inserted
self.assertEqual(cache.evictable_size(), 2)
# Test match_prefix
result = cache.match_prefix(RadixKey([1, 2, 3, 4]))
self.assertEqual(len(result.device_indices), 2)
torch.testing.assert_close(
result.device_indices, torch.tensor([10, 20], dtype=torch.int64)
)
# Test unmatched
result = cache.match_prefix(RadixKey([1, 2]))
self.assertEqual(len(result.device_indices), 0)
torch.testing.assert_close(
result.device_indices, torch.tensor([], dtype=torch.int64)
)
def test_insert_with_none_value(self):
"""Test insert with None value (should use token_ids as list)."""
cache = RadixCache(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment