use existing qconfig to create learnable qconfig
Summary: Pull Request resolved: https://github.com/facebookresearch/d2go/pull/215 Follow up the comment in D35631192 (https://github.com/facebookresearch/d2go/commit/3204f147d67fb2ce7ac2600c46708195c15bfa3a). The current `get_learnable_qat_qconfig` implementation mimics the default get qconfig functions, as commented "follow `default_per_channel_weight_fake_quant`", etc. Instead of creating custom qconfig from scratch, this diff change it to convert an existing qconfig to learnable, so that this process is transparent to the orthogonal change on the qconfig (eg. symmetric qscheme or new backend). The following shows the difference between learnable and non-learnable `QConfig` for `qnnpack` and `fbgemm`, the actual difference is just adding `use_grad_scaling=True` and change FakeQuant type from `FusedMovingAvgObsFakeQuantize` to `_LearnableFakeQuantize`. (maybe more obvious to copy to text editor compare show side-by-side) ```` qat_utils.get_learnable_qat_qconfig("qnnpack") QConfig( activation=functools.partial( <class 'torch.ao.quantization._learnable_fake_quantize._LearnableFakeQuantize'>, observer=<class 'torch.ao.quantization.observer.MovingAverageMinMaxObserver'>, quant_min=0, quant_max=255, use_grad_scaling=True, reduce_range=False ){}, weight=functools.partial( <class 'torch.ao.quantization._learnable_fake_quantize._LearnableFakeQuantize'>, observer=<class 'torch.ao.quantization.observer.MovingAverageMinMaxObserver'>, quant_min=-128, quant_max=127, dtype=torch.qint8, use_grad_scaling=True, qscheme=torch.per_tensor_symmetric, reduce_range=False ){} ) torch.ao.quantization.get_default_qat_qconfig("qnnpack") QConfig( activation=functools.partial( <class 'torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize'>, observer=<class 'torch.ao.quantization.observer.MovingAverageMinMaxObserver'>, quant_min=0, quant_max=255, reduce_range=False ){}, weight=functools.partial( <class 'torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize'>, observer=<class 'torch.ao.quantization.observer.MovingAverageMinMaxObserver'>, quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, ){} ) qat_utils.get_learnable_qat_qconfig("fbgemm") QConfig( activation=functools.partial( <class 'torch.ao.quantization._learnable_fake_quantize._LearnableFakeQuantize'>, observer=<class 'torch.ao.quantization.observer.MovingAverageMinMaxObserver'>, quant_min=0, quant_max=255, use_grad_scaling=True, reduce_range=True ){}, weight=functools.partial( <class 'torch.ao.quantization._learnable_fake_quantize._LearnableFakeQuantize'>, observer=<class 'torch.ao.quantization.observer.MovingAveragePerChannelMinMaxObserver'>, quant_min=-128, quant_max=127, dtype=torch.qint8, use_grad_scaling=True, qscheme=torch.per_channel_symmetric, reduce_range=False, ch_axis=0 ){} ) torch.ao.quantization.get_default_qat_qconfig("fbgemm") QConfig( activation=functools.partial( <class 'torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize'>, observer=<class 'torch.ao.quantization.observer.MovingAverageMinMaxObserver'>, quant_min=0, quant_max=255, reduce_range=True ){}, weight=functools.partial( <class 'torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize'>, observer=<class 'torch.ao.quantization.observer.MovingAveragePerChannelMinMaxObserver'>, quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_channel_symmetric ){} ) ``` Reviewed By: kimishpatel Differential Revision: D35772970 fbshipit-source-id: 0be8057e4f7ce3b315bd66d77aa88b733b676223
Showing
Please register or sign in to comment