train_quantization.py 9.75 KB
Newer Older
1
import copy
2
3
4
5
6
import datetime
import os
import time

import torch
7
import torch.quantization
8
9
10
import torch.utils.data
import torchvision
import utils
11
from torch import nn
12
13
14
15
16
17
18
19
20
21
22
from train import train_one_epoch, evaluate, load_data


def main(args):
    if args.output_dir:
        utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)
    print(args)

    if args.post_training_quantize and args.distributed:
23
        raise RuntimeError("Post training quantization example should not be performed " "on distributed mode")
24
25
26
27
28
29
30
31
32
33
34

    # Set backend engine to ensure that quantized model runs on the correct kernels
    if args.backend not in torch.backends.quantized.supported_engines:
        raise RuntimeError("Quantized backend not supported: " + str(args.backend))
    torch.backends.quantized.engine = args.backend

    device = torch.device(args.device)
    torch.backends.cudnn.benchmark = True

    # Data loading code
    print("Loading data")
35
36
    train_dir = os.path.join(args.data_path, "train")
    val_dir = os.path.join(args.data_path, "val")
37

38
    dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
39
    data_loader = torch.utils.data.DataLoader(
40
41
        dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True
    )
42
43

    data_loader_test = torch.utils.data.DataLoader(
44
45
        dataset_test, batch_size=args.eval_batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True
    )
46
47
48
49

    print("Creating model", args.model)
    # when training quantized models, we always start from a pre-trained fp32 reference model
    model = torchvision.models.quantization.__dict__[args.model](pretrained=True, quantize=args.test_only)
50
    model.to(device)
51
52
53
54
55
56

    if not (args.test_only or args.post_training_quantize):
        model.fuse_model()
        model.qconfig = torch.quantization.get_default_qat_qconfig(args.backend)
        torch.quantization.prepare_qat(model, inplace=True)

57
58
59
        if args.distributed and args.sync_bn:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

60
        optimizer = torch.optim.SGD(
61
62
            model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
        )
63

64
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
65
66
67
68
69
70
71
72

    criterion = nn.CrossEntropyLoss()
    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module

    if args.resume:
73
74
75
76
77
        checkpoint = torch.load(args.resume, map_location="cpu")
        model_without_ddp.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
        args.start_epoch = checkpoint["epoch"] + 1
78
79
80
81

    if args.post_training_quantize:
        # perform calibration on a subset of the training dataset
        # for that, create a subset of the training dataset
82
        ds = torch.utils.data.Subset(dataset, indices=list(range(args.batch_size * args.num_calibration_batches)))
83
        data_loader_calibration = torch.utils.data.DataLoader(
84
85
            ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True
        )
86
87
88
89
90
91
92
93
94
        model.eval()
        model.fuse_model()
        model.qconfig = torch.quantization.get_default_qconfig(args.backend)
        torch.quantization.prepare(model, inplace=True)
        # Calibrate first
        print("Calibrating")
        evaluate(model, criterion, data_loader_calibration, device=device, print_freq=1)
        torch.quantization.convert(model, inplace=True)
        if args.output_dir:
95
            print("Saving quantized model")
96
            if utils.is_main_process():
97
                torch.save(model.state_dict(), os.path.join(args.output_dir, "quantized_post_train_model.pth"))
98
99
100
101
102
103
104
105
106
107
108
109
110
111
        print("Evaluating post-training quantized model")
        evaluate(model, criterion, data_loader_test, device=device)
        return

    if args.test_only:
        evaluate(model, criterion, data_loader_test, device=device)
        return

    model.apply(torch.quantization.enable_observer)
    model.apply(torch.quantization.enable_fake_quant)
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
112
113
        print("Starting training for epoch", epoch)
        train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq)
114
115
116
        lr_scheduler.step()
        with torch.no_grad():
            if epoch >= args.num_observer_update_epochs:
117
                print("Disabling observer for subseq epochs, epoch = ", epoch)
118
119
                model.apply(torch.quantization.disable_observer)
            if epoch >= args.num_batch_norm_update_epochs:
120
                print("Freezing BN for subseq epochs, epoch = ", epoch)
121
                model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
122
            print("Evaluate QAT model")
123
124

            evaluate(model, criterion, data_loader_test, device=device)
125
            quantized_eval_model = copy.deepcopy(model_without_ddp)
126
            quantized_eval_model.eval()
127
            quantized_eval_model.to(torch.device("cpu"))
128
129
            torch.quantization.convert(quantized_eval_model, inplace=True)

130
131
            print("Evaluate Quantized model")
            evaluate(quantized_eval_model, criterion, data_loader_test, device=torch.device("cpu"))
132
133
134
135
136

        model.train()

        if args.output_dir:
            checkpoint = {
137
138
139
140
141
142
143
144
145
146
                "model": model_without_ddp.state_dict(),
                "eval_model": quantized_eval_model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "lr_scheduler": lr_scheduler.state_dict(),
                "epoch": epoch,
                "args": args,
            }
            utils.save_on_master(checkpoint, os.path.join(args.output_dir, "model_{}.pth".format(epoch)))
            utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
        print("Saving models after epoch ", epoch)
147
148
149

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
150
    print("Training time {}".format(total_time_str))
151
152


153
def get_args_parser(add_help=True):
154
    import argparse
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208

    parser = argparse.ArgumentParser(description="PyTorch Quantized Classification Training", add_help=add_help)

    parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", help="dataset")
    parser.add_argument("--model", default="mobilenet_v2", help="model")
    parser.add_argument("--backend", default="qnnpack", help="fbgemm or qnnpack")
    parser.add_argument("--device", default="cuda", help="device")

    parser.add_argument("-b", "--batch-size", default=32, type=int, help="batch size for calibration/training")
    parser.add_argument("--eval-batch-size", default=128, type=int, help="batch size for evaluation")
    parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run")
    parser.add_argument(
        "--num-observer-update-epochs",
        default=4,
        type=int,
        metavar="N",
        help="number of total epochs to update observers",
    )
    parser.add_argument(
        "--num-batch-norm-update-epochs",
        default=3,
        type=int,
        metavar="N",
        help="number of total epochs to update batch norm stats",
    )
    parser.add_argument(
        "--num-calibration-batches",
        default=32,
        type=int,
        metavar="N",
        help="number of batches of training set for \
                              observer calibration ",
    )

    parser.add_argument(
        "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
    )
    parser.add_argument("--lr", default=0.0001, type=float, help="initial learning rate")
    parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
    parser.add_argument(
        "--wd",
        "--weight-decay",
        default=1e-4,
        type=float,
        metavar="W",
        help="weight decay (default: 1e-4)",
        dest="weight_decay",
    )
    parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs")
    parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
    parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
    parser.add_argument("--output-dir", default=".", help="path where to save")
    parser.add_argument("--resume", default="", help="resume from checkpoint")
    parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
209
210
211
212
213
214
215
    parser.add_argument(
        "--cache-dataset",
        dest="cache_dataset",
        help="Cache the datasets for quicker initialization. \
             It also serializes the transforms",
        action="store_true",
    )
216
217
218
219
220
221
    parser.add_argument(
        "--sync-bn",
        dest="sync_bn",
        help="Use sync batch norm",
        action="store_true",
    )
222
223
224
225
226
227
228
229
230
231
232
233
234
235
    parser.add_argument(
        "--test-only",
        dest="test_only",
        help="Only test the model",
        action="store_true",
    )
    parser.add_argument(
        "--post-training-quantize",
        dest="post_training_quantize",
        help="Post training quantize the model",
        action="store_true",
    )

    # distributed training parameters
236
237
    parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
    parser.add_argument("--dist-url", default="env://", help="url used to set up distributed training")
238

239
    return parser
240
241
242


if __name__ == "__main__":
243
    args = get_args_parser().parse_args()
244
    main(args)