Commit 9584b934 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

use existing qconfig to create learnable qconfig

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/215

Follow up the comment in D35631192 (https://github.com/facebookresearch/d2go/commit/3204f147d67fb2ce7ac2600c46708195c15bfa3a).

The current `get_learnable_qat_qconfig` implementation mimics the default get qconfig functions, as commented "follow `default_per_channel_weight_fake_quant`", etc. Instead of creating custom qconfig from scratch, this diff change it to convert an existing qconfig to learnable, so that this process is transparent to the orthogonal change on the qconfig (eg. symmetric qscheme or new backend).

The following shows the difference between learnable and non-learnable `QConfig` for `qnnpack` and `fbgemm`, the actual difference is just adding `use_grad_scaling=True` and change FakeQuant type from `FusedMovingAvgObsFakeQuantize` to `_LearnableFakeQuantize`. (maybe more obvious to copy to text editor compare show side-by-side)
````
qat_utils.get_learnable_qat_qconfig("qnnpack")
QConfig(
	activation=functools.partial(
		<class 'torch.ao.quantization._learnable_fake_quantize._LearnableFakeQuantize'>,
		observer=<class 'torch.ao.quantization.observer.MovingAverageMinMaxObserver'>,
		quant_min=0,
		quant_max=255,
		use_grad_scaling=True,
		reduce_range=False
	){},
	weight=functools.partial(
		<class 'torch.ao.quantization._learnable_fake_quantize._LearnableFakeQuantize'>,
		observer=<class 'torch.ao.quantization.observer.MovingAverageMinMaxObserver'>,
		quant_min=-128,
		quant_max=127,
		dtype=torch.qint8,
		use_grad_scaling=True,
		qscheme=torch.per_tensor_symmetric,
		reduce_range=False
	){}
)

torch.ao.quantization.get_default_qat_qconfig("qnnpack")
QConfig(
	activation=functools.partial(
		<class 'torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize'>,
		observer=<class 'torch.ao.quantization.observer.MovingAverageMinMaxObserver'>,
		quant_min=0,
		quant_max=255,

		reduce_range=False
	){},
	weight=functools.partial(
		<class 'torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize'>,
		observer=<class 'torch.ao.quantization.observer.MovingAverageMinMaxObserver'>,
		quant_min=-128,
		quant_max=127,
		dtype=torch.qint8,

		qscheme=torch.per_tensor_symmetric,

	){}
)

qat_utils.get_learnable_qat_qconfig("fbgemm")
QConfig(
	activation=functools.partial(
		<class 'torch.ao.quantization._learnable_fake_quantize._LearnableFakeQuantize'>,
		observer=<class 'torch.ao.quantization.observer.MovingAverageMinMaxObserver'>,
		quant_min=0,
		quant_max=255,
		use_grad_scaling=True,
		reduce_range=True
	){},
	weight=functools.partial(
		<class 'torch.ao.quantization._learnable_fake_quantize._LearnableFakeQuantize'>,
		observer=<class 'torch.ao.quantization.observer.MovingAveragePerChannelMinMaxObserver'>,
		quant_min=-128,
		quant_max=127,
		dtype=torch.qint8,
		use_grad_scaling=True,
		qscheme=torch.per_channel_symmetric,
		reduce_range=False,
		ch_axis=0
	){}
)

torch.ao.quantization.get_default_qat_qconfig("fbgemm")
QConfig(
	activation=functools.partial(
		<class 'torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize'>,
		observer=<class 'torch.ao.quantization.observer.MovingAverageMinMaxObserver'>,
		quant_min=0,
		quant_max=255,

		reduce_range=True
	){},
	weight=functools.partial(
		<class 'torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize'>,
		observer=<class 'torch.ao.quantization.observer.MovingAveragePerChannelMinMaxObserver'>,
		quant_min=-128,
		quant_max=127,
		dtype=torch.qint8,

		qscheme=torch.per_channel_symmetric

	){}
)
```

Reviewed By: kimishpatel

Differential Revision: D35772970

fbshipit-source-id: 0be8057e4f7ce3b315bd66d77aa88b733b676223
parent c055a84f
......@@ -287,12 +287,11 @@ def _smart_set_backend_and_create_qconfig(cfg, *, is_train):
qat_method = cfg.QUANTIZATION.QAT.FAKE_QUANT_METHOD
assert qat_method in ["default", "learnable"]
if is_train and qat_method == "learnable":
qconfig = qat_utils.get_learnable_qat_qconfig(backend)
else:
qconfig = holistic_get_qconfig(
backend=backend, is_qat=is_train, use_symmetric=is_symmetric
)
if is_train and qat_method == "learnable":
qconfig = qat_utils.convert_to_learnable_qconfig(qconfig)
return qconfig
......
......@@ -57,53 +57,40 @@ def iterate_module_named_parameters(model, check_requires_grad=True):
yield module_name, module, module_param_name, value
def get_learnable_qat_qconfig(backend):
assert backend in ["qnnpack", "fbgemm"]
ACT_CONFIGS = {
# follow `get_default_qat_qconfig()`
# fbcode/caffe2/torch/quantization/qconfig.py
"fbgemm": {
"reduce_range": True,
},
"qnnpack": {
"reduce_range": False,
},
}
def convert_to_learnable_qconfig(qconfig):
"""
Convert a QConfig to its learnable counterpart.
"""
WEIGHT_CONFIGS = {
# follow `default_per_channel_weight_fake_quant`
# fbcode/caffe2/torch/quantization/fake_quantize.py
"fbgemm": {
"observer": torch.quantization.MovingAveragePerChannelMinMaxObserver,
"qscheme": torch.per_channel_symmetric,
"reduce_range": False,
"ch_axis": 0,
},
# follow `default_weight_fake_quant`
# fbcode/caffe2/torch/quantization/fake_quantize.py
"qnnpack": {
"observer": torch.quantization.MovingAverageMinMaxObserver,
"qscheme": torch.per_tensor_symmetric,
"reduce_range": False,
},
def _update_fused_moving_avg_obs_fake_quantize(keywords):
# requires setting use_grad_scaling to True, all other parameters are the same
# as default setting of FusedMovingAvgObsFakeQuantize (both qnnpack and fbgemm).
assert "use_grad_scaling" not in keywords
keywords["use_grad_scaling"] = True
return keywords
_OVERWRITE_PARAMS = {
# map from supported FakeQuant type to the its new parameters in order to convert
# it to a learnable FakeQuant
torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize: _update_fused_moving_avg_obs_fake_quantize
}
act = _LearnableFakeQuantize.with_args(
observer=torch.quantization.MovingAverageMinMaxObserver,
quant_min=0,
quant_max=255,
use_grad_scaling=True,
**ACT_CONFIGS[backend],
)
weight = _LearnableFakeQuantize.with_args(
quant_min=-128,
quant_max=127,
dtype=torch.qint8,
use_grad_scaling=True,
**WEIGHT_CONFIGS[backend],
)
return torch.quantization.QConfig(activation=act, weight=weight)
def _update_to_learnable(wrapper):
assert isinstance(
wrapper, torch.ao.quantization.observer._PartialWrapper
), wrapper
assert isinstance(wrapper.p, partial), wrapper
keywords_updater = _OVERWRITE_PARAMS[wrapper.p.func]
keywords = keywords_updater(wrapper.p.keywords)
new_p = partial(_LearnableFakeQuantize, *wrapper.p.args, **keywords)
wrapper.p = new_p
return wrapper
activation = _update_to_learnable(qconfig.activation)
weight = _update_to_learnable(qconfig.weight)
return torch.quantization.QConfig(activation=activation, weight=weight)
def get_world_size() -> int:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment