Commit 74bc35ea authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

split default_prepare_for_quant into eager and FX mode

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/331

- remove `example_input` from `default_prepare_for_quant` since now it's eager mode only.
- rename `apply_prepare_for_quant` to `convert_to_fake_quant_model` to reflect what it's doing.

Reviewed By: jerryzh168

Differential Revision: D37794085

fbshipit-source-id: e6f12098976272d979a3aac66287d9b95432dcc8
parent 6d152388
......@@ -213,7 +213,7 @@ def mock_quantization_type(quant_func):
return wrapper
def default_prepare_for_quant(cfg, model, example_input=None):
def default_prepare_for_quant(cfg, model):
"""
Default implementation of preparing a model for quantization. This function will
......@@ -234,36 +234,46 @@ def default_prepare_for_quant(cfg, model, example_input=None):
Args:
model (nn.Module): a non-quantized model.
cfg (CfgNode): config
example_input (Optional[Any]): optional example_input for model,
if it is not provided we'll use `model.example_input` when example_input
is required, Note: d2go assumes we always have a single example_input
Return:
nn.Module: a ready model for QAT training or PTQ calibration
"""
assert cfg.QUANTIZATION.EAGER_MODE
qconfig = set_backend_and_create_qconfig(cfg, is_train=model.training)
model = fuse_utils.fuse_model(
model,
is_qat=cfg.QUANTIZATION.QAT.ENABLED,
inplace=True,
)
model.qconfig = qconfig
# TODO(future diff): move the torch.ao.quantization.prepare(...) call
# here, to be consistent with the FX branch
if cfg.QUANTIZATION.EAGER_MODE:
model = fuse_utils.fuse_model(
model,
is_qat=cfg.QUANTIZATION.QAT.ENABLED,
inplace=True,
)
logger.info("Setup the model with qconfig:\n{}".format(qconfig))
return model
model.qconfig = qconfig
# TODO(future diff): move the torch.ao.quantization.prepare(...) call
# here, to be consistent with the FX branch
else: # FX graph mode quantization
qconfig_dict = {"": qconfig}
if example_input is None:
example_input = model.example_input
if model.training:
model = prepare_qat_fx(model, qconfig_dict, (example_input,))
else:
model = prepare_fx(model, qconfig_dict, (example_input,))
logger.info("Setup the model with qconfig:\n{}".format(qconfig))
def default_custom_prepare_fx(cfg, model, example_input=None):
"""
Similar to default_prepare_for_quant, but for FX graph mode.
Args:
example_input (Optional[Any]): optional example_input for model,
if it is not provided we'll use `model.example_input` when example_input
is required, Note: d2go assumes we always have a single example_input
"""
assert not cfg.QUANTIZATION.EAGER_MODE
qconfig = set_backend_and_create_qconfig(cfg, is_train=model.training)
qconfig_dict = {"": qconfig}
if example_input is None:
example_input = model.example_input
if model.training:
model = prepare_qat_fx(model, qconfig_dict, (example_input,))
else:
model = prepare_fx(model, qconfig_dict, (example_input,))
logger.info("Setup the model with qconfig:\n{}".format(qconfig))
return model
......@@ -271,7 +281,10 @@ def default_custom_convert_fx(cfg, model):
return convert_fx(model)
def apply_prepare_for_quant(cfg, model, example_input=None):
def convert_to_fake_quant_model(cfg, model, example_input=None):
"""
Centralized function to convert fp32 model (D2Go's MetaArch) to fake quant model.
"""
# TODO: create a warning for the direct use of `torch.ao.quantization.get_default_qconfig`
# or `torch.ao.quantization.get_default_qat_qconfig` without calling D2Go's high-level
# `set_backend_and_create_qconfig` API.
......@@ -281,7 +294,7 @@ def apply_prepare_for_quant(cfg, model, example_input=None):
model = model.prepare_for_quant(cfg)
else:
logger.info("Using default implementation for prepare_for_quant")
model = default_prepare_for_quant(cfg, model, example_input)
model = default_prepare_for_quant(cfg, model)
# NOTE: eager model needs to call prepare after `prepare_for_quant`
if model.training:
torch.ao.quantization.prepare_qat(model, inplace=True)
......@@ -295,15 +308,16 @@ def apply_prepare_for_quant(cfg, model, example_input=None):
elif hasattr(model, "prepare_for_quant"):
model = model.prepare_for_quant(cfg, example_input)
else:
logger.info("Using default implementation for prepare_for_quant")
model = default_prepare_for_quant(cfg, model, example_input)
logger.info("Using default implementation for custom_prepare_fx")
model = default_custom_prepare_fx(cfg, model, example_input)
return model
def convert_to_quantized_model(cfg, fp32_model):
"""
Convert fake quant model (fp32 operators) to "real" quantized model (int8 operators)
Contralized function to convert fake quant model (fp32 operators) to "real"
quantized model (int8 operators).
"""
if cfg.QUANTIZATION.EAGER_MODE:
int8_model = convert(fp32_model, inplace=False)
......@@ -326,7 +340,7 @@ def post_training_quantize(cfg, model, data_loader):
param.requires_grad = False
example_input = next(iter(data_loader))
model = apply_prepare_for_quant(cfg, model, example_input)
model = convert_to_fake_quant_model(cfg, model, example_input)
logger.info("Prepared the PTQ model for calibration:\n{}".format(model))
# Option for forcing running calibration on GPU, works only when the model supports
......@@ -387,7 +401,7 @@ def setup_qat_model(
model_fp32_state_dict = model_fp32.state_dict()
# prepare model for qat
model = apply_prepare_for_quant(cfg, model_fp32)
model = convert_to_fake_quant_model(cfg, model_fp32)
# make sure the proper qconfig are used in the model
learnable_qat.check_for_learnable_fake_quant_ops(qat_method, model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment