Unverified Commit 2566badb authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[Model Compression] Pruning Wrapper Refactor (#4488)

parent 8d5f643c
...@@ -257,14 +257,7 @@ class Compressor: ...@@ -257,14 +257,7 @@ class Compressor:
Dict[str, str] Dict[str, str]
Return a dict `{original_model_parameter_name: wrapped_model_parameter_name}` Return a dict `{original_model_parameter_name: wrapped_model_parameter_name}`
""" """
if self.is_wrapped: raise NotImplementedError()
wrapped_param_names = {id(param): name for name, param in self.bound_model.named_parameters()}
self._unwrap_model()
parameter_name_map = {name: wrapped_param_names[id(param)] for name, param in self.bound_model.named_parameters()}
self._wrap_model()
return parameter_name_map
else:
raise Exception('When only the model is wrapped can get the parameter_name_map.')
def _wrap_modules(self, layer: LayerInfo, config: Dict): def _wrap_modules(self, layer: LayerInfo, config: Dict):
""" """
......
...@@ -6,9 +6,9 @@ from typing import Dict, List, Optional, Tuple ...@@ -6,9 +6,9 @@ from typing import Dict, List, Optional, Tuple
import torch import torch
from torch import Tensor from torch import Tensor
from torch.nn import Module from torch.nn import Module, Parameter
from .compressor import Compressor, LayerInfo from .compressor import Compressor, LayerInfo, _setattr
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -27,31 +27,57 @@ class PrunerModuleWrapper(Module): ...@@ -27,31 +27,57 @@ class PrunerModuleWrapper(Module):
The configurations that users specify for compression. The configurations that users specify for compression.
module_name module_name
The name of the module to compress, wrapper module shares same name. The name of the module to compress, wrapper module shares same name.
pruner
The pruner used to calculate mask.
""" """
def __init__(self, module: Module, module_name: str, config: Dict, pruner: Compressor): def __init__(self, module: Module, module_name: str, config: Dict):
super().__init__() super().__init__()
# origin layer information # origin layer information
self.module = module self.module = module
self.name = module_name self.name = module_name
# config and pruner # config information
self.config = config self.config = config
self.pruner = pruner
self.weight = Parameter(torch.empty(self.module.weight.size()))
# register buffer for mask # register buffer for mask
self.register_buffer("weight_mask", torch.ones(self.module.weight.shape)) self.register_buffer("weight_mask", torch.ones(self.module.weight.shape))
if hasattr(self.module, 'bias') and self.module.bias is not None: if hasattr(self.module, 'bias') and self.module.bias is not None:
self.register_buffer("bias_mask", torch.ones(self.module.bias.shape)) self.register_buffer("bias_mask", torch.ones(self.module.bias.shape))
self.bias = Parameter(torch.empty(self.module.bias.size()))
else: else:
self.register_buffer("bias_mask", None) self.register_buffer("bias_mask", None)
def _weight2buffer(self):
"""
When using this wrapper to inference, call `_weight2buffer()` to make original weight untrainable.
The best place to call this function is in `Pruner._wrap_model()`.
"""
self.weight.data = self.module.weight.data
delattr(self.module, 'weight')
self.module.register_buffer('weight', self.weight.data)
if hasattr(self.module, 'bias') and self.module.bias is not None:
self.bias.data = self.module.bias.data
delattr(self.module, 'bias')
self.module.register_buffer('bias', self.bias.data)
def _weight2parameter(self):
"""
When don't need to record score or need to export the model, call `_weight2parameter()` to make the original weight trainable.
The best place to call this function is in `Pruner._unwrap_model()`.
"""
delattr(self.module, 'weight')
self.module.weight = Parameter(torch.empty(self.weight.size()))
self.module.weight.data = torch.mul(self.weight, self.weight_mask)
if hasattr(self.module, 'bias') and self.module.bias is not None:
delattr(self.module, 'bias')
self.module.bias = Parameter(torch.empty(self.bias.size()))
self.module.bias.data = torch.mul(self.bias, self.bias_mask)
def forward(self, *inputs): def forward(self, *inputs):
# apply mask to weight, bias # apply mask to weight, bias
self.module.weight.data = self.module.weight.data.mul_(self.weight_mask) self.module.weight = torch.mul(self.weight, self.weight_mask)
if hasattr(self.module, 'bias') and self.module.bias is not None: if hasattr(self.module, 'bias') and self.module.bias is not None:
self.module.bias.data = self.module.bias.data.mul_(self.bias_mask) self.module.bias = torch.mul(self.bias, self.bias_mask)
return self.module(*inputs) return self.module(*inputs)
...@@ -75,12 +101,58 @@ class Pruner(Compressor): ...@@ -75,12 +101,58 @@ class Pruner(Compressor):
The configuration for generating the mask. The configuration for generating the mask.
""" """
_logger.debug("Module detected to compress : %s.", layer.name) _logger.debug("Module detected to compress : %s.", layer.name)
wrapper = PrunerModuleWrapper(layer.module, layer.name, config, self) wrapper = PrunerModuleWrapper(layer.module, layer.name, config)
assert hasattr(layer.module, 'weight'), "module %s does not have 'weight' attribute" % layer.name assert hasattr(layer.module, 'weight'), "module %s does not have 'weight' attribute" % layer.name
# move newly registered buffers to the same device of weight # move newly registered buffers to the same device of weight
wrapper.to(layer.module.weight.device) wrapper.to(layer.module.weight.device)
return wrapper return wrapper
# The following `_wrap_model`, `_unwrap_model`, `get_origin2wrapped_parameter_name_map` can merge to `Compressor`,
# if quantizer use the similar structure wrapper.
def _wrap_model(self):
"""
Wrap all modules that needed to be compressed.
Different from the parent function, call `wrapper._weight2buffer()` after replace the origin module to wrapper.
"""
if not self.is_wrapped:
for _, wrapper in reversed(self.get_modules_wrapper().items()):
_setattr(self.bound_model, wrapper.name, wrapper)
wrapper._weight2buffer()
self.is_wrapped = True
def _unwrap_model(self):
"""
Unwrap all modules that needed to be compressed.
Different from the parent function, call `wrapper._weight2parameter()` after replace the wrapper to origin module.
"""
if self.is_wrapped:
for _, wrapper in self.get_modules_wrapper().items():
_setattr(self.bound_model, wrapper.name, wrapper.module)
wrapper._weight2parameter()
self.is_wrapped = False
def get_origin2wrapped_parameter_name_map(self) -> Dict[str, str]:
"""
Get the name mapping of parameters from original model to wrapped model.
Returns
-------
Dict[str, str]
Return a dict `{original_model_parameter_name: wrapped_model_parameter_name}`
"""
if self.is_wrapped:
wrapped_param_names = {id(param): name for name, param in self.bound_model.named_parameters()}
self._unwrap_model()
parameter_name_map = {}
for name, param in self.bound_model.named_parameters():
# If the parameter name in under wrapped module is `xxx.weight` or `xxx.bias`, the name will not change after wrap.
# If the parameter name in under wrapped module is others, the name `xxx.param` will change to `xxx.module.param` after wrap.
parameter_name_map[name] = wrapped_param_names[id(param)] if id(param) in wrapped_param_names else name
self._wrap_model()
return parameter_name_map
else:
raise Exception('When only the model is wrapped can get the parameter_name_map.')
def load_masks(self, masks: Dict[str, Dict[str, Tensor]]): def load_masks(self, masks: Dict[str, Dict[str, Tensor]]):
""" """
Load an exist masks on the wrapper. You can train the model with an exist masks after load the masks. Load an exist masks on the wrapper. You can train the model with an exist masks after load the masks.
......
...@@ -999,7 +999,7 @@ class TaylorFOWeightPruner(BasicPruner): ...@@ -999,7 +999,7 @@ class TaylorFOWeightPruner(BasicPruner):
return (weight_tensor.detach() * grad.detach()).data.pow(2) return (weight_tensor.detach() * grad.detach()).data.pow(2)
def reset_tools(self): def reset_tools(self):
hook_targets = {layer_info.name: layer_info.module.weight for layer_info, _ in self._detect_modules_to_compress()} hook_targets = {name: wrapper.weight for name, wrapper in self.get_modules_wrapper().items()}
collector_info = HookCollectorInfo(hook_targets, 'tensor', self._collector) collector_info = HookCollectorInfo(hook_targets, 'tensor', self._collector)
if self.data_collector is None: if self.data_collector is None:
self.data_collector = SingleHookTrainerBasedDataCollector(self, self.trainer, self.optimizer_helper, self.criterion, self.data_collector = SingleHookTrainerBasedDataCollector(self, self.trainer, self.optimizer_helper, self.criterion,
......
...@@ -10,7 +10,7 @@ from torch import autograd, Tensor ...@@ -10,7 +10,7 @@ from torch import autograd, Tensor
from torch.nn import Module, Parameter from torch.nn import Module, Parameter
from torch.optim import Optimizer, Adam from torch.optim import Optimizer, Adam
from nni.algorithms.compression.v2.pytorch.base.compressor import Compressor, _setattr, LayerInfo from nni.algorithms.compression.v2.pytorch.base import PrunerModuleWrapper, LayerInfo
from nni.algorithms.compression.v2.pytorch.pruning.basic_pruner import BasicPruner, NORMAL_SCHEMA, EXCLUDE_SCHEMA, INTERNAL_SCHEMA from nni.algorithms.compression.v2.pytorch.pruning.basic_pruner import BasicPruner, NORMAL_SCHEMA, EXCLUDE_SCHEMA, INTERNAL_SCHEMA
from nni.algorithms.compression.v2.pytorch.utils import CompressorSchema, OptimizerConstructHelper from nni.algorithms.compression.v2.pytorch.utils import CompressorSchema, OptimizerConstructHelper
from nni.common.serializer import Traceable from nni.common.serializer import Traceable
...@@ -25,7 +25,7 @@ from .tools import ( ...@@ -25,7 +25,7 @@ from .tools import (
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
class PrunerScoredModuleWrapper(Module): class PrunerScoredModuleWrapper(PrunerModuleWrapper):
""" """
Wrap a module to enable data parallel, forward method customization and buffer registeration. Wrap a module to enable data parallel, forward method customization and buffer registeration.
Different from `PrunerModuleWrapper`, `PrunerScoredModuleWrapper` will record the gradient. Different from `PrunerModuleWrapper`, `PrunerScoredModuleWrapper` will record the gradient.
...@@ -38,56 +38,12 @@ class PrunerScoredModuleWrapper(Module): ...@@ -38,56 +38,12 @@ class PrunerScoredModuleWrapper(Module):
The configurations that users specify for compression. The configurations that users specify for compression.
module_name module_name
The name of the module to compress, wrapper module shares same name. The name of the module to compress, wrapper module shares same name.
pruner
The pruner used to calculate mask.
""" """
def __init__(self, module: Module, module_name: str, config: Dict, pruner: Compressor): def __init__(self, module: Module, module_name: str, config: Dict):
super().__init__() super().__init__(module, module_name, config)
# origin layer information
self.module = module
self.name = module_name
# config and pruner
self.config = config
self.pruner = pruner
self.weight = Parameter(torch.empty(self.module.weight.size()))
self.weight_score = Parameter(torch.empty(self.weight.size())) self.weight_score = Parameter(torch.empty(self.weight.size()))
torch.nn.init.constant_(self.weight_score, val=0.0) torch.nn.init.constant_(self.weight_score, val=0.0)
# register buffer for mask
self.register_buffer("weight_mask", torch.ones(self.module.weight.shape))
if hasattr(self.module, 'bias') and self.module.bias is not None:
self.register_buffer("bias_mask", torch.ones(self.module.bias.shape))
self.bias = Parameter(torch.empty(self.module.bias.size()))
else:
self.register_buffer("bias_mask", None)
def _weight2buffer(self):
"""
When using this wrapper to inference, call `_weight2buffer()` to make original weight untrainable.
The best place to call this function is in `Pruner._wrap_model()`.
"""
self.weight.data = self.module.weight.data
delattr(self.module, 'weight')
self.module.register_buffer('weight', self.weight.data)
if hasattr(self.module, 'bias') and self.module.bias is not None:
self.bias.data = self.module.bias.data
delattr(self.module, 'bias')
self.module.register_buffer('bias', self.bias.data)
def _weight2parameter(self):
"""
When don't need to record score or need to export the model, call `_weight2parameter()` to make the original weight trainable.
The best place to call this function is in `Pruner._unwrap_model()`.
"""
delattr(self.module, 'weight')
self.module.weight = Parameter(torch.empty(self.weight.size()))
self.module.weight.data = torch.mul(self.weight, self.weight_mask)
if hasattr(self.module, 'bias') and self.module.bias is not None:
delattr(self.module, 'bias')
self.module.bias = Parameter(torch.empty(self.bias.size()))
self.module.bias.data = torch.mul(self.bias, self.bias_mask)
def forward(self, *inputs): def forward(self, *inputs):
# apply mask to weight, bias # apply mask to weight, bias
self.module.weight = torch.mul(self.weight, _StraightThrough.apply(self.weight_score, self.weight_mask)) self.module.weight = torch.mul(self.weight, _StraightThrough.apply(self.weight_score, self.weight_mask))
...@@ -259,28 +215,6 @@ class MovementPruner(BasicPruner): ...@@ -259,28 +215,6 @@ class MovementPruner(BasicPruner):
else: else:
self.data_collector.reset() self.data_collector.reset()
def _wrap_model(self):
"""
Wrap all modules that needed to be compressed.
Different from the parent function, call `wrapper._weight2buffer()` after replace the origin module to wrapper.
"""
if not self.is_wrapped:
for _, wrapper in reversed(self.get_modules_wrapper().items()):
_setattr(self.bound_model, wrapper.name, wrapper)
wrapper._weight2buffer()
self.is_wrapped = True
def _unwrap_model(self):
"""
Unwrap all modules that needed to be compressed.
Different from the parent function, call `wrapper._weight2parameter()` after replace the wrapper to origin module.
"""
if self.is_wrapped:
for _, wrapper in self.get_modules_wrapper().items():
_setattr(self.bound_model, wrapper.name, wrapper.module)
wrapper._weight2parameter()
self.is_wrapped = False
def _wrap_modules(self, layer: LayerInfo, config: Dict): def _wrap_modules(self, layer: LayerInfo, config: Dict):
""" """
Create a wrapper module to replace the original one. Create a wrapper module to replace the original one.
...@@ -294,21 +228,12 @@ class MovementPruner(BasicPruner): ...@@ -294,21 +228,12 @@ class MovementPruner(BasicPruner):
The configuration for generating the mask. The configuration for generating the mask.
""" """
_logger.debug("Module detected to compress : %s.", layer.name) _logger.debug("Module detected to compress : %s.", layer.name)
wrapper = PrunerScoredModuleWrapper(layer.module, layer.name, config, self) wrapper = PrunerScoredModuleWrapper(layer.module, layer.name, config)
assert hasattr(layer.module, 'weight'), "module %s does not have 'weight' attribute" % layer.name assert hasattr(layer.module, 'weight'), "module %s does not have 'weight' attribute" % layer.name
# move newly registered buffers to the same device of weight # move newly registered buffers to the same device of weight
wrapper.to(layer.module.weight.device) wrapper.to(layer.module.weight.device)
return wrapper return wrapper
def get_origin2wrapped_parameter_name_map(self) -> Dict[str, str]:
if self.is_wrapped:
self._unwrap_model()
parameter_name_map = {name: name for name, _ in self.bound_model.named_parameters()}
self._wrap_model()
return parameter_name_map
else:
raise Exception('When only the model is wrapped can get the parameter_name_map.')
def compress(self) -> Tuple[Module, Dict]: def compress(self) -> Tuple[Module, Dict]:
# sparsity grow from 0 # sparsity grow from 0
for _, wrapper in self.get_modules_wrapper().items(): for _, wrapper in self.get_modules_wrapper().items():
......
...@@ -384,7 +384,7 @@ class SparsityAllocator: ...@@ -384,7 +384,7 @@ class SparsityAllocator:
weight_mask = weight_mask.expand(expand_size).reshape(reshape_size) weight_mask = weight_mask.expand(expand_size).reshape(reshape_size)
wrapper = self.pruner.get_modules_wrapper()[name] wrapper = self.pruner.get_modules_wrapper()[name]
weight_size = wrapper.module.weight.data.size() weight_size = wrapper.weight.data.size()
if self.dim is None: if self.dim is None:
assert weight_mask.size() == weight_size assert weight_mask.size() == weight_size
......
...@@ -24,7 +24,7 @@ class WeightDataCollector(DataCollector): ...@@ -24,7 +24,7 @@ class WeightDataCollector(DataCollector):
def collect(self) -> Dict[str, Tensor]: def collect(self) -> Dict[str, Tensor]:
data = {} data = {}
for _, wrapper in self.compressor.get_modules_wrapper().items(): for _, wrapper in self.compressor.get_modules_wrapper().items():
data[wrapper.name] = wrapper.module.weight.data data[wrapper.name] = wrapper.weight.data
return data return data
...@@ -39,7 +39,7 @@ class WeightTrainerBasedDataCollector(TrainerBasedDataCollector): ...@@ -39,7 +39,7 @@ class WeightTrainerBasedDataCollector(TrainerBasedDataCollector):
data = {} data = {}
for _, wrapper in self.compressor.get_modules_wrapper().items(): for _, wrapper in self.compressor.get_modules_wrapper().items():
data[wrapper.name] = wrapper.module.weight.data data[wrapper.name] = wrapper.weight.data
return data return data
......
...@@ -132,7 +132,7 @@ class GlobalSparsityAllocator(SparsityAllocator): ...@@ -132,7 +132,7 @@ class GlobalSparsityAllocator(SparsityAllocator):
if self.continuous_mask: if self.continuous_mask:
metric = metric * self._compress_mask(wrapper.weight_mask) metric = metric * self._compress_mask(wrapper.weight_mask)
layer_weight_num = wrapper.module.weight.data.numel() layer_weight_num = wrapper.weight.data.numel()
total_weight_num += layer_weight_num total_weight_num += layer_weight_num
expend_times = int(layer_weight_num / metric.numel()) expend_times = int(layer_weight_num / metric.numel())
......
...@@ -83,17 +83,17 @@ class PruningToolsTestCase(unittest.TestCase): ...@@ -83,17 +83,17 @@ class PruningToolsTestCase(unittest.TestCase):
# Test WeightDataCollector # Test WeightDataCollector
data_collector = WeightDataCollector(pruner) data_collector = WeightDataCollector(pruner)
data = data_collector.collect() data = data_collector.collect()
assert all(torch.equal(get_module_by_name(model, module_name)[1].module.weight.data, data[module_name]) for module_name in ['conv1', 'conv2']) assert all(torch.equal(get_module_by_name(model, module_name)[1].weight.data, data[module_name]) for module_name in ['conv1', 'conv2'])
# Test WeightTrainerBasedDataCollector # Test WeightTrainerBasedDataCollector
def opt_after(): def opt_after():
model.conv1.module.weight.data = torch.ones(5, 1, 5, 5) model.conv1.weight.data = torch.ones(5, 1, 5, 5)
model.conv2.module.weight.data = torch.ones(10, 5, 5, 5) model.conv2.weight.data = torch.ones(10, 5, 5, 5)
optimizer_helper = OptimizerConstructHelper.from_trace(model, get_optimizer(model)) optimizer_helper = OptimizerConstructHelper.from_trace(model, get_optimizer(model))
data_collector = WeightTrainerBasedDataCollector(pruner, trainer, optimizer_helper, criterion, 1, opt_after_tasks=[opt_after]) data_collector = WeightTrainerBasedDataCollector(pruner, trainer, optimizer_helper, criterion, 1, opt_after_tasks=[opt_after])
data = data_collector.collect() data = data_collector.collect()
assert all(torch.equal(get_module_by_name(model, module_name)[1].module.weight.data, data[module_name]) for module_name in ['conv1', 'conv2']) assert all(torch.equal(get_module_by_name(model, module_name)[1].weight.data, data[module_name]) for module_name in ['conv1', 'conv2'])
assert all(t.numel() == (t == 1).type_as(t).sum().item() for t in data.values()) assert all(t.numel() == (t == 1).type_as(t).sum().item() for t in data.values())
# Test SingleHookTrainerBasedDataCollector # Test SingleHookTrainerBasedDataCollector
...@@ -102,7 +102,7 @@ class PruningToolsTestCase(unittest.TestCase): ...@@ -102,7 +102,7 @@ class PruningToolsTestCase(unittest.TestCase):
if len(buffer) < 2: if len(buffer) < 2:
buffer.append(grad.clone().detach()) buffer.append(grad.clone().detach())
return collect_taylor return collect_taylor
hook_targets = {'conv1': model.conv1.module.weight, 'conv2': model.conv2.module.weight} hook_targets = {'conv1': model.conv1.weight, 'conv2': model.conv2.weight}
collector_info = HookCollectorInfo(hook_targets, 'tensor', _collector) collector_info = HookCollectorInfo(hook_targets, 'tensor', _collector)
optimizer_helper = OptimizerConstructHelper.from_trace(model, get_optimizer(model)) optimizer_helper = OptimizerConstructHelper.from_trace(model, get_optimizer(model))
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import unittest
import torch
import torch.nn.functional as F
from nni.algorithms.compression.v2.pytorch.pruning import L1NormPruner
class TorchModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 5, 5, 1)
self.bn1 = torch.nn.BatchNorm2d(5)
self.conv2 = torch.nn.Conv2d(5, 10, 5, 1)
self.bn2 = torch.nn.BatchNorm2d(10)
self.fc1 = torch.nn.Linear(4 * 4 * 10, 100)
self.fc2 = torch.nn.Linear(100, 10)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.bn2(self.conv2(x)))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 10)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
class PrunerTestCase(unittest.TestCase):
def test_pruner_module_wrapper(self):
model = TorchModel()
conv1_weight = model.conv1.weight.data.clone()
conv2_weight = model.conv2.weight.data.clone()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
pruner = L1NormPruner(model, config_list)
_, masks = pruner.compress()
model(torch.rand(10, 1, 28, 28))
assert torch.equal(model.conv1.weight.data, conv1_weight)
assert torch.equal(model.conv2.weight.data, conv2_weight)
assert torch.equal(model.conv1.module.weight.data, conv1_weight * masks['conv1']['weight'])
assert torch.equal(model.conv2.module.weight.data, conv2_weight * masks['conv2']['weight'])
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment