Commit 9584b934 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

use existing qconfig to create learnable qconfig

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/215

Follow up the comment in D35631192 (https://github.com/facebookresearch/d2go/commit/3204f147d67fb2ce7ac2600c46708195c15bfa3a).

The current `get_learnable_qat_qconfig` implementation mimics the default get qconfig functions, as commented "follow `default_per_channel_weight_fake_quant`", etc. Instead of creating custom qconfig from scratch, this diff change it to convert an existing qconfig to learnable, so that this process is transparent to the orthogonal change on the qconfig (eg. symmetric qscheme or new backend).

The following shows the difference between learnable and non-learnable `QConfig` for `qnnpack` and `fbgemm`, the actual difference is just adding `use_grad_scaling=True` and change FakeQuant type from `FusedMovingAvgObsFakeQuantize` to `_LearnableFakeQuantize`. (maybe more obvious to copy to text editor compare show side-by-side)
````
qat_utils.get_learnable_qat_qconfig("qnnpack")
QConfig(
	activation=functools.partial(
		<class 'torch.ao.quantization._learnable_fake_quantize._LearnableFakeQuantize'>,
		observer=<class 'torch.ao.quantization.observer.MovingAverageMinMaxObserver'>,
		quant_min=0,
		quant_max=255,
		use_grad_scaling=True,
		reduce_range=False
	){},
	weight=functools.partial(
		<class 'torch.ao.quantization._learnable_fake_quantize._LearnableFakeQuantize'>,
		observer=<class 'torch.ao.quantization.observer.MovingAverageMinMaxObserver'>,
		quant_min=-128,
		quant_max=127,
		dtype=torch.qint8,
		use_grad_scaling=True,
		qscheme=torch.per_tensor_symmetric,
		reduce_range=False
	){}
)

torch.ao.quantization.get_default_qat_qconfig("qnnpack")
QConfig(
	activation=functools.partial(
		<class 'torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize'>,
		observer=<class 'torch.ao.quantization.observer.MovingAverageMinMaxObserver'>,
		quant_min=0,
		quant_max=255,

		reduce_range=False
	){},
	weight=functools.partial(
		<class 'torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize'>,
		observer=<class 'torch.ao.quantization.observer.MovingAverageMinMaxObserver'>,
		quant_min=-128,
		quant_max=127,
		dtype=torch.qint8,

		qscheme=torch.per_tensor_symmetric,

	){}
)

qat_utils.get_learnable_qat_qconfig("fbgemm")
QConfig(
	activation=functools.partial(
		<class 'torch.ao.quantization._learnable_fake_quantize._LearnableFakeQuantize'>,
		observer=<class 'torch.ao.quantization.observer.MovingAverageMinMaxObserver'>,
		quant_min=0,
		quant_max=255,
		use_grad_scaling=True,
		reduce_range=True
	){},
	weight=functools.partial(
		<class 'torch.ao.quantization._learnable_fake_quantize._LearnableFakeQuantize'>,
		observer=<class 'torch.ao.quantization.observer.MovingAveragePerChannelMinMaxObserver'>,
		quant_min=-128,
		quant_max=127,
		dtype=torch.qint8,
		use_grad_scaling=True,
		qscheme=torch.per_channel_symmetric,
		reduce_range=False,
		ch_axis=0
	){}
)

torch.ao.quantization.get_default_qat_qconfig("fbgemm")
QConfig(
	activation=functools.partial(
		<class 'torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize'>,
		observer=<class 'torch.ao.quantization.observer.MovingAverageMinMaxObserver'>,
		quant_min=0,
		quant_max=255,

		reduce_range=True
	){},
	weight=functools.partial(
		<class 'torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize'>,
		observer=<class 'torch.ao.quantization.observer.MovingAveragePerChannelMinMaxObserver'>,
		quant_min=-128,
		quant_max=127,
		dtype=torch.qint8,

		qscheme=torch.per_channel_symmetric

	){}
)
```

Reviewed By: kimishpatel

Differential Revision: D35772970

fbshipit-source-id: 0be8057e4f7ce3b315bd66d77aa88b733b676223
parent c055a84f
...@@ -287,12 +287,11 @@ def _smart_set_backend_and_create_qconfig(cfg, *, is_train): ...@@ -287,12 +287,11 @@ def _smart_set_backend_and_create_qconfig(cfg, *, is_train):
qat_method = cfg.QUANTIZATION.QAT.FAKE_QUANT_METHOD qat_method = cfg.QUANTIZATION.QAT.FAKE_QUANT_METHOD
assert qat_method in ["default", "learnable"] assert qat_method in ["default", "learnable"]
qconfig = holistic_get_qconfig(
backend=backend, is_qat=is_train, use_symmetric=is_symmetric
)
if is_train and qat_method == "learnable": if is_train and qat_method == "learnable":
qconfig = qat_utils.get_learnable_qat_qconfig(backend) qconfig = qat_utils.convert_to_learnable_qconfig(qconfig)
else:
qconfig = holistic_get_qconfig(
backend=backend, is_qat=is_train, use_symmetric=is_symmetric
)
return qconfig return qconfig
......
...@@ -57,53 +57,40 @@ def iterate_module_named_parameters(model, check_requires_grad=True): ...@@ -57,53 +57,40 @@ def iterate_module_named_parameters(model, check_requires_grad=True):
yield module_name, module, module_param_name, value yield module_name, module, module_param_name, value
def get_learnable_qat_qconfig(backend): def convert_to_learnable_qconfig(qconfig):
assert backend in ["qnnpack", "fbgemm"] """
Convert a QConfig to its learnable counterpart.
ACT_CONFIGS = { """
# follow `get_default_qat_qconfig()`
# fbcode/caffe2/torch/quantization/qconfig.py
"fbgemm": {
"reduce_range": True,
},
"qnnpack": {
"reduce_range": False,
},
}
WEIGHT_CONFIGS = { def _update_fused_moving_avg_obs_fake_quantize(keywords):
# follow `default_per_channel_weight_fake_quant` # requires setting use_grad_scaling to True, all other parameters are the same
# fbcode/caffe2/torch/quantization/fake_quantize.py # as default setting of FusedMovingAvgObsFakeQuantize (both qnnpack and fbgemm).
"fbgemm": { assert "use_grad_scaling" not in keywords
"observer": torch.quantization.MovingAveragePerChannelMinMaxObserver, keywords["use_grad_scaling"] = True
"qscheme": torch.per_channel_symmetric, return keywords
"reduce_range": False,
"ch_axis": 0, _OVERWRITE_PARAMS = {
}, # map from supported FakeQuant type to the its new parameters in order to convert
# follow `default_weight_fake_quant` # it to a learnable FakeQuant
# fbcode/caffe2/torch/quantization/fake_quantize.py torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize: _update_fused_moving_avg_obs_fake_quantize
"qnnpack": {
"observer": torch.quantization.MovingAverageMinMaxObserver,
"qscheme": torch.per_tensor_symmetric,
"reduce_range": False,
},
} }
act = _LearnableFakeQuantize.with_args( def _update_to_learnable(wrapper):
observer=torch.quantization.MovingAverageMinMaxObserver, assert isinstance(
quant_min=0, wrapper, torch.ao.quantization.observer._PartialWrapper
quant_max=255, ), wrapper
use_grad_scaling=True, assert isinstance(wrapper.p, partial), wrapper
**ACT_CONFIGS[backend],
) keywords_updater = _OVERWRITE_PARAMS[wrapper.p.func]
weight = _LearnableFakeQuantize.with_args( keywords = keywords_updater(wrapper.p.keywords)
quant_min=-128,
quant_max=127, new_p = partial(_LearnableFakeQuantize, *wrapper.p.args, **keywords)
dtype=torch.qint8, wrapper.p = new_p
use_grad_scaling=True, return wrapper
**WEIGHT_CONFIGS[backend],
) activation = _update_to_learnable(qconfig.activation)
return torch.quantization.QConfig(activation=act, weight=weight) weight = _update_to_learnable(qconfig.weight)
return torch.quantization.QConfig(activation=activation, weight=weight)
def get_world_size() -> int: def get_world_size() -> int:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment