Unverified Commit a911b856 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Resolve conflicts for #4760 (#4762)

parent 14d2966b
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from .apply_compression import apply_compression_results from nni.algorithms.compression.v2.pytorch.pruning import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from enum import Enum, EnumMeta from enum import Enum, EnumMeta
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from torch.quantization import default_weight_observer, default_histogram_observer from torch.quantization import default_weight_observer, default_histogram_observer
from torch.quantization import RecordingObserver as _RecordingObserver from torch.quantization import RecordingObserver as _RecordingObserver
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Any, Optional from typing import Any, Optional
from .literal import QuantDtype, QuantType, QuantScheme from .literal import QuantDtype, QuantType, QuantScheme
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch import torch
from nni.common.version import TORCH_VERSION from nni.common.version import TORCH_VERSION
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .integrated_tensorrt import CalibrateType, ModelSpeedupTensorRT from .integrated_tensorrt import CalibrateType, ModelSpeedupTensorRT
\ No newline at end of file
...@@ -10,7 +10,7 @@ class BaseModelSpeedup: ...@@ -10,7 +10,7 @@ class BaseModelSpeedup:
Parameters Parameters
---------- ----------
model : pytorch model model : pytorch model
The model to speed up by quantization. The model to speedup by quantization.
config : dict config : dict
Config recording bit number and name of layers. Config recording bit number and name of layers.
""" """
......
...@@ -37,7 +37,7 @@ def _setattr(model, name, module): ...@@ -37,7 +37,7 @@ def _setattr(model, name, module):
Parameters Parameters
---------- ----------
model : pytorch model model : pytorch model
The model to speed up by quantization The model to speedup by quantization
name : str name : str
name of pytorch module name of pytorch module
module : torch.nn.Module module : torch.nn.Module
...@@ -98,7 +98,7 @@ def torch_to_onnx(model, config, input_shape, model_path, input_names, output_na ...@@ -98,7 +98,7 @@ def torch_to_onnx(model, config, input_shape, model_path, input_names, output_na
Parameters Parameters
---------- ----------
model : pytorch model model : pytorch model
The model to speed up by quantization The model to speedup by quantization
config : dict config : dict
Config recording bits number and name of layers Config recording bits number and name of layers
input_shape : tuple input_shape : tuple
......
...@@ -228,40 +228,41 @@ def build_engine(model_file, config=None, extra_layer_bits=32, strict_datatype=F ...@@ -228,40 +228,41 @@ def build_engine(model_file, config=None, extra_layer_bits=32, strict_datatype=F
return engine return engine
class ModelSpeedupTensorRT(BaseModelSpeedup): class ModelSpeedupTensorRT(BaseModelSpeedup):
r"""
Parameters
----------
model : pytorch model
The model to speedup by quantization.
input_shape : tuple
The input shape of model, shall pass it to torch.onnx.export.
config : dict
Config recording bits number and name of layers.
onnx_path : str
The path user want to store onnx model which is converted from pytorch model.
extra_layer_bits : int
Other layers which are not in config will be quantized to corresponding bits number.
strict_datatype : bool
Whether constrain layer bits to the number given in config or not. If true, all the layer
will be set to given bits strictly. Otherwise, these layers will be set automatically by
tensorrt.
calibrate_type : tensorrt.tensorrt.CalibrationAlgoType
The algorithm of calibrating. Please refer to https://docs.nvidia.com/deeplearning/
tensorrt/api/python_api/infer/Int8/Calibrator.html for detail
calibrate_data : numpy array
The data using to calibrate quantization model
calibration_cache : str
The path user want to store calibrate cache file
batchsize : int
The batch size of calibration and inference
input_names : list
Input name of onnx model providing for torch.onnx.export to generate onnx model
output_name : list
Output name of onnx model providing for torch.onnx.export to generate onnx model
"""
def __init__(self, model, input_shape, config=None, onnx_path="default_model.onnx", extra_layer_bits=32, strict_datatype=True, def __init__(self, model, input_shape, config=None, onnx_path="default_model.onnx", extra_layer_bits=32, strict_datatype=True,
calibrate_type=CalibrateType.ENTROPY2, calib_data_loader=None, calibration_cache = "calibration.cache", batchsize=1, calibrate_type=CalibrateType.ENTROPY2, calib_data_loader=None, calibration_cache = "calibration.cache", batchsize=1,
input_names=["actual_input_1"], output_names=["output1"]): input_names=["actual_input_1"], output_names=["output1"]):
"""
Parameters
----------
model : pytorch model
The model to speed up by quantization.
input_shape : tuple
The input shape of model, shall pass it to torch.onnx.export.
config : dict
Config recording bits number and name of layers.
onnx_path : str
The path user want to store onnx model which is converted from pytorch model.
extra_layer_bits : int
Other layers which are not in config will be quantized to corresponding bits number.
strict_datatype : bool
Whether constrain layer bits to the number given in config or not. If true, all the layer
will be set to given bits strictly. Otherwise, these layers will be set automatically by
tensorrt.
calibrate_type : tensorrt.tensorrt.CalibrationAlgoType
The algorithm of calibrating. Please refer to https://docs.nvidia.com/deeplearning/
tensorrt/api/python_api/infer/Int8/Calibrator.html for detail
calibrate_data : numpy array
The data using to calibrate quantization model
calibration_cache : str
The path user want to store calibrate cache file
batchsize : int
The batch size of calibration and inference
input_names : list
Input name of onnx model providing for torch.onnx.export to generate onnx model
output_name : list
Output name of onnx model providing for torch.onnx.export to generate onnx model
"""
super().__init__(model, config) super().__init__(model, config)
self.model = model self.model = model
self.onnx_path = onnx_path self.onnx_path = onnx_path
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .compressor import ModelSpeedup from .compressor import ModelSpeedup
\ No newline at end of file
...@@ -16,6 +16,7 @@ replace_module = { ...@@ -16,6 +16,7 @@ replace_module = {
'MaxPool2d': lambda module, masks: no_replace(module, masks), 'MaxPool2d': lambda module, masks: no_replace(module, masks),
'AvgPool2d': lambda module, masks: no_replace(module, masks), 'AvgPool2d': lambda module, masks: no_replace(module, masks),
'AdaptiveAvgPool2d': lambda module, masks: no_replace(module, masks), 'AdaptiveAvgPool2d': lambda module, masks: no_replace(module, masks),
'ZeroPad2d': lambda module, masks: no_replace(module, masks),
'ReLU': lambda module, masks: no_replace(module, masks), 'ReLU': lambda module, masks: no_replace(module, masks),
'ReLU6': lambda module, masks: no_replace(module, masks), 'ReLU6': lambda module, masks: no_replace(module, masks),
'LeakyReLU': lambda module, masks: no_replace(module, masks), 'LeakyReLU': lambda module, masks: no_replace(module, masks),
...@@ -41,7 +42,8 @@ replace_module = { ...@@ -41,7 +42,8 @@ replace_module = {
'Dropout3d': lambda module, masks: no_replace(module, masks), 'Dropout3d': lambda module, masks: no_replace(module, masks),
'Upsample': lambda module, masks: no_replace(module, masks), 'Upsample': lambda module, masks: no_replace(module, masks),
'LayerNorm': lambda module, masks: replace_layernorm(module, masks), 'LayerNorm': lambda module, masks: replace_layernorm(module, masks),
'ConvTranspose2d': lambda module, masks: replace_convtranspose2d(module, masks) 'ConvTranspose2d': lambda module, masks: replace_convtranspose2d(module, masks),
'Flatten': lambda module, masks: no_replace(module, masks)
} }
......
...@@ -29,7 +29,7 @@ class ModelSpeedup: ...@@ -29,7 +29,7 @@ class ModelSpeedup:
Parameters Parameters
---------- ----------
model : pytorch model model : pytorch model
The model user wants to speed up The model user wants to speedup
dummy_input : pytorch tensor, tuple of tensor, list of tensor dummy_input : pytorch tensor, tuple of tensor, list of tensor
Note: The first dimension of the dummy_input should be the batchsize. Note: The first dimension of the dummy_input should be the batchsize.
The dummy input for ```jit.trace```, users should put it on the right The dummy input for ```jit.trace```, users should put it on the right
...@@ -388,6 +388,9 @@ class ModelSpeedup: ...@@ -388,6 +388,9 @@ class ModelSpeedup:
def replace_submodule(self, unique_name, reindex_dim=None, reindex=None): def replace_submodule(self, unique_name, reindex_dim=None, reindex=None):
""" """
Replace the submodule according to the inferred sparsity. Replace the submodule according to the inferred sparsity.
Parameters
----------
unique_name: str unique_name: str
The unique_name of the submodule to replace. The unique_name of the submodule to replace.
reindex_dim: int reindex_dim: int
...@@ -496,7 +499,7 @@ class ModelSpeedup: ...@@ -496,7 +499,7 @@ class ModelSpeedup:
second, replace modules. second, replace modules.
""" """
_logger.info("start to speed up the model") _logger.info("start to speedup the model")
self.initialize_speedup() self.initialize_speedup()
training = self.bound_model.training training = self.bound_model.training
# set to the evaluation mode # set to the evaluation mode
......
...@@ -171,10 +171,14 @@ class AutoMaskInference: ...@@ -171,10 +171,14 @@ class AutoMaskInference:
# apply the input mask # apply the input mask
for tid, in_tensor in enumerate(self.dummy_input): for tid, in_tensor in enumerate(self.dummy_input):
if isinstance(in_tensor, torch.Tensor) and self.in_masks[tid] is not None: if isinstance(in_tensor, torch.Tensor) and self.in_masks[tid] is not None:
# in_tensor.data = in_tensor.data * \
# self.in_masks[tid] + \
# (1-self.in_masks[tid]) * self.in_constants[tid]
# issue-4540 when two tensors are multiplied, the constants part make
# the propagation weaker, and lead to shape misaligment. Currently, we
# donnot support the constant folding, so, we just remove the constant here
in_tensor.data = in_tensor.data * \ in_tensor.data = in_tensor.data * \
self.in_masks[tid] + \ self.in_masks[tid]
(1-self.in_masks[tid]) * self.in_constants[tid]
def __apply_weight_mask(self): def __apply_weight_mask(self):
""" """
......
...@@ -10,16 +10,31 @@ import torch ...@@ -10,16 +10,31 @@ import torch
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
# to exclude partial
__all__ = [
'adaptive_avgpool_python', 'add_python', 'avgpool2d_python', 'cat_python', 'contiguous_python',
'div_python', 'dropout_python', 'exp_python', 'flatten_python', 'floor_div_python', 'gelu_python',
'getattr_python', 'jit_to_python_function', 'matmul_python', 'mean_python',
'mul_python', 'num2tensor_python', 'parse_constant', 'permute_python', 'relu_inplace_python',
'relu_python', 'reshape_python', 'select_python', 'sigmoid_python', 'size_python', 'slice_python',
'softmax_python', 'squeeze_python', 'to_python', 'toint_python', 'torch', 'trans_from_jit_to_python',
'translate_list', 'transpose2_python', 'transpose_python', 'tupleunpack_python', 'typeas_python',
'unsqueeze_python', 'upsample_bilinear2d_python', 'view_python'
]
def translate_list(list_node, speedup=None): def translate_list(list_node, speedup=None):
""" """
Get the list of values from the list construct node. Get the list of values from the list construct node.
Parameters Parameters
--------- ----------
list_node: Torch.C.Value list_node: Torch.C.Value
The cpp node of the target list. The cpp node of the target list.
speedup: ModuleSpeed speedup: ModuleSpeed
The Module speedup module. The Module speedup module.
Returns Returns
------- -------
values: list values: list
...@@ -45,12 +60,14 @@ def translate_list(list_node, speedup=None): ...@@ -45,12 +60,14 @@ def translate_list(list_node, speedup=None):
def parse_constant(cvalue, speedup): def parse_constant(cvalue, speedup):
""" """
Parse the constant values from this Node Parse the constant values from this Node
Parameters Parameters
---------- ----------
cvalue: Torch.C.Value cvalue: Torch.C.Value
The cpp node of the target constant value. The cpp node of the target constant value.
speedup: ModelSpeedup speedup: ModelSpeedup
The Model speedup module. The Model speedup module.
Returns Returns
------- -------
value: int/float/tensor value: int/float/tensor
...@@ -125,6 +142,29 @@ def add_python(node, speedup): ...@@ -125,6 +142,29 @@ def add_python(node, speedup):
return new_add return new_add
def sub_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
constant = [None, None]
for i in range(2):
input_i = inputs[i]
debug_name = input_i.debugName()
if debug_name not in speedup.internal_result:
# this input is a constant value
# TODO: what if this input is a constant tensor
if input_i.toIValue() is not None:
constant[i] = parse_constant(input_i, speedup)
break
if constant[0] is None and constant[1] is None:
new_sub = torch.sub
elif constant[0] is not None:
new_sub = partial(torch.sub, input=constant)
else:
new_sub = partial(torch.sub, other=constant)
return new_sub
def floor_div_python(node, speedup): def floor_div_python(node, speedup):
c_node = node.key_node c_node = node.key_node
inputs = list(c_node.inputs()) inputs = list(c_node.inputs())
...@@ -211,6 +251,10 @@ def gelu_python(node, speedup): ...@@ -211,6 +251,10 @@ def gelu_python(node, speedup):
return torch.nn.GELU() return torch.nn.GELU()
def silu_python(node, speedup):
return torch.nn.SiLU()
def avgpool2d_python(node, speedup): def avgpool2d_python(node, speedup):
c_node = node.key_node c_node = node.key_node
inputs = list(c_node.inputs()) inputs = list(c_node.inputs())
...@@ -260,6 +304,14 @@ def unsqueeze_python(node, speedup): ...@@ -260,6 +304,14 @@ def unsqueeze_python(node, speedup):
new_unsqueeze = partial(torch.unsqueeze, dim=dim) new_unsqueeze = partial(torch.unsqueeze, dim=dim)
return new_unsqueeze return new_unsqueeze
def constant_pad_nd_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
pad = translate_list(inputs[1], speedup)
value = parse_constant(inputs[2], speedup)
new_constant_pad_nd = partial(torch.nn.functional.pad, pad=pad, value=value)
return new_constant_pad_nd
########################################################## ##########################################################
# Split Line # Split Line
# Following module/functions cannot be translated into a # Following module/functions cannot be translated into a
...@@ -362,7 +414,7 @@ def reshape_python(node, speedup): ...@@ -362,7 +414,7 @@ def reshape_python(node, speedup):
logger.info('Reshape Module output size: %s', str(self.shape)) logger.info('Reshape Module output size: %s', str(self.shape))
def forward(self, *args): def forward(self, *args):
return args[0].view(self.shape) return args[0].reshape(self.shape)
c_node = node.key_node c_node = node.key_node
inputs = list(c_node.inputs()) inputs = list(c_node.inputs())
shape = translate_list(inputs[1], speedup) shape = translate_list(inputs[1], speedup)
...@@ -488,6 +540,8 @@ def cat_python(node, speedup): ...@@ -488,6 +540,8 @@ def cat_python(node, speedup):
trans_from_jit_to_python = { trans_from_jit_to_python = {
'aten::add': add_python, 'aten::add': add_python,
'aten::add_': add_python, 'aten::add_': add_python,
'aten::sub': sub_python,
'aten::sub_': sub_python,
'aten::mul': mul_python, 'aten::mul': mul_python,
'aten::mul_': mul_python, 'aten::mul_': mul_python,
'aten::relu': relu_python, 'aten::relu': relu_python,
...@@ -525,6 +579,8 @@ trans_from_jit_to_python = { ...@@ -525,6 +579,8 @@ trans_from_jit_to_python = {
'aten::exp': exp_python, 'aten::exp': exp_python,
'aten::squeeze': squeeze_python, 'aten::squeeze': squeeze_python,
'aten::unsqueeze': unsqueeze_python, 'aten::unsqueeze': unsqueeze_python,
'aten::constant_pad_nd': constant_pad_nd_python,
'aten::silu': silu_python,
'prim::TupleUnpack': tupleunpack_python, 'prim::TupleUnpack': tupleunpack_python,
'prim::ListUnpack': tupleunpack_python, 'prim::ListUnpack': tupleunpack_python,
'prim::NumToTensor': num2tensor_python, 'prim::NumToTensor': num2tensor_python,
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .counter import count_flops_params
from .mask_conflict import ChannelMaskConflict, GroupMaskConflict
from .utils import * from .utils import *
from .sensitivity_analysis import SensitivityAnalysis
from .shape_dependency import * from .shape_dependency import *
from .shape_dependency import ReshapeDependency
def not_safe_to_prune(model, dummy_input): def not_safe_to_prune(model, dummy_input):
""" """
......
...@@ -81,23 +81,23 @@ class MaskFix: ...@@ -81,23 +81,23 @@ class MaskFix:
class GroupMaskConflict(MaskFix): class GroupMaskConflict(MaskFix):
"""
GroupMaskConflict fix the mask conflict between the layers that
has group dependecy with each other.
Parameters
----------
masks : dict
a dict object that stores the masks
model : torch.nn.Module
model to fix the mask conflict
dummy_input : torch.Tensor
input example to trace the model
traced : torch._C.torch.jit.TopLevelTracedModule
the traced model of the target model, is this parameter is not None,
we donnot use the model and dummpy_input to get the trace graph.
"""
def __init__(self, masks, model, dummy_input, traced=None): def __init__(self, masks, model, dummy_input, traced=None):
"""
GroupMaskConflict fix the mask conflict between the layers that
has group dependecy with each other.
Parameters
----------
masks : dict
a dict object that stores the masks
model : torch.nn.Module
model to fix the mask conflict
dummy_input : torch.Tensor
input example to trace the model
traced : torch._C.torch.jit.TopLevelTracedModule
the traced model of the target model, is this parameter is not None,
we donnot use the model and dummpy_input to get the trace graph.
"""
super(GroupMaskConflict, self).__init__( super(GroupMaskConflict, self).__init__(
masks, model, dummy_input, traced) masks, model, dummy_input, traced)
...@@ -168,23 +168,24 @@ class GroupMaskConflict(MaskFix): ...@@ -168,23 +168,24 @@ class GroupMaskConflict(MaskFix):
class ChannelMaskConflict(MaskFix): class ChannelMaskConflict(MaskFix):
"""
ChannelMaskConflict fix the mask conflict between the layers that
has channel dependecy with each other.
Parameters
----------
masks : dict
a dict object that stores the masks
model : torch.nn.Module
model to fix the mask conflict
dummy_input : torch.Tensor
input example to trace the model
graph : torch._C.torch.jit.TopLevelTracedModule
the traced graph of the target model, is this parameter is not None,
we donnot use the model and dummpy_input to get the trace graph.
"""
def __init__(self, masks, model, dummy_input, traced=None): def __init__(self, masks, model, dummy_input, traced=None):
"""
ChannelMaskConflict fix the mask conflict between the layers that
has channel dependecy with each other.
Parameters
----------
masks : dict
a dict object that stores the masks
model : torch.nn.Module
model to fix the mask conflict
dummy_input : torch.Tensor
input example to trace the model
graph : torch._C.torch.jit.TopLevelTracedModule
the traced graph of the target model, is this parameter is not None,
we donnot use the model and dummpy_input to get the trace graph.
"""
super(ChannelMaskConflict, self).__init__( super(ChannelMaskConflict, self).__init__(
masks, model, dummy_input, traced) masks, model, dummy_input, traced)
self.conv_prune_dim = detect_mask_prune_dim(masks, model) self.conv_prune_dim = detect_mask_prune_dim(masks, model)
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
def get_total_num_weights(model, op_types=['default']): def get_total_num_weights(model, op_types=['default']):
''' '''
calculate the total number of weights calculate the total number of weights
......
...@@ -18,51 +18,52 @@ logger.setLevel(logging.INFO) ...@@ -18,51 +18,52 @@ logger.setLevel(logging.INFO)
class SensitivityAnalysis: class SensitivityAnalysis:
def __init__(self, model, val_func, sparsities=None, prune_type='l1', early_stop_mode=None, early_stop_value=None): """
""" Perform sensitivity analysis for this model.
Perform sensitivity analysis for this model.
Parameters Parameters
---------- ----------
model : torch.nn.Module model : torch.nn.Module
the model to perform sensitivity analysis the model to perform sensitivity analysis
val_func : function val_func : function
validation function for the model. Due to validation function for the model. Due to
different models may need different dataset/criterion different models may need different dataset/criterion
, therefore the user need to cover this part by themselves. , therefore the user need to cover this part by themselves.
In the val_func, the model should be tested on the validation dateset, In the val_func, the model should be tested on the validation dateset,
and the validation accuracy/loss should be returned as the output of val_func. and the validation accuracy/loss should be returned as the output of val_func.
There are no restrictions on the input parameters of the val_function. There are no restrictions on the input parameters of the val_function.
User can use the val_args, val_kwargs parameters in analysis User can use the val_args, val_kwargs parameters in analysis
to pass all the parameters that val_func needed. to pass all the parameters that val_func needed.
sparsities : list sparsities : list
The sparsity list provided by users. This parameter is set when the user The sparsity list provided by users. This parameter is set when the user
only wants to test some specific sparsities. In the sparsity list, each element only wants to test some specific sparsities. In the sparsity list, each element
is a sparsity value which means how much weight the pruner should prune. Take is a sparsity value which means how much weight the pruner should prune. Take
[0.25, 0.5, 0.75] for an example, the SensitivityAnalysis will prune 25% 50% 75% [0.25, 0.5, 0.75] for an example, the SensitivityAnalysis will prune 25% 50% 75%
weights gradually for each layer. weights gradually for each layer.
prune_type : str prune_type : str
The pruner type used to prune the conv layers, default is 'l1', The pruner type used to prune the conv layers, default is 'l1',
and 'l2', 'fine-grained' is also supported. and 'l2', 'fine-grained' is also supported.
early_stop_mode : str early_stop_mode : str
If this flag is set, the sensitivity analysis If this flag is set, the sensitivity analysis
for a conv layer will early stop when the validation metric( for a conv layer will early stop when the validation metric(
for example, accurracy/loss) has alreay meet the threshold. We for example, accurracy/loss) has alreay meet the threshold. We
support four different early stop modes: minimize, maximize, dropped, support four different early stop modes: minimize, maximize, dropped,
raised. The default value is None, which means the analysis won't stop raised. The default value is None, which means the analysis won't stop
until all given sparsities are tested. This option should be used with until all given sparsities are tested. This option should be used with
early_stop_value together. early_stop_value together.
minimize: The analysis stops when the validation metric return by the val_func minimize: The analysis stops when the validation metric return by the val_func
lower than early_stop_value. lower than early_stop_value.
maximize: The analysis stops when the validation metric return by the val_func maximize: The analysis stops when the validation metric return by the val_func
larger than early_stop_value. larger than early_stop_value.
dropped: The analysis stops when the validation metric has dropped by early_stop_value. dropped: The analysis stops when the validation metric has dropped by early_stop_value.
raised: The analysis stops when the validation metric has raised by early_stop_value. raised: The analysis stops when the validation metric has raised by early_stop_value.
early_stop_value : float early_stop_value : float
This value is used as the threshold for different earlystop modes. This value is used as the threshold for different earlystop modes.
This value is effective only when the early_stop_mode is set. This value is effective only when the early_stop_mode is set.
"""
""" def __init__(self, model, val_func, sparsities=None, prune_type='l1', early_stop_mode=None, early_stop_value=None):
from nni.algorithms.compression.pytorch.pruning.constants_pruner import PRUNER_DICT from nni.algorithms.compression.pytorch.pruning.constants_pruner import PRUNER_DICT
self.model = model self.model = model
......
...@@ -10,7 +10,7 @@ from nni.algorithms.compression.v2.pytorch.base import PrunerModuleWrapper as Pr ...@@ -10,7 +10,7 @@ from nni.algorithms.compression.v2.pytorch.base import PrunerModuleWrapper as Pr
from .utils import get_module_by_name from .utils import get_module_by_name
__all__ = ['ChannelDependency', 'GroupDependency', __all__ = ['ChannelDependency', 'GroupDependency', 'ReshapeDependency',
'InputChannelDependency', 'AttentionWeightDependency'] 'InputChannelDependency', 'AttentionWeightDependency']
...@@ -91,24 +91,26 @@ def reshape_break_channel_dependency(op_node): ...@@ -91,24 +91,26 @@ def reshape_break_channel_dependency(op_node):
class ChannelDependency(Dependency): class ChannelDependency(Dependency):
"""
This model analyze the channel dependencies between the conv
layers in a model.
Parameters
----------
model : torch.nn.Module
The model to be analyzed.
data : torch.Tensor
The example input data to trace the network architecture.
traced_model : torch._C.Graph
if we alreay has the traced graph of the target model, we donnot
need to trace the model again.
prune_type: str
This parameter indicates the channel pruning type: 1) `Filter`
prune the filter of the convolution layer to prune the corresponding
channels 2) `Batchnorm`: prune the channel in the batchnorm layer
"""
def __init__(self, model, dummy_input, traced_model=None, prune_type='Filter'): def __init__(self, model, dummy_input, traced_model=None, prune_type='Filter'):
"""
This model analyze the channel dependencies between the conv
layers in a model.
Parameters
----------
model : torch.nn.Module
The model to be analyzed.
data : torch.Tensor
The example input data to trace the network architecture.
traced_model : torch._C.Graph
if we alreay has the traced graph of the target model, we donnot
need to trace the model again.
prune_type: str
This parameter indicates the channel pruning type: 1) `Filter`
prune the filter of the convolution layer to prune the corresponding
channels 2) `Batchnorm`: prune the channel in the batchnorm layer
"""
self.prune_type = prune_type self.prune_type = prune_type
self.target_types = [] self.target_types = []
if self.prune_type == 'Filter': if self.prune_type == 'Filter':
...@@ -163,7 +165,13 @@ class ChannelDependency(Dependency): ...@@ -163,7 +165,13 @@ class ChannelDependency(Dependency):
parent_layers = [] parent_layers = []
# find the node that contains aten::add # find the node that contains aten::add
# or aten::cat operations # or aten::cat operations
if node.op_type in ADD_TYPES: if node.op_type in ADD_TYPES or node.op_type in MUL_TYPES:
# refer issue 4540 for more details. Multiplication actually
# will not introduce the channel dependency, cause the misaligned
# channels can propagate to each other. However, when one of the input
# tensor is from skip connection(residual), the channel propagation
# may be failed(the input is also used by another layer and cannot be
# pruned), in this case, we need to fix the conflict maunally.
parent_layers = self._get_parent_layers(node) parent_layers = self._get_parent_layers(node)
elif node.op_type == CAT_TYPE: elif node.op_type == CAT_TYPE:
# To determine if this cat operation will introduce channel # To determine if this cat operation will introduce channel
...@@ -271,6 +279,7 @@ class InputChannelDependency(ChannelDependency): ...@@ -271,6 +279,7 @@ class InputChannelDependency(ChannelDependency):
""" """
This model analyze the input channel dependencies between the conv This model analyze the input channel dependencies between the conv
layers in a model. layers in a model.
Parameters Parameters
---------- ----------
model : torch.nn.Module model : torch.nn.Module
...@@ -329,20 +338,22 @@ class InputChannelDependency(ChannelDependency): ...@@ -329,20 +338,22 @@ class InputChannelDependency(ChannelDependency):
class GroupDependency(Dependency): class GroupDependency(Dependency):
"""
This model analyze the group dependencis between the conv
layers in a model.
Parameters
----------
model : torch.nn.Module
The model to be analyzed.
data : torch.Tensor
The example input data to trace the network architecture.
traced_model : torch._C.Graph
if we alreay has the traced graph of the target model, we donnot
need to trace the model again.
"""
def __init__(self, model, dummy_input, traced_model=None): def __init__(self, model, dummy_input, traced_model=None):
"""
This model analyze the group dependencis between the conv
layers in a model.
Parameters
----------
model : torch.nn.Module
The model to be analyzed.
data : torch.Tensor
The example input data to trace the network architecture.
traced_model : torch._C.Graph
if we alreay has the traced graph of the target model, we donnot
need to trace the model again.
"""
self.min_groups = {} self.min_groups = {}
super(GroupDependency, self).__init__(model, dummy_input, traced_model) super(GroupDependency, self).__init__(model, dummy_input, traced_model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment