"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "4581f147a6347ae790d26c9c8fbd3536e573f112"
Commit bd1beec9 authored by Jiaxu Zhu's avatar Jiaxu Zhu Committed by Facebook GitHub Bot
Browse files

Support QAT in D2Go

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/408

Update the DPE training script in D2 (https://github.com/facebookresearch/d2go/commit/87374efb134e539090e0b5c476809dc35bf6aedb)Go to support Turing DPE QAT.

Reviewed By: newstzpz

Differential Revision: D40612406

fbshipit-source-id: 9379e4be248045b995293c5a522bab05e0b13c6e
parent 588ea957
...@@ -35,7 +35,7 @@ def _has_module(model, module_type): ...@@ -35,7 +35,7 @@ def _has_module(model, module_type):
def check_for_learnable_fake_quant_ops(qat_method, model): def check_for_learnable_fake_quant_ops(qat_method, model):
"""Make sure learnable observers are used if qat method is `learnable`""" """Make sure learnable observers are used if qat method is `learnable`"""
if qat_method == "learnable": if qat_method.startswith("learnable"):
if not _has_module(model, _LearnableFakeQuantize): if not _has_module(model, _LearnableFakeQuantize):
raise Exception( raise Exception(
"No learnable fake quant is used for learnable quantzation, please use d2go.quantization.learnable_qat.get_learnable_qat_qconfig() to get proper qconfig" "No learnable fake quant is used for learnable quantzation, please use d2go.quantization.learnable_qat.get_learnable_qat_qconfig() to get proper qconfig"
...@@ -187,7 +187,7 @@ def setup_qat_get_optimizer_param_groups(model, qat_method): ...@@ -187,7 +187,7 @@ def setup_qat_get_optimizer_param_groups(model, qat_method):
"""Add a function `get_optimizer_param_groups` to the model so that it could """Add a function `get_optimizer_param_groups` to the model so that it could
return proper weight decay for learnable qat return proper weight decay for learnable qat
""" """
if qat_method != "learnable": if not qat_method.startswith("learnable"):
return model return model
assert _is_q_state_dict(model.state_dict()) assert _is_q_state_dict(model.state_dict())
......
...@@ -437,7 +437,11 @@ def setup_qat_model( ...@@ -437,7 +437,11 @@ def setup_qat_model(
enable_observer: bool = False, enable_observer: bool = False,
enable_learnable_observer: bool = False, enable_learnable_observer: bool = False,
): ):
assert cfg.QUANTIZATION.QAT.FAKE_QUANT_METHOD in ["default", "learnable"] assert cfg.QUANTIZATION.QAT.FAKE_QUANT_METHOD in [
"default",
"learnable",
"learnable_act",
]
if hasattr(model_fp32, "_non_qat_to_qat_state_dict_map"): if hasattr(model_fp32, "_non_qat_to_qat_state_dict_map"):
raise RuntimeError("The model is already setup to be QAT, cannot setup again!") raise RuntimeError("The model is already setup to be QAT, cannot setup again!")
...@@ -467,7 +471,7 @@ def setup_qat_model( ...@@ -467,7 +471,7 @@ def setup_qat_model(
logger.info("Disabling static observer ...") logger.info("Disabling static observer ...")
model.apply(torch.ao.quantization.disable_observer) model.apply(torch.ao.quantization.disable_observer)
model.apply(learnable_qat.disable_lqat_static_observer) model.apply(learnable_qat.disable_lqat_static_observer)
if not enable_learnable_observer and qat_method == "learnable": if not enable_learnable_observer and qat_method.startswith("learnable"):
logger.info("Disabling learnable observer ...") logger.info("Disabling learnable observer ...")
model.apply(learnable_qat.disable_lqat_learnable_observer) model.apply(learnable_qat.disable_lqat_learnable_observer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment