Unverified Commit 2e544bd7 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[feat]: add size and names metadata to sha1 store (#1036)



* additional metadata, step 1

* add gzip option to repo::add

* add repo:add's return value and some refactoring and todo

* added size metadata to sha1_store

* added names metadata to sha1_store
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent 4d58a294
......@@ -21,6 +21,40 @@ from .sha1_store import SHA1_Store
SHA1_STORE_DIR_NAME = "sha1_store"
# These are on-disk keys. Don't modify for backward compatibility.
SHA1_KEY = "SHA1"
LAST_MODIFIED_TS_KEY = "last_modified_time_stamp"
REL_PATH_KEY = "file_path" # this will be removed from the json since it is redundant.
class RepoStatus(Enum):
"""Collections of Repo Statuses"""
CLEAN = 1
CHANGES_NOT_ADDED = 2
CHANGES_ADDED_NOT_COMMITED = 3
@dataclass
class SizeInfo:
"""Size info for a file or the repo in bytes.
Deduped size can't be disabled. So it will always be there.
Both sparsified and gzipped are optional. They are applied in the following
order if both are enabled:
sparsify -> gzip
Therefore, original >= deduped >= sparsified >= gzipped
"""
original: int
deduped: int
sparsified: int
gzipped: int
@dataclass
class _SHA1_Tensor:
"""Representing a tensor using sha1(s) from SHA1 store.
......@@ -155,61 +189,130 @@ class Repo:
sys.stderr.write("fatal: no wgit repo exists!\n")
sys.exit(1)
def add(self, in_file_path: str, per_tensor: bool = False) -> None:
def add(
self,
in_file_path: str,
per_tensor: bool = False,
gzip: bool = True,
sparsify: bool = False,
sparsify_policy: Any = None,
) -> Optional[Dict[Any, Any]]:
"""Add a file to the wgit repo.
This could a new file or a modified file. Adding an unmodified, existing file
is allowed but it is a noop.
Args:
in_file_path (str):
Path to the file to be added.
per_tensor (bool):
Add a file in a per-tensor fashion.
per_tensor (bool, optional):
Add a file in a per-tensor fashion. This enables more deduplication
due to tensors being identical. Deduplication cannot be disabled
completely because we use a content addressable SHA1_Store class.
Default: False
gzip (bool, optional):
Enable gzip based lossless compression on the object being added.
Default: True
sparsify (bool, optional):
Enable sparsify for the tensors, which is going to modify the values
for all or some tensors, i.e. lossy compression.
Default: False
sparsify_policy (Any):
TODO (Min): need to add a callback function to control which tensors
and how to sparsify.
Default: None
Returns:
(Dict, optional)
None if the content is added but not modified with lossy compression.
Otherwise, returns a state_dict that contains the modified Tensors to
be loaded back into the model, which means the tensors are dense, not
SST and DST tensors.
"""
self._sanity_check()
# create the corresponding metadata file
if sparsify and not per_tensor:
raise ValueError("Only support sparsity when per_tensor is true")
# Create the corresponding metadata file or load it if the file is
# not a newly added file.
file_path = Path(in_file_path)
rel_file_path = self._rel_file_path(file_path)
metadata_file = self._process_metadata_file(rel_file_path)
# add the file to the sha1_store
# Add the file to the sha1_store.
ret_state_dict = None
file_path_or_state_dict: Union[Path, Dict] = file_path
# TODO (Min): We don't add parent sha1 tracking to sha1 store due to
# de-duplication & dependency tracking can create cycles.
# We need to figure out a way to handle deletion.
sha1_dict = {}
# TODO (Min): We don't detect changes and compute delta on a modified file
# yet. Need to figure out a method for delta tracking.
if per_tensor:
def fn(element: Any) -> Any:
"""Callback on each leaf object for _recursive_apply_to_elements below."""
if isinstance(element, Tensor):
if sparsify:
# TODO (Min): here we will optionally do SST/DST and add those
# tensors with sparsity.
sha1 = self._sha1_store.add(element, compress=True)
# Remember to update ret_state_dict
raise NotImplementedError()
sha1 = self._sha1_store.add(element, compress=gzip)
return _SHA1_Tensor(is_dense=True, dense_sha1=sha1)
else:
return element
state_dict = torch.load(file_path)
_recursive_apply_to_elements(state_dict, fn)
sha1_dict = {"__sha1_full__": self._sha1_store.add(state_dict)}
else:
sha1_dict = {"__sha1_full__": self._sha1_store.add(file_path)}
file_path_or_state_dict = state_dict
# Add this top-level object.
sha1 = self._sha1_store.add(file_path_or_state_dict, compress=gzip)
# write metadata to the metadata-file
self._write_metadata(metadata_file, file_path, sha1_dict)
self._write_metadata(metadata_file, file_path, sha1)
self._pygit.add() # add to the .wgit/.git repo
return ret_state_dict
def commit(self, message: str) -> None:
"""Commits staged changes to the repo.
Args:
message (str):
The commit message
The commit message to be added.
"""
self._sanity_check()
# TODO (Min): make commit message a json for better handling of metadata like step count,
# LR, sparsity level, etc.
self._pygit.commit(message)
def status(self) -> Dict:
def size_info(self, path: Optional[str] = None) -> SizeInfo:
"""Get size info for a file or the whole repo.
For the whole repo, just call size_info from sha1_store.
For a file, needs to open the metadata and find the sha1 and then
for per_tensor state_dict, collect size_info on all objects.
TODO (Min): not exactly clear it is easy to compute this with
delta encoding, deduplication between objects, this
is possible to compute precisely.
Args:
path (str, optional):
File path for the query. If None, return whole repo's info.
Default: None
Returns:
(SizeInfo):
The dataclass that contains the size info.
"""
raise NotImplementedError()
def status(self) -> Dict[str, RepoStatus]:
"""Show the state of the weigit working tree.
State can be
......@@ -218,6 +321,8 @@ class Repo:
3. clean and tracking files after a change has been committed,
or clean with with an empty repo.
TODO (Min): this needs to return repo status and dirty files and untracked
files too.
Returns:
(dict):
A dict keyed with files and their status.
......@@ -250,6 +355,8 @@ class Repo:
"""
self._sanity_check()
# TODO (Min): this should return a list of sha1 for the history as well as
# each commit's message, which could be a dict from json commit msg.
if file:
print(f"wgit log of the file: {file}")
else:
......@@ -263,25 +370,22 @@ class Repo:
The sha1 hash of the file version to checkout.
"""
self._sanity_check()
raise NotImplementedError
def compression(self) -> None:
"""Not Implemented: Compression functionalities"""
self._sanity_check()
raise NotImplementedError
raise NotImplementedError()
def checkout_by_steps(self) -> None:
"""Not Implemented: Checkout by steps"""
"""Not Implemented: Checkout by step count of the train process"""
self._sanity_check()
raise NotImplementedError
raise NotImplementedError()
def _get_metdata_files(self) -> Dict:
def _get_metdata_files(self) -> Dict[str, bool]:
"""Walk the directories that contain the metadata files and check the
status of those files, whether they have been modified or not.
Dict[str, bool] is a path in string and whether the file is_modified.
"""
metadata_d = dict()
for file in self._dot_wgit_dir_path.iterdir(): # iterate over the .wgit directory
# exlude all the .wgit files and directory
# exclude all the .wgit files and directory
if file.name not in {"sha1_store", ".git", ".gitignore"}:
# perform a directory walk on the metadata_file directories to find the metadata files
for path in file.rglob("*"):
......@@ -297,11 +401,7 @@ class Repo:
try:
with open(file) as f:
metadata = json.load(f)
is_metadata = set(metadata.keys()) == {
"SHA1",
"file_path",
"last_modified_time_stamp",
} # TODO: Consider storing the keys as a class attribute, instead of hard coding.
is_metadata = set(metadata.keys()) == {SHA1_KEY, LAST_MODIFIED_TS_KEY, REL_PATH_KEY}
except json.JSONDecodeError:
return False # not a json file, so not valid metadata file
return is_metadata
......@@ -315,8 +415,8 @@ class Repo:
# Get the last modified timestamp recorded by weigit and the current modified
# timestamp. If not the same, then file has been modified since last weigit
# updated metadata.
last_mod_timestamp = data["last_modified_time_stamp"]
curr_mod_timestamp = Path(data["file_path"]).stat().st_mtime
last_mod_timestamp = data[LAST_MODIFIED_TS_KEY]
curr_mod_timestamp = Path(data[REL_PATH_KEY]).stat().st_mtime
return not curr_mod_timestamp == last_mod_timestamp
def _process_metadata_file(self, metadata_fname: Path) -> Path:
......@@ -334,13 +434,13 @@ class Repo:
ref_data = json.load(f)
return metadata_file
def _write_metadata(self, metadata_file: Path, file_path: Path, sha1_dict: Dict) -> None:
def _write_metadata(self, metadata_file: Path, file_path: Path, sha1: str) -> None:
"""Write metadata to the metadata file"""
change_time = Path(file_path).stat().st_mtime
metadata = {
"SHA1": sha1_dict,
"file_path": str(file_path),
"last_modified_time_stamp": change_time,
SHA1_KEY: sha1,
LAST_MODIFIED_TS_KEY: change_time,
REL_PATH_KEY: str(file_path),
}
with open(metadata_file, "w", encoding="utf-8") as f:
json.dump(metadata, f, ensure_ascii=False, indent=4)
......@@ -356,11 +456,3 @@ class Repo:
pass
# return the relative part (path not common to cwd)
return Path(*filepath.parts[i:])
class RepoStatus(Enum):
"""Collections of Repo Statuses"""
CLEAN = 1
CHANGES_NOT_ADDED = 2
CHANGES_ADDED_NOT_COMMITED = 3
......@@ -11,7 +11,7 @@ import shutil
import sys
import tempfile
import time
from typing import Any, Dict, Optional, Union, cast
from typing import Any, Dict, Optional, Tuple, Union, cast
import pgzip
import torch
......@@ -19,9 +19,23 @@ from torch import Tensor
from .utils import ExitCode
#
# Const string keys for json file. Do not change for backward compatibilities.
RF_KEY = "ref_count"
COMP_KEY = "compressed"
#
# For each object entry in the metadata json file.
ENTRY_RF_KEY = "ref_count" # int, reference count for this object.
ENTRY_COMP_KEY = "compressed" # bool, is compressed or not.
ENTRY_OS_KEY = "original_size" # int, original size for all identical objects mapped to this object.
ENTRY_DS_KEY = "deduped_size" # int, size after deduplication (always enabled).
ENTRY_CS_KEY = "compressed_size" # int, size after gzip compression, if enabled.
ENTRY_NAMES_KEY = "names" # dict, names of objects and their count mapped to this object.
# For the entire store in the metadata json file.
STORE_CREATE_DATE_KEY = "created_on" # str, when is the store created.
STORE_OS_KEY = "original_size" # int, original size for all objects added.
STORE_DS_KEY = "deduped_size" # int, size after deduplication (always enabled).
STORE_CS_KEY = "compressed_size" # int, size after gzip compression, if enabled on any object within the store.
def _get_json_entry(d: Dict[str, Any]) -> Dict[str, Any]:
......@@ -30,13 +44,28 @@ def _get_json_entry(d: Dict[str, Any]) -> Dict[str, Any]:
This fills in any missing entries in case we load an older version
json file from the disk.
"""
if RF_KEY not in d.keys():
d[RF_KEY] = 0
for int_key_init_zero in [ENTRY_RF_KEY, ENTRY_OS_KEY, STORE_DS_KEY, ENTRY_CS_KEY]:
if int_key_init_zero not in d.keys():
d[int_key_init_zero] = 0
for bool_key_init_false in [ENTRY_COMP_KEY]:
if bool_key_init_false not in d.keys():
d[bool_key_init_false] = False
for dict_key_init_empty in [ENTRY_NAMES_KEY]:
if dict_key_init_empty not in d.keys():
d[dict_key_init_empty] = {}
return d
def _copy_compressed(src: Path, dest: Path, thread: Optional[int], blocksize: int) -> None:
"""Helper to copy a file and compress it at the same time."""
def _copy_compressed(src: Path, dest: Path, thread: Optional[int], blocksize: int) -> Tuple[int, int]:
"""Helper to copy a file and compress it at the same time.
Returns:
(int, int):
original size and compressed size in bytes.
"""
with open(str(src), "rb") as srcf:
with pgzip.open(str(dest), "wb", compresslevel=5, thread=thread, blocksize=blocksize) as destf:
while True:
......@@ -44,6 +73,9 @@ def _copy_compressed(src: Path, dest: Path, thread: Optional[int], blocksize: in
if len(buf) == 0:
break
destf.write(buf)
orig, comp = Path(src).stat().st_size, Path(dest).stat().st_size
assert orig >= comp, f"Compressed size {comp} > original {orig}"
return orig, comp
def _copy_uncompressed(src: Path, dest: Path, thread: Optional[int], blocksize: int) -> None:
......@@ -118,7 +150,12 @@ class SHA1_Store:
self._sha1_buf_size = sha1_buf_size
self._pgzip_threads = pgzip_threads
self._pgzip_block_size = pgzip_block_size
self._json_dict: Dict[str, Any] = {"created_on": time.ctime()}
self._json_dict: Dict[str, Any] = {
STORE_CREATE_DATE_KEY: time.ctime(),
STORE_OS_KEY: 0,
STORE_DS_KEY: 0,
STORE_CS_KEY: 0,
}
# Initialize the store if not exist and if init is True.
if init and not self._path.exists():
......@@ -137,7 +174,7 @@ class SHA1_Store:
# Make sure there is a valid metadata file.
self._load_json_dict()
assert "created_on" in self._json_dict, f"Invalid SHA1 Store in {self._path}"
assert STORE_CREATE_DATE_KEY in self._json_dict, f"Invalid SHA1 Store in {self._path}"
# Init temp dir.
if tmp_dir:
......@@ -160,7 +197,7 @@ class SHA1_Store:
with open(self._metadata_file_path, "w", encoding="utf-8") as f:
json.dump(self._json_dict, f, ensure_ascii=False, indent=4)
def add(self, file_or_obj: Union[Path, Tensor, Dict], compress: bool = False) -> str:
def add(self, file_or_obj: Union[Path, Tensor, Dict], compress: bool = True, name: str = None) -> str:
"""Adds a file/object to this store and the sha1 references accordingly.
First, a sha1 hash is calculated. Utilizing the sha1 hash string, the actual file
......@@ -177,6 +214,12 @@ class SHA1_Store:
you call `state_dict()` on a nn.Module, and it is an instance
of a Dict too. A model's state_dict can be a simple dict because
it may contain both model state_dict and other non-tensor info.
compress (bool, optional):
Use gzip compression on this object or not.
Default: True
name (str, optional):
Optional name for this object.
Default: None
"""
# Use `isinstance` not type() == Path since pathlib returns OS specific
# Path types, which inherit from the Path class.
......@@ -200,9 +243,7 @@ class SHA1_Store:
# Add reference.
ref_count = self._add_ref(sha1_hash, True, compress)
if ref_count == 1:
# First time adding
if ref_count == 1: # First time adding.
# Create the file dir, if needed.
repo_fdir = self._sha1_to_dir(sha1_hash)
if not repo_fdir.exists():
......@@ -216,15 +257,41 @@ class SHA1_Store:
repo_fpath = repo_fdir.joinpath(sha1_hash)
try:
if compress:
_copy_compressed(file_path, repo_fpath, self._pgzip_threads, self._pgzip_block_size)
orig_size, comp_size = _copy_compressed(
file_path, repo_fpath, self._pgzip_threads, self._pgzip_block_size
)
else:
shutil.copy2(file_path, repo_fpath)
orig_size = comp_size = file_path.stat().st_size
except BaseException as error:
# Something went wrong, perhaps out of space, or race condition due to lack of locking.
# TODO (Min): proper handle the error and recover when we learn more here.
sys.stderr.write(f"An exception occured: {repr(error)}\n")
ref_count = self._add_ref(sha1_hash, False, compress)
self._load_json_dict()
# Update the sizes for this entry.
entry = _get_json_entry(self._json_dict[sha1_hash])
o_diff = orig_size if ref_count == 1 else entry[ENTRY_OS_KEY]
d_diff = orig_size if ref_count == 1 else 0
c_diff = comp_size if ref_count == 1 else 0
entry[ENTRY_OS_KEY] += o_diff
entry[ENTRY_DS_KEY] += d_diff
entry[ENTRY_CS_KEY] += c_diff
self._json_dict[STORE_OS_KEY] += o_diff
self._json_dict[STORE_DS_KEY] += d_diff
self._json_dict[STORE_CS_KEY] += c_diff
# Update the name list for this entry.
if name:
if name not in entry[ENTRY_NAMES_KEY].keys():
entry[ENTRY_NAMES_KEY][name] = 1
else:
entry[ENTRY_NAMES_KEY][name] += 1
self._store_json_dict()
# Clean up if needed.
if remove_tmp:
file_path.unlink()
......@@ -257,7 +324,7 @@ class SHA1_Store:
# TODO (Min): we could also keep a stats in the meta data on how many
# times the object is read. Will add if that's needed.
self._load_json_dict()
if self._json_dict[sha1][COMP_KEY]:
if self._json_dict[sha1][ENTRY_COMP_KEY]:
# Compressed. Because pgzip doesn't support tell() yet, we need to
# uncompress into a temp file and return it.
tmp = self._get_tmp_file_path()
......@@ -290,9 +357,9 @@ class SHA1_Store:
assert sha1 in self._json_dict.keys(), "internal error: sha1 not found in json"
entry = _get_json_entry(self._json_dict[sha1])
assert entry[RF_KEY] > 0, f"ref count {entry[RF_KEY]} should be positive"
entry[RF_KEY] -= 1
if entry[RF_KEY] == 0:
assert entry[ENTRY_RF_KEY] > 0, f"ref count {entry[ENTRY_RF_KEY]} should be positive"
entry[ENTRY_RF_KEY] -= 1
if entry[ENTRY_RF_KEY] == 0:
# Now, since ref count is 0 now deleting the object.
path.unlink() # We may leave behind an empty dir, which is OK.
entry = {} # Below, we remove the entry because of this.
......@@ -305,6 +372,24 @@ class SHA1_Store:
del self._json_dict[sha1]
self._store_json_dict()
def size_info(self, sha1: Optional[str] = None) -> Tuple[int, int, int]:
"""Return original, deduped, gzipped sizes for an entry or the store."""
self._load_json_dict()
if sha1:
if sha1 not in self._json_dict.keys():
raise ValueError(f"SHA1 {sha1} not found")
entry = self._json_dict[sha1]
return entry[ENTRY_OS_KEY], entry[ENTRY_DS_KEY], entry[ENTRY_CS_KEY]
return self._json_dict[STORE_OS_KEY], self._json_dict[STORE_DS_KEY], self._json_dict[STORE_CS_KEY]
def names(self, sha1: str = None) -> Dict[str, int]:
"""Return the names dict for an object."""
self._load_json_dict()
if sha1 not in self._json_dict.keys():
raise ValueError(f"SHA1 {sha1} not found")
entry = self._json_dict[sha1]
return entry[ENTRY_NAMES_KEY]
def _get_sha1_hash(self, file_path: Union[str, Path]) -> str:
"""Return the sha1 hash of a file
......@@ -366,13 +451,13 @@ class SHA1_Store:
entry = _get_json_entry(entry)
# Update the ref count.
entry[RF_KEY] += 1 if inc else -1
assert entry[RF_KEY] >= 0, "negative ref count"
entry[ENTRY_RF_KEY] += 1 if inc else -1
assert entry[ENTRY_RF_KEY] >= 0, "negative ref count"
# Update compressed flag.
entry[COMP_KEY] = compressed
entry[ENTRY_COMP_KEY] = compressed
self._json_dict[current_sha1_hash] = entry
self._store_json_dict()
return entry[RF_KEY]
return entry[ENTRY_RF_KEY]
......@@ -59,10 +59,11 @@ def test_api_init(capsys, repo):
@pytest.mark.parametrize("per_tensor", [True, False])
def test_api_add(capsys, repo, per_tensor):
@pytest.mark.parametrize("gzip", [True, False])
def test_api_add(capsys, repo, per_tensor, gzip):
fnum = random.randint(0, 2)
chkpt0 = f"checkpoint_{fnum}.pt"
repo.add(chkpt0, per_tensor)
repo.add(chkpt0, per_tensor=per_tensor, gzip=gzip)
if per_tensor:
# TODO (Min): test per_tensor add more.
return
......@@ -73,7 +74,7 @@ def test_api_add(capsys, repo, per_tensor):
json_data = json.load(f)
sha1_dir_0 = f"{sha1_hash[:2]}/" + f"{sha1_hash[2:]}"
assert json_data["SHA1"] == {"__sha1_full__": sha1_hash}
assert json_data["SHA1"] == sha1_hash
def test_api_commit(capsys, repo):
......
......@@ -64,7 +64,7 @@ def test_cli_add(create_test_dir, capsys):
json_data = json.load(f)
sha1_dir_0 = f"{sha1_hash[:2]}/" + f"{sha1_hash[2:]}"
assert json_data["SHA1"] == {"__sha1_full__": sha1_hash}
assert json_data["SHA1"] == sha1_hash
def test_cli_commit(capsys):
......
......@@ -19,6 +19,9 @@ from fairscale.internal import torch_version
# so that we can proper clean it up at any CWD.
TESTING_STORE_DIR = Path("sha1_store_testing").resolve()
# Used to filter metadata json keys.
SHA1_KEY_STR_LEN = 40
@pytest.fixture(scope="function")
def sha1_store(request):
......@@ -83,7 +86,7 @@ def test_sha1_add_file(sha1_store, compress):
# torch 1.8 LTS doesn't produce deterministic checkpoint file from fixed tensors/state_dict.
key = "da3e19590de8f77fcf7a09c888c526b0149863a0"
assert key in json_dict.keys() and json_dict[key]["ref_count"] == 2, json_dict
del json_dict["created_on"]
json_dict = dict(filter(lambda item: len(item[0]) == SHA1_KEY_STR_LEN, json_dict.items()))
assert sorted(map(lambda x: x["ref_count"], json_dict.values())) == [1, 1, 1, 1, 1, 2], json_dict
......@@ -101,7 +104,7 @@ def test_sha1_add_state_dict(sha1_store, compress):
sha1_store._load_json_dict()
json_dict = sha1_store._json_dict
del json_dict["created_on"]
json_dict = dict(filter(lambda item: len(item[0]) == SHA1_KEY_STR_LEN, json_dict.items()))
assert sorted(map(lambda x: x["ref_count"], json_dict.values())) == [1, 1, 1, 2, 2, 2], json_dict
......@@ -168,3 +171,32 @@ def test_sha1_delete(sha1_store, compress):
sha1_store.delete(sha1)
with pytest.raises(ValueError):
sha1_store.delete(sha1)
@pytest.mark.parametrize("compress", [True, False])
def test_sha1_size_info_and_names(sha1_store, compress):
"""Testing the size_info() and names() APIs."""
os.chdir(TESTING_STORE_DIR)
# Add once & check.
tensor = torch.ones(300, 500)
sha1 = sha1_store.add(tensor, compress=compress, name="name1")
orig, dedup, gzip = sha1_store.size_info(sha1)
assert orig == dedup, "no dedup should happen"
if not compress:
assert orig == gzip, "no compression should happen"
else:
assert orig > gzip, "compression should be smaller"
assert (orig, dedup, gzip) == sha1_store.size_info(), "store and entry sizes should match"
names = sha1_store.names(sha1)
assert names == {"name1": 1}, names
# Add second time & check.
sha1 = sha1_store.add(tensor, compress=compress, name="name2")
orig2, dedup2, gzip2 = sha1_store.size_info(sha1)
assert orig2 == orig * 2 == dedup2 * 2, "dedup not correct"
assert gzip == gzip2, "compression shouldn't change"
names = sha1_store.names(sha1)
assert names == {"name1": 1, "name2": 1}, names
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment