Unverified Commit fde9e1a0 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

update speed up doc (#4281)

parent e779d7f3
......@@ -198,3 +198,11 @@ The latency is measured on one V100 GPU and the input tensor is ``torch.randn(1
.. image:: ../../img/SA_latency_accuracy.png
User configuration for ModelSpeedup
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
**PyTorch**
.. autoclass:: nni.compression.pytorch.ModelSpeedup
......@@ -23,11 +23,7 @@ _logger.setLevel(logging.INFO)
class ModelSpeedup:
"""
This class is to speedup the model with provided weight mask.
"""
def __init__(self, model, dummy_input, masks_file, map_location=None,
batch_dim=0, confidence=8):
"""
Parameters
----------
model : pytorch model
......@@ -45,6 +41,9 @@ class ModelSpeedup:
confidence: the confidence coefficient of the sparsity inference. This value is
actually used as the batchsize of the dummy_input.
"""
def __init__(self, model, dummy_input, masks_file, map_location=None,
batch_dim=0, confidence=8):
assert confidence > 1
# The auto inference will change the values of the parameters in the model
# so we need make a copy before the mask inference
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment