Unverified Commit 4d58a294 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[feat]: add per-tensor add to repo (#1033)



* formatting change, no logical change

* formatting and name change, no logical change

* [refactor] sha1_store's path arg

- make sha1_store's path arg directly the path, not its parent
- this is because sha1_store is not like a .git or a .wgit dir, which is
  nested inside another "working" dir. It is simply a store, which
  is using a given dir.
- updated repo and tests as well.

* remove a test warning due to deprecated API from torch

* [refactor] change how dot_wgit_dir_path is used

- it should only be assigned in __init__.
- we use it in error checking in the rest APIs.

* simplify the init a bit

* refactor the sanity check

* moved some functions, no code change

* [feat] added per-tensor add to the repo

* enabled gzip compression on add

* fix a unit test

* add a note

* make sha1 store work on general dict

* handle general state_dict from a model, not just a module's one-level OrderedDict

* formatting
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent d0ad08c0
......@@ -503,7 +503,7 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False, dict_key: O
return all(objects_are_equal(x, y, raise_exception) for x, y in zip(a, b))
elif torch.is_tensor(a):
try:
# assert_allclose doesn't strictly test shape, dtype and device
# assert_close doesn't strictly test shape, dtype and device
shape_dtype_device_match = a.size() == b.size() and a.dtype == b.dtype and a.device == b.device
if not shape_dtype_device_match:
if raise_exception:
......@@ -513,8 +513,11 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False, dict_key: O
raise AssertionError(msg)
else:
return False
# assert_allclose.
torch.testing.assert_allclose(a, b)
# assert_close.
if torch_version() < (1, 12, 0):
torch.testing.assert_allclose(a, b)
else:
torch.testing.assert_close(a, b)
return True
except (AssertionError, RuntimeError) as e:
if raise_exception:
......
This diff is collapsed.
......@@ -4,7 +4,6 @@
# LICENSE file in the root directory of this source tree.
from collections import OrderedDict
import hashlib
import json
from pathlib import Path
......@@ -20,10 +19,6 @@ from torch import Tensor
from .utils import ExitCode
# This is a fixed dir name we use for sha1_store. It should not be changed
# for backward compatibility reasons.
SHA1_STORE_DIR_NAME = "sha1_store"
# Const string keys for json file. Do not change for backward compatibilities.
RF_KEY = "ref_count"
COMP_KEY = "compressed"
......@@ -92,15 +87,10 @@ class SHA1_Store:
compression/decompression on top of it to use all the cores.
Args:
parent_path (Path):
The parent path in which a SHA1_Store will be created.
path (Path):
The path in which a SHA1_Store will be created.
init (bool, optional):
- If ``True``, initializes a new SHA1_Store in the parent_path. Initialization
creates a `sha1_store` directory in ./<parent_path>/,
and a `ref_count.json` within ./<parent_path>/.
- If ``False``, a new `sha1_store` dir is not initialized and the existing
`sha1_store` is used to init this class, populating `_json_dict`, and other
attributes.
- If ``True``, a new SHA1_Store in the path if not already exists.
- Default: False
sha1_buf_size (int):
Buffer size used for checksumming. Default: 100MB.
......@@ -115,22 +105,22 @@ class SHA1_Store:
def __init__(
self,
parent_path: Path,
path: Path,
init: bool = False,
sha1_buf_size: int = 100 * 1024 * 1024,
tmp_dir: str = "",
pgzip_threads: Optional[int] = None,
pgzip_block_size: int = 10 * 1024 * 1024,
) -> None:
"""Create or wrap (if already exists) a sha1_store."""
self._path = parent_path.joinpath(SHA1_STORE_DIR_NAME)
self._ref_file_path = self._path.joinpath("ref_count.json")
"""Create or wrap (if already exists) a store."""
self._path = path
self._metadata_file_path = self._path.joinpath("metadata.json")
self._sha1_buf_size = sha1_buf_size
self._pgzip_threads = pgzip_threads
self._pgzip_block_size = pgzip_block_size
self._json_dict: Dict[str, Any] = {"created_on": time.ctime()}
# Initialize the sha1_store if not exist and init==True.
# Initialize the store if not exist and if init is True.
if init and not self._path.exists():
try:
Path.mkdir(self._path, parents=False, exist_ok=False)
......@@ -141,7 +131,13 @@ class SHA1_Store:
self._store_json_dict()
# This is an internal error since caller of this our own wgit code.
assert self._path.exists(), "SHA1 store does not exist and init==False"
assert (
self._path.exists() and self._metadata_file_path.exists()
), f"SHA1 store {self._path} does not exist and init is False"
# Make sure there is a valid metadata file.
self._load_json_dict()
assert "created_on" in self._json_dict, f"Invalid SHA1 Store in {self._path}"
# Init temp dir.
if tmp_dir:
......@@ -156,30 +152,31 @@ class SHA1_Store:
def _load_json_dict(self) -> None:
"""Loading json dict from disk."""
with open(self._ref_file_path, "r") as f:
with open(self._metadata_file_path, "r") as f:
self._json_dict = json.load(f)
def _store_json_dict(self) -> None:
"""Storing json dict to disk."""
with open(self._ref_file_path, "w", encoding="utf-8") as f:
with open(self._metadata_file_path, "w", encoding="utf-8") as f:
json.dump(self._json_dict, f, ensure_ascii=False, indent=4)
def add(self, file_or_obj: Union[Path, Tensor, OrderedDict], compress: bool = False) -> str:
"""
Adds a file/object to the internal sha1_store and the sha1 references
accordingly.
def add(self, file_or_obj: Union[Path, Tensor, Dict], compress: bool = False) -> str:
"""Adds a file/object to this store and the sha1 references accordingly.
First, a sha1 hash is calculated. Utilizing the sha1 hash string, the actual file
in <file_or_obj> is moved within the sha1_store and the reference file is updated.
in <file_or_obj> is moved within the store and the reference file is updated.
If the input is an object, it will be store in the self._tmp_dir and then moved.
If compress is True, the stored file is also compressed, which is useful for tensors
with a lot of zeros.
Args:
file_or_obj (str or tensor or OrderedDict):
Path to the file to be added to the sha1_store or an in-memory object
that can be handled by torch.save.
file_or_obj (str or tensor or Dict):
Path to the file to be added to the store or an in-memory object
that can be handled by torch.save. Note, OrderedDict is used when
you call `state_dict()` on a nn.Module, and it is an instance
of a Dict too. A model's state_dict can be a simple dict because
it may contain both model state_dict and other non-tensor info.
"""
# Use `isinstance` not type() == Path since pathlib returns OS specific
# Path types, which inherit from the Path class.
......@@ -188,10 +185,10 @@ class SHA1_Store:
torch.load(cast(Union[Path, str], file_or_obj))
file_path = Path(file_or_obj)
remove_tmp = False
elif isinstance(file_or_obj, (Tensor, OrderedDict)):
elif isinstance(file_or_obj, (Tensor, Dict)):
# Serialize the object into a tmp file.
file_path = self._get_tmp_file_path()
torch.save(cast(Union[Tensor, OrderedDict], file_or_obj), file_path)
torch.save(cast(Union[Tensor, Dict], file_or_obj), file_path)
remove_tmp = True
else:
assert False, f"incorrect input {type(file_or_obj)}"
......@@ -215,7 +212,7 @@ class SHA1_Store:
sys.stderr.write(f"An exception occured: {repr(error)}\n")
sys.exit(ExitCode.FILE_EXISTS_ERROR)
# Transfer the file to the internal sha1_store
# Transfer the file to the store.
repo_fpath = repo_fdir.joinpath(sha1_hash)
try:
if compress:
......@@ -234,7 +231,7 @@ class SHA1_Store:
return sha1_hash
def get(self, sha1: str) -> Union[Tensor, OrderedDict]:
def get(self, sha1: str) -> Union[Tensor, Dict]:
"""Get data from a SHA1
Args:
......@@ -242,7 +239,7 @@ class SHA1_Store:
SHA1 of the object to get.
Returns:
(Tensor or OrderedDict):
(Tensor or Dict):
In-memory object.
Throws:
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#MODIFIED FOR FlattenParamsWrapper
from typing import Any
# Deprecate allclose when we move to newer versions.
def assert_allclose(actual: Any, expected: Any, rtol: float = ..., atol: float = ..., equal_nan: bool = ..., msg: str = ...) -> None: ...
#END
def assert_close(actual: Any, expected: Any, rtol: float = ..., atol: float = ..., equal_nan: bool = ..., msg: str = ...) -> None: ...
......@@ -33,7 +33,10 @@ def create_test_dir():
# create random checkpoints
size_list = [30e5, 35e5, 40e5, 40e5]
for i, size in enumerate(size_list):
torch.save(nn.Linear(1, int(size)), f"checkpoint_{i}.pt")
sd = {}
sd["model"] = nn.Linear(1, int(size)).state_dict()
sd["step"] = 100
torch.save(sd, f"checkpoint_{i}.pt")
return test_dir
......@@ -55,10 +58,14 @@ def test_api_init(capsys, repo):
assert Path(".wgit/.gitignore").exists()
def test_api_add(capsys, repo):
@pytest.mark.parametrize("per_tensor", [True, False])
def test_api_add(capsys, repo, per_tensor):
fnum = random.randint(0, 2)
chkpt0 = f"checkpoint_{fnum}.pt"
repo.add(chkpt0)
repo.add(chkpt0, per_tensor)
if per_tensor:
# TODO (Min): test per_tensor add more.
return
sha1_hash = repo._sha1_store._get_sha1_hash(chkpt0)
metadata_path = repo._rel_file_path(Path(chkpt0))
......
......@@ -56,7 +56,7 @@ def test_cli_add(create_test_dir, capsys):
cli.main(["add", chkpt0])
sha1_store = SHA1_Store(
Path.cwd().joinpath(".wgit"),
Path.cwd().joinpath(".wgit", "sha1_store"),
init=False,
)
sha1_hash = sha1_store._get_sha1_hash(chkpt0)
......
......@@ -17,7 +17,7 @@ from fairscale.internal import torch_version
# Get the absolute path of the parent at the beginning before any os.chdir(),
# so that we can proper clean it up at any CWD.
PARENT_DIR = Path("sha1_store_testing").resolve()
TESTING_STORE_DIR = Path("sha1_store_testing").resolve()
@pytest.fixture(scope="function")
......@@ -32,8 +32,9 @@ def sha1_store(request):
"""
# Attach a teardown function.
def teardown():
os.chdir(PARENT_DIR.joinpath("..").resolve())
shutil.rmtree(PARENT_DIR, ignore_errors=True)
os.chdir(TESTING_STORE_DIR.joinpath("..").resolve())
if TESTING_STORE_DIR.exists():
shutil.rmtree(TESTING_STORE_DIR)
request.addfinalizer(teardown)
......@@ -41,15 +42,14 @@ def sha1_store(request):
teardown()
# Get an empty sha1 store.
PARENT_DIR.mkdir()
sha1_store = SHA1_Store(PARENT_DIR, init=True)
sha1_store = SHA1_Store(TESTING_STORE_DIR, init=True)
return sha1_store
@pytest.mark.parametrize("compress", [True, False])
def test_sha1_add_file(sha1_store, compress):
os.chdir(PARENT_DIR)
os.chdir(TESTING_STORE_DIR)
# Create random checkpoints
size_list = [25e5, 27e5, 30e5, 35e5, 40e5]
......@@ -89,7 +89,7 @@ def test_sha1_add_file(sha1_store, compress):
@pytest.mark.parametrize("compress", [True, False])
def test_sha1_add_state_dict(sha1_store, compress):
os.chdir(PARENT_DIR)
os.chdir(TESTING_STORE_DIR)
# add once
for i in range(3):
sha1_store.add(nn.Linear(10, 10).state_dict(), compress)
......@@ -107,7 +107,7 @@ def test_sha1_add_state_dict(sha1_store, compress):
@pytest.mark.parametrize("compress", [True, False])
def test_sha1_add_tensor(sha1_store, compress):
os.chdir(PARENT_DIR)
os.chdir(TESTING_STORE_DIR)
sha1_store.add(torch.Tensor([1.0, 5.5, 3.4]), compress)
sha1_store._load_json_dict()
json_dict = sha1_store._json_dict
......@@ -120,7 +120,7 @@ def test_sha1_add_tensor(sha1_store, compress):
@pytest.mark.parametrize("compress", [True, False])
def test_sha1_get(sha1_store, compress):
"""Testing the get() API: normal and exception cases."""
os.chdir(PARENT_DIR)
os.chdir(TESTING_STORE_DIR)
# Add a file, a state dict and a tensor.
file = "test_get.pt"
......@@ -149,7 +149,7 @@ def test_sha1_get(sha1_store, compress):
@pytest.mark.parametrize("compress", [True, False])
def test_sha1_delete(sha1_store, compress):
"""Testing the delete() API: with ref counting behavior."""
os.chdir(PARENT_DIR)
os.chdir(TESTING_STORE_DIR)
# Add once and delete, second delete should throw an exception.
tensor = torch.ones(30, 50)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment