Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
d0ad08c0
Unverified
Commit
d0ad08c0
authored
Jul 18, 2022
by
Min Xu
Committed by
GitHub
Jul 18, 2022
Browse files
[feat] add compression and tests to sha1 store (#1032)
Co-authored-by:
Min Xu
<
min.xu.public@gmail.com
>
parent
c8327e1c
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
114 additions
and
28 deletions
+114
-28
fairscale/experimental/wgit/__init__.py
fairscale/experimental/wgit/__init__.py
+15
-0
fairscale/experimental/wgit/sha1_store.py
fairscale/experimental/wgit/sha1_store.py
+72
-9
requirements-dev.txt
requirements-dev.txt
+4
-1
tests/experimental/wgit/test_sha1_store.py
tests/experimental/wgit/test_sha1_store.py
+23
-18
No files found.
fairscale/experimental/wgit/__init__.py
View file @
d0ad08c0
...
...
@@ -3,8 +3,23 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import
sys
from
typing
import
List
# Check for user requirements before we import our code.
try
:
import
pygit2
except
ImportError
:
print
(
"Error: please pip install pygit2 module to use wgit"
)
sys
.
exit
(
1
)
try
:
import
pgzip
except
ImportError
:
print
(
"Error: please pip install pgzip module to use wgit"
)
sys
.
exit
(
1
)
from
.repo
import
Repo
from
.signal_sparsity
import
Algo
,
SignalSparsity
from
.version
import
__version_tuple__
...
...
fairscale/experimental/wgit/sha1_store.py
View file @
d0ad08c0
...
...
@@ -12,8 +12,9 @@ import shutil
import
sys
import
tempfile
import
time
from
typing
import
Any
,
Dict
,
Union
,
cast
from
typing
import
Any
,
Dict
,
Optional
,
Union
,
cast
import
pgzip
import
torch
from
torch
import
Tensor
...
...
@@ -25,6 +26,7 @@ SHA1_STORE_DIR_NAME = "sha1_store"
# Const string keys for json file. Do not change for backward compatibilities.
RF_KEY
=
"ref_count"
COMP_KEY
=
"compressed"
def
_get_json_entry
(
d
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
Any
]:
...
...
@@ -38,6 +40,28 @@ def _get_json_entry(d: Dict[str, Any]) -> Dict[str, Any]:
return
d
def
_copy_compressed
(
src
:
Path
,
dest
:
Path
,
thread
:
Optional
[
int
],
blocksize
:
int
)
->
None
:
"""Helper to copy a file and compress it at the same time."""
with
open
(
str
(
src
),
"rb"
)
as
srcf
:
with
pgzip
.
open
(
str
(
dest
),
"wb"
,
compresslevel
=
5
,
thread
=
thread
,
blocksize
=
blocksize
)
as
destf
:
while
True
:
buf
=
srcf
.
read
(
blocksize
)
if
len
(
buf
)
==
0
:
break
destf
.
write
(
buf
)
def
_copy_uncompressed
(
src
:
Path
,
dest
:
Path
,
thread
:
Optional
[
int
],
blocksize
:
int
)
->
None
:
"""Helper to copy a file and uncompress it at the same time."""
with
open
(
str
(
dest
),
"wb"
)
as
destf
:
with
pgzip
.
open
(
str
(
src
),
"rb"
,
thread
=
thread
,
blocksize
=
blocksize
)
as
srcf
:
while
True
:
buf
=
srcf
.
read
(
blocksize
)
if
len
(
buf
)
==
0
:
break
destf
.
write
(
buf
)
class
SHA1_Store
:
"""
This class represents a SHA1 checksum based storage dir for state_dict
...
...
@@ -61,6 +85,12 @@ class SHA1_Store:
to delete in a version tracking graph. The lesson here is that content
addressibility and dependency graphs do not mix well.
We support multicore compression for the data to be store on per-object basis.
The ``torch.save()`` API uses zip format to store the data, but it appears to
be uncompressed. Even if it can be made compressed, it is likely a single
threaded compression. Therefore, we use pgzip to do parallel
compression/decompression on top of it to use all the cores.
Args:
parent_path (Path):
The parent path in which a SHA1_Store will be created.
...
...
@@ -75,16 +105,29 @@ class SHA1_Store:
sha1_buf_size (int):
Buffer size used for checksumming. Default: 100MB.
tmp_dir (str):
Dir for temporary files if input is an in-memory object.
Dir for temporary files if input is an in-memory object or output data needs
to be decompressed first.
pgzip_threads (int, optional):
Number of threads (cores) used in compression. Default: None to use all cores.
pgzip_block_size (int):
Per-thread block size for compression. Default: 10MB.
"""
def
__init__
(
self
,
parent_path
:
Path
,
init
:
bool
=
False
,
sha1_buf_size
:
int
=
100
*
1024
*
1024
,
tmp_dir
:
str
=
""
self
,
parent_path
:
Path
,
init
:
bool
=
False
,
sha1_buf_size
:
int
=
100
*
1024
*
1024
,
tmp_dir
:
str
=
""
,
pgzip_threads
:
Optional
[
int
]
=
None
,
pgzip_block_size
:
int
=
10
*
1024
*
1024
,
)
->
None
:
"""Create or wrap (if already exists) a sha1_store."""
self
.
_path
=
parent_path
.
joinpath
(
SHA1_STORE_DIR_NAME
)
self
.
_ref_file_path
=
self
.
_path
.
joinpath
(
"ref_count.json"
)
self
.
_sha1_buf_size
=
sha1_buf_size
self
.
_pgzip_threads
=
pgzip_threads
self
.
_pgzip_block_size
=
pgzip_block_size
self
.
_json_dict
:
Dict
[
str
,
Any
]
=
{
"created_on"
:
time
.
ctime
()}
# Initialize the sha1_store if not exist and init==True.
...
...
@@ -121,7 +164,7 @@ class SHA1_Store:
with
open
(
self
.
_ref_file_path
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
json
.
dump
(
self
.
_json_dict
,
f
,
ensure_ascii
=
False
,
indent
=
4
)
def
add
(
self
,
file_or_obj
:
Union
[
Path
,
Tensor
,
OrderedDict
])
->
str
:
def
add
(
self
,
file_or_obj
:
Union
[
Path
,
Tensor
,
OrderedDict
]
,
compress
:
bool
=
False
)
->
str
:
"""
Adds a file/object to the internal sha1_store and the sha1 references
accordingly.
...
...
@@ -130,6 +173,9 @@ class SHA1_Store:
in <file_or_obj> is moved within the sha1_store and the reference file is updated.
If the input is an object, it will be store in the self._tmp_dir and then moved.
If compress is True, the stored file is also compressed, which is useful for tensors
with a lot of zeros.
Args:
file_or_obj (str or tensor or OrderedDict):
Path to the file to be added to the sha1_store or an in-memory object
...
...
@@ -155,7 +201,7 @@ class SHA1_Store:
sha1_hash
=
self
.
_get_sha1_hash
(
file_path
)
# Add reference.
ref_count
=
self
.
_add_ref
(
sha1_hash
,
True
)
ref_count
=
self
.
_add_ref
(
sha1_hash
,
True
,
compress
)
if
ref_count
==
1
:
# First time adding
...
...
@@ -172,12 +218,15 @@ class SHA1_Store:
# Transfer the file to the internal sha1_store
repo_fpath
=
repo_fdir
.
joinpath
(
sha1_hash
)
try
:
if
compress
:
_copy_compressed
(
file_path
,
repo_fpath
,
self
.
_pgzip_threads
,
self
.
_pgzip_block_size
)
else
:
shutil
.
copy2
(
file_path
,
repo_fpath
)
except
BaseException
as
error
:
# Something went wrong, perhaps out of space, or race condition due to lack of locking.
# TODO (Min): proper handle the error and recover when we learn more here.
sys
.
stderr
.
write
(
f
"An exception occured:
{
repr
(
error
)
}
\n
"
)
ref_count
=
self
.
_add_ref
(
sha1_hash
,
False
)
ref_count
=
self
.
_add_ref
(
sha1_hash
,
False
,
compress
)
# Clean up if needed.
if
remove_tmp
:
...
...
@@ -210,6 +259,17 @@ class SHA1_Store:
#
# TODO (Min): we could also keep a stats in the meta data on how many
# times the object is read. Will add if that's needed.
self
.
_load_json_dict
()
if
self
.
_json_dict
[
sha1
][
COMP_KEY
]:
# Compressed. Because pgzip doesn't support tell() yet, we need to
# uncompress into a temp file and return it.
tmp
=
self
.
_get_tmp_file_path
()
_copy_uncompressed
(
path
,
tmp
,
self
.
_pgzip_threads
,
self
.
_pgzip_block_size
)
obj
=
torch
.
load
(
tmp
)
tmp
.
unlink
()
return
obj
else
:
# Uncompressed.
return
torch
.
load
(
path
)
def
delete
(
self
,
sha1
:
str
)
->
None
:
...
...
@@ -282,7 +342,7 @@ class SHA1_Store:
part1
,
part2
=
sha1
[:
2
],
sha1
[
2
:
4
]
return
self
.
_path
.
joinpath
(
part1
,
part2
)
def
_add_ref
(
self
,
current_sha1_hash
:
str
,
inc
:
bool
)
->
int
:
def
_add_ref
(
self
,
current_sha1_hash
:
str
,
inc
:
bool
,
compressed
:
bool
)
->
int
:
"""
Update the reference count.
...
...
@@ -312,6 +372,9 @@ class SHA1_Store:
entry
[
RF_KEY
]
+=
1
if
inc
else
-
1
assert
entry
[
RF_KEY
]
>=
0
,
"negative ref count"
# Update compressed flag.
entry
[
COMP_KEY
]
=
compressed
self
.
_json_dict
[
current_sha1_hash
]
=
entry
self
.
_store_json_dict
()
...
...
requirements-dev.txt
View file @
d0ad08c0
...
...
@@ -34,5 +34,8 @@ numpy == 1.22.0
# For layerwise gradient scaler
sklearn >= 0.0
# For weigit
# For weigit. These are actually user requirements, not developer requirements.
# However, due to the experimental nature of weigit, we don't expose to the
# general users of fairscale yet. We check for them in weigit's init code.
pygit2==1.9.2
pgzip==0.3.1
tests/experimental/wgit/test_sha1_store.py
View file @
d0ad08c0
...
...
@@ -47,7 +47,8 @@ def sha1_store(request):
return
sha1_store
def
test_sha1_add_file
(
sha1_store
):
@
pytest
.
mark
.
parametrize
(
"compress"
,
[
True
,
False
])
def
test_sha1_add_file
(
sha1_store
,
compress
):
os
.
chdir
(
PARENT_DIR
)
# Create random checkpoints
...
...
@@ -65,15 +66,15 @@ def test_sha1_add_file(sha1_store):
# Add those 5 random files.
for
c
in
chkpts
:
sha1_store
.
add
(
c
)
sha1_store
.
add
(
c
,
compress
)
# Add a fixed data twice.
module
=
nn
.
Linear
(
100
,
100
,
bias
=
False
)
module
.
weight
.
data
=
torch
.
zeros
(
100
,
100
)
zeros_file
=
"zeros.pt"
torch
.
save
(
module
.
state_dict
(),
zeros_file
)
sha1_store
.
add
(
zeros_file
)
sha1_store
.
add
(
zeros_file
)
sha1_store
.
add
(
zeros_file
,
compress
)
sha1_store
.
add
(
zeros_file
,
compress
)
# Assert the ref counts are 1,1,1,1,1 and 2
sha1_store
.
_load_json_dict
()
...
...
@@ -86,16 +87,17 @@ def test_sha1_add_file(sha1_store):
assert
sorted
(
map
(
lambda
x
:
x
[
"ref_count"
],
json_dict
.
values
()))
==
[
1
,
1
,
1
,
1
,
1
,
2
],
json_dict
def
test_sha1_add_state_dict
(
sha1_store
):
@
pytest
.
mark
.
parametrize
(
"compress"
,
[
True
,
False
])
def
test_sha1_add_state_dict
(
sha1_store
,
compress
):
os
.
chdir
(
PARENT_DIR
)
# add once
for
i
in
range
(
3
):
sha1_store
.
add
(
nn
.
Linear
(
10
,
10
).
state_dict
())
sha1_store
.
add
(
nn
.
Linear
(
10
,
10
).
state_dict
()
,
compress
)
# add twice
for
i
in
range
(
3
):
sd
=
nn
.
Linear
(
8
,
8
).
state_dict
()
sha1_store
.
add
(
sd
)
sha1_store
.
add
(
sd
)
sha1_store
.
add
(
sd
,
compress
)
sha1_store
.
add
(
sd
,
compress
)
sha1_store
.
_load_json_dict
()
json_dict
=
sha1_store
.
_json_dict
...
...
@@ -103,9 +105,10 @@ def test_sha1_add_state_dict(sha1_store):
assert
sorted
(
map
(
lambda
x
:
x
[
"ref_count"
],
json_dict
.
values
()))
==
[
1
,
1
,
1
,
2
,
2
,
2
],
json_dict
def
test_sha1_add_tensor
(
sha1_store
):
@
pytest
.
mark
.
parametrize
(
"compress"
,
[
True
,
False
])
def
test_sha1_add_tensor
(
sha1_store
,
compress
):
os
.
chdir
(
PARENT_DIR
)
sha1_store
.
add
(
torch
.
Tensor
([
1.0
,
5.5
,
3.4
]))
sha1_store
.
add
(
torch
.
Tensor
([
1.0
,
5.5
,
3.4
])
,
compress
)
sha1_store
.
_load_json_dict
()
json_dict
=
sha1_store
.
_json_dict
if
torch_version
()
>=
(
1
,
9
,
0
):
...
...
@@ -114,7 +117,8 @@ def test_sha1_add_tensor(sha1_store):
assert
key
in
json_dict
.
keys
()
and
json_dict
[
key
][
"ref_count"
]
==
1
,
json_dict
def
test_sha1_get
(
sha1_store
):
@
pytest
.
mark
.
parametrize
(
"compress"
,
[
True
,
False
])
def
test_sha1_get
(
sha1_store
,
compress
):
"""Testing the get() API: normal and exception cases."""
os
.
chdir
(
PARENT_DIR
)
...
...
@@ -125,15 +129,15 @@ def test_sha1_get(sha1_store):
tensor
=
torch
.
ones
(
20
,
30
)
# Check that we can get them back.
file_sha1
=
sha1_store
.
add
(
file
)
file_sha1
=
sha1_store
.
add
(
file
,
compress
)
sd
=
sha1_store
.
get
(
file_sha1
)
assert
objects_are_equal
(
sd
,
torch
.
load
(
file
))
sd_sha1
=
sha1_store
.
add
(
state_dict
)
sd_sha1
=
sha1_store
.
add
(
state_dict
,
compress
)
sd
=
sha1_store
.
get
(
sd_sha1
)
assert
objects_are_equal
(
sd
,
state_dict
)
tensor_sha1
=
sha1_store
.
add
(
tensor
)
tensor_sha1
=
sha1_store
.
add
(
tensor
,
compress
)
tensor_got
=
sha1_store
.
get
(
tensor_sha1
)
assert
objects_are_equal
(
tensor_got
,
tensor
)
...
...
@@ -142,22 +146,23 @@ def test_sha1_get(sha1_store):
sha1_store
.
get
(
tensor_sha1
[:
-
1
])
def
test_sha1_delete
(
sha1_store
):
@
pytest
.
mark
.
parametrize
(
"compress"
,
[
True
,
False
])
def
test_sha1_delete
(
sha1_store
,
compress
):
"""Testing the delete() API: with ref counting behavior."""
os
.
chdir
(
PARENT_DIR
)
# Add once and delete, second delete should throw an exception.
tensor
=
torch
.
ones
(
30
,
50
)
sha1
=
sha1_store
.
add
(
tensor
)
sha1
=
sha1_store
.
add
(
tensor
,
compress
)
sha1_store
.
delete
(
sha1
)
with
pytest
.
raises
(
ValueError
):
sha1_store
.
delete
(
sha1
)
# Add multiple times and delete should match that.
state_dict
=
nn
.
Sequential
(
nn
.
Linear
(
10
,
10
),
nn
.
Linear
(
10
,
20
)).
state_dict
()
sha1
=
sha1_store
.
add
(
state_dict
)
sha1
=
sha1_store
.
add
(
state_dict
,
compress
)
for
i
in
range
(
3
):
new_sha1
=
sha1_store
.
add
(
state_dict
)
new_sha1
=
sha1_store
.
add
(
state_dict
,
compress
)
assert
sha1
==
new_sha1
,
f
"
{
sha1
}
vs.
{
new_sha1
}
"
for
i
in
range
(
4
):
sha1_store
.
delete
(
sha1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment