Unverified Commit 073618d8 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[feat] add sha1_store get function (#1027)


Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent 68af57d8
...@@ -182,7 +182,15 @@ class SHA1_Store: ...@@ -182,7 +182,15 @@ class SHA1_Store:
(Tensor or OrderedDict): (Tensor or OrderedDict):
In-memory object. In-memory object.
""" """
raise NotImplementedError() path = self._sha1_to_dir(sha1).joinpath(sha1)
if not path.exists():
# This is potentially valid case for the caller, we need to inform the
# the caller about it.
raise ValueError(f"Try to get SHA1 {sha1} but it is not found")
# Directly return the object after loading it. This could be throw an
# exception but that indicates some internal error since we should never
# have stored the (invalid) object in the first place with the add() API.
return torch.load(path)
def delete(self, sha1: str) -> None: def delete(self, sha1: str) -> None:
"""Delete a SHA1 """Delete a SHA1
......
...@@ -11,6 +11,7 @@ import pytest ...@@ -11,6 +11,7 @@ import pytest
import torch import torch
from torch import nn from torch import nn
from fair_dev.testing.testing import objects_are_equal
from fairscale.experimental.wgit.sha1_store import SHA1_Store from fairscale.experimental.wgit.sha1_store import SHA1_Store
from fairscale.internal import torch_version from fairscale.internal import torch_version
...@@ -111,3 +112,30 @@ def test_sha1_add_tensor(sha1_store): ...@@ -111,3 +112,30 @@ def test_sha1_add_tensor(sha1_store):
# torch 1.8 LTS doesn't produce deterministic checkpoint file from fixed tensors/state_dict. # torch 1.8 LTS doesn't produce deterministic checkpoint file from fixed tensors/state_dict.
key = "71df4069a03a766eacf9f03eea50968e87eae9f8" key = "71df4069a03a766eacf9f03eea50968e87eae9f8"
assert key in json_dict.keys() and json_dict[key] == 1, json_dict assert key in json_dict.keys() and json_dict[key] == 1, json_dict
def test_sha1_get(sha1_store):
os.chdir(PARENT_DIR)
# Add a file, a state dict and a tensor.
file = "test_get.pt"
torch.save(nn.Linear(100, 100).state_dict(), file)
state_dict = nn.Sequential(nn.Linear(10, 10), nn.Linear(10, 20)).state_dict()
tensor = torch.ones(20, 30)
# Check that we can get them back.
file_sha1 = sha1_store.add(file)
sd = sha1_store.get(file_sha1)
assert objects_are_equal(sd, torch.load(file))
sd_sha1 = sha1_store.add(state_dict)
sd = sha1_store.get(sd_sha1)
assert objects_are_equal(sd, state_dict)
tensor_sha1 = sha1_store.add(tensor)
tensor_got = sha1_store.get(tensor_sha1)
assert objects_are_equal(tensor_got, tensor)
# Make sure invalid sha1 cause exceptions.
with pytest.raises(ValueError):
sha1_store.get(tensor_sha1[:-1])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment