"vscode:/vscode.git/clone" did not exist on "b233e3eeceb6a10a08359c565685432f5a390009"
train.py 9.77 KB
Newer Older
1
2
3
4
import datetime
import os
import time

5
import presets
6
7
8
9
import torch
import torch.utils.data
import torchvision
import utils
10
11
from coco_utils import get_coco
from torch import nn
12
13


14
def get_dataset(dir_path, name, image_set, transform):
15
    def sbd(*args, **kwargs):
16
17
        return torchvision.datasets.SBDataset(*args, mode="segmentation", **kwargs)

18
    paths = {
19
20
        "voc": (dir_path, torchvision.datasets.VOCSegmentation, 21),
        "voc_aug": (dir_path, sbd, 21),
21
        "coco": (dir_path, get_coco, 21),
22
23
24
25
26
27
28
29
30
31
32
    }
    p, ds_fn, num_classes = paths[name]

    ds = ds_fn(p, image_set=image_set, transforms=transform)
    return ds, num_classes


def get_transform(train):
    base_size = 520
    crop_size = 480

33
    return presets.SegmentationPresetTrain(base_size, crop_size) if train else presets.SegmentationPresetEval(base_size)
34
35
36
37
38
39
40
41


def criterion(inputs, target):
    losses = {}
    for name, x in inputs.items():
        losses[name] = nn.functional.cross_entropy(x, target, ignore_index=255)

    if len(losses) == 1:
42
        return losses["out"]
43

44
    return losses["out"] + 0.5 * losses["aux"]
45
46
47
48
49
50


def evaluate(model, data_loader, device, num_classes):
    model.eval()
    confmat = utils.ConfusionMatrix(num_classes)
    metric_logger = utils.MetricLogger(delimiter="  ")
51
    header = "Test:"
52
    with torch.inference_mode():
53
54
55
        for image, target in metric_logger.log_every(data_loader, 100, header):
            image, target = image.to(device), target.to(device)
            output = model(image)
56
            output = output["out"]
57
58
59
60
61
62
63
64
65
66
67

            confmat.update(target.flatten(), output.argmax(1).flatten())

        confmat.reduce_from_all_processes()

    return confmat


def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, print_freq):
    model.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
68
    metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
69
    header = f"Epoch: [{epoch}]"
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
    for image, target in metric_logger.log_every(data_loader, print_freq, header):
        image, target = image.to(device), target.to(device)
        output = model(image)
        loss = criterion(output, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        lr_scheduler.step()

        metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])


def main(args):
    if args.output_dir:
        utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)
    print(args)

    device = torch.device(args.device)

93
94
    dataset, num_classes = get_dataset(args.data_path, args.dataset, "train", get_transform(train=True))
    dataset_test, _ = get_dataset(args.data_path, args.dataset, "val", get_transform(train=False))
95
96
97
98
99
100
101
102
103

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
        test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
    else:
        train_sampler = torch.utils.data.RandomSampler(dataset)
        test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    data_loader = torch.utils.data.DataLoader(
104
105
106
107
108
109
110
        dataset,
        batch_size=args.batch_size,
        sampler=train_sampler,
        num_workers=args.workers,
        collate_fn=utils.collate_fn,
        drop_last=True,
    )
111
112

    data_loader_test = torch.utils.data.DataLoader(
113
114
        dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
    )
115

116
117
118
    model = torchvision.models.segmentation.__dict__[args.model](
        num_classes=num_classes, aux_loss=args.aux_loss, pretrained=args.pretrained
    )
119
120
    model.to(device)
    if args.distributed:
Francisco Massa's avatar
Francisco Massa committed
121
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
122
123
124
125
126
127
128
129
130
131
132
133
134

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module

    params_to_optimize = [
        {"params": [p for p in model_without_ddp.backbone.parameters() if p.requires_grad]},
        {"params": [p for p in model_without_ddp.classifier.parameters() if p.requires_grad]},
    ]
    if args.aux_loss:
        params = [p for p in model_without_ddp.aux_classifier.parameters() if p.requires_grad]
        params_to_optimize.append({"params": params, "lr": args.lr * 10})
135
    optimizer = torch.optim.SGD(params_to_optimize, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
136

137
138
    iters_per_epoch = len(data_loader)
    main_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
139
140
        optimizer, lambda x: (1 - x / (iters_per_epoch * (args.epochs - args.lr_warmup_epochs))) ** 0.9
    )
141
142
143
144

    if args.lr_warmup_epochs > 0:
        warmup_iters = iters_per_epoch * args.lr_warmup_epochs
        args.lr_warmup_method = args.lr_warmup_method.lower()
145
146
147
148
149
150
151
152
        if args.lr_warmup_method == "linear":
            warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
                optimizer, start_factor=args.lr_warmup_decay, total_iters=warmup_iters
            )
        elif args.lr_warmup_method == "constant":
            warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
                optimizer, factor=args.lr_warmup_decay, total_iters=warmup_iters
            )
153
        else:
154
            raise RuntimeError(
155
                f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
156
            )
157
        lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
158
            optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[warmup_iters]
159
160
161
        )
    else:
        lr_scheduler = main_lr_scheduler
162

163
    if args.resume:
164
165
        checkpoint = torch.load(args.resume, map_location="cpu")
        model_without_ddp.load_state_dict(checkpoint["model"], strict=not args.test_only)
166
        if not args.test_only:
167
168
169
            optimizer.load_state_dict(checkpoint["optimizer"])
            lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
            args.start_epoch = checkpoint["epoch"] + 1
170
171
172
173
174

    if args.test_only:
        confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes)
        print(confmat)
        return
175

176
    start_time = time.time()
177
    for epoch in range(args.start_epoch, args.epochs):
178
179
180
181
182
        if args.distributed:
            train_sampler.set_epoch(epoch)
        train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, args.print_freq)
        confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes)
        print(confmat)
183
        checkpoint = {
184
185
186
187
188
            "model": model_without_ddp.state_dict(),
            "optimizer": optimizer.state_dict(),
            "lr_scheduler": lr_scheduler.state_dict(),
            "epoch": epoch,
            "args": args,
189
        }
190
        utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
191
        utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
192
193
194

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
195
    print(f"Training time {total_time_str}")
196
197


198
def get_args_parser(add_help=True):
199
    import argparse
200
201
202

    parser = argparse.ArgumentParser(description="PyTorch Segmentation Training", add_help=add_help)

203
204
205
    parser.add_argument("--data-path", default="/datasets01/COCO/022719/", type=str, help="dataset path")
    parser.add_argument("--dataset", default="coco", type=str, help="dataset name")
    parser.add_argument("--model", default="fcn_resnet101", type=str, help="model name")
206
    parser.add_argument("--aux-loss", action="store_true", help="auxiliar loss")
207
208
209
210
    parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
    parser.add_argument(
        "-b", "--batch-size", default=8, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
    )
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
    parser.add_argument("--epochs", default=30, type=int, metavar="N", help="number of total epochs to run")

    parser.add_argument(
        "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)"
    )
    parser.add_argument("--lr", default=0.01, type=float, help="initial learning rate")
    parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
    parser.add_argument(
        "--wd",
        "--weight-decay",
        default=1e-4,
        type=float,
        metavar="W",
        help="weight decay (default: 1e-4)",
        dest="weight_decay",
    )
    parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)")
    parser.add_argument("--lr-warmup-method", default="linear", type=str, help="the warmup method (default: linear)")
    parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr")
    parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
231
232
    parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
    parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
233
    parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
234
235
236
237
238
239
    parser.add_argument(
        "--test-only",
        dest="test_only",
        help="Only test the model",
        action="store_true",
    )
240
241
242
243
244
245
    parser.add_argument(
        "--pretrained",
        dest="pretrained",
        help="Use pre-trained models from the modelzoo",
        action="store_true",
    )
246
    # distributed training parameters
247
    parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
248
    parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
249

250
    return parser
251
252
253


if __name__ == "__main__":
254
    args = get_args_parser().parse_args()
255
    main(args)