Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
c75d1896
Unverified
Commit
c75d1896
authored
Jul 14, 2022
by
Min Xu
Committed by
GitHub
Jul 14, 2022
Browse files
[feat] add sha1_store delete function (#1028)
Co-authored-by:
Min Xu
<
min.xu.public@gmail.com
>
parent
073618d8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
83 additions
and
9 deletions
+83
-9
fairscale/experimental/wgit/sha1_store.py
fairscale/experimental/wgit/sha1_store.py
+55
-5
tests/experimental/wgit/test_sha1_store.py
tests/experimental/wgit/test_sha1_store.py
+28
-4
No files found.
fairscale/experimental/wgit/sha1_store.py
View file @
c75d1896
...
...
@@ -23,6 +23,20 @@ from .utils import ExitCode
# for backward compatibility reasons.
SHA1_STORE_DIR_NAME
=
"sha1_store"
# Const string keys for json file. Do not change for backward compatibilities.
RF_KEY
=
"ref_count"
def
_get_json_entry
(
d
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
Any
]:
"""Get a dict from a json entry.
This fills in any missing entries in case we load an older version
json file from the disk.
"""
if
RF_KEY
not
in
d
.
keys
():
d
[
RF_KEY
]
=
0
return
d
class
SHA1_Store
:
"""
...
...
@@ -181,6 +195,9 @@ class SHA1_Store:
Returns:
(Tensor or OrderedDict):
In-memory object.
Throws:
ValueError if sha1 is not found.
"""
path
=
self
.
_sha1_to_dir
(
sha1
).
joinpath
(
sha1
)
if
not
path
.
exists
():
...
...
@@ -190,6 +207,9 @@ class SHA1_Store:
# Directly return the object after loading it. This could be throw an
# exception but that indicates some internal error since we should never
# have stored the (invalid) object in the first place with the add() API.
#
# TODO (Min): we could also keep a stats in the meta data on how many
# times the object is read. Will add if that's needed.
return
torch
.
load
(
path
)
def
delete
(
self
,
sha1
:
str
)
->
None
:
...
...
@@ -199,8 +219,34 @@ class SHA1_Store:
sha1 (str):
SHA1 of the object to delete.
Throws:
ValueError if sha1 is not found.
"""
raise
NotImplementedError
()
path
=
self
.
_sha1_to_dir
(
sha1
).
joinpath
(
sha1
)
if
not
path
.
exists
():
# This is potentially a valid case for the caller, we need to inform the
# the caller about it.
raise
ValueError
(
f
"Try to delete SHA1
{
sha1
}
but it is not found"
)
self
.
_load_json_dict
()
assert
sha1
in
self
.
_json_dict
.
keys
(),
"internal error: sha1 not found in json"
entry
=
_get_json_entry
(
self
.
_json_dict
[
sha1
])
assert
entry
[
RF_KEY
]
>
0
,
f
"ref count
{
entry
[
RF_KEY
]
}
should be positive"
entry
[
RF_KEY
]
-=
1
if
entry
[
RF_KEY
]
==
0
:
# Now, since ref count is 0 now deleting the object.
path
.
unlink
()
# We may leave behind an empty dir, which is OK.
entry
=
{}
# Below, we remove the entry because of this.
# Put the entry back and store it or delete it.
if
entry
:
self
.
_json_dict
[
sha1
]
=
entry
else
:
# empty entry, it means this sha1 is deleted.
del
self
.
_json_dict
[
sha1
]
self
.
_store_json_dict
()
def
_get_sha1_hash
(
self
,
file_path
:
Union
[
str
,
Path
])
->
str
:
"""Return the sha1 hash of a file
...
...
@@ -257,12 +303,16 @@ class SHA1_Store:
# Init the entry if needed.
if
current_sha1_hash
not
in
self
.
_json_dict
:
self
.
_json_dict
[
current_sha1_hash
]
=
0
entry
=
{}
else
:
entry
=
self
.
_json_dict
[
current_sha1_hash
]
entry
=
_get_json_entry
(
entry
)
# Update the ref count.
self
.
_json_dict
[
current_sha1_hash
]
+=
1
if
inc
else
-
1
assert
self
.
_json_dict
[
current_sha1_hash
]
>=
0
,
"negative ref count"
entry
[
RF_KEY
]
+=
1
if
inc
else
-
1
assert
entry
[
RF_KEY
]
>=
0
,
"negative ref count"
self
.
_json_dict
[
current_sha1_hash
]
=
entry
self
.
_store_json_dict
()
return
self
.
_json_dict
[
current_sha1_hash
]
return
entry
[
RF_KEY
]
tests/experimental/wgit/test_sha1_store.py
View file @
c75d1896
...
...
@@ -81,9 +81,9 @@ def test_sha1_add_file(sha1_store):
if
torch_version
()
>=
(
1
,
9
,
0
):
# torch 1.8 LTS doesn't produce deterministic checkpoint file from fixed tensors/state_dict.
key
=
"da3e19590de8f77fcf7a09c888c526b0149863a0"
assert
key
in
json_dict
.
keys
()
and
json_dict
[
key
]
==
2
,
json_dict
assert
key
in
json_dict
.
keys
()
and
json_dict
[
key
]
[
"ref_count"
]
==
2
,
json_dict
del
json_dict
[
"created_on"
]
assert
sorted
(
json_dict
.
values
())
==
[
1
,
1
,
1
,
1
,
1
,
2
],
json_dict
assert
sorted
(
map
(
lambda
x
:
x
[
"ref_count"
],
json_dict
.
values
())
)
==
[
1
,
1
,
1
,
1
,
1
,
2
],
json_dict
def
test_sha1_add_state_dict
(
sha1_store
):
...
...
@@ -100,7 +100,7 @@ def test_sha1_add_state_dict(sha1_store):
sha1_store
.
_load_json_dict
()
json_dict
=
sha1_store
.
_json_dict
del
json_dict
[
"created_on"
]
assert
sorted
(
json_dict
.
values
())
==
[
1
,
1
,
1
,
2
,
2
,
2
],
json_dict
assert
sorted
(
map
(
lambda
x
:
x
[
"ref_count"
],
json_dict
.
values
())
)
==
[
1
,
1
,
1
,
2
,
2
,
2
],
json_dict
def
test_sha1_add_tensor
(
sha1_store
):
...
...
@@ -111,10 +111,11 @@ def test_sha1_add_tensor(sha1_store):
if
torch_version
()
>=
(
1
,
9
,
0
):
# torch 1.8 LTS doesn't produce deterministic checkpoint file from fixed tensors/state_dict.
key
=
"71df4069a03a766eacf9f03eea50968e87eae9f8"
assert
key
in
json_dict
.
keys
()
and
json_dict
[
key
]
==
1
,
json_dict
assert
key
in
json_dict
.
keys
()
and
json_dict
[
key
]
[
"ref_count"
]
==
1
,
json_dict
def
test_sha1_get
(
sha1_store
):
"""Testing the get() API: normal and exception cases."""
os
.
chdir
(
PARENT_DIR
)
# Add a file, a state dict and a tensor.
...
...
@@ -139,3 +140,26 @@ def test_sha1_get(sha1_store):
# Make sure invalid sha1 cause exceptions.
with
pytest
.
raises
(
ValueError
):
sha1_store
.
get
(
tensor_sha1
[:
-
1
])
def
test_sha1_delete
(
sha1_store
):
"""Testing the delete() API: with ref counting behavior."""
os
.
chdir
(
PARENT_DIR
)
# Add once and delete, second delete should throw an exception.
tensor
=
torch
.
ones
(
30
,
50
)
sha1
=
sha1_store
.
add
(
tensor
)
sha1_store
.
delete
(
sha1
)
with
pytest
.
raises
(
ValueError
):
sha1_store
.
delete
(
sha1
)
# Add multiple times and delete should match that.
state_dict
=
nn
.
Sequential
(
nn
.
Linear
(
10
,
10
),
nn
.
Linear
(
10
,
20
)).
state_dict
()
sha1
=
sha1_store
.
add
(
state_dict
)
for
i
in
range
(
3
):
new_sha1
=
sha1_store
.
add
(
state_dict
)
assert
sha1
==
new_sha1
,
f
"
{
sha1
}
vs.
{
new_sha1
}
"
for
i
in
range
(
4
):
sha1_store
.
delete
(
sha1
)
with
pytest
.
raises
(
ValueError
):
sha1_store
.
delete
(
sha1
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment