train.py 10.5 KB
Newer Older
flauted's avatar
flauted committed
1
2
3
4
5
6
7
r"""PyTorch Detection Training.

To run in a multi-gpu environment, use the distributed launcher::

    python -m torch.distributed.launch --nproc_per_node=$NGPU --use_env \
        train.py ... --world-size $NGPU

8
9
10
The default hyperparameters are tuned for training on 8 gpus and 2 images per gpu.
    --lr 0.02 --batch-size 2 --world-size 8
If you use different number of gpus, the learning rate should be changed to 0.02/8*$NGPU.
11
12
13
14
15
16
17
18

On top of that, for training Faster/Mask R-CNN, the default hyperparameters are
    --epochs 26 --lr-steps 16 22 --aspect-ratio-group-factor 3

Also, if you train Keypoint R-CNN, the default hyperparameters are
    --epochs 46 --lr-steps 36 43 --aspect-ratio-group-factor 3
Because the number of images is smaller in the person keypoint subset of COCO,
the number of epochs should be adapted so that we have the same number of iterations.
flauted's avatar
flauted committed
19
"""
20
21
22
23
import datetime
import os
import time

24
import presets
25
26
27
28
29
import torch
import torch.utils.data
import torchvision
import torchvision.models.detection
import torchvision.models.detection.mask_rcnn
30
import utils
31
32
from coco_utils import get_coco, get_coco_kp
from engine import train_one_epoch, evaluate
33
from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups
34
35


36
37
38
39
40
41
try:
    from torchvision.prototype import models as PM
except ImportError:
    PM = None


flauted's avatar
flauted committed
42
def get_dataset(name, image_set, transform, data_path):
43
    paths = {"coco": (data_path, get_coco, 91), "coco_kp": (data_path, get_coco_kp, 2)}
44
45
46
47
48
49
    p, ds_fn, num_classes = paths[name]

    ds = ds_fn(p, image_set=image_set, transforms=transform)
    return ds, num_classes


50
51
52
53
54
55
def get_transform(train, args):
    if train:
        return presets.DetectionPresetTrain(args.data_augmentation)
    elif not args.weights:
        return presets.DetectionPresetEval()
    else:
56
        weights = PM.get_weight(args.weights)
57
        return weights.transforms()
58
59


60
61
def get_args_parser(add_help=True):
    import argparse
62
63
64

    parser = argparse.ArgumentParser(description="PyTorch Detection Training", add_help=add_help)

65
66
67
68
    parser.add_argument("--data-path", default="/datasets01/COCO/022719/", type=str, help="dataset path")
    parser.add_argument("--dataset", default="coco", type=str, help="dataset name")
    parser.add_argument("--model", default="maskrcnn_resnet50_fpn", type=str, help="model name")
    parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
69
70
71
72
73
74
75
76
77
78
79
    parser.add_argument(
        "-b", "--batch-size", default=2, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
    )
    parser.add_argument("--epochs", default=26, type=int, metavar="N", help="number of total epochs to run")
    parser.add_argument(
        "-j", "--workers", default=4, type=int, metavar="N", help="number of data loading workers (default: 4)"
    )
    parser.add_argument(
        "--lr",
        default=0.02,
        type=float,
80
        help="initial learning rate, 0.02 is the default value for training on 8 gpus and 2 images_per_gpu",
81
82
83
84
85
86
87
88
89
90
91
    )
    parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
    parser.add_argument(
        "--wd",
        "--weight-decay",
        default=1e-4,
        type=float,
        metavar="W",
        help="weight decay (default: 1e-4)",
        dest="weight_decay",
    )
92
93
94
    parser.add_argument(
        "--lr-scheduler", default="multisteplr", type=str, help="name of lr scheduler (default: multisteplr)"
    )
95
96
97
98
99
100
101
102
103
104
105
106
107
108
    parser.add_argument(
        "--lr-step-size", default=8, type=int, help="decrease lr every step-size epochs (multisteplr scheduler only)"
    )
    parser.add_argument(
        "--lr-steps",
        default=[16, 22],
        nargs="+",
        type=int,
        help="decrease lr every step-size epochs (multisteplr scheduler only)",
    )
    parser.add_argument(
        "--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma (multisteplr scheduler only)"
    )
    parser.add_argument("--print-freq", default=20, type=int, help="print frequency")
109
110
    parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
    parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
111
112
113
114
115
116
    parser.add_argument("--start_epoch", default=0, type=int, help="start epoch")
    parser.add_argument("--aspect-ratio-group-factor", default=3, type=int)
    parser.add_argument("--rpn-score-thresh", default=None, type=float, help="rpn score threshold for faster-rcnn")
    parser.add_argument(
        "--trainable-backbone-layers", default=None, type=int, help="number of trainable layers of backbone"
    )
117
118
119
    parser.add_argument(
        "--data-augmentation", default="hflip", type=str, help="data augmentation policy (default: hflip)"
    )
120
121
122
123
124
125
    parser.add_argument(
        "--sync-bn",
        dest="sync_bn",
        help="Use sync batch norm",
        action="store_true",
    )
126
127
128
129
130
131
132
133
134
135
136
137
138
139
    parser.add_argument(
        "--test-only",
        dest="test_only",
        help="Only test the model",
        action="store_true",
    )
    parser.add_argument(
        "--pretrained",
        dest="pretrained",
        help="Use pre-trained models from the modelzoo",
        action="store_true",
    )

    # distributed training parameters
140
    parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
141
    parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
142

143
144
145
    # Prototype models only
    parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")

146
147
148
    # Mixed precision training parameters
    parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")

149
150
151
    return parser


152
def main(args):
153
154
    if args.weights and PM is None:
        raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
155
156
157
    if args.output_dir:
        utils.mkdir(args.output_dir)

158
159
160
161
162
163
164
165
    utils.init_distributed_mode(args)
    print(args)

    device = torch.device(args.device)

    # Data loading code
    print("Loading data")

166
167
    dataset, num_classes = get_dataset(args.dataset, "train", get_transform(True, args), args.data_path)
    dataset_test, _ = get_dataset(args.dataset, "val", get_transform(False, args), args.data_path)
168
169
170
171
172
173
174
175
176
177
178
179
180

    print("Creating data loaders")
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
        test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
    else:
        train_sampler = torch.utils.data.RandomSampler(dataset)
        test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    if args.aspect_ratio_group_factor >= 0:
        group_ids = create_aspect_ratio_groups(dataset, k=args.aspect_ratio_group_factor)
        train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size)
    else:
181
        train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, args.batch_size, drop_last=True)
182
183

    data_loader = torch.utils.data.DataLoader(
184
185
        dataset, batch_sampler=train_batch_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
    )
186
187

    data_loader_test = torch.utils.data.DataLoader(
188
189
        dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
    )
190
191

    print("Creating model")
192
    kwargs = {"trainable_backbone_layers": args.trainable_backbone_layers}
193
    if "rcnn" in args.model:
194
195
        if args.rpn_score_thresh is not None:
            kwargs["rpn_score_thresh"] = args.rpn_score_thresh
196
197
198
199
200
201
    if not args.weights:
        model = torchvision.models.detection.__dict__[args.model](
            pretrained=args.pretrained, num_classes=num_classes, **kwargs
        )
    else:
        model = PM.detection.__dict__[args.model](weights=args.weights, num_classes=num_classes, **kwargs)
202
    model.to(device)
203
204
    if args.distributed and args.sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
205
206
207
208
209
210
211

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module

    params = [p for p in model.parameters() if p.requires_grad]
212
    optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
213

214
215
    scaler = torch.cuda.amp.GradScaler() if args.amp else None

216
    args.lr_scheduler = args.lr_scheduler.lower()
217
    if args.lr_scheduler == "multisteplr":
218
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)
219
    elif args.lr_scheduler == "cosineannealinglr":
220
221
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
    else:
222
        raise RuntimeError(
223
            f"Invalid lr scheduler '{args.lr_scheduler}'. Only MultiStepLR and CosineAnnealingLR are supported."
224
        )
Francisco Massa's avatar
Francisco Massa committed
225

226
    if args.resume:
227
228
229
230
231
        checkpoint = torch.load(args.resume, map_location="cpu")
        model_without_ddp.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
        args.start_epoch = checkpoint["epoch"] + 1
232
233
        if args.amp:
            scaler.load_state_dict(checkpoint["scaler"])
Francisco Massa's avatar
Francisco Massa committed
234

235
236
237
238
239
240
    if args.test_only:
        evaluate(model, data_loader_test, device=device)
        return

    print("Start training")
    start_time = time.time()
MultiK's avatar
MultiK committed
241
    for epoch in range(args.start_epoch, args.epochs):
242
243
        if args.distributed:
            train_sampler.set_epoch(epoch)
244
        train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq, scaler)
245
246
        lr_scheduler.step()
        if args.output_dir:
247
            checkpoint = {
248
249
250
251
252
                "model": model_without_ddp.state_dict(),
                "optimizer": optimizer.state_dict(),
                "lr_scheduler": lr_scheduler.state_dict(),
                "args": args,
                "epoch": epoch,
253
            }
254
255
            if args.amp:
                checkpoint["scaler"] = scaler.state_dict()
256
            utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
257
            utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
258
259
260
261
262
263

        # evaluate after every epoch
        evaluate(model, data_loader_test, device=device)

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
264
    print(f"Training time {total_time_str}")
265
266
267


if __name__ == "__main__":
268
    args = get_args_parser().parse_args()
269
    main(args)