Unverified Commit 22165cea authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[Doc] update compression reference (#4667)

parent de6662a4
...@@ -631,6 +631,7 @@ class Quantizer(Compressor): ...@@ -631,6 +631,7 @@ class Quantizer(Compressor):
""" """
quantize should overload this method to quantize weight. quantize should overload this method to quantize weight.
This method is effectively hooked to :meth:`forward` of the model. This method is effectively hooked to :meth:`forward` of the model.
Parameters Parameters
---------- ----------
wrapper : QuantizerModuleWrapper wrapper : QuantizerModuleWrapper
...@@ -642,6 +643,7 @@ class Quantizer(Compressor): ...@@ -642,6 +643,7 @@ class Quantizer(Compressor):
""" """
quantize should overload this method to quantize output. quantize should overload this method to quantize output.
This method is effectively hooked to :meth:`forward` of the model. This method is effectively hooked to :meth:`forward` of the model.
Parameters Parameters
---------- ----------
output : Tensor output : Tensor
...@@ -655,6 +657,7 @@ class Quantizer(Compressor): ...@@ -655,6 +657,7 @@ class Quantizer(Compressor):
""" """
quantize should overload this method to quantize input. quantize should overload this method to quantize input.
This method is effectively hooked to :meth:`forward` of the model. This method is effectively hooked to :meth:`forward` of the model.
Parameters Parameters
---------- ----------
inputs : Tensor inputs : Tensor
...@@ -908,6 +911,7 @@ class QuantGrad(torch.autograd.Function): ...@@ -908,6 +911,7 @@ class QuantGrad(torch.autograd.Function):
def _quantize(cls, x, scale, zero_point): def _quantize(cls, x, scale, zero_point):
""" """
Reference function for quantizing x -- non-clamped. Reference function for quantizing x -- non-clamped.
Parameters Parameters
---------- ----------
x : Tensor x : Tensor
...@@ -916,6 +920,7 @@ class QuantGrad(torch.autograd.Function): ...@@ -916,6 +920,7 @@ class QuantGrad(torch.autograd.Function):
scale for quantizing x scale for quantizing x
zero_point : Tensor zero_point : Tensor
zero_point for quantizing x zero_point for quantizing x
Returns Returns
------- -------
tensor tensor
...@@ -927,12 +932,14 @@ class QuantGrad(torch.autograd.Function): ...@@ -927,12 +932,14 @@ class QuantGrad(torch.autograd.Function):
def get_bits_length(cls, config, quant_type): def get_bits_length(cls, config, quant_type):
""" """
Get bits for quantize config Get bits for quantize config
Parameters Parameters
---------- ----------
config : Dict config : Dict
the configuration for quantization the configuration for quantization
quant_type : str quant_type : str
quant type quant type
Returns Returns
------- -------
int int
...@@ -948,6 +955,7 @@ class QuantGrad(torch.autograd.Function): ...@@ -948,6 +955,7 @@ class QuantGrad(torch.autograd.Function):
""" """
This method should be overrided by subclass to provide customized backward function, This method should be overrided by subclass to provide customized backward function,
default implementation is Straight-Through Estimator default implementation is Straight-Through Estimator
Parameters Parameters
---------- ----------
tensor : Tensor tensor : Tensor
...@@ -963,6 +971,7 @@ class QuantGrad(torch.autograd.Function): ...@@ -963,6 +971,7 @@ class QuantGrad(torch.autograd.Function):
quant_min for quantizing tensor quant_min for quantizing tensor
qmax : Tensor qmax : Tensor
quant_max for quantizng tensor quant_max for quantizng tensor
Returns Returns
------- -------
tensor tensor
......
...@@ -228,10 +228,7 @@ def build_engine(model_file, config=None, extra_layer_bits=32, strict_datatype=F ...@@ -228,10 +228,7 @@ def build_engine(model_file, config=None, extra_layer_bits=32, strict_datatype=F
return engine return engine
class ModelSpeedupTensorRT(BaseModelSpeedup): class ModelSpeedupTensorRT(BaseModelSpeedup):
def __init__(self, model, input_shape, config=None, onnx_path="default_model.onnx", extra_layer_bits=32, strict_datatype=True, r"""
calibrate_type=CalibrateType.ENTROPY2, calib_data_loader=None, calibration_cache = "calibration.cache", batchsize=1,
input_names=["actual_input_1"], output_names=["output1"]):
"""
Parameters Parameters
---------- ----------
model : pytorch model model : pytorch model
...@@ -262,6 +259,10 @@ class ModelSpeedupTensorRT(BaseModelSpeedup): ...@@ -262,6 +259,10 @@ class ModelSpeedupTensorRT(BaseModelSpeedup):
output_name : list output_name : list
Output name of onnx model providing for torch.onnx.export to generate onnx model Output name of onnx model providing for torch.onnx.export to generate onnx model
""" """
def __init__(self, model, input_shape, config=None, onnx_path="default_model.onnx", extra_layer_bits=32, strict_datatype=True,
calibrate_type=CalibrateType.ENTROPY2, calib_data_loader=None, calibration_cache = "calibration.cache", batchsize=1,
input_names=["actual_input_1"], output_names=["output1"]):
super().__init__(model, config) super().__init__(model, config)
self.model = model self.model = model
self.onnx_path = onnx_path self.onnx_path = onnx_path
......
...@@ -388,6 +388,9 @@ class ModelSpeedup: ...@@ -388,6 +388,9 @@ class ModelSpeedup:
def replace_submodule(self, unique_name, reindex_dim=None, reindex=None): def replace_submodule(self, unique_name, reindex_dim=None, reindex=None):
""" """
Replace the submodule according to the inferred sparsity. Replace the submodule according to the inferred sparsity.
Parameters
----------
unique_name: str unique_name: str
The unique_name of the submodule to replace. The unique_name of the submodule to replace.
reindex_dim: int reindex_dim: int
......
...@@ -81,7 +81,6 @@ class MaskFix: ...@@ -81,7 +81,6 @@ class MaskFix:
class GroupMaskConflict(MaskFix): class GroupMaskConflict(MaskFix):
def __init__(self, masks, model, dummy_input, traced=None):
""" """
GroupMaskConflict fix the mask conflict between the layers that GroupMaskConflict fix the mask conflict between the layers that
has group dependecy with each other. has group dependecy with each other.
...@@ -98,6 +97,7 @@ class GroupMaskConflict(MaskFix): ...@@ -98,6 +97,7 @@ class GroupMaskConflict(MaskFix):
the traced model of the target model, is this parameter is not None, the traced model of the target model, is this parameter is not None,
we donnot use the model and dummpy_input to get the trace graph. we donnot use the model and dummpy_input to get the trace graph.
""" """
def __init__(self, masks, model, dummy_input, traced=None):
super(GroupMaskConflict, self).__init__( super(GroupMaskConflict, self).__init__(
masks, model, dummy_input, traced) masks, model, dummy_input, traced)
...@@ -168,7 +168,6 @@ class GroupMaskConflict(MaskFix): ...@@ -168,7 +168,6 @@ class GroupMaskConflict(MaskFix):
class ChannelMaskConflict(MaskFix): class ChannelMaskConflict(MaskFix):
def __init__(self, masks, model, dummy_input, traced=None):
""" """
ChannelMaskConflict fix the mask conflict between the layers that ChannelMaskConflict fix the mask conflict between the layers that
has channel dependecy with each other. has channel dependecy with each other.
...@@ -185,6 +184,8 @@ class ChannelMaskConflict(MaskFix): ...@@ -185,6 +184,8 @@ class ChannelMaskConflict(MaskFix):
the traced graph of the target model, is this parameter is not None, the traced graph of the target model, is this parameter is not None,
we donnot use the model and dummpy_input to get the trace graph. we donnot use the model and dummpy_input to get the trace graph.
""" """
def __init__(self, masks, model, dummy_input, traced=None):
super(ChannelMaskConflict, self).__init__( super(ChannelMaskConflict, self).__init__(
masks, model, dummy_input, traced) masks, model, dummy_input, traced)
self.conv_prune_dim = detect_mask_prune_dim(masks, model) self.conv_prune_dim = detect_mask_prune_dim(masks, model)
......
...@@ -18,9 +18,9 @@ logger.setLevel(logging.INFO) ...@@ -18,9 +18,9 @@ logger.setLevel(logging.INFO)
class SensitivityAnalysis: class SensitivityAnalysis:
def __init__(self, model, val_func, sparsities=None, prune_type='l1', early_stop_mode=None, early_stop_value=None):
""" """
Perform sensitivity analysis for this model. Perform sensitivity analysis for this model.
Parameters Parameters
---------- ----------
model : torch.nn.Module model : torch.nn.Module
...@@ -61,8 +61,9 @@ class SensitivityAnalysis: ...@@ -61,8 +61,9 @@ class SensitivityAnalysis:
early_stop_value : float early_stop_value : float
This value is used as the threshold for different earlystop modes. This value is used as the threshold for different earlystop modes.
This value is effective only when the early_stop_mode is set. This value is effective only when the early_stop_mode is set.
""" """
def __init__(self, model, val_func, sparsities=None, prune_type='l1', early_stop_mode=None, early_stop_value=None):
from nni.algorithms.compression.pytorch.pruning.constants_pruner import PRUNER_DICT from nni.algorithms.compression.pytorch.pruning.constants_pruner import PRUNER_DICT
self.model = model self.model = model
......
...@@ -91,10 +91,10 @@ def reshape_break_channel_dependency(op_node): ...@@ -91,10 +91,10 @@ def reshape_break_channel_dependency(op_node):
class ChannelDependency(Dependency): class ChannelDependency(Dependency):
def __init__(self, model, dummy_input, traced_model=None, prune_type='Filter'):
""" """
This model analyze the channel dependencies between the conv This model analyze the channel dependencies between the conv
layers in a model. layers in a model.
Parameters Parameters
---------- ----------
model : torch.nn.Module model : torch.nn.Module
...@@ -109,6 +109,8 @@ class ChannelDependency(Dependency): ...@@ -109,6 +109,8 @@ class ChannelDependency(Dependency):
prune the filter of the convolution layer to prune the corresponding prune the filter of the convolution layer to prune the corresponding
channels 2) `Batchnorm`: prune the channel in the batchnorm layer channels 2) `Batchnorm`: prune the channel in the batchnorm layer
""" """
def __init__(self, model, dummy_input, traced_model=None, prune_type='Filter'):
self.prune_type = prune_type self.prune_type = prune_type
self.target_types = [] self.target_types = []
if self.prune_type == 'Filter': if self.prune_type == 'Filter':
...@@ -271,6 +273,7 @@ class InputChannelDependency(ChannelDependency): ...@@ -271,6 +273,7 @@ class InputChannelDependency(ChannelDependency):
""" """
This model analyze the input channel dependencies between the conv This model analyze the input channel dependencies between the conv
layers in a model. layers in a model.
Parameters Parameters
---------- ----------
model : torch.nn.Module model : torch.nn.Module
...@@ -329,10 +332,10 @@ class InputChannelDependency(ChannelDependency): ...@@ -329,10 +332,10 @@ class InputChannelDependency(ChannelDependency):
class GroupDependency(Dependency): class GroupDependency(Dependency):
def __init__(self, model, dummy_input, traced_model=None):
""" """
This model analyze the group dependencis between the conv This model analyze the group dependencis between the conv
layers in a model. layers in a model.
Parameters Parameters
---------- ----------
model : torch.nn.Module model : torch.nn.Module
...@@ -343,6 +346,8 @@ class GroupDependency(Dependency): ...@@ -343,6 +346,8 @@ class GroupDependency(Dependency):
if we alreay has the traced graph of the target model, we donnot if we alreay has the traced graph of the target model, we donnot
need to trace the model again. need to trace the model again.
""" """
def __init__(self, model, dummy_input, traced_model=None):
self.min_groups = {} self.min_groups = {}
super(GroupDependency, self).__init__(model, dummy_input, traced_model) super(GroupDependency, self).__init__(model, dummy_input, traced_model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment