Unverified Commit b0add2da authored by Zhiqiang Xie's avatar Zhiqiang Xie Committed by GitHub
Browse files

HiCache storage, style change and bug fix (#8719)

parent 0305c505
...@@ -33,8 +33,7 @@ class HiCacheStorage(ABC): ...@@ -33,8 +33,7 @@ class HiCacheStorage(ABC):
It abstracts the underlying storage mechanism, allowing different implementations to be used. It abstracts the underlying storage mechanism, allowing different implementations to be used.
""" """
# todo, translate tensor object access for different TP ranks # todo, potentially pass model and TP configs into storage backend
# potentially pass model and TP configs into storage backend
# todo, the page size of storage backend does not have to be the same as the same as host memory pool # todo, the page size of storage backend does not have to be the same as the same as host memory pool
@abstractmethod @abstractmethod
...@@ -117,35 +116,28 @@ class HiCacheFile(HiCacheStorage): ...@@ -117,35 +116,28 @@ class HiCacheFile(HiCacheStorage):
def get( def get(
self, self,
key: str, key: str,
target_location: Optional[Any] = None, target_location: torch.Tensor,
target_sizes: Optional[Any] = None, target_sizes: Optional[Any] = None,
) -> torch.Tensor | None: ) -> torch.Tensor | None:
key = self._get_suffixed_key(key) key = self._get_suffixed_key(key)
tensor_path = os.path.join(self.file_path, f"{key}.bin") tensor_path = os.path.join(self.file_path, f"{key}.bin")
try: try:
if target_location is not None: # Load directly into target_location's memory buffer
# Load directly into target_location's memory buffer with open(tensor_path, "rb") as f:
with open(tensor_path, "rb") as f: target_location.set_(
target_location.set_( torch.frombuffer(f.read(), dtype=target_location.dtype)
torch.frombuffer(f.read(), dtype=target_location.dtype) .reshape(target_location.shape)
.reshape(target_location.shape) .untyped_storage()
.storage() )
) return target_location
return target_location
else:
loaded_tensor = torch.load(tensor_path)
if isinstance(loaded_tensor, torch.Tensor):
return loaded_tensor
else:
logger.error(f"Loaded data for key {key} is not a tensor.")
return None
except FileNotFoundError: except FileNotFoundError:
logger.warning(f"Failed to fetch {key} from HiCacheFile storage.")
return None return None
def batch_get( def batch_get(
self, self,
keys: List[str], keys: List[str],
target_locations: Optional[Any] = None, target_locations: List[torch.Tensor],
target_sizes: Optional[Any] = None, target_sizes: Optional[Any] = None,
) -> List[torch.Tensor | None]: ) -> List[torch.Tensor | None]:
return [ return [
...@@ -168,7 +160,7 @@ class HiCacheFile(HiCacheStorage): ...@@ -168,7 +160,7 @@ class HiCacheFile(HiCacheStorage):
logger.debug(f"Key {key} already exists. Skipped.") logger.debug(f"Key {key} already exists. Skipped.")
return True return True
try: try:
torch.save(value, tensor_path) value.contiguous().view(dtype=torch.uint8).numpy().tofile(tensor_path)
return True return True
except Exception as e: except Exception as e:
logger.error(f"Failed to save tensor {key}: {e}") logger.error(f"Failed to save tensor {key}: {e}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment