Unverified Commit 9a34c0c9 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Adding multiweight support to Quantized ResNet (#4827)

* Adding multi-weight support to Quantized ResNet.

* Update references script to support testing quantized models with the new API.

* Handle quantized models correctly in ref script.

* Fixing references for quantization.
parent 6a60b9bc
...@@ -151,9 +151,9 @@ torchrun --nproc_per_node=8 train.py\ ...@@ -151,9 +151,9 @@ torchrun --nproc_per_node=8 train.py\
## Quantized ## Quantized
### Parameters used for generating quantized models: ### Post training quantized models
For all post training quantized models (All quantized models except mobilenet-v2), the settings are: For all post training quantized models, the settings are:
1. num_calibration_batches: 32 1. num_calibration_batches: 32
2. num_workers: 16 2. num_workers: 16
...@@ -162,8 +162,11 @@ For all post training quantized models (All quantized models except mobilenet-v2 ...@@ -162,8 +162,11 @@ For all post training quantized models (All quantized models except mobilenet-v2
5. backend: 'fbgemm' 5. backend: 'fbgemm'
``` ```
python train_quantization.py --device='cpu' --post-training-quantize --backend='fbgemm' --model='<model_name>' python train_quantization.py --device='cpu' --post-training-quantize --backend='fbgemm' --model='$MODEL'
``` ```
Here `$MODEL` is one of `googlenet`, `inception_v3`, `resnet18`, `resnet50`, `resnext101_32x8d` and `shufflenet_v2_x1_0`.
### QAT MobileNetV2
For Mobilenet-v2, the model was trained with quantization aware training, the settings used are: For Mobilenet-v2, the model was trained with quantization aware training, the settings used are:
1. num_workers: 16 1. num_workers: 16
...@@ -185,6 +188,8 @@ torchrun --nproc_per_node=8 train_quantization.py --model='mobilenet_v2' ...@@ -185,6 +188,8 @@ torchrun --nproc_per_node=8 train_quantization.py --model='mobilenet_v2'
Training converges at about 10 epochs. Training converges at about 10 epochs.
### QAT MobileNetV3
For Mobilenet-v3 Large, the model was trained with quantization aware training, the settings used are: For Mobilenet-v3 Large, the model was trained with quantization aware training, the settings used are:
1. num_workers: 16 1. num_workers: 16
2. batch_size: 32 2. batch_size: 32
......
...@@ -153,7 +153,7 @@ def load_data(traindir, valdir, args): ...@@ -153,7 +153,7 @@ def load_data(traindir, valdir, args):
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
) )
else: else:
fn = PM.__dict__[args.model] fn = PM.quantization.__dict__[args.model] if hasattr(args, "backend") else PM.__dict__[args.model]
weights = PM._api.get_weight(fn, args.weights) weights = PM._api.get_weight(fn, args.weights)
preprocessing = weights.transforms() preprocessing = weights.transforms()
......
...@@ -12,6 +12,12 @@ from torch import nn ...@@ -12,6 +12,12 @@ from torch import nn
from train import train_one_epoch, evaluate, load_data from train import train_one_epoch, evaluate, load_data
try:
from torchvision.prototype import models as PM
except ImportError:
PM = None
def main(args): def main(args):
if args.output_dir: if args.output_dir:
utils.mkdir(args.output_dir) utils.mkdir(args.output_dir)
...@@ -46,7 +52,12 @@ def main(args): ...@@ -46,7 +52,12 @@ def main(args):
print("Creating model", args.model) print("Creating model", args.model)
# when training quantized models, we always start from a pre-trained fp32 reference model # when training quantized models, we always start from a pre-trained fp32 reference model
model = torchvision.models.quantization.__dict__[args.model](pretrained=True, quantize=args.test_only) if not args.weights:
model = torchvision.models.quantization.__dict__[args.model](pretrained=True, quantize=args.test_only)
else:
if PM is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
model = PM.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only)
model.to(device) model.to(device)
if not (args.test_only or args.post_training_quantize): if not (args.test_only or args.post_training_quantize):
...@@ -251,6 +262,9 @@ def get_args_parser(add_help=True): ...@@ -251,6 +262,9 @@ def get_args_parser(add_help=True):
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)" "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
) )
# Prototype models only
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
return parser return parser
......
...@@ -30,10 +30,15 @@ def get_models_with_module_names(module): ...@@ -30,10 +30,15 @@ def get_models_with_module_names(module):
return [(fn, module_name) for fn in TM.get_models_from_module(module)] return [(fn, module_name) for fn in TM.get_models_from_module(module)]
def test_get_weight(): @pytest.mark.parametrize(
fn = models.resnet50 "model_fn, weight",
weight_name = "ImageNet1K_RefV2" [
assert models._api.get_weight(fn, weight_name) == models.ResNet50Weights.ImageNet1K_RefV2 (models.resnet50, models.ResNet50Weights.ImageNet1K_RefV2),
(models.quantization.resnet50, models.quantization.QuantizedResNet50Weights.ImageNet1K_FBGEMM_RefV1),
],
)
def test_get_weight(model_fn, weight):
assert models._api.get_weight(model_fn, weight.name) == weight
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models)) @pytest.mark.parametrize("model_fn", TM.get_models_from_module(models))
...@@ -43,6 +48,12 @@ def test_classification_model(model_fn, dev): ...@@ -43,6 +48,12 @@ def test_classification_model(model_fn, dev):
TM.test_classification_model(model_fn, dev) TM.test_classification_model(model_fn, dev)
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.quantization))
@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled")
def test_quantized_classification_model(model_fn):
TM.test_quantized_classification_model(model_fn)
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.segmentation)) @pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.segmentation))
@pytest.mark.parametrize("dev", cpu_and_gpu()) @pytest.mark.parametrize("dev", cpu_and_gpu())
@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled") @pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled")
...@@ -60,6 +71,7 @@ def test_video_model(model_fn, dev): ...@@ -60,6 +71,7 @@ def test_video_model(model_fn, dev):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_fn, module_name", "model_fn, module_name",
get_models_with_module_names(models) get_models_with_module_names(models)
+ get_models_with_module_names(models.quantization)
+ get_models_with_module_names(models.segmentation) + get_models_with_module_names(models.segmentation)
+ get_models_with_module_names(models.video), + get_models_with_module_names(models.video),
) )
...@@ -70,6 +82,9 @@ def test_old_vs_new_factory(model_fn, module_name, dev): ...@@ -70,6 +82,9 @@ def test_old_vs_new_factory(model_fn, module_name, dev):
"models": { "models": {
"input_shape": (1, 3, 224, 224), "input_shape": (1, 3, 224, 224),
}, },
"quantization": {
"input_shape": (1, 3, 224, 224),
},
"segmentation": { "segmentation": {
"input_shape": (1, 3, 520, 520), "input_shape": (1, 3, 520, 520),
}, },
......
...@@ -12,10 +12,18 @@ from ....models.quantization.resnet import ( ...@@ -12,10 +12,18 @@ from ....models.quantization.resnet import (
from ...transforms.presets import ImageNetEval from ...transforms.presets import ImageNetEval
from .._api import Weights, WeightEntry from .._api import Weights, WeightEntry
from .._meta import _IMAGENET_CATEGORIES from .._meta import _IMAGENET_CATEGORIES
from ..resnet import ResNet50Weights from ..resnet import ResNet18Weights, ResNet50Weights, ResNeXt101_32x8dWeights
__all__ = ["QuantizableResNet", "QuantizedResNet50Weights", "resnet50"] __all__ = [
"QuantizableResNet",
"QuantizedResNet18Weights",
"QuantizedResNet50Weights",
"QuantizedResNeXt101_32x8dWeights",
"resnet18",
"resnet50",
"resnext101_32x8d",
]
def _resnet( def _resnet(
...@@ -47,22 +55,67 @@ _common_meta = { ...@@ -47,22 +55,67 @@ _common_meta = {
"size": (224, 224), "size": (224, 224),
"categories": _IMAGENET_CATEGORIES, "categories": _IMAGENET_CATEGORIES,
"backend": "fbgemm", "backend": "fbgemm",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
} }
class QuantizedResNet18Weights(Weights):
ImageNet1K_FBGEMM_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"acc@1": 69.494,
"acc@5": 88.882,
},
)
class QuantizedResNet50Weights(Weights): class QuantizedResNet50Weights(Weights):
ImageNet1K_FBGEMM_RefV1 = WeightEntry( ImageNet1K_FBGEMM_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth", url="https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageNetEval, crop_size=224),
meta={ meta={
**_common_meta, **_common_meta,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#quantized",
"acc@1": 75.920, "acc@1": 75.920,
"acc@5": 92.814, "acc@5": 92.814,
}, },
) )
class QuantizedResNeXt101_32x8dWeights(Weights):
ImageNet1K_FBGEMM_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"acc@1": 78.986,
"acc@5": 94.480,
},
)
def resnet18(
weights: Optional[Union[QuantizedResNet18Weights, ResNet18Weights]] = None,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableResNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"):
weights = QuantizedResNet18Weights.ImageNet1K_FBGEMM_RefV1 if quantize else ResNet18Weights.ImageNet1K_RefV1
else:
weights = None
if quantize:
weights = QuantizedResNet18Weights.verify(weights)
else:
weights = ResNet18Weights.verify(weights)
return _resnet(QuantizableBasicBlock, [2, 2, 2, 2], weights, progress, quantize, **kwargs)
def resnet50( def resnet50(
weights: Optional[Union[QuantizedResNet50Weights, ResNet50Weights]] = None, weights: Optional[Union[QuantizedResNet50Weights, ResNet50Weights]] = None,
progress: bool = True, progress: bool = True,
...@@ -82,3 +135,30 @@ def resnet50( ...@@ -82,3 +135,30 @@ def resnet50(
weights = ResNet50Weights.verify(weights) weights = ResNet50Weights.verify(weights)
return _resnet(QuantizableBottleneck, [3, 4, 6, 3], weights, progress, quantize, **kwargs) return _resnet(QuantizableBottleneck, [3, 4, 6, 3], weights, progress, quantize, **kwargs)
def resnext101_32x8d(
weights: Optional[Union[QuantizedResNeXt101_32x8dWeights, ResNeXt101_32x8dWeights]] = None,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableResNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"):
weights = (
QuantizedResNeXt101_32x8dWeights.ImageNet1K_FBGEMM_RefV1
if quantize
else ResNeXt101_32x8dWeights.ImageNet1K_RefV1
)
else:
weights = None
if quantize:
weights = QuantizedResNeXt101_32x8dWeights.verify(weights)
else:
weights = ResNeXt101_32x8dWeights.verify(weights)
kwargs["groups"] = 32
kwargs["width_per_group"] = 8
return _resnet(QuantizableBottleneck, [3, 4, 23, 3], weights, progress, quantize, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment