Unverified Commit 4d58a294 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[feat]: add per-tensor add to repo (#1033)



* formatting change, no logical change

* formatting and name change, no logical change

* [refactor] sha1_store's path arg

- make sha1_store's path arg directly the path, not its parent
- this is because sha1_store is not like a .git or a .wgit dir, which is
  nested inside another "working" dir. It is simply a store, which
  is using a given dir.
- updated repo and tests as well.

* remove a test warning due to deprecated API from torch

* [refactor] change how dot_wgit_dir_path is used

- it should only be assigned in __init__.
- we use it in error checking in the rest APIs.

* simplify the init a bit

* refactor the sanity check

* moved some functions, no code change

* [feat] added per-tensor add to the repo

* enabled gzip compression on add

* fix a unit test

* add a note

* make sha1 store work on general dict

* handle general state_dict from a model, not just a module's one-level OrderedDict

* formatting
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent d0ad08c0
...@@ -503,7 +503,7 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False, dict_key: O ...@@ -503,7 +503,7 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False, dict_key: O
return all(objects_are_equal(x, y, raise_exception) for x, y in zip(a, b)) return all(objects_are_equal(x, y, raise_exception) for x, y in zip(a, b))
elif torch.is_tensor(a): elif torch.is_tensor(a):
try: try:
# assert_allclose doesn't strictly test shape, dtype and device # assert_close doesn't strictly test shape, dtype and device
shape_dtype_device_match = a.size() == b.size() and a.dtype == b.dtype and a.device == b.device shape_dtype_device_match = a.size() == b.size() and a.dtype == b.dtype and a.device == b.device
if not shape_dtype_device_match: if not shape_dtype_device_match:
if raise_exception: if raise_exception:
...@@ -513,8 +513,11 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False, dict_key: O ...@@ -513,8 +513,11 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False, dict_key: O
raise AssertionError(msg) raise AssertionError(msg)
else: else:
return False return False
# assert_allclose. # assert_close.
torch.testing.assert_allclose(a, b) if torch_version() < (1, 12, 0):
torch.testing.assert_allclose(a, b)
else:
torch.testing.assert_close(a, b)
return True return True
except (AssertionError, RuntimeError) as e: except (AssertionError, RuntimeError) as e:
if raise_exception: if raise_exception:
......
This diff is collapsed.
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from collections import OrderedDict
import hashlib import hashlib
import json import json
from pathlib import Path from pathlib import Path
...@@ -20,10 +19,6 @@ from torch import Tensor ...@@ -20,10 +19,6 @@ from torch import Tensor
from .utils import ExitCode from .utils import ExitCode
# This is a fixed dir name we use for sha1_store. It should not be changed
# for backward compatibility reasons.
SHA1_STORE_DIR_NAME = "sha1_store"
# Const string keys for json file. Do not change for backward compatibilities. # Const string keys for json file. Do not change for backward compatibilities.
RF_KEY = "ref_count" RF_KEY = "ref_count"
COMP_KEY = "compressed" COMP_KEY = "compressed"
...@@ -92,15 +87,10 @@ class SHA1_Store: ...@@ -92,15 +87,10 @@ class SHA1_Store:
compression/decompression on top of it to use all the cores. compression/decompression on top of it to use all the cores.
Args: Args:
parent_path (Path): path (Path):
The parent path in which a SHA1_Store will be created. The path in which a SHA1_Store will be created.
init (bool, optional): init (bool, optional):
- If ``True``, initializes a new SHA1_Store in the parent_path. Initialization - If ``True``, a new SHA1_Store in the path if not already exists.
creates a `sha1_store` directory in ./<parent_path>/,
and a `ref_count.json` within ./<parent_path>/.
- If ``False``, a new `sha1_store` dir is not initialized and the existing
`sha1_store` is used to init this class, populating `_json_dict`, and other
attributes.
- Default: False - Default: False
sha1_buf_size (int): sha1_buf_size (int):
Buffer size used for checksumming. Default: 100MB. Buffer size used for checksumming. Default: 100MB.
...@@ -115,22 +105,22 @@ class SHA1_Store: ...@@ -115,22 +105,22 @@ class SHA1_Store:
def __init__( def __init__(
self, self,
parent_path: Path, path: Path,
init: bool = False, init: bool = False,
sha1_buf_size: int = 100 * 1024 * 1024, sha1_buf_size: int = 100 * 1024 * 1024,
tmp_dir: str = "", tmp_dir: str = "",
pgzip_threads: Optional[int] = None, pgzip_threads: Optional[int] = None,
pgzip_block_size: int = 10 * 1024 * 1024, pgzip_block_size: int = 10 * 1024 * 1024,
) -> None: ) -> None:
"""Create or wrap (if already exists) a sha1_store.""" """Create or wrap (if already exists) a store."""
self._path = parent_path.joinpath(SHA1_STORE_DIR_NAME) self._path = path
self._ref_file_path = self._path.joinpath("ref_count.json") self._metadata_file_path = self._path.joinpath("metadata.json")
self._sha1_buf_size = sha1_buf_size self._sha1_buf_size = sha1_buf_size
self._pgzip_threads = pgzip_threads self._pgzip_threads = pgzip_threads
self._pgzip_block_size = pgzip_block_size self._pgzip_block_size = pgzip_block_size
self._json_dict: Dict[str, Any] = {"created_on": time.ctime()} self._json_dict: Dict[str, Any] = {"created_on": time.ctime()}
# Initialize the sha1_store if not exist and init==True. # Initialize the store if not exist and if init is True.
if init and not self._path.exists(): if init and not self._path.exists():
try: try:
Path.mkdir(self._path, parents=False, exist_ok=False) Path.mkdir(self._path, parents=False, exist_ok=False)
...@@ -141,7 +131,13 @@ class SHA1_Store: ...@@ -141,7 +131,13 @@ class SHA1_Store:
self._store_json_dict() self._store_json_dict()
# This is an internal error since caller of this our own wgit code. # This is an internal error since caller of this our own wgit code.
assert self._path.exists(), "SHA1 store does not exist and init==False" assert (
self._path.exists() and self._metadata_file_path.exists()
), f"SHA1 store {self._path} does not exist and init is False"
# Make sure there is a valid metadata file.
self._load_json_dict()
assert "created_on" in self._json_dict, f"Invalid SHA1 Store in {self._path}"
# Init temp dir. # Init temp dir.
if tmp_dir: if tmp_dir:
...@@ -156,30 +152,31 @@ class SHA1_Store: ...@@ -156,30 +152,31 @@ class SHA1_Store:
def _load_json_dict(self) -> None: def _load_json_dict(self) -> None:
"""Loading json dict from disk.""" """Loading json dict from disk."""
with open(self._ref_file_path, "r") as f: with open(self._metadata_file_path, "r") as f:
self._json_dict = json.load(f) self._json_dict = json.load(f)
def _store_json_dict(self) -> None: def _store_json_dict(self) -> None:
"""Storing json dict to disk.""" """Storing json dict to disk."""
with open(self._ref_file_path, "w", encoding="utf-8") as f: with open(self._metadata_file_path, "w", encoding="utf-8") as f:
json.dump(self._json_dict, f, ensure_ascii=False, indent=4) json.dump(self._json_dict, f, ensure_ascii=False, indent=4)
def add(self, file_or_obj: Union[Path, Tensor, OrderedDict], compress: bool = False) -> str: def add(self, file_or_obj: Union[Path, Tensor, Dict], compress: bool = False) -> str:
""" """Adds a file/object to this store and the sha1 references accordingly.
Adds a file/object to the internal sha1_store and the sha1 references
accordingly.
First, a sha1 hash is calculated. Utilizing the sha1 hash string, the actual file First, a sha1 hash is calculated. Utilizing the sha1 hash string, the actual file
in <file_or_obj> is moved within the sha1_store and the reference file is updated. in <file_or_obj> is moved within the store and the reference file is updated.
If the input is an object, it will be store in the self._tmp_dir and then moved. If the input is an object, it will be store in the self._tmp_dir and then moved.
If compress is True, the stored file is also compressed, which is useful for tensors If compress is True, the stored file is also compressed, which is useful for tensors
with a lot of zeros. with a lot of zeros.
Args: Args:
file_or_obj (str or tensor or OrderedDict): file_or_obj (str or tensor or Dict):
Path to the file to be added to the sha1_store or an in-memory object Path to the file to be added to the store or an in-memory object
that can be handled by torch.save. that can be handled by torch.save. Note, OrderedDict is used when
you call `state_dict()` on a nn.Module, and it is an instance
of a Dict too. A model's state_dict can be a simple dict because
it may contain both model state_dict and other non-tensor info.
""" """
# Use `isinstance` not type() == Path since pathlib returns OS specific # Use `isinstance` not type() == Path since pathlib returns OS specific
# Path types, which inherit from the Path class. # Path types, which inherit from the Path class.
...@@ -188,10 +185,10 @@ class SHA1_Store: ...@@ -188,10 +185,10 @@ class SHA1_Store:
torch.load(cast(Union[Path, str], file_or_obj)) torch.load(cast(Union[Path, str], file_or_obj))
file_path = Path(file_or_obj) file_path = Path(file_or_obj)
remove_tmp = False remove_tmp = False
elif isinstance(file_or_obj, (Tensor, OrderedDict)): elif isinstance(file_or_obj, (Tensor, Dict)):
# Serialize the object into a tmp file. # Serialize the object into a tmp file.
file_path = self._get_tmp_file_path() file_path = self._get_tmp_file_path()
torch.save(cast(Union[Tensor, OrderedDict], file_or_obj), file_path) torch.save(cast(Union[Tensor, Dict], file_or_obj), file_path)
remove_tmp = True remove_tmp = True
else: else:
assert False, f"incorrect input {type(file_or_obj)}" assert False, f"incorrect input {type(file_or_obj)}"
...@@ -215,7 +212,7 @@ class SHA1_Store: ...@@ -215,7 +212,7 @@ class SHA1_Store:
sys.stderr.write(f"An exception occured: {repr(error)}\n") sys.stderr.write(f"An exception occured: {repr(error)}\n")
sys.exit(ExitCode.FILE_EXISTS_ERROR) sys.exit(ExitCode.FILE_EXISTS_ERROR)
# Transfer the file to the internal sha1_store # Transfer the file to the store.
repo_fpath = repo_fdir.joinpath(sha1_hash) repo_fpath = repo_fdir.joinpath(sha1_hash)
try: try:
if compress: if compress:
...@@ -234,7 +231,7 @@ class SHA1_Store: ...@@ -234,7 +231,7 @@ class SHA1_Store:
return sha1_hash return sha1_hash
def get(self, sha1: str) -> Union[Tensor, OrderedDict]: def get(self, sha1: str) -> Union[Tensor, Dict]:
"""Get data from a SHA1 """Get data from a SHA1
Args: Args:
...@@ -242,7 +239,7 @@ class SHA1_Store: ...@@ -242,7 +239,7 @@ class SHA1_Store:
SHA1 of the object to get. SHA1 of the object to get.
Returns: Returns:
(Tensor or OrderedDict): (Tensor or Dict):
In-memory object. In-memory object.
Throws: Throws:
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#MODIFIED FOR FlattenParamsWrapper
from typing import Any from typing import Any
# Deprecate allclose when we move to newer versions.
def assert_allclose(actual: Any, expected: Any, rtol: float = ..., atol: float = ..., equal_nan: bool = ..., msg: str = ...) -> None: ... def assert_allclose(actual: Any, expected: Any, rtol: float = ..., atol: float = ..., equal_nan: bool = ..., msg: str = ...) -> None: ...
def assert_close(actual: Any, expected: Any, rtol: float = ..., atol: float = ..., equal_nan: bool = ..., msg: str = ...) -> None: ...
#END
...@@ -33,7 +33,10 @@ def create_test_dir(): ...@@ -33,7 +33,10 @@ def create_test_dir():
# create random checkpoints # create random checkpoints
size_list = [30e5, 35e5, 40e5, 40e5] size_list = [30e5, 35e5, 40e5, 40e5]
for i, size in enumerate(size_list): for i, size in enumerate(size_list):
torch.save(nn.Linear(1, int(size)), f"checkpoint_{i}.pt") sd = {}
sd["model"] = nn.Linear(1, int(size)).state_dict()
sd["step"] = 100
torch.save(sd, f"checkpoint_{i}.pt")
return test_dir return test_dir
...@@ -55,10 +58,14 @@ def test_api_init(capsys, repo): ...@@ -55,10 +58,14 @@ def test_api_init(capsys, repo):
assert Path(".wgit/.gitignore").exists() assert Path(".wgit/.gitignore").exists()
def test_api_add(capsys, repo): @pytest.mark.parametrize("per_tensor", [True, False])
def test_api_add(capsys, repo, per_tensor):
fnum = random.randint(0, 2) fnum = random.randint(0, 2)
chkpt0 = f"checkpoint_{fnum}.pt" chkpt0 = f"checkpoint_{fnum}.pt"
repo.add(chkpt0) repo.add(chkpt0, per_tensor)
if per_tensor:
# TODO (Min): test per_tensor add more.
return
sha1_hash = repo._sha1_store._get_sha1_hash(chkpt0) sha1_hash = repo._sha1_store._get_sha1_hash(chkpt0)
metadata_path = repo._rel_file_path(Path(chkpt0)) metadata_path = repo._rel_file_path(Path(chkpt0))
......
...@@ -56,7 +56,7 @@ def test_cli_add(create_test_dir, capsys): ...@@ -56,7 +56,7 @@ def test_cli_add(create_test_dir, capsys):
cli.main(["add", chkpt0]) cli.main(["add", chkpt0])
sha1_store = SHA1_Store( sha1_store = SHA1_Store(
Path.cwd().joinpath(".wgit"), Path.cwd().joinpath(".wgit", "sha1_store"),
init=False, init=False,
) )
sha1_hash = sha1_store._get_sha1_hash(chkpt0) sha1_hash = sha1_store._get_sha1_hash(chkpt0)
......
...@@ -17,7 +17,7 @@ from fairscale.internal import torch_version ...@@ -17,7 +17,7 @@ from fairscale.internal import torch_version
# Get the absolute path of the parent at the beginning before any os.chdir(), # Get the absolute path of the parent at the beginning before any os.chdir(),
# so that we can proper clean it up at any CWD. # so that we can proper clean it up at any CWD.
PARENT_DIR = Path("sha1_store_testing").resolve() TESTING_STORE_DIR = Path("sha1_store_testing").resolve()
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
...@@ -32,8 +32,9 @@ def sha1_store(request): ...@@ -32,8 +32,9 @@ def sha1_store(request):
""" """
# Attach a teardown function. # Attach a teardown function.
def teardown(): def teardown():
os.chdir(PARENT_DIR.joinpath("..").resolve()) os.chdir(TESTING_STORE_DIR.joinpath("..").resolve())
shutil.rmtree(PARENT_DIR, ignore_errors=True) if TESTING_STORE_DIR.exists():
shutil.rmtree(TESTING_STORE_DIR)
request.addfinalizer(teardown) request.addfinalizer(teardown)
...@@ -41,15 +42,14 @@ def sha1_store(request): ...@@ -41,15 +42,14 @@ def sha1_store(request):
teardown() teardown()
# Get an empty sha1 store. # Get an empty sha1 store.
PARENT_DIR.mkdir() sha1_store = SHA1_Store(TESTING_STORE_DIR, init=True)
sha1_store = SHA1_Store(PARENT_DIR, init=True)
return sha1_store return sha1_store
@pytest.mark.parametrize("compress", [True, False]) @pytest.mark.parametrize("compress", [True, False])
def test_sha1_add_file(sha1_store, compress): def test_sha1_add_file(sha1_store, compress):
os.chdir(PARENT_DIR) os.chdir(TESTING_STORE_DIR)
# Create random checkpoints # Create random checkpoints
size_list = [25e5, 27e5, 30e5, 35e5, 40e5] size_list = [25e5, 27e5, 30e5, 35e5, 40e5]
...@@ -89,7 +89,7 @@ def test_sha1_add_file(sha1_store, compress): ...@@ -89,7 +89,7 @@ def test_sha1_add_file(sha1_store, compress):
@pytest.mark.parametrize("compress", [True, False]) @pytest.mark.parametrize("compress", [True, False])
def test_sha1_add_state_dict(sha1_store, compress): def test_sha1_add_state_dict(sha1_store, compress):
os.chdir(PARENT_DIR) os.chdir(TESTING_STORE_DIR)
# add once # add once
for i in range(3): for i in range(3):
sha1_store.add(nn.Linear(10, 10).state_dict(), compress) sha1_store.add(nn.Linear(10, 10).state_dict(), compress)
...@@ -107,7 +107,7 @@ def test_sha1_add_state_dict(sha1_store, compress): ...@@ -107,7 +107,7 @@ def test_sha1_add_state_dict(sha1_store, compress):
@pytest.mark.parametrize("compress", [True, False]) @pytest.mark.parametrize("compress", [True, False])
def test_sha1_add_tensor(sha1_store, compress): def test_sha1_add_tensor(sha1_store, compress):
os.chdir(PARENT_DIR) os.chdir(TESTING_STORE_DIR)
sha1_store.add(torch.Tensor([1.0, 5.5, 3.4]), compress) sha1_store.add(torch.Tensor([1.0, 5.5, 3.4]), compress)
sha1_store._load_json_dict() sha1_store._load_json_dict()
json_dict = sha1_store._json_dict json_dict = sha1_store._json_dict
...@@ -120,7 +120,7 @@ def test_sha1_add_tensor(sha1_store, compress): ...@@ -120,7 +120,7 @@ def test_sha1_add_tensor(sha1_store, compress):
@pytest.mark.parametrize("compress", [True, False]) @pytest.mark.parametrize("compress", [True, False])
def test_sha1_get(sha1_store, compress): def test_sha1_get(sha1_store, compress):
"""Testing the get() API: normal and exception cases.""" """Testing the get() API: normal and exception cases."""
os.chdir(PARENT_DIR) os.chdir(TESTING_STORE_DIR)
# Add a file, a state dict and a tensor. # Add a file, a state dict and a tensor.
file = "test_get.pt" file = "test_get.pt"
...@@ -149,7 +149,7 @@ def test_sha1_get(sha1_store, compress): ...@@ -149,7 +149,7 @@ def test_sha1_get(sha1_store, compress):
@pytest.mark.parametrize("compress", [True, False]) @pytest.mark.parametrize("compress", [True, False])
def test_sha1_delete(sha1_store, compress): def test_sha1_delete(sha1_store, compress):
"""Testing the delete() API: with ref counting behavior.""" """Testing the delete() API: with ref counting behavior."""
os.chdir(PARENT_DIR) os.chdir(TESTING_STORE_DIR)
# Add once and delete, second delete should throw an exception. # Add once and delete, second delete should throw an exception.
tensor = torch.ones(30, 50) tensor = torch.ones(30, 50)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment