Unverified Commit 68af57d8 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[feat] Sha1 Store enhancements (#1026)



* refactor SHA1_Store

- renamed the class
- added created_on field and refactored how init is done
- wrap long lines

* wrapped longer lines

* rename json file ref_count.json

* make sha1_buf_size an argument

* update gitignore

* added tmp_dir

* added new sha1_store add and tests

* update chdir

* add debug to test

* fixing unit test for 1.8
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent 03f3e833
......@@ -12,10 +12,6 @@
*.egg-info/
.testmondata
# experimental weigit
fairscale/experimental/wgit/dev
.wgit
# Build and release
build/
dist/
......
......@@ -11,7 +11,7 @@ import sys
from typing import Dict, Tuple, Union
from .pygit import PyGit
from .sha1_store import SHA1_store
from .sha1_store import SHA1_Store
class Repo:
......@@ -42,28 +42,27 @@ class Repo:
exists = self._exists(self.wgit_parent)
if not exists and init:
# No weigit repo exists and is being initialized with init=True
# Make .wgit directory, create sha1_refs
# Make .wgit directory, create sha1_store
weigit_dir = self.wgit_parent.joinpath(".wgit")
weigit_dir.mkdir(parents=False, exist_ok=True)
# Initializing sha1_store only after wgit has been initialized!
self._sha1_store = SHA1_store(weigit_dir, init=True)
self._sha1_store = SHA1_Store(weigit_dir, init=True)
# # Make the .wgit a git repo
gitignore_files = [
self._sha1_store_path.name,
self._sha1_store.ref_file_path.name,
]
self._pygit = PyGit(weigit_dir, gitignore=gitignore_files)
elif exists and init:
# if weigit repo already exists and init is being called, wrap the existing .wgit/.git repo with PyGit
self._sha1_store = SHA1_store(self.path)
self._sha1_store = SHA1_Store(self.path)
self._pygit = PyGit(self.path)
elif exists and not init:
# weigit exists and non-init commands are triggered
self._sha1_store = SHA1_store(self.path)
self._sha1_store = SHA1_Store(self.path)
self._pygit = PyGit(self.path)
else:
......@@ -86,7 +85,10 @@ class Repo:
metadata_file, parent_sha1 = self._process_metadata_file(rel_file_path)
# add the file to the sha1_store
sha1_hash = self._sha1_store.add(file_path, parent_sha1)
# TODO (Min): We don't add parent sha1 tracking to sha1 store due to
# de-duplication & dependency tracking can create cycles.
# We need to figure out a way to handle deletion.
sha1_hash = self._sha1_store.add(file_path)
# write metadata to the metadata-file
self._write_metadata(metadata_file, file_path, sha1_hash)
......@@ -183,7 +185,7 @@ class Repo:
metadata_d = dict()
for file in self.path.iterdir(): # iterate over the .wgit directory
# exlude all the .wgit files and directory
if file.name not in {"sha1_store", "sha1_refs.json", ".git", ".gitignore"}:
if file.name not in {"sha1_store", ".git", ".gitignore"}:
# perform a directory walk on the metadata_file directories to find the metadata files
for path in file.rglob("*"):
if path.is_file():
......@@ -275,16 +277,15 @@ class Repo:
def _weigit_repo_exists(self, check_dir: pathlib.Path) -> bool:
"""Returns True if a valid WeiGit repo exists in the path: check_dir"""
wgit_exists, sha1_refs, git_exists, gitignore_exists = self._weight_repo_file_check(check_dir)
return wgit_exists and sha1_refs and git_exists and gitignore_exists
wgit_exists, git_exists, gitignore_exists = self._weight_repo_file_check(check_dir)
return wgit_exists and git_exists and gitignore_exists
def _weight_repo_file_check(self, check_dir: Path) -> tuple:
"""Returns a tuple of boolean corresponding to the existence of each .wgit internally required files."""
wgit_exists = check_dir.joinpath(".wgit").exists()
sha1_refs = check_dir.joinpath(".wgit/sha1_refs.json").exists()
git_exists = check_dir.joinpath(".wgit/.git").exists()
gitignore_exists = check_dir.joinpath(".wgit/.gitignore").exists()
return wgit_exists, sha1_refs, git_exists, gitignore_exists
return wgit_exists, git_exists, gitignore_exists
class RepoStatus(Enum):
......
......@@ -4,160 +4,257 @@
# LICENSE file in the root directory of this source tree.
from collections import OrderedDict
import hashlib
import json
from pathlib import Path
import shutil
import sys
from typing import Union
import tempfile
import time
from typing import Any, Dict, Union, cast
import torch
from torch import Tensor
from .utils import ExitCode
# This is a fixed dir name we use for sha1_store. It should not be changed
# for backward compatibility reasons.
SHA1_STORE_DIR_NAME = "sha1_store"
class SHA1_store:
class SHA1_Store:
"""
Represent the sha1_store within the WeiGit repo for handling added file to the store and managing references.
This class represents a SHA1 checksum based storage dir for state_dict
and tensors.
This means the same content will not be stored multiple times, resulting
in space savings. (a.k.a. de-duplication)
To make things easier for the callers, this class accept input data
as files, state_dict or tensors. This class always returns in-memory
data, not on-disk files. This class doesn't really care or know the actually
data types. It uses torch.save() and torch.load() to do serialization.
A key issue is dealing with content deletion. We use a reference counting
algorithm, which means the caller must have symmetrical add/remove calls
for each object.
We used to support children-parent dependency graph and ref counting, but
it is flawed since a grand-child can have the same SHA1 as the grand-parent,
resulting in a cycle. This means caller must compute which parent is safe
to delete in a version tracking graph. The lesson here is that content
addressibility and dependency graphs do not mix well.
Args:
weigit_path (pathlib.Path)
The path to the weigit repo where a sha1_store will be created, or if already exists will be wrapped.
init (bool, optional)
- If ``True``, initializes a new sha1_store in the weigit_path. Initialization creates a `sha1_store` directory within WeiGit repo in ./<weigit_path>/,
and a `sha1_refs.json` withiin ./<weigit_path>/.
- If ``False``, a new sha1_store is not initialized and the existing sha1_store is simply wrapped, populating the `name`, `path` and the `ref_file_path` attributes.
parent_path (Path):
The parent path in which a SHA1_Store will be created.
init (bool, optional):
- If ``True``, initializes a new SHA1_Store in the parent_path. Initialization
creates a `sha1_store` directory in ./<parent_path>/,
and a `ref_count.json` within ./<parent_path>/.
- If ``False``, a new `sha1_store` dir is not initialized and the existing
`sha1_store` is used to init this class, populating `_json_dict`, and other
attributes.
- Default: False
sha1_buf_size (int):
Buffer size used for checksumming. Default: 100MB.
tmp_dir (str):
Dir for temporary files if input is an in-memory object.
"""
def __init__(self, weigit_path: Path, init: bool = False) -> None:
"""Create or wrap (if already exists) a sha1_store within the WeiGit repo."""
# should use the sha1_refs.json to track the parent references.
self.name = "sha1_store"
self.path = weigit_path.joinpath(self.name)
self.ref_file_path = weigit_path.joinpath("sha1_refs.json")
def __init__(
self, parent_path: Path, init: bool = False, sha1_buf_size: int = 100 * 1024 * 1024, tmp_dir: str = ""
) -> None:
"""Create or wrap (if already exists) a sha1_store."""
self._path = parent_path.joinpath(SHA1_STORE_DIR_NAME)
self._ref_file_path = self._path.joinpath("ref_count.json")
self._sha1_buf_size = sha1_buf_size
self._json_dict: Dict[str, Any] = {"created_on": time.ctime()}
self._weigit_path = weigit_path
# initialize the sha1_store
if init:
# Initialize the sha1_store if not exist and init==True.
if init and not self._path.exists():
try:
if not self.path.exists():
Path.mkdir(self.path, parents=False, exist_ok=False)
self.ref_file_path.touch(exist_ok=False)
Path.mkdir(self._path, parents=False, exist_ok=False)
except FileExistsError as error:
sys.stderr.write(f"An exception occured while creating Sha1_store: {repr(error)}\n")
sys.exit(ExitCode.FILE_EXISTS_ERROR)
# Create a new json file for this new store.
self._store_json_dict()
def add(self, file_path: Path, parent_sha1: str) -> str:
# This is an internal error since caller of this our own wgit code.
assert self._path.exists(), "SHA1 store does not exist and init==False"
# Init temp dir.
if tmp_dir:
# Caller supplied tmp dir
assert Path(tmp_dir).is_dir(), "incorrect input"
self._tmp_dir = Path(tmp_dir)
else:
# Default tmp dir, need to clean it.
self._tmp_dir = self._path.joinpath("tmp")
shutil.rmtree(self._tmp_dir, ignore_errors=True)
self._tmp_dir.mkdir()
def _load_json_dict(self) -> None:
"""Loading json dict from disk."""
with open(self._ref_file_path, "r") as f:
self._json_dict = json.load(f)
def _store_json_dict(self) -> None:
"""Storing json dict to disk."""
with open(self._ref_file_path, "w", encoding="utf-8") as f:
json.dump(self._json_dict, f, ensure_ascii=False, indent=4)
def add(self, file_or_obj: Union[Path, Tensor, OrderedDict]) -> str:
"""
Adds a file/checkpoint to the internal sha1_store and the sha1 references accordingly.
First, a sha1 hash is calculated. Utilizing the sha1 hash string, the actual file in <in_file_path> is moved
within the sha1_store and the sha1 reference file is updated accordingly with the information of their parents
node (if exists) and whether the new version is a leaf node or not.
Adds a file/object to the internal sha1_store and the sha1 references
accordingly.
First, a sha1 hash is calculated. Utilizing the sha1 hash string, the actual file
in <file_or_obj> is moved within the sha1_store and the reference file is updated.
If the input is an object, it will be store in the self._tmp_dir and then moved.
Args:
in_file_path (str): path to the file to be added to the sha1_store.
file_or_obj (str or tensor or OrderedDict):
Path to the file to be added to the sha1_store or an in-memory object
that can be handled by torch.save.
"""
# Use `isinstance` not type() == Path since pathlib returns OS specific
# Path types, which inherit from the Path class.
if isinstance(file_or_obj, (Path, str)):
# Make sure it is a valid file.
torch.load(cast(Union[Path, str], file_or_obj))
file_path = Path(file_or_obj)
remove_tmp = False
elif isinstance(file_or_obj, (Tensor, OrderedDict)):
# Serialize the object into a tmp file.
file_path = self._get_tmp_file_path()
torch.save(cast(Union[Tensor, OrderedDict], file_or_obj), file_path)
remove_tmp = True
else:
assert False, f"incorrect input {type(file_or_obj)}"
# Get SHA1 from the file.
assert isinstance(file_path, Path), type(file_path)
sha1_hash = self._get_sha1_hash(file_path)
# use the sha1_hash to create a directory with first2 sha naming convention
# Add reference.
ref_count = self._add_ref(sha1_hash, True)
if ref_count == 1:
# First time adding
# Create the file dir, if needed.
repo_fdir = self._sha1_to_dir(sha1_hash)
if not repo_fdir.exists():
try:
repo_fdir = self.path.joinpath(sha1_hash[:2])
repo_fdir.mkdir(exist_ok=True)
repo_fdir.mkdir(exist_ok=True, parents=True)
except FileExistsError as error:
sys.stderr.write(f"An exception occured: {repr(error)}\n")
sys.exit(ExitCode.FILE_EXISTS_ERROR)
# Transfer the file to the internal sha1_store
repo_fpath = repo_fdir.joinpath(sha1_hash)
try:
# First transfer the file to the internal sha1_store
repo_fpath = Path.cwd().joinpath(repo_fdir, sha1_hash[2:])
shutil.copy2(file_path, repo_fpath)
# Create the dependency Graph and track reference
self._add_ref(sha1_hash, parent_sha1)
except BaseException as error:
# in case of failure: Cleans up the sub-directories created to store sha1-named checkpoints
# Something went wrong, perhaps out of space, or race condition due to lack of locking.
# TODO (Min): proper handle the error and recover when we learn more here.
sys.stderr.write(f"An exception occured: {repr(error)}\n")
shutil.rmtree(repo_fdir)
ref_count = self._add_ref(sha1_hash, False)
# Clean up if needed.
if remove_tmp:
file_path.unlink()
return sha1_hash
def _get_sha1_hash(self, file_path: Union[str, Path]) -> str:
"""return the sha1 hash of a file
def get(self, sha1: str) -> Union[Tensor, OrderedDict]:
"""Get data from a SHA1
Args:
sha1 (str):
SHA1 of the object to get.
Returns:
(Tensor or OrderedDict):
In-memory object.
"""
raise NotImplementedError()
def delete(self, sha1: str) -> None:
"""Delete a SHA1
Args:
file_path (str, Path): Path to the file whose sha1 hash is to be calculalated and returned.
sha1 (str):
SHA1 of the object to delete.
"""
SHA1_BUF_SIZE = 104857600 # Reading file in 100MB chunks
raise NotImplementedError()
def _get_sha1_hash(self, file_path: Union[str, Path]) -> str:
"""Return the sha1 hash of a file
Args:
file_path (str, Path):
Path to the file whose sha1 hash is to be calculalated and returned.
Returns:
(str):
The SHA1 computed.
"""
sha1 = hashlib.sha1()
with open(file_path, "rb") as f:
while True:
data = f.read(SHA1_BUF_SIZE)
data = f.read(self._sha1_buf_size)
if not data:
break
sha1.update(data)
return sha1.hexdigest()
def _add_ref(self, current_sha1_hash: str, parent_hash: str) -> None:
def _get_tmp_file_path(self) -> Path:
"""Helper to get a tmp file name under self.tmp_dir."""
return Path(tempfile.mkstemp(dir=self._tmp_dir)[1])
def _sha1_to_dir(self, sha1: str) -> Path:
"""Helper to get the internal dir for a file based on its SHA1"""
# Using first 2 letters of the sha1, which results 26 * 26 = 676 subdirs under the top
# level. Then, using another 2 letters for sub-sub-dir. If each dir holds 1000 files, this
# can hold 450 millions files.
# NOTE: this can NOT be changed for backward compatible reasons once in production.
assert len(sha1) > 4, "sha1 too short"
part1, part2 = sha1[:2], sha1[2:4]
return self._path.joinpath(part1, part2)
def _add_ref(self, current_sha1_hash: str, inc: bool) -> int:
"""
Populates the sha1_refs.json file when file is added and keeps track of reference to earlier file additions.
If the sha1_refs.json file is empty, then a new tracking entry of the added file is logged in the sha1_refs file.
If the file already has an entry, first it checks if the incoming new added file is a new version of any of the
existing entries. If it is, then logs the tracking info as a new version of that existing entry.
Otherwise a new entry for the new added file is created for tracking.
Update the reference count.
If the reference counting file does not have this sha1, then a new tracking
entry of the added.
Args:
file_path (pathlib.Path)
Path to the incoming added file.
current_sha1_hash (str)
current_sha1_hash (str):
The sha1 hash of the incoming added file.
inc (bool):
Increment or decrement.
Returns:
(int):
Resulting ref count.
"""
# Check the current state of the reference file and check if the added file already has an entry.
sha1_refs_empty = self._sha1_refs_file_state()
# if the file is empty: add the first entry
if sha1_refs_empty:
with open(self.ref_file_path) as f:
ref_data = {current_sha1_hash: {"parent": "ROOT", "ref_count": 1, "is_leaf": True}}
self._write_to_json(self.ref_file_path, ref_data)
else:
# Open sha1 reference file and check if there is a parent_hash not equal to Root?
# if Yes, find parent and add the child. Else, just add a new entry
with open(self.ref_file_path, "r") as f:
ref_data = json.load(f)
if parent_hash != "ROOT":
# get the last head and replace it's child from HEAD -> this sha1
ref_data[parent_hash]["is_leaf"] = False
ref_data[current_sha1_hash] = {"parent": parent_hash, "ref_count": 1, "is_leaf": True}
else:
ref_data[current_sha1_hash] = {"parent": "ROOT", "ref_count": 1, "is_leaf": True}
self._load_json_dict()
self._write_to_json(self.ref_file_path, ref_data)
# Init the entry if needed.
if current_sha1_hash not in self._json_dict:
self._json_dict[current_sha1_hash] = 0
def _sha1_refs_file_state(self) -> bool:
"""
Checks the state of the sha1 reference file, whether the file is empty or not.
If not empty, it checks whether the input file in <file_path> has an older entry (version)
in the reference file.
# Update the ref count.
self._json_dict[current_sha1_hash] += 1 if inc else -1
assert self._json_dict[current_sha1_hash] >= 0, "negative ref count"
Args:
file_path (pathlib.Path)
input File whose entry will be checked if it exists in the reference file.
"""
try:
with open(self.ref_file_path, "r") as f:
ref_data = json.load(f)
sha1_refs_empty: bool = False
except json.JSONDecodeError as error:
if not self.ref_file_path.stat().st_size:
sha1_refs_empty = True
return sha1_refs_empty
def _write_to_json(self, file: Path, data: dict) -> None:
"""
Populates a json file with data.
Args:
file (pathlib.Path)
path to the file to be written in.
data (pathlib.Path)
Data to be written in the file.
"""
with open(file, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=4)
self._store_json_dict()
return self._json_dict[current_sha1_hash]
import os
import pickle
from pathlib import Path
from typing import Any, BinaryIO, Callable, IO, Union
DEFAULT_PROTOCOL: int = 2
......@@ -7,4 +8,4 @@ DEFAULT_PROTOCOL: int = 2
def save(obj, f: Union[str, os.PathLike, BinaryIO, IO[bytes]],
pickle_module: Any=pickle, pickle_protocol: int=DEFAULT_PROTOCOL, _use_new_zipfile_serialization: bool=True) -> None: ...
def load(f: Union[str, BinaryIO], map_location) -> Any: ...
def load(f: Union[str, BinaryIO, Path], map_location=None) -> Any: ...
......@@ -10,9 +10,10 @@ import random
import shutil
import pytest
import torch
from torch import nn
from fairscale.experimental.wgit import repo as api
from fairscale.experimental.wgit.repo import RepoStatus
from fairscale.experimental.wgit.repo import Repo, RepoStatus
@pytest.fixture
......@@ -32,14 +33,13 @@ def create_test_dir():
# create random checkpoints
size_list = [30e5, 35e5, 40e5, 40e5]
for i, size in enumerate(size_list):
with open(f"checkpoint_{i}.pt", "wb") as f:
f.write(os.urandom(int(size)))
torch.save(nn.Linear(1, int(size)), f"checkpoint_{i}.pt")
return test_dir
@pytest.fixture
def repo():
repo = api.Repo(Path.cwd(), init=True)
repo = Repo(Path.cwd(), init=True)
return repo
......@@ -48,8 +48,8 @@ def test_setup(create_test_dir):
def test_api_init(capsys, repo):
repo = api.Repo(Path.cwd(), init=True)
assert Path(".wgit/sha1_refs.json").is_file()
repo = Repo(Path.cwd(), init=True)
assert Path(".wgit/sha1_store").is_dir()
assert Path(".wgit/.gitignore").is_file()
assert Path(".wgit/.git").exists()
assert Path(".wgit/.gitignore").exists()
......@@ -80,7 +80,7 @@ def test_api_commit(capsys, repo):
def test_api_status(capsys, repo):
# delete the repo and initialize a new one:
shutil.rmtree(".wgit")
repo = api.Repo(Path.cwd(), init=True)
repo = Repo(Path.cwd(), init=True)
# check status before any file is added
out = repo.status()
......@@ -99,8 +99,7 @@ def test_api_status(capsys, repo):
assert out == {key_list[0]: RepoStatus.CLEAN}
# check status after a new change has been made to the file
with open(chkpt0, "wb") as f:
f.write(os.urandom(int(15e5)))
torch.save(nn.Linear(1, int(15e5)), chkpt0)
out = repo.status()
assert out == {key_list[0]: RepoStatus.CHANGES_NOT_ADDED}
......
......@@ -9,13 +9,18 @@ from pathlib import Path
import shutil
import pytest
import torch
from torch import nn
import fairscale.experimental.wgit.cli as cli
from fairscale.experimental.wgit.sha1_store import SHA1_store
from fairscale.experimental.wgit.sha1_store import SHA1_Store
@pytest.fixture
@pytest.fixture(scope="module")
def create_test_dir():
"""This setup function runs once per test of this module and
it creates a repo, in the process, testing the init function.
"""
curr_dir = Path.cwd()
parent_dir = "experimental"
test_dir = curr_dir.joinpath(parent_dir, "wgit_testing/")
......@@ -30,28 +35,27 @@ def create_test_dir():
# create random checkpoints
size_list = [30e5, 35e5, 40e5]
for i, size in enumerate(size_list):
with open(f"checkpoint_{i}.pt", "wb") as f:
f.write(os.urandom(int(size)))
return test_dir
torch.save(nn.Linear(1, int(size)), f"checkpoint_{i}.pt")
def test_setup(create_test_dir):
# Test init.
cli.main(["init"])
assert str(create_test_dir.stem) == "wgit_testing"
assert str(test_dir.stem) == "wgit_testing"
return test_dir
def test_cli_init(capsys):
def test_cli_init(create_test_dir, capsys):
# Check if the json and other files have been created by the init
assert Path(".wgit/sha1_refs.json").is_file()
assert Path(".wgit/sha1_store").is_dir()
assert Path(".wgit/.gitignore").is_file()
assert Path(".wgit/.git").exists()
def test_cli_add(capsys):
def test_cli_add(create_test_dir, capsys):
chkpt0 = "checkpoint_0.pt"
cli.main(["add", "checkpoint_0.pt"])
cli.main(["add", chkpt0])
sha1_store = SHA1_store(
sha1_store = SHA1_Store(
Path.cwd().joinpath(".wgit"),
init=False,
)
......
......@@ -3,148 +3,111 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
import json
import os
from pathlib import Path
import shutil
import pytest
import torch
from torch import nn
from fairscale.experimental.wgit.repo import Repo
from fairscale.experimental.wgit.sha1_store import SHA1_store
from fairscale.experimental.wgit.sha1_store import SHA1_Store
from fairscale.internal import torch_version
# Get the absolute path of the parent at the beginning before any os.chdir(),
# so that we can proper clean it up at any CWD.
PARENT_DIR = Path("sha1_store_testing").resolve()
@pytest.fixture
def sha1_configs():
@dataclass
class Sha1StorePaths:
test_dirs = Path("temp_wgit_testing/.wgit")
test_path = Path.cwd().joinpath(test_dirs)
sha1_ref = test_path.joinpath("sha1_refs.json")
chkpt1a_dir = test_path.joinpath("checkpoint_1a")
chkpt1b_dir = test_path.joinpath("checkpoint_1b")
chkpt1c_dir = test_path.joinpath("checkpoint_1c")
checkpoint_1a = test_path.joinpath("checkpoint_1a", "checkpoint_1.pt")
checkpoint_1b = test_path.joinpath("checkpoint_1b", "checkpoint_1.pt")
checkpoint_1c = test_path.joinpath("checkpoint_1c", "checkpoint_1.pt")
checkpoint_2 = test_path.joinpath("checkpoint_1a", "checkpoint_2.pt")
checkpoint_3 = test_path.joinpath("checkpoint_1a", "checkpoint_3.pt")
metadata_1 = test_path.joinpath("checkpoint_1.pt")
metadata_2 = test_path.joinpath("checkpoint_2.pt")
metadata_3 = test_path.joinpath("checkpoint_3.pt")
return Sha1StorePaths
@pytest.fixture(scope="function")
def sha1_store(request):
"""A fixture for setup and teardown.
This only runs once per test function. So don't make this too slow.
@pytest.fixture
def sha1_store(sha1_configs):
repo = Repo(sha1_configs.test_path.parent, init=False)
sha1_store = SHA1_store(sha1_configs.test_dirs, init=False)
return repo, sha1_store
Tests must be written in a way that either all of the tests run
in the order they appears in this file or a specific test is
run separately by the user. Either way, the test should work.
"""
# Attach a teardown function.
def teardown():
os.chdir(PARENT_DIR.joinpath("..").resolve())
shutil.rmtree(PARENT_DIR, ignore_errors=True)
request.addfinalizer(teardown)
def test_setup(sha1_configs):
# Set up the testing directory
sha1_configs.test_dirs.mkdir(parents=True, exist_ok=True) # create test .wgit dir
# Teardown in case last run didn't clean it up.
teardown()
# Create the test checkpoint files
sha1_configs.chkpt1a_dir.mkdir(exist_ok=False)
sha1_configs.chkpt1b_dir.mkdir(exist_ok=False)
sha1_configs.chkpt1c_dir.mkdir(exist_ok=False)
# Get an empty sha1 store.
PARENT_DIR.mkdir()
sha1_store = SHA1_Store(PARENT_DIR, init=True)
return sha1_store
def test_sha1_add_file(sha1_store):
os.chdir(PARENT_DIR)
# Create random checkpoints
size_list = [25e5, 27e5, 30e5, 35e5, 40e5]
chkpts = [
sha1_configs.checkpoint_1a,
sha1_configs.checkpoint_1b,
sha1_configs.checkpoint_1c,
sha1_configs.checkpoint_2,
sha1_configs.checkpoint_3,
"checkpoint_1a.pt",
"checkpoint_1b.pt",
"checkpoint_1c.pt",
"checkpoint_2.pt",
"checkpoint_3.pt",
]
for file, size in zip(chkpts, size_list):
with open(file, "wb") as f:
f.write(os.urandom(int(size)))
repo = Repo(sha1_configs.test_path.parent, init=True)
sha1_store = SHA1_store(sha1_configs.test_dirs, init=True)
return sha1_store
def test_sha1_add(sha1_configs, sha1_store):
repo, sha1_store = sha1_store
# Add checkpoint_1: Create the meta_data
chkpt1 = sha1_configs.checkpoint_1a
metadata_file, parent_sha1 = repo._process_metadata_file(chkpt1.name)
sha1_hash = sha1_store.add(sha1_configs.checkpoint_1a, parent_sha1)
repo._write_metadata(metadata_file, chkpt1, sha1_hash)
# for checkpoint 1
metadata_file = sha1_configs.test_path.joinpath(sha1_configs.checkpoint_1a.name)
with open(metadata_file, "r") as file:
metadata = json.load(file)
assert metadata["SHA1"]["__sha1_full__"] == sha1_hash
def test_sha1_refs(sha1_configs, sha1_store):
repo, sha1_store = sha1_store
def add_checkpoint(checkpoint):
metadata_file, parent_sha1 = repo._process_metadata_file(checkpoint.name)
sha1_hash = sha1_store.add(checkpoint, parent_sha1)
repo._write_metadata(metadata_file, checkpoint, sha1_hash)
return sha1_hash
with open(sha1_configs.sha1_ref, "r") as file:
refs_data = json.load(file)
# get checkpoint1 sha1
sha1_chkpt1a_hash = sha1_store._get_sha1_hash(sha1_configs.checkpoint_1a)
assert refs_data[sha1_chkpt1a_hash]["parent"] == "ROOT"
assert refs_data[sha1_chkpt1a_hash]["ref_count"] == 1
ck1a_sha1_hash = sha1_store._get_sha1_hash(sha1_configs.checkpoint_1a)
# add checkpoint new version of checkpoint-1
ck1b_sha1_hash = add_checkpoint(sha1_configs.checkpoint_1b)
# Add new checkpoints 2 and 3
ck2_sha1_hash = add_checkpoint(sha1_configs.checkpoint_2)
ck3_sha1_hash = add_checkpoint(sha1_configs.checkpoint_3)
# add another version of checkpoint 1
ck1c_sha1_hash = add_checkpoint(sha1_configs.checkpoint_1c)
# load ref file after Sha1 add
with open(sha1_configs.sha1_ref, "r") as file:
refs_data = json.load(file)
# Tests for same file versions
assert refs_data[ck1b_sha1_hash]["parent"] == ck1a_sha1_hash
assert refs_data[ck1c_sha1_hash]["parent"] == ck1b_sha1_hash
assert refs_data[ck1b_sha1_hash]["ref_count"] == 1
assert refs_data[ck1a_sha1_hash]["is_leaf"] is False
assert refs_data[ck1a_sha1_hash]["is_leaf"] is False
assert refs_data[ck1b_sha1_hash]["is_leaf"] is False
assert refs_data[ck1c_sha1_hash]["is_leaf"] is True
# Tests for new files
assert refs_data[ck2_sha1_hash]["parent"] == "ROOT"
assert refs_data[ck2_sha1_hash]["is_leaf"] is True
assert refs_data[ck3_sha1_hash]["parent"] == "ROOT"
assert refs_data[ck3_sha1_hash]["is_leaf"] is True
def test_tear_down(sha1_configs):
# clean up: delete the .wgit directory created during this Test
# Making sure the current directory is ./temp_wgit_testing before removing test dir
test_parent_dir = sha1_configs.test_path.parent
if (test_parent_dir.stem == "temp_wgit_testing") and (sha1_configs.test_path.stem == ".wgit"):
shutil.rmtree(test_parent_dir)
else:
raise Exception("Exception in testing directory tear down!")
torch.save(nn.Linear(1, int(size)), file)
# Add those 5 random files.
for c in chkpts:
sha1_store.add(c)
# Add a fixed data twice.
module = nn.Linear(100, 100, bias=False)
module.weight.data = torch.zeros(100, 100)
zeros_file = "zeros.pt"
torch.save(module.state_dict(), zeros_file)
sha1_store.add(zeros_file)
sha1_store.add(zeros_file)
# Assert the ref counts are 1,1,1,1,1 and 2
sha1_store._load_json_dict()
json_dict = sha1_store._json_dict
if torch_version() >= (1, 9, 0):
# torch 1.8 LTS doesn't produce deterministic checkpoint file from fixed tensors/state_dict.
key = "da3e19590de8f77fcf7a09c888c526b0149863a0"
assert key in json_dict.keys() and json_dict[key] == 2, json_dict
del json_dict["created_on"]
assert sorted(json_dict.values()) == [1, 1, 1, 1, 1, 2], json_dict
def test_sha1_add_state_dict(sha1_store):
os.chdir(PARENT_DIR)
# add once
for i in range(3):
sha1_store.add(nn.Linear(10, 10).state_dict())
# add twice
for i in range(3):
sd = nn.Linear(8, 8).state_dict()
sha1_store.add(sd)
sha1_store.add(sd)
sha1_store._load_json_dict()
json_dict = sha1_store._json_dict
del json_dict["created_on"]
assert sorted(json_dict.values()) == [1, 1, 1, 2, 2, 2], json_dict
def test_sha1_add_tensor(sha1_store):
os.chdir(PARENT_DIR)
sha1_store.add(torch.Tensor([1.0, 5.5, 3.4]))
sha1_store._load_json_dict()
json_dict = sha1_store._json_dict
if torch_version() >= (1, 9, 0):
# torch 1.8 LTS doesn't produce deterministic checkpoint file from fixed tensors/state_dict.
key = "71df4069a03a766eacf9f03eea50968e87eae9f8"
assert key in json_dict.keys() and json_dict[key] == 1, json_dict
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment