"docs/git@developer.sourcefind.cn:change/sglang.git" did not exist on "0abb41c70d1dde72b7566cc005b738867d5a44b4"
Commit ecf832da authored by Mircea Cimpoi's avatar Mircea Cimpoi Committed by Facebook GitHub Bot
Browse files

Add BoltNN conversion to d2go exporter

Summary:
Added predictor_type `boltnn_int8` to export to BoltNN via torch delegate.

- `int8` needs to be in the name, otherwise the post-train quantization won't happen;

```
cfg.QUANTIZATION.BACKEND = "qnnpack"
// cfg.QUANTIZATION.CUSTOM_QSCHEME = "per_tensor_affine"
```

Seems that ` QUANTIZATION.CUSTOM_QSCHEME per_tensor_affine` is not needed - likely covered by "qnnpack".

Reviewed By: wat3rBro

Differential Revision: D29106043

fbshipit-source-id: 865ac5af86919fe7b4530b48433a1bd11e295bf4
parent abf2f327
......@@ -92,7 +92,11 @@ class PredictorExportConfig(NamedTuple):
def convert_and_export_predictor(
cfg, pytorch_model, predictor_type, output_dir, data_loader
cfg,
pytorch_model,
predictor_type,
output_dir,
data_loader,
):
"""
Entry point for convert and export model. This involves two steps:
......@@ -101,6 +105,7 @@ def convert_and_export_predictor(
- export: exporting the converted `pytorch_model` to predictor. This step
should not alter the behaviour of model.
"""
if "int8" in predictor_type:
if not cfg.QUANTIZATION.QAT.ENABLED:
logger.info(
......@@ -111,14 +116,15 @@ def convert_and_export_predictor(
# only check bn exists in ptq as qat still has bn inside fused ops
assert not fuse_utils.check_bn_exist(pytorch_model)
logger.info(f"Converting quantized model {cfg.QUANTIZATION.BACKEND}...")
if cfg.QUANTIZATION.EAGER_MODE:
# TODO(future diff): move this logic to prepare_for_quant_convert
# TODO(T93870278): move this logic to prepare_for_quant_convert
pytorch_model = torch.quantization.convert(pytorch_model, inplace=False)
else: # FX graph mode quantization
if hasattr(pytorch_model, "prepare_for_quant_convert"):
pytorch_model = pytorch_model.prepare_for_quant_convert(cfg)
else:
# TODO(future diff): move this to a default function
# TODO(T93870381): move this to a default function
pytorch_model = torch.quantization.quantize_fx.convert_fx(pytorch_model)
logger.info("Quantized Model:\n{}".format(pytorch_model))
......
......@@ -169,7 +169,14 @@ def tracing_adapter_wrap_load(old_f):
@ModelExportMethodRegistry.register("torchscript_mobile_int8")
class DefaultTorchscriptExport(ModelExportMethod):
@classmethod
def export(cls, model, input_args, save_path, export_method, **export_kwargs):
def export(
cls,
model: nn.Module,
input_args: Tuple[Tuple[torch.Tensor]],
save_path: str,
export_method: Optional[str],
**export_kwargs
):
if export_method is not None:
# update export_kwargs based on export_method
assert isinstance(export_method, str)
......
......@@ -57,7 +57,11 @@ def main(
pytorch_model = copy.deepcopy(model)
try:
predictor_path = convert_and_export_predictor(
cfg, pytorch_model, typ, output_dir, data_loader
cfg,
pytorch_model,
typ,
output_dir,
data_loader,
)
logger.info(f"Predictor type {typ} has been exported to {predictor_path}")
predictor_paths[typ] = predictor_path
......@@ -112,8 +116,10 @@ def get_parser():
)
return parser
def cli():
run_with_cmdline_args(get_parser().parse_args())
if __name__ == "__main__":
cli()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment