train.py 12.3 KB
Newer Older
1
2
3
import datetime
import os
import time
4
import warnings
5

6
import presets
7
8
9
10
import torch
import torch.utils.data
import torchvision
import utils
11
12
from coco_utils import get_coco
from torch import nn
13
from torchvision.transforms import functional as F, InterpolationMode
14
15


16
def get_dataset(dir_path, name, image_set, transform):
17
    def sbd(*args, **kwargs):
18
19
        return torchvision.datasets.SBDataset(*args, mode="segmentation", **kwargs)

20
    paths = {
21
22
        "voc": (dir_path, torchvision.datasets.VOCSegmentation, 21),
        "voc_aug": (dir_path, sbd, 21),
23
        "coco": (dir_path, get_coco, 21),
24
25
26
27
28
29
30
    }
    p, ds_fn, num_classes = paths[name]

    ds = ds_fn(p, image_set=image_set, transforms=transform)
    return ds, num_classes


31
32
33
def get_transform(train, args):
    if train:
        return presets.SegmentationPresetTrain(base_size=520, crop_size=480)
34
35
36
37
38
39
40
41
42
43
44
    elif args.weights and args.test_only:
        weights = torchvision.models.get_weight(args.weights)
        trans = weights.transforms()

        def preprocessing(img, target):
            img = trans(img)
            size = F.get_dimensions(img)[1:]
            target = F.resize(target, size, interpolation=InterpolationMode.NEAREST)
            return img, F.pil_to_tensor(target)

        return preprocessing
45
    else:
46
        return presets.SegmentationPresetEval(base_size=520)
47
48
49
50
51
52
53
54


def criterion(inputs, target):
    losses = {}
    for name, x in inputs.items():
        losses[name] = nn.functional.cross_entropy(x, target, ignore_index=255)

    if len(losses) == 1:
55
        return losses["out"]
56

57
    return losses["out"] + 0.5 * losses["aux"]
58
59
60
61
62
63


def evaluate(model, data_loader, device, num_classes):
    model.eval()
    confmat = utils.ConfusionMatrix(num_classes)
    metric_logger = utils.MetricLogger(delimiter="  ")
64
    header = "Test:"
65
    num_processed_samples = 0
66
    with torch.inference_mode():
67
68
69
        for image, target in metric_logger.log_every(data_loader, 100, header):
            image, target = image.to(device), target.to(device)
            output = model(image)
70
            output = output["out"]
71
72

            confmat.update(target.flatten(), output.argmax(1).flatten())
73
74
75
            # FIXME need to take into account that the datasets
            # could have been padded in distributed setup
            num_processed_samples += image.shape[0]
76
77
78

        confmat.reduce_from_all_processes()

79
80
81
82
83
84
85
86
87
88
89
90
91
92
    num_processed_samples = utils.reduce_across_processes(num_processed_samples)
    if (
        hasattr(data_loader.dataset, "__len__")
        and len(data_loader.dataset) != num_processed_samples
        and torch.distributed.get_rank() == 0
    ):
        # See FIXME above
        warnings.warn(
            f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} "
            "samples were used for the validation, which might bias the results. "
            "Try adjusting the batch size and / or the world size. "
            "Setting the world size to 1 is always a safe bet."
        )

93
94
95
    return confmat


96
def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, print_freq, scaler=None):
97
98
    model.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
99
    metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
100
    header = f"Epoch: [{epoch}]"
101
102
    for image, target in metric_logger.log_every(data_loader, print_freq, header):
        image, target = image.to(device), target.to(device)
103
104
105
        with torch.cuda.amp.autocast(enabled=scaler is not None):
            output = model(image)
            loss = criterion(output, target)
106
107

        optimizer.zero_grad()
108
109
110
111
112
113
114
        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129

        lr_scheduler.step()

        metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])


def main(args):
    if args.output_dir:
        utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)
    print(args)

    device = torch.device(args.device)

130
131
132
133
134
135
    if args.use_deterministic_algorithms:
        torch.backends.cudnn.benchmark = False
        torch.use_deterministic_algorithms(True)
    else:
        torch.backends.cudnn.benchmark = True

136
137
    dataset, num_classes = get_dataset(args.data_path, args.dataset, "train", get_transform(True, args))
    dataset_test, _ = get_dataset(args.data_path, args.dataset, "val", get_transform(False, args))
138
139
140

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
141
        test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
142
143
144
145
146
    else:
        train_sampler = torch.utils.data.RandomSampler(dataset)
        test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    data_loader = torch.utils.data.DataLoader(
147
148
149
150
151
152
153
        dataset,
        batch_size=args.batch_size,
        sampler=train_sampler,
        num_workers=args.workers,
        collate_fn=utils.collate_fn,
        drop_last=True,
    )
154
155

    data_loader_test = torch.utils.data.DataLoader(
156
157
        dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
    )
158

159
160
161
    model = torchvision.models.segmentation.__dict__[args.model](
        weights=args.weights, weights_backbone=args.weights_backbone, num_classes=num_classes, aux_loss=args.aux_loss
    )
162
163
    model.to(device)
    if args.distributed:
Francisco Massa's avatar
Francisco Massa committed
164
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
165
166
167
168
169
170
171
172
173
174
175
176
177

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module

    params_to_optimize = [
        {"params": [p for p in model_without_ddp.backbone.parameters() if p.requires_grad]},
        {"params": [p for p in model_without_ddp.classifier.parameters() if p.requires_grad]},
    ]
    if args.aux_loss:
        params = [p for p in model_without_ddp.aux_classifier.parameters() if p.requires_grad]
        params_to_optimize.append({"params": params, "lr": args.lr * 10})
178
    optimizer = torch.optim.SGD(params_to_optimize, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
179

180
181
    scaler = torch.cuda.amp.GradScaler() if args.amp else None

182
183
    iters_per_epoch = len(data_loader)
    main_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
184
185
        optimizer, lambda x: (1 - x / (iters_per_epoch * (args.epochs - args.lr_warmup_epochs))) ** 0.9
    )
186
187
188
189

    if args.lr_warmup_epochs > 0:
        warmup_iters = iters_per_epoch * args.lr_warmup_epochs
        args.lr_warmup_method = args.lr_warmup_method.lower()
190
191
192
193
194
195
196
197
        if args.lr_warmup_method == "linear":
            warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
                optimizer, start_factor=args.lr_warmup_decay, total_iters=warmup_iters
            )
        elif args.lr_warmup_method == "constant":
            warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
                optimizer, factor=args.lr_warmup_decay, total_iters=warmup_iters
            )
198
        else:
199
            raise RuntimeError(
200
                f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
201
            )
202
        lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
203
            optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[warmup_iters]
204
205
206
        )
    else:
        lr_scheduler = main_lr_scheduler
207

208
    if args.resume:
209
210
        checkpoint = torch.load(args.resume, map_location="cpu")
        model_without_ddp.load_state_dict(checkpoint["model"], strict=not args.test_only)
211
        if not args.test_only:
212
213
214
            optimizer.load_state_dict(checkpoint["optimizer"])
            lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
            args.start_epoch = checkpoint["epoch"] + 1
215
216
            if args.amp:
                scaler.load_state_dict(checkpoint["scaler"])
217
218

    if args.test_only:
219
220
221
        # We disable the cudnn benchmarking because it can noticeably affect the accuracy
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
222
223
224
        confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes)
        print(confmat)
        return
225

226
    start_time = time.time()
227
    for epoch in range(args.start_epoch, args.epochs):
228
229
        if args.distributed:
            train_sampler.set_epoch(epoch)
230
        train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, args.print_freq, scaler)
231
232
        confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes)
        print(confmat)
233
        checkpoint = {
234
235
236
237
238
            "model": model_without_ddp.state_dict(),
            "optimizer": optimizer.state_dict(),
            "lr_scheduler": lr_scheduler.state_dict(),
            "epoch": epoch,
            "args": args,
239
        }
240
241
        if args.amp:
            checkpoint["scaler"] = scaler.state_dict()
242
        utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
243
        utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
244
245
246

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
247
    print(f"Training time {total_time_str}")
248
249


250
def get_args_parser(add_help=True):
251
    import argparse
252
253
254

    parser = argparse.ArgumentParser(description="PyTorch Segmentation Training", add_help=add_help)

255
256
257
    parser.add_argument("--data-path", default="/datasets01/COCO/022719/", type=str, help="dataset path")
    parser.add_argument("--dataset", default="coco", type=str, help="dataset name")
    parser.add_argument("--model", default="fcn_resnet101", type=str, help="model name")
258
    parser.add_argument("--aux-loss", action="store_true", help="auxiliar loss")
259
260
261
262
    parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
    parser.add_argument(
        "-b", "--batch-size", default=8, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
    )
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
    parser.add_argument("--epochs", default=30, type=int, metavar="N", help="number of total epochs to run")

    parser.add_argument(
        "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
    )
    parser.add_argument("--lr", default=0.01, type=float, help="initial learning rate")
    parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
    parser.add_argument(
        "--wd",
        "--weight-decay",
        default=1e-4,
        type=float,
        metavar="W",
        help="weight decay (default: 1e-4)",
        dest="weight_decay",
    )
    parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)")
    parser.add_argument("--lr-warmup-method", default="linear", type=str, help="the warmup method (default: linear)")
    parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")
    parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
283
284
    parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
    parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
285
    parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
286
287
288
289
290
291
    parser.add_argument(
        "--test-only",
        dest="test_only",
        help="Only test the model",
        action="store_true",
    )
292
293
294
    parser.add_argument(
        "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
    )
295
    # distributed training parameters
296
    parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
297
    parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
298

299
    parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
300
    parser.add_argument("--weights-backbone", default=None, type=str, help="the backbone weights enum name to load")
301

302
303
304
    # Mixed precision training parameters
    parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")

305
    return parser
306
307
308


if __name__ == "__main__":
309
    args = get_args_parser().parse_args()
310
    main(args)