Commit f905958c authored by Shangdi Yu's avatar Shangdi Yu Committed by Facebook GitHub Bot
Browse files

Migrate in d2go

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/678

capture_pre_autograd_graph is deprecating. Migrate to use the new API.

Reviewed By: navsud, tugsbayasgalan

Differential Revision: D63859679

fbshipit-source-id: f14def6bc622cc451020d0edcc312330fa626943
parent 5b856252
...@@ -22,7 +22,6 @@ from mobile_cv.arch.quantization.observer import update_stat as observer_update_ ...@@ -22,7 +22,6 @@ from mobile_cv.arch.quantization.observer import update_stat as observer_update_
from mobile_cv.arch.utils import fuse_utils from mobile_cv.arch.utils import fuse_utils
from mobile_cv.common.misc.iter_utils import recursive_iterate from mobile_cv.common.misc.iter_utils import recursive_iterate
from torch import nn from torch import nn
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantize_pt2e import ( from torch.ao.quantization.quantize_pt2e import (
convert_pt2e, convert_pt2e,
prepare_pt2e, prepare_pt2e,
...@@ -32,6 +31,7 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer import ( ...@@ -32,6 +31,7 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config, get_symmetric_quantization_config,
XNNPACKQuantizer, XNNPACKQuantizer,
) )
from torch.export import export_for_training
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2]) TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
# some tests still import prepare/convert from below. So don't remove these. # some tests still import prepare/convert from below. So don't remove these.
...@@ -39,8 +39,7 @@ if TORCH_VERSION > (1, 10): ...@@ -39,8 +39,7 @@ if TORCH_VERSION > (1, 10):
from torch.ao.quantization.quantize import convert from torch.ao.quantization.quantize import convert
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
else: else:
from torch.quantization.quantize import convert pass
from torch.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -368,7 +367,7 @@ def prepare_fake_quant_model(cfg, model, is_qat, example_input=None): ...@@ -368,7 +367,7 @@ def prepare_fake_quant_model(cfg, model, is_qat, example_input=None):
) )
else: else:
logger.info("Using default pt2e quantization APIs with XNNPACKQuantizer") logger.info("Using default pt2e quantization APIs with XNNPACKQuantizer")
captured_model = capture_pre_autograd_graph(model, example_input) captured_model = export_for_training(model, example_input).module()
quantizer = _get_symmetric_xnnpack_quantizer() quantizer = _get_symmetric_xnnpack_quantizer()
if is_qat: if is_qat:
model = prepare_qat_pt2e(captured_model, quantizer) model = prepare_qat_pt2e(captured_model, quantizer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment