train.py 13.1 KB
Newer Older
1
2
3
import datetime
import os
import time
4
import warnings
5

6
import presets
7
8
9
10
import torch
import torch.utils.data
import torchvision
import utils
11
12
from coco_utils import get_coco
from torch import nn
13
from torch.optim.lr_scheduler import PolynomialLR
14
from torchvision.transforms import functional as F, InterpolationMode
15
16


17
def get_dataset(args, is_train):
18
    def sbd(*args, **kwargs):
19
        kwargs.pop("use_v2")
20
21
        return torchvision.datasets.SBDataset(*args, mode="segmentation", **kwargs)

22
23
24
25
    def voc(*args, **kwargs):
        kwargs.pop("use_v2")
        return torchvision.datasets.VOCSegmentation(*args, **kwargs)

26
    paths = {
27
28
29
        "voc": (args.data_path, voc, 21),
        "voc_aug": (args.data_path, sbd, 21),
        "coco": (args.data_path, get_coco, 21),
30
    }
31
    p, ds_fn, num_classes = paths[args.dataset]
32

33
34
    image_set = "train" if is_train else "val"
    ds = ds_fn(p, image_set=image_set, transforms=get_transform(is_train, args), use_v2=args.use_v2)
35
36
37
    return ds, num_classes


38
39
40
def get_transform(is_train, args):
    if is_train:
        return presets.SegmentationPresetTrain(base_size=520, crop_size=480, backend=args.backend, use_v2=args.use_v2)
41
42
43
44
45
46
47
48
49
50
51
    elif args.weights and args.test_only:
        weights = torchvision.models.get_weight(args.weights)
        trans = weights.transforms()

        def preprocessing(img, target):
            img = trans(img)
            size = F.get_dimensions(img)[1:]
            target = F.resize(target, size, interpolation=InterpolationMode.NEAREST)
            return img, F.pil_to_tensor(target)

        return preprocessing
52
    else:
53
        return presets.SegmentationPresetEval(base_size=520, backend=args.backend, use_v2=args.use_v2)
54
55
56
57
58
59
60
61


def criterion(inputs, target):
    losses = {}
    for name, x in inputs.items():
        losses[name] = nn.functional.cross_entropy(x, target, ignore_index=255)

    if len(losses) == 1:
62
        return losses["out"]
63

64
    return losses["out"] + 0.5 * losses["aux"]
65
66
67
68
69
70


def evaluate(model, data_loader, device, num_classes):
    model.eval()
    confmat = utils.ConfusionMatrix(num_classes)
    metric_logger = utils.MetricLogger(delimiter="  ")
71
    header = "Test:"
72
    num_processed_samples = 0
73
    with torch.inference_mode():
74
75
76
        for image, target in metric_logger.log_every(data_loader, 100, header):
            image, target = image.to(device), target.to(device)
            output = model(image)
77
            output = output["out"]
78
79

            confmat.update(target.flatten(), output.argmax(1).flatten())
80
81
82
            # FIXME need to take into account that the datasets
            # could have been padded in distributed setup
            num_processed_samples += image.shape[0]
83
84
85

        confmat.reduce_from_all_processes()

86
87
88
89
90
91
92
93
94
95
96
97
98
99
    num_processed_samples = utils.reduce_across_processes(num_processed_samples)
    if (
        hasattr(data_loader.dataset, "__len__")
        and len(data_loader.dataset) != num_processed_samples
        and torch.distributed.get_rank() == 0
    ):
        # See FIXME above
        warnings.warn(
            f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} "
            "samples were used for the validation, which might bias the results. "
            "Try adjusting the batch size and / or the world size. "
            "Setting the world size to 1 is always a safe bet."
        )

100
101
102
    return confmat


103
def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, print_freq, scaler=None):
104
105
    model.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
106
    metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
107
    header = f"Epoch: [{epoch}]"
108
109
    for image, target in metric_logger.log_every(data_loader, print_freq, header):
        image, target = image.to(device), target.to(device)
110
111
112
        with torch.cuda.amp.autocast(enabled=scaler is not None):
            output = model(image)
            loss = criterion(output, target)
113
114

        optimizer.zero_grad()
115
116
117
118
119
120
121
        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()
122
123
124
125
126
127
128

        lr_scheduler.step()

        metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])


def main(args):
129
130
131
132
133
134
    if args.backend.lower() != "pil" and not args.use_v2:
        # TODO: Support tensor backend in V1?
        raise ValueError("Use --use-v2 if you want to use the datapoint or tensor backend.")
    if args.use_v2 and args.dataset != "coco":
        raise ValueError("v2 is only support supported for coco dataset for now.")

135
136
137
138
139
140
141
142
    if args.output_dir:
        utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)
    print(args)

    device = torch.device(args.device)

143
144
145
146
147
148
    if args.use_deterministic_algorithms:
        torch.backends.cudnn.benchmark = False
        torch.use_deterministic_algorithms(True)
    else:
        torch.backends.cudnn.benchmark = True

149
150
    dataset, num_classes = get_dataset(args, is_train=True)
    dataset_test, _ = get_dataset(args, is_train=False)
151
152
153

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
154
        test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
155
156
157
158
159
    else:
        train_sampler = torch.utils.data.RandomSampler(dataset)
        test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    data_loader = torch.utils.data.DataLoader(
160
161
162
163
164
165
166
        dataset,
        batch_size=args.batch_size,
        sampler=train_sampler,
        num_workers=args.workers,
        collate_fn=utils.collate_fn,
        drop_last=True,
    )
167
168

    data_loader_test = torch.utils.data.DataLoader(
169
170
        dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
    )
171

172
173
174
175
176
177
    model = torchvision.models.get_model(
        args.model,
        weights=args.weights,
        weights_backbone=args.weights_backbone,
        num_classes=num_classes,
        aux_loss=args.aux_loss,
178
    )
179
180
    model.to(device)
    if args.distributed:
Francisco Massa's avatar
Francisco Massa committed
181
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
182
183
184
185
186
187
188
189
190
191
192
193
194

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module

    params_to_optimize = [
        {"params": [p for p in model_without_ddp.backbone.parameters() if p.requires_grad]},
        {"params": [p for p in model_without_ddp.classifier.parameters() if p.requires_grad]},
    ]
    if args.aux_loss:
        params = [p for p in model_without_ddp.aux_classifier.parameters() if p.requires_grad]
        params_to_optimize.append({"params": params, "lr": args.lr * 10})
195
    optimizer = torch.optim.SGD(params_to_optimize, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
196

197
198
    scaler = torch.cuda.amp.GradScaler() if args.amp else None

199
    iters_per_epoch = len(data_loader)
200
    main_lr_scheduler = PolynomialLR(
201
        optimizer, total_iters=iters_per_epoch * (args.epochs - args.lr_warmup_epochs), power=0.9
202
    )
203
204
205
206

    if args.lr_warmup_epochs > 0:
        warmup_iters = iters_per_epoch * args.lr_warmup_epochs
        args.lr_warmup_method = args.lr_warmup_method.lower()
207
208
209
210
211
212
213
214
        if args.lr_warmup_method == "linear":
            warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
                optimizer, start_factor=args.lr_warmup_decay, total_iters=warmup_iters
            )
        elif args.lr_warmup_method == "constant":
            warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
                optimizer, factor=args.lr_warmup_decay, total_iters=warmup_iters
            )
215
        else:
216
            raise RuntimeError(
217
                f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
218
            )
219
        lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
220
            optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[warmup_iters]
221
222
223
        )
    else:
        lr_scheduler = main_lr_scheduler
224

225
    if args.resume:
226
227
        checkpoint = torch.load(args.resume, map_location="cpu")
        model_without_ddp.load_state_dict(checkpoint["model"], strict=not args.test_only)
228
        if not args.test_only:
229
230
231
            optimizer.load_state_dict(checkpoint["optimizer"])
            lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
            args.start_epoch = checkpoint["epoch"] + 1
232
233
            if args.amp:
                scaler.load_state_dict(checkpoint["scaler"])
234
235

    if args.test_only:
236
237
238
        # We disable the cudnn benchmarking because it can noticeably affect the accuracy
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
239
240
241
        confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes)
        print(confmat)
        return
242

243
    start_time = time.time()
244
    for epoch in range(args.start_epoch, args.epochs):
245
246
        if args.distributed:
            train_sampler.set_epoch(epoch)
247
        train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, args.print_freq, scaler)
248
249
        confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes)
        print(confmat)
250
        checkpoint = {
251
252
253
254
255
            "model": model_without_ddp.state_dict(),
            "optimizer": optimizer.state_dict(),
            "lr_scheduler": lr_scheduler.state_dict(),
            "epoch": epoch,
            "args": args,
256
        }
257
258
        if args.amp:
            checkpoint["scaler"] = scaler.state_dict()
259
        utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
260
        utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
261
262
263

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
264
    print(f"Training time {total_time_str}")
265
266


267
def get_args_parser(add_help=True):
268
    import argparse
269
270
271

    parser = argparse.ArgumentParser(description="PyTorch Segmentation Training", add_help=add_help)

272
273
274
    parser.add_argument("--data-path", default="/datasets01/COCO/022719/", type=str, help="dataset path")
    parser.add_argument("--dataset", default="coco", type=str, help="dataset name")
    parser.add_argument("--model", default="fcn_resnet101", type=str, help="model name")
275
    parser.add_argument("--aux-loss", action="store_true", help="auxiliary loss")
276
277
278
279
    parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
    parser.add_argument(
        "-b", "--batch-size", default=8, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
    )
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
    parser.add_argument("--epochs", default=30, type=int, metavar="N", help="number of total epochs to run")

    parser.add_argument(
        "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
    )
    parser.add_argument("--lr", default=0.01, type=float, help="initial learning rate")
    parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
    parser.add_argument(
        "--wd",
        "--weight-decay",
        default=1e-4,
        type=float,
        metavar="W",
        help="weight decay (default: 1e-4)",
        dest="weight_decay",
    )
    parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)")
    parser.add_argument("--lr-warmup-method", default="linear", type=str, help="the warmup method (default: linear)")
    parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")
    parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
300
301
    parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
    parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
302
    parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
303
304
305
306
307
308
    parser.add_argument(
        "--test-only",
        dest="test_only",
        help="Only test the model",
        action="store_true",
    )
309
310
311
    parser.add_argument(
        "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
    )
312
    # distributed training parameters
313
    parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
314
    parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
315

316
    parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
317
    parser.add_argument("--weights-backbone", default=None, type=str, help="the backbone weights enum name to load")
318

319
320
321
    # Mixed precision training parameters
    parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")

322
323
    parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive")
    parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms")
324
    return parser
325
326
327


if __name__ == "__main__":
328
    args = get_args_parser().parse_args()
329
    main(args)