Commit 86f8c2ab authored by Tang Lang's avatar Tang Lang Committed by chicm-ms
Browse files

pruner export (#1674)

parent 6210625b
......@@ -66,6 +66,7 @@ def main():
batch_size=1000, shuffle=True)
model = Mnist()
model.to(device)
'''you can change this to LevelPruner to implement it
pruner = LevelPruner(configure_list)
......@@ -80,7 +81,7 @@ def main():
}]
pruner = AGP_Pruner(model, configure_list)
pruner.compress()
model = pruner.compress()
# you can also use compress(model) method
# like that pruner.compress(model)
......@@ -90,6 +91,7 @@ def main():
print('# Epoch {} #'.format(epoch))
train(model, device, train_loader, optimizer)
test(model, device, test_loader)
pruner.export_model('model.pth', 'mask.pth', 'model.onnx', [1, 1, 28, 28])
if __name__ == '__main__':
......
......@@ -17,7 +17,6 @@ class LevelPruner(Pruner):
- sparsity
"""
super().__init__(model, config_list)
self.mask_list = {}
self.if_init_list = {}
def calc_mask(self, layer, config):
......@@ -30,10 +29,10 @@ class LevelPruner(Pruner):
return torch.ones(weight.shape).type_as(weight)
threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max()
mask = torch.gt(w_abs, threshold).type_as(weight)
self.mask_list.update({op_name: mask})
self.mask_dict.update({op_name: mask})
self.if_init_list.update({op_name: False})
else:
mask = self.mask_list[op_name]
mask = self.mask_dict[op_name]
return mask
......@@ -57,7 +56,6 @@ class AGP_Pruner(Pruner):
- frequency: if you want update every 2 epoch, you can set it 2
"""
super().__init__(model, config_list)
self.mask_list = {}
self.now_epoch = 0
self.if_init_list = {}
......@@ -68,7 +66,7 @@ class AGP_Pruner(Pruner):
freq = config.get('frequency', 1)
if self.now_epoch >= start_epoch and self.if_init_list.get(op_name, True) and (
self.now_epoch - start_epoch) % freq == 0:
mask = self.mask_list.get(op_name, torch.ones(weight.shape).type_as(weight))
mask = self.mask_dict.get(op_name, torch.ones(weight.shape).type_as(weight))
target_sparsity = self.compute_target_sparsity(config)
k = int(weight.numel() * target_sparsity)
if k == 0 or target_sparsity >= 1 or target_sparsity <= 0:
......@@ -77,10 +75,10 @@ class AGP_Pruner(Pruner):
w_abs = weight.abs() * mask
threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max()
new_mask = torch.gt(w_abs, threshold).type_as(weight)
self.mask_list.update({op_name: new_mask})
self.mask_dict.update({op_name: new_mask})
self.if_init_list.update({op_name: False})
else:
new_mask = self.mask_list.get(op_name, torch.ones(weight.shape).type_as(weight))
new_mask = self.mask_dict.get(op_name, torch.ones(weight.shape).type_as(weight))
return new_mask
def compute_target_sparsity(self, config):
......
......@@ -128,11 +128,23 @@ class Compressor:
expanded_op_types.append(op_type)
return expanded_op_types
class Pruner(Compressor):
"""
Abstract base PyTorch pruner
Prune to an exact pruning level specification
Attributes
----------
mask_dict : dict
Dictionary for saving masks, `key` should be layer name and
`value` should be a tensor which has the same shape with layer's weight
"""
def __init__(self, model, config_list):
super().__init__(model, config_list)
self.mask_dict = {}
def calc_mask(self, layer, config):
"""
Pruners should overload this method to provide mask for weight tensors.
......@@ -177,6 +189,48 @@ class Pruner(Compressor):
layer.module.forward = new_forward
def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=None):
"""
Export pruned model weights, masks and onnx model(optional)
Parameters
----------
model_path : str
path to save pruned model state_dict
mask_path : str
(optional) path to save mask dict
onnx_path : str
(optional) path to save onnx model
input_shape : list or tuple
input shape to onnx model
"""
assert model_path is not None, 'model_path must be specified'
for name, m in self.bound_model.named_modules():
mask = self.mask_dict.get(name)
if mask is not None:
mask_sum = mask.sum().item()
mask_num = mask.numel()
_logger.info('Layer: %s Sparsity: %.2f', name, 1 - mask_sum / mask_num)
print('Layer: %s Sparsity: %.2f' % (name, 1 - mask_sum / mask_num))
m.weight.data = m.weight.data.mul(mask)
else:
_logger.info('Layer: %s NOT compressed', name)
print('Layer: %s NOT compressed' % name)
torch.save(self.bound_model.state_dict(), model_path)
_logger.info('Model state_dict saved to %s', model_path)
print('Model state_dict saved to %s' % model_path)
if mask_path is not None:
torch.save(self.mask_dict, mask_path)
_logger.info('Mask dict saved to %s', mask_path)
print('Mask dict saved to %s' % mask_path)
if onnx_path is not None:
assert input_shape is not None, 'input_shape must be specified to export onnx model'
# input info needed
input_data = torch.Tensor(*input_shape)
torch.onnx.export(self.bound_model, input_data, onnx_path)
_logger.info('Model in onnx with input shape %s saved to %s', input_data.shape, onnx_path)
print('Model in onnx with input shape %s saved to %s' % (input_data.shape, onnx_path))
class Quantizer(Compressor):
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment