train_quantization.py 11.4 KB
Newer Older
1
import copy
2
3
4
5
6
import datetime
import os
import time

import torch
7
import torch.ao.quantization
8
9
10
import torch.utils.data
import torchvision
import utils
11
from torch import nn
12
from train import evaluate, load_data, train_one_epoch
13
14
15
16
17
18
19
20
21
22


def main(args):
    if args.output_dir:
        utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)
    print(args)

    if args.post_training_quantize and args.distributed:
23
        raise RuntimeError("Post training quantization example should not be performed on distributed mode")
24
25

    # Set backend engine to ensure that quantized model runs on the correct kernels
Nicolas Hug's avatar
Nicolas Hug committed
26
27
28
    if args.qbackend not in torch.backends.quantized.supported_engines:
        raise RuntimeError("Quantized backend not supported: " + str(args.qbackend))
    torch.backends.quantized.engine = args.qbackend
29
30
31
32
33
34

    device = torch.device(args.device)
    torch.backends.cudnn.benchmark = True

    # Data loading code
    print("Loading data")
35
36
    train_dir = os.path.join(args.data_path, "train")
    val_dir = os.path.join(args.data_path, "val")
37

38
    dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
39
    data_loader = torch.utils.data.DataLoader(
40
41
        dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True
    )
42
43

    data_loader_test = torch.utils.data.DataLoader(
44
45
        dataset_test, batch_size=args.eval_batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True
    )
46
47
48

    print("Creating model", args.model)
    # when training quantized models, we always start from a pre-trained fp32 reference model
49
50
51
52
53
    prefix = "quantized_"
    model_name = args.model
    if not model_name.startswith(prefix):
        model_name = prefix + model_name
    model = torchvision.models.get_model(model_name, weights=args.weights, quantize=args.test_only)
54
    model.to(device)
55
56

    if not (args.test_only or args.post_training_quantize):
57
        model.fuse_model(is_qat=True)
Nicolas Hug's avatar
Nicolas Hug committed
58
        model.qconfig = torch.ao.quantization.get_default_qat_qconfig(args.qbackend)
59
        torch.ao.quantization.prepare_qat(model, inplace=True)
60

61
62
63
        if args.distributed and args.sync_bn:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

64
        optimizer = torch.optim.SGD(
65
66
            model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
        )
67

68
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
69
70
71
72
73
74
75
76

    criterion = nn.CrossEntropyLoss()
    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module

    if args.resume:
77
78
79
80
81
        checkpoint = torch.load(args.resume, map_location="cpu")
        model_without_ddp.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
        args.start_epoch = checkpoint["epoch"] + 1
82
83
84
85

    if args.post_training_quantize:
        # perform calibration on a subset of the training dataset
        # for that, create a subset of the training dataset
86
        ds = torch.utils.data.Subset(dataset, indices=list(range(args.batch_size * args.num_calibration_batches)))
87
        data_loader_calibration = torch.utils.data.DataLoader(
88
89
            ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True
        )
90
        model.eval()
91
        model.fuse_model(is_qat=False)
Nicolas Hug's avatar
Nicolas Hug committed
92
        model.qconfig = torch.ao.quantization.get_default_qconfig(args.qbackend)
93
        torch.ao.quantization.prepare(model, inplace=True)
94
95
96
        # Calibrate first
        print("Calibrating")
        evaluate(model, criterion, data_loader_calibration, device=device, print_freq=1)
97
        torch.ao.quantization.convert(model, inplace=True)
98
        if args.output_dir:
99
            print("Saving quantized model")
100
            if utils.is_main_process():
101
                torch.save(model.state_dict(), os.path.join(args.output_dir, "quantized_post_train_model.pth"))
102
103
104
105
106
107
108
109
        print("Evaluating post-training quantized model")
        evaluate(model, criterion, data_loader_test, device=device)
        return

    if args.test_only:
        evaluate(model, criterion, data_loader_test, device=device)
        return

110
111
    model.apply(torch.ao.quantization.enable_observer)
    model.apply(torch.ao.quantization.enable_fake_quant)
112
113
114
115
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
116
        print("Starting training for epoch", epoch)
117
        train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args)
118
        lr_scheduler.step()
119
        with torch.inference_mode():
120
            if epoch >= args.num_observer_update_epochs:
121
                print("Disabling observer for subseq epochs, epoch = ", epoch)
122
                model.apply(torch.ao.quantization.disable_observer)
123
            if epoch >= args.num_batch_norm_update_epochs:
124
                print("Freezing BN for subseq epochs, epoch = ", epoch)
125
                model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
126
            print("Evaluate QAT model")
127

128
            evaluate(model, criterion, data_loader_test, device=device, log_suffix="QAT")
129
            quantized_eval_model = copy.deepcopy(model_without_ddp)
130
            quantized_eval_model.eval()
131
            quantized_eval_model.to(torch.device("cpu"))
132
            torch.ao.quantization.convert(quantized_eval_model, inplace=True)
133

134
135
            print("Evaluate Quantized model")
            evaluate(quantized_eval_model, criterion, data_loader_test, device=torch.device("cpu"))
136
137
138
139
140

        model.train()

        if args.output_dir:
            checkpoint = {
141
142
143
144
145
146
147
                "model": model_without_ddp.state_dict(),
                "eval_model": quantized_eval_model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "lr_scheduler": lr_scheduler.state_dict(),
                "epoch": epoch,
                "args": args,
            }
148
            utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
149
150
            utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
        print("Saving models after epoch ", epoch)
151
152
153

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
154
    print(f"Training time {total_time_str}")
155
156


157
def get_args_parser(add_help=True):
158
    import argparse
159
160
161

    parser = argparse.ArgumentParser(description="PyTorch Quantized Classification Training", add_help=add_help)

162
163
    parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", type=str, help="dataset path")
    parser.add_argument("--model", default="mobilenet_v2", type=str, help="model name")
Nicolas Hug's avatar
Nicolas Hug committed
164
    parser.add_argument("--qbackend", default="qnnpack", type=str, help="Quantized backend: fbgemm or qnnpack")
165
    parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
166

167
168
169
    parser.add_argument(
        "-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
    )
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
    parser.add_argument("--eval-batch-size", default=128, type=int, help="batch size for evaluation")
    parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
    parser.add_argument(
        "--num-observer-update-epochs",
        default=4,
        type=int,
        metavar="N",
        help="number of total epochs to update observers",
    )
    parser.add_argument(
        "--num-batch-norm-update-epochs",
        default=3,
        type=int,
        metavar="N",
        help="number of total epochs to update batch norm stats",
    )
    parser.add_argument(
        "--num-calibration-batches",
        default=32,
        type=int,
        metavar="N",
        help="number of batches of training set for \
                              observer calibration ",
    )

    parser.add_argument(
        "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
    )
    parser.add_argument("--lr", default=0.0001, type=float, help="initial learning rate")
    parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
    parser.add_argument(
        "--wd",
        "--weight-decay",
        default=1e-4,
        type=float,
        metavar="W",
        help="weight decay (default: 1e-4)",
        dest="weight_decay",
    )
    parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
    parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
    parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
212
213
    parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
    parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
214
    parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
215
216
217
218
219
220
221
    parser.add_argument(
        "--cache-dataset",
        dest="cache_dataset",
        help="Cache the datasets for quicker initialization. \
             It also serializes the transforms",
        action="store_true",
    )
222
223
224
225
226
227
    parser.add_argument(
        "--sync-bn",
        dest="sync_bn",
        help="Use sync batch norm",
        action="store_true",
    )
228
229
230
231
232
233
234
235
236
237
238
239
240
241
    parser.add_argument(
        "--test-only",
        dest="test_only",
        help="Only test the model",
        action="store_true",
    )
    parser.add_argument(
        "--post-training-quantize",
        dest="post_training_quantize",
        help="Post training quantize the model",
        action="store_true",
    )

    # distributed training parameters
242
    parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
243
    parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
244

245
246
247
248
249
250
251
252
253
254
255
256
    parser.add_argument(
        "--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)"
    )
    parser.add_argument(
        "--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)"
    )
    parser.add_argument(
        "--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)"
    )
    parser.add_argument(
        "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
    )
257
    parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
258
259
    parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")

Nicolas Hug's avatar
Nicolas Hug committed
260
261
262
    parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive")
    parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms")

263
    return parser
264
265
266


if __name__ == "__main__":
267
    args = get_args_parser().parse_args()
Nicolas Hug's avatar
Nicolas Hug committed
268
269
270
271
272
    if args.backend in ("fbgemm", "qnnpack"):
        raise ValueError(
            "The --backend parameter has been re-purposed to specify the backend of the transforms (PIL or Tensor) "
            "instead of the quantized backend. Please use the --qbackend parameter to specify the quantized backend."
        )
273
    main(args)