Unverified Commit a98290ae authored by Zhiqiang Xie's avatar Zhiqiang Xie Committed by GitHub
Browse files

Unit test for Hierarchical Caching (#4486)

parent 9b81f9bd
...@@ -445,6 +445,7 @@ class Scheduler(SchedulerOutputProcessorMixin): ...@@ -445,6 +445,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
tp_cache_group=self.tp_worker.get_tp_cpu_group(), tp_cache_group=self.tp_worker.get_tp_cpu_group(),
page_size=self.page_size, page_size=self.page_size,
hicache_ratio=server_args.hicache_ratio,
) )
else: else:
self.tree_cache = RadixCache( self.tree_cache = RadixCache(
......
...@@ -29,6 +29,7 @@ class HiRadixCache(RadixCache): ...@@ -29,6 +29,7 @@ class HiRadixCache(RadixCache):
token_to_kv_pool_allocator: TokenToKVPoolAllocator, token_to_kv_pool_allocator: TokenToKVPoolAllocator,
tp_cache_group: torch.distributed.ProcessGroup, tp_cache_group: torch.distributed.ProcessGroup,
page_size: int, page_size: int,
hicache_ratio: float,
): ):
if page_size != 1: if page_size != 1:
raise ValueError( raise ValueError(
...@@ -36,9 +37,13 @@ class HiRadixCache(RadixCache): ...@@ -36,9 +37,13 @@ class HiRadixCache(RadixCache):
) )
self.kv_cache = token_to_kv_pool_allocator.get_kvcache() self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
if isinstance(self.kv_cache, MHATokenToKVPool): if isinstance(self.kv_cache, MHATokenToKVPool):
self.token_to_kv_pool_host = MHATokenToKVPoolHost(self.kv_cache) self.token_to_kv_pool_host = MHATokenToKVPoolHost(
self.kv_cache, hicache_ratio
)
elif isinstance(self.kv_cache, MLATokenToKVPool): elif isinstance(self.kv_cache, MLATokenToKVPool):
self.token_to_kv_pool_host = MLATokenToKVPoolHost(self.kv_cache) self.token_to_kv_pool_host = MLATokenToKVPoolHost(
self.kv_cache, hicache_ratio
)
else: else:
raise ValueError(f"Only MHA and MLA supports swap kv_cache to host.") raise ValueError(f"Only MHA and MLA supports swap kv_cache to host.")
......
...@@ -581,7 +581,7 @@ class HostKVCache(abc.ABC): ...@@ -581,7 +581,7 @@ class HostKVCache(abc.ABC):
def __init__( def __init__(
self, self,
device_pool: MHATokenToKVPool, device_pool: MHATokenToKVPool,
host_to_device_ratio: float = 3.0, host_to_device_ratio: float,
pin_memory: bool = False, # no need to use pin memory with the double buffering pin_memory: bool = False, # no need to use pin memory with the double buffering
device: str = "cpu", device: str = "cpu",
): ):
...@@ -747,7 +747,7 @@ class MHATokenToKVPoolHost(HostKVCache): ...@@ -747,7 +747,7 @@ class MHATokenToKVPoolHost(HostKVCache):
def __init__( def __init__(
self, self,
device_pool: MHATokenToKVPool, device_pool: MHATokenToKVPool,
host_to_device_ratio: float = 3.0, host_to_device_ratio: float,
pin_memory: bool = False, # no need to use pin memory with the double buffering pin_memory: bool = False, # no need to use pin memory with the double buffering
device: str = "cpu", device: str = "cpu",
): ):
...@@ -789,7 +789,7 @@ class MLATokenToKVPoolHost(HostKVCache): ...@@ -789,7 +789,7 @@ class MLATokenToKVPoolHost(HostKVCache):
def __init__( def __init__(
self, self,
device_pool: MLATokenToKVPool, device_pool: MLATokenToKVPool,
host_to_device_ratio: float = 4.0, host_to_device_ratio: float,
pin_memory: bool = False, # no need to use pin memory with the double buffering pin_memory: bool = False, # no need to use pin memory with the double buffering
device: str = "cpu", device: str = "cpu",
): ):
......
...@@ -173,6 +173,7 @@ class ServerArgs: ...@@ -173,6 +173,7 @@ class ServerArgs:
enable_custom_logit_processor: bool = False enable_custom_logit_processor: bool = False
tool_call_parser: str = None tool_call_parser: str = None
enable_hierarchical_cache: bool = False enable_hierarchical_cache: bool = False
hicache_ratio: float = 2.0
enable_flashinfer_mla: bool = False enable_flashinfer_mla: bool = False
enable_flashmla: bool = False enable_flashmla: bool = False
flashinfer_mla_disable_ragged: bool = False flashinfer_mla_disable_ragged: bool = False
...@@ -1007,6 +1008,13 @@ class ServerArgs: ...@@ -1007,6 +1008,13 @@ class ServerArgs:
action="store_true", action="store_true",
help="Enable hierarchical cache", help="Enable hierarchical cache",
) )
parser.add_argument(
"--hicache-ratio",
type=float,
required=False,
default=ServerArgs.hicache_ratio,
help="The ratio of the size of host KV cache memory pool to the size of device pool.",
)
# Server warmups # Server warmups
parser.add_argument( parser.add_argument(
......
...@@ -74,6 +74,8 @@ suites = { ...@@ -74,6 +74,8 @@ suites = {
TestFile("test_w8a8_quantization.py", 46), TestFile("test_w8a8_quantization.py", 46),
TestFile("test_eval_fp8_accuracy.py", 172), TestFile("test_eval_fp8_accuracy.py", 172),
TestFile("test_create_kvindices.py", 2), TestFile("test_create_kvindices.py", 2),
TestFile("test_hicache.py", 60),
TestFile("test_hicache_mla.py", 90),
], ],
"nightly": [ "nightly": [
TestFile("test_nightly_gsm8k_eval.py"), TestFile("test_nightly_gsm8k_eval.py"),
......
import unittest
from types import SimpleNamespace
from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)
class TestPageSize(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--enable-hierarchical-cache"],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
self.assertGreaterEqual(metrics["score"], 0.65)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment