"src/targets/vscode:/vscode.git/clone" did not exist on "d2c25a07250644e1e8c6efbc5fe9254fc0c30e5d"
Unverified Commit 22165cea authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[Doc] update compression reference (#4667)

parent de6662a4
......@@ -631,6 +631,7 @@ class Quantizer(Compressor):
"""
quantize should overload this method to quantize weight.
This method is effectively hooked to :meth:`forward` of the model.
Parameters
----------
wrapper : QuantizerModuleWrapper
......@@ -642,6 +643,7 @@ class Quantizer(Compressor):
"""
quantize should overload this method to quantize output.
This method is effectively hooked to :meth:`forward` of the model.
Parameters
----------
output : Tensor
......@@ -655,6 +657,7 @@ class Quantizer(Compressor):
"""
quantize should overload this method to quantize input.
This method is effectively hooked to :meth:`forward` of the model.
Parameters
----------
inputs : Tensor
......@@ -908,6 +911,7 @@ class QuantGrad(torch.autograd.Function):
def _quantize(cls, x, scale, zero_point):
"""
Reference function for quantizing x -- non-clamped.
Parameters
----------
x : Tensor
......@@ -916,6 +920,7 @@ class QuantGrad(torch.autograd.Function):
scale for quantizing x
zero_point : Tensor
zero_point for quantizing x
Returns
-------
tensor
......@@ -927,12 +932,14 @@ class QuantGrad(torch.autograd.Function):
def get_bits_length(cls, config, quant_type):
"""
Get bits for quantize config
Parameters
----------
config : Dict
the configuration for quantization
quant_type : str
quant type
Returns
-------
int
......@@ -948,6 +955,7 @@ class QuantGrad(torch.autograd.Function):
"""
This method should be overrided by subclass to provide customized backward function,
default implementation is Straight-Through Estimator
Parameters
----------
tensor : Tensor
......@@ -963,6 +971,7 @@ class QuantGrad(torch.autograd.Function):
quant_min for quantizing tensor
qmax : Tensor
quant_max for quantizng tensor
Returns
-------
tensor
......
......@@ -228,10 +228,7 @@ def build_engine(model_file, config=None, extra_layer_bits=32, strict_datatype=F
return engine
class ModelSpeedupTensorRT(BaseModelSpeedup):
def __init__(self, model, input_shape, config=None, onnx_path="default_model.onnx", extra_layer_bits=32, strict_datatype=True,
calibrate_type=CalibrateType.ENTROPY2, calib_data_loader=None, calibration_cache = "calibration.cache", batchsize=1,
input_names=["actual_input_1"], output_names=["output1"]):
"""
r"""
Parameters
----------
model : pytorch model
......@@ -262,6 +259,10 @@ class ModelSpeedupTensorRT(BaseModelSpeedup):
output_name : list
Output name of onnx model providing for torch.onnx.export to generate onnx model
"""
def __init__(self, model, input_shape, config=None, onnx_path="default_model.onnx", extra_layer_bits=32, strict_datatype=True,
calibrate_type=CalibrateType.ENTROPY2, calib_data_loader=None, calibration_cache = "calibration.cache", batchsize=1,
input_names=["actual_input_1"], output_names=["output1"]):
super().__init__(model, config)
self.model = model
self.onnx_path = onnx_path
......
......@@ -388,6 +388,9 @@ class ModelSpeedup:
def replace_submodule(self, unique_name, reindex_dim=None, reindex=None):
"""
Replace the submodule according to the inferred sparsity.
Parameters
----------
unique_name: str
The unique_name of the submodule to replace.
reindex_dim: int
......
......@@ -81,7 +81,6 @@ class MaskFix:
class GroupMaskConflict(MaskFix):
def __init__(self, masks, model, dummy_input, traced=None):
"""
GroupMaskConflict fix the mask conflict between the layers that
has group dependecy with each other.
......@@ -98,6 +97,7 @@ class GroupMaskConflict(MaskFix):
the traced model of the target model, is this parameter is not None,
we donnot use the model and dummpy_input to get the trace graph.
"""
def __init__(self, masks, model, dummy_input, traced=None):
super(GroupMaskConflict, self).__init__(
masks, model, dummy_input, traced)
......@@ -168,7 +168,6 @@ class GroupMaskConflict(MaskFix):
class ChannelMaskConflict(MaskFix):
def __init__(self, masks, model, dummy_input, traced=None):
"""
ChannelMaskConflict fix the mask conflict between the layers that
has channel dependecy with each other.
......@@ -185,6 +184,8 @@ class ChannelMaskConflict(MaskFix):
the traced graph of the target model, is this parameter is not None,
we donnot use the model and dummpy_input to get the trace graph.
"""
def __init__(self, masks, model, dummy_input, traced=None):
super(ChannelMaskConflict, self).__init__(
masks, model, dummy_input, traced)
self.conv_prune_dim = detect_mask_prune_dim(masks, model)
......
......@@ -18,9 +18,9 @@ logger.setLevel(logging.INFO)
class SensitivityAnalysis:
def __init__(self, model, val_func, sparsities=None, prune_type='l1', early_stop_mode=None, early_stop_value=None):
"""
Perform sensitivity analysis for this model.
Parameters
----------
model : torch.nn.Module
......@@ -61,8 +61,9 @@ class SensitivityAnalysis:
early_stop_value : float
This value is used as the threshold for different earlystop modes.
This value is effective only when the early_stop_mode is set.
"""
def __init__(self, model, val_func, sparsities=None, prune_type='l1', early_stop_mode=None, early_stop_value=None):
from nni.algorithms.compression.pytorch.pruning.constants_pruner import PRUNER_DICT
self.model = model
......
......@@ -91,10 +91,10 @@ def reshape_break_channel_dependency(op_node):
class ChannelDependency(Dependency):
def __init__(self, model, dummy_input, traced_model=None, prune_type='Filter'):
"""
This model analyze the channel dependencies between the conv
layers in a model.
Parameters
----------
model : torch.nn.Module
......@@ -109,6 +109,8 @@ class ChannelDependency(Dependency):
prune the filter of the convolution layer to prune the corresponding
channels 2) `Batchnorm`: prune the channel in the batchnorm layer
"""
def __init__(self, model, dummy_input, traced_model=None, prune_type='Filter'):
self.prune_type = prune_type
self.target_types = []
if self.prune_type == 'Filter':
......@@ -271,6 +273,7 @@ class InputChannelDependency(ChannelDependency):
"""
This model analyze the input channel dependencies between the conv
layers in a model.
Parameters
----------
model : torch.nn.Module
......@@ -329,10 +332,10 @@ class InputChannelDependency(ChannelDependency):
class GroupDependency(Dependency):
def __init__(self, model, dummy_input, traced_model=None):
"""
This model analyze the group dependencis between the conv
layers in a model.
Parameters
----------
model : torch.nn.Module
......@@ -343,6 +346,8 @@ class GroupDependency(Dependency):
if we alreay has the traced graph of the target model, we donnot
need to trace the model again.
"""
def __init__(self, model, dummy_input, traced_model=None):
self.min_groups = {}
super(GroupDependency, self).__init__(model, dummy_input, traced_model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment