train_quantization.py 11.1 KB
Newer Older
1
import copy
2
3
4
5
6
import datetime
import os
import time

import torch
7
import torch.ao.quantization
8
9
10
import torch.utils.data
import torchvision
import utils
11
from torch import nn
12
13
14
from train import train_one_epoch, evaluate, load_data


15
16
17
18
19
20
try:
    from torchvision.prototype import models as PM
except ImportError:
    PM = None


21
def main(args):
22
23
    if args.weights and PM is None:
        raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
24
25
26
27
28
29
30
    if args.output_dir:
        utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)
    print(args)

    if args.post_training_quantize and args.distributed:
31
        raise RuntimeError("Post training quantization example should not be performed on distributed mode")
32
33
34
35
36
37
38
39
40
41
42

    # Set backend engine to ensure that quantized model runs on the correct kernels
    if args.backend not in torch.backends.quantized.supported_engines:
        raise RuntimeError("Quantized backend not supported: " + str(args.backend))
    torch.backends.quantized.engine = args.backend

    device = torch.device(args.device)
    torch.backends.cudnn.benchmark = True

    # Data loading code
    print("Loading data")
43
44
    train_dir = os.path.join(args.data_path, "train")
    val_dir = os.path.join(args.data_path, "val")
45

46
    dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
47
    data_loader = torch.utils.data.DataLoader(
48
49
        dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True
    )
50
51

    data_loader_test = torch.utils.data.DataLoader(
52
53
        dataset_test, batch_size=args.eval_batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True
    )
54
55
56

    print("Creating model", args.model)
    # when training quantized models, we always start from a pre-trained fp32 reference model
57
58
59
60
    if not args.weights:
        model = torchvision.models.quantization.__dict__[args.model](pretrained=True, quantize=args.test_only)
    else:
        model = PM.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only)
61
    model.to(device)
62
63
64

    if not (args.test_only or args.post_training_quantize):
        model.fuse_model()
65
66
        model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.backend)
        torch.ao.quantization.prepare_qat(model, inplace=True)
67

68
69
70
        if args.distributed and args.sync_bn:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

71
        optimizer = torch.optim.SGD(
72
73
            model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
        )
74

75
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
76
77
78
79
80
81
82
83

    criterion = nn.CrossEntropyLoss()
    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module

    if args.resume:
84
85
86
87
88
        checkpoint = torch.load(args.resume, map_location="cpu")
        model_without_ddp.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
        args.start_epoch = checkpoint["epoch"] + 1
89
90
91
92

    if args.post_training_quantize:
        # perform calibration on a subset of the training dataset
        # for that, create a subset of the training dataset
93
        ds = torch.utils.data.Subset(dataset, indices=list(range(args.batch_size * args.num_calibration_batches)))
94
        data_loader_calibration = torch.utils.data.DataLoader(
95
96
            ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True
        )
97
98
        model.eval()
        model.fuse_model()
99
100
        model.qconfig = torch.ao.quantization.get_default_qconfig(args.backend)
        torch.ao.quantization.prepare(model, inplace=True)
101
102
103
        # Calibrate first
        print("Calibrating")
        evaluate(model, criterion, data_loader_calibration, device=device, print_freq=1)
104
        torch.ao.quantization.convert(model, inplace=True)
105
        if args.output_dir:
106
            print("Saving quantized model")
107
            if utils.is_main_process():
108
                torch.save(model.state_dict(), os.path.join(args.output_dir, "quantized_post_train_model.pth"))
109
110
111
112
113
114
115
116
        print("Evaluating post-training quantized model")
        evaluate(model, criterion, data_loader_test, device=device)
        return

    if args.test_only:
        evaluate(model, criterion, data_loader_test, device=device)
        return

117
118
    model.apply(torch.ao.quantization.enable_observer)
    model.apply(torch.ao.quantization.enable_fake_quant)
119
120
121
122
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
123
        print("Starting training for epoch", epoch)
124
        train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args)
125
        lr_scheduler.step()
126
        with torch.inference_mode():
127
            if epoch >= args.num_observer_update_epochs:
128
                print("Disabling observer for subseq epochs, epoch = ", epoch)
129
                model.apply(torch.ao.quantization.disable_observer)
130
            if epoch >= args.num_batch_norm_update_epochs:
131
                print("Freezing BN for subseq epochs, epoch = ", epoch)
132
                model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
133
            print("Evaluate QAT model")
134

135
            evaluate(model, criterion, data_loader_test, device=device, log_suffix="QAT")
136
            quantized_eval_model = copy.deepcopy(model_without_ddp)
137
            quantized_eval_model.eval()
138
            quantized_eval_model.to(torch.device("cpu"))
139
            torch.ao.quantization.convert(quantized_eval_model, inplace=True)
140

141
142
            print("Evaluate Quantized model")
            evaluate(quantized_eval_model, criterion, data_loader_test, device=torch.device("cpu"))
143
144
145
146
147

        model.train()

        if args.output_dir:
            checkpoint = {
148
149
150
151
152
153
154
                "model": model_without_ddp.state_dict(),
                "eval_model": quantized_eval_model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "lr_scheduler": lr_scheduler.state_dict(),
                "epoch": epoch,
                "args": args,
            }
155
            utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
156
157
            utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
        print("Saving models after epoch ", epoch)
158
159
160

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
161
    print(f"Training time {total_time_str}")
162
163


164
def get_args_parser(add_help=True):
165
    import argparse
166
167
168

    parser = argparse.ArgumentParser(description="PyTorch Quantized Classification Training", add_help=add_help)

169
170
171
172
    parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", type=str, help="dataset path")
    parser.add_argument("--model", default="mobilenet_v2", type=str, help="model name")
    parser.add_argument("--backend", default="qnnpack", type=str, help="fbgemm or qnnpack")
    parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
173

174
175
176
    parser.add_argument(
        "-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
    )
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
    parser.add_argument("--eval-batch-size", default=128, type=int, help="batch size for evaluation")
    parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
    parser.add_argument(
        "--num-observer-update-epochs",
        default=4,
        type=int,
        metavar="N",
        help="number of total epochs to update observers",
    )
    parser.add_argument(
        "--num-batch-norm-update-epochs",
        default=3,
        type=int,
        metavar="N",
        help="number of total epochs to update batch norm stats",
    )
    parser.add_argument(
        "--num-calibration-batches",
        default=32,
        type=int,
        metavar="N",
        help="number of batches of training set for \
                              observer calibration ",
    )

    parser.add_argument(
        "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
    )
    parser.add_argument("--lr", default=0.0001, type=float, help="initial learning rate")
    parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
    parser.add_argument(
        "--wd",
        "--weight-decay",
        default=1e-4,
        type=float,
        metavar="W",
        help="weight decay (default: 1e-4)",
        dest="weight_decay",
    )
    parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
    parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
    parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
219
220
    parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
    parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
221
    parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
222
223
224
225
226
227
228
    parser.add_argument(
        "--cache-dataset",
        dest="cache_dataset",
        help="Cache the datasets for quicker initialization. \
             It also serializes the transforms",
        action="store_true",
    )
229
230
231
232
233
234
    parser.add_argument(
        "--sync-bn",
        dest="sync_bn",
        help="Use sync batch norm",
        action="store_true",
    )
235
236
237
238
239
240
241
242
243
244
245
246
247
248
    parser.add_argument(
        "--test-only",
        dest="test_only",
        help="Only test the model",
        action="store_true",
    )
    parser.add_argument(
        "--post-training-quantize",
        dest="post_training_quantize",
        help="Post training quantize the model",
        action="store_true",
    )

    # distributed training parameters
249
    parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
250
    parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
251

252
253
254
255
256
257
258
259
260
261
262
263
    parser.add_argument(
        "--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)"
    )
    parser.add_argument(
        "--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)"
    )
    parser.add_argument(
        "--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)"
    )
    parser.add_argument(
        "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
    )
264
    parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
265

266
267
268
    # Prototype models only
    parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")

269
    return parser
270
271
272


if __name__ == "__main__":
273
    args = get_args_parser().parse_args()
274
    main(args)