import datetime import os import time import torch import torch.utils.data from torch import nn import torchvision from torchvision.transforms.functional import InterpolationMode import presets import utils try: from apex import amp except ImportError: amp = None def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, print_freq, apex=False, model_ema=None): model.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}')) metric_logger.add_meter('img/s', utils.SmoothedValue(window_size=10, fmt='{value}')) header = 'Epoch: [{}]'.format(epoch) for image, target in metric_logger.log_every(data_loader, print_freq, header): start_time = time.time() image, target = image.to(device), target.to(device) output = model(image) loss = criterion(output, target) optimizer.zero_grad() if apex: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() optimizer.step() acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) batch_size = image.shape[0] metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"]) metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) metric_logger.meters['img/s'].update(batch_size / (time.time() - start_time)) if model_ema: model_ema.update_parameters(model) def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=''): model.eval() metric_logger = utils.MetricLogger(delimiter=" ") header = f'Test: {log_suffix}' with torch.no_grad(): for image, target in metric_logger.log_every(data_loader, print_freq, header): image = image.to(device, non_blocking=True) target = target.to(device, non_blocking=True) output = model(image) loss = criterion(output, target) acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) # FIXME need to take into account that the datasets # could have been padded in distributed setup batch_size = image.shape[0] metric_logger.update(loss=loss.item()) metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) # gather the stats from all processes metric_logger.synchronize_between_processes() print(' * Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f}' .format(top1=metric_logger.acc1, top5=metric_logger.acc5)) return metric_logger.acc1.global_avg def _get_cache_path(filepath): import hashlib h = hashlib.sha1(filepath.encode()).hexdigest() cache_path = os.path.join("~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt") cache_path = os.path.expanduser(cache_path) return cache_path def load_data(traindir, valdir, args): # Data loading code print("Loading data") resize_size, crop_size = 256, 224 interpolation = InterpolationMode.BILINEAR if args.model == 'inception_v3': resize_size, crop_size = 342, 299 elif args.model.startswith('efficientnet_'): sizes = { 'b0': (256, 224), 'b1': (256, 240), 'b2': (288, 288), 'b3': (320, 300), 'b4': (384, 380), 'b5': (456, 456), 'b6': (528, 528), 'b7': (600, 600), } e_type = args.model.replace('efficientnet_', '') resize_size, crop_size = sizes[e_type] interpolation = InterpolationMode.BICUBIC print("Loading training data") st = time.time() cache_path = _get_cache_path(traindir) if args.cache_dataset and os.path.exists(cache_path): # Attention, as the transforms are also cached! print("Loading dataset_train from {}".format(cache_path)) dataset, _ = torch.load(cache_path) else: auto_augment_policy = getattr(args, "auto_augment", None) random_erase_prob = getattr(args, "random_erase", 0.0) dataset = torchvision.datasets.ImageFolder( traindir, presets.ClassificationPresetTrain(crop_size=crop_size, auto_augment_policy=auto_augment_policy, random_erase_prob=random_erase_prob)) if args.cache_dataset: print("Saving dataset_train to {}".format(cache_path)) utils.mkdir(os.path.dirname(cache_path)) utils.save_on_master((dataset, traindir), cache_path) print("Took", time.time() - st) print("Loading validation data") cache_path = _get_cache_path(valdir) if args.cache_dataset and os.path.exists(cache_path): # Attention, as the transforms are also cached! print("Loading dataset_test from {}".format(cache_path)) dataset_test, _ = torch.load(cache_path) else: dataset_test = torchvision.datasets.ImageFolder( valdir, presets.ClassificationPresetEval(crop_size=crop_size, resize_size=resize_size, interpolation=interpolation)) if args.cache_dataset: print("Saving dataset_test to {}".format(cache_path)) utils.mkdir(os.path.dirname(cache_path)) utils.save_on_master((dataset_test, valdir), cache_path) print("Creating data loaders") if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test) else: train_sampler = torch.utils.data.RandomSampler(dataset) test_sampler = torch.utils.data.SequentialSampler(dataset_test) return dataset, dataset_test, train_sampler, test_sampler def main(args): if args.apex and amp is None: raise RuntimeError("Failed to import apex. Please install apex from https://www.github.com/nvidia/apex " "to enable mixed-precision training.") if args.output_dir: utils.mkdir(args.output_dir) utils.init_distributed_mode(args) print(args) device = torch.device(args.device) torch.backends.cudnn.benchmark = True train_dir = os.path.join(args.data_path, 'train') val_dir = os.path.join(args.data_path, 'val') dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args) data_loader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True) data_loader_test = torch.utils.data.DataLoader( dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True) print("Creating model") model = torchvision.models.__dict__[args.model](pretrained=args.pretrained) model.to(device) if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) opt_name = args.opt.lower() if opt_name == 'sgd': optimizer = torch.optim.SGD( model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) elif opt_name == 'rmsprop': optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9) else: raise RuntimeError("Invalid optimizer {}. Only SGD and RMSprop are supported.".format(args.opt)) if args.apex: model, optimizer = amp.initialize(model, optimizer, opt_level=args.apex_opt_level ) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) model_without_ddp = model.module model_ema = None if args.model_ema: model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=args.model_ema_decay) if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 if model_ema: model_ema.load_state_dict(checkpoint['model_ema']) if args.test_only: evaluate(model, criterion, data_loader_test, device=device) return print("Start training") start_time = time.time() for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.apex, model_ema) lr_scheduler.step() evaluate(model, criterion, data_loader_test, device=device) if model_ema: evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix='EMA') if args.output_dir: checkpoint = { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'args': args} if model_ema: checkpoint['model_ema'] = model_ema.state_dict() utils.save_on_master( checkpoint, os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) utils.save_on_master( checkpoint, os.path.join(args.output_dir, 'checkpoint.pth')) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str)) def get_args_parser(add_help=True): import argparse parser = argparse.ArgumentParser(description='PyTorch Classification Training', add_help=add_help) parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', help='dataset') parser.add_argument('--model', default='resnet18', help='model') parser.add_argument('--device', default='cuda', help='device') parser.add_argument('-b', '--batch-size', default=32, type=int) parser.add_argument('--epochs', default=90, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', help='number of data loading workers (default: 16)') parser.add_argument('--opt', default='sgd', type=str, help='optimizer') parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)', dest='weight_decay') parser.add_argument('--label-smoothing', default=0.0, type=float, help='label smoothing (default: 0.0)', dest='label_smoothing') parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs') parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma') parser.add_argument('--print-freq', default=10, type=int, help='print frequency') parser.add_argument('--output-dir', default='.', help='path where to save') parser.add_argument('--resume', default='', help='resume from checkpoint') parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='start epoch') parser.add_argument( "--cache-dataset", dest="cache_dataset", help="Cache the datasets for quicker initialization. It also serializes the transforms", action="store_true", ) parser.add_argument( "--sync-bn", dest="sync_bn", help="Use sync batch norm", action="store_true", ) parser.add_argument( "--test-only", dest="test_only", help="Only test the model", action="store_true", ) parser.add_argument( "--pretrained", dest="pretrained", help="Use pre-trained models from the modelzoo", action="store_true", ) parser.add_argument('--auto-augment', default=None, help='auto augment policy (default: None)') parser.add_argument('--random-erase', default=0.0, type=float, help='random erasing probability (default: 0.0)') # Mixed precision training parameters parser.add_argument('--apex', action='store_true', help='Use apex for mixed precision training') parser.add_argument('--apex-opt-level', default='O1', type=str, help='For apex mixed precision training' 'O0 for FP32 training, O1 for mixed precision training.' 'For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet' ) # distributed training parameters parser.add_argument('--world-size', default=1, type=int, help='number of distributed processes') parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training') parser.add_argument( '--model-ema', action='store_true', help='enable tracking Exponential Moving Average of model parameters') parser.add_argument( '--model-ema-decay', type=float, default=0.9, help='decay factor for Exponential Moving Average of model parameters(default: 0.9)') return parser if __name__ == "__main__": args = get_args_parser().parse_args() main(args)