Unverified Commit e93d2c25 authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Merge model compression dev branch to master (#1571)

* [Proposal] demo compressor (#1402)

model compression

* update doc for model compression (#1509)

* Update Overview.md

* Change Doc (#1510)

* refactor compression sdk (#1562)

* refactor compression sdk

* bugfix

* bugfix

* update ut

* Sync model compression doc and implementation (#1575)

* update doc

* formatting

* bugfix

* add import to examples
parent 3274ca30
import torch
import logging
from . import default_layers
_logger = logging.getLogger(__name__)
class LayerInfo:
def __init__(self, name, module):
self.module = module
self.name = name
self.type = type(module).__name__
self._forward = None
class Compressor:
"""
Abstract base PyTorch compressor
"""
def __init__(self, config_list):
self._bound_model = None
self._config_list = config_list
def __call__(self, model):
self.compress(model)
return model
def compress(self, model):
"""
Compress the model with algorithm implemented by subclass.
The model will be instrumented and user should never edit it after calling this method.
"""
assert self._bound_model is None, "Each NNI compressor instance can only compress one model"
self._bound_model = model
self.bind_model(model)
for name, module in model.named_modules():
layer = LayerInfo(name, module)
config = self._select_config(layer)
if config is not None:
self._instrument_layer(layer, config)
def bind_model(self, model):
"""
This method is called when a model is bound to the compressor.
Users can optionally overload this method to do model-specific initialization.
It is guaranteed that only one model will be bound to each compressor instance.
"""
pass
def update_epoch(self, epoch):
"""
if user want to update model every epoch, user can override this method
"""
pass
def step(self):
"""
if user want to update model every step, user can override this method
"""
pass
def _instrument_layer(self, layer, config):
raise NotImplementedError()
def _select_config(self, layer):
ret = None
for config in self._config_list:
op_types = config.get('op_types')
if op_types == 'default':
op_types = default_layers.weighted_modules
if op_types and layer.type not in op_types:
continue
if config.get('op_names') and layer.name not in config['op_names']:
continue
ret = config
if ret is None or ret.get('exclude'):
return None
return ret
class Pruner(Compressor):
"""
Abstract base PyTorch pruner
"""
def __init__(self, config_list):
super().__init__(config_list)
def calc_mask(self, weight, config, op, op_type, op_name):
"""
Pruners should overload this method to provide mask for weight tensors.
The mask must have the same shape and type comparing to the weight.
It will be applied with `mul()` operation.
This method is effectively hooked to `forward()` method of the model.
"""
raise NotImplementedError("Pruners must overload calc_mask()")
def _instrument_layer(self, layer, config):
# TODO: support multiple weight tensors
# create a wrapper forward function to replace the original one
assert layer._forward is None, 'Each model can only be compressed once'
if not _check_weight(layer.module):
_logger.warning('Module {} does not have parameter "weight"'.format(layer.name))
return
layer._forward = layer.module.forward
def new_forward(*input):
# apply mask to weight
old_weight = layer.module.weight.data
mask = self.calc_mask(old_weight, config, op=layer.module, op_type=layer.type, op_name=layer.name)
layer.module.weight.data = old_weight.mul(mask)
# calculate forward
ret = layer._forward(*input)
# recover original weight
layer.module.weight.data = old_weight
return ret
layer.module.forward = new_forward
class Quantizer(Compressor):
"""
Base quantizer for pytorch quantizer
"""
def __init__(self, config_list):
super().__init__(config_list)
def __call__(self, model):
self.compress(model)
return model
def quantize_weight(self, weight, config, op, op_type, op_name):
"""
user should know where dequantize goes and implement it in quantize method
we now do not provide dequantize method
"""
raise NotImplementedError("Quantizer must overload quantize_weight()")
def _instrument_layer(self, layer, config):
assert layer._forward is None, 'Each model can only be compressed once'
if not _check_weight(layer.module):
_logger.warning('Module {} does not have parameter "weight"'.format(layer.name))
return
layer._forward = layer.module.forward
def new_forward(*input):
weight = layer.module.weight.data
new_weight = self.quantize_weight(weight, config, op=layer.module, op_type=layer.type, op_name=layer.name)
layer.module.weight.data = new_weight
return layer._forward(*input)
layer.module.forward = new_forward
def _check_weight(module):
try:
return isinstance(module.weight, torch.nn.Parameter) and isinstance(module.weight.data, torch.Tensor)
except AttributeError:
return False
weighted_modules = [
'Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d',
'Linear', 'Bilinear',
'PReLU',
'Embedding', 'EmbeddingBag',
]
from unittest import TestCase, main
import nni.compression.tensorflow as tf_compressor
import nni.compression.torch as torch_compressor
import torch
import torch.nn.functional as F
import tensorflow as tf
def weight_variable(shape):
return tf.Variable(tf.truncated_normal(shape, stddev = 0.1))
def bias_variable(shape):
return tf.Variable(tf.constant(0.1, shape = shape))
def conv2d(x_input, w_matrix):
return tf.nn.conv2d(x_input, w_matrix, strides = [ 1, 1, 1, 1 ], padding = 'SAME')
def max_pool(x_input, pool_size):
size = [ 1, pool_size, pool_size, 1 ]
return tf.nn.max_pool(x_input, ksize = size, strides = size, padding = 'SAME')
class TfMnist:
def __init__(self):
images = tf.placeholder(tf.float32, [ None, 784 ], name = 'input_x')
labels = tf.placeholder(tf.float32, [ None, 10 ], name = 'input_y')
keep_prob = tf.placeholder(tf.float32, name='keep_prob')
self.images = images
self.labels = labels
self.keep_prob = keep_prob
self.train_step = None
self.accuracy = None
self.w1 = None
self.b1 = None
self.fcw1 = None
self.cross = None
with tf.name_scope('reshape'):
x_image = tf.reshape(images, [ -1, 28, 28, 1 ])
with tf.name_scope('conv1'):
w_conv1 = weight_variable([ 5, 5, 1, 32 ])
self.w1 = w_conv1
b_conv1 = bias_variable([ 32 ])
self.b1 = b_conv1
h_conv1 = tf.nn.relu(conv2d(x_image, w_conv1) + b_conv1)
with tf.name_scope('pool1'):
h_pool1 = max_pool(h_conv1, 2)
with tf.name_scope('conv2'):
w_conv2 = weight_variable([ 5, 5, 32, 64 ])
b_conv2 = bias_variable([ 64 ])
h_conv2 = tf.nn.relu(conv2d(h_pool1, w_conv2) + b_conv2)
with tf.name_scope('pool2'):
h_pool2 = max_pool(h_conv2, 2)
with tf.name_scope('fc1'):
w_fc1 = weight_variable([ 7 * 7 * 64, 1024 ])
self.fcw1 = w_fc1
b_fc1 = bias_variable([ 1024 ])
h_pool2_flat = tf.reshape(h_pool2, [ -1, 7 * 7 * 64 ])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1) + b_fc1)
with tf.name_scope('dropout'):
h_fc1_drop = tf.nn.dropout(h_fc1, 0.5)
with tf.name_scope('fc2'):
w_fc2 = weight_variable([ 1024, 10 ])
b_fc2 = bias_variable([ 10 ])
y_conv = tf.matmul(h_fc1_drop, w_fc2) + b_fc2
with tf.name_scope('loss'):
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels = labels, logits = y_conv))
self.cross = cross_entropy
with tf.name_scope('adam_optimizer'):
self.train_step = tf.train.AdamOptimizer(0.0001).minimize(cross_entropy)
with tf.name_scope('accuracy'):
correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(labels, 1))
self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
class TorchMnist(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)
self.fc2 = torch.nn.Linear(500, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim = 1)
class CompressorTestCase(TestCase):
def test_tf_pruner(self):
model = TfMnist()
configure_list = [{'sparsity':0.8, 'op_types':'default'}]
tf_compressor.LevelPruner(configure_list).compress_default_graph()
def test_tf_quantizer(self):
model = TfMnist()
tf_compressor.NaiveQuantizer([{'op_types': 'default'}]).compress_default_graph()
def test_torch_pruner(self):
model = TorchMnist()
configure_list = [{'sparsity':0.8, 'op_types':'default'}]
torch_compressor.LevelPruner(configure_list).compress(model)
def test_torch_quantizer(self):
model = TorchMnist()
torch_compressor.NaiveQuantizer([{'op_types': 'default'}]).compress(model)
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment