Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
073618d8
Unverified
Commit
073618d8
authored
Jul 14, 2022
by
Min Xu
Committed by
GitHub
Jul 14, 2022
Browse files
[feat] add sha1_store get function (#1027)
Co-authored-by:
Min Xu
<
min.xu.public@gmail.com
>
parent
68af57d8
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
37 additions
and
1 deletion
+37
-1
fairscale/experimental/wgit/sha1_store.py
fairscale/experimental/wgit/sha1_store.py
+9
-1
tests/experimental/wgit/test_sha1_store.py
tests/experimental/wgit/test_sha1_store.py
+28
-0
No files found.
fairscale/experimental/wgit/sha1_store.py
View file @
073618d8
...
@@ -182,7 +182,15 @@ class SHA1_Store:
...
@@ -182,7 +182,15 @@ class SHA1_Store:
(Tensor or OrderedDict):
(Tensor or OrderedDict):
In-memory object.
In-memory object.
"""
"""
raise
NotImplementedError
()
path
=
self
.
_sha1_to_dir
(
sha1
).
joinpath
(
sha1
)
if
not
path
.
exists
():
# This is potentially valid case for the caller, we need to inform the
# the caller about it.
raise
ValueError
(
f
"Try to get SHA1
{
sha1
}
but it is not found"
)
# Directly return the object after loading it. This could be throw an
# exception but that indicates some internal error since we should never
# have stored the (invalid) object in the first place with the add() API.
return
torch
.
load
(
path
)
def
delete
(
self
,
sha1
:
str
)
->
None
:
def
delete
(
self
,
sha1
:
str
)
->
None
:
"""Delete a SHA1
"""Delete a SHA1
...
...
tests/experimental/wgit/test_sha1_store.py
View file @
073618d8
...
@@ -11,6 +11,7 @@ import pytest
...
@@ -11,6 +11,7 @@ import pytest
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
fair_dev.testing.testing
import
objects_are_equal
from
fairscale.experimental.wgit.sha1_store
import
SHA1_Store
from
fairscale.experimental.wgit.sha1_store
import
SHA1_Store
from
fairscale.internal
import
torch_version
from
fairscale.internal
import
torch_version
...
@@ -111,3 +112,30 @@ def test_sha1_add_tensor(sha1_store):
...
@@ -111,3 +112,30 @@ def test_sha1_add_tensor(sha1_store):
# torch 1.8 LTS doesn't produce deterministic checkpoint file from fixed tensors/state_dict.
# torch 1.8 LTS doesn't produce deterministic checkpoint file from fixed tensors/state_dict.
key
=
"71df4069a03a766eacf9f03eea50968e87eae9f8"
key
=
"71df4069a03a766eacf9f03eea50968e87eae9f8"
assert
key
in
json_dict
.
keys
()
and
json_dict
[
key
]
==
1
,
json_dict
assert
key
in
json_dict
.
keys
()
and
json_dict
[
key
]
==
1
,
json_dict
def
test_sha1_get
(
sha1_store
):
os
.
chdir
(
PARENT_DIR
)
# Add a file, a state dict and a tensor.
file
=
"test_get.pt"
torch
.
save
(
nn
.
Linear
(
100
,
100
).
state_dict
(),
file
)
state_dict
=
nn
.
Sequential
(
nn
.
Linear
(
10
,
10
),
nn
.
Linear
(
10
,
20
)).
state_dict
()
tensor
=
torch
.
ones
(
20
,
30
)
# Check that we can get them back.
file_sha1
=
sha1_store
.
add
(
file
)
sd
=
sha1_store
.
get
(
file_sha1
)
assert
objects_are_equal
(
sd
,
torch
.
load
(
file
))
sd_sha1
=
sha1_store
.
add
(
state_dict
)
sd
=
sha1_store
.
get
(
sd_sha1
)
assert
objects_are_equal
(
sd
,
state_dict
)
tensor_sha1
=
sha1_store
.
add
(
tensor
)
tensor_got
=
sha1_store
.
get
(
tensor_sha1
)
assert
objects_are_equal
(
tensor_got
,
tensor
)
# Make sure invalid sha1 cause exceptions.
with
pytest
.
raises
(
ValueError
):
sha1_store
.
get
(
tensor_sha1
[:
-
1
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment