Unverified Commit 16fba4c0 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix] original size computation (#1037)



* flip per_tensor's default

* fixed original size computation
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent 2e544bd7
...@@ -36,6 +36,11 @@ def main(argv: List[str] = None) -> None: ...@@ -36,6 +36,11 @@ def main(argv: List[str] = None) -> None:
metavar="FILE_PATH", metavar="FILE_PATH",
help="add a file to the staged changeset (default: none)", help="add a file to the staged changeset (default: none)",
) )
add_parser.add_argument(
"--no_per_tensor",
action="store_true",
help="Disable per-tensor adding of a file",
)
commit_parser = subparsers.add_parser("commit", description="Commits the staged changes") commit_parser = subparsers.add_parser("commit", description="Commits the staged changes")
commit_parser.add_argument("commit", action="store_true", help="Commit the staged changes") commit_parser.add_argument("commit", action="store_true", help="Commit the staged changes")
...@@ -76,7 +81,7 @@ def main(argv: List[str] = None) -> None: ...@@ -76,7 +81,7 @@ def main(argv: List[str] = None) -> None:
if args.command == "add": if args.command == "add":
repo = Repo(Path.cwd()) repo = Repo(Path.cwd())
repo.add(args.add) repo.add(args.add, per_tensor=not args.no_per_tensor)
if args.command == "status": if args.command == "status":
repo = Repo(Path.cwd()) repo = Repo(Path.cwd())
......
...@@ -192,7 +192,7 @@ class Repo: ...@@ -192,7 +192,7 @@ class Repo:
def add( def add(
self, self,
in_file_path: str, in_file_path: str,
per_tensor: bool = False, per_tensor: bool = True,
gzip: bool = True, gzip: bool = True,
sparsify: bool = False, sparsify: bool = False,
sparsify_policy: Any = None, sparsify_policy: Any = None,
...@@ -209,7 +209,7 @@ class Repo: ...@@ -209,7 +209,7 @@ class Repo:
Add a file in a per-tensor fashion. This enables more deduplication Add a file in a per-tensor fashion. This enables more deduplication
due to tensors being identical. Deduplication cannot be disabled due to tensors being identical. Deduplication cannot be disabled
completely because we use a content addressable SHA1_Store class. completely because we use a content addressable SHA1_Store class.
Default: False Default: True
gzip (bool, optional): gzip (bool, optional):
Enable gzip based lossless compression on the object being added. Enable gzip based lossless compression on the object being added.
Default: True Default: True
......
...@@ -273,7 +273,10 @@ class SHA1_Store: ...@@ -273,7 +273,10 @@ class SHA1_Store:
# Update the sizes for this entry. # Update the sizes for this entry.
entry = _get_json_entry(self._json_dict[sha1_hash]) entry = _get_json_entry(self._json_dict[sha1_hash])
o_diff = orig_size if ref_count == 1 else entry[ENTRY_OS_KEY] assert (
ref_count == 1 or entry[ENTRY_OS_KEY] % (ref_count - 1) == 0
), f"incorrect size: {entry[ENTRY_OS_KEY]} and {ref_count}"
o_diff = orig_size if ref_count == 1 else (entry[ENTRY_OS_KEY] // (ref_count - 1))
d_diff = orig_size if ref_count == 1 else 0 d_diff = orig_size if ref_count == 1 else 0
c_diff = comp_size if ref_count == 1 else 0 c_diff = comp_size if ref_count == 1 else 0
entry[ENTRY_OS_KEY] += o_diff entry[ENTRY_OS_KEY] += o_diff
......
...@@ -85,7 +85,8 @@ def test_api_commit(capsys, repo): ...@@ -85,7 +85,8 @@ def test_api_commit(capsys, repo):
assert line[0].rstrip().split()[-1] == commit_msg assert line[0].rstrip().split()[-1] == commit_msg
def test_api_status(capsys, repo): @pytest.mark.parametrize("per_tensor", [True, False])
def test_api_status(capsys, repo, per_tensor):
# delete the repo and initialize a new one: # delete the repo and initialize a new one:
shutil.rmtree(".wgit") shutil.rmtree(".wgit")
repo = Repo(Path.cwd(), init=True) repo = Repo(Path.cwd(), init=True)
...@@ -96,7 +97,7 @@ def test_api_status(capsys, repo): ...@@ -96,7 +97,7 @@ def test_api_status(capsys, repo):
# check status before after a file is added but not committed # check status before after a file is added but not committed
chkpt0 = f"checkpoint_{random.randint(0, 1)}.pt" chkpt0 = f"checkpoint_{random.randint(0, 1)}.pt"
repo.add(chkpt0) repo.add(chkpt0, per_tensor=per_tensor)
out = repo.status() out = repo.status()
key_list = list(repo._get_metdata_files().keys()) key_list = list(repo._get_metdata_files().keys())
assert out == {key_list[0]: RepoStatus.CHANGES_ADDED_NOT_COMMITED} assert out == {key_list[0]: RepoStatus.CHANGES_ADDED_NOT_COMMITED}
...@@ -107,18 +108,18 @@ def test_api_status(capsys, repo): ...@@ -107,18 +108,18 @@ def test_api_status(capsys, repo):
assert out == {key_list[0]: RepoStatus.CLEAN} assert out == {key_list[0]: RepoStatus.CLEAN}
# check status after a new change has been made to the file # check status after a new change has been made to the file
torch.save(nn.Linear(1, int(15e5)), chkpt0) torch.save(nn.Linear(1, int(15e5)).state_dict(), chkpt0)
out = repo.status() out = repo.status()
assert out == {key_list[0]: RepoStatus.CHANGES_NOT_ADDED} assert out == {key_list[0]: RepoStatus.CHANGES_NOT_ADDED}
# add the new changes made to weigit # add the new changes made to weigit
repo.add(chkpt0) repo.add(chkpt0, per_tensor=per_tensor)
out = repo.status() out = repo.status()
assert out == {key_list[0]: RepoStatus.CHANGES_ADDED_NOT_COMMITED} assert out == {key_list[0]: RepoStatus.CHANGES_ADDED_NOT_COMMITED}
# check status after a new different file is added to be tracked by weigit # check status after a new different file is added to be tracked by weigit
chkpt3 = "checkpoint_3.pt" chkpt3 = "checkpoint_3.pt"
repo.add(chkpt3) repo.add(chkpt3, per_tensor=per_tensor)
key_list = list(repo._get_metdata_files().keys()) key_list = list(repo._get_metdata_files().keys())
out = repo.status() out = repo.status()
assert out == { assert out == {
......
...@@ -35,7 +35,7 @@ def create_test_dir(): ...@@ -35,7 +35,7 @@ def create_test_dir():
# create random checkpoints # create random checkpoints
size_list = [30e5, 35e5, 40e5] size_list = [30e5, 35e5, 40e5]
for i, size in enumerate(size_list): for i, size in enumerate(size_list):
torch.save(nn.Linear(1, int(size)), f"checkpoint_{i}.pt") torch.save(nn.Linear(1, int(size)).state_dict(), f"checkpoint_{i}.pt")
# Test init. # Test init.
cli.main(["init"]) cli.main(["init"])
...@@ -53,7 +53,7 @@ def test_cli_init(create_test_dir, capsys): ...@@ -53,7 +53,7 @@ def test_cli_init(create_test_dir, capsys):
def test_cli_add(create_test_dir, capsys): def test_cli_add(create_test_dir, capsys):
chkpt0 = "checkpoint_0.pt" chkpt0 = "checkpoint_0.pt"
cli.main(["add", chkpt0]) cli.main(["add", "--no_per_tensor", chkpt0])
sha1_store = SHA1_Store( sha1_store = SHA1_Store(
Path.cwd().joinpath(".wgit", "sha1_store"), Path.cwd().joinpath(".wgit", "sha1_store"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment