Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
22165cea
Unverified
Commit
22165cea
authored
Mar 21, 2022
by
J-shang
Committed by
GitHub
Mar 21, 2022
Browse files
[Doc] update compression reference (#4667)
parent
de6662a4
Changes
26
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
157 additions
and
137 deletions
+157
-137
nni/compression/pytorch/compressor.py
nni/compression/pytorch/compressor.py
+9
-0
nni/compression/pytorch/quantization_speedup/integrated_tensorrt.py
...ssion/pytorch/quantization_speedup/integrated_tensorrt.py
+32
-31
nni/compression/pytorch/speedup/compressor.py
nni/compression/pytorch/speedup/compressor.py
+3
-0
nni/compression/pytorch/utils/mask_conflict.py
nni/compression/pytorch/utils/mask_conflict.py
+33
-32
nni/compression/pytorch/utils/sensitivity_analysis.py
nni/compression/pytorch/utils/sensitivity_analysis.py
+45
-44
nni/compression/pytorch/utils/shape_dependency.py
nni/compression/pytorch/utils/shape_dependency.py
+35
-30
No files found.
nni/compression/pytorch/compressor.py
View file @
22165cea
...
@@ -631,6 +631,7 @@ class Quantizer(Compressor):
...
@@ -631,6 +631,7 @@ class Quantizer(Compressor):
"""
"""
quantize should overload this method to quantize weight.
quantize should overload this method to quantize weight.
This method is effectively hooked to :meth:`forward` of the model.
This method is effectively hooked to :meth:`forward` of the model.
Parameters
Parameters
----------
----------
wrapper : QuantizerModuleWrapper
wrapper : QuantizerModuleWrapper
...
@@ -642,6 +643,7 @@ class Quantizer(Compressor):
...
@@ -642,6 +643,7 @@ class Quantizer(Compressor):
"""
"""
quantize should overload this method to quantize output.
quantize should overload this method to quantize output.
This method is effectively hooked to :meth:`forward` of the model.
This method is effectively hooked to :meth:`forward` of the model.
Parameters
Parameters
----------
----------
output : Tensor
output : Tensor
...
@@ -655,6 +657,7 @@ class Quantizer(Compressor):
...
@@ -655,6 +657,7 @@ class Quantizer(Compressor):
"""
"""
quantize should overload this method to quantize input.
quantize should overload this method to quantize input.
This method is effectively hooked to :meth:`forward` of the model.
This method is effectively hooked to :meth:`forward` of the model.
Parameters
Parameters
----------
----------
inputs : Tensor
inputs : Tensor
...
@@ -908,6 +911,7 @@ class QuantGrad(torch.autograd.Function):
...
@@ -908,6 +911,7 @@ class QuantGrad(torch.autograd.Function):
def
_quantize
(
cls
,
x
,
scale
,
zero_point
):
def
_quantize
(
cls
,
x
,
scale
,
zero_point
):
"""
"""
Reference function for quantizing x -- non-clamped.
Reference function for quantizing x -- non-clamped.
Parameters
Parameters
----------
----------
x : Tensor
x : Tensor
...
@@ -916,6 +920,7 @@ class QuantGrad(torch.autograd.Function):
...
@@ -916,6 +920,7 @@ class QuantGrad(torch.autograd.Function):
scale for quantizing x
scale for quantizing x
zero_point : Tensor
zero_point : Tensor
zero_point for quantizing x
zero_point for quantizing x
Returns
Returns
-------
-------
tensor
tensor
...
@@ -927,12 +932,14 @@ class QuantGrad(torch.autograd.Function):
...
@@ -927,12 +932,14 @@ class QuantGrad(torch.autograd.Function):
def
get_bits_length
(
cls
,
config
,
quant_type
):
def
get_bits_length
(
cls
,
config
,
quant_type
):
"""
"""
Get bits for quantize config
Get bits for quantize config
Parameters
Parameters
----------
----------
config : Dict
config : Dict
the configuration for quantization
the configuration for quantization
quant_type : str
quant_type : str
quant type
quant type
Returns
Returns
-------
-------
int
int
...
@@ -948,6 +955,7 @@ class QuantGrad(torch.autograd.Function):
...
@@ -948,6 +955,7 @@ class QuantGrad(torch.autograd.Function):
"""
"""
This method should be overrided by subclass to provide customized backward function,
This method should be overrided by subclass to provide customized backward function,
default implementation is Straight-Through Estimator
default implementation is Straight-Through Estimator
Parameters
Parameters
----------
----------
tensor : Tensor
tensor : Tensor
...
@@ -963,6 +971,7 @@ class QuantGrad(torch.autograd.Function):
...
@@ -963,6 +971,7 @@ class QuantGrad(torch.autograd.Function):
quant_min for quantizing tensor
quant_min for quantizing tensor
qmax : Tensor
qmax : Tensor
quant_max for quantizng tensor
quant_max for quantizng tensor
Returns
Returns
-------
-------
tensor
tensor
...
...
nni/compression/pytorch/quantization_speedup/integrated_tensorrt.py
View file @
22165cea
...
@@ -228,10 +228,7 @@ def build_engine(model_file, config=None, extra_layer_bits=32, strict_datatype=F
...
@@ -228,10 +228,7 @@ def build_engine(model_file, config=None, extra_layer_bits=32, strict_datatype=F
return
engine
return
engine
class
ModelSpeedupTensorRT
(
BaseModelSpeedup
):
class
ModelSpeedupTensorRT
(
BaseModelSpeedup
):
def
__init__
(
self
,
model
,
input_shape
,
config
=
None
,
onnx_path
=
"default_model.onnx"
,
extra_layer_bits
=
32
,
strict_datatype
=
True
,
r
"""
calibrate_type
=
CalibrateType
.
ENTROPY2
,
calib_data_loader
=
None
,
calibration_cache
=
"calibration.cache"
,
batchsize
=
1
,
input_names
=
[
"actual_input_1"
],
output_names
=
[
"output1"
]):
"""
Parameters
Parameters
----------
----------
model : pytorch model
model : pytorch model
...
@@ -262,6 +259,10 @@ class ModelSpeedupTensorRT(BaseModelSpeedup):
...
@@ -262,6 +259,10 @@ class ModelSpeedupTensorRT(BaseModelSpeedup):
output_name : list
output_name : list
Output name of onnx model providing for torch.onnx.export to generate onnx model
Output name of onnx model providing for torch.onnx.export to generate onnx model
"""
"""
def
__init__
(
self
,
model
,
input_shape
,
config
=
None
,
onnx_path
=
"default_model.onnx"
,
extra_layer_bits
=
32
,
strict_datatype
=
True
,
calibrate_type
=
CalibrateType
.
ENTROPY2
,
calib_data_loader
=
None
,
calibration_cache
=
"calibration.cache"
,
batchsize
=
1
,
input_names
=
[
"actual_input_1"
],
output_names
=
[
"output1"
]):
super
().
__init__
(
model
,
config
)
super
().
__init__
(
model
,
config
)
self
.
model
=
model
self
.
model
=
model
self
.
onnx_path
=
onnx_path
self
.
onnx_path
=
onnx_path
...
...
nni/compression/pytorch/speedup/compressor.py
View file @
22165cea
...
@@ -388,6 +388,9 @@ class ModelSpeedup:
...
@@ -388,6 +388,9 @@ class ModelSpeedup:
def
replace_submodule
(
self
,
unique_name
,
reindex_dim
=
None
,
reindex
=
None
):
def
replace_submodule
(
self
,
unique_name
,
reindex_dim
=
None
,
reindex
=
None
):
"""
"""
Replace the submodule according to the inferred sparsity.
Replace the submodule according to the inferred sparsity.
Parameters
----------
unique_name: str
unique_name: str
The unique_name of the submodule to replace.
The unique_name of the submodule to replace.
reindex_dim: int
reindex_dim: int
...
...
nni/compression/pytorch/utils/mask_conflict.py
View file @
22165cea
...
@@ -81,7 +81,6 @@ class MaskFix:
...
@@ -81,7 +81,6 @@ class MaskFix:
class
GroupMaskConflict
(
MaskFix
):
class
GroupMaskConflict
(
MaskFix
):
def
__init__
(
self
,
masks
,
model
,
dummy_input
,
traced
=
None
):
"""
"""
GroupMaskConflict fix the mask conflict between the layers that
GroupMaskConflict fix the mask conflict between the layers that
has group dependecy with each other.
has group dependecy with each other.
...
@@ -98,6 +97,7 @@ class GroupMaskConflict(MaskFix):
...
@@ -98,6 +97,7 @@ class GroupMaskConflict(MaskFix):
the traced model of the target model, is this parameter is not None,
the traced model of the target model, is this parameter is not None,
we donnot use the model and dummpy_input to get the trace graph.
we donnot use the model and dummpy_input to get the trace graph.
"""
"""
def
__init__
(
self
,
masks
,
model
,
dummy_input
,
traced
=
None
):
super
(
GroupMaskConflict
,
self
).
__init__
(
super
(
GroupMaskConflict
,
self
).
__init__
(
masks
,
model
,
dummy_input
,
traced
)
masks
,
model
,
dummy_input
,
traced
)
...
@@ -168,7 +168,6 @@ class GroupMaskConflict(MaskFix):
...
@@ -168,7 +168,6 @@ class GroupMaskConflict(MaskFix):
class
ChannelMaskConflict
(
MaskFix
):
class
ChannelMaskConflict
(
MaskFix
):
def
__init__
(
self
,
masks
,
model
,
dummy_input
,
traced
=
None
):
"""
"""
ChannelMaskConflict fix the mask conflict between the layers that
ChannelMaskConflict fix the mask conflict between the layers that
has channel dependecy with each other.
has channel dependecy with each other.
...
@@ -185,6 +184,8 @@ class ChannelMaskConflict(MaskFix):
...
@@ -185,6 +184,8 @@ class ChannelMaskConflict(MaskFix):
the traced graph of the target model, is this parameter is not None,
the traced graph of the target model, is this parameter is not None,
we donnot use the model and dummpy_input to get the trace graph.
we donnot use the model and dummpy_input to get the trace graph.
"""
"""
def
__init__
(
self
,
masks
,
model
,
dummy_input
,
traced
=
None
):
super
(
ChannelMaskConflict
,
self
).
__init__
(
super
(
ChannelMaskConflict
,
self
).
__init__
(
masks
,
model
,
dummy_input
,
traced
)
masks
,
model
,
dummy_input
,
traced
)
self
.
conv_prune_dim
=
detect_mask_prune_dim
(
masks
,
model
)
self
.
conv_prune_dim
=
detect_mask_prune_dim
(
masks
,
model
)
...
...
nni/compression/pytorch/utils/sensitivity_analysis.py
View file @
22165cea
...
@@ -18,9 +18,9 @@ logger.setLevel(logging.INFO)
...
@@ -18,9 +18,9 @@ logger.setLevel(logging.INFO)
class
SensitivityAnalysis
:
class
SensitivityAnalysis
:
def
__init__
(
self
,
model
,
val_func
,
sparsities
=
None
,
prune_type
=
'l1'
,
early_stop_mode
=
None
,
early_stop_value
=
None
):
"""
"""
Perform sensitivity analysis for this model.
Perform sensitivity analysis for this model.
Parameters
Parameters
----------
----------
model : torch.nn.Module
model : torch.nn.Module
...
@@ -61,8 +61,9 @@ class SensitivityAnalysis:
...
@@ -61,8 +61,9 @@ class SensitivityAnalysis:
early_stop_value : float
early_stop_value : float
This value is used as the threshold for different earlystop modes.
This value is used as the threshold for different earlystop modes.
This value is effective only when the early_stop_mode is set.
This value is effective only when the early_stop_mode is set.
"""
"""
def
__init__
(
self
,
model
,
val_func
,
sparsities
=
None
,
prune_type
=
'l1'
,
early_stop_mode
=
None
,
early_stop_value
=
None
):
from
nni.algorithms.compression.pytorch.pruning.constants_pruner
import
PRUNER_DICT
from
nni.algorithms.compression.pytorch.pruning.constants_pruner
import
PRUNER_DICT
self
.
model
=
model
self
.
model
=
model
...
...
nni/compression/pytorch/utils/shape_dependency.py
View file @
22165cea
...
@@ -91,10 +91,10 @@ def reshape_break_channel_dependency(op_node):
...
@@ -91,10 +91,10 @@ def reshape_break_channel_dependency(op_node):
class
ChannelDependency
(
Dependency
):
class
ChannelDependency
(
Dependency
):
def
__init__
(
self
,
model
,
dummy_input
,
traced_model
=
None
,
prune_type
=
'Filter'
):
"""
"""
This model analyze the channel dependencies between the conv
This model analyze the channel dependencies between the conv
layers in a model.
layers in a model.
Parameters
Parameters
----------
----------
model : torch.nn.Module
model : torch.nn.Module
...
@@ -109,6 +109,8 @@ class ChannelDependency(Dependency):
...
@@ -109,6 +109,8 @@ class ChannelDependency(Dependency):
prune the filter of the convolution layer to prune the corresponding
prune the filter of the convolution layer to prune the corresponding
channels 2) `Batchnorm`: prune the channel in the batchnorm layer
channels 2) `Batchnorm`: prune the channel in the batchnorm layer
"""
"""
def
__init__
(
self
,
model
,
dummy_input
,
traced_model
=
None
,
prune_type
=
'Filter'
):
self
.
prune_type
=
prune_type
self
.
prune_type
=
prune_type
self
.
target_types
=
[]
self
.
target_types
=
[]
if
self
.
prune_type
==
'Filter'
:
if
self
.
prune_type
==
'Filter'
:
...
@@ -271,6 +273,7 @@ class InputChannelDependency(ChannelDependency):
...
@@ -271,6 +273,7 @@ class InputChannelDependency(ChannelDependency):
"""
"""
This model analyze the input channel dependencies between the conv
This model analyze the input channel dependencies between the conv
layers in a model.
layers in a model.
Parameters
Parameters
----------
----------
model : torch.nn.Module
model : torch.nn.Module
...
@@ -329,10 +332,10 @@ class InputChannelDependency(ChannelDependency):
...
@@ -329,10 +332,10 @@ class InputChannelDependency(ChannelDependency):
class
GroupDependency
(
Dependency
):
class
GroupDependency
(
Dependency
):
def
__init__
(
self
,
model
,
dummy_input
,
traced_model
=
None
):
"""
"""
This model analyze the group dependencis between the conv
This model analyze the group dependencis between the conv
layers in a model.
layers in a model.
Parameters
Parameters
----------
----------
model : torch.nn.Module
model : torch.nn.Module
...
@@ -343,6 +346,8 @@ class GroupDependency(Dependency):
...
@@ -343,6 +346,8 @@ class GroupDependency(Dependency):
if we alreay has the traced graph of the target model, we donnot
if we alreay has the traced graph of the target model, we donnot
need to trace the model again.
need to trace the model again.
"""
"""
def
__init__
(
self
,
model
,
dummy_input
,
traced_model
=
None
):
self
.
min_groups
=
{}
self
.
min_groups
=
{}
super
(
GroupDependency
,
self
).
__init__
(
model
,
dummy_input
,
traced_model
)
super
(
GroupDependency
,
self
).
__init__
(
model
,
dummy_input
,
traced_model
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment