Commit e93de5d2 authored by Dark Knight's avatar Dark Knight Committed by Facebook GitHub Bot
Browse files

Revert D38436675: Multisect successfully blamed D38436675 for test or build failures

Summary:
This diff is reverting D38436675 (https://github.com/facebookresearch/d2go/commit/d98f74aa467a6576c9905accbf4a2b6279599f9c)
D38436675 (https://github.com/facebookresearch/d2go/commit/d98f74aa467a6576c9905accbf4a2b6279599f9c) has been identified to be causing the following test or build failures:
Tests affected:
- https://www.internalfb.com/intern/test/844425001950025/
- https://www.internalfb.com/intern/test/844425001950027/

Here's the Multisect link:
https://www.internalfb.com/intern/testinfra/multisect/1259258
Here are the tasks that are relevant to this breakage:
T120995919: 51 tests started failing for oncall d2go in the last 2 weeks
We're generating a revert to back out the changes in this diff, please note the backout may land if someone accepts it.

Reviewed By: wat3rBro

Differential Revision: D39594147

fbshipit-source-id: 56c489bb9feea2d60a2a5f0e89941ed7c0f3f675
parent 1af59d41
......@@ -73,7 +73,7 @@ def convert_quantized_model(
logger.info(f"Converting quantized model {cfg.QUANTIZATION.BACKEND}...")
# convert the fake-quantized model to int8 model
pytorch_model = convert_to_quantized_model(cfg, pytorch_model, data_loader)
pytorch_model = convert_to_quantized_model(cfg, pytorch_model)
logger.info(f"Quantized Model:\n{pytorch_model}")
return pytorch_model
......
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Tuple
import torch
from mobile_cv.common.misc.oss_utils import fb_overwritable
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
if TORCH_VERSION > (1, 10):
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
else:
from torch.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
@fb_overwritable()
def get_prepare_fx_fn(cfg, is_qat):
return prepare_qat_fx if is_qat else prepare_fx
@fb_overwritable()
def get_convert_fx_fn(cfg, example_inputs):
return convert_fx
......@@ -17,7 +17,6 @@ from mobile_cv.arch.quantization.observer import update_stat as observer_update_
from mobile_cv.arch.utils import fuse_utils
from mobile_cv.common.misc.iter_utils import recursive_iterate
from .fx import get_convert_fx_fn, get_prepare_fx_fn
from .qconfig import set_backend_and_create_qconfig, smart_decode_backend
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
......@@ -290,12 +289,12 @@ def default_custom_prepare_fx(cfg, model, is_qat, example_input=None):
" their own MetaArch."
)
prepare_fn = get_prepare_fx_fn(cfg, is_qat)
model = prepare_fn(
model,
qconfig_mapping=qconfig_dict,
example_inputs=(example_input,),
)
if is_qat:
model = prepare_qat_fx(model, qconfig_dict, (example_input,))
else:
model = prepare_fx(model, qconfig_dict, (example_input,))
logger.info("Setup the model with qconfig:\n{}".format(qconfig))
return model
......@@ -344,7 +343,7 @@ def prepare_fake_quant_model(cfg, model, is_qat, example_input=None):
return model
def convert_to_quantized_model(cfg, fp32_model, data_loader):
def convert_to_quantized_model(cfg, fp32_model):
"""
Contralized function to convert fake quant model (fp32 operators) to "real"
quantized model (int8 operators).
......@@ -353,12 +352,10 @@ def convert_to_quantized_model(cfg, fp32_model, data_loader):
int8_model = convert(fp32_model, inplace=False)
else:
# FX graph mode quantization
example_input = next(iter(data_loader))
if hasattr(fp32_model, "custom_convert_fx"):
int8_model = fp32_model.custom_convert_fx(cfg, example_input)
int8_model = fp32_model.custom_convert_fx(cfg)
else:
convert_fn = get_convert_fx_fn(cfg, (example_input,))
int8_model = convert_fn(fp32_model)
int8_model = convert_fx(fp32_model)
return int8_model
......
......@@ -9,7 +9,6 @@ if TORCH_VERSION > (1, 10):
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
else:
from torch.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
from mobile_cv.common.misc.oss_utils import fb_overwritable
QCONFIG_CREATOR_REGISTRY = Registry("QCONFIG_CREATOR_REGISTRY")
......@@ -83,7 +82,6 @@ def validate_native_backend(backend):
)
@fb_overwritable()
def _smart_parse_extended_backend(extended_backend):
"""
D2Go extends the definition of quantization "backend". In addition to PyTorch's
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment