import datetime import os import time import torch import torch.utils.data from torch import nn import torchvision from torchvision import transforms import utils def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, print_freq): model.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}')) header = 'Epoch: [{}]'.format(epoch) for image, target in metric_logger.log_every(data_loader, print_freq, header): image, target = image.to(device), target.to(device) output = model(image) loss = criterion(output, target) optimizer.zero_grad() loss.backward() optimizer.step() acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) batch_size = image.shape[0] metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"]) metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) def evaluate(model, criterion, data_loader, device): model.eval() metric_logger = utils.MetricLogger(delimiter=" ") header = 'Test:' with torch.no_grad(): for image, target in metric_logger.log_every(data_loader, 100, header): image = image.to(device, non_blocking=True) target = target.to(device, non_blocking=True) output = model(image) loss = criterion(output, target) acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) # FIXME need to take into account that the datasets # could have been padded in distributed setup batch_size = image.shape[0] metric_logger.update(loss=loss.item()) metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) # gather the stats from all processes metric_logger.synchronize_between_processes() print(' * Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f}' .format(top1=metric_logger.acc1, top5=metric_logger.acc5)) return metric_logger.acc1.global_avg def main(args): args.gpu = args.local_rank if args.distributed: args.rank = int(os.environ["RANK"]) torch.cuda.set_device(args.gpu) args.dist_backend = 'nccl' dist_url = 'env://' print('| distributed init (rank {}): {}'.format( args.rank, dist_url), flush=True) torch.distributed.init_process_group(backend=args.dist_backend, init_method=dist_url) utils.setup_for_distributed(args.rank == 0) device = torch.device(args.device) torch.backends.cudnn.benchmark = True # Data loading code print("Loading data") traindir = os.path.join(args.data_path, 'train') valdir = os.path.join(args.data_path, 'val') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) print("Loading training data") st = time.time() scale = (0.08, 1.0) if args.model == 'mobilenet_v2': scale = (0.2, 1.0) dataset = torchvision.datasets.ImageFolder( traindir, transforms.Compose([ transforms.RandomResizedCrop(224, scale=scale), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) print("Took", time.time() - st) print("Loading validation data") dataset_test = torchvision.datasets.ImageFolder( valdir, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ])) print("Creating data loaders") if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test) else: train_sampler = torch.utils.data.RandomSampler(dataset) test_sampler = torch.utils.data.SequentialSampler(dataset_test) data_loader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True) data_loader_test = torch.utils.data.DataLoader( dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True) print("Creating model") model = torchvision.models.__dict__[args.model]() model.to(device) if args.distributed: model = torch.nn.utils.convert_sync_batchnorm(model) model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) model_without_ddp = model.module criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD( model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # if using mobilenet, step_size=2 and gamma=0.94 lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) if args.test_only: evaluate(model, criterion, data_loader_test, device=device) return print("Start training") start_time = time.time() for epoch in range(args.epochs): if args.distributed: train_sampler.set_epoch(epoch) lr_scheduler.step() train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq) evaluate(model, criterion, data_loader_test, device=device) if args.output_dir: utils.save_on_master({ 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'args': args}, os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str)) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description='PyTorch Classification Training') parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', help='dataset') parser.add_argument('--model', default='resnet18', help='model') parser.add_argument('--device', default='cuda', help='device') parser.add_argument('-b', '--batch-size', default=32, type=int) parser.add_argument('--epochs', default=90, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', help='number of data loading workers (default: 16)') parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)', dest='weight_decay') parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs') parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma') parser.add_argument('--print-freq', default=10, type=int, help='print frequency') parser.add_argument('--output-dir', default='.', help='path where to save') parser.add_argument('--resume', default='', help='resume from checkpoint') parser.add_argument( "--test-only", dest="test_only", help="Only test the model", action="store_true", ) parser.add_argument('--local_rank', default=0, type=int, help='print frequency') args = parser.parse_args() print(args) if args.output_dir: utils.mkdir(args.output_dir) num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 args.distributed = num_gpus > 1 main(args)