Unverified Commit c75d1896 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[feat] add sha1_store delete function (#1028)


Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent 073618d8
...@@ -23,6 +23,20 @@ from .utils import ExitCode ...@@ -23,6 +23,20 @@ from .utils import ExitCode
# for backward compatibility reasons. # for backward compatibility reasons.
SHA1_STORE_DIR_NAME = "sha1_store" SHA1_STORE_DIR_NAME = "sha1_store"
# Const string keys for json file. Do not change for backward compatibilities.
RF_KEY = "ref_count"
def _get_json_entry(d: Dict[str, Any]) -> Dict[str, Any]:
"""Get a dict from a json entry.
This fills in any missing entries in case we load an older version
json file from the disk.
"""
if RF_KEY not in d.keys():
d[RF_KEY] = 0
return d
class SHA1_Store: class SHA1_Store:
""" """
...@@ -181,6 +195,9 @@ class SHA1_Store: ...@@ -181,6 +195,9 @@ class SHA1_Store:
Returns: Returns:
(Tensor or OrderedDict): (Tensor or OrderedDict):
In-memory object. In-memory object.
Throws:
ValueError if sha1 is not found.
""" """
path = self._sha1_to_dir(sha1).joinpath(sha1) path = self._sha1_to_dir(sha1).joinpath(sha1)
if not path.exists(): if not path.exists():
...@@ -190,6 +207,9 @@ class SHA1_Store: ...@@ -190,6 +207,9 @@ class SHA1_Store:
# Directly return the object after loading it. This could be throw an # Directly return the object after loading it. This could be throw an
# exception but that indicates some internal error since we should never # exception but that indicates some internal error since we should never
# have stored the (invalid) object in the first place with the add() API. # have stored the (invalid) object in the first place with the add() API.
#
# TODO (Min): we could also keep a stats in the meta data on how many
# times the object is read. Will add if that's needed.
return torch.load(path) return torch.load(path)
def delete(self, sha1: str) -> None: def delete(self, sha1: str) -> None:
...@@ -199,8 +219,34 @@ class SHA1_Store: ...@@ -199,8 +219,34 @@ class SHA1_Store:
sha1 (str): sha1 (str):
SHA1 of the object to delete. SHA1 of the object to delete.
Throws:
ValueError if sha1 is not found.
""" """
raise NotImplementedError() path = self._sha1_to_dir(sha1).joinpath(sha1)
if not path.exists():
# This is potentially a valid case for the caller, we need to inform the
# the caller about it.
raise ValueError(f"Try to delete SHA1 {sha1} but it is not found")
self._load_json_dict()
assert sha1 in self._json_dict.keys(), "internal error: sha1 not found in json"
entry = _get_json_entry(self._json_dict[sha1])
assert entry[RF_KEY] > 0, f"ref count {entry[RF_KEY]} should be positive"
entry[RF_KEY] -= 1
if entry[RF_KEY] == 0:
# Now, since ref count is 0 now deleting the object.
path.unlink() # We may leave behind an empty dir, which is OK.
entry = {} # Below, we remove the entry because of this.
# Put the entry back and store it or delete it.
if entry:
self._json_dict[sha1] = entry
else:
# empty entry, it means this sha1 is deleted.
del self._json_dict[sha1]
self._store_json_dict()
def _get_sha1_hash(self, file_path: Union[str, Path]) -> str: def _get_sha1_hash(self, file_path: Union[str, Path]) -> str:
"""Return the sha1 hash of a file """Return the sha1 hash of a file
...@@ -257,12 +303,16 @@ class SHA1_Store: ...@@ -257,12 +303,16 @@ class SHA1_Store:
# Init the entry if needed. # Init the entry if needed.
if current_sha1_hash not in self._json_dict: if current_sha1_hash not in self._json_dict:
self._json_dict[current_sha1_hash] = 0 entry = {}
else:
entry = self._json_dict[current_sha1_hash]
entry = _get_json_entry(entry)
# Update the ref count. # Update the ref count.
self._json_dict[current_sha1_hash] += 1 if inc else -1 entry[RF_KEY] += 1 if inc else -1
assert self._json_dict[current_sha1_hash] >= 0, "negative ref count" assert entry[RF_KEY] >= 0, "negative ref count"
self._json_dict[current_sha1_hash] = entry
self._store_json_dict() self._store_json_dict()
return self._json_dict[current_sha1_hash] return entry[RF_KEY]
...@@ -81,9 +81,9 @@ def test_sha1_add_file(sha1_store): ...@@ -81,9 +81,9 @@ def test_sha1_add_file(sha1_store):
if torch_version() >= (1, 9, 0): if torch_version() >= (1, 9, 0):
# torch 1.8 LTS doesn't produce deterministic checkpoint file from fixed tensors/state_dict. # torch 1.8 LTS doesn't produce deterministic checkpoint file from fixed tensors/state_dict.
key = "da3e19590de8f77fcf7a09c888c526b0149863a0" key = "da3e19590de8f77fcf7a09c888c526b0149863a0"
assert key in json_dict.keys() and json_dict[key] == 2, json_dict assert key in json_dict.keys() and json_dict[key]["ref_count"] == 2, json_dict
del json_dict["created_on"] del json_dict["created_on"]
assert sorted(json_dict.values()) == [1, 1, 1, 1, 1, 2], json_dict assert sorted(map(lambda x: x["ref_count"], json_dict.values())) == [1, 1, 1, 1, 1, 2], json_dict
def test_sha1_add_state_dict(sha1_store): def test_sha1_add_state_dict(sha1_store):
...@@ -100,7 +100,7 @@ def test_sha1_add_state_dict(sha1_store): ...@@ -100,7 +100,7 @@ def test_sha1_add_state_dict(sha1_store):
sha1_store._load_json_dict() sha1_store._load_json_dict()
json_dict = sha1_store._json_dict json_dict = sha1_store._json_dict
del json_dict["created_on"] del json_dict["created_on"]
assert sorted(json_dict.values()) == [1, 1, 1, 2, 2, 2], json_dict assert sorted(map(lambda x: x["ref_count"], json_dict.values())) == [1, 1, 1, 2, 2, 2], json_dict
def test_sha1_add_tensor(sha1_store): def test_sha1_add_tensor(sha1_store):
...@@ -111,10 +111,11 @@ def test_sha1_add_tensor(sha1_store): ...@@ -111,10 +111,11 @@ def test_sha1_add_tensor(sha1_store):
if torch_version() >= (1, 9, 0): if torch_version() >= (1, 9, 0):
# torch 1.8 LTS doesn't produce deterministic checkpoint file from fixed tensors/state_dict. # torch 1.8 LTS doesn't produce deterministic checkpoint file from fixed tensors/state_dict.
key = "71df4069a03a766eacf9f03eea50968e87eae9f8" key = "71df4069a03a766eacf9f03eea50968e87eae9f8"
assert key in json_dict.keys() and json_dict[key] == 1, json_dict assert key in json_dict.keys() and json_dict[key]["ref_count"] == 1, json_dict
def test_sha1_get(sha1_store): def test_sha1_get(sha1_store):
"""Testing the get() API: normal and exception cases."""
os.chdir(PARENT_DIR) os.chdir(PARENT_DIR)
# Add a file, a state dict and a tensor. # Add a file, a state dict and a tensor.
...@@ -139,3 +140,26 @@ def test_sha1_get(sha1_store): ...@@ -139,3 +140,26 @@ def test_sha1_get(sha1_store):
# Make sure invalid sha1 cause exceptions. # Make sure invalid sha1 cause exceptions.
with pytest.raises(ValueError): with pytest.raises(ValueError):
sha1_store.get(tensor_sha1[:-1]) sha1_store.get(tensor_sha1[:-1])
def test_sha1_delete(sha1_store):
"""Testing the delete() API: with ref counting behavior."""
os.chdir(PARENT_DIR)
# Add once and delete, second delete should throw an exception.
tensor = torch.ones(30, 50)
sha1 = sha1_store.add(tensor)
sha1_store.delete(sha1)
with pytest.raises(ValueError):
sha1_store.delete(sha1)
# Add multiple times and delete should match that.
state_dict = nn.Sequential(nn.Linear(10, 10), nn.Linear(10, 20)).state_dict()
sha1 = sha1_store.add(state_dict)
for i in range(3):
new_sha1 = sha1_store.add(state_dict)
assert sha1 == new_sha1, f"{sha1} vs. {new_sha1}"
for i in range(4):
sha1_store.delete(sha1)
with pytest.raises(ValueError):
sha1_store.delete(sha1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment