Unverified Commit b0add2da authored by Zhiqiang Xie's avatar Zhiqiang Xie Committed by GitHub
Browse files

HiCache storage, style change and bug fix (#8719)

parent 0305c505
......@@ -33,8 +33,7 @@ class HiCacheStorage(ABC):
It abstracts the underlying storage mechanism, allowing different implementations to be used.
"""
# todo, translate tensor object access for different TP ranks
# potentially pass model and TP configs into storage backend
# todo, potentially pass model and TP configs into storage backend
# todo, the page size of storage backend does not have to be the same as the same as host memory pool
@abstractmethod
......@@ -117,35 +116,28 @@ class HiCacheFile(HiCacheStorage):
def get(
self,
key: str,
target_location: Optional[Any] = None,
target_location: torch.Tensor,
target_sizes: Optional[Any] = None,
) -> torch.Tensor | None:
key = self._get_suffixed_key(key)
tensor_path = os.path.join(self.file_path, f"{key}.bin")
try:
if target_location is not None:
# Load directly into target_location's memory buffer
with open(tensor_path, "rb") as f:
target_location.set_(
torch.frombuffer(f.read(), dtype=target_location.dtype)
.reshape(target_location.shape)
.storage()
)
return target_location
else:
loaded_tensor = torch.load(tensor_path)
if isinstance(loaded_tensor, torch.Tensor):
return loaded_tensor
else:
logger.error(f"Loaded data for key {key} is not a tensor.")
return None
# Load directly into target_location's memory buffer
with open(tensor_path, "rb") as f:
target_location.set_(
torch.frombuffer(f.read(), dtype=target_location.dtype)
.reshape(target_location.shape)
.untyped_storage()
)
return target_location
except FileNotFoundError:
logger.warning(f"Failed to fetch {key} from HiCacheFile storage.")
return None
def batch_get(
self,
keys: List[str],
target_locations: Optional[Any] = None,
target_locations: List[torch.Tensor],
target_sizes: Optional[Any] = None,
) -> List[torch.Tensor | None]:
return [
......@@ -168,7 +160,7 @@ class HiCacheFile(HiCacheStorage):
logger.debug(f"Key {key} already exists. Skipped.")
return True
try:
torch.save(value, tensor_path)
value.contiguous().view(dtype=torch.uint8).numpy().tofile(tensor_path)
return True
except Exception as e:
logger.error(f"Failed to save tensor {key}: {e}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment