"tests/vscode:/vscode.git/clone" did not exist on "d185b5ed5f23c5912918ee81881a3c03f9359523"
Unverified Commit b7062b5d authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

[Model Compression / TensorFlow] Support exporting pruned model (#3487)

parent f0e3c584
...@@ -87,6 +87,18 @@ class Compressor: ...@@ -87,6 +87,18 @@ class Compressor:
return layer return layer
def _uninstrument(self, layer):
# note that ``self._wrappers`` cache is not cleared here,
# so the same wrapper objects will be recovered in next ``self._instrument()`` call
if isinstance(layer, LayerWrapper):
layer._instrumented = False
return self._uninstrument(layer.layer)
if isinstance(layer, tf.keras.Sequential):
return self._uninstrument_sequential(layer)
if isinstance(layer, tf.keras.Model):
return self._uninstrument_model(layer)
return layer
def _instrument_sequential(self, seq): def _instrument_sequential(self, seq):
layers = list(seq.layers) # seq.layers is read-only property layers = list(seq.layers) # seq.layers is read-only property
need_rebuild = False need_rebuild = False
...@@ -97,6 +109,16 @@ class Compressor: ...@@ -97,6 +109,16 @@ class Compressor:
need_rebuild = True need_rebuild = True
return tf.keras.Sequential(layers) if need_rebuild else seq return tf.keras.Sequential(layers) if need_rebuild else seq
def _uninstrument_sequential(self, seq):
layers = list(seq.layers)
rebuilt = False
for i, layer in enumerate(layers):
orig_layer = self._uninstrument(layer)
if orig_layer is not layer:
layers[i] = orig_layer
rebuilt = True
return tf.keras.Sequential(layers) if rebuilt else seq
def _instrument_model(self, model): def _instrument_model(self, model):
for key, value in list(model.__dict__.items()): # avoid "dictionary keys changed during iteration" for key, value in list(model.__dict__.items()): # avoid "dictionary keys changed during iteration"
if isinstance(value, tf.keras.layers.Layer): if isinstance(value, tf.keras.layers.Layer):
...@@ -109,6 +131,17 @@ class Compressor: ...@@ -109,6 +131,17 @@ class Compressor:
value[i] = self._instrument(item) value[i] = self._instrument(item)
return model return model
def _uninstrument_model(self, model):
for key, value in list(model.__dict__.items()):
if isinstance(value, tf.keras.layers.Layer):
orig_layer = self._uninstrument(value)
if orig_layer is not value:
setattr(model, key, orig_layer)
elif isinstance(value, list):
for i, item in enumerate(value):
if isinstance(item, tf.keras.layers.Layer):
value[i] = self._uninstrument(item)
return model
def _select_config(self, layer): def _select_config(self, layer):
# Find the last matching config block for given layer. # Find the last matching config block for given layer.
...@@ -129,6 +162,17 @@ class Compressor: ...@@ -129,6 +162,17 @@ class Compressor:
return last_match return last_match
class LayerWrapper(tf.keras.Model):
"""
Abstract base class of layer wrappers.
Concrete layer wrapper classes must inherit this to support ``isinstance`` check.
"""
def __init__(self):
super().__init__()
self._instrumented = True
class Pruner(Compressor): class Pruner(Compressor):
""" """
Base class for pruning algorithms. Base class for pruning algorithms.
...@@ -167,6 +211,43 @@ class Pruner(Compressor): ...@@ -167,6 +211,43 @@ class Pruner(Compressor):
self._update_mask() self._update_mask()
return self.compressed_model return self.compressed_model
def export_model(self, model_path, mask_path=None):
"""
Export pruned model and optionally mask tensors.
Parameters
----------
model_path : path-like
The path passed to ``Model.save()``.
You can use ".h5" extension name to export HDF5 format.
mask_path : path-like or None
Export masks to the path when set.
Because Keras cannot save tensors without a ``Model``,
this will create a model, set all masks as its weights, and then save that model.
Masks in saved model will be named by corresponding layer name in compressed model.
Returns
-------
None
"""
_logger.info('Saving model to %s', model_path)
input_shape = self.compressed_model._build_input_shape # cannot find a public API
model = self._uninstrument(self.compressed_model)
if input_shape:
model.build(input_shape)
model.save(model_path)
self._instrument(model)
if mask_path is not None:
_logger.info('Saving masks to %s', mask_path)
# can't find "save raw weights" API in tensorflow, so build a simple model
mask_model = tf.keras.Model()
for wrapper in self.wrappers:
setattr(mask_model, wrapper.layer.name, wrapper.masks)
mask_model.save_weights(mask_path)
_logger.info('Done')
def calc_masks(self, wrapper, **kwargs): def calc_masks(self, wrapper, **kwargs):
""" """
Abstract method to be overridden by algorithm. End users should ignore it. Abstract method to be overridden by algorithm. End users should ignore it.
...@@ -199,7 +280,7 @@ class Pruner(Compressor): ...@@ -199,7 +280,7 @@ class Pruner(Compressor):
wrapper.masks = masks wrapper.masks = masks
class PrunerLayerWrapper(tf.keras.Model): class PrunerLayerWrapper(LayerWrapper):
""" """
Instrumented TF layer. Instrumented TF layer.
...@@ -210,8 +291,6 @@ class PrunerLayerWrapper(tf.keras.Model): ...@@ -210,8 +291,6 @@ class PrunerLayerWrapper(tf.keras.Model):
Attributes Attributes
---------- ----------
layer_info : LayerInfo
All static information of the original layer.
layer : tf.keras.layers.Layer layer : tf.keras.layers.Layer
The original layer. The original layer.
config : JSON object config : JSON object
...@@ -233,6 +312,10 @@ class PrunerLayerWrapper(tf.keras.Model): ...@@ -233,6 +312,10 @@ class PrunerLayerWrapper(tf.keras.Model):
_logger.info('Layer detected to compress: %s', self.layer.name) _logger.info('Layer detected to compress: %s', self.layer.name)
def call(self, *inputs): def call(self, *inputs):
self._update_weights()
return self.layer(*inputs)
def _update_weights(self):
new_weights = [] new_weights = []
for weight in self.layer.weights: for weight in self.layer.weights:
mask = self.masks.get(weight.name) mask = self.masks.get(weight.name)
...@@ -243,7 +326,6 @@ class PrunerLayerWrapper(tf.keras.Model): ...@@ -243,7 +326,6 @@ class PrunerLayerWrapper(tf.keras.Model):
if new_weights and not hasattr(new_weights[0], 'numpy'): if new_weights and not hasattr(new_weights[0], 'numpy'):
raise RuntimeError('NNI: Compressed model can only run in eager mode') raise RuntimeError('NNI: Compressed model can only run in eager mode')
self.layer.set_weights([weight.numpy() for weight in new_weights]) self.layer.set_weights([weight.numpy() for weight in new_weights])
return self.layer(*inputs)
# TODO: designed to replace `patch_optimizer` # TODO: designed to replace `patch_optimizer`
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from pathlib import Path
import tempfile
import unittest import unittest
import numpy as np import numpy as np
...@@ -27,6 +29,9 @@ import tensorflow as tf ...@@ -27,6 +29,9 @@ import tensorflow as tf
# This tensor is used as input of 10x10 linear layer, the first dimension is batch size # This tensor is used as input of 10x10 linear layer, the first dimension is batch size
tensor1x10 = tf.constant([[1.0] * 10]) tensor1x10 = tf.constant([[1.0] * 10])
# This tensor is used as input of CNN models
image_tensor = tf.zeros([1, 10, 10, 3])
@unittest.skipIf(tf.__version__[0] != '2', 'Skip TF 1.x setup') @unittest.skipIf(tf.__version__[0] != '2', 'Skip TF 1.x setup')
class TfCompressorTestCase(unittest.TestCase): class TfCompressorTestCase(unittest.TestCase):
...@@ -42,13 +47,37 @@ class TfCompressorTestCase(unittest.TestCase): ...@@ -42,13 +47,37 @@ class TfCompressorTestCase(unittest.TestCase):
layer_types = sorted(type(wrapper.layer).__name__ for wrapper in pruner.wrappers) layer_types = sorted(type(wrapper.layer).__name__ for wrapper in pruner.wrappers)
assert layer_types == ['Conv2D', 'Dense', 'Dense'], layer_types assert layer_types == ['Conv2D', 'Dense', 'Dense'], layer_types
def test_level_pruner(self): def test_level_pruner_and_export_correctness(self):
# prune 90% : 9.0 + 9.1 + ... + 9.9 = 94.5 # prune 90% : 9.0 + 9.1 + ... + 9.9 = 94.5
model = build_naive_model() model = build_naive_model()
pruners['level'](model).compress() pruner = pruners['level'](model)
model = pruner.compress()
x = model(tensor1x10) x = model(tensor1x10)
assert x.numpy() == 94.5 assert x.numpy() == 94.5
temp_dir = Path(tempfile.gettempdir())
pruner.export_model(temp_dir / 'model', temp_dir / 'mask')
# because exporting will uninstrument and re-instrument the model,
# we must test the model again
x = model(tensor1x10)
assert x.numpy() == 94.5
# load and test exported model
exported_model = tf.keras.models.load_model(temp_dir / 'model')
x = exported_model(tensor1x10)
assert x.numpy() == 94.5
def test_export_not_crash(self):
for model in [CnnModel(), build_sequential_model()]:
pruner = pruners['level'](model)
model = pruner.compress()
# cannot use model.build(image_tensor.shape) here
# it fails even without compression
# seems TF's bug, not ours
model(image_tensor)
pruner.export_model(tempfile.TemporaryDirectory().name)
try: try:
from tensorflow.keras import Model, Sequential from tensorflow.keras import Model, Sequential
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment