Unverified Commit c506e7ed authored by Riyasat Ohib's avatar Riyasat Ohib Committed by GitHub
Browse files

Addition of wgit add and wgit commit functionalities. Includes refactors and new classes. (#1002)

* [feat] Adds the implementaion for the wgit add functionality, with sha1 hash creation, reference tracking, dependency graph creation and all related functionalities for the wgit add method.

* [feat] Adds the wgit add and wgit commit functionalities and major refactors.

1. Adds the wgit add and wgit commit functionalities to the api.
2. Introduces a new PyGit class that wraps the internal .wgit/.git repo.
3. Refactors the Repo class in the api, and introduces some methods.
4. .Refactors all the classes which no longer uses @staticmethods and now uses object istances instead.
5.  Moved many of the directory path handling code from os.path to pathlib library.

* [Feat] Combines the Repo and Weigit classes. Separate claases into separate modules.

1. Combines the functionalities of the WeiGit and Repo class into a single WeiGitRepo class.
2. Classes are now separated into their own modules.
3. Moved some functions and staticmethod to utils.
4. Adds a range of tests for add and commit functionalities of weigit.

* [fix] adds a new test to the ci_test_list_3

* [fix] test fix

* [fix] test fix

* [Feat] Directory restructuring, type checking and some standardization
1. Restructured the directory and moved wgit to fairscale/experimental/wgit so that it can be found as a package when pip installed.
2. Added a range of type checking
3. Some refactors

* [Feat][Refactor] Directory restructuring, test addition and type checking
1. Restructed the test directory
2. Added and modified a few wgit tests.
3. Added some type checking to the code

* test fix

* "setup fix and repo checking added in cli"

* [Feat] Better initialization and error handling for init and wgit subcommands. Test reorg.

* [refactor] Changes in classes, encapsulation and addition of PyGit test.

* [Feat][Refactor]
1. Changed some class method arguments for better encapsulation for Sha1_store.
2. Moved sha1 hash calculation within sha1_store.
3. Some standardization and code clean up of unnecessary snippets.
4. Added new tests for the PyGit and Sha1_Store class.
parent 2350968e
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
.testmondata .testmondata
# experimental weigit # experimental weigit
experimental/.gitignore fairscale/experimental/wgit/dev
.wgit .wgit
# Build and release # Build and release
......
from enum import Enum
class ExitCode(Enum):
CLEAN = 0
FILE_EXISTS_ERROR = 1
ERROR = -1 # unknown errors
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import os
import pathlib
import sys
import pygit2
from experimental.wgit.utils import ExitCode
class WeiGit:
def __init__(self) -> None:
"""
Planned Features:
1. create the wgit directory. Error, if already dir exists.
2. SHA1Store.init()
3. Create SHA1 .wgit/sha1_ref_count.json
3. Initialize a .git directory within the .wgit using `git init`.
4. add a .gitignore within the .wgit directory, so that the git repo within will ignore `sha1_ref_count.json`
"""
# Make .wgit directory. If already exists, we error out
try:
os.mkdir(".wgit")
except FileExistsError:
sys.stderr.write("An exception occured while wgit initialization: WeiGit already Initialized\n")
sys.exit(ExitCode.FILE_EXISTS_ERROR)
# if no .wgit dir then initialize the following
SHA1_store()
# create sha1_ref_count and a .gitignore:
# In general sha1_ref_count is only create only if .wgit already exists
try:
ref_count_json = ".wgit/sha1_ref_count.json"
with open(ref_count_json, "w") as f:
pass
except FileExistsError as error:
sys.stderr.write(f"An exception occured while creating {ref_count_json}: {repr(error)}\n")
sys.exit(ExitCode.FILE_EXISTS_ERROR)
# Make the .wgit a git repo
try:
pygit2.init_repository(".wgit/.git", False)
except BaseException as error:
sys.stderr.write(f"An exception occurred while initializing .wgit/.git: {repr(error)}\n")
sys.exit(ExitCode.ERROR)
# add a .gitignore:
try:
gitignore = ".wgit/.gitignore"
with open(gitignore, "w") as f:
f.write("sha1_ref_count.json")
except FileExistsError as error:
sys.stderr.write(f"An exception occured while creating {gitignore}: {repr(error)}\n")
sys.exit(ExitCode.FILE_EXISTS_ERROR)
@staticmethod
def add(file):
if Repo(os.getcwd()).exists():
print("wgit added")
@staticmethod
def status():
if Repo(os.getcwd()).exists():
print("wgit status")
@staticmethod
def log(file):
if Repo(os.getcwd()).exists():
if file:
print(f"wgit log of the file: {file}")
else:
print("wgit log")
@staticmethod
def commit(message):
if Repo(os.getcwd()).exists():
if message:
print(f"commited with message: {message}")
else:
print("wgit commit")
@staticmethod
def checkout():
if Repo(os.getcwd()).exists():
print("wgit checkout")
@staticmethod
def compression():
print("Not Implemented!")
@staticmethod
def checkout_by_steps():
print("Not Implemented!")
class SHA1_store:
"""
Planned Features:
1. def init
2. def add <file or data> -> SHA1
3. def remove (SHA1)
4. def add_ref(children_SHA1, parent_SHA1)
5. def read(SHA1): ->
6. def lookup(SHA1): -> file path to the data. NotFound Exception if not found.
"""
def __init__(self) -> None:
pass
class Repo:
"""
Designates the weigit repo, which is identified by a path to the repo.
"""
def __init__(self, check_dir) -> None:
self.repo_path = None
self.check_dir = os.path.realpath(check_dir)
def exists(self):
def weigit_repo_exists(check_dir):
"""
checks if the input path to dir (check_dir) is a valid weigit repo
with .git and sha1_ref_count in the repo.
"""
is_wgit_in_curr = pathlib.Path(os.path.join(check_dir, ".wgit")).exists()
is_refcount_in_wgit = pathlib.Path(os.path.join(check_dir, ".wgit/sha1_ref_count.json")).exists()
is_git_in_wgit = pathlib.Path(os.path.join(check_dir, ".wgit/.git")).exists()
return is_wgit_in_curr and is_refcount_in_wgit and is_git_in_wgit
if weigit_repo_exists(self.check_dir):
self.repo_path = os.path.join(self.check_dir, ".wgit")
else:
while self.check_dir != os.getcwd():
self.check_dir = os.path.dirname(self.check_dir)
if weigit_repo_exists(self.check_dir):
self.repo_path = os.path.join(self.check_dir, ".wgit")
break
if self.repo_path is None:
print("Initialize a weigit repo first!!")
is_exist = False
else:
is_exist = True
return is_exist
def get_repo_path(self):
if self.repo_path is None:
self.exists()
return self.repo_path
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from .cli import main from .repo import Repo
from .weigit_api import WeiGit from .version import __version_tuple__
__version__ = "0.0.1" __version__ = ".".join([str(x) for x in __version_tuple__])
...@@ -4,12 +4,13 @@ ...@@ -4,12 +4,13 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import argparse import argparse
from pathlib import Path
from typing import List
import experimental.wgit as wgit from . import Repo, version
import experimental.wgit.weigit_api as weigit_api
def main(argv=None): def main(argv: List[str] = None) -> None:
desc = "WeiGit: A git-like tool for model weight tracking" desc = "WeiGit: A git-like tool for model weight tracking"
# top level parser and corresponding subparser # top level parser and corresponding subparser
...@@ -71,25 +72,30 @@ def main(argv=None): ...@@ -71,25 +72,30 @@ def main(argv=None):
args = parser.parse_args(argv) args = parser.parse_args(argv)
if args.command == "init": if args.command == "init":
weigit = weigit_api.WeiGit() repo = Repo(Path.cwd(), init=True)
if args.command == "add": if args.command == "add":
weigit_api.WeiGit.add(args.add) repo = Repo(Path.cwd())
repo.add(args.add)
if args.command == "status": if args.command == "status":
weigit_api.WeiGit.status() repo = Repo(Path.cwd())
repo.status()
if args.command == "log": if args.command == "log":
weigit_api.WeiGit.log(args.file) repo = Repo(Path.cwd())
repo.log(args.file)
if args.command == "commit": if args.command == "commit":
weigit_api.WeiGit.commit(args.message) repo = Repo(Path.cwd())
repo.commit(args.message)
if args.command == "checkout": if args.command == "checkout":
weigit_api.WeiGit.checkout() repo = Repo(Path.cwd())
repo.checkout(args.checkout)
if args.command == "version": if args.command == "version":
print(wgit.__version__) print(".".join([str(x) for x in version.__version_tuple__]))
if __name__ == "__main__": if __name__ == "__main__":
......
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from pathlib import Path
from typing import List
import pygit2
class PyGit:
def __init__(
self,
parent_path: Path,
gitignore: List = list(),
name: str = "user",
email: str = "user@email.com",
) -> None:
"""
PyGit class to wrap the wgit/.git repo and interact with the git repo.
Args:
parent_path: Has to be the full path of the parent!
"""
# Find if a git repo exists within .wgit repo:
# If exists: then discover it and set the self.gitrepo path to its path
self._parent_path = parent_path
self.name = name
self.email = email
git_repo_found = pygit2.discover_repository(self._parent_path)
if git_repo_found:
# grab the parent dir of this git repo
git_repo = Path(pygit2.discover_repository(self._parent_path))
pygit_parent_p = git_repo.parent.absolute()
# Check If the parent dir is a .wgit dir. If the .wgit is a git repo
# just wrap the existing git repo with pygit2.Repository class
if pygit_parent_p == self._parent_path:
self.repo = pygit2.Repository(str(self._parent_path))
self.path = self._parent_path.joinpath(".git")
else:
# if the parent is not a .wgit repo,
# then the found-repo is a different git repo. Init new .wgit/.git
self._init_wgit_git(gitignore)
else:
# no git repo found, make .wgit a git repo
self._init_wgit_git(gitignore)
def _init_wgit_git(self, gitignore: List) -> None:
"""Initializes a .git within .wgit directory, making it a git repo."""
self.repo = pygit2.init_repository(str(self._parent_path), False)
self.path = self._parent_path.joinpath(".git")
# create and populate a .gitignore
self._parent_path.joinpath(".gitignore").touch(exist_ok=False)
with open(self._parent_path.joinpath(".gitignore"), "a") as file:
for item in gitignore:
file.write(f"{item}\n")
def add(self) -> None:
"""git add all the untracked files not in gitignore, to the .wgit/.git repo."""
# If .wgit is git repo, add all the files in .wgit not being ignored to git
if self._exists:
self.repo.index.add_all()
self.repo.index.write()
def commit(self, message: str) -> None:
"""git commit the staged changes to the .wgit/.git repo."""
# If .wgit is git repo, commit the staged files to git
if self._exists:
# if no commit exists, set ref to HEAD and parents to empty
try:
ref = self.repo.head.name
parents = [self.repo.head.target]
except pygit2.GitError:
ref = "HEAD"
parents = []
author = pygit2.Signature(self.name, self.email)
committer = pygit2.Signature(self.name, self.email)
tree = self.repo.index.write_tree()
self.repo.create_commit(ref, author, committer, message, tree, parents)
@property
def _exists(self) -> bool:
"""returns True if wgit is a git repository"""
return self._parent_path == Path(self.repo.path).parent
@property
def _path(self) -> str:
"""returns the path of the git repository PyGit is wrapped around"""
return self.repo.path
def status(self) -> None:
"""Print the status of the git repo"""
print(self.repo.status())
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import pathlib
from pathlib import Path
import sys
from typing import Union
from .pygit import PyGit
from .sha1_store import SHA1_store
class Repo:
def __init__(self, parent_dir: Path, init: bool = False) -> None:
"""Features:
1. Create the wgit directory if it does not exist.
2. SHA1Store.init()
3. Create SHA1 .wgit/sha1_refs.json
3. Initialize a .git directory within the .wgit using `git init`.
4. add a .gitignore within the .wgit directory, so that the git repo within will ignore `sha1_refs.json`
"""
# If repo does not exist, creates a new wgit repo object with self.repo.path pointing to the path of repo
# and notes all the internal files.
# else, if repo already exists: create a pygit object from the .wgit/.git.
self.wgit_parent = parent_dir
self._repo_path: Union[None, Path] = None
self._wgit_dir = Path(".wgit")
self._metadata_file = Path(".wgit/checkpoint.pt")
self._sha1_ref = Path(".wgit/sha1_refs.json")
self._wgit_git_path = Path(".wgit/.git")
self._sha1_store_path = Path(".wgit/sha1_store")
if not self._exists() and init:
# No weigit repo exists and is being initialized with init=True
# Make .wgit directory, create sha1_refs and metadata file
self._wgit_dir.mkdir(exist_ok=True)
self._metadata_file.touch(exist_ok=False)
self._sha1_ref.touch(exist_ok=False)
# # Make the .wgit a git repo
gitignore_files = [self._sha1_store_path.name, self._sha1_ref.name]
self._pygit = PyGit(self.wgit_parent.joinpath(self._wgit_dir), gitignore=gitignore_files)
# Initializing sha1_store only after wgit has been initialized!
self._sha1_store = SHA1_store(self._wgit_dir, self._metadata_file, self._sha1_ref, init=True)
elif self._exists() and init:
# if weigit repo already exists and init is being called, wrap the existing .wgit/.git repo with PyGit
self._sha1_store = SHA1_store(
self._wgit_dir,
self._metadata_file,
self._sha1_ref,
)
self._pygit = PyGit(self.wgit_parent.joinpath(self._wgit_dir))
elif self._exists() and not init:
# weigit exists and non-init commands are triggered
self._sha1_store = SHA1_store(
self._wgit_dir,
self._metadata_file,
self._sha1_ref,
)
self._pygit = PyGit(self.wgit_parent.joinpath(self._wgit_dir))
else:
# weigit doesn't exist and is not trying to be initialized (triggers during non-init commands)
sys.stderr.write("fatal: not a wgit repository!\n")
def add(self, file_path: str) -> None:
"""Adds a file to the wgit repo"""
if self._exists():
self._sha1_store.add(file_path) # add the filefile to the sha1_store
self._pygit.add() # add to the .wgit/.git repo
def commit(self, message: str) -> None:
"""Commits staged changes to the repo"""
if self._exists():
self._pygit.commit(message)
def status(self) -> None:
"""Skeleton"""
if self._exists():
print("wgit status")
def log(self, file: str) -> None:
"""Returns the WeiGit log of commit history."""
if self._exists():
if file:
print(f"wgit log of the file: {file}")
else:
print("wgit log")
def checkout(self, sha1: str) -> None:
"""Checkout a previously commited version of the checkpoint"""
if self._exists():
print("wgit checkout: sha1")
def compression(self) -> None:
"""Not Implemented: Compression functionalities"""
print("Not Implemented!")
def checkout_by_steps(self) -> None:
"""Not Implemented: Checkout by steps"""
print("Not Implemented!")
@property
def path(self) -> str:
"""Get the path to the WeiGit repo"""
if self._repo_path is None:
self._exists()
return str(self._repo_path)
def _exists(self) -> bool:
"""Returns True if a valid wgit exists within the cwd, and sets the self._repo_path to the wgit path."""
if self._weigit_repo_exists(self.wgit_parent):
self._repo_path = self.wgit_parent.joinpath(".wgit")
return True if self._repo_path is not None else False
def _weigit_repo_exists(self, check_dir: pathlib.Path) -> bool:
"""Returns True if a valid WeiGit repo exists in the path: check_dir"""
wgit_exists, sha1_refs, git_exists, gitignore_exists = self._weight_repo_file_check(check_dir)
return wgit_exists and sha1_refs and git_exists and gitignore_exists
def _weight_repo_file_check(self, check_dir: Path) -> tuple:
"""Returns a tuple of boolean corresponding to the existence of each .wgit internally required files."""
wgit_exists = check_dir.joinpath(".wgit").exists()
sha1_refs = check_dir.joinpath(".wgit/sha1_refs.json").exists()
git_exists = check_dir.joinpath(".wgit/.git").exists()
gitignore_exists = check_dir.joinpath(".wgit/.gitignore").exists()
return wgit_exists, sha1_refs, git_exists, gitignore_exists
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import hashlib
import json
import os
from pathlib import Path
import shutil
import sys
from typing import Union, cast
from .utils import ExitCode
class SHA1_store:
def __init__(self, weigit_path: Path, metadata_file: Path, sha1_refs: Path, init: bool = False) -> None:
"""
Planned Features:
1. def init
2. def add <file or data> -> SHA1
3. def remove (SHA1)
4. def add_ref(children_SHA1, parent_SHA1)
5. def read(SHA1): ->
6. def lookup(SHA1): -> file path to the data. NotFound Exception if not found.
"""
# should use the sha1_refs.json to track the parent references.
self.name = "sha1_store"
self.path = weigit_path.joinpath(self.name)
self._ref_file_name = Path.cwd().joinpath(sha1_refs)
self._metadata_file = Path.cwd().joinpath(metadata_file)
# initialize the sha1_store
if init:
try:
if not self.path.exists():
Path.mkdir(self.path, parents=False, exist_ok=False)
except FileExistsError as error:
sys.stderr.write(f"An exception occured while creating Sha1_store: {repr(error)}\n")
sys.exit(ExitCode.FILE_EXISTS_ERROR)
def add(self, file_path: str) -> None:
"""Adds a file/checkpoint to the internal sha1_store and update the metadata and the
sha1 references accordingly.
"""
sha1_hash = self.get_sha1_hash(file_path)
# use the sha1_hash to create a directory with first2 sha naming convention
try:
repo_fdir = self.path.joinpath(sha1_hash[:2])
repo_fdir.mkdir(exist_ok=False)
except FileExistsError as error:
sys.stderr.write(f"An exception occured: {repr(error)}\n")
sys.exit(ExitCode.FILE_EXISTS_ERROR)
try:
# First transfer the file to the internal sha1_store
repo_fpath = Path.cwd().joinpath(repo_fdir, sha1_hash[2:])
shutil.copy2(file_path, repo_fpath)
change_time = Path(repo_fpath).stat().st_ctime
# Create the dependency Graph and track reference
self._add_ref(current_sha1_hash=sha1_hash)
metadata = {
"SHA1": {
"__sha1_full__": sha1_hash,
},
"file_path": str(repo_fpath),
"time_stamp": change_time,
}
# Populate the meta_data file with the meta_data and git add
self._add_metadata_to_json(metadata)
except BaseException as error:
# in case of failure: Cleans up the sub-directories created to store sha1-named checkpoints
sys.stderr.write(f"An exception occured: {repr(error)}\n")
shutil.rmtree(repo_fdir)
def _add_ref(self, current_sha1_hash: str) -> None:
"""Populates the sha1_refs.json file when file is added and keeps track of reference to earlier commits"""
if not os.path.getsize(self._ref_file_name): # If no entry yet
with open(self._ref_file_name) as f:
ref_data = {
current_sha1_hash: {"parent": "ROOT", "child": "HEAD", "ref_count": 0},
}
with open(self._ref_file_name, "w", encoding="utf-8") as f:
json.dump(ref_data, f, ensure_ascii=False, indent=4)
else:
with open(self._ref_file_name, "r") as f:
ref_data = json.load(f)
# get the last head and replace it's child from HEAD -> this sha1
for key, vals in ref_data.items():
if vals["child"] == "HEAD":
parent = key
ref_data[parent]["child"] = current_sha1_hash
# increase the ref counter of that (now parent sha1)
ref_count = cast(int, ref_data[parent]["ref_count"])
ref_count += 1
ref_data[parent]["ref_count"] = ref_count
# Add this new sha1 as a new entry, make the earlier sha1 a parent
# make "HEAD" as a child, and json dump
ref_data[current_sha1_hash] = {"parent": parent, "child": "HEAD", "ref_count": 0}
# Try
with open(self._ref_file_name, "w", encoding="utf-8") as f:
json.dump(ref_data, f, ensure_ascii=False, indent=4)
def get_sha1_hash(self, file_path: Union[str, Path]) -> str:
""" " return the sha1 hash of a file"""
SHA1_BUF_SIZE = 104857600 # Reading file in 100MB chunks
sha1 = hashlib.sha1()
with open(file_path, "rb") as f:
while True:
data = f.read(SHA1_BUF_SIZE)
if not data:
break
sha1.update(data)
return sha1.hexdigest()
def _add_metadata_to_json(self, metadata: dict) -> None:
"""Populates the meta_data_file: checkpoint.pt with the meta_data"""
file_pt_json = self._metadata_file
with open(file_pt_json, "w", encoding="utf-8") as f:
json.dump(metadata, f, ensure_ascii=False, indent=4)
...@@ -3,4 +3,12 @@ ...@@ -3,4 +3,12 @@
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from .wgit import cli, weigit_api from enum import Enum
class ExitCode(Enum):
"""Collections of the Exit codes as an Enum class"""
CLEAN = 0
FILE_EXISTS_ERROR = 1
FILE_DOES_NOT_EXIST_ERROR = 2
__version_tuple__ = (0, 0, 1)
...@@ -69,7 +69,7 @@ if __name__ == "__main__": ...@@ -69,7 +69,7 @@ if __name__ == "__main__":
author_email="todo@fb.com", author_email="todo@fb.com",
long_description="FairScale is a PyTorch extension library for high performance and large scale training on one or multiple machines/nodes. This library extends basic PyTorch capabilities while adding new experimental ones.", long_description="FairScale is a PyTorch extension library for high performance and large scale training on one or multiple machines/nodes. This library extends basic PyTorch capabilities while adding new experimental ones.",
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
entry_points={"console_scripts": ["wgit = experimental.wgit.__main__:main"]}, entry_points={"console_scripts": ["wgit = fairscale.experimental.wgit.__main__:main"]},
classifiers=[ classifiers=[
"Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.9",
......
tests/experimental/tooling/test_layer_memory_tracker.py
tests/experimental/nn/test_mevo.py tests/experimental/nn/test_mevo.py
tests/experimental/nn/test_multiprocess_pipe.py tests/experimental/nn/test_multiprocess_pipe.py
tests/experimental/nn/test_sync_batchnorm.py tests/experimental/nn/test_sync_batchnorm.py
...@@ -5,7 +6,6 @@ tests/experimental/nn/ampnet_pipe_process/test_ampnet_pipe.py ...@@ -5,7 +6,6 @@ tests/experimental/nn/ampnet_pipe_process/test_ampnet_pipe.py
tests/experimental/nn/test_offload.py tests/experimental/nn/test_offload.py
tests/experimental/nn/test_auto_shard.py tests/experimental/nn/test_auto_shard.py
tests/experimental/optim/test_dynamic_loss_scaler.py tests/experimental/optim/test_dynamic_loss_scaler.py
tests/experimental/tooling/test_layer_memory_tracker.py
tests/experimental/nn/test_ssd_offload.py tests/experimental/nn/test_ssd_offload.py
tests/nn/data_parallel/test_fsdp_shared_weights_mevo.py tests/nn/data_parallel/test_fsdp_shared_weights_mevo.py
tests/nn/data_parallel/test_fsdp_shared_weights.py tests/nn/data_parallel/test_fsdp_shared_weights.py
......
...@@ -23,4 +23,7 @@ tests/optim/test_oss_adascale.py ...@@ -23,4 +23,7 @@ tests/optim/test_oss_adascale.py
tests/optim/test_ddp_adascale.py tests/optim/test_ddp_adascale.py
tests/experimental/nn/data_parallel/test_gossip.py tests/experimental/nn/data_parallel/test_gossip.py
tests/nn/data_parallel/test_fsdp_hf_transformer_eval.py tests/nn/data_parallel/test_fsdp_hf_transformer_eval.py
tests/experimental/wgit/cli/test_cli.py tests/experimental/wgit/test_cli.py
tests/experimental/wgit/test_api.py
tests/experimental/wgit/test_pygit.py
tests/experimental/wgit/test_sha1_store.py
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import json
import os
from pathlib import Path
import shutil
import pytest
from fairscale.experimental.wgit import cli
from fairscale.experimental.wgit import repo as api
@pytest.fixture
def create_test_dir():
curr_dir = Path.cwd()
parent_dir = "experimental"
test_dir = curr_dir.joinpath(parent_dir, "wgit_testing/")
# creates a testing directory within ./experimental
try:
os.makedirs(test_dir)
except FileExistsError:
shutil.rmtree(test_dir)
os.makedirs(test_dir)
os.chdir(test_dir)
# create random checkpoints
size_list = [30e5, 35e5, 40e5]
for i, size in enumerate(size_list):
with open(f"checkpoint_{i}.pt", "wb") as f:
f.write(os.urandom(int(size)))
return test_dir
@pytest.fixture
def repo():
repo = api.Repo(Path.cwd())
return repo
def test_setup(create_test_dir):
assert str(create_test_dir.stem) == "wgit_testing"
def test_api_init(capsys, repo):
repo = api.Repo(Path.cwd(), init=True)
assert Path(".wgit/sha1_refs.json").is_file()
assert Path(".wgit/.gitignore").is_file()
assert Path(".wgit/.git").exists()
assert Path(".wgit/.gitignore").exists()
def test_api_add(capsys, repo):
chkpt0 = "checkpoint_0.pt"
repo.add("checkpoint_0.pt")
sha1_hash = repo._sha1_store.get_sha1_hash(chkpt0)
with open(os.path.join(".wgit", "checkpoint.pt"), "r") as f:
json_data = json.load(f)
sha1_dir_0 = f"{sha1_hash[:2]}/" + f"{sha1_hash[2:]}"
assert json_data["SHA1"] == {"__sha1_full__": sha1_hash}
assert json_data["file_path"] == os.path.join(os.getcwd(), ".wgit/sha1_store/", sha1_dir_0)
def test_api_commit(capsys, repo):
commit_msg = "epoch_1"
repo.commit(message=commit_msg)
with open(".wgit/.git/logs/HEAD") as f:
line = f.readlines()
assert line[0].rstrip().split()[-1] == commit_msg
def test_api_status(capsys, repo):
repo.status()
captured = capsys.readouterr()
assert captured.out == "wgit status\n"
assert captured.err == ""
def test_api_log(capsys, repo):
repo.log("testfile.pt")
captured = capsys.readouterr()
assert captured.out == "wgit log of the file: testfile.pt\n"
assert captured.err == ""
def test_cli_checkout(capsys):
cli.main(["checkout", "sha1"])
captured = capsys.readouterr()
assert captured.out == "wgit checkout: sha1\n"
assert captured.err == ""
def teardown_module(module):
# clean up: delete the .wgit directory created during this Test
# Making sure the current directory is ./experimental before removing test dir
if (Path.cwd().parent.name == "experimental") and (Path.cwd().name == "wgit_testing"):
os.chdir(Path.cwd().parent)
shutil.rmtree("./wgit_testing/")
else:
raise Exception("Exception in testing directory tear down!")
...@@ -3,16 +3,22 @@ ...@@ -3,16 +3,22 @@
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import json
import os import os
import pathlib from pathlib import Path
import shutil import shutil
import experimental.wgit.cli as cli import pytest
import fairscale.experimental.wgit.cli as cli
from fairscale.experimental.wgit.sha1_store import SHA1_store
def setup_module(module):
@pytest.fixture
def create_test_dir():
curr_dir = Path.cwd()
parent_dir = "experimental" parent_dir = "experimental"
test_dir = os.path.join(parent_dir, "wgit_testing/") test_dir = curr_dir.joinpath(parent_dir, "wgit_testing/")
# creates a testing directory within ./experimental # creates a testing directory within ./experimental
try: try:
...@@ -20,23 +26,55 @@ def setup_module(module): ...@@ -20,23 +26,55 @@ def setup_module(module):
except FileExistsError: except FileExistsError:
shutil.rmtree(test_dir) shutil.rmtree(test_dir)
os.makedirs(test_dir) os.makedirs(test_dir)
os.chdir(test_dir) os.chdir(test_dir)
# create random checkpoints
size_list = [30e5, 35e5, 40e5]
for i, size in enumerate(size_list):
with open(f"checkpoint_{i}.pt", "wb") as f:
f.write(os.urandom(int(size)))
return test_dir
def test_setup(create_test_dir):
cli.main(["init"]) cli.main(["init"])
assert str(create_test_dir.stem) == "wgit_testing"
def test_cli_init(capsys): def test_cli_init(capsys):
# Check if the json and other files have been created by the init # Check if the json and other files have been created by the init
assert pathlib.Path(".wgit/sha1_ref_count.json").is_file() assert Path(".wgit/sha1_refs.json").is_file()
assert pathlib.Path(".wgit/.gitignore").is_file() assert Path(".wgit/.gitignore").is_file()
assert pathlib.Path(".wgit/.git").exists() assert Path(".wgit/.git").exists()
def test_cli_add(capsys): def test_cli_add(capsys):
cli.main(["add", "test"]) chkpt0 = "checkpoint_0.pt"
captured = capsys.readouterr() cli.main(["add", "checkpoint_0.pt"])
assert captured.out == "wgit added\n"
assert captured.err == "" sha1_store = SHA1_store(
Path.cwd().joinpath(".wgit"),
Path.cwd().joinpath(".wgit", "checkpoint.pt"),
Path.cwd().joinpath(".wgit", "sha1_refs.json"),
init=False,
)
sha1_hash = sha1_store.get_sha1_hash(chkpt0)
with open(os.path.join(".wgit", "checkpoint.pt"), "r") as f:
json_data = json.load(f)
sha1_dir_0 = f"{sha1_hash[:2]}/" + f"{sha1_hash[2:]}"
assert json_data["SHA1"] == {"__sha1_full__": sha1_hash}
assert json_data["file_path"] == os.path.join(os.getcwd(), ".wgit/sha1_store/", sha1_dir_0)
def test_cli_commit(capsys):
commit_msg = "epoch_1"
cli.main(["commit", "-m", f"{commit_msg}"])
with open(".wgit/.git/logs/HEAD") as f:
line = f.readlines()
assert line[0].rstrip().split()[-1] == commit_msg
def test_cli_status(capsys): def test_cli_status(capsys):
...@@ -53,27 +91,18 @@ def test_cli_log(capsys): ...@@ -53,27 +91,18 @@ def test_cli_log(capsys):
assert captured.err == "" assert captured.err == ""
def test_cli_commit(capsys):
cli.main(["commit", "-m", "text"])
captured = capsys.readouterr()
assert captured.out == "commited with message: text\n"
assert captured.err == ""
def test_cli_checkout(capsys): def test_cli_checkout(capsys):
cli.main(["checkout", "sha1"]) cli.main(["checkout", "sha1"])
captured = capsys.readouterr() captured = capsys.readouterr()
assert captured.out == "wgit checkout\n" assert captured.out == "wgit checkout: sha1\n"
assert captured.err == "" assert captured.err == ""
def teardown_module(module): def teardown_module(module):
# clean up: delete the .wgit directory created during this Test # clean up: delete the .wgit directory created during this Test
parent_dir = "experimental"
os.chdir(os.path.dirname(os.getcwd()))
# Making sure the current directory is ./experimental before removing test dir # Making sure the current directory is ./experimental before removing test dir
if os.path.split(os.getcwd())[1] == parent_dir: if (Path.cwd().parent.name == "experimental") and (Path.cwd().name == "wgit_testing"):
os.chdir(Path.cwd().parent)
shutil.rmtree("./wgit_testing/") shutil.rmtree("./wgit_testing/")
else: else:
raise Exception("Exception in testing directory tear down!") raise Exception("Exception in testing directory tear down!")
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from pathlib import Path
import shutil
import pytest
from fairscale.experimental.wgit.pygit import PyGit
@pytest.fixture
def repo_data():
test_dirs = Path("temp_wgit_testing/.wgit")
file1, file2 = "test_file1", "test_file2"
out_dict = {
"test_path": Path.cwd().joinpath(test_dirs),
"file1": file1,
"file2": file2,
}
return out_dict
@pytest.fixture
def pygit_repo_wrap(repo_data):
path = Path.cwd().joinpath(repo_data["test_path"])
pygit_repo_wrap = PyGit(path, gitignore=[repo_data["file1"], repo_data["file2"]])
return pygit_repo_wrap
def test_setup(repo_data):
curr_dir = Path.cwd()
test_dir = curr_dir.joinpath(repo_data["test_path"])
# Initialize the repo for the first time
pygit_repo = PyGit(test_dir, gitignore=["test_file1", "test_file2"])
# create sample files
test_dir.joinpath("test_file_1.pt").touch()
test_dir.joinpath("test_file_2.pt").touch()
assert test_dir.stem == str(pygit_repo.path.parent.stem)
def test_pygit_add(pygit_repo_wrap):
"""Tests the add functionality of the PyGit class"""
assert str(pygit_repo_wrap.path.parent.stem) == ".wgit"
repo = pygit_repo_wrap.repo
# File status 128 in pygit2 signifies file has Not been added yet
assert repo.status()[".gitignore"] == 128
assert repo.status()["test_file_1.pt"] == 128
assert repo.status()["test_file_2.pt"] == 128
pygit_repo_wrap.add()
# File status 1 in pygit2 signifies file has been added to git repo
assert repo.status()[".gitignore"] == 1
assert repo.status()["test_file_1.pt"] == 1
assert repo.status()["test_file_2.pt"] == 1
def test_pygit_commit(pygit_repo_wrap):
"""Tests the add functionality of the PyGit class"""
assert str(pygit_repo_wrap.path.parent.stem) == ".wgit"
repo = pygit_repo_wrap.repo
# File status 1 in pygit2 signifies file has been added
assert repo.status()[".gitignore"] == 1
assert repo.status()["test_file_1.pt"] == 1
assert repo.status()["test_file_2.pt"] == 1
pygit_repo_wrap.commit("random_message")
# File status {} in pygit2 implies commit has been made
assert repo.status() == {}
def test_tear_down(repo_data):
# clean up: delete the .wgit directory created during this Test
# Making sure the current directory is ./temp_wgit_testing before removing test dir
if (repo_data["test_path"].parent.stem == "temp_wgit_testing") and (repo_data["test_path"].stem == ".wgit"):
shutil.rmtree(repo_data["test_path"].parent)
else:
raise Exception("Exception in testing directory tear down!")
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
import json
import os
from pathlib import Path
import shutil
import pytest
from fairscale.experimental.wgit.sha1_store import SHA1_store
@pytest.fixture
def sha1_configs():
@dataclass
class Sha1StorePaths:
test_dirs = Path("temp_wgit_testing/.wgit")
test_path = Path.cwd().joinpath(test_dirs)
metadata_file = test_path.joinpath("checkpoint.pt")
sha1_ref = test_path.joinpath("sha1_refs.json")
chkpt_dir = test_path.joinpath("checkpoint")
checkpoint_1 = test_path.joinpath("checkpoint", "checkpoint_1.pt")
checkpoint_2 = test_path.joinpath("checkpoint", "checkpoint_2.pt")
checkpoint_3 = test_path.joinpath("checkpoint", "checkpoint_3.pt")
return Sha1StorePaths
@pytest.fixture
def sha1_store(sha1_configs):
sha1_store = SHA1_store(sha1_configs.test_dirs, sha1_configs.metadata_file, sha1_configs.sha1_ref, init=False)
return sha1_store
def test_setup(sha1_configs):
# Set up the testing directory
sha1_configs.test_dirs.mkdir(parents=True, exist_ok=True) # create test .wgit dir
sha1_configs.metadata_file.touch()
sha1_configs.sha1_ref.touch()
# Create the test checkpoint files
sha1_configs.chkpt_dir.mkdir(exist_ok=False)
sha1_configs.checkpoint_1.touch()
sha1_configs.checkpoint_2.touch()
# Create random checkpoints
size_list = [30e5, 35e5, 40e5]
chkpts = [sha1_configs.checkpoint_1, sha1_configs.checkpoint_2, sha1_configs.checkpoint_3]
for file, size in zip(chkpts, size_list):
with open(file, "wb") as f:
f.write(os.urandom(int(size)))
sha1_store = SHA1_store(sha1_configs.test_dirs, sha1_configs.metadata_file, sha1_configs.sha1_ref, init=True)
return sha1_store
def test_sha1_add(sha1_configs, sha1_store):
# add the file to sha1_store
sha1_store.add(sha1_configs.checkpoint_1)
with open(sha1_configs.metadata_file, "r") as file:
metadata = json.load(file)
file_sha1 = metadata["SHA1"]["__sha1_full__"]
# Check metadata file creation
assert file_sha1 == sha1_store.get_sha1_hash(sha1_configs.checkpoint_1)
assert metadata["file_path"] == str(sha1_configs.test_path.joinpath(sha1_store.name, file_sha1[:2], file_sha1[2:]))
def test_sha1_refs(sha1_configs, sha1_store):
# Check reference creation
with open(sha1_configs.sha1_ref, "r") as file:
refs_data = json.load(file)
# get checkpoint1 sha1
sha1_chkpt1 = sha1_store.get_sha1_hash(sha1_configs.checkpoint_1)
assert refs_data[sha1_chkpt1]["parent"] == "ROOT"
assert refs_data[sha1_chkpt1]["child"] == "HEAD"
assert refs_data[sha1_chkpt1]["ref_count"] == 0
# add checkpoint 2 and checkpoint 3
sha1_store.add(sha1_configs.checkpoint_2)
sha1_store.add(sha1_configs.checkpoint_3)
# load ref file after Sha1 add
with open(sha1_configs.sha1_ref, "r") as file:
refs_data = json.load(file)
# get checkpoint1 sha1
sha1_chkpt2 = sha1_store.get_sha1_hash(sha1_configs.checkpoint_2)
sha1_chkpt3 = sha1_store.get_sha1_hash(sha1_configs.checkpoint_3)
assert refs_data[sha1_chkpt2]["parent"] == sha1_chkpt1
assert refs_data[sha1_chkpt2]["child"] == sha1_chkpt3
assert refs_data[sha1_chkpt2]["ref_count"] == 1
def test_tear_down(sha1_configs):
# clean up: delete the .wgit directory created during this Test
# Making sure the current directory is ./temp_wgit_testing before removing test dir
test_parent_dir = sha1_configs.test_path.parent
if (test_parent_dir.stem == "temp_wgit_testing") and (sha1_configs.test_path.stem == ".wgit"):
shutil.rmtree(test_parent_dir)
else:
raise Exception("Exception in testing directory tear down!")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment