Unverified Commit 808b0655 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[Bugbash] update doc (#5075)

parent bb8114a4
docs/img/compression_pipeline.png

84.2 KB | W: | H:

docs/img/compression_pipeline.png

38.3 KB | W: | H:

docs/img/compression_pipeline.png
docs/img/compression_pipeline.png
docs/img/compression_pipeline.png
docs/img/compression_pipeline.png
  • 2-up
  • Swipe
  • Onion skin
......@@ -170,13 +170,13 @@ Tutorials
.. only:: html
.. image:: /tutorials/images/thumb/sphx_glr_pruning_bert_glue_thumb.png
:alt: Pruning Transformer with NNI
:alt: Pruning Bert on Task MNLI
:ref:`sphx_glr_tutorials_pruning_bert_glue.py`
.. raw:: html
<div class="sphx-glr-thumbnail-title">Pruning Transformer with NNI</div>
<div class="sphx-glr-thumbnail-title">Pruning Bert on Task MNLI</div>
</div>
......
This diff is collapsed.
7d8ff24fe5a88d208ad2ad051f060df4
\ No newline at end of file
4935f5727dd073c91bcfab8b9f0676d7
\ No newline at end of file
This diff is collapsed.
......@@ -5,12 +5,12 @@
Computation times
=================
**01:38.004** total execution time for **tutorials** files:
**00:41.637** total execution time for **tutorials** files:
+-----------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_tutorials_darts.py` (``darts.py``) | 01:38.004 | 0.0 MB |
| :ref:`sphx_glr_tutorials_pruning_bert_glue.py` (``pruning_bert_glue.py``) | 00:41.637 | 0.0 MB |
+-----------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_tutorials_pruning_bert_glue.py` (``pruning_bert_glue.py``) | 00:27.206 | 0.0 MB |
| :ref:`sphx_glr_tutorials_darts.py` (``darts.py``) | 00:00.000 | 0.0 MB |
+-----------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_tutorials_hello_nas.py` (``hello_nas.py``) | 00:00.000 | 0.0 MB |
+-----------------------------------------------------------------------------------------------------+-----------+--------+
......
This diff is collapsed.
......@@ -22,7 +22,7 @@ class NaiveQuantizer(Quantizer):
config_list : List[Dict]
List of configurations for quantization. Supported keys:
- quant_types : List[str]
Type of quantization you want to apply, currently support 'weight', 'input', 'output'.
Type of quantization you want to apply, currently support 'weight'.
- quant_bits : Union[int, Dict[str, int]]
Bits length of quantization, key is the quantization type, value is the length, eg. {'weight': 8},
when the type is int, all quantization types share same bits length.
......
......@@ -228,27 +228,6 @@ class QAT_Quantizer(Quantizer):
'quant_dtype': 'uint',
'quant_scheme': 'per_tensor_affine'
}]
**Multi-GPU training**
QAT quantizer natively supports multi-gpu training (DataParallel and DistributedDataParallel). Note that the quantizer
instantiation should happen before you wrap your model with DataParallel or DistributedDataParallel. For example:
.. code-block:: python
from torch.nn.parallel import DistributedDataParallel as DDP
from nni.algorithms.compression.pytorch.quantization import QAT_Quantizer
model = define_your_model()
model = QAT_Quantizer(model, **other_params) # <--- QAT_Quantizer instantiation
model = DDP(model)
for i in range(epochs):
train(model)
eval(model)
"""
def __init__(self, model, config_list, optimizer, dummy_input=None):
......
......@@ -175,8 +175,8 @@ class EvaluatorBasedPruner(BasicPruner):
else:
self.optimizer_helper = OptimizerConstructHelper.from_trace(model, traced_optimizer)
self.using_evaluator = False
warn_msg = f"The old API ...{','.join(old_api)} will be deprecated after NNI v3.0, " + \
"please using the new one ...{','.join(new_api)}"
warn_msg = f"The old API {','.join(old_api)} will be deprecated after NNI v3.0, " + \
f"please using the new one {','.join(new_api)}"
_logger.warning(warn_msg)
return init_kwargs
......
......@@ -760,7 +760,10 @@ class TorchEvaluator(Evaluator):
def evaluate(self) -> float | None | Tuple[float, Dict[str, Any]] | Tuple[None, Dict[str, Any]]:
assert self.model is not None
assert self.evaluating_func is not None
if self.evaluating_func is None:
warn_msg = f'Did not pass evaluation_func to {self.__class__.__name__}, will return None for calling evaluate()'
_logger.warning(warn_msg)
return None
metric = self.evaluating_func(self.model)
if isinstance(metric, dict):
nni_used_metric = metric.get('default', None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment