Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
b7062b5d
Unverified
Commit
b7062b5d
authored
Apr 09, 2021
by
liuzhe-lz
Committed by
GitHub
Apr 09, 2021
Browse files
[Model Compression / TensorFlow] Support exporting pruned model (#3487)
parent
f0e3c584
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
117 additions
and
6 deletions
+117
-6
nni/compression/tensorflow/compressor.py
nni/compression/tensorflow/compressor.py
+86
-4
test/ut/sdk/test_compressor_tf.py
test/ut/sdk/test_compressor_tf.py
+31
-2
No files found.
nni/compression/tensorflow/compressor.py
View file @
b7062b5d
...
...
@@ -87,6 +87,18 @@ class Compressor:
return
layer
def
_uninstrument
(
self
,
layer
):
# note that ``self._wrappers`` cache is not cleared here,
# so the same wrapper objects will be recovered in next ``self._instrument()`` call
if
isinstance
(
layer
,
LayerWrapper
):
layer
.
_instrumented
=
False
return
self
.
_uninstrument
(
layer
.
layer
)
if
isinstance
(
layer
,
tf
.
keras
.
Sequential
):
return
self
.
_uninstrument_sequential
(
layer
)
if
isinstance
(
layer
,
tf
.
keras
.
Model
):
return
self
.
_uninstrument_model
(
layer
)
return
layer
def
_instrument_sequential
(
self
,
seq
):
layers
=
list
(
seq
.
layers
)
# seq.layers is read-only property
need_rebuild
=
False
...
...
@@ -97,6 +109,16 @@ class Compressor:
need_rebuild
=
True
return
tf
.
keras
.
Sequential
(
layers
)
if
need_rebuild
else
seq
def
_uninstrument_sequential
(
self
,
seq
):
layers
=
list
(
seq
.
layers
)
rebuilt
=
False
for
i
,
layer
in
enumerate
(
layers
):
orig_layer
=
self
.
_uninstrument
(
layer
)
if
orig_layer
is
not
layer
:
layers
[
i
]
=
orig_layer
rebuilt
=
True
return
tf
.
keras
.
Sequential
(
layers
)
if
rebuilt
else
seq
def
_instrument_model
(
self
,
model
):
for
key
,
value
in
list
(
model
.
__dict__
.
items
()):
# avoid "dictionary keys changed during iteration"
if
isinstance
(
value
,
tf
.
keras
.
layers
.
Layer
):
...
...
@@ -109,6 +131,17 @@ class Compressor:
value
[
i
]
=
self
.
_instrument
(
item
)
return
model
def
_uninstrument_model
(
self
,
model
):
for
key
,
value
in
list
(
model
.
__dict__
.
items
()):
if
isinstance
(
value
,
tf
.
keras
.
layers
.
Layer
):
orig_layer
=
self
.
_uninstrument
(
value
)
if
orig_layer
is
not
value
:
setattr
(
model
,
key
,
orig_layer
)
elif
isinstance
(
value
,
list
):
for
i
,
item
in
enumerate
(
value
):
if
isinstance
(
item
,
tf
.
keras
.
layers
.
Layer
):
value
[
i
]
=
self
.
_uninstrument
(
item
)
return
model
def
_select_config
(
self
,
layer
):
# Find the last matching config block for given layer.
...
...
@@ -129,6 +162,17 @@ class Compressor:
return
last_match
class
LayerWrapper
(
tf
.
keras
.
Model
):
"""
Abstract base class of layer wrappers.
Concrete layer wrapper classes must inherit this to support ``isinstance`` check.
"""
def
__init__
(
self
):
super
().
__init__
()
self
.
_instrumented
=
True
class
Pruner
(
Compressor
):
"""
Base class for pruning algorithms.
...
...
@@ -167,6 +211,43 @@ class Pruner(Compressor):
self
.
_update_mask
()
return
self
.
compressed_model
def
export_model
(
self
,
model_path
,
mask_path
=
None
):
"""
Export pruned model and optionally mask tensors.
Parameters
----------
model_path : path-like
The path passed to ``Model.save()``.
You can use ".h5" extension name to export HDF5 format.
mask_path : path-like or None
Export masks to the path when set.
Because Keras cannot save tensors without a ``Model``,
this will create a model, set all masks as its weights, and then save that model.
Masks in saved model will be named by corresponding layer name in compressed model.
Returns
-------
None
"""
_logger
.
info
(
'Saving model to %s'
,
model_path
)
input_shape
=
self
.
compressed_model
.
_build_input_shape
# cannot find a public API
model
=
self
.
_uninstrument
(
self
.
compressed_model
)
if
input_shape
:
model
.
build
(
input_shape
)
model
.
save
(
model_path
)
self
.
_instrument
(
model
)
if
mask_path
is
not
None
:
_logger
.
info
(
'Saving masks to %s'
,
mask_path
)
# can't find "save raw weights" API in tensorflow, so build a simple model
mask_model
=
tf
.
keras
.
Model
()
for
wrapper
in
self
.
wrappers
:
setattr
(
mask_model
,
wrapper
.
layer
.
name
,
wrapper
.
masks
)
mask_model
.
save_weights
(
mask_path
)
_logger
.
info
(
'Done'
)
def
calc_masks
(
self
,
wrapper
,
**
kwargs
):
"""
Abstract method to be overridden by algorithm. End users should ignore it.
...
...
@@ -199,7 +280,7 @@ class Pruner(Compressor):
wrapper
.
masks
=
masks
class
PrunerLayerWrapper
(
tf
.
keras
.
Model
):
class
PrunerLayerWrapper
(
LayerWrapper
):
"""
Instrumented TF layer.
...
...
@@ -210,8 +291,6 @@ class PrunerLayerWrapper(tf.keras.Model):
Attributes
----------
layer_info : LayerInfo
All static information of the original layer.
layer : tf.keras.layers.Layer
The original layer.
config : JSON object
...
...
@@ -233,6 +312,10 @@ class PrunerLayerWrapper(tf.keras.Model):
_logger
.
info
(
'Layer detected to compress: %s'
,
self
.
layer
.
name
)
def
call
(
self
,
*
inputs
):
self
.
_update_weights
()
return
self
.
layer
(
*
inputs
)
def
_update_weights
(
self
):
new_weights
=
[]
for
weight
in
self
.
layer
.
weights
:
mask
=
self
.
masks
.
get
(
weight
.
name
)
...
...
@@ -243,7 +326,6 @@ class PrunerLayerWrapper(tf.keras.Model):
if
new_weights
and
not
hasattr
(
new_weights
[
0
],
'numpy'
):
raise
RuntimeError
(
'NNI: Compressed model can only run in eager mode'
)
self
.
layer
.
set_weights
([
weight
.
numpy
()
for
weight
in
new_weights
])
return
self
.
layer
(
*
inputs
)
# TODO: designed to replace `patch_optimizer`
...
...
test/ut/sdk/test_compressor_tf.py
View file @
b7062b5d
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
pathlib
import
Path
import
tempfile
import
unittest
import
numpy
as
np
...
...
@@ -27,6 +29,9 @@ import tensorflow as tf
# This tensor is used as input of 10x10 linear layer, the first dimension is batch size
tensor1x10
=
tf
.
constant
([[
1.0
]
*
10
])
# This tensor is used as input of CNN models
image_tensor
=
tf
.
zeros
([
1
,
10
,
10
,
3
])
@
unittest
.
skipIf
(
tf
.
__version__
[
0
]
!=
'2'
,
'Skip TF 1.x setup'
)
class
TfCompressorTestCase
(
unittest
.
TestCase
):
...
...
@@ -42,13 +47,37 @@ class TfCompressorTestCase(unittest.TestCase):
layer_types
=
sorted
(
type
(
wrapper
.
layer
).
__name__
for
wrapper
in
pruner
.
wrappers
)
assert
layer_types
==
[
'Conv2D'
,
'Dense'
,
'Dense'
],
layer_types
def
test_level_pruner
(
self
):
def
test_level_pruner
_and_export_correctness
(
self
):
# prune 90% : 9.0 + 9.1 + ... + 9.9 = 94.5
model
=
build_naive_model
()
pruners
[
'level'
](
model
).
compress
()
pruner
=
pruners
[
'level'
](
model
)
model
=
pruner
.
compress
()
x
=
model
(
tensor1x10
)
assert
x
.
numpy
()
==
94.5
temp_dir
=
Path
(
tempfile
.
gettempdir
())
pruner
.
export_model
(
temp_dir
/
'model'
,
temp_dir
/
'mask'
)
# because exporting will uninstrument and re-instrument the model,
# we must test the model again
x
=
model
(
tensor1x10
)
assert
x
.
numpy
()
==
94.5
# load and test exported model
exported_model
=
tf
.
keras
.
models
.
load_model
(
temp_dir
/
'model'
)
x
=
exported_model
(
tensor1x10
)
assert
x
.
numpy
()
==
94.5
def
test_export_not_crash
(
self
):
for
model
in
[
CnnModel
(),
build_sequential_model
()]:
pruner
=
pruners
[
'level'
](
model
)
model
=
pruner
.
compress
()
# cannot use model.build(image_tensor.shape) here
# it fails even without compression
# seems TF's bug, not ours
model
(
image_tensor
)
pruner
.
export_model
(
tempfile
.
TemporaryDirectory
().
name
)
try
:
from
tensorflow.keras
import
Model
,
Sequential
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment