Unverified Commit 6078d5fc authored by huangtingwei's avatar huangtingwei Committed by GitHub
Browse files

[HiCacheStorage] backup optimization for MLA model (#8865)


Co-authored-by: default avatarZhiqiang Xie <xiezhq@stanford.edu>
parent 70cf4abc
...@@ -26,6 +26,8 @@ if TYPE_CHECKING: ...@@ -26,6 +26,8 @@ if TYPE_CHECKING:
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool_host import HostKVCache from sglang.srt.mem_cache.memory_pool_host import HostKVCache
from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.mem_cache.memory_pool_host import MLATokenToKVPoolHost
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -238,13 +240,14 @@ class HiCacheController: ...@@ -238,13 +240,14 @@ class HiCacheController:
self.io_backend = io_backend self.io_backend = io_backend
self.enable_storage = False self.enable_storage = False
self.is_mla = isinstance(self.mem_pool_host, MLATokenToKVPoolHost)
# todo: move backend initialization to storage backend module # todo: move backend initialization to storage backend module
if storage_backend is not None: if storage_backend is not None:
self.storage_backend_type = storage_backend self.storage_backend_type = storage_backend
from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
if storage_backend == "file": if storage_backend == "file":
self.storage_backend = HiCacheFile() self.storage_backend = HiCacheFile(is_mla=self.is_mla)
self.get_hash_str = get_hash_str self.get_hash_str = get_hash_str
elif storage_backend == "nixl": elif storage_backend == "nixl":
from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
...@@ -257,12 +260,11 @@ class HiCacheController: ...@@ -257,12 +260,11 @@ class HiCacheController:
get_hash_str_mooncake, get_hash_str_mooncake,
) )
self.storage_backend = MooncakeStore() self.storage_backend = MooncakeStore(is_mla=self.is_mla)
self.get_hash_str = get_hash_str_mooncake self.get_hash_str = get_hash_str_mooncake
self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer) self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
assert self.mem_pool_host.layout == "page_first" assert self.mem_pool_host.layout == "page_first"
elif storage_backend == "hf3fs": elif storage_backend == "hf3fs":
from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import ( from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import (
HiCacheHF3FS, HiCacheHF3FS,
) )
...@@ -399,6 +401,15 @@ class HiCacheController: ...@@ -399,6 +401,15 @@ class HiCacheController:
self.prefetch_thread.start() self.prefetch_thread.start()
self.backup_thread.start() self.backup_thread.start()
@property
def backup_skip(self):
return (
self.is_mla
and get_tensor_model_parallel_rank() != 0
# todo: only support file and mooncake
and self.storage_backend_type in ["file", "mooncake"]
)
def write( def write(
self, self,
device_indices: torch.Tensor, device_indices: torch.Tensor,
...@@ -809,6 +820,7 @@ class HiCacheController: ...@@ -809,6 +820,7 @@ class HiCacheController:
if operation is None: if operation is None:
continue continue
if not self.backup_skip:
if self.is_mooncake_backend(): if self.is_mooncake_backend():
self.mooncake_page_backup(operation) self.mooncake_page_backup(operation)
elif self.storage_backend_type == "hf3fs": elif self.storage_backend_type == "hf3fs":
...@@ -818,8 +830,10 @@ class HiCacheController: ...@@ -818,8 +830,10 @@ class HiCacheController:
self.generic_page_backup(operation, batch_size=128) self.generic_page_backup(operation, batch_size=128)
else: else:
self.generic_page_backup(operation) self.generic_page_backup(operation)
min_completed_tokens = operation.completed_tokens min_completed_tokens = operation.completed_tokens
else:
min_completed_tokens = len(operation.token_ids)
if self.tp_world_size > 1: if self.tp_world_size > 1:
completed_tokens_tensor = torch.tensor( completed_tokens_tensor = torch.tensor(
min_completed_tokens, dtype=torch.int min_completed_tokens, dtype=torch.int
......
...@@ -101,11 +101,11 @@ class HiCacheStorage(ABC): ...@@ -101,11 +101,11 @@ class HiCacheStorage(ABC):
class HiCacheFile(HiCacheStorage): class HiCacheFile(HiCacheStorage):
def __init__(self, file_path: str = "/tmp/hicache"): def __init__(self, file_path: str = "/tmp/hicache", is_mla: bool = False):
self.file_path = os.getenv("SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR", file_path) self.file_path = os.getenv("SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR", file_path)
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
self.tp_suffix = f"_{tp_rank}_{tp_size}" if tp_size > 1 else "" self.tp_suffix = f"_{tp_rank}_{tp_size}" if tp_size > 1 and not is_mla else ""
if not os.path.exists(self.file_path) and tp_rank == 0: if not os.path.exists(self.file_path) and tp_rank == 0:
os.makedirs(self.file_path) os.makedirs(self.file_path)
logger.info(f"Created HiCacheFile storage directory at {self.file_path}") logger.info(f"Created HiCacheFile storage directory at {self.file_path}")
......
...@@ -7,6 +7,7 @@ from functools import wraps ...@@ -7,6 +7,7 @@ from functools import wraps
import psutil import psutil
import torch import torch
from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
from sglang.srt.utils import is_npu from sglang.srt.utils import is_npu
...@@ -487,8 +488,8 @@ class MHATokenToKVPoolHost(HostKVCache): ...@@ -487,8 +488,8 @@ class MHATokenToKVPoolHost(HostKVCache):
ptr_list.append(k_ptr) ptr_list.append(k_ptr)
ptr_list.append(v_ptr) ptr_list.append(v_ptr)
key_ = keys[index // self.page_size] key_ = keys[index // self.page_size]
key_list.append(f"{key_}_k") key_list.append(f"{key_}_{get_tensor_model_parallel_rank()}_k")
key_list.append(f"{key_}_v") key_list.append(f"{key_}_{get_tensor_model_parallel_rank()}_v")
element_size = ( element_size = (
self.layer_num self.layer_num
* self.dtype.itemsize * self.dtype.itemsize
......
...@@ -19,14 +19,13 @@ logger = logging.getLogger(__name__) ...@@ -19,14 +19,13 @@ logger = logging.getLogger(__name__)
def get_hash_str_mooncake(token_ids: List[int], prior_hash: str = None): def get_hash_str_mooncake(token_ids: List[int], prior_hash: str = None):
local_rank = get_tensor_model_parallel_rank()
prefix_str = "" prefix_str = ""
if prior_hash: if prior_hash:
prefix_str = hashlib.sha256(prior_hash.encode()).hexdigest() prefix_str = hashlib.sha256(prior_hash.encode()).hexdigest()
current_token_ids_bytes = np.array(token_ids).tobytes() current_token_ids_bytes = np.array(token_ids).tobytes()
current_hash_object = hashlib.sha256(current_token_ids_bytes) current_hash_object = hashlib.sha256(current_token_ids_bytes)
current_hash_hex = current_hash_object.hexdigest() current_hash_hex = current_hash_object.hexdigest()
return f"{prefix_str}_{int(current_hash_hex[:16], 16)}_{local_rank}" return f"{prefix_str}_{int(current_hash_hex[:16], 16)}"
@dataclass @dataclass
...@@ -97,7 +96,7 @@ class MooncakeStoreConfig: ...@@ -97,7 +96,7 @@ class MooncakeStoreConfig:
class MooncakeStore(HiCacheStorage): class MooncakeStore(HiCacheStorage):
def __init__(self): def __init__(self, is_mla: bool = False):
try: try:
from mooncake.store import MooncakeDistributedStore from mooncake.store import MooncakeDistributedStore
except ImportError as e: except ImportError as e:
...@@ -127,6 +126,7 @@ class MooncakeStore(HiCacheStorage): ...@@ -127,6 +126,7 @@ class MooncakeStore(HiCacheStorage):
logger.info("Connect to Mooncake store successfully.") logger.info("Connect to Mooncake store successfully.")
self.warmup() self.warmup()
logger.info("Mooncake store warmup successfully.") logger.info("Mooncake store warmup successfully.")
self.is_mla = is_mla
except ValueError as e: except ValueError as e:
logger.error("Configuration loading failed: %s", e) logger.error("Configuration loading failed: %s", e)
...@@ -223,11 +223,15 @@ class MooncakeStore(HiCacheStorage): ...@@ -223,11 +223,15 @@ class MooncakeStore(HiCacheStorage):
def exists(self, keys) -> bool | dict: def exists(self, keys) -> bool | dict:
_keys = [] _keys = []
local_rank = get_tensor_model_parallel_rank()
for key in keys: for key in keys:
if key is None: if key is None:
return None return None
if self.is_mla:
_keys.append(f"{key}_k") _keys.append(f"{key}_k")
else:
_keys.append(f"{key}_{local_rank}_k")
result = {k: v for k, v in zip(keys, self.store.batch_is_exist(_keys))} result = {k: v for k, v in zip(keys, self.store.batch_is_exist(_keys))}
return result return result
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment