Unverified Commit a911b856 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Resolve conflicts for #4760 (#4762)

parent 14d2966b
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .apply_compression import apply_compression_results
from nni.algorithms.compression.v2.pytorch.pruning import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from enum import Enum, EnumMeta
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from torch.quantization import default_weight_observer, default_histogram_observer
from torch.quantization import RecordingObserver as _RecordingObserver
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Any, Optional
from .literal import QuantDtype, QuantType, QuantScheme
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
from nni.common.version import TORCH_VERSION
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .integrated_tensorrt import CalibrateType, ModelSpeedupTensorRT
\ No newline at end of file
......@@ -10,7 +10,7 @@ class BaseModelSpeedup:
Parameters
----------
model : pytorch model
The model to speed up by quantization.
The model to speedup by quantization.
config : dict
Config recording bit number and name of layers.
"""
......
......@@ -37,7 +37,7 @@ def _setattr(model, name, module):
Parameters
----------
model : pytorch model
The model to speed up by quantization
The model to speedup by quantization
name : str
name of pytorch module
module : torch.nn.Module
......@@ -98,7 +98,7 @@ def torch_to_onnx(model, config, input_shape, model_path, input_names, output_na
Parameters
----------
model : pytorch model
The model to speed up by quantization
The model to speedup by quantization
config : dict
Config recording bits number and name of layers
input_shape : tuple
......
......@@ -228,14 +228,11 @@ def build_engine(model_file, config=None, extra_layer_bits=32, strict_datatype=F
return engine
class ModelSpeedupTensorRT(BaseModelSpeedup):
def __init__(self, model, input_shape, config=None, onnx_path="default_model.onnx", extra_layer_bits=32, strict_datatype=True,
calibrate_type=CalibrateType.ENTROPY2, calib_data_loader=None, calibration_cache = "calibration.cache", batchsize=1,
input_names=["actual_input_1"], output_names=["output1"]):
"""
r"""
Parameters
----------
model : pytorch model
The model to speed up by quantization.
The model to speedup by quantization.
input_shape : tuple
The input shape of model, shall pass it to torch.onnx.export.
config : dict
......@@ -262,6 +259,10 @@ class ModelSpeedupTensorRT(BaseModelSpeedup):
output_name : list
Output name of onnx model providing for torch.onnx.export to generate onnx model
"""
def __init__(self, model, input_shape, config=None, onnx_path="default_model.onnx", extra_layer_bits=32, strict_datatype=True,
calibrate_type=CalibrateType.ENTROPY2, calib_data_loader=None, calibration_cache = "calibration.cache", batchsize=1,
input_names=["actual_input_1"], output_names=["output1"]):
super().__init__(model, config)
self.model = model
self.onnx_path = onnx_path
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .compressor import ModelSpeedup
\ No newline at end of file
......@@ -16,6 +16,7 @@ replace_module = {
'MaxPool2d': lambda module, masks: no_replace(module, masks),
'AvgPool2d': lambda module, masks: no_replace(module, masks),
'AdaptiveAvgPool2d': lambda module, masks: no_replace(module, masks),
'ZeroPad2d': lambda module, masks: no_replace(module, masks),
'ReLU': lambda module, masks: no_replace(module, masks),
'ReLU6': lambda module, masks: no_replace(module, masks),
'LeakyReLU': lambda module, masks: no_replace(module, masks),
......@@ -41,7 +42,8 @@ replace_module = {
'Dropout3d': lambda module, masks: no_replace(module, masks),
'Upsample': lambda module, masks: no_replace(module, masks),
'LayerNorm': lambda module, masks: replace_layernorm(module, masks),
'ConvTranspose2d': lambda module, masks: replace_convtranspose2d(module, masks)
'ConvTranspose2d': lambda module, masks: replace_convtranspose2d(module, masks),
'Flatten': lambda module, masks: no_replace(module, masks)
}
......
......@@ -29,7 +29,7 @@ class ModelSpeedup:
Parameters
----------
model : pytorch model
The model user wants to speed up
The model user wants to speedup
dummy_input : pytorch tensor, tuple of tensor, list of tensor
Note: The first dimension of the dummy_input should be the batchsize.
The dummy input for ```jit.trace```, users should put it on the right
......@@ -388,6 +388,9 @@ class ModelSpeedup:
def replace_submodule(self, unique_name, reindex_dim=None, reindex=None):
"""
Replace the submodule according to the inferred sparsity.
Parameters
----------
unique_name: str
The unique_name of the submodule to replace.
reindex_dim: int
......@@ -496,7 +499,7 @@ class ModelSpeedup:
second, replace modules.
"""
_logger.info("start to speed up the model")
_logger.info("start to speedup the model")
self.initialize_speedup()
training = self.bound_model.training
# set to the evaluation mode
......
......@@ -171,10 +171,14 @@ class AutoMaskInference:
# apply the input mask
for tid, in_tensor in enumerate(self.dummy_input):
if isinstance(in_tensor, torch.Tensor) and self.in_masks[tid] is not None:
# in_tensor.data = in_tensor.data * \
# self.in_masks[tid] + \
# (1-self.in_masks[tid]) * self.in_constants[tid]
# issue-4540 when two tensors are multiplied, the constants part make
# the propagation weaker, and lead to shape misaligment. Currently, we
# donnot support the constant folding, so, we just remove the constant here
in_tensor.data = in_tensor.data * \
self.in_masks[tid] + \
(1-self.in_masks[tid]) * self.in_constants[tid]
self.in_masks[tid]
def __apply_weight_mask(self):
"""
......
......@@ -10,16 +10,31 @@ import torch
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# to exclude partial
__all__ = [
'adaptive_avgpool_python', 'add_python', 'avgpool2d_python', 'cat_python', 'contiguous_python',
'div_python', 'dropout_python', 'exp_python', 'flatten_python', 'floor_div_python', 'gelu_python',
'getattr_python', 'jit_to_python_function', 'matmul_python', 'mean_python',
'mul_python', 'num2tensor_python', 'parse_constant', 'permute_python', 'relu_inplace_python',
'relu_python', 'reshape_python', 'select_python', 'sigmoid_python', 'size_python', 'slice_python',
'softmax_python', 'squeeze_python', 'to_python', 'toint_python', 'torch', 'trans_from_jit_to_python',
'translate_list', 'transpose2_python', 'transpose_python', 'tupleunpack_python', 'typeas_python',
'unsqueeze_python', 'upsample_bilinear2d_python', 'view_python'
]
def translate_list(list_node, speedup=None):
"""
Get the list of values from the list construct node.
Parameters
---------
----------
list_node: Torch.C.Value
The cpp node of the target list.
speedup: ModuleSpeed
The Module speedup module.
Returns
-------
values: list
......@@ -45,12 +60,14 @@ def translate_list(list_node, speedup=None):
def parse_constant(cvalue, speedup):
"""
Parse the constant values from this Node
Parameters
----------
cvalue: Torch.C.Value
The cpp node of the target constant value.
speedup: ModelSpeedup
The Model speedup module.
Returns
-------
value: int/float/tensor
......@@ -125,6 +142,29 @@ def add_python(node, speedup):
return new_add
def sub_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
constant = [None, None]
for i in range(2):
input_i = inputs[i]
debug_name = input_i.debugName()
if debug_name not in speedup.internal_result:
# this input is a constant value
# TODO: what if this input is a constant tensor
if input_i.toIValue() is not None:
constant[i] = parse_constant(input_i, speedup)
break
if constant[0] is None and constant[1] is None:
new_sub = torch.sub
elif constant[0] is not None:
new_sub = partial(torch.sub, input=constant)
else:
new_sub = partial(torch.sub, other=constant)
return new_sub
def floor_div_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
......@@ -211,6 +251,10 @@ def gelu_python(node, speedup):
return torch.nn.GELU()
def silu_python(node, speedup):
return torch.nn.SiLU()
def avgpool2d_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
......@@ -260,6 +304,14 @@ def unsqueeze_python(node, speedup):
new_unsqueeze = partial(torch.unsqueeze, dim=dim)
return new_unsqueeze
def constant_pad_nd_python(node, speedup):
c_node = node.key_node
inputs = list(c_node.inputs())
pad = translate_list(inputs[1], speedup)
value = parse_constant(inputs[2], speedup)
new_constant_pad_nd = partial(torch.nn.functional.pad, pad=pad, value=value)
return new_constant_pad_nd
##########################################################
# Split Line
# Following module/functions cannot be translated into a
......@@ -362,7 +414,7 @@ def reshape_python(node, speedup):
logger.info('Reshape Module output size: %s', str(self.shape))
def forward(self, *args):
return args[0].view(self.shape)
return args[0].reshape(self.shape)
c_node = node.key_node
inputs = list(c_node.inputs())
shape = translate_list(inputs[1], speedup)
......@@ -488,6 +540,8 @@ def cat_python(node, speedup):
trans_from_jit_to_python = {
'aten::add': add_python,
'aten::add_': add_python,
'aten::sub': sub_python,
'aten::sub_': sub_python,
'aten::mul': mul_python,
'aten::mul_': mul_python,
'aten::relu': relu_python,
......@@ -525,6 +579,8 @@ trans_from_jit_to_python = {
'aten::exp': exp_python,
'aten::squeeze': squeeze_python,
'aten::unsqueeze': unsqueeze_python,
'aten::constant_pad_nd': constant_pad_nd_python,
'aten::silu': silu_python,
'prim::TupleUnpack': tupleunpack_python,
'prim::ListUnpack': tupleunpack_python,
'prim::NumToTensor': num2tensor_python,
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .counter import count_flops_params
from .mask_conflict import ChannelMaskConflict, GroupMaskConflict
from .utils import *
from .sensitivity_analysis import SensitivityAnalysis
from .shape_dependency import *
from .shape_dependency import ReshapeDependency
def not_safe_to_prune(model, dummy_input):
"""
......
......@@ -81,7 +81,6 @@ class MaskFix:
class GroupMaskConflict(MaskFix):
def __init__(self, masks, model, dummy_input, traced=None):
"""
GroupMaskConflict fix the mask conflict between the layers that
has group dependecy with each other.
......@@ -98,6 +97,7 @@ class GroupMaskConflict(MaskFix):
the traced model of the target model, is this parameter is not None,
we donnot use the model and dummpy_input to get the trace graph.
"""
def __init__(self, masks, model, dummy_input, traced=None):
super(GroupMaskConflict, self).__init__(
masks, model, dummy_input, traced)
......@@ -168,7 +168,6 @@ class GroupMaskConflict(MaskFix):
class ChannelMaskConflict(MaskFix):
def __init__(self, masks, model, dummy_input, traced=None):
"""
ChannelMaskConflict fix the mask conflict between the layers that
has channel dependecy with each other.
......@@ -185,6 +184,8 @@ class ChannelMaskConflict(MaskFix):
the traced graph of the target model, is this parameter is not None,
we donnot use the model and dummpy_input to get the trace graph.
"""
def __init__(self, masks, model, dummy_input, traced=None):
super(ChannelMaskConflict, self).__init__(
masks, model, dummy_input, traced)
self.conv_prune_dim = detect_mask_prune_dim(masks, model)
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
def get_total_num_weights(model, op_types=['default']):
'''
calculate the total number of weights
......
......@@ -18,9 +18,9 @@ logger.setLevel(logging.INFO)
class SensitivityAnalysis:
def __init__(self, model, val_func, sparsities=None, prune_type='l1', early_stop_mode=None, early_stop_value=None):
"""
Perform sensitivity analysis for this model.
Parameters
----------
model : torch.nn.Module
......@@ -61,8 +61,9 @@ class SensitivityAnalysis:
early_stop_value : float
This value is used as the threshold for different earlystop modes.
This value is effective only when the early_stop_mode is set.
"""
def __init__(self, model, val_func, sparsities=None, prune_type='l1', early_stop_mode=None, early_stop_value=None):
from nni.algorithms.compression.pytorch.pruning.constants_pruner import PRUNER_DICT
self.model = model
......
......@@ -10,7 +10,7 @@ from nni.algorithms.compression.v2.pytorch.base import PrunerModuleWrapper as Pr
from .utils import get_module_by_name
__all__ = ['ChannelDependency', 'GroupDependency',
__all__ = ['ChannelDependency', 'GroupDependency', 'ReshapeDependency',
'InputChannelDependency', 'AttentionWeightDependency']
......@@ -91,10 +91,10 @@ def reshape_break_channel_dependency(op_node):
class ChannelDependency(Dependency):
def __init__(self, model, dummy_input, traced_model=None, prune_type='Filter'):
"""
This model analyze the channel dependencies between the conv
layers in a model.
Parameters
----------
model : torch.nn.Module
......@@ -109,6 +109,8 @@ class ChannelDependency(Dependency):
prune the filter of the convolution layer to prune the corresponding
channels 2) `Batchnorm`: prune the channel in the batchnorm layer
"""
def __init__(self, model, dummy_input, traced_model=None, prune_type='Filter'):
self.prune_type = prune_type
self.target_types = []
if self.prune_type == 'Filter':
......@@ -163,7 +165,13 @@ class ChannelDependency(Dependency):
parent_layers = []
# find the node that contains aten::add
# or aten::cat operations
if node.op_type in ADD_TYPES:
if node.op_type in ADD_TYPES or node.op_type in MUL_TYPES:
# refer issue 4540 for more details. Multiplication actually
# will not introduce the channel dependency, cause the misaligned
# channels can propagate to each other. However, when one of the input
# tensor is from skip connection(residual), the channel propagation
# may be failed(the input is also used by another layer and cannot be
# pruned), in this case, we need to fix the conflict maunally.
parent_layers = self._get_parent_layers(node)
elif node.op_type == CAT_TYPE:
# To determine if this cat operation will introduce channel
......@@ -271,6 +279,7 @@ class InputChannelDependency(ChannelDependency):
"""
This model analyze the input channel dependencies between the conv
layers in a model.
Parameters
----------
model : torch.nn.Module
......@@ -329,10 +338,10 @@ class InputChannelDependency(ChannelDependency):
class GroupDependency(Dependency):
def __init__(self, model, dummy_input, traced_model=None):
"""
This model analyze the group dependencis between the conv
layers in a model.
Parameters
----------
model : torch.nn.Module
......@@ -343,6 +352,8 @@ class GroupDependency(Dependency):
if we alreay has the traced graph of the target model, we donnot
need to trace the model again.
"""
def __init__(self, model, dummy_input, traced_model=None):
self.min_groups = {}
super(GroupDependency, self).__init__(model, dummy_input, traced_model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment