Unverified Commit d0ad08c0 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[feat] add compression and tests to sha1 store (#1032)


Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent c8327e1c
...@@ -3,8 +3,23 @@ ...@@ -3,8 +3,23 @@
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import sys
from typing import List from typing import List
# Check for user requirements before we import our code.
try:
import pygit2
except ImportError:
print("Error: please pip install pygit2 module to use wgit")
sys.exit(1)
try:
import pgzip
except ImportError:
print("Error: please pip install pgzip module to use wgit")
sys.exit(1)
from .repo import Repo from .repo import Repo
from .signal_sparsity import Algo, SignalSparsity from .signal_sparsity import Algo, SignalSparsity
from .version import __version_tuple__ from .version import __version_tuple__
......
...@@ -12,8 +12,9 @@ import shutil ...@@ -12,8 +12,9 @@ import shutil
import sys import sys
import tempfile import tempfile
import time import time
from typing import Any, Dict, Union, cast from typing import Any, Dict, Optional, Union, cast
import pgzip
import torch import torch
from torch import Tensor from torch import Tensor
...@@ -25,6 +26,7 @@ SHA1_STORE_DIR_NAME = "sha1_store" ...@@ -25,6 +26,7 @@ SHA1_STORE_DIR_NAME = "sha1_store"
# Const string keys for json file. Do not change for backward compatibilities. # Const string keys for json file. Do not change for backward compatibilities.
RF_KEY = "ref_count" RF_KEY = "ref_count"
COMP_KEY = "compressed"
def _get_json_entry(d: Dict[str, Any]) -> Dict[str, Any]: def _get_json_entry(d: Dict[str, Any]) -> Dict[str, Any]:
...@@ -38,6 +40,28 @@ def _get_json_entry(d: Dict[str, Any]) -> Dict[str, Any]: ...@@ -38,6 +40,28 @@ def _get_json_entry(d: Dict[str, Any]) -> Dict[str, Any]:
return d return d
def _copy_compressed(src: Path, dest: Path, thread: Optional[int], blocksize: int) -> None:
"""Helper to copy a file and compress it at the same time."""
with open(str(src), "rb") as srcf:
with pgzip.open(str(dest), "wb", compresslevel=5, thread=thread, blocksize=blocksize) as destf:
while True:
buf = srcf.read(blocksize)
if len(buf) == 0:
break
destf.write(buf)
def _copy_uncompressed(src: Path, dest: Path, thread: Optional[int], blocksize: int) -> None:
"""Helper to copy a file and uncompress it at the same time."""
with open(str(dest), "wb") as destf:
with pgzip.open(str(src), "rb", thread=thread, blocksize=blocksize) as srcf:
while True:
buf = srcf.read(blocksize)
if len(buf) == 0:
break
destf.write(buf)
class SHA1_Store: class SHA1_Store:
""" """
This class represents a SHA1 checksum based storage dir for state_dict This class represents a SHA1 checksum based storage dir for state_dict
...@@ -61,6 +85,12 @@ class SHA1_Store: ...@@ -61,6 +85,12 @@ class SHA1_Store:
to delete in a version tracking graph. The lesson here is that content to delete in a version tracking graph. The lesson here is that content
addressibility and dependency graphs do not mix well. addressibility and dependency graphs do not mix well.
We support multicore compression for the data to be store on per-object basis.
The ``torch.save()`` API uses zip format to store the data, but it appears to
be uncompressed. Even if it can be made compressed, it is likely a single
threaded compression. Therefore, we use pgzip to do parallel
compression/decompression on top of it to use all the cores.
Args: Args:
parent_path (Path): parent_path (Path):
The parent path in which a SHA1_Store will be created. The parent path in which a SHA1_Store will be created.
...@@ -75,16 +105,29 @@ class SHA1_Store: ...@@ -75,16 +105,29 @@ class SHA1_Store:
sha1_buf_size (int): sha1_buf_size (int):
Buffer size used for checksumming. Default: 100MB. Buffer size used for checksumming. Default: 100MB.
tmp_dir (str): tmp_dir (str):
Dir for temporary files if input is an in-memory object. Dir for temporary files if input is an in-memory object or output data needs
to be decompressed first.
pgzip_threads (int, optional):
Number of threads (cores) used in compression. Default: None to use all cores.
pgzip_block_size (int):
Per-thread block size for compression. Default: 10MB.
""" """
def __init__( def __init__(
self, parent_path: Path, init: bool = False, sha1_buf_size: int = 100 * 1024 * 1024, tmp_dir: str = "" self,
parent_path: Path,
init: bool = False,
sha1_buf_size: int = 100 * 1024 * 1024,
tmp_dir: str = "",
pgzip_threads: Optional[int] = None,
pgzip_block_size: int = 10 * 1024 * 1024,
) -> None: ) -> None:
"""Create or wrap (if already exists) a sha1_store.""" """Create or wrap (if already exists) a sha1_store."""
self._path = parent_path.joinpath(SHA1_STORE_DIR_NAME) self._path = parent_path.joinpath(SHA1_STORE_DIR_NAME)
self._ref_file_path = self._path.joinpath("ref_count.json") self._ref_file_path = self._path.joinpath("ref_count.json")
self._sha1_buf_size = sha1_buf_size self._sha1_buf_size = sha1_buf_size
self._pgzip_threads = pgzip_threads
self._pgzip_block_size = pgzip_block_size
self._json_dict: Dict[str, Any] = {"created_on": time.ctime()} self._json_dict: Dict[str, Any] = {"created_on": time.ctime()}
# Initialize the sha1_store if not exist and init==True. # Initialize the sha1_store if not exist and init==True.
...@@ -121,7 +164,7 @@ class SHA1_Store: ...@@ -121,7 +164,7 @@ class SHA1_Store:
with open(self._ref_file_path, "w", encoding="utf-8") as f: with open(self._ref_file_path, "w", encoding="utf-8") as f:
json.dump(self._json_dict, f, ensure_ascii=False, indent=4) json.dump(self._json_dict, f, ensure_ascii=False, indent=4)
def add(self, file_or_obj: Union[Path, Tensor, OrderedDict]) -> str: def add(self, file_or_obj: Union[Path, Tensor, OrderedDict], compress: bool = False) -> str:
""" """
Adds a file/object to the internal sha1_store and the sha1 references Adds a file/object to the internal sha1_store and the sha1 references
accordingly. accordingly.
...@@ -130,6 +173,9 @@ class SHA1_Store: ...@@ -130,6 +173,9 @@ class SHA1_Store:
in <file_or_obj> is moved within the sha1_store and the reference file is updated. in <file_or_obj> is moved within the sha1_store and the reference file is updated.
If the input is an object, it will be store in the self._tmp_dir and then moved. If the input is an object, it will be store in the self._tmp_dir and then moved.
If compress is True, the stored file is also compressed, which is useful for tensors
with a lot of zeros.
Args: Args:
file_or_obj (str or tensor or OrderedDict): file_or_obj (str or tensor or OrderedDict):
Path to the file to be added to the sha1_store or an in-memory object Path to the file to be added to the sha1_store or an in-memory object
...@@ -155,7 +201,7 @@ class SHA1_Store: ...@@ -155,7 +201,7 @@ class SHA1_Store:
sha1_hash = self._get_sha1_hash(file_path) sha1_hash = self._get_sha1_hash(file_path)
# Add reference. # Add reference.
ref_count = self._add_ref(sha1_hash, True) ref_count = self._add_ref(sha1_hash, True, compress)
if ref_count == 1: if ref_count == 1:
# First time adding # First time adding
...@@ -172,12 +218,15 @@ class SHA1_Store: ...@@ -172,12 +218,15 @@ class SHA1_Store:
# Transfer the file to the internal sha1_store # Transfer the file to the internal sha1_store
repo_fpath = repo_fdir.joinpath(sha1_hash) repo_fpath = repo_fdir.joinpath(sha1_hash)
try: try:
if compress:
_copy_compressed(file_path, repo_fpath, self._pgzip_threads, self._pgzip_block_size)
else:
shutil.copy2(file_path, repo_fpath) shutil.copy2(file_path, repo_fpath)
except BaseException as error: except BaseException as error:
# Something went wrong, perhaps out of space, or race condition due to lack of locking. # Something went wrong, perhaps out of space, or race condition due to lack of locking.
# TODO (Min): proper handle the error and recover when we learn more here. # TODO (Min): proper handle the error and recover when we learn more here.
sys.stderr.write(f"An exception occured: {repr(error)}\n") sys.stderr.write(f"An exception occured: {repr(error)}\n")
ref_count = self._add_ref(sha1_hash, False) ref_count = self._add_ref(sha1_hash, False, compress)
# Clean up if needed. # Clean up if needed.
if remove_tmp: if remove_tmp:
...@@ -210,6 +259,17 @@ class SHA1_Store: ...@@ -210,6 +259,17 @@ class SHA1_Store:
# #
# TODO (Min): we could also keep a stats in the meta data on how many # TODO (Min): we could also keep a stats in the meta data on how many
# times the object is read. Will add if that's needed. # times the object is read. Will add if that's needed.
self._load_json_dict()
if self._json_dict[sha1][COMP_KEY]:
# Compressed. Because pgzip doesn't support tell() yet, we need to
# uncompress into a temp file and return it.
tmp = self._get_tmp_file_path()
_copy_uncompressed(path, tmp, self._pgzip_threads, self._pgzip_block_size)
obj = torch.load(tmp)
tmp.unlink()
return obj
else:
# Uncompressed.
return torch.load(path) return torch.load(path)
def delete(self, sha1: str) -> None: def delete(self, sha1: str) -> None:
...@@ -282,7 +342,7 @@ class SHA1_Store: ...@@ -282,7 +342,7 @@ class SHA1_Store:
part1, part2 = sha1[:2], sha1[2:4] part1, part2 = sha1[:2], sha1[2:4]
return self._path.joinpath(part1, part2) return self._path.joinpath(part1, part2)
def _add_ref(self, current_sha1_hash: str, inc: bool) -> int: def _add_ref(self, current_sha1_hash: str, inc: bool, compressed: bool) -> int:
""" """
Update the reference count. Update the reference count.
...@@ -312,6 +372,9 @@ class SHA1_Store: ...@@ -312,6 +372,9 @@ class SHA1_Store:
entry[RF_KEY] += 1 if inc else -1 entry[RF_KEY] += 1 if inc else -1
assert entry[RF_KEY] >= 0, "negative ref count" assert entry[RF_KEY] >= 0, "negative ref count"
# Update compressed flag.
entry[COMP_KEY] = compressed
self._json_dict[current_sha1_hash] = entry self._json_dict[current_sha1_hash] = entry
self._store_json_dict() self._store_json_dict()
......
...@@ -34,5 +34,8 @@ numpy == 1.22.0 ...@@ -34,5 +34,8 @@ numpy == 1.22.0
# For layerwise gradient scaler # For layerwise gradient scaler
sklearn >= 0.0 sklearn >= 0.0
# For weigit # For weigit. These are actually user requirements, not developer requirements.
# However, due to the experimental nature of weigit, we don't expose to the
# general users of fairscale yet. We check for them in weigit's init code.
pygit2==1.9.2 pygit2==1.9.2
pgzip==0.3.1
...@@ -47,7 +47,8 @@ def sha1_store(request): ...@@ -47,7 +47,8 @@ def sha1_store(request):
return sha1_store return sha1_store
def test_sha1_add_file(sha1_store): @pytest.mark.parametrize("compress", [True, False])
def test_sha1_add_file(sha1_store, compress):
os.chdir(PARENT_DIR) os.chdir(PARENT_DIR)
# Create random checkpoints # Create random checkpoints
...@@ -65,15 +66,15 @@ def test_sha1_add_file(sha1_store): ...@@ -65,15 +66,15 @@ def test_sha1_add_file(sha1_store):
# Add those 5 random files. # Add those 5 random files.
for c in chkpts: for c in chkpts:
sha1_store.add(c) sha1_store.add(c, compress)
# Add a fixed data twice. # Add a fixed data twice.
module = nn.Linear(100, 100, bias=False) module = nn.Linear(100, 100, bias=False)
module.weight.data = torch.zeros(100, 100) module.weight.data = torch.zeros(100, 100)
zeros_file = "zeros.pt" zeros_file = "zeros.pt"
torch.save(module.state_dict(), zeros_file) torch.save(module.state_dict(), zeros_file)
sha1_store.add(zeros_file) sha1_store.add(zeros_file, compress)
sha1_store.add(zeros_file) sha1_store.add(zeros_file, compress)
# Assert the ref counts are 1,1,1,1,1 and 2 # Assert the ref counts are 1,1,1,1,1 and 2
sha1_store._load_json_dict() sha1_store._load_json_dict()
...@@ -86,16 +87,17 @@ def test_sha1_add_file(sha1_store): ...@@ -86,16 +87,17 @@ def test_sha1_add_file(sha1_store):
assert sorted(map(lambda x: x["ref_count"], json_dict.values())) == [1, 1, 1, 1, 1, 2], json_dict assert sorted(map(lambda x: x["ref_count"], json_dict.values())) == [1, 1, 1, 1, 1, 2], json_dict
def test_sha1_add_state_dict(sha1_store): @pytest.mark.parametrize("compress", [True, False])
def test_sha1_add_state_dict(sha1_store, compress):
os.chdir(PARENT_DIR) os.chdir(PARENT_DIR)
# add once # add once
for i in range(3): for i in range(3):
sha1_store.add(nn.Linear(10, 10).state_dict()) sha1_store.add(nn.Linear(10, 10).state_dict(), compress)
# add twice # add twice
for i in range(3): for i in range(3):
sd = nn.Linear(8, 8).state_dict() sd = nn.Linear(8, 8).state_dict()
sha1_store.add(sd) sha1_store.add(sd, compress)
sha1_store.add(sd) sha1_store.add(sd, compress)
sha1_store._load_json_dict() sha1_store._load_json_dict()
json_dict = sha1_store._json_dict json_dict = sha1_store._json_dict
...@@ -103,9 +105,10 @@ def test_sha1_add_state_dict(sha1_store): ...@@ -103,9 +105,10 @@ def test_sha1_add_state_dict(sha1_store):
assert sorted(map(lambda x: x["ref_count"], json_dict.values())) == [1, 1, 1, 2, 2, 2], json_dict assert sorted(map(lambda x: x["ref_count"], json_dict.values())) == [1, 1, 1, 2, 2, 2], json_dict
def test_sha1_add_tensor(sha1_store): @pytest.mark.parametrize("compress", [True, False])
def test_sha1_add_tensor(sha1_store, compress):
os.chdir(PARENT_DIR) os.chdir(PARENT_DIR)
sha1_store.add(torch.Tensor([1.0, 5.5, 3.4])) sha1_store.add(torch.Tensor([1.0, 5.5, 3.4]), compress)
sha1_store._load_json_dict() sha1_store._load_json_dict()
json_dict = sha1_store._json_dict json_dict = sha1_store._json_dict
if torch_version() >= (1, 9, 0): if torch_version() >= (1, 9, 0):
...@@ -114,7 +117,8 @@ def test_sha1_add_tensor(sha1_store): ...@@ -114,7 +117,8 @@ def test_sha1_add_tensor(sha1_store):
assert key in json_dict.keys() and json_dict[key]["ref_count"] == 1, json_dict assert key in json_dict.keys() and json_dict[key]["ref_count"] == 1, json_dict
def test_sha1_get(sha1_store): @pytest.mark.parametrize("compress", [True, False])
def test_sha1_get(sha1_store, compress):
"""Testing the get() API: normal and exception cases.""" """Testing the get() API: normal and exception cases."""
os.chdir(PARENT_DIR) os.chdir(PARENT_DIR)
...@@ -125,15 +129,15 @@ def test_sha1_get(sha1_store): ...@@ -125,15 +129,15 @@ def test_sha1_get(sha1_store):
tensor = torch.ones(20, 30) tensor = torch.ones(20, 30)
# Check that we can get them back. # Check that we can get them back.
file_sha1 = sha1_store.add(file) file_sha1 = sha1_store.add(file, compress)
sd = sha1_store.get(file_sha1) sd = sha1_store.get(file_sha1)
assert objects_are_equal(sd, torch.load(file)) assert objects_are_equal(sd, torch.load(file))
sd_sha1 = sha1_store.add(state_dict) sd_sha1 = sha1_store.add(state_dict, compress)
sd = sha1_store.get(sd_sha1) sd = sha1_store.get(sd_sha1)
assert objects_are_equal(sd, state_dict) assert objects_are_equal(sd, state_dict)
tensor_sha1 = sha1_store.add(tensor) tensor_sha1 = sha1_store.add(tensor, compress)
tensor_got = sha1_store.get(tensor_sha1) tensor_got = sha1_store.get(tensor_sha1)
assert objects_are_equal(tensor_got, tensor) assert objects_are_equal(tensor_got, tensor)
...@@ -142,22 +146,23 @@ def test_sha1_get(sha1_store): ...@@ -142,22 +146,23 @@ def test_sha1_get(sha1_store):
sha1_store.get(tensor_sha1[:-1]) sha1_store.get(tensor_sha1[:-1])
def test_sha1_delete(sha1_store): @pytest.mark.parametrize("compress", [True, False])
def test_sha1_delete(sha1_store, compress):
"""Testing the delete() API: with ref counting behavior.""" """Testing the delete() API: with ref counting behavior."""
os.chdir(PARENT_DIR) os.chdir(PARENT_DIR)
# Add once and delete, second delete should throw an exception. # Add once and delete, second delete should throw an exception.
tensor = torch.ones(30, 50) tensor = torch.ones(30, 50)
sha1 = sha1_store.add(tensor) sha1 = sha1_store.add(tensor, compress)
sha1_store.delete(sha1) sha1_store.delete(sha1)
with pytest.raises(ValueError): with pytest.raises(ValueError):
sha1_store.delete(sha1) sha1_store.delete(sha1)
# Add multiple times and delete should match that. # Add multiple times and delete should match that.
state_dict = nn.Sequential(nn.Linear(10, 10), nn.Linear(10, 20)).state_dict() state_dict = nn.Sequential(nn.Linear(10, 10), nn.Linear(10, 20)).state_dict()
sha1 = sha1_store.add(state_dict) sha1 = sha1_store.add(state_dict, compress)
for i in range(3): for i in range(3):
new_sha1 = sha1_store.add(state_dict) new_sha1 = sha1_store.add(state_dict, compress)
assert sha1 == new_sha1, f"{sha1} vs. {new_sha1}" assert sha1 == new_sha1, f"{sha1} vs. {new_sha1}"
for i in range(4): for i in range(4):
sha1_store.delete(sha1) sha1_store.delete(sha1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment