Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
22165cea
Unverified
Commit
22165cea
authored
Mar 21, 2022
by
J-shang
Committed by
GitHub
Mar 21, 2022
Browse files
[Doc] update compression reference (#4667)
parent
de6662a4
Changes
26
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
157 additions
and
137 deletions
+157
-137
nni/compression/pytorch/compressor.py
nni/compression/pytorch/compressor.py
+9
-0
nni/compression/pytorch/quantization_speedup/integrated_tensorrt.py
...ssion/pytorch/quantization_speedup/integrated_tensorrt.py
+32
-31
nni/compression/pytorch/speedup/compressor.py
nni/compression/pytorch/speedup/compressor.py
+3
-0
nni/compression/pytorch/utils/mask_conflict.py
nni/compression/pytorch/utils/mask_conflict.py
+33
-32
nni/compression/pytorch/utils/sensitivity_analysis.py
nni/compression/pytorch/utils/sensitivity_analysis.py
+45
-44
nni/compression/pytorch/utils/shape_dependency.py
nni/compression/pytorch/utils/shape_dependency.py
+35
-30
No files found.
nni/compression/pytorch/compressor.py
View file @
22165cea
...
...
@@ -631,6 +631,7 @@ class Quantizer(Compressor):
"""
quantize should overload this method to quantize weight.
This method is effectively hooked to :meth:`forward` of the model.
Parameters
----------
wrapper : QuantizerModuleWrapper
...
...
@@ -642,6 +643,7 @@ class Quantizer(Compressor):
"""
quantize should overload this method to quantize output.
This method is effectively hooked to :meth:`forward` of the model.
Parameters
----------
output : Tensor
...
...
@@ -655,6 +657,7 @@ class Quantizer(Compressor):
"""
quantize should overload this method to quantize input.
This method is effectively hooked to :meth:`forward` of the model.
Parameters
----------
inputs : Tensor
...
...
@@ -908,6 +911,7 @@ class QuantGrad(torch.autograd.Function):
def
_quantize
(
cls
,
x
,
scale
,
zero_point
):
"""
Reference function for quantizing x -- non-clamped.
Parameters
----------
x : Tensor
...
...
@@ -916,6 +920,7 @@ class QuantGrad(torch.autograd.Function):
scale for quantizing x
zero_point : Tensor
zero_point for quantizing x
Returns
-------
tensor
...
...
@@ -927,12 +932,14 @@ class QuantGrad(torch.autograd.Function):
def
get_bits_length
(
cls
,
config
,
quant_type
):
"""
Get bits for quantize config
Parameters
----------
config : Dict
the configuration for quantization
quant_type : str
quant type
Returns
-------
int
...
...
@@ -948,6 +955,7 @@ class QuantGrad(torch.autograd.Function):
"""
This method should be overrided by subclass to provide customized backward function,
default implementation is Straight-Through Estimator
Parameters
----------
tensor : Tensor
...
...
@@ -963,6 +971,7 @@ class QuantGrad(torch.autograd.Function):
quant_min for quantizing tensor
qmax : Tensor
quant_max for quantizng tensor
Returns
-------
tensor
...
...
nni/compression/pytorch/quantization_speedup/integrated_tensorrt.py
View file @
22165cea
...
...
@@ -228,10 +228,7 @@ def build_engine(model_file, config=None, extra_layer_bits=32, strict_datatype=F
return
engine
class
ModelSpeedupTensorRT
(
BaseModelSpeedup
):
def
__init__
(
self
,
model
,
input_shape
,
config
=
None
,
onnx_path
=
"default_model.onnx"
,
extra_layer_bits
=
32
,
strict_datatype
=
True
,
calibrate_type
=
CalibrateType
.
ENTROPY2
,
calib_data_loader
=
None
,
calibration_cache
=
"calibration.cache"
,
batchsize
=
1
,
input_names
=
[
"actual_input_1"
],
output_names
=
[
"output1"
]):
"""
r
"""
Parameters
----------
model : pytorch model
...
...
@@ -262,6 +259,10 @@ class ModelSpeedupTensorRT(BaseModelSpeedup):
output_name : list
Output name of onnx model providing for torch.onnx.export to generate onnx model
"""
def
__init__
(
self
,
model
,
input_shape
,
config
=
None
,
onnx_path
=
"default_model.onnx"
,
extra_layer_bits
=
32
,
strict_datatype
=
True
,
calibrate_type
=
CalibrateType
.
ENTROPY2
,
calib_data_loader
=
None
,
calibration_cache
=
"calibration.cache"
,
batchsize
=
1
,
input_names
=
[
"actual_input_1"
],
output_names
=
[
"output1"
]):
super
().
__init__
(
model
,
config
)
self
.
model
=
model
self
.
onnx_path
=
onnx_path
...
...
nni/compression/pytorch/speedup/compressor.py
View file @
22165cea
...
...
@@ -388,6 +388,9 @@ class ModelSpeedup:
def
replace_submodule
(
self
,
unique_name
,
reindex_dim
=
None
,
reindex
=
None
):
"""
Replace the submodule according to the inferred sparsity.
Parameters
----------
unique_name: str
The unique_name of the submodule to replace.
reindex_dim: int
...
...
nni/compression/pytorch/utils/mask_conflict.py
View file @
22165cea
...
...
@@ -81,7 +81,6 @@ class MaskFix:
class
GroupMaskConflict
(
MaskFix
):
def
__init__
(
self
,
masks
,
model
,
dummy_input
,
traced
=
None
):
"""
GroupMaskConflict fix the mask conflict between the layers that
has group dependecy with each other.
...
...
@@ -98,6 +97,7 @@ class GroupMaskConflict(MaskFix):
the traced model of the target model, is this parameter is not None,
we donnot use the model and dummpy_input to get the trace graph.
"""
def
__init__
(
self
,
masks
,
model
,
dummy_input
,
traced
=
None
):
super
(
GroupMaskConflict
,
self
).
__init__
(
masks
,
model
,
dummy_input
,
traced
)
...
...
@@ -168,7 +168,6 @@ class GroupMaskConflict(MaskFix):
class
ChannelMaskConflict
(
MaskFix
):
def
__init__
(
self
,
masks
,
model
,
dummy_input
,
traced
=
None
):
"""
ChannelMaskConflict fix the mask conflict between the layers that
has channel dependecy with each other.
...
...
@@ -185,6 +184,8 @@ class ChannelMaskConflict(MaskFix):
the traced graph of the target model, is this parameter is not None,
we donnot use the model and dummpy_input to get the trace graph.
"""
def
__init__
(
self
,
masks
,
model
,
dummy_input
,
traced
=
None
):
super
(
ChannelMaskConflict
,
self
).
__init__
(
masks
,
model
,
dummy_input
,
traced
)
self
.
conv_prune_dim
=
detect_mask_prune_dim
(
masks
,
model
)
...
...
nni/compression/pytorch/utils/sensitivity_analysis.py
View file @
22165cea
...
...
@@ -18,9 +18,9 @@ logger.setLevel(logging.INFO)
class
SensitivityAnalysis
:
def
__init__
(
self
,
model
,
val_func
,
sparsities
=
None
,
prune_type
=
'l1'
,
early_stop_mode
=
None
,
early_stop_value
=
None
):
"""
Perform sensitivity analysis for this model.
Parameters
----------
model : torch.nn.Module
...
...
@@ -61,8 +61,9 @@ class SensitivityAnalysis:
early_stop_value : float
This value is used as the threshold for different earlystop modes.
This value is effective only when the early_stop_mode is set.
"""
def
__init__
(
self
,
model
,
val_func
,
sparsities
=
None
,
prune_type
=
'l1'
,
early_stop_mode
=
None
,
early_stop_value
=
None
):
from
nni.algorithms.compression.pytorch.pruning.constants_pruner
import
PRUNER_DICT
self
.
model
=
model
...
...
nni/compression/pytorch/utils/shape_dependency.py
View file @
22165cea
...
...
@@ -91,10 +91,10 @@ def reshape_break_channel_dependency(op_node):
class
ChannelDependency
(
Dependency
):
def
__init__
(
self
,
model
,
dummy_input
,
traced_model
=
None
,
prune_type
=
'Filter'
):
"""
This model analyze the channel dependencies between the conv
layers in a model.
Parameters
----------
model : torch.nn.Module
...
...
@@ -109,6 +109,8 @@ class ChannelDependency(Dependency):
prune the filter of the convolution layer to prune the corresponding
channels 2) `Batchnorm`: prune the channel in the batchnorm layer
"""
def
__init__
(
self
,
model
,
dummy_input
,
traced_model
=
None
,
prune_type
=
'Filter'
):
self
.
prune_type
=
prune_type
self
.
target_types
=
[]
if
self
.
prune_type
==
'Filter'
:
...
...
@@ -271,6 +273,7 @@ class InputChannelDependency(ChannelDependency):
"""
This model analyze the input channel dependencies between the conv
layers in a model.
Parameters
----------
model : torch.nn.Module
...
...
@@ -329,10 +332,10 @@ class InputChannelDependency(ChannelDependency):
class
GroupDependency
(
Dependency
):
def
__init__
(
self
,
model
,
dummy_input
,
traced_model
=
None
):
"""
This model analyze the group dependencis between the conv
layers in a model.
Parameters
----------
model : torch.nn.Module
...
...
@@ -343,6 +346,8 @@ class GroupDependency(Dependency):
if we alreay has the traced graph of the target model, we donnot
need to trace the model again.
"""
def
__init__
(
self
,
model
,
dummy_input
,
traced_model
=
None
):
self
.
min_groups
=
{}
super
(
GroupDependency
,
self
).
__init__
(
model
,
dummy_input
,
traced_model
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment