Unverified Commit fd7b962f authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix] unclose FD and not load/store metadata many times (#1038)



* [fix] unclose FD and not load/store metadata many times

* one more stat

* Update fairscale/experimental/wgit/sha1_store.py

* add name to the objects when added

* dict key can be int from a state_dict

* removed top_level_objects key; it should be added into repo, not sha1_store
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent b0c3fe1e
...@@ -28,7 +28,7 @@ REL_PATH_KEY = "file_path" # this will be removed from the json since it is red ...@@ -28,7 +28,7 @@ REL_PATH_KEY = "file_path" # this will be removed from the json since it is red
class RepoStatus(Enum): class RepoStatus(Enum):
"""Collections of Repo Statuses""" """Repo Statuses"""
CLEAN = 1 CLEAN = 1
CHANGES_NOT_ADDED = 2 CHANGES_NOT_ADDED = 2
...@@ -39,7 +39,7 @@ class RepoStatus(Enum): ...@@ -39,7 +39,7 @@ class RepoStatus(Enum):
class SizeInfo: class SizeInfo:
"""Size info for a file or the repo in bytes. """Size info for a file or the repo in bytes.
Deduped size can't be disabled. So it will always be there. Deduped size can't be disabled. So it is always performed.
Both sparsified and gzipped are optional. They are applied in the following Both sparsified and gzipped are optional. They are applied in the following
order if both are enabled: order if both are enabled:
...@@ -59,7 +59,7 @@ class SizeInfo: ...@@ -59,7 +59,7 @@ class SizeInfo:
class _SHA1_Tensor: class _SHA1_Tensor:
"""Representing a tensor using sha1(s) from SHA1 store. """Representing a tensor using sha1(s) from SHA1 store.
It can be either a dense one or 2 sparse one with SST and DST. It can be either a dense one or two sparse one (SST and DST).
""" """
is_dense: bool = True is_dense: bool = True
...@@ -68,23 +68,35 @@ class _SHA1_Tensor: ...@@ -68,23 +68,35 @@ class _SHA1_Tensor:
dst_sha1: str = "" dst_sha1: str = ""
def _recursive_apply_to_elements(data: Union[List[Any], Dict[str, Any]], fn: Any) -> None: def _recursive_apply_to_elements(data: Union[List[Any], Dict[str, Any]], fn: Any, names: List[str]) -> None:
"""Helper function to traverse a dict recursively and apply a function to leafs. """Helper function to traverse a dict recursively and apply a function to leafs.
The input `data` is a dict or a list and it should only contain dict and list.
Args:
data (dict or list):
A dict or a list and it should only contain dict and list.
fn (Any):
A call back function on each element. Signature:
fn(element: Any, names: List[str]) -> Any
names (list):
Stack of names for making the element path.
""" """
if isinstance(data, list): if isinstance(data, list):
for i, _ in enumerate(data): for i, _ in enumerate(data):
names.append(str(i))
if isinstance(data[i], (list, dict)): if isinstance(data[i], (list, dict)):
_recursive_apply_to_elements(data[i], fn) _recursive_apply_to_elements(data[i], fn, names)
else: else:
data[i] = fn(data[i]) data[i] = fn(data[i], names)
names.pop()
elif isinstance(data, dict): elif isinstance(data, dict):
for key in data.keys(): for key in data.keys():
names.append(str(key))
if isinstance(data[key], (list, dict)): if isinstance(data[key], (list, dict)):
_recursive_apply_to_elements(data[key], fn) _recursive_apply_to_elements(data[key], fn, names)
else: else:
data[key] = fn(data[key]) data[key] = fn(data[key], names)
names.pop()
else: else:
assert False, f"Unexpected data type: {type(data)}" assert False, f"Unexpected data type: {type(data)}"
...@@ -250,7 +262,7 @@ class Repo: ...@@ -250,7 +262,7 @@ class Repo:
# yet. Need to figure out a method for delta tracking. # yet. Need to figure out a method for delta tracking.
if per_tensor: if per_tensor:
def fn(element: Any) -> Any: def fn(element: Any, names: List[str]) -> Any:
"""Callback on each leaf object for _recursive_apply_to_elements below.""" """Callback on each leaf object for _recursive_apply_to_elements below."""
if isinstance(element, Tensor): if isinstance(element, Tensor):
if sparsify: if sparsify:
...@@ -258,13 +270,13 @@ class Repo: ...@@ -258,13 +270,13 @@ class Repo:
# tensors with sparsity. # tensors with sparsity.
# Remember to update ret_state_dict # Remember to update ret_state_dict
raise NotImplementedError() raise NotImplementedError()
sha1 = self._sha1_store.add(element, compress=gzip) sha1 = self._sha1_store.add(element, compress=gzip, name=".".join(names))
return _SHA1_Tensor(is_dense=True, dense_sha1=sha1) return _SHA1_Tensor(is_dense=True, dense_sha1=sha1)
else: else:
return element return element
state_dict = torch.load(file_path) state_dict = torch.load(file_path)
_recursive_apply_to_elements(state_dict, fn) _recursive_apply_to_elements(state_dict, fn, [])
file_path_or_state_dict = state_dict file_path_or_state_dict = state_dict
# Add this top-level object. # Add this top-level object.
......
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
import hashlib import hashlib
import json import json
import logging
import os
from pathlib import Path from pathlib import Path
import shutil import shutil
import sys import sys
...@@ -89,6 +91,31 @@ def _copy_uncompressed(src: Path, dest: Path, thread: Optional[int], blocksize: ...@@ -89,6 +91,31 @@ def _copy_uncompressed(src: Path, dest: Path, thread: Optional[int], blocksize:
destf.write(buf) destf.write(buf)
class _JSON_DictContext:
"""Helper class that handles syncing of a json and a dict."""
def __init__(self, s: "SHA1_Store", readonly: bool) -> None:
self._s = s
self._readonly = readonly
def __enter__(self) -> None:
"""Load from file."""
assert self._s._json_dict is None
if self._s._metadata_file_path.exists():
with open(self._s._metadata_file_path, "r") as f:
self._s._json_dict = json.load(f)
else:
self._s._json_dict = {}
def __exit__(self, exc_type: Any, exc_value: Any, exc_traceback: Any) -> None:
"""Store back to file."""
assert isinstance(self._s._json_dict, dict)
if not self._readonly:
with open(self._s._metadata_file_path, "w", encoding="utf-8") as f:
json.dump(self._s._json_dict, f, ensure_ascii=False, indent=2)
self._s._json_dict = None
class SHA1_Store: class SHA1_Store:
""" """
This class represents a SHA1 checksum based storage dir for state_dict This class represents a SHA1 checksum based storage dir for state_dict
...@@ -146,16 +173,15 @@ class SHA1_Store: ...@@ -146,16 +173,15 @@ class SHA1_Store:
) -> None: ) -> None:
"""Create or wrap (if already exists) a store.""" """Create or wrap (if already exists) a store."""
self._path = path self._path = path
self._metadata_file_path = self._path.joinpath("metadata.json")
self._sha1_buf_size = sha1_buf_size self._sha1_buf_size = sha1_buf_size
self._pgzip_threads = pgzip_threads self._pgzip_threads = pgzip_threads
self._pgzip_block_size = pgzip_block_size self._pgzip_block_size = pgzip_block_size
self._json_dict: Dict[str, Any] = {
STORE_CREATE_DATE_KEY: time.ctime(), # Metadata related.
STORE_OS_KEY: 0, self._metadata_file_path = self._path.joinpath("metadata.json")
STORE_DS_KEY: 0, self._json_dict: Optional[Dict[str, Any]] = None
STORE_CS_KEY: 0, self._json_ctx = _JSON_DictContext(self, readonly=False)
} self._readonly_json_ctx = _JSON_DictContext(self, readonly=True)
# Initialize the store if not exist and if init is True. # Initialize the store if not exist and if init is True.
if init and not self._path.exists(): if init and not self._path.exists():
...@@ -165,7 +191,13 @@ class SHA1_Store: ...@@ -165,7 +191,13 @@ class SHA1_Store:
sys.stderr.write(f"An exception occured while creating Sha1_store: {repr(error)}\n") sys.stderr.write(f"An exception occured while creating Sha1_store: {repr(error)}\n")
sys.exit(ExitCode.FILE_EXISTS_ERROR) sys.exit(ExitCode.FILE_EXISTS_ERROR)
# Create a new json file for this new store. # Create a new json file for this new store.
self._store_json_dict() with self._json_ctx:
self._json_dict = {
STORE_CREATE_DATE_KEY: time.ctime(),
STORE_OS_KEY: 0,
STORE_DS_KEY: 0,
STORE_CS_KEY: 0,
}
# This is an internal error since caller of this our own wgit code. # This is an internal error since caller of this our own wgit code.
assert ( assert (
...@@ -173,8 +205,8 @@ class SHA1_Store: ...@@ -173,8 +205,8 @@ class SHA1_Store:
), f"SHA1 store {self._path} does not exist and init is False" ), f"SHA1 store {self._path} does not exist and init is False"
# Make sure there is a valid metadata file. # Make sure there is a valid metadata file.
self._load_json_dict() with self._readonly_json_ctx:
assert STORE_CREATE_DATE_KEY in self._json_dict, f"Invalid SHA1 Store in {self._path}" assert STORE_CREATE_DATE_KEY in self._json_dict, f"Invalid SHA1 Store in {self._path}"
# Init temp dir. # Init temp dir.
if tmp_dir: if tmp_dir:
...@@ -187,16 +219,6 @@ class SHA1_Store: ...@@ -187,16 +219,6 @@ class SHA1_Store:
shutil.rmtree(self._tmp_dir, ignore_errors=True) shutil.rmtree(self._tmp_dir, ignore_errors=True)
self._tmp_dir.mkdir() self._tmp_dir.mkdir()
def _load_json_dict(self) -> None:
"""Loading json dict from disk."""
with open(self._metadata_file_path, "r") as f:
self._json_dict = json.load(f)
def _store_json_dict(self) -> None:
"""Storing json dict to disk."""
with open(self._metadata_file_path, "w", encoding="utf-8") as f:
json.dump(self._json_dict, f, ensure_ascii=False, indent=4)
def add(self, file_or_obj: Union[Path, Tensor, Dict], compress: bool = True, name: str = None) -> str: def add(self, file_or_obj: Union[Path, Tensor, Dict], compress: bool = True, name: str = None) -> str:
"""Adds a file/object to this store and the sha1 references accordingly. """Adds a file/object to this store and the sha1 references accordingly.
...@@ -221,6 +243,8 @@ class SHA1_Store: ...@@ -221,6 +243,8 @@ class SHA1_Store:
Optional name for this object. Optional name for this object.
Default: None Default: None
""" """
start = time.time()
# Use `isinstance` not type() == Path since pathlib returns OS specific # Use `isinstance` not type() == Path since pathlib returns OS specific
# Path types, which inherit from the Path class. # Path types, which inherit from the Path class.
if isinstance(file_or_obj, (Path, str)): if isinstance(file_or_obj, (Path, str)):
...@@ -240,65 +264,70 @@ class SHA1_Store: ...@@ -240,65 +264,70 @@ class SHA1_Store:
assert isinstance(file_path, Path), type(file_path) assert isinstance(file_path, Path), type(file_path)
sha1_hash = self._get_sha1_hash(file_path) sha1_hash = self._get_sha1_hash(file_path)
# Add reference. # Load json for many meta data operations below. Repeatedly loading
ref_count = self._add_ref(sha1_hash, True, compress) # can be very slow.
with self._json_ctx:
if ref_count == 1: # First time adding.
# Create the file dir, if needed. # Add reference.
repo_fdir = self._sha1_to_dir(sha1_hash) ref_count = self._add_ref(sha1_hash, True, compress)
if not repo_fdir.exists():
if ref_count == 1: # First time adding.
# Create the file dir, if needed.
repo_fdir = self._sha1_to_dir(sha1_hash)
if not repo_fdir.exists():
try:
repo_fdir.mkdir(exist_ok=True, parents=True)
except FileExistsError as error:
sys.stderr.write(f"An exception occured: {repr(error)}\n")
sys.exit(ExitCode.FILE_EXISTS_ERROR)
# Transfer the file to the store.
repo_fpath = repo_fdir.joinpath(sha1_hash)
try: try:
repo_fdir.mkdir(exist_ok=True, parents=True) if compress:
except FileExistsError as error: orig_size, comp_size = _copy_compressed(
file_path, repo_fpath, self._pgzip_threads, self._pgzip_block_size
)
else:
shutil.copy2(file_path, repo_fpath)
orig_size = comp_size = file_path.stat().st_size
except BaseException as error:
# Something went wrong, perhaps out of space, or race condition due to lack of locking.
# TODO (Min): proper handle the error and recover when we learn more here.
sys.stderr.write(f"An exception occured: {repr(error)}\n") sys.stderr.write(f"An exception occured: {repr(error)}\n")
sys.exit(ExitCode.FILE_EXISTS_ERROR) ref_count = self._add_ref(sha1_hash, False, compress)
# Transfer the file to the store. # Update the sizes for this entry.
repo_fpath = repo_fdir.joinpath(sha1_hash) entry = _get_json_entry(self._json_dict[sha1_hash])
try: assert (
if compress: ref_count == 1 or entry[ENTRY_OS_KEY] % (ref_count - 1) == 0
orig_size, comp_size = _copy_compressed( ), f"incorrect size: {entry[ENTRY_OS_KEY]} and {ref_count}"
file_path, repo_fpath, self._pgzip_threads, self._pgzip_block_size o_diff = orig_size if ref_count == 1 else (entry[ENTRY_OS_KEY] // (ref_count - 1))
) d_diff = orig_size if ref_count == 1 else 0
c_diff = comp_size if ref_count == 1 else 0
entry[ENTRY_OS_KEY] += o_diff
entry[ENTRY_DS_KEY] += d_diff
entry[ENTRY_CS_KEY] += c_diff
# Update whole store's stats.
self._json_dict[STORE_OS_KEY] += o_diff
self._json_dict[STORE_DS_KEY] += d_diff
self._json_dict[STORE_CS_KEY] += c_diff
# Update the name list for this entry.
if name:
if name not in entry[ENTRY_NAMES_KEY].keys():
entry[ENTRY_NAMES_KEY][name] = 1
else: else:
shutil.copy2(file_path, repo_fpath) entry[ENTRY_NAMES_KEY][name] += 1
orig_size = comp_size = file_path.stat().st_size
except BaseException as error:
# Something went wrong, perhaps out of space, or race condition due to lack of locking.
# TODO (Min): proper handle the error and recover when we learn more here.
sys.stderr.write(f"An exception occured: {repr(error)}\n")
ref_count = self._add_ref(sha1_hash, False, compress)
self._load_json_dict()
# Update the sizes for this entry.
entry = _get_json_entry(self._json_dict[sha1_hash])
assert (
ref_count == 1 or entry[ENTRY_OS_KEY] % (ref_count - 1) == 0
), f"incorrect size: {entry[ENTRY_OS_KEY]} and {ref_count}"
o_diff = orig_size if ref_count == 1 else (entry[ENTRY_OS_KEY] // (ref_count - 1))
d_diff = orig_size if ref_count == 1 else 0
c_diff = comp_size if ref_count == 1 else 0
entry[ENTRY_OS_KEY] += o_diff
entry[ENTRY_DS_KEY] += d_diff
entry[ENTRY_CS_KEY] += c_diff
self._json_dict[STORE_OS_KEY] += o_diff
self._json_dict[STORE_DS_KEY] += d_diff
self._json_dict[STORE_CS_KEY] += c_diff
# Update the name list for this entry.
if name:
if name not in entry[ENTRY_NAMES_KEY].keys():
entry[ENTRY_NAMES_KEY][name] = 1
else:
entry[ENTRY_NAMES_KEY][name] += 1
self._store_json_dict()
# Clean up if needed. # Clean up if needed.
if remove_tmp: if remove_tmp:
file_path.unlink() file_path.unlink()
duration = time.time() - start
if duration > 60:
logging.warning(f"Add() is taking long: {duration}s")
return sha1_hash return sha1_hash
def get(self, sha1: str) -> Union[Tensor, Dict]: def get(self, sha1: str) -> Union[Tensor, Dict]:
...@@ -326,18 +355,18 @@ class SHA1_Store: ...@@ -326,18 +355,18 @@ class SHA1_Store:
# #
# TODO (Min): we could also keep a stats in the meta data on how many # TODO (Min): we could also keep a stats in the meta data on how many
# times the object is read. Will add if that's needed. # times the object is read. Will add if that's needed.
self._load_json_dict() with self._readonly_json_ctx:
if self._json_dict[sha1][ENTRY_COMP_KEY]: if self._json_dict[sha1][ENTRY_COMP_KEY]:
# Compressed. Because pgzip doesn't support tell() yet, we need to # Compressed. Because pgzip doesn't support tell() yet, we need to
# uncompress into a temp file and return it. # uncompress into a temp file and return it.
tmp = self._get_tmp_file_path() tmp = self._get_tmp_file_path()
_copy_uncompressed(path, tmp, self._pgzip_threads, self._pgzip_block_size) _copy_uncompressed(path, tmp, self._pgzip_threads, self._pgzip_block_size)
obj = torch.load(tmp) obj = torch.load(tmp)
tmp.unlink() tmp.unlink()
return obj return obj
else: else:
# Uncompressed. # Uncompressed.
return torch.load(path) return torch.load(path)
def delete(self, sha1: str) -> None: def delete(self, sha1: str) -> None:
"""Delete a SHA1 """Delete a SHA1
...@@ -355,43 +384,41 @@ class SHA1_Store: ...@@ -355,43 +384,41 @@ class SHA1_Store:
# the caller about it. # the caller about it.
raise ValueError(f"Try to delete SHA1 {sha1} but it is not found") raise ValueError(f"Try to delete SHA1 {sha1} but it is not found")
self._load_json_dict() with self._json_ctx:
assert sha1 in self._json_dict.keys(), "internal error: sha1 not found in json"
assert sha1 in self._json_dict.keys(), "internal error: sha1 not found in json" entry = _get_json_entry(self._json_dict[sha1])
entry = _get_json_entry(self._json_dict[sha1])
assert entry[ENTRY_RF_KEY] > 0, f"ref count {entry[ENTRY_RF_KEY]} should be positive" assert entry[ENTRY_RF_KEY] > 0, f"ref count {entry[ENTRY_RF_KEY]} should be positive"
entry[ENTRY_RF_KEY] -= 1 entry[ENTRY_RF_KEY] -= 1
if entry[ENTRY_RF_KEY] == 0: if entry[ENTRY_RF_KEY] == 0:
# Now, since ref count is 0 now deleting the object. # Now, since ref count is 0 now deleting the object.
path.unlink() # We may leave behind an empty dir, which is OK. path.unlink() # We may leave behind an empty dir, which is OK.
entry = {} # Below, we remove the entry because of this. entry = {} # Below, we remove the entry because of this.
# Put the entry back and store it or delete it. # Put the entry back and store it or delete it.
if entry: if entry:
self._json_dict[sha1] = entry self._json_dict[sha1] = entry
else: else:
# empty entry, it means this sha1 is deleted. # empty entry, it means this sha1 is deleted.
del self._json_dict[sha1] del self._json_dict[sha1]
self._store_json_dict()
def size_info(self, sha1: Optional[str] = None) -> Tuple[int, int, int]: def size_info(self, sha1: Optional[str] = None) -> Tuple[int, int, int]:
"""Return original, deduped, gzipped sizes for an entry or the store.""" """Return original, deduped, gzipped sizes for an entry or the store."""
self._load_json_dict() with self._readonly_json_ctx:
if sha1: if sha1:
if sha1 not in self._json_dict.keys(): if sha1 not in self._json_dict.keys():
raise ValueError(f"SHA1 {sha1} not found") raise ValueError(f"SHA1 {sha1} not found")
entry = self._json_dict[sha1] entry = self._json_dict[sha1]
return entry[ENTRY_OS_KEY], entry[ENTRY_DS_KEY], entry[ENTRY_CS_KEY] return entry[ENTRY_OS_KEY], entry[ENTRY_DS_KEY], entry[ENTRY_CS_KEY]
return self._json_dict[STORE_OS_KEY], self._json_dict[STORE_DS_KEY], self._json_dict[STORE_CS_KEY] return self._json_dict[STORE_OS_KEY], self._json_dict[STORE_DS_KEY], self._json_dict[STORE_CS_KEY]
def names(self, sha1: str = None) -> Dict[str, int]: def names(self, sha1: str = None) -> Dict[str, int]:
"""Return the names dict for an object.""" """Return the names dict for an object."""
self._load_json_dict() with self._readonly_json_ctx:
if sha1 not in self._json_dict.keys(): if sha1 not in self._json_dict.keys():
raise ValueError(f"SHA1 {sha1} not found") raise ValueError(f"SHA1 {sha1} not found")
entry = self._json_dict[sha1] entry = self._json_dict[sha1]
return entry[ENTRY_NAMES_KEY] return entry[ENTRY_NAMES_KEY]
def _get_sha1_hash(self, file_path: Union[str, Path]) -> str: def _get_sha1_hash(self, file_path: Union[str, Path]) -> str:
"""Return the sha1 hash of a file """Return the sha1 hash of a file
...@@ -415,7 +442,9 @@ class SHA1_Store: ...@@ -415,7 +442,9 @@ class SHA1_Store:
def _get_tmp_file_path(self) -> Path: def _get_tmp_file_path(self) -> Path:
"""Helper to get a tmp file name under self.tmp_dir.""" """Helper to get a tmp file name under self.tmp_dir."""
return Path(tempfile.mkstemp(dir=self._tmp_dir)[1]) fd, name = tempfile.mkstemp(dir=self._tmp_dir)
os.close(fd) # Must close this FD or unlink() won't be able to release the space of the file.
return Path(name)
def _sha1_to_dir(self, sha1: str) -> Path: def _sha1_to_dir(self, sha1: str) -> Path:
"""Helper to get the internal dir for a file based on its SHA1""" """Helper to get the internal dir for a file based on its SHA1"""
...@@ -444,8 +473,6 @@ class SHA1_Store: ...@@ -444,8 +473,6 @@ class SHA1_Store:
(int): (int):
Resulting ref count. Resulting ref count.
""" """
self._load_json_dict()
# Init the entry if needed. # Init the entry if needed.
if current_sha1_hash not in self._json_dict: if current_sha1_hash not in self._json_dict:
entry = {} entry = {}
...@@ -461,6 +488,5 @@ class SHA1_Store: ...@@ -461,6 +488,5 @@ class SHA1_Store:
entry[ENTRY_COMP_KEY] = compressed entry[ENTRY_COMP_KEY] = compressed
self._json_dict[current_sha1_hash] = entry self._json_dict[current_sha1_hash] = entry
self._store_json_dict()
return entry[ENTRY_RF_KEY] return entry[ENTRY_RF_KEY]
...@@ -80,8 +80,8 @@ def test_sha1_add_file(sha1_store, compress): ...@@ -80,8 +80,8 @@ def test_sha1_add_file(sha1_store, compress):
sha1_store.add(zeros_file, compress) sha1_store.add(zeros_file, compress)
# Assert the ref counts are 1,1,1,1,1 and 2 # Assert the ref counts are 1,1,1,1,1 and 2
sha1_store._load_json_dict() with sha1_store._readonly_json_ctx:
json_dict = sha1_store._json_dict json_dict = sha1_store._json_dict
if torch_version() >= (1, 9, 0): if torch_version() >= (1, 9, 0):
# torch 1.8 LTS doesn't produce deterministic checkpoint file from fixed tensors/state_dict. # torch 1.8 LTS doesn't produce deterministic checkpoint file from fixed tensors/state_dict.
key = "da3e19590de8f77fcf7a09c888c526b0149863a0" key = "da3e19590de8f77fcf7a09c888c526b0149863a0"
...@@ -102,8 +102,8 @@ def test_sha1_add_state_dict(sha1_store, compress): ...@@ -102,8 +102,8 @@ def test_sha1_add_state_dict(sha1_store, compress):
sha1_store.add(sd, compress) sha1_store.add(sd, compress)
sha1_store.add(sd, compress) sha1_store.add(sd, compress)
sha1_store._load_json_dict() with sha1_store._readonly_json_ctx:
json_dict = sha1_store._json_dict json_dict = sha1_store._json_dict
json_dict = dict(filter(lambda item: len(item[0]) == SHA1_KEY_STR_LEN, json_dict.items())) json_dict = dict(filter(lambda item: len(item[0]) == SHA1_KEY_STR_LEN, json_dict.items()))
assert sorted(map(lambda x: x["ref_count"], json_dict.values())) == [1, 1, 1, 2, 2, 2], json_dict assert sorted(map(lambda x: x["ref_count"], json_dict.values())) == [1, 1, 1, 2, 2, 2], json_dict
...@@ -112,8 +112,8 @@ def test_sha1_add_state_dict(sha1_store, compress): ...@@ -112,8 +112,8 @@ def test_sha1_add_state_dict(sha1_store, compress):
def test_sha1_add_tensor(sha1_store, compress): def test_sha1_add_tensor(sha1_store, compress):
os.chdir(TESTING_STORE_DIR) os.chdir(TESTING_STORE_DIR)
sha1_store.add(torch.Tensor([1.0, 5.5, 3.4]), compress) sha1_store.add(torch.Tensor([1.0, 5.5, 3.4]), compress)
sha1_store._load_json_dict() with sha1_store._readonly_json_ctx:
json_dict = sha1_store._json_dict json_dict = sha1_store._json_dict
if torch_version() >= (1, 9, 0): if torch_version() >= (1, 9, 0):
# torch 1.8 LTS doesn't produce deterministic checkpoint file from fixed tensors/state_dict. # torch 1.8 LTS doesn't produce deterministic checkpoint file from fixed tensors/state_dict.
key = "71df4069a03a766eacf9f03eea50968e87eae9f8" key = "71df4069a03a766eacf9f03eea50968e87eae9f8"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment