Unverified Commit 4d89389c authored by hzh0425's avatar hzh0425 Committed by GitHub
Browse files

Fix the key passing issue in page first layout. (#9929)

parent 9491d6e5
...@@ -407,6 +407,7 @@ class HiCacheController: ...@@ -407,6 +407,7 @@ class HiCacheController:
tp_rank=self.tp_rank, tp_rank=self.tp_rank,
tp_size=self.tp_size, tp_size=self.tp_size,
is_mla_model=is_mla_backend, is_mla_model=is_mla_backend,
is_page_first_layout=self.mem_pool_host.layout == "page_first",
model_name=model_name, model_name=model_name,
extra_config=extra_config, extra_config=extra_config,
) )
......
...@@ -27,6 +27,7 @@ class HiCacheStorageConfig: ...@@ -27,6 +27,7 @@ class HiCacheStorageConfig:
tp_rank: int tp_rank: int
tp_size: int tp_size: int
is_mla_model: bool is_mla_model: bool
is_page_first_layout: bool
model_name: Optional[str] model_name: Optional[str]
extra_config: Optional[dict] = None extra_config: Optional[dict] = None
......
...@@ -128,6 +128,7 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -128,6 +128,7 @@ class HiCacheHF3FS(HiCacheStorage):
dtype: torch.dtype, dtype: torch.dtype,
metadata_client: Hf3fsMetadataInterface, metadata_client: Hf3fsMetadataInterface,
is_mla_model: bool = False, is_mla_model: bool = False,
is_page_first_layout: bool = False,
): ):
self.rank = rank self.rank = rank
self.file_path = file_path self.file_path = file_path
...@@ -138,6 +139,7 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -138,6 +139,7 @@ class HiCacheHF3FS(HiCacheStorage):
self.dtype = dtype self.dtype = dtype
self.metadata_client = metadata_client self.metadata_client = metadata_client
self.is_mla_model = is_mla_model self.is_mla_model = is_mla_model
self.is_page_first_layout = is_page_first_layout
self.numel = self.bytes_per_page // self.dtype.itemsize self.numel = self.bytes_per_page // self.dtype.itemsize
self.num_pages = self.file_size // self.bytes_per_page self.num_pages = self.file_size // self.bytes_per_page
self.skip_backup = False self.skip_backup = False
...@@ -193,9 +195,13 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -193,9 +195,13 @@ class HiCacheHF3FS(HiCacheStorage):
) )
if storage_config is not None: if storage_config is not None:
rank, is_mla_model = storage_config.tp_rank, storage_config.is_mla_model rank, is_mla_model, is_page_first_layout = (
storage_config.tp_rank,
storage_config.is_mla_model,
storage_config.is_page_first_layout,
)
else: else:
rank, is_mla_model = 0, False rank, is_mla_model, is_page_first_layout = 0, False, False
mla_unsupported_msg = f"MLA model is not supported without global metadata server, please refer to https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/mem_cache/storage/hf3fs/docs/deploy_sglang_3fs_multinode.md" mla_unsupported_msg = f"MLA model is not supported without global metadata server, please refer to https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/mem_cache/storage/hf3fs/docs/deploy_sglang_3fs_multinode.md"
...@@ -213,6 +219,7 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -213,6 +219,7 @@ class HiCacheHF3FS(HiCacheStorage):
entries=8, entries=8,
dtype=dtype, dtype=dtype,
metadata_client=Hf3fsLocalMetadataClient(), metadata_client=Hf3fsLocalMetadataClient(),
is_page_first_layout=is_page_first_layout,
) )
try: try:
...@@ -261,6 +268,7 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -261,6 +268,7 @@ class HiCacheHF3FS(HiCacheStorage):
dtype=dtype, dtype=dtype,
metadata_client=metadata_client, metadata_client=metadata_client,
is_mla_model=is_mla_model, is_mla_model=is_mla_model,
is_page_first_layout=is_page_first_layout,
) )
def get( def get(
...@@ -407,12 +415,22 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -407,12 +415,22 @@ class HiCacheHF3FS(HiCacheStorage):
return result[0] if result else False return result[0] if result else False
def batch_exists(self, keys: List[str]) -> int: def batch_exists(self, keys: List[str]) -> int:
results = self.metadata_client.exists(self.rank, keys) if self.is_page_first_layout and not self.is_mla_model:
for i in range(len(keys)): query_keys = []
if not results[i]: # Compatible with page_first layout's key format, Refer to memory_pool_host.py#get_buffer_with_hash
return i for key in keys:
query_keys.append(f"{key}-k")
return len(keys) query_keys.append(f"{key}-v")
key_multiplier = 2
else:
query_keys = keys
key_multiplier = 1
exist_result = self.metadata_client.exists(self.rank, query_keys)
for i in range(len(query_keys)):
if not exist_result[i]:
return i // key_multiplier
return len(query_keys) // key_multiplier
def clear(self) -> bool: def clear(self) -> bool:
try: try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment