Unverified Commit fd7b962f authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix] unclose FD and not load/store metadata many times (#1038)



* [fix] unclose FD and not load/store metadata many times

* one more stat

* Update fairscale/experimental/wgit/sha1_store.py

* add name to the objects when added

* dict key can be int from a state_dict

* removed top_level_objects key; it should be added into repo, not sha1_store
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent b0c3fe1e
......@@ -28,7 +28,7 @@ REL_PATH_KEY = "file_path" # this will be removed from the json since it is red
class RepoStatus(Enum):
"""Collections of Repo Statuses"""
"""Repo Statuses"""
CLEAN = 1
CHANGES_NOT_ADDED = 2
......@@ -39,7 +39,7 @@ class RepoStatus(Enum):
class SizeInfo:
"""Size info for a file or the repo in bytes.
Deduped size can't be disabled. So it will always be there.
Deduped size can't be disabled. So it is always performed.
Both sparsified and gzipped are optional. They are applied in the following
order if both are enabled:
......@@ -59,7 +59,7 @@ class SizeInfo:
class _SHA1_Tensor:
"""Representing a tensor using sha1(s) from SHA1 store.
It can be either a dense one or 2 sparse one with SST and DST.
It can be either a dense one or two sparse one (SST and DST).
"""
is_dense: bool = True
......@@ -68,23 +68,35 @@ class _SHA1_Tensor:
dst_sha1: str = ""
def _recursive_apply_to_elements(data: Union[List[Any], Dict[str, Any]], fn: Any) -> None:
def _recursive_apply_to_elements(data: Union[List[Any], Dict[str, Any]], fn: Any, names: List[str]) -> None:
"""Helper function to traverse a dict recursively and apply a function to leafs.
The input `data` is a dict or a list and it should only contain dict and list.
Args:
data (dict or list):
A dict or a list and it should only contain dict and list.
fn (Any):
A call back function on each element. Signature:
fn(element: Any, names: List[str]) -> Any
names (list):
Stack of names for making the element path.
"""
if isinstance(data, list):
for i, _ in enumerate(data):
names.append(str(i))
if isinstance(data[i], (list, dict)):
_recursive_apply_to_elements(data[i], fn)
_recursive_apply_to_elements(data[i], fn, names)
else:
data[i] = fn(data[i])
data[i] = fn(data[i], names)
names.pop()
elif isinstance(data, dict):
for key in data.keys():
names.append(str(key))
if isinstance(data[key], (list, dict)):
_recursive_apply_to_elements(data[key], fn)
_recursive_apply_to_elements(data[key], fn, names)
else:
data[key] = fn(data[key])
data[key] = fn(data[key], names)
names.pop()
else:
assert False, f"Unexpected data type: {type(data)}"
......@@ -250,7 +262,7 @@ class Repo:
# yet. Need to figure out a method for delta tracking.
if per_tensor:
def fn(element: Any) -> Any:
def fn(element: Any, names: List[str]) -> Any:
"""Callback on each leaf object for _recursive_apply_to_elements below."""
if isinstance(element, Tensor):
if sparsify:
......@@ -258,13 +270,13 @@ class Repo:
# tensors with sparsity.
# Remember to update ret_state_dict
raise NotImplementedError()
sha1 = self._sha1_store.add(element, compress=gzip)
sha1 = self._sha1_store.add(element, compress=gzip, name=".".join(names))
return _SHA1_Tensor(is_dense=True, dense_sha1=sha1)
else:
return element
state_dict = torch.load(file_path)
_recursive_apply_to_elements(state_dict, fn)
_recursive_apply_to_elements(state_dict, fn, [])
file_path_or_state_dict = state_dict
# Add this top-level object.
......
......@@ -6,6 +6,8 @@
import hashlib
import json
import logging
import os
from pathlib import Path
import shutil
import sys
......@@ -89,6 +91,31 @@ def _copy_uncompressed(src: Path, dest: Path, thread: Optional[int], blocksize:
destf.write(buf)
class _JSON_DictContext:
"""Helper class that handles syncing of a json and a dict."""
def __init__(self, s: "SHA1_Store", readonly: bool) -> None:
self._s = s
self._readonly = readonly
def __enter__(self) -> None:
"""Load from file."""
assert self._s._json_dict is None
if self._s._metadata_file_path.exists():
with open(self._s._metadata_file_path, "r") as f:
self._s._json_dict = json.load(f)
else:
self._s._json_dict = {}
def __exit__(self, exc_type: Any, exc_value: Any, exc_traceback: Any) -> None:
"""Store back to file."""
assert isinstance(self._s._json_dict, dict)
if not self._readonly:
with open(self._s._metadata_file_path, "w", encoding="utf-8") as f:
json.dump(self._s._json_dict, f, ensure_ascii=False, indent=2)
self._s._json_dict = None
class SHA1_Store:
"""
This class represents a SHA1 checksum based storage dir for state_dict
......@@ -146,16 +173,15 @@ class SHA1_Store:
) -> None:
"""Create or wrap (if already exists) a store."""
self._path = path
self._metadata_file_path = self._path.joinpath("metadata.json")
self._sha1_buf_size = sha1_buf_size
self._pgzip_threads = pgzip_threads
self._pgzip_block_size = pgzip_block_size
self._json_dict: Dict[str, Any] = {
STORE_CREATE_DATE_KEY: time.ctime(),
STORE_OS_KEY: 0,
STORE_DS_KEY: 0,
STORE_CS_KEY: 0,
}
# Metadata related.
self._metadata_file_path = self._path.joinpath("metadata.json")
self._json_dict: Optional[Dict[str, Any]] = None
self._json_ctx = _JSON_DictContext(self, readonly=False)
self._readonly_json_ctx = _JSON_DictContext(self, readonly=True)
# Initialize the store if not exist and if init is True.
if init and not self._path.exists():
......@@ -165,7 +191,13 @@ class SHA1_Store:
sys.stderr.write(f"An exception occured while creating Sha1_store: {repr(error)}\n")
sys.exit(ExitCode.FILE_EXISTS_ERROR)
# Create a new json file for this new store.
self._store_json_dict()
with self._json_ctx:
self._json_dict = {
STORE_CREATE_DATE_KEY: time.ctime(),
STORE_OS_KEY: 0,
STORE_DS_KEY: 0,
STORE_CS_KEY: 0,
}
# This is an internal error since caller of this our own wgit code.
assert (
......@@ -173,8 +205,8 @@ class SHA1_Store:
), f"SHA1 store {self._path} does not exist and init is False"
# Make sure there is a valid metadata file.
self._load_json_dict()
assert STORE_CREATE_DATE_KEY in self._json_dict, f"Invalid SHA1 Store in {self._path}"
with self._readonly_json_ctx:
assert STORE_CREATE_DATE_KEY in self._json_dict, f"Invalid SHA1 Store in {self._path}"
# Init temp dir.
if tmp_dir:
......@@ -187,16 +219,6 @@ class SHA1_Store:
shutil.rmtree(self._tmp_dir, ignore_errors=True)
self._tmp_dir.mkdir()
def _load_json_dict(self) -> None:
"""Loading json dict from disk."""
with open(self._metadata_file_path, "r") as f:
self._json_dict = json.load(f)
def _store_json_dict(self) -> None:
"""Storing json dict to disk."""
with open(self._metadata_file_path, "w", encoding="utf-8") as f:
json.dump(self._json_dict, f, ensure_ascii=False, indent=4)
def add(self, file_or_obj: Union[Path, Tensor, Dict], compress: bool = True, name: str = None) -> str:
"""Adds a file/object to this store and the sha1 references accordingly.
......@@ -221,6 +243,8 @@ class SHA1_Store:
Optional name for this object.
Default: None
"""
start = time.time()
# Use `isinstance` not type() == Path since pathlib returns OS specific
# Path types, which inherit from the Path class.
if isinstance(file_or_obj, (Path, str)):
......@@ -240,65 +264,70 @@ class SHA1_Store:
assert isinstance(file_path, Path), type(file_path)
sha1_hash = self._get_sha1_hash(file_path)
# Add reference.
ref_count = self._add_ref(sha1_hash, True, compress)
if ref_count == 1: # First time adding.
# Create the file dir, if needed.
repo_fdir = self._sha1_to_dir(sha1_hash)
if not repo_fdir.exists():
# Load json for many meta data operations below. Repeatedly loading
# can be very slow.
with self._json_ctx:
# Add reference.
ref_count = self._add_ref(sha1_hash, True, compress)
if ref_count == 1: # First time adding.
# Create the file dir, if needed.
repo_fdir = self._sha1_to_dir(sha1_hash)
if not repo_fdir.exists():
try:
repo_fdir.mkdir(exist_ok=True, parents=True)
except FileExistsError as error:
sys.stderr.write(f"An exception occured: {repr(error)}\n")
sys.exit(ExitCode.FILE_EXISTS_ERROR)
# Transfer the file to the store.
repo_fpath = repo_fdir.joinpath(sha1_hash)
try:
repo_fdir.mkdir(exist_ok=True, parents=True)
except FileExistsError as error:
if compress:
orig_size, comp_size = _copy_compressed(
file_path, repo_fpath, self._pgzip_threads, self._pgzip_block_size
)
else:
shutil.copy2(file_path, repo_fpath)
orig_size = comp_size = file_path.stat().st_size
except BaseException as error:
# Something went wrong, perhaps out of space, or race condition due to lack of locking.
# TODO (Min): proper handle the error and recover when we learn more here.
sys.stderr.write(f"An exception occured: {repr(error)}\n")
sys.exit(ExitCode.FILE_EXISTS_ERROR)
# Transfer the file to the store.
repo_fpath = repo_fdir.joinpath(sha1_hash)
try:
if compress:
orig_size, comp_size = _copy_compressed(
file_path, repo_fpath, self._pgzip_threads, self._pgzip_block_size
)
ref_count = self._add_ref(sha1_hash, False, compress)
# Update the sizes for this entry.
entry = _get_json_entry(self._json_dict[sha1_hash])
assert (
ref_count == 1 or entry[ENTRY_OS_KEY] % (ref_count - 1) == 0
), f"incorrect size: {entry[ENTRY_OS_KEY]} and {ref_count}"
o_diff = orig_size if ref_count == 1 else (entry[ENTRY_OS_KEY] // (ref_count - 1))
d_diff = orig_size if ref_count == 1 else 0
c_diff = comp_size if ref_count == 1 else 0
entry[ENTRY_OS_KEY] += o_diff
entry[ENTRY_DS_KEY] += d_diff
entry[ENTRY_CS_KEY] += c_diff
# Update whole store's stats.
self._json_dict[STORE_OS_KEY] += o_diff
self._json_dict[STORE_DS_KEY] += d_diff
self._json_dict[STORE_CS_KEY] += c_diff
# Update the name list for this entry.
if name:
if name not in entry[ENTRY_NAMES_KEY].keys():
entry[ENTRY_NAMES_KEY][name] = 1
else:
shutil.copy2(file_path, repo_fpath)
orig_size = comp_size = file_path.stat().st_size
except BaseException as error:
# Something went wrong, perhaps out of space, or race condition due to lack of locking.
# TODO (Min): proper handle the error and recover when we learn more here.
sys.stderr.write(f"An exception occured: {repr(error)}\n")
ref_count = self._add_ref(sha1_hash, False, compress)
self._load_json_dict()
# Update the sizes for this entry.
entry = _get_json_entry(self._json_dict[sha1_hash])
assert (
ref_count == 1 or entry[ENTRY_OS_KEY] % (ref_count - 1) == 0
), f"incorrect size: {entry[ENTRY_OS_KEY]} and {ref_count}"
o_diff = orig_size if ref_count == 1 else (entry[ENTRY_OS_KEY] // (ref_count - 1))
d_diff = orig_size if ref_count == 1 else 0
c_diff = comp_size if ref_count == 1 else 0
entry[ENTRY_OS_KEY] += o_diff
entry[ENTRY_DS_KEY] += d_diff
entry[ENTRY_CS_KEY] += c_diff
self._json_dict[STORE_OS_KEY] += o_diff
self._json_dict[STORE_DS_KEY] += d_diff
self._json_dict[STORE_CS_KEY] += c_diff
# Update the name list for this entry.
if name:
if name not in entry[ENTRY_NAMES_KEY].keys():
entry[ENTRY_NAMES_KEY][name] = 1
else:
entry[ENTRY_NAMES_KEY][name] += 1
self._store_json_dict()
entry[ENTRY_NAMES_KEY][name] += 1
# Clean up if needed.
if remove_tmp:
file_path.unlink()
duration = time.time() - start
if duration > 60:
logging.warning(f"Add() is taking long: {duration}s")
return sha1_hash
def get(self, sha1: str) -> Union[Tensor, Dict]:
......@@ -326,18 +355,18 @@ class SHA1_Store:
#
# TODO (Min): we could also keep a stats in the meta data on how many
# times the object is read. Will add if that's needed.
self._load_json_dict()
if self._json_dict[sha1][ENTRY_COMP_KEY]:
# Compressed. Because pgzip doesn't support tell() yet, we need to
# uncompress into a temp file and return it.
tmp = self._get_tmp_file_path()
_copy_uncompressed(path, tmp, self._pgzip_threads, self._pgzip_block_size)
obj = torch.load(tmp)
tmp.unlink()
return obj
else:
# Uncompressed.
return torch.load(path)
with self._readonly_json_ctx:
if self._json_dict[sha1][ENTRY_COMP_KEY]:
# Compressed. Because pgzip doesn't support tell() yet, we need to
# uncompress into a temp file and return it.
tmp = self._get_tmp_file_path()
_copy_uncompressed(path, tmp, self._pgzip_threads, self._pgzip_block_size)
obj = torch.load(tmp)
tmp.unlink()
return obj
else:
# Uncompressed.
return torch.load(path)
def delete(self, sha1: str) -> None:
"""Delete a SHA1
......@@ -355,43 +384,41 @@ class SHA1_Store:
# the caller about it.
raise ValueError(f"Try to delete SHA1 {sha1} but it is not found")
self._load_json_dict()
assert sha1 in self._json_dict.keys(), "internal error: sha1 not found in json"
entry = _get_json_entry(self._json_dict[sha1])
with self._json_ctx:
assert sha1 in self._json_dict.keys(), "internal error: sha1 not found in json"
entry = _get_json_entry(self._json_dict[sha1])
assert entry[ENTRY_RF_KEY] > 0, f"ref count {entry[ENTRY_RF_KEY]} should be positive"
entry[ENTRY_RF_KEY] -= 1
if entry[ENTRY_RF_KEY] == 0:
# Now, since ref count is 0 now deleting the object.
path.unlink() # We may leave behind an empty dir, which is OK.
entry = {} # Below, we remove the entry because of this.
assert entry[ENTRY_RF_KEY] > 0, f"ref count {entry[ENTRY_RF_KEY]} should be positive"
entry[ENTRY_RF_KEY] -= 1
if entry[ENTRY_RF_KEY] == 0:
# Now, since ref count is 0 now deleting the object.
path.unlink() # We may leave behind an empty dir, which is OK.
entry = {} # Below, we remove the entry because of this.
# Put the entry back and store it or delete it.
if entry:
self._json_dict[sha1] = entry
else:
# empty entry, it means this sha1 is deleted.
del self._json_dict[sha1]
self._store_json_dict()
# Put the entry back and store it or delete it.
if entry:
self._json_dict[sha1] = entry
else:
# empty entry, it means this sha1 is deleted.
del self._json_dict[sha1]
def size_info(self, sha1: Optional[str] = None) -> Tuple[int, int, int]:
"""Return original, deduped, gzipped sizes for an entry or the store."""
self._load_json_dict()
if sha1:
if sha1 not in self._json_dict.keys():
raise ValueError(f"SHA1 {sha1} not found")
entry = self._json_dict[sha1]
return entry[ENTRY_OS_KEY], entry[ENTRY_DS_KEY], entry[ENTRY_CS_KEY]
return self._json_dict[STORE_OS_KEY], self._json_dict[STORE_DS_KEY], self._json_dict[STORE_CS_KEY]
with self._readonly_json_ctx:
if sha1:
if sha1 not in self._json_dict.keys():
raise ValueError(f"SHA1 {sha1} not found")
entry = self._json_dict[sha1]
return entry[ENTRY_OS_KEY], entry[ENTRY_DS_KEY], entry[ENTRY_CS_KEY]
return self._json_dict[STORE_OS_KEY], self._json_dict[STORE_DS_KEY], self._json_dict[STORE_CS_KEY]
def names(self, sha1: str = None) -> Dict[str, int]:
"""Return the names dict for an object."""
self._load_json_dict()
if sha1 not in self._json_dict.keys():
raise ValueError(f"SHA1 {sha1} not found")
entry = self._json_dict[sha1]
return entry[ENTRY_NAMES_KEY]
with self._readonly_json_ctx:
if sha1 not in self._json_dict.keys():
raise ValueError(f"SHA1 {sha1} not found")
entry = self._json_dict[sha1]
return entry[ENTRY_NAMES_KEY]
def _get_sha1_hash(self, file_path: Union[str, Path]) -> str:
"""Return the sha1 hash of a file
......@@ -415,7 +442,9 @@ class SHA1_Store:
def _get_tmp_file_path(self) -> Path:
"""Helper to get a tmp file name under self.tmp_dir."""
return Path(tempfile.mkstemp(dir=self._tmp_dir)[1])
fd, name = tempfile.mkstemp(dir=self._tmp_dir)
os.close(fd) # Must close this FD or unlink() won't be able to release the space of the file.
return Path(name)
def _sha1_to_dir(self, sha1: str) -> Path:
"""Helper to get the internal dir for a file based on its SHA1"""
......@@ -444,8 +473,6 @@ class SHA1_Store:
(int):
Resulting ref count.
"""
self._load_json_dict()
# Init the entry if needed.
if current_sha1_hash not in self._json_dict:
entry = {}
......@@ -461,6 +488,5 @@ class SHA1_Store:
entry[ENTRY_COMP_KEY] = compressed
self._json_dict[current_sha1_hash] = entry
self._store_json_dict()
return entry[ENTRY_RF_KEY]
......@@ -80,8 +80,8 @@ def test_sha1_add_file(sha1_store, compress):
sha1_store.add(zeros_file, compress)
# Assert the ref counts are 1,1,1,1,1 and 2
sha1_store._load_json_dict()
json_dict = sha1_store._json_dict
with sha1_store._readonly_json_ctx:
json_dict = sha1_store._json_dict
if torch_version() >= (1, 9, 0):
# torch 1.8 LTS doesn't produce deterministic checkpoint file from fixed tensors/state_dict.
key = "da3e19590de8f77fcf7a09c888c526b0149863a0"
......@@ -102,8 +102,8 @@ def test_sha1_add_state_dict(sha1_store, compress):
sha1_store.add(sd, compress)
sha1_store.add(sd, compress)
sha1_store._load_json_dict()
json_dict = sha1_store._json_dict
with sha1_store._readonly_json_ctx:
json_dict = sha1_store._json_dict
json_dict = dict(filter(lambda item: len(item[0]) == SHA1_KEY_STR_LEN, json_dict.items()))
assert sorted(map(lambda x: x["ref_count"], json_dict.values())) == [1, 1, 1, 2, 2, 2], json_dict
......@@ -112,8 +112,8 @@ def test_sha1_add_state_dict(sha1_store, compress):
def test_sha1_add_tensor(sha1_store, compress):
os.chdir(TESTING_STORE_DIR)
sha1_store.add(torch.Tensor([1.0, 5.5, 3.4]), compress)
sha1_store._load_json_dict()
json_dict = sha1_store._json_dict
with sha1_store._readonly_json_ctx:
json_dict = sha1_store._json_dict
if torch_version() >= (1, 9, 0):
# torch 1.8 LTS doesn't produce deterministic checkpoint file from fixed tensors/state_dict.
key = "71df4069a03a766eacf9f03eea50968e87eae9f8"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment