Unverified Commit 4d58a294 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[feat]: add per-tensor add to repo (#1033)



* formatting change, no logical change

* formatting and name change, no logical change

* [refactor] sha1_store's path arg

- make sha1_store's path arg directly the path, not its parent
- this is because sha1_store is not like a .git or a .wgit dir, which is
  nested inside another "working" dir. It is simply a store, which
  is using a given dir.
- updated repo and tests as well.

* remove a test warning due to deprecated API from torch

* [refactor] change how dot_wgit_dir_path is used

- it should only be assigned in __init__.
- we use it in error checking in the rest APIs.

* simplify the init a bit

* refactor the sanity check

* moved some functions, no code change

* [feat] added per-tensor add to the repo

* enabled gzip compression on add

* fix a unit test

* add a note

* make sha1 store work on general dict

* handle general state_dict from a model, not just a module's one-level OrderedDict

* formatting
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent d0ad08c0
...@@ -503,7 +503,7 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False, dict_key: O ...@@ -503,7 +503,7 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False, dict_key: O
return all(objects_are_equal(x, y, raise_exception) for x, y in zip(a, b)) return all(objects_are_equal(x, y, raise_exception) for x, y in zip(a, b))
elif torch.is_tensor(a): elif torch.is_tensor(a):
try: try:
# assert_allclose doesn't strictly test shape, dtype and device # assert_close doesn't strictly test shape, dtype and device
shape_dtype_device_match = a.size() == b.size() and a.dtype == b.dtype and a.device == b.device shape_dtype_device_match = a.size() == b.size() and a.dtype == b.dtype and a.device == b.device
if not shape_dtype_device_match: if not shape_dtype_device_match:
if raise_exception: if raise_exception:
...@@ -513,8 +513,11 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False, dict_key: O ...@@ -513,8 +513,11 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False, dict_key: O
raise AssertionError(msg) raise AssertionError(msg)
else: else:
return False return False
# assert_allclose. # assert_close.
torch.testing.assert_allclose(a, b) if torch_version() < (1, 12, 0):
torch.testing.assert_allclose(a, b)
else:
torch.testing.assert_close(a, b)
return True return True
except (AssertionError, RuntimeError) as e: except (AssertionError, RuntimeError) as e:
if raise_exception: if raise_exception:
......
...@@ -3,199 +3,297 @@ ...@@ -3,199 +3,297 @@
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from enum import Enum from enum import Enum
import json import json
import pathlib
from pathlib import Path from pathlib import Path
import sys import sys
from typing import Dict, Tuple, Union from typing import Any, Dict, List, Optional, Union
import torch
from torch import Tensor
from .pygit import PyGit from .pygit import PyGit
from .sha1_store import SHA1_Store from .sha1_store import SHA1_Store
# This is a fixed dir name we use for sha1_store. It should not be changed
# for backward compatibility reasons.
SHA1_STORE_DIR_NAME = "sha1_store"
@dataclass
class _SHA1_Tensor:
"""Representing a tensor using sha1(s) from SHA1 store.
It can be either a dense one or 2 sparse one with SST and DST.
"""
is_dense: bool = True
dense_sha1: str = ""
sst_sha1: str = ""
dst_sha1: str = ""
def _recursive_apply_to_elements(data: Union[List[Any], Dict[str, Any]], fn: Any) -> None:
"""Helper function to traverse a dict recursively and apply a function to leafs.
The input `data` is a dict or a list and it should only contain dict and list.
"""
if isinstance(data, list):
for i, _ in enumerate(data):
if isinstance(data[i], (list, dict)):
_recursive_apply_to_elements(data[i], fn)
else:
data[i] = fn(data[i])
elif isinstance(data, dict):
for key in data.keys():
if isinstance(data[key], (list, dict)):
_recursive_apply_to_elements(data[key], fn)
else:
data[key] = fn(data[key])
else:
assert False, f"Unexpected data type: {type(data)}"
class Repo: class Repo:
""" """
Represents the WeiGit repo for tracking neural network weights and their versions. Represents the WeiGit repo for tracking neural network weights and their versions.
A WeiGit repo is like a git repo. It is a dir, in which a .wgit dir exists to keep
track of the content.
Args: Args:
parent_dir (pathlib.Path, str) parent_dir (Path, str):
The path to the parent directory where a weigit repo will be created. In the case a repo already exists, it will be wrapped with this class. Parent dir in which to make or to load a .wgit dir.
init (bool, optional) Default: "", which means CWD.
- If ``True``, initializes a new WeiGit repo in the parent_dir. Initialization creates a `.wgit` directory within the <parent_dir>, triggers an initialization init (bool, optional):
of a sha1_store in the ./<parent_dir>/.wgit directory, and makes the ./<parent_dir>/.wgit a git repository through git initialization. - If ``True``, initializes a new WeiGit repo in the parent_dir. Initialization
- If ``False``, a new WeiGit repo is not initialized and the existing repo is simply wrapped, populating the `wgit_parent` and other internal attributes. creates a `.wgit` directory within the <parent_dir>, triggers an initialization.
of a sha1_store in the ./<parent_dir>/.wgit directory, and makes the
./<parent_dir>/.wgit a git repository through git initialization.
- If ``False``, a new WeiGit repo is not initialized and the existing repo is
wrapped, populating the `_wgit_parent` and other internal attributes.
- Default: False - Default: False
""" """
def __init__(self, parent_dir: Union[Path, str] = Path.cwd(), init: bool = False) -> None: def __init__(self, parent_dir: Union[Path, str] = "", init: bool = False) -> None:
"""initialize a weigit repo: Subsequently, also initialize a sha1_store and a pygit git repo within as # Set _wgit_parent.
part of the weigit initialization process""" self._wgit_parent = Path(parent_dir if parent_dir != "" else Path.cwd())
# If repo does not exist, creates a new wgit repo object with self.repo.path pointing to the path of repo
# and notes all the internal files. # Set _dot_wgit_dir_path.
# else, if repo already exists: create a pygit object from the .wgit/.git. self._dot_wgit_dir_path: Optional[Path] = None
self.wgit_parent = Path(parent_dir) exists = self._recursive_search_and_may_init_dot_wgit_dir_path(self._wgit_parent)
self._repo_path: Union[None, Path] = None
self._wgit_git_path = Path(".wgit/.git")
self._sha1_store_path = Path(".wgit/sha1_store")
exists = self._exists(self.wgit_parent)
if not exists and init: if not exists and init:
# No weigit repo exists and is being initialized with init=True # No weigit repo exists and is being initialized with init=True
# Make .wgit directory, create sha1_store # Make .wgit directory, create sha1_store
weigit_dir = self.wgit_parent.joinpath(".wgit") self._dot_wgit_dir_path = self._wgit_parent.joinpath(".wgit")
weigit_dir.mkdir(parents=False, exist_ok=True) self._dot_wgit_dir_path.mkdir(parents=False, exist_ok=True)
# Initializing sha1_store only after wgit has been initialized! # Initializing sha1_store only after wgit has been initialized!
self._sha1_store = SHA1_Store(weigit_dir, init=True) self._sha1_store = SHA1_Store(self._dot_wgit_dir_path.joinpath(SHA1_STORE_DIR_NAME), init=True)
# # Make the .wgit a git repo # Create a git repo for the metadata versioning.
gitignore_files = [ self._pygit = PyGit(self._dot_wgit_dir_path, gitignore=[SHA1_STORE_DIR_NAME])
self._sha1_store_path.name,
]
self._pygit = PyGit(weigit_dir, gitignore=gitignore_files)
elif exists and init: elif exists:
# if weigit repo already exists and init is being called, wrap the existing .wgit/.git repo with PyGit # Weigit repo already exists, populate this object.
self._sha1_store = SHA1_Store(self.path) assert self._dot_wgit_dir_path is not None
self._pygit = PyGit(self.path) self._sha1_store = SHA1_Store(self._dot_wgit_dir_path.joinpath(SHA1_STORE_DIR_NAME))
self._pygit = PyGit(self._dot_wgit_dir_path)
elif exists and not init:
# weigit exists and non-init commands are triggered
self._sha1_store = SHA1_Store(self.path)
self._pygit = PyGit(self.path)
else: else:
# weigit doesn't exist and is not trying to be initialized (triggers during non-init commands) # weigit doesn't exist and is not trying to be initialized (triggers
# during non-init commands)
sys.stderr.write("fatal: not a wgit repository!\n") sys.stderr.write("fatal: not a wgit repository!\n")
sys.exit(1) sys.exit(1)
def add(self, in_file_path: str) -> None: # We are done init. Do a check.
""" self._sanity_check()
Adds a file to the wgit repo.
def _recursive_search_and_may_init_dot_wgit_dir_path(self, check_dir: Path) -> bool:
"""Search for a wgit repo top level dir from potentiall a subdir of a repo.
This may set the self._dot_wgit_dir_path if a repo is found.
Args: Args:
file_path (str) check_dir (Path):
Path to the file to be added to the weigit repo Path to the directory from where search is started.
Returns:
Returns True if a repo is found.
""" """
if self._exists(self.wgit_parent): assert self._dot_wgit_dir_path is None, f"_dot_wgit_dir_path is already set to {self._dot_wgit_dir_path}"
# create the corresponding metadata file if self._weigit_repo_exists(check_dir):
file_path = Path(in_file_path) self._dot_wgit_dir_path = check_dir.joinpath(".wgit")
rel_file_path = self._rel_file_path(file_path)
metadata_file, parent_sha1 = self._process_metadata_file(rel_file_path)
# add the file to the sha1_store
# TODO (Min): We don't add parent sha1 tracking to sha1 store due to
# de-duplication & dependency tracking can create cycles.
# We need to figure out a way to handle deletion.
sha1_hash = self._sha1_store.add(file_path)
# write metadata to the metadata-file
self._write_metadata(metadata_file, file_path, sha1_hash)
self._pygit.add() # add to the .wgit/.git repo
else: else:
root = Path(check_dir.parts[0])
while check_dir != root:
check_dir = check_dir.parent
if self._weigit_repo_exists(check_dir):
self._dot_wgit_dir_path = check_dir.joinpath(".wgit")
break
return True if self._dot_wgit_dir_path is not None else False
def _weigit_repo_exists(self, check_dir: Path) -> bool:
"""Returns True if a valid WeiGit repo exists in the path: check_dir."""
wgit_exists, git_exists, gitignore_exists = self._weigit_repo_file_check(check_dir)
return wgit_exists and git_exists and gitignore_exists
def _weigit_repo_file_check(self, check_dir: Path) -> tuple:
"""Returns a tuple of boolean corresponding to the existence of each
.wgit internally required files.
"""
wgit_exists = check_dir.joinpath(".wgit").exists()
git_exists = check_dir.joinpath(".wgit/.git").exists()
gitignore_exists = check_dir.joinpath(".wgit/.gitignore").exists()
return wgit_exists, git_exists, gitignore_exists
def _sanity_check(self) -> None:
"""Helper to check if on-disk state matches what we expect."""
if not self._weigit_repo_exists(self._wgit_parent):
sys.stderr.write("fatal: no wgit repo exists!\n") sys.stderr.write("fatal: no wgit repo exists!\n")
sys.exit(1) sys.exit(1)
def commit(self, message: str) -> None: def add(self, in_file_path: str, per_tensor: bool = False) -> None:
"""Add a file to the wgit repo.
Args:
in_file_path (str):
Path to the file to be added.
per_tensor (bool):
Add a file in a per-tensor fashion.
""" """
Commits staged changes to the repo. self._sanity_check()
# create the corresponding metadata file
file_path = Path(in_file_path)
rel_file_path = self._rel_file_path(file_path)
metadata_file = self._process_metadata_file(rel_file_path)
# add the file to the sha1_store
# TODO (Min): We don't add parent sha1 tracking to sha1 store due to
# de-duplication & dependency tracking can create cycles.
# We need to figure out a way to handle deletion.
sha1_dict = {}
if per_tensor:
def fn(element: Any) -> Any:
"""Callback on each leaf object for _recursive_apply_to_elements below."""
if isinstance(element, Tensor):
# TODO (Min): here we will optionally do SST/DST and add those
# tensors with sparsity.
sha1 = self._sha1_store.add(element, compress=True)
return _SHA1_Tensor(is_dense=True, dense_sha1=sha1)
else:
return element
state_dict = torch.load(file_path)
_recursive_apply_to_elements(state_dict, fn)
sha1_dict = {"__sha1_full__": self._sha1_store.add(state_dict)}
else:
sha1_dict = {"__sha1_full__": self._sha1_store.add(file_path)}
# write metadata to the metadata-file
self._write_metadata(metadata_file, file_path, sha1_dict)
self._pygit.add() # add to the .wgit/.git repo
def commit(self, message: str) -> None:
"""Commits staged changes to the repo.
Args: Args:
message (str) message (str):
The commit message The commit message
""" """
if self._exists(self.wgit_parent): self._sanity_check()
self._pygit.commit(message)
else: self._pygit.commit(message)
sys.stderr.write("fatal: no wgit repo exists!\n")
sys.exit(1)
def status(self) -> Dict: def status(self) -> Dict:
"""Show the state of the weigit working tree. State can be """Show the state of the weigit working tree.
1. dirty with changes/modifications not added to weigit repo,
2. dirty with a file changes added but not committed State can be
3. clean and tracking files after a change has been committed, or clean with with an empty repo. 1. dirty with changes/modifications not added to weigit repo.
2. dirty with a file changes added but not committed.
3. clean and tracking files after a change has been committed,
or clean with with an empty repo.
Returns:
(dict):
A dict keyed with files and their status.
""" """
if self._exists(self.wgit_parent): self._sanity_check()
pygit_status = self._pygit.status()
status = self._get_metdata_files() pygit_status = self._pygit.status()
if status: status = self._get_metdata_files()
out_status = dict() if status:
for metadata_file, is_modified in status.items(): out_status = dict()
# if metadata_file is among the keys of pygit_status dict, it has not been commited to git yet. for metadata_file, is_modified in status.items():
if is_modified: # if metadata_file is among the keys of pygit_status dict, it has not been commited to git yet.
out_status[str(metadata_file)] = RepoStatus.CHANGES_NOT_ADDED if is_modified:
elif not is_modified and metadata_file in pygit_status.keys(): out_status[str(metadata_file)] = RepoStatus.CHANGES_NOT_ADDED
out_status[str(metadata_file)] = RepoStatus.CHANGES_ADDED_NOT_COMMITED elif not is_modified and metadata_file in pygit_status.keys():
elif not is_modified and metadata_file not in pygit_status.keys(): out_status[str(metadata_file)] = RepoStatus.CHANGES_ADDED_NOT_COMMITED
out_status[str(metadata_file)] = RepoStatus.CLEAN elif not is_modified and metadata_file not in pygit_status.keys():
return out_status out_status[str(metadata_file)] = RepoStatus.CLEAN
else: # if status dict is empty, nothing has been added so far. return out_status
return {"": RepoStatus.CLEAN} # sub case of case-3, clean with an empty repo else: # if status dict is empty, nothing has been added so far.
else: return {"": RepoStatus.CLEAN} # sub case of case-3, clean with an empty repo
sys.stderr.write("fatal: no wgit repo exists!\n")
sys.exit(1)
def log(self, file: str) -> None: def log(self, file: str) -> None:
""" """Returns the WeiGit log of commit history.
Returns the WeiGit log of commit history.
Args: Args:
file (str, optional) file (str, optional):
Show the log of the commit history of the repo. Optionally, show the log history of a specific file. Show the log of the commit history of the repo. Optionally, show
the log history of a specific file.
""" """
if self._exists(self.wgit_parent): self._sanity_check()
if file:
print(f"wgit log of the file: {file}") if file:
else: print(f"wgit log of the file: {file}")
print("wgit log")
else: else:
sys.stderr.write("fatal: no wgit repo exists!\n") print("wgit log")
sys.exit(1)
def checkout(self, sha1: str) -> None: def checkout(self, sha1: str) -> None:
""" """Checkout a previously commited version of the checkpoint.
Checkout a previously commited version of the checkpoint.
Args: Args:
sha1 (str) The sha1 hash of the file version to checkout. sha1 (str):
The sha1 hash of the file version to checkout.
""" """
self._sanity_check()
raise NotImplementedError raise NotImplementedError
def compression(self) -> None: def compression(self) -> None:
"""Not Implemented: Compression functionalities""" """Not Implemented: Compression functionalities"""
self._sanity_check()
raise NotImplementedError raise NotImplementedError
def checkout_by_steps(self) -> None: def checkout_by_steps(self) -> None:
"""Not Implemented: Checkout by steps""" """Not Implemented: Checkout by steps"""
self._sanity_check()
raise NotImplementedError raise NotImplementedError
@property
def path(self) -> Path:
"""Get the path to the WeiGit repo"""
if self._repo_path is None:
self._exists(self.wgit_parent)
return self._repo_path
def _get_metdata_files(self) -> Dict: def _get_metdata_files(self) -> Dict:
"""Walk the directories that contain the metadata files and check the status of those files, """Walk the directories that contain the metadata files and check the
whether they have been modified or not. status of those files, whether they have been modified or not.
""" """
metadata_d = dict() metadata_d = dict()
for file in self.path.iterdir(): # iterate over the .wgit directory for file in self._dot_wgit_dir_path.iterdir(): # iterate over the .wgit directory
# exlude all the .wgit files and directory # exlude all the .wgit files and directory
if file.name not in {"sha1_store", ".git", ".gitignore"}: if file.name not in {"sha1_store", ".git", ".gitignore"}:
# perform a directory walk on the metadata_file directories to find the metadata files # perform a directory walk on the metadata_file directories to find the metadata files
for path in file.rglob("*"): for path in file.rglob("*"):
if path.is_file(): if path.is_file():
rel_path = str(path.relative_to(self.path)) # metadata path relative to .wgit dir rel_path = str(path.relative_to(self._dot_wgit_dir_path)) # metadata path relative to .wgit dir
metadata_d[rel_path] = self._is_file_modified(path) metadata_d[rel_path] = self._is_file_modified(path)
return metadata_d return metadata_d
def _is_metadata_file(self, file: Path) -> bool: def _is_metadata_file(self, file: Path) -> bool:
"""Checks whether a file is a valid metadata file by matching keys and checking if it has valid """Checks whether a file is a valid metadata file by matching keys and
json data.""" checking if it has valid json data.
"""
try: try:
with open(file) as f: with open(file) as f:
metadata = json.load(f) metadata = json.load(f)
...@@ -209,37 +307,38 @@ class Repo: ...@@ -209,37 +307,38 @@ class Repo:
return is_metadata return is_metadata
def _is_file_modified(self, file: Path) -> bool: def _is_file_modified(self, file: Path) -> bool:
"""Checks whether a file has been modified since its last recorded modification time recorded in the metadata_file""" """Checks whether a file has been modified since its last recorded modification
time recorded in the metadata_file.
"""
with open(file) as f: with open(file) as f:
data = json.load(f) data = json.load(f)
# get the last modified timestamp recorded by weigit and the current modified timestamp. If not the # Get the last modified timestamp recorded by weigit and the current modified
# same, then file has been modified since last weigit updated metadata # timestamp. If not the same, then file has been modified since last weigit
# updated metadata.
last_mod_timestamp = data["last_modified_time_stamp"] last_mod_timestamp = data["last_modified_time_stamp"]
curr_mod_timestamp = Path(data["file_path"]).stat().st_mtime curr_mod_timestamp = Path(data["file_path"]).stat().st_mtime
return not curr_mod_timestamp == last_mod_timestamp return not curr_mod_timestamp == last_mod_timestamp
def _process_metadata_file(self, metadata_fname: Path) -> Tuple[Path, str]: def _process_metadata_file(self, metadata_fname: Path) -> Path:
"""Create a metadata_file corresponding to the file to be tracked by weigit if the first version of the file """Create a metadata_file corresponding to the file to be tracked by weigit if
is encountered. If a version already exists, open the file and get the sha1_hash of the last version as parent_sha1""" the first version of the file is encountered. If a version already exists, open
metadata_file = self.path.joinpath(metadata_fname) the file and get the sha1_hash of the last version as parent_sha1.
"""
metadata_file = self._dot_wgit_dir_path.joinpath(metadata_fname)
metadata_file.parent.mkdir(parents=True, exist_ok=True) # create parent dirs for metadata file metadata_file.parent.mkdir(parents=True, exist_ok=True) # create parent dirs for metadata file
if not metadata_file.exists() or not metadata_file.stat().st_size: if not metadata_file.exists() or not metadata_file.stat().st_size:
metadata_file.touch() metadata_file.touch()
parent_sha1 = "ROOT"
else: else:
with open(metadata_file, "r") as f: with open(metadata_file, "r") as f:
ref_data = json.load(f) ref_data = json.load(f)
parent_sha1 = ref_data["SHA1"]["__sha1_full__"] return metadata_file
return metadata_file, parent_sha1
def _write_metadata(self, metadata_file: Path, file_path: Path, sha1_hash: str) -> None: def _write_metadata(self, metadata_file: Path, file_path: Path, sha1_dict: Dict) -> None:
"""Write metadata to the metadata file""" """Write metadata to the metadata file"""
change_time = Path(file_path).stat().st_mtime change_time = Path(file_path).stat().st_mtime
metadata = { metadata = {
"SHA1": { "SHA1": sha1_dict,
"__sha1_full__": sha1_hash,
},
"file_path": str(file_path), "file_path": str(file_path),
"last_modified_time_stamp": change_time, "last_modified_time_stamp": change_time,
} }
...@@ -247,7 +346,9 @@ class Repo: ...@@ -247,7 +346,9 @@ class Repo:
json.dump(metadata, f, ensure_ascii=False, indent=4) json.dump(metadata, f, ensure_ascii=False, indent=4)
def _rel_file_path(self, filepath: Path) -> Path: def _rel_file_path(self, filepath: Path) -> Path:
"""Find the relative part to the filepath from the current working directory and return the relative path.""" """Find the relative part to the filepath from the current working
directory and return the relative path.
"""
# get the absolute path # get the absolute path
filepath = filepath.resolve() filepath = filepath.resolve()
# using zipped loop we get the path common to the filepath and cwd # using zipped loop we get the path common to the filepath and cwd
...@@ -256,37 +357,6 @@ class Repo: ...@@ -256,37 +357,6 @@ class Repo:
# return the relative part (path not common to cwd) # return the relative part (path not common to cwd)
return Path(*filepath.parts[i:]) return Path(*filepath.parts[i:])
def _exists(self, check_dir: Path) -> bool:
"""Returns True if a valid wgit exists within the cwd and iteratively checks to the root directory and
sets the self._repo_path attribute to the wgit path.
Args:
check_dir (Path)
path to the directory from where search is started.
"""
if self._weigit_repo_exists(check_dir):
self._repo_path = check_dir.joinpath(".wgit")
else:
root = Path(check_dir.parts[0])
while check_dir != root:
check_dir = check_dir.parent
if self._weigit_repo_exists(check_dir):
self._repo_path = check_dir.joinpath(".wgit")
break
return True if self._repo_path is not None else False
def _weigit_repo_exists(self, check_dir: pathlib.Path) -> bool:
"""Returns True if a valid WeiGit repo exists in the path: check_dir"""
wgit_exists, git_exists, gitignore_exists = self._weight_repo_file_check(check_dir)
return wgit_exists and git_exists and gitignore_exists
def _weight_repo_file_check(self, check_dir: Path) -> tuple:
"""Returns a tuple of boolean corresponding to the existence of each .wgit internally required files."""
wgit_exists = check_dir.joinpath(".wgit").exists()
git_exists = check_dir.joinpath(".wgit/.git").exists()
gitignore_exists = check_dir.joinpath(".wgit/.gitignore").exists()
return wgit_exists, git_exists, gitignore_exists
class RepoStatus(Enum): class RepoStatus(Enum):
"""Collections of Repo Statuses""" """Collections of Repo Statuses"""
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from collections import OrderedDict
import hashlib import hashlib
import json import json
from pathlib import Path from pathlib import Path
...@@ -20,10 +19,6 @@ from torch import Tensor ...@@ -20,10 +19,6 @@ from torch import Tensor
from .utils import ExitCode from .utils import ExitCode
# This is a fixed dir name we use for sha1_store. It should not be changed
# for backward compatibility reasons.
SHA1_STORE_DIR_NAME = "sha1_store"
# Const string keys for json file. Do not change for backward compatibilities. # Const string keys for json file. Do not change for backward compatibilities.
RF_KEY = "ref_count" RF_KEY = "ref_count"
COMP_KEY = "compressed" COMP_KEY = "compressed"
...@@ -92,15 +87,10 @@ class SHA1_Store: ...@@ -92,15 +87,10 @@ class SHA1_Store:
compression/decompression on top of it to use all the cores. compression/decompression on top of it to use all the cores.
Args: Args:
parent_path (Path): path (Path):
The parent path in which a SHA1_Store will be created. The path in which a SHA1_Store will be created.
init (bool, optional): init (bool, optional):
- If ``True``, initializes a new SHA1_Store in the parent_path. Initialization - If ``True``, a new SHA1_Store in the path if not already exists.
creates a `sha1_store` directory in ./<parent_path>/,
and a `ref_count.json` within ./<parent_path>/.
- If ``False``, a new `sha1_store` dir is not initialized and the existing
`sha1_store` is used to init this class, populating `_json_dict`, and other
attributes.
- Default: False - Default: False
sha1_buf_size (int): sha1_buf_size (int):
Buffer size used for checksumming. Default: 100MB. Buffer size used for checksumming. Default: 100MB.
...@@ -115,22 +105,22 @@ class SHA1_Store: ...@@ -115,22 +105,22 @@ class SHA1_Store:
def __init__( def __init__(
self, self,
parent_path: Path, path: Path,
init: bool = False, init: bool = False,
sha1_buf_size: int = 100 * 1024 * 1024, sha1_buf_size: int = 100 * 1024 * 1024,
tmp_dir: str = "", tmp_dir: str = "",
pgzip_threads: Optional[int] = None, pgzip_threads: Optional[int] = None,
pgzip_block_size: int = 10 * 1024 * 1024, pgzip_block_size: int = 10 * 1024 * 1024,
) -> None: ) -> None:
"""Create or wrap (if already exists) a sha1_store.""" """Create or wrap (if already exists) a store."""
self._path = parent_path.joinpath(SHA1_STORE_DIR_NAME) self._path = path
self._ref_file_path = self._path.joinpath("ref_count.json") self._metadata_file_path = self._path.joinpath("metadata.json")
self._sha1_buf_size = sha1_buf_size self._sha1_buf_size = sha1_buf_size
self._pgzip_threads = pgzip_threads self._pgzip_threads = pgzip_threads
self._pgzip_block_size = pgzip_block_size self._pgzip_block_size = pgzip_block_size
self._json_dict: Dict[str, Any] = {"created_on": time.ctime()} self._json_dict: Dict[str, Any] = {"created_on": time.ctime()}
# Initialize the sha1_store if not exist and init==True. # Initialize the store if not exist and if init is True.
if init and not self._path.exists(): if init and not self._path.exists():
try: try:
Path.mkdir(self._path, parents=False, exist_ok=False) Path.mkdir(self._path, parents=False, exist_ok=False)
...@@ -141,7 +131,13 @@ class SHA1_Store: ...@@ -141,7 +131,13 @@ class SHA1_Store:
self._store_json_dict() self._store_json_dict()
# This is an internal error since caller of this our own wgit code. # This is an internal error since caller of this our own wgit code.
assert self._path.exists(), "SHA1 store does not exist and init==False" assert (
self._path.exists() and self._metadata_file_path.exists()
), f"SHA1 store {self._path} does not exist and init is False"
# Make sure there is a valid metadata file.
self._load_json_dict()
assert "created_on" in self._json_dict, f"Invalid SHA1 Store in {self._path}"
# Init temp dir. # Init temp dir.
if tmp_dir: if tmp_dir:
...@@ -156,30 +152,31 @@ class SHA1_Store: ...@@ -156,30 +152,31 @@ class SHA1_Store:
def _load_json_dict(self) -> None: def _load_json_dict(self) -> None:
"""Loading json dict from disk.""" """Loading json dict from disk."""
with open(self._ref_file_path, "r") as f: with open(self._metadata_file_path, "r") as f:
self._json_dict = json.load(f) self._json_dict = json.load(f)
def _store_json_dict(self) -> None: def _store_json_dict(self) -> None:
"""Storing json dict to disk.""" """Storing json dict to disk."""
with open(self._ref_file_path, "w", encoding="utf-8") as f: with open(self._metadata_file_path, "w", encoding="utf-8") as f:
json.dump(self._json_dict, f, ensure_ascii=False, indent=4) json.dump(self._json_dict, f, ensure_ascii=False, indent=4)
def add(self, file_or_obj: Union[Path, Tensor, OrderedDict], compress: bool = False) -> str: def add(self, file_or_obj: Union[Path, Tensor, Dict], compress: bool = False) -> str:
""" """Adds a file/object to this store and the sha1 references accordingly.
Adds a file/object to the internal sha1_store and the sha1 references
accordingly.
First, a sha1 hash is calculated. Utilizing the sha1 hash string, the actual file First, a sha1 hash is calculated. Utilizing the sha1 hash string, the actual file
in <file_or_obj> is moved within the sha1_store and the reference file is updated. in <file_or_obj> is moved within the store and the reference file is updated.
If the input is an object, it will be store in the self._tmp_dir and then moved. If the input is an object, it will be store in the self._tmp_dir and then moved.
If compress is True, the stored file is also compressed, which is useful for tensors If compress is True, the stored file is also compressed, which is useful for tensors
with a lot of zeros. with a lot of zeros.
Args: Args:
file_or_obj (str or tensor or OrderedDict): file_or_obj (str or tensor or Dict):
Path to the file to be added to the sha1_store or an in-memory object Path to the file to be added to the store or an in-memory object
that can be handled by torch.save. that can be handled by torch.save. Note, OrderedDict is used when
you call `state_dict()` on a nn.Module, and it is an instance
of a Dict too. A model's state_dict can be a simple dict because
it may contain both model state_dict and other non-tensor info.
""" """
# Use `isinstance` not type() == Path since pathlib returns OS specific # Use `isinstance` not type() == Path since pathlib returns OS specific
# Path types, which inherit from the Path class. # Path types, which inherit from the Path class.
...@@ -188,10 +185,10 @@ class SHA1_Store: ...@@ -188,10 +185,10 @@ class SHA1_Store:
torch.load(cast(Union[Path, str], file_or_obj)) torch.load(cast(Union[Path, str], file_or_obj))
file_path = Path(file_or_obj) file_path = Path(file_or_obj)
remove_tmp = False remove_tmp = False
elif isinstance(file_or_obj, (Tensor, OrderedDict)): elif isinstance(file_or_obj, (Tensor, Dict)):
# Serialize the object into a tmp file. # Serialize the object into a tmp file.
file_path = self._get_tmp_file_path() file_path = self._get_tmp_file_path()
torch.save(cast(Union[Tensor, OrderedDict], file_or_obj), file_path) torch.save(cast(Union[Tensor, Dict], file_or_obj), file_path)
remove_tmp = True remove_tmp = True
else: else:
assert False, f"incorrect input {type(file_or_obj)}" assert False, f"incorrect input {type(file_or_obj)}"
...@@ -215,7 +212,7 @@ class SHA1_Store: ...@@ -215,7 +212,7 @@ class SHA1_Store:
sys.stderr.write(f"An exception occured: {repr(error)}\n") sys.stderr.write(f"An exception occured: {repr(error)}\n")
sys.exit(ExitCode.FILE_EXISTS_ERROR) sys.exit(ExitCode.FILE_EXISTS_ERROR)
# Transfer the file to the internal sha1_store # Transfer the file to the store.
repo_fpath = repo_fdir.joinpath(sha1_hash) repo_fpath = repo_fdir.joinpath(sha1_hash)
try: try:
if compress: if compress:
...@@ -234,7 +231,7 @@ class SHA1_Store: ...@@ -234,7 +231,7 @@ class SHA1_Store:
return sha1_hash return sha1_hash
def get(self, sha1: str) -> Union[Tensor, OrderedDict]: def get(self, sha1: str) -> Union[Tensor, Dict]:
"""Get data from a SHA1 """Get data from a SHA1
Args: Args:
...@@ -242,7 +239,7 @@ class SHA1_Store: ...@@ -242,7 +239,7 @@ class SHA1_Store:
SHA1 of the object to get. SHA1 of the object to get.
Returns: Returns:
(Tensor or OrderedDict): (Tensor or Dict):
In-memory object. In-memory object.
Throws: Throws:
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#MODIFIED FOR FlattenParamsWrapper
from typing import Any from typing import Any
# Deprecate allclose when we move to newer versions.
def assert_allclose(actual: Any, expected: Any, rtol: float = ..., atol: float = ..., equal_nan: bool = ..., msg: str = ...) -> None: ... def assert_allclose(actual: Any, expected: Any, rtol: float = ..., atol: float = ..., equal_nan: bool = ..., msg: str = ...) -> None: ...
def assert_close(actual: Any, expected: Any, rtol: float = ..., atol: float = ..., equal_nan: bool = ..., msg: str = ...) -> None: ...
#END
...@@ -33,7 +33,10 @@ def create_test_dir(): ...@@ -33,7 +33,10 @@ def create_test_dir():
# create random checkpoints # create random checkpoints
size_list = [30e5, 35e5, 40e5, 40e5] size_list = [30e5, 35e5, 40e5, 40e5]
for i, size in enumerate(size_list): for i, size in enumerate(size_list):
torch.save(nn.Linear(1, int(size)), f"checkpoint_{i}.pt") sd = {}
sd["model"] = nn.Linear(1, int(size)).state_dict()
sd["step"] = 100
torch.save(sd, f"checkpoint_{i}.pt")
return test_dir return test_dir
...@@ -55,10 +58,14 @@ def test_api_init(capsys, repo): ...@@ -55,10 +58,14 @@ def test_api_init(capsys, repo):
assert Path(".wgit/.gitignore").exists() assert Path(".wgit/.gitignore").exists()
def test_api_add(capsys, repo): @pytest.mark.parametrize("per_tensor", [True, False])
def test_api_add(capsys, repo, per_tensor):
fnum = random.randint(0, 2) fnum = random.randint(0, 2)
chkpt0 = f"checkpoint_{fnum}.pt" chkpt0 = f"checkpoint_{fnum}.pt"
repo.add(chkpt0) repo.add(chkpt0, per_tensor)
if per_tensor:
# TODO (Min): test per_tensor add more.
return
sha1_hash = repo._sha1_store._get_sha1_hash(chkpt0) sha1_hash = repo._sha1_store._get_sha1_hash(chkpt0)
metadata_path = repo._rel_file_path(Path(chkpt0)) metadata_path = repo._rel_file_path(Path(chkpt0))
......
...@@ -56,7 +56,7 @@ def test_cli_add(create_test_dir, capsys): ...@@ -56,7 +56,7 @@ def test_cli_add(create_test_dir, capsys):
cli.main(["add", chkpt0]) cli.main(["add", chkpt0])
sha1_store = SHA1_Store( sha1_store = SHA1_Store(
Path.cwd().joinpath(".wgit"), Path.cwd().joinpath(".wgit", "sha1_store"),
init=False, init=False,
) )
sha1_hash = sha1_store._get_sha1_hash(chkpt0) sha1_hash = sha1_store._get_sha1_hash(chkpt0)
......
...@@ -17,7 +17,7 @@ from fairscale.internal import torch_version ...@@ -17,7 +17,7 @@ from fairscale.internal import torch_version
# Get the absolute path of the parent at the beginning before any os.chdir(), # Get the absolute path of the parent at the beginning before any os.chdir(),
# so that we can proper clean it up at any CWD. # so that we can proper clean it up at any CWD.
PARENT_DIR = Path("sha1_store_testing").resolve() TESTING_STORE_DIR = Path("sha1_store_testing").resolve()
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
...@@ -32,8 +32,9 @@ def sha1_store(request): ...@@ -32,8 +32,9 @@ def sha1_store(request):
""" """
# Attach a teardown function. # Attach a teardown function.
def teardown(): def teardown():
os.chdir(PARENT_DIR.joinpath("..").resolve()) os.chdir(TESTING_STORE_DIR.joinpath("..").resolve())
shutil.rmtree(PARENT_DIR, ignore_errors=True) if TESTING_STORE_DIR.exists():
shutil.rmtree(TESTING_STORE_DIR)
request.addfinalizer(teardown) request.addfinalizer(teardown)
...@@ -41,15 +42,14 @@ def sha1_store(request): ...@@ -41,15 +42,14 @@ def sha1_store(request):
teardown() teardown()
# Get an empty sha1 store. # Get an empty sha1 store.
PARENT_DIR.mkdir() sha1_store = SHA1_Store(TESTING_STORE_DIR, init=True)
sha1_store = SHA1_Store(PARENT_DIR, init=True)
return sha1_store return sha1_store
@pytest.mark.parametrize("compress", [True, False]) @pytest.mark.parametrize("compress", [True, False])
def test_sha1_add_file(sha1_store, compress): def test_sha1_add_file(sha1_store, compress):
os.chdir(PARENT_DIR) os.chdir(TESTING_STORE_DIR)
# Create random checkpoints # Create random checkpoints
size_list = [25e5, 27e5, 30e5, 35e5, 40e5] size_list = [25e5, 27e5, 30e5, 35e5, 40e5]
...@@ -89,7 +89,7 @@ def test_sha1_add_file(sha1_store, compress): ...@@ -89,7 +89,7 @@ def test_sha1_add_file(sha1_store, compress):
@pytest.mark.parametrize("compress", [True, False]) @pytest.mark.parametrize("compress", [True, False])
def test_sha1_add_state_dict(sha1_store, compress): def test_sha1_add_state_dict(sha1_store, compress):
os.chdir(PARENT_DIR) os.chdir(TESTING_STORE_DIR)
# add once # add once
for i in range(3): for i in range(3):
sha1_store.add(nn.Linear(10, 10).state_dict(), compress) sha1_store.add(nn.Linear(10, 10).state_dict(), compress)
...@@ -107,7 +107,7 @@ def test_sha1_add_state_dict(sha1_store, compress): ...@@ -107,7 +107,7 @@ def test_sha1_add_state_dict(sha1_store, compress):
@pytest.mark.parametrize("compress", [True, False]) @pytest.mark.parametrize("compress", [True, False])
def test_sha1_add_tensor(sha1_store, compress): def test_sha1_add_tensor(sha1_store, compress):
os.chdir(PARENT_DIR) os.chdir(TESTING_STORE_DIR)
sha1_store.add(torch.Tensor([1.0, 5.5, 3.4]), compress) sha1_store.add(torch.Tensor([1.0, 5.5, 3.4]), compress)
sha1_store._load_json_dict() sha1_store._load_json_dict()
json_dict = sha1_store._json_dict json_dict = sha1_store._json_dict
...@@ -120,7 +120,7 @@ def test_sha1_add_tensor(sha1_store, compress): ...@@ -120,7 +120,7 @@ def test_sha1_add_tensor(sha1_store, compress):
@pytest.mark.parametrize("compress", [True, False]) @pytest.mark.parametrize("compress", [True, False])
def test_sha1_get(sha1_store, compress): def test_sha1_get(sha1_store, compress):
"""Testing the get() API: normal and exception cases.""" """Testing the get() API: normal and exception cases."""
os.chdir(PARENT_DIR) os.chdir(TESTING_STORE_DIR)
# Add a file, a state dict and a tensor. # Add a file, a state dict and a tensor.
file = "test_get.pt" file = "test_get.pt"
...@@ -149,7 +149,7 @@ def test_sha1_get(sha1_store, compress): ...@@ -149,7 +149,7 @@ def test_sha1_get(sha1_store, compress):
@pytest.mark.parametrize("compress", [True, False]) @pytest.mark.parametrize("compress", [True, False])
def test_sha1_delete(sha1_store, compress): def test_sha1_delete(sha1_store, compress):
"""Testing the delete() API: with ref counting behavior.""" """Testing the delete() API: with ref counting behavior."""
os.chdir(PARENT_DIR) os.chdir(TESTING_STORE_DIR)
# Add once and delete, second delete should throw an exception. # Add once and delete, second delete should throw an exception.
tensor = torch.ones(30, 50) tensor = torch.ones(30, 50)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment