"tools/python/vscode:/vscode.git/clone" did not exist on "14d34ec99733c59482872aff910d6000844caba6"
Commit ecf832da authored by Mircea Cimpoi's avatar Mircea Cimpoi Committed by Facebook GitHub Bot
Browse files

Add BoltNN conversion to d2go exporter

Summary:
Added predictor_type `boltnn_int8` to export to BoltNN via torch delegate.

- `int8` needs to be in the name, otherwise the post-train quantization won't happen;

```
cfg.QUANTIZATION.BACKEND = "qnnpack"
// cfg.QUANTIZATION.CUSTOM_QSCHEME = "per_tensor_affine"
```

Seems that ` QUANTIZATION.CUSTOM_QSCHEME per_tensor_affine` is not needed - likely covered by "qnnpack".

Reviewed By: wat3rBro

Differential Revision: D29106043

fbshipit-source-id: 865ac5af86919fe7b4530b48433a1bd11e295bf4
parent abf2f327
...@@ -92,7 +92,11 @@ class PredictorExportConfig(NamedTuple): ...@@ -92,7 +92,11 @@ class PredictorExportConfig(NamedTuple):
def convert_and_export_predictor( def convert_and_export_predictor(
cfg, pytorch_model, predictor_type, output_dir, data_loader cfg,
pytorch_model,
predictor_type,
output_dir,
data_loader,
): ):
""" """
Entry point for convert and export model. This involves two steps: Entry point for convert and export model. This involves two steps:
...@@ -101,6 +105,7 @@ def convert_and_export_predictor( ...@@ -101,6 +105,7 @@ def convert_and_export_predictor(
- export: exporting the converted `pytorch_model` to predictor. This step - export: exporting the converted `pytorch_model` to predictor. This step
should not alter the behaviour of model. should not alter the behaviour of model.
""" """
if "int8" in predictor_type: if "int8" in predictor_type:
if not cfg.QUANTIZATION.QAT.ENABLED: if not cfg.QUANTIZATION.QAT.ENABLED:
logger.info( logger.info(
...@@ -111,14 +116,15 @@ def convert_and_export_predictor( ...@@ -111,14 +116,15 @@ def convert_and_export_predictor(
# only check bn exists in ptq as qat still has bn inside fused ops # only check bn exists in ptq as qat still has bn inside fused ops
assert not fuse_utils.check_bn_exist(pytorch_model) assert not fuse_utils.check_bn_exist(pytorch_model)
logger.info(f"Converting quantized model {cfg.QUANTIZATION.BACKEND}...") logger.info(f"Converting quantized model {cfg.QUANTIZATION.BACKEND}...")
if cfg.QUANTIZATION.EAGER_MODE: if cfg.QUANTIZATION.EAGER_MODE:
# TODO(future diff): move this logic to prepare_for_quant_convert # TODO(T93870278): move this logic to prepare_for_quant_convert
pytorch_model = torch.quantization.convert(pytorch_model, inplace=False) pytorch_model = torch.quantization.convert(pytorch_model, inplace=False)
else: # FX graph mode quantization else: # FX graph mode quantization
if hasattr(pytorch_model, "prepare_for_quant_convert"): if hasattr(pytorch_model, "prepare_for_quant_convert"):
pytorch_model = pytorch_model.prepare_for_quant_convert(cfg) pytorch_model = pytorch_model.prepare_for_quant_convert(cfg)
else: else:
# TODO(future diff): move this to a default function # TODO(T93870381): move this to a default function
pytorch_model = torch.quantization.quantize_fx.convert_fx(pytorch_model) pytorch_model = torch.quantization.quantize_fx.convert_fx(pytorch_model)
logger.info("Quantized Model:\n{}".format(pytorch_model)) logger.info("Quantized Model:\n{}".format(pytorch_model))
......
...@@ -169,7 +169,14 @@ def tracing_adapter_wrap_load(old_f): ...@@ -169,7 +169,14 @@ def tracing_adapter_wrap_load(old_f):
@ModelExportMethodRegistry.register("torchscript_mobile_int8") @ModelExportMethodRegistry.register("torchscript_mobile_int8")
class DefaultTorchscriptExport(ModelExportMethod): class DefaultTorchscriptExport(ModelExportMethod):
@classmethod @classmethod
def export(cls, model, input_args, save_path, export_method, **export_kwargs): def export(
cls,
model: nn.Module,
input_args: Tuple[Tuple[torch.Tensor]],
save_path: str,
export_method: Optional[str],
**export_kwargs
):
if export_method is not None: if export_method is not None:
# update export_kwargs based on export_method # update export_kwargs based on export_method
assert isinstance(export_method, str) assert isinstance(export_method, str)
......
...@@ -57,7 +57,11 @@ def main( ...@@ -57,7 +57,11 @@ def main(
pytorch_model = copy.deepcopy(model) pytorch_model = copy.deepcopy(model)
try: try:
predictor_path = convert_and_export_predictor( predictor_path = convert_and_export_predictor(
cfg, pytorch_model, typ, output_dir, data_loader cfg,
pytorch_model,
typ,
output_dir,
data_loader,
) )
logger.info(f"Predictor type {typ} has been exported to {predictor_path}") logger.info(f"Predictor type {typ} has been exported to {predictor_path}")
predictor_paths[typ] = predictor_path predictor_paths[typ] = predictor_path
...@@ -112,8 +116,10 @@ def get_parser(): ...@@ -112,8 +116,10 @@ def get_parser():
) )
return parser return parser
def cli(): def cli():
run_with_cmdline_args(get_parser().parse_args()) run_with_cmdline_args(get_parser().parse_args())
if __name__ == "__main__": if __name__ == "__main__":
cli() cli()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment