Unverified Commit 68af57d8 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[feat] Sha1 Store enhancements (#1026)



* refactor SHA1_Store

- renamed the class
- added created_on field and refactored how init is done
- wrap long lines

* wrapped longer lines

* rename json file ref_count.json

* make sha1_buf_size an argument

* update gitignore

* added tmp_dir

* added new sha1_store add and tests

* update chdir

* add debug to test

* fixing unit test for 1.8
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent 03f3e833
...@@ -12,10 +12,6 @@ ...@@ -12,10 +12,6 @@
*.egg-info/ *.egg-info/
.testmondata .testmondata
# experimental weigit
fairscale/experimental/wgit/dev
.wgit
# Build and release # Build and release
build/ build/
dist/ dist/
......
...@@ -11,7 +11,7 @@ import sys ...@@ -11,7 +11,7 @@ import sys
from typing import Dict, Tuple, Union from typing import Dict, Tuple, Union
from .pygit import PyGit from .pygit import PyGit
from .sha1_store import SHA1_store from .sha1_store import SHA1_Store
class Repo: class Repo:
...@@ -42,28 +42,27 @@ class Repo: ...@@ -42,28 +42,27 @@ class Repo:
exists = self._exists(self.wgit_parent) exists = self._exists(self.wgit_parent)
if not exists and init: if not exists and init:
# No weigit repo exists and is being initialized with init=True # No weigit repo exists and is being initialized with init=True
# Make .wgit directory, create sha1_refs # Make .wgit directory, create sha1_store
weigit_dir = self.wgit_parent.joinpath(".wgit") weigit_dir = self.wgit_parent.joinpath(".wgit")
weigit_dir.mkdir(parents=False, exist_ok=True) weigit_dir.mkdir(parents=False, exist_ok=True)
# Initializing sha1_store only after wgit has been initialized! # Initializing sha1_store only after wgit has been initialized!
self._sha1_store = SHA1_store(weigit_dir, init=True) self._sha1_store = SHA1_Store(weigit_dir, init=True)
# # Make the .wgit a git repo # # Make the .wgit a git repo
gitignore_files = [ gitignore_files = [
self._sha1_store_path.name, self._sha1_store_path.name,
self._sha1_store.ref_file_path.name,
] ]
self._pygit = PyGit(weigit_dir, gitignore=gitignore_files) self._pygit = PyGit(weigit_dir, gitignore=gitignore_files)
elif exists and init: elif exists and init:
# if weigit repo already exists and init is being called, wrap the existing .wgit/.git repo with PyGit # if weigit repo already exists and init is being called, wrap the existing .wgit/.git repo with PyGit
self._sha1_store = SHA1_store(self.path) self._sha1_store = SHA1_Store(self.path)
self._pygit = PyGit(self.path) self._pygit = PyGit(self.path)
elif exists and not init: elif exists and not init:
# weigit exists and non-init commands are triggered # weigit exists and non-init commands are triggered
self._sha1_store = SHA1_store(self.path) self._sha1_store = SHA1_Store(self.path)
self._pygit = PyGit(self.path) self._pygit = PyGit(self.path)
else: else:
...@@ -86,7 +85,10 @@ class Repo: ...@@ -86,7 +85,10 @@ class Repo:
metadata_file, parent_sha1 = self._process_metadata_file(rel_file_path) metadata_file, parent_sha1 = self._process_metadata_file(rel_file_path)
# add the file to the sha1_store # add the file to the sha1_store
sha1_hash = self._sha1_store.add(file_path, parent_sha1) # TODO (Min): We don't add parent sha1 tracking to sha1 store due to
# de-duplication & dependency tracking can create cycles.
# We need to figure out a way to handle deletion.
sha1_hash = self._sha1_store.add(file_path)
# write metadata to the metadata-file # write metadata to the metadata-file
self._write_metadata(metadata_file, file_path, sha1_hash) self._write_metadata(metadata_file, file_path, sha1_hash)
...@@ -183,7 +185,7 @@ class Repo: ...@@ -183,7 +185,7 @@ class Repo:
metadata_d = dict() metadata_d = dict()
for file in self.path.iterdir(): # iterate over the .wgit directory for file in self.path.iterdir(): # iterate over the .wgit directory
# exlude all the .wgit files and directory # exlude all the .wgit files and directory
if file.name not in {"sha1_store", "sha1_refs.json", ".git", ".gitignore"}: if file.name not in {"sha1_store", ".git", ".gitignore"}:
# perform a directory walk on the metadata_file directories to find the metadata files # perform a directory walk on the metadata_file directories to find the metadata files
for path in file.rglob("*"): for path in file.rglob("*"):
if path.is_file(): if path.is_file():
...@@ -275,16 +277,15 @@ class Repo: ...@@ -275,16 +277,15 @@ class Repo:
def _weigit_repo_exists(self, check_dir: pathlib.Path) -> bool: def _weigit_repo_exists(self, check_dir: pathlib.Path) -> bool:
"""Returns True if a valid WeiGit repo exists in the path: check_dir""" """Returns True if a valid WeiGit repo exists in the path: check_dir"""
wgit_exists, sha1_refs, git_exists, gitignore_exists = self._weight_repo_file_check(check_dir) wgit_exists, git_exists, gitignore_exists = self._weight_repo_file_check(check_dir)
return wgit_exists and sha1_refs and git_exists and gitignore_exists return wgit_exists and git_exists and gitignore_exists
def _weight_repo_file_check(self, check_dir: Path) -> tuple: def _weight_repo_file_check(self, check_dir: Path) -> tuple:
"""Returns a tuple of boolean corresponding to the existence of each .wgit internally required files.""" """Returns a tuple of boolean corresponding to the existence of each .wgit internally required files."""
wgit_exists = check_dir.joinpath(".wgit").exists() wgit_exists = check_dir.joinpath(".wgit").exists()
sha1_refs = check_dir.joinpath(".wgit/sha1_refs.json").exists()
git_exists = check_dir.joinpath(".wgit/.git").exists() git_exists = check_dir.joinpath(".wgit/.git").exists()
gitignore_exists = check_dir.joinpath(".wgit/.gitignore").exists() gitignore_exists = check_dir.joinpath(".wgit/.gitignore").exists()
return wgit_exists, sha1_refs, git_exists, gitignore_exists return wgit_exists, git_exists, gitignore_exists
class RepoStatus(Enum): class RepoStatus(Enum):
......
...@@ -4,160 +4,257 @@ ...@@ -4,160 +4,257 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from collections import OrderedDict
import hashlib import hashlib
import json import json
from pathlib import Path from pathlib import Path
import shutil import shutil
import sys import sys
from typing import Union import tempfile
import time
from typing import Any, Dict, Union, cast
import torch
from torch import Tensor
from .utils import ExitCode from .utils import ExitCode
# This is a fixed dir name we use for sha1_store. It should not be changed
# for backward compatibility reasons.
SHA1_STORE_DIR_NAME = "sha1_store"
class SHA1_store: class SHA1_Store:
""" """
Represent the sha1_store within the WeiGit repo for handling added file to the store and managing references. This class represents a SHA1 checksum based storage dir for state_dict
and tensors.
This means the same content will not be stored multiple times, resulting
in space savings. (a.k.a. de-duplication)
To make things easier for the callers, this class accept input data
as files, state_dict or tensors. This class always returns in-memory
data, not on-disk files. This class doesn't really care or know the actually
data types. It uses torch.save() and torch.load() to do serialization.
A key issue is dealing with content deletion. We use a reference counting
algorithm, which means the caller must have symmetrical add/remove calls
for each object.
We used to support children-parent dependency graph and ref counting, but
it is flawed since a grand-child can have the same SHA1 as the grand-parent,
resulting in a cycle. This means caller must compute which parent is safe
to delete in a version tracking graph. The lesson here is that content
addressibility and dependency graphs do not mix well.
Args: Args:
weigit_path (pathlib.Path) parent_path (Path):
The path to the weigit repo where a sha1_store will be created, or if already exists will be wrapped. The parent path in which a SHA1_Store will be created.
init (bool, optional) init (bool, optional):
- If ``True``, initializes a new sha1_store in the weigit_path. Initialization creates a `sha1_store` directory within WeiGit repo in ./<weigit_path>/, - If ``True``, initializes a new SHA1_Store in the parent_path. Initialization
and a `sha1_refs.json` withiin ./<weigit_path>/. creates a `sha1_store` directory in ./<parent_path>/,
- If ``False``, a new sha1_store is not initialized and the existing sha1_store is simply wrapped, populating the `name`, `path` and the `ref_file_path` attributes. and a `ref_count.json` within ./<parent_path>/.
- If ``False``, a new `sha1_store` dir is not initialized and the existing
`sha1_store` is used to init this class, populating `_json_dict`, and other
attributes.
- Default: False - Default: False
sha1_buf_size (int):
Buffer size used for checksumming. Default: 100MB.
tmp_dir (str):
Dir for temporary files if input is an in-memory object.
""" """
def __init__(self, weigit_path: Path, init: bool = False) -> None: def __init__(
"""Create or wrap (if already exists) a sha1_store within the WeiGit repo.""" self, parent_path: Path, init: bool = False, sha1_buf_size: int = 100 * 1024 * 1024, tmp_dir: str = ""
# should use the sha1_refs.json to track the parent references. ) -> None:
self.name = "sha1_store" """Create or wrap (if already exists) a sha1_store."""
self.path = weigit_path.joinpath(self.name) self._path = parent_path.joinpath(SHA1_STORE_DIR_NAME)
self.ref_file_path = weigit_path.joinpath("sha1_refs.json") self._ref_file_path = self._path.joinpath("ref_count.json")
self._sha1_buf_size = sha1_buf_size
self._json_dict: Dict[str, Any] = {"created_on": time.ctime()}
self._weigit_path = weigit_path # Initialize the sha1_store if not exist and init==True.
# initialize the sha1_store if init and not self._path.exists():
if init:
try: try:
if not self.path.exists(): Path.mkdir(self._path, parents=False, exist_ok=False)
Path.mkdir(self.path, parents=False, exist_ok=False)
self.ref_file_path.touch(exist_ok=False)
except FileExistsError as error: except FileExistsError as error:
sys.stderr.write(f"An exception occured while creating Sha1_store: {repr(error)}\n") sys.stderr.write(f"An exception occured while creating Sha1_store: {repr(error)}\n")
sys.exit(ExitCode.FILE_EXISTS_ERROR) sys.exit(ExitCode.FILE_EXISTS_ERROR)
# Create a new json file for this new store.
self._store_json_dict()
def add(self, file_path: Path, parent_sha1: str) -> str: # This is an internal error since caller of this our own wgit code.
assert self._path.exists(), "SHA1 store does not exist and init==False"
# Init temp dir.
if tmp_dir:
# Caller supplied tmp dir
assert Path(tmp_dir).is_dir(), "incorrect input"
self._tmp_dir = Path(tmp_dir)
else:
# Default tmp dir, need to clean it.
self._tmp_dir = self._path.joinpath("tmp")
shutil.rmtree(self._tmp_dir, ignore_errors=True)
self._tmp_dir.mkdir()
def _load_json_dict(self) -> None:
"""Loading json dict from disk."""
with open(self._ref_file_path, "r") as f:
self._json_dict = json.load(f)
def _store_json_dict(self) -> None:
"""Storing json dict to disk."""
with open(self._ref_file_path, "w", encoding="utf-8") as f:
json.dump(self._json_dict, f, ensure_ascii=False, indent=4)
def add(self, file_or_obj: Union[Path, Tensor, OrderedDict]) -> str:
""" """
Adds a file/checkpoint to the internal sha1_store and the sha1 references accordingly. Adds a file/object to the internal sha1_store and the sha1 references
First, a sha1 hash is calculated. Utilizing the sha1 hash string, the actual file in <in_file_path> is moved accordingly.
within the sha1_store and the sha1 reference file is updated accordingly with the information of their parents
node (if exists) and whether the new version is a leaf node or not. First, a sha1 hash is calculated. Utilizing the sha1 hash string, the actual file
in <file_or_obj> is moved within the sha1_store and the reference file is updated.
If the input is an object, it will be store in the self._tmp_dir and then moved.
Args: Args:
in_file_path (str): path to the file to be added to the sha1_store. file_or_obj (str or tensor or OrderedDict):
Path to the file to be added to the sha1_store or an in-memory object
that can be handled by torch.save.
""" """
# Use `isinstance` not type() == Path since pathlib returns OS specific
# Path types, which inherit from the Path class.
if isinstance(file_or_obj, (Path, str)):
# Make sure it is a valid file.
torch.load(cast(Union[Path, str], file_or_obj))
file_path = Path(file_or_obj)
remove_tmp = False
elif isinstance(file_or_obj, (Tensor, OrderedDict)):
# Serialize the object into a tmp file.
file_path = self._get_tmp_file_path()
torch.save(cast(Union[Tensor, OrderedDict], file_or_obj), file_path)
remove_tmp = True
else:
assert False, f"incorrect input {type(file_or_obj)}"
# Get SHA1 from the file.
assert isinstance(file_path, Path), type(file_path)
sha1_hash = self._get_sha1_hash(file_path) sha1_hash = self._get_sha1_hash(file_path)
# use the sha1_hash to create a directory with first2 sha naming convention
try: # Add reference.
repo_fdir = self.path.joinpath(sha1_hash[:2]) ref_count = self._add_ref(sha1_hash, True)
repo_fdir.mkdir(exist_ok=True)
except FileExistsError as error: if ref_count == 1:
sys.stderr.write(f"An exception occured: {repr(error)}\n") # First time adding
sys.exit(ExitCode.FILE_EXISTS_ERROR)
try: # Create the file dir, if needed.
# First transfer the file to the internal sha1_store repo_fdir = self._sha1_to_dir(sha1_hash)
repo_fpath = Path.cwd().joinpath(repo_fdir, sha1_hash[2:]) if not repo_fdir.exists():
shutil.copy2(file_path, repo_fpath) try:
repo_fdir.mkdir(exist_ok=True, parents=True)
# Create the dependency Graph and track reference except FileExistsError as error:
self._add_ref(sha1_hash, parent_sha1) sys.stderr.write(f"An exception occured: {repr(error)}\n")
sys.exit(ExitCode.FILE_EXISTS_ERROR)
except BaseException as error:
# in case of failure: Cleans up the sub-directories created to store sha1-named checkpoints # Transfer the file to the internal sha1_store
sys.stderr.write(f"An exception occured: {repr(error)}\n") repo_fpath = repo_fdir.joinpath(sha1_hash)
shutil.rmtree(repo_fdir) try:
shutil.copy2(file_path, repo_fpath)
except BaseException as error:
# Something went wrong, perhaps out of space, or race condition due to lack of locking.
# TODO (Min): proper handle the error and recover when we learn more here.
sys.stderr.write(f"An exception occured: {repr(error)}\n")
ref_count = self._add_ref(sha1_hash, False)
# Clean up if needed.
if remove_tmp:
file_path.unlink()
return sha1_hash return sha1_hash
def _get_sha1_hash(self, file_path: Union[str, Path]) -> str: def get(self, sha1: str) -> Union[Tensor, OrderedDict]:
"""return the sha1 hash of a file """Get data from a SHA1
Args: Args:
file_path (str, Path): Path to the file whose sha1 hash is to be calculalated and returned. sha1 (str):
SHA1 of the object to get.
Returns:
(Tensor or OrderedDict):
In-memory object.
""" """
SHA1_BUF_SIZE = 104857600 # Reading file in 100MB chunks raise NotImplementedError()
def delete(self, sha1: str) -> None:
"""Delete a SHA1
Args:
sha1 (str):
SHA1 of the object to delete.
"""
raise NotImplementedError()
def _get_sha1_hash(self, file_path: Union[str, Path]) -> str:
"""Return the sha1 hash of a file
Args:
file_path (str, Path):
Path to the file whose sha1 hash is to be calculalated and returned.
Returns:
(str):
The SHA1 computed.
"""
sha1 = hashlib.sha1() sha1 = hashlib.sha1()
with open(file_path, "rb") as f: with open(file_path, "rb") as f:
while True: while True:
data = f.read(SHA1_BUF_SIZE) data = f.read(self._sha1_buf_size)
if not data: if not data:
break break
sha1.update(data) sha1.update(data)
return sha1.hexdigest() return sha1.hexdigest()
def _add_ref(self, current_sha1_hash: str, parent_hash: str) -> None: def _get_tmp_file_path(self) -> Path:
"""Helper to get a tmp file name under self.tmp_dir."""
return Path(tempfile.mkstemp(dir=self._tmp_dir)[1])
def _sha1_to_dir(self, sha1: str) -> Path:
"""Helper to get the internal dir for a file based on its SHA1"""
# Using first 2 letters of the sha1, which results 26 * 26 = 676 subdirs under the top
# level. Then, using another 2 letters for sub-sub-dir. If each dir holds 1000 files, this
# can hold 450 millions files.
# NOTE: this can NOT be changed for backward compatible reasons once in production.
assert len(sha1) > 4, "sha1 too short"
part1, part2 = sha1[:2], sha1[2:4]
return self._path.joinpath(part1, part2)
def _add_ref(self, current_sha1_hash: str, inc: bool) -> int:
""" """
Populates the sha1_refs.json file when file is added and keeps track of reference to earlier file additions. Update the reference count.
If the sha1_refs.json file is empty, then a new tracking entry of the added file is logged in the sha1_refs file.
If the file already has an entry, first it checks if the incoming new added file is a new version of any of the If the reference counting file does not have this sha1, then a new tracking
existing entries. If it is, then logs the tracking info as a new version of that existing entry. entry of the added.
Otherwise a new entry for the new added file is created for tracking.
Args: Args:
file_path (pathlib.Path) current_sha1_hash (str):
Path to the incoming added file.
current_sha1_hash (str)
The sha1 hash of the incoming added file. The sha1 hash of the incoming added file.
""" inc (bool):
# Check the current state of the reference file and check if the added file already has an entry. Increment or decrement.
sha1_refs_empty = self._sha1_refs_file_state()
# if the file is empty: add the first entry
if sha1_refs_empty:
with open(self.ref_file_path) as f:
ref_data = {current_sha1_hash: {"parent": "ROOT", "ref_count": 1, "is_leaf": True}}
self._write_to_json(self.ref_file_path, ref_data)
else:
# Open sha1 reference file and check if there is a parent_hash not equal to Root?
# if Yes, find parent and add the child. Else, just add a new entry
with open(self.ref_file_path, "r") as f:
ref_data = json.load(f)
if parent_hash != "ROOT":
# get the last head and replace it's child from HEAD -> this sha1
ref_data[parent_hash]["is_leaf"] = False
ref_data[current_sha1_hash] = {"parent": parent_hash, "ref_count": 1, "is_leaf": True}
else:
ref_data[current_sha1_hash] = {"parent": "ROOT", "ref_count": 1, "is_leaf": True}
self._write_to_json(self.ref_file_path, ref_data)
def _sha1_refs_file_state(self) -> bool:
"""
Checks the state of the sha1 reference file, whether the file is empty or not.
If not empty, it checks whether the input file in <file_path> has an older entry (version)
in the reference file.
Args: Returns:
file_path (pathlib.Path) (int):
input File whose entry will be checked if it exists in the reference file. Resulting ref count.
"""
try:
with open(self.ref_file_path, "r") as f:
ref_data = json.load(f)
sha1_refs_empty: bool = False
except json.JSONDecodeError as error:
if not self.ref_file_path.stat().st_size:
sha1_refs_empty = True
return sha1_refs_empty
def _write_to_json(self, file: Path, data: dict) -> None:
""" """
Populates a json file with data. self._load_json_dict()
Args:
file (pathlib.Path) # Init the entry if needed.
path to the file to be written in. if current_sha1_hash not in self._json_dict:
data (pathlib.Path) self._json_dict[current_sha1_hash] = 0
Data to be written in the file.
""" # Update the ref count.
with open(file, "w", encoding="utf-8") as f: self._json_dict[current_sha1_hash] += 1 if inc else -1
json.dump(data, f, ensure_ascii=False, indent=4) assert self._json_dict[current_sha1_hash] >= 0, "negative ref count"
self._store_json_dict()
return self._json_dict[current_sha1_hash]
import os import os
import pickle import pickle
from pathlib import Path
from typing import Any, BinaryIO, Callable, IO, Union from typing import Any, BinaryIO, Callable, IO, Union
DEFAULT_PROTOCOL: int = 2 DEFAULT_PROTOCOL: int = 2
...@@ -7,4 +8,4 @@ DEFAULT_PROTOCOL: int = 2 ...@@ -7,4 +8,4 @@ DEFAULT_PROTOCOL: int = 2
def save(obj, f: Union[str, os.PathLike, BinaryIO, IO[bytes]], def save(obj, f: Union[str, os.PathLike, BinaryIO, IO[bytes]],
pickle_module: Any=pickle, pickle_protocol: int=DEFAULT_PROTOCOL, _use_new_zipfile_serialization: bool=True) -> None: ... pickle_module: Any=pickle, pickle_protocol: int=DEFAULT_PROTOCOL, _use_new_zipfile_serialization: bool=True) -> None: ...
def load(f: Union[str, BinaryIO], map_location) -> Any: ... def load(f: Union[str, BinaryIO, Path], map_location=None) -> Any: ...
...@@ -10,9 +10,10 @@ import random ...@@ -10,9 +10,10 @@ import random
import shutil import shutil
import pytest import pytest
import torch
from torch import nn
from fairscale.experimental.wgit import repo as api from fairscale.experimental.wgit.repo import Repo, RepoStatus
from fairscale.experimental.wgit.repo import RepoStatus
@pytest.fixture @pytest.fixture
...@@ -32,14 +33,13 @@ def create_test_dir(): ...@@ -32,14 +33,13 @@ def create_test_dir():
# create random checkpoints # create random checkpoints
size_list = [30e5, 35e5, 40e5, 40e5] size_list = [30e5, 35e5, 40e5, 40e5]
for i, size in enumerate(size_list): for i, size in enumerate(size_list):
with open(f"checkpoint_{i}.pt", "wb") as f: torch.save(nn.Linear(1, int(size)), f"checkpoint_{i}.pt")
f.write(os.urandom(int(size)))
return test_dir return test_dir
@pytest.fixture @pytest.fixture
def repo(): def repo():
repo = api.Repo(Path.cwd(), init=True) repo = Repo(Path.cwd(), init=True)
return repo return repo
...@@ -48,8 +48,8 @@ def test_setup(create_test_dir): ...@@ -48,8 +48,8 @@ def test_setup(create_test_dir):
def test_api_init(capsys, repo): def test_api_init(capsys, repo):
repo = api.Repo(Path.cwd(), init=True) repo = Repo(Path.cwd(), init=True)
assert Path(".wgit/sha1_refs.json").is_file() assert Path(".wgit/sha1_store").is_dir()
assert Path(".wgit/.gitignore").is_file() assert Path(".wgit/.gitignore").is_file()
assert Path(".wgit/.git").exists() assert Path(".wgit/.git").exists()
assert Path(".wgit/.gitignore").exists() assert Path(".wgit/.gitignore").exists()
...@@ -80,7 +80,7 @@ def test_api_commit(capsys, repo): ...@@ -80,7 +80,7 @@ def test_api_commit(capsys, repo):
def test_api_status(capsys, repo): def test_api_status(capsys, repo):
# delete the repo and initialize a new one: # delete the repo and initialize a new one:
shutil.rmtree(".wgit") shutil.rmtree(".wgit")
repo = api.Repo(Path.cwd(), init=True) repo = Repo(Path.cwd(), init=True)
# check status before any file is added # check status before any file is added
out = repo.status() out = repo.status()
...@@ -99,8 +99,7 @@ def test_api_status(capsys, repo): ...@@ -99,8 +99,7 @@ def test_api_status(capsys, repo):
assert out == {key_list[0]: RepoStatus.CLEAN} assert out == {key_list[0]: RepoStatus.CLEAN}
# check status after a new change has been made to the file # check status after a new change has been made to the file
with open(chkpt0, "wb") as f: torch.save(nn.Linear(1, int(15e5)), chkpt0)
f.write(os.urandom(int(15e5)))
out = repo.status() out = repo.status()
assert out == {key_list[0]: RepoStatus.CHANGES_NOT_ADDED} assert out == {key_list[0]: RepoStatus.CHANGES_NOT_ADDED}
......
...@@ -9,13 +9,18 @@ from pathlib import Path ...@@ -9,13 +9,18 @@ from pathlib import Path
import shutil import shutil
import pytest import pytest
import torch
from torch import nn
import fairscale.experimental.wgit.cli as cli import fairscale.experimental.wgit.cli as cli
from fairscale.experimental.wgit.sha1_store import SHA1_store from fairscale.experimental.wgit.sha1_store import SHA1_Store
@pytest.fixture @pytest.fixture(scope="module")
def create_test_dir(): def create_test_dir():
"""This setup function runs once per test of this module and
it creates a repo, in the process, testing the init function.
"""
curr_dir = Path.cwd() curr_dir = Path.cwd()
parent_dir = "experimental" parent_dir = "experimental"
test_dir = curr_dir.joinpath(parent_dir, "wgit_testing/") test_dir = curr_dir.joinpath(parent_dir, "wgit_testing/")
...@@ -30,28 +35,27 @@ def create_test_dir(): ...@@ -30,28 +35,27 @@ def create_test_dir():
# create random checkpoints # create random checkpoints
size_list = [30e5, 35e5, 40e5] size_list = [30e5, 35e5, 40e5]
for i, size in enumerate(size_list): for i, size in enumerate(size_list):
with open(f"checkpoint_{i}.pt", "wb") as f: torch.save(nn.Linear(1, int(size)), f"checkpoint_{i}.pt")
f.write(os.urandom(int(size)))
return test_dir
def test_setup(create_test_dir): # Test init.
cli.main(["init"]) cli.main(["init"])
assert str(create_test_dir.stem) == "wgit_testing" assert str(test_dir.stem) == "wgit_testing"
return test_dir
def test_cli_init(capsys): def test_cli_init(create_test_dir, capsys):
# Check if the json and other files have been created by the init # Check if the json and other files have been created by the init
assert Path(".wgit/sha1_refs.json").is_file() assert Path(".wgit/sha1_store").is_dir()
assert Path(".wgit/.gitignore").is_file() assert Path(".wgit/.gitignore").is_file()
assert Path(".wgit/.git").exists() assert Path(".wgit/.git").exists()
def test_cli_add(capsys): def test_cli_add(create_test_dir, capsys):
chkpt0 = "checkpoint_0.pt" chkpt0 = "checkpoint_0.pt"
cli.main(["add", "checkpoint_0.pt"]) cli.main(["add", chkpt0])
sha1_store = SHA1_store( sha1_store = SHA1_Store(
Path.cwd().joinpath(".wgit"), Path.cwd().joinpath(".wgit"),
init=False, init=False,
) )
......
...@@ -3,148 +3,111 @@ ...@@ -3,148 +3,111 @@
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
import json
import os import os
from pathlib import Path from pathlib import Path
import shutil import shutil
import pytest import pytest
import torch
from torch import nn
from fairscale.experimental.wgit.repo import Repo from fairscale.experimental.wgit.sha1_store import SHA1_Store
from fairscale.experimental.wgit.sha1_store import SHA1_store from fairscale.internal import torch_version
# Get the absolute path of the parent at the beginning before any os.chdir(),
# so that we can proper clean it up at any CWD.
PARENT_DIR = Path("sha1_store_testing").resolve()
@pytest.fixture
def sha1_configs():
@dataclass
class Sha1StorePaths:
test_dirs = Path("temp_wgit_testing/.wgit")
test_path = Path.cwd().joinpath(test_dirs)
sha1_ref = test_path.joinpath("sha1_refs.json")
chkpt1a_dir = test_path.joinpath("checkpoint_1a")
chkpt1b_dir = test_path.joinpath("checkpoint_1b")
chkpt1c_dir = test_path.joinpath("checkpoint_1c")
checkpoint_1a = test_path.joinpath("checkpoint_1a", "checkpoint_1.pt")
checkpoint_1b = test_path.joinpath("checkpoint_1b", "checkpoint_1.pt")
checkpoint_1c = test_path.joinpath("checkpoint_1c", "checkpoint_1.pt")
checkpoint_2 = test_path.joinpath("checkpoint_1a", "checkpoint_2.pt")
checkpoint_3 = test_path.joinpath("checkpoint_1a", "checkpoint_3.pt")
metadata_1 = test_path.joinpath("checkpoint_1.pt")
metadata_2 = test_path.joinpath("checkpoint_2.pt")
metadata_3 = test_path.joinpath("checkpoint_3.pt")
return Sha1StorePaths @pytest.fixture(scope="function")
def sha1_store(request):
"""A fixture for setup and teardown.
This only runs once per test function. So don't make this too slow.
@pytest.fixture Tests must be written in a way that either all of the tests run
def sha1_store(sha1_configs): in the order they appears in this file or a specific test is
repo = Repo(sha1_configs.test_path.parent, init=False) run separately by the user. Either way, the test should work.
sha1_store = SHA1_store(sha1_configs.test_dirs, init=False) """
return repo, sha1_store # Attach a teardown function.
def teardown():
os.chdir(PARENT_DIR.joinpath("..").resolve())
shutil.rmtree(PARENT_DIR, ignore_errors=True)
request.addfinalizer(teardown)
def test_setup(sha1_configs): # Teardown in case last run didn't clean it up.
# Set up the testing directory teardown()
sha1_configs.test_dirs.mkdir(parents=True, exist_ok=True) # create test .wgit dir
# Create the test checkpoint files # Get an empty sha1 store.
sha1_configs.chkpt1a_dir.mkdir(exist_ok=False) PARENT_DIR.mkdir()
sha1_configs.chkpt1b_dir.mkdir(exist_ok=False) sha1_store = SHA1_Store(PARENT_DIR, init=True)
sha1_configs.chkpt1c_dir.mkdir(exist_ok=False)
return sha1_store
def test_sha1_add_file(sha1_store):
os.chdir(PARENT_DIR)
# Create random checkpoints # Create random checkpoints
size_list = [25e5, 27e5, 30e5, 35e5, 40e5] size_list = [25e5, 27e5, 30e5, 35e5, 40e5]
chkpts = [ chkpts = [
sha1_configs.checkpoint_1a, "checkpoint_1a.pt",
sha1_configs.checkpoint_1b, "checkpoint_1b.pt",
sha1_configs.checkpoint_1c, "checkpoint_1c.pt",
sha1_configs.checkpoint_2, "checkpoint_2.pt",
sha1_configs.checkpoint_3, "checkpoint_3.pt",
] ]
for file, size in zip(chkpts, size_list): for file, size in zip(chkpts, size_list):
with open(file, "wb") as f: torch.save(nn.Linear(1, int(size)), file)
f.write(os.urandom(int(size)))
# Add those 5 random files.
repo = Repo(sha1_configs.test_path.parent, init=True) for c in chkpts:
sha1_store = SHA1_store(sha1_configs.test_dirs, init=True) sha1_store.add(c)
return sha1_store # Add a fixed data twice.
module = nn.Linear(100, 100, bias=False)
module.weight.data = torch.zeros(100, 100)
def test_sha1_add(sha1_configs, sha1_store): zeros_file = "zeros.pt"
repo, sha1_store = sha1_store torch.save(module.state_dict(), zeros_file)
sha1_store.add(zeros_file)
# Add checkpoint_1: Create the meta_data sha1_store.add(zeros_file)
chkpt1 = sha1_configs.checkpoint_1a
metadata_file, parent_sha1 = repo._process_metadata_file(chkpt1.name) # Assert the ref counts are 1,1,1,1,1 and 2
sha1_store._load_json_dict()
sha1_hash = sha1_store.add(sha1_configs.checkpoint_1a, parent_sha1) json_dict = sha1_store._json_dict
repo._write_metadata(metadata_file, chkpt1, sha1_hash) if torch_version() >= (1, 9, 0):
# torch 1.8 LTS doesn't produce deterministic checkpoint file from fixed tensors/state_dict.
# for checkpoint 1 key = "da3e19590de8f77fcf7a09c888c526b0149863a0"
metadata_file = sha1_configs.test_path.joinpath(sha1_configs.checkpoint_1a.name) assert key in json_dict.keys() and json_dict[key] == 2, json_dict
del json_dict["created_on"]
with open(metadata_file, "r") as file: assert sorted(json_dict.values()) == [1, 1, 1, 1, 1, 2], json_dict
metadata = json.load(file)
assert metadata["SHA1"]["__sha1_full__"] == sha1_hash
def test_sha1_add_state_dict(sha1_store):
os.chdir(PARENT_DIR)
def test_sha1_refs(sha1_configs, sha1_store): # add once
repo, sha1_store = sha1_store for i in range(3):
sha1_store.add(nn.Linear(10, 10).state_dict())
def add_checkpoint(checkpoint): # add twice
metadata_file, parent_sha1 = repo._process_metadata_file(checkpoint.name) for i in range(3):
sha1_hash = sha1_store.add(checkpoint, parent_sha1) sd = nn.Linear(8, 8).state_dict()
repo._write_metadata(metadata_file, checkpoint, sha1_hash) sha1_store.add(sd)
return sha1_hash sha1_store.add(sd)
with open(sha1_configs.sha1_ref, "r") as file: sha1_store._load_json_dict()
refs_data = json.load(file) json_dict = sha1_store._json_dict
del json_dict["created_on"]
# get checkpoint1 sha1 assert sorted(json_dict.values()) == [1, 1, 1, 2, 2, 2], json_dict
sha1_chkpt1a_hash = sha1_store._get_sha1_hash(sha1_configs.checkpoint_1a)
assert refs_data[sha1_chkpt1a_hash]["parent"] == "ROOT"
assert refs_data[sha1_chkpt1a_hash]["ref_count"] == 1 def test_sha1_add_tensor(sha1_store):
os.chdir(PARENT_DIR)
ck1a_sha1_hash = sha1_store._get_sha1_hash(sha1_configs.checkpoint_1a) sha1_store.add(torch.Tensor([1.0, 5.5, 3.4]))
sha1_store._load_json_dict()
# add checkpoint new version of checkpoint-1 json_dict = sha1_store._json_dict
ck1b_sha1_hash = add_checkpoint(sha1_configs.checkpoint_1b) if torch_version() >= (1, 9, 0):
# torch 1.8 LTS doesn't produce deterministic checkpoint file from fixed tensors/state_dict.
# Add new checkpoints 2 and 3 key = "71df4069a03a766eacf9f03eea50968e87eae9f8"
ck2_sha1_hash = add_checkpoint(sha1_configs.checkpoint_2) assert key in json_dict.keys() and json_dict[key] == 1, json_dict
ck3_sha1_hash = add_checkpoint(sha1_configs.checkpoint_3)
# add another version of checkpoint 1
ck1c_sha1_hash = add_checkpoint(sha1_configs.checkpoint_1c)
# load ref file after Sha1 add
with open(sha1_configs.sha1_ref, "r") as file:
refs_data = json.load(file)
# Tests for same file versions
assert refs_data[ck1b_sha1_hash]["parent"] == ck1a_sha1_hash
assert refs_data[ck1c_sha1_hash]["parent"] == ck1b_sha1_hash
assert refs_data[ck1b_sha1_hash]["ref_count"] == 1
assert refs_data[ck1a_sha1_hash]["is_leaf"] is False
assert refs_data[ck1a_sha1_hash]["is_leaf"] is False
assert refs_data[ck1b_sha1_hash]["is_leaf"] is False
assert refs_data[ck1c_sha1_hash]["is_leaf"] is True
# Tests for new files
assert refs_data[ck2_sha1_hash]["parent"] == "ROOT"
assert refs_data[ck2_sha1_hash]["is_leaf"] is True
assert refs_data[ck3_sha1_hash]["parent"] == "ROOT"
assert refs_data[ck3_sha1_hash]["is_leaf"] is True
def test_tear_down(sha1_configs):
# clean up: delete the .wgit directory created during this Test
# Making sure the current directory is ./temp_wgit_testing before removing test dir
test_parent_dir = sha1_configs.test_path.parent
if (test_parent_dir.stem == "temp_wgit_testing") and (sha1_configs.test_path.stem == ".wgit"):
shutil.rmtree(test_parent_dir)
else:
raise Exception("Exception in testing directory tear down!")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment