Unverified Commit 161e9dc5 authored by hzh0425's avatar hzh0425 Committed by GitHub
Browse files

feat(hicache-3fs): 3FS-Store Backup Optimizations For MLA Model. (#9692)

parent 54e872d3
...@@ -125,6 +125,7 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -125,6 +125,7 @@ class HiCacheHF3FS(HiCacheStorage):
entries: int, entries: int,
dtype: torch.dtype, dtype: torch.dtype,
metadata_client: Hf3fsMetadataInterface, metadata_client: Hf3fsMetadataInterface,
is_mla_model: bool = False,
): ):
self.rank = rank self.rank = rank
self.file_path = file_path self.file_path = file_path
...@@ -134,9 +135,13 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -134,9 +135,13 @@ class HiCacheHF3FS(HiCacheStorage):
self.entries = entries self.entries = entries
self.dtype = dtype self.dtype = dtype
self.metadata_client = metadata_client self.metadata_client = metadata_client
self.is_mla_model = is_mla_model
self.numel = self.bytes_per_page // self.dtype.itemsize self.numel = self.bytes_per_page // self.dtype.itemsize
self.num_pages = self.file_size // self.bytes_per_page self.num_pages = self.file_size // self.bytes_per_page
self.skip_backup = False
if self.is_mla_model and self.rank != 0:
self.skip_backup = True
self.rank = 0
logger.info( logger.info(
f"[Rank {self.rank}] HiCacheHF3FS Client Initializing: " f"[Rank {self.rank}] HiCacheHF3FS Client Initializing: "
...@@ -209,10 +214,14 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -209,10 +214,14 @@ class HiCacheHF3FS(HiCacheStorage):
raise ValueError(f"Missing required keys in config: {missing_keys}") raise ValueError(f"Missing required keys in config: {missing_keys}")
# Choose metadata client based on configuration # Choose metadata client based on configuration
is_mla_model = False
if "metadata_server_url" in config and config["metadata_server_url"]: if "metadata_server_url" in config and config["metadata_server_url"]:
# Use global metadata client to connect to metadata server # Use global metadata client to connect to metadata server
metadata_server_url = config["metadata_server_url"] metadata_server_url = config["metadata_server_url"]
metadata_client = Hf3fsGlobalMetadataClient(metadata_server_url) metadata_client = Hf3fsGlobalMetadataClient(metadata_server_url)
# Enable MLA optimization only when using the global metadata client
is_mla_model = storage_config.is_mla_model if storage_config else False
logger.info( logger.info(
f"Using global metadata client with server url: {metadata_server_url}" f"Using global metadata client with server url: {metadata_server_url}"
) )
...@@ -222,13 +231,15 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -222,13 +231,15 @@ class HiCacheHF3FS(HiCacheStorage):
return HiCacheHF3FS( return HiCacheHF3FS(
rank=rank, rank=rank,
file_path=f"{config['file_path_prefix']}.{rank}.bin", # Let all ranks use the same file path for MLA model
file_path=f"{config['file_path_prefix']}.{rank if not is_mla_model else 0}.bin",
file_size=int(config["file_size"]), file_size=int(config["file_size"]),
numjobs=int(config["numjobs"]), numjobs=int(config["numjobs"]),
bytes_per_page=bytes_per_page, bytes_per_page=bytes_per_page,
entries=int(config["entries"]), entries=int(config["entries"]),
dtype=dtype, dtype=dtype,
metadata_client=metadata_client, metadata_client=metadata_client,
is_mla_model=is_mla_model,
) )
def get( def get(
...@@ -312,6 +323,10 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -312,6 +323,10 @@ class HiCacheHF3FS(HiCacheStorage):
target_locations: Optional[Any] = None, target_locations: Optional[Any] = None,
target_sizes: Optional[Any] = None, target_sizes: Optional[Any] = None,
) -> bool: ) -> bool:
# In MLA backend, only one rank needs to backup the KV cache
if self.skip_backup:
return True
# Todo: Add prefix block's hash key # Todo: Add prefix block's hash key
key_with_prefix = [(key, "") for key in keys] key_with_prefix = [(key, "") for key in keys]
indices = self.metadata_client.reserve_and_allocate_page_indices( indices = self.metadata_client.reserve_and_allocate_page_indices(
...@@ -363,16 +378,21 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -363,16 +378,21 @@ class HiCacheHF3FS(HiCacheStorage):
return all(results) return all(results)
@synchronized()
def delete(self, key: str) -> None: def delete(self, key: str) -> None:
self.metadata_client.delete_keys(self.rank, [key]) self.metadata_client.delete_keys(self.rank, [key])
@synchronized()
def exists(self, key: str) -> bool: def exists(self, key: str) -> bool:
result = self.metadata_client.exists(self.rank, [key]) result = self.metadata_client.exists(self.rank, [key])
return result[0] if result else False return result[0] if result else False
@synchronized() def batch_exists(self, keys: List[str]) -> int:
results = self.metadata_client.exists(self.rank, keys)
for i in range(len(keys)):
if not results[i]:
return i
return len(keys)
def clear(self) -> None: def clear(self) -> None:
self.metadata_client.clear(self.rank) self.metadata_client.clear(self.rank)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment