Unverified Commit 9a34c0c9 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Adding multiweight support to Quantized ResNet (#4827)

* Adding multi-weight support to Quantized ResNet.

* Update references script to support testing quantized models with the new API.

* Handle quantized models correctly in ref script.

* Fixing references for quantization.
parent 6a60b9bc
......@@ -151,9 +151,9 @@ torchrun --nproc_per_node=8 train.py\
## Quantized
### Parameters used for generating quantized models:
### Post training quantized models
For all post training quantized models (All quantized models except mobilenet-v2), the settings are:
For all post training quantized models, the settings are:
1. num_calibration_batches: 32
2. num_workers: 16
......@@ -162,8 +162,11 @@ For all post training quantized models (All quantized models except mobilenet-v2
5. backend: 'fbgemm'
```
python train_quantization.py --device='cpu' --post-training-quantize --backend='fbgemm' --model='<model_name>'
python train_quantization.py --device='cpu' --post-training-quantize --backend='fbgemm' --model='$MODEL'
```
Here `$MODEL` is one of `googlenet`, `inception_v3`, `resnet18`, `resnet50`, `resnext101_32x8d` and `shufflenet_v2_x1_0`.
### QAT MobileNetV2
For Mobilenet-v2, the model was trained with quantization aware training, the settings used are:
1. num_workers: 16
......@@ -185,6 +188,8 @@ torchrun --nproc_per_node=8 train_quantization.py --model='mobilenet_v2'
Training converges at about 10 epochs.
### QAT MobileNetV3
For Mobilenet-v3 Large, the model was trained with quantization aware training, the settings used are:
1. num_workers: 16
2. batch_size: 32
......
......@@ -153,7 +153,7 @@ def load_data(traindir, valdir, args):
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
)
else:
fn = PM.__dict__[args.model]
fn = PM.quantization.__dict__[args.model] if hasattr(args, "backend") else PM.__dict__[args.model]
weights = PM._api.get_weight(fn, args.weights)
preprocessing = weights.transforms()
......
......@@ -12,6 +12,12 @@ from torch import nn
from train import train_one_epoch, evaluate, load_data
try:
from torchvision.prototype import models as PM
except ImportError:
PM = None
def main(args):
if args.output_dir:
utils.mkdir(args.output_dir)
......@@ -46,7 +52,12 @@ def main(args):
print("Creating model", args.model)
# when training quantized models, we always start from a pre-trained fp32 reference model
model = torchvision.models.quantization.__dict__[args.model](pretrained=True, quantize=args.test_only)
if not args.weights:
model = torchvision.models.quantization.__dict__[args.model](pretrained=True, quantize=args.test_only)
else:
if PM is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
model = PM.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only)
model.to(device)
if not (args.test_only or args.post_training_quantize):
......@@ -251,6 +262,9 @@ def get_args_parser(add_help=True):
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
)
# Prototype models only
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
return parser
......
......@@ -30,10 +30,15 @@ def get_models_with_module_names(module):
return [(fn, module_name) for fn in TM.get_models_from_module(module)]
def test_get_weight():
fn = models.resnet50
weight_name = "ImageNet1K_RefV2"
assert models._api.get_weight(fn, weight_name) == models.ResNet50Weights.ImageNet1K_RefV2
@pytest.mark.parametrize(
"model_fn, weight",
[
(models.resnet50, models.ResNet50Weights.ImageNet1K_RefV2),
(models.quantization.resnet50, models.quantization.QuantizedResNet50Weights.ImageNet1K_FBGEMM_RefV1),
],
)
def test_get_weight(model_fn, weight):
assert models._api.get_weight(model_fn, weight.name) == weight
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models))
......@@ -43,6 +48,12 @@ def test_classification_model(model_fn, dev):
TM.test_classification_model(model_fn, dev)
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.quantization))
@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled")
def test_quantized_classification_model(model_fn):
TM.test_quantized_classification_model(model_fn)
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.segmentation))
@pytest.mark.parametrize("dev", cpu_and_gpu())
@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled")
......@@ -60,6 +71,7 @@ def test_video_model(model_fn, dev):
@pytest.mark.parametrize(
"model_fn, module_name",
get_models_with_module_names(models)
+ get_models_with_module_names(models.quantization)
+ get_models_with_module_names(models.segmentation)
+ get_models_with_module_names(models.video),
)
......@@ -70,6 +82,9 @@ def test_old_vs_new_factory(model_fn, module_name, dev):
"models": {
"input_shape": (1, 3, 224, 224),
},
"quantization": {
"input_shape": (1, 3, 224, 224),
},
"segmentation": {
"input_shape": (1, 3, 520, 520),
},
......
......@@ -12,10 +12,18 @@ from ....models.quantization.resnet import (
from ...transforms.presets import ImageNetEval
from .._api import Weights, WeightEntry
from .._meta import _IMAGENET_CATEGORIES
from ..resnet import ResNet50Weights
from ..resnet import ResNet18Weights, ResNet50Weights, ResNeXt101_32x8dWeights
__all__ = ["QuantizableResNet", "QuantizedResNet50Weights", "resnet50"]
__all__ = [
"QuantizableResNet",
"QuantizedResNet18Weights",
"QuantizedResNet50Weights",
"QuantizedResNeXt101_32x8dWeights",
"resnet18",
"resnet50",
"resnext101_32x8d",
]
def _resnet(
......@@ -47,22 +55,67 @@ _common_meta = {
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
"backend": "fbgemm",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
}
class QuantizedResNet18Weights(Weights):
ImageNet1K_FBGEMM_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"acc@1": 69.494,
"acc@5": 88.882,
},
)
class QuantizedResNet50Weights(Weights):
ImageNet1K_FBGEMM_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#quantized",
"acc@1": 75.920,
"acc@5": 92.814,
},
)
class QuantizedResNeXt101_32x8dWeights(Weights):
ImageNet1K_FBGEMM_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"acc@1": 78.986,
"acc@5": 94.480,
},
)
def resnet18(
weights: Optional[Union[QuantizedResNet18Weights, ResNet18Weights]] = None,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableResNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"):
weights = QuantizedResNet18Weights.ImageNet1K_FBGEMM_RefV1 if quantize else ResNet18Weights.ImageNet1K_RefV1
else:
weights = None
if quantize:
weights = QuantizedResNet18Weights.verify(weights)
else:
weights = ResNet18Weights.verify(weights)
return _resnet(QuantizableBasicBlock, [2, 2, 2, 2], weights, progress, quantize, **kwargs)
def resnet50(
weights: Optional[Union[QuantizedResNet50Weights, ResNet50Weights]] = None,
progress: bool = True,
......@@ -82,3 +135,30 @@ def resnet50(
weights = ResNet50Weights.verify(weights)
return _resnet(QuantizableBottleneck, [3, 4, 6, 3], weights, progress, quantize, **kwargs)
def resnext101_32x8d(
weights: Optional[Union[QuantizedResNeXt101_32x8dWeights, ResNeXt101_32x8dWeights]] = None,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableResNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"):
weights = (
QuantizedResNeXt101_32x8dWeights.ImageNet1K_FBGEMM_RefV1
if quantize
else ResNeXt101_32x8dWeights.ImageNet1K_RefV1
)
else:
weights = None
if quantize:
weights = QuantizedResNeXt101_32x8dWeights.verify(weights)
else:
weights = ResNeXt101_32x8dWeights.verify(weights)
kwargs["groups"] = 32
kwargs["width_per_group"] = 8
return _resnet(QuantizableBottleneck, [3, 4, 23, 3], weights, progress, quantize, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment