train_quantization.py 11.5 KB
Newer Older
1
import copy
2
3
4
5
6
import datetime
import os
import time

import torch
7
import torch.ao.quantization
8
9
10
import torch.utils.data
import torchvision
import utils
11
from torch import nn
12
13
14
from train import train_one_epoch, evaluate, load_data


15
try:
16
    from torchvision import prototype
17
except ImportError:
18
    prototype = None
19
20


21
def main(args):
22
    if args.prototype and prototype is None:
23
        raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
24
25
    if not args.prototype and args.weights:
        raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.")
26
27
28
29
30
31
32
    if args.output_dir:
        utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)
    print(args)

    if args.post_training_quantize and args.distributed:
33
        raise RuntimeError("Post training quantization example should not be performed on distributed mode")
34
35
36
37
38
39
40
41
42
43
44

    # Set backend engine to ensure that quantized model runs on the correct kernels
    if args.backend not in torch.backends.quantized.supported_engines:
        raise RuntimeError("Quantized backend not supported: " + str(args.backend))
    torch.backends.quantized.engine = args.backend

    device = torch.device(args.device)
    torch.backends.cudnn.benchmark = True

    # Data loading code
    print("Loading data")
45
46
    train_dir = os.path.join(args.data_path, "train")
    val_dir = os.path.join(args.data_path, "val")
47

48
    dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
49
    data_loader = torch.utils.data.DataLoader(
50
51
        dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True
    )
52
53

    data_loader_test = torch.utils.data.DataLoader(
54
55
        dataset_test, batch_size=args.eval_batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True
    )
56
57
58

    print("Creating model", args.model)
    # when training quantized models, we always start from a pre-trained fp32 reference model
59
    if not args.prototype:
60
61
        model = torchvision.models.quantization.__dict__[args.model](pretrained=True, quantize=args.test_only)
    else:
62
        model = prototype.models.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only)
63
    model.to(device)
64
65

    if not (args.test_only or args.post_training_quantize):
66
        model.fuse_model(is_qat=True)
67
68
        model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.backend)
        torch.ao.quantization.prepare_qat(model, inplace=True)
69

70
71
72
        if args.distributed and args.sync_bn:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

73
        optimizer = torch.optim.SGD(
74
75
            model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
        )
76

77
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
78
79
80
81
82
83
84
85

    criterion = nn.CrossEntropyLoss()
    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module

    if args.resume:
86
87
88
89
90
        checkpoint = torch.load(args.resume, map_location="cpu")
        model_without_ddp.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
        args.start_epoch = checkpoint["epoch"] + 1
91
92
93
94

    if args.post_training_quantize:
        # perform calibration on a subset of the training dataset
        # for that, create a subset of the training dataset
95
        ds = torch.utils.data.Subset(dataset, indices=list(range(args.batch_size * args.num_calibration_batches)))
96
        data_loader_calibration = torch.utils.data.DataLoader(
97
98
            ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True
        )
99
        model.eval()
100
        model.fuse_model(is_qat=False)
101
102
        model.qconfig = torch.ao.quantization.get_default_qconfig(args.backend)
        torch.ao.quantization.prepare(model, inplace=True)
103
104
105
        # Calibrate first
        print("Calibrating")
        evaluate(model, criterion, data_loader_calibration, device=device, print_freq=1)
106
        torch.ao.quantization.convert(model, inplace=True)
107
        if args.output_dir:
108
            print("Saving quantized model")
109
            if utils.is_main_process():
110
                torch.save(model.state_dict(), os.path.join(args.output_dir, "quantized_post_train_model.pth"))
111
112
113
114
115
116
117
118
        print("Evaluating post-training quantized model")
        evaluate(model, criterion, data_loader_test, device=device)
        return

    if args.test_only:
        evaluate(model, criterion, data_loader_test, device=device)
        return

119
120
    model.apply(torch.ao.quantization.enable_observer)
    model.apply(torch.ao.quantization.enable_fake_quant)
121
122
123
124
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
125
        print("Starting training for epoch", epoch)
126
        train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args)
127
        lr_scheduler.step()
128
        with torch.inference_mode():
129
            if epoch >= args.num_observer_update_epochs:
130
                print("Disabling observer for subseq epochs, epoch = ", epoch)
131
                model.apply(torch.ao.quantization.disable_observer)
132
            if epoch >= args.num_batch_norm_update_epochs:
133
                print("Freezing BN for subseq epochs, epoch = ", epoch)
134
                model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
135
            print("Evaluate QAT model")
136

137
            evaluate(model, criterion, data_loader_test, device=device, log_suffix="QAT")
138
            quantized_eval_model = copy.deepcopy(model_without_ddp)
139
            quantized_eval_model.eval()
140
            quantized_eval_model.to(torch.device("cpu"))
141
            torch.ao.quantization.convert(quantized_eval_model, inplace=True)
142

143
144
            print("Evaluate Quantized model")
            evaluate(quantized_eval_model, criterion, data_loader_test, device=torch.device("cpu"))
145
146
147
148
149

        model.train()

        if args.output_dir:
            checkpoint = {
150
151
152
153
154
155
156
                "model": model_without_ddp.state_dict(),
                "eval_model": quantized_eval_model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "lr_scheduler": lr_scheduler.state_dict(),
                "epoch": epoch,
                "args": args,
            }
157
            utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
158
159
            utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
        print("Saving models after epoch ", epoch)
160
161
162

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
163
    print(f"Training time {total_time_str}")
164
165


166
def get_args_parser(add_help=True):
167
    import argparse
168
169
170

    parser = argparse.ArgumentParser(description="PyTorch Quantized Classification Training", add_help=add_help)

171
172
173
174
    parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", type=str, help="dataset path")
    parser.add_argument("--model", default="mobilenet_v2", type=str, help="model name")
    parser.add_argument("--backend", default="qnnpack", type=str, help="fbgemm or qnnpack")
    parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
175

176
177
178
    parser.add_argument(
        "-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
    )
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
    parser.add_argument("--eval-batch-size", default=128, type=int, help="batch size for evaluation")
    parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
    parser.add_argument(
        "--num-observer-update-epochs",
        default=4,
        type=int,
        metavar="N",
        help="number of total epochs to update observers",
    )
    parser.add_argument(
        "--num-batch-norm-update-epochs",
        default=3,
        type=int,
        metavar="N",
        help="number of total epochs to update batch norm stats",
    )
    parser.add_argument(
        "--num-calibration-batches",
        default=32,
        type=int,
        metavar="N",
        help="number of batches of training set for \
                              observer calibration ",
    )

    parser.add_argument(
        "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
    )
    parser.add_argument("--lr", default=0.0001, type=float, help="initial learning rate")
    parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
    parser.add_argument(
        "--wd",
        "--weight-decay",
        default=1e-4,
        type=float,
        metavar="W",
        help="weight decay (default: 1e-4)",
        dest="weight_decay",
    )
    parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
    parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
    parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
221
222
    parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
    parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
223
    parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
224
225
226
227
228
229
230
    parser.add_argument(
        "--cache-dataset",
        dest="cache_dataset",
        help="Cache the datasets for quicker initialization. \
             It also serializes the transforms",
        action="store_true",
    )
231
232
233
234
235
236
    parser.add_argument(
        "--sync-bn",
        dest="sync_bn",
        help="Use sync batch norm",
        action="store_true",
    )
237
238
239
240
241
242
243
244
245
246
247
248
249
250
    parser.add_argument(
        "--test-only",
        dest="test_only",
        help="Only test the model",
        action="store_true",
    )
    parser.add_argument(
        "--post-training-quantize",
        dest="post_training_quantize",
        help="Post training quantize the model",
        action="store_true",
    )

    # distributed training parameters
251
    parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
252
    parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
253

254
255
256
257
258
259
260
261
262
263
264
265
    parser.add_argument(
        "--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)"
    )
    parser.add_argument(
        "--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)"
    )
    parser.add_argument(
        "--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)"
    )
    parser.add_argument(
        "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
    )
266
    parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
267

268
    # Prototype models only
269
270
271
272
273
274
    parser.add_argument(
        "--prototype",
        dest="prototype",
        help="Use prototype model builders instead those from main area",
        action="store_true",
    )
275
276
    parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")

277
    return parser
278
279
280


if __name__ == "__main__":
281
    args = get_args_parser().parse_args()
282
    main(args)