# -*- coding: utf-8 -*- ''' Train CIFAR10 with PyTorch and Vision Transformers! written by @kentaroy47, @arutema47 ''' from __future__ import print_function import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F import torch.backends.cudnn as cudnn import numpy as np import torchvision import torchvision.transforms as transforms import os import argparse import pandas as pd #import csv import time import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel dist.init_process_group(backend="nccl", init_method="env://") local_rank = int(os.environ["LOCAL_RANK"]) world_size = int(os.environ["WORLD_SIZE"]) from models import * from utils import progress_bar from randomaug import RandAugment from models.vit import ViT from models.convmixer import ConvMixer def write_pid_file(pid_file_path): '''Write pid file for watching the process later. In each round of case, we will write the current pid in the same path. ''' if os.path.exists(pid_file_path): os.remove(pid_file_path) file_d=open(pid_file_path,"w") file_d.write("%s\n" % os.getpid()) file_d.close() # parsers parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') parser.add_argument('--lr', default=1e-4, type=float, help='learning rate') # resnets.. 1e-3, Vit..1e-4 parser.add_argument('--opt', default="adam") parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') parser.add_argument('--noaug', action='store_true', help='disable use randomaug') parser.add_argument('--noamp', action='store_true', help='disable mixed precision training. for older pytorch versions') #parser.add_argument('--nowandb', action='store_true', help='disable wandb') parser.add_argument('--mixup', action='store_true', help='add mixup augumentations') parser.add_argument('--net', default='vit') parser.add_argument('--bs', default='512') parser.add_argument('--size', default="32") parser.add_argument('--n_epochs', type=int, default='200') parser.add_argument('--patch', default='4', type=int, help="patch for ViT") parser.add_argument('--dimhead', default="512", type=int) parser.add_argument('--convkernel', default='8', type=int, help="parameter for convmixer") parser.add_argument("--log_dir", type=str, default="/data/flagperf/training/result/", help="Log directory in container.") args = parser.parse_args() if dist.get_rank() == 0: write_pid_file(args.log_dir) # take in args #usewandb = ~args.nowandb #if usewandb: # import wandb # watermark = "{}_lr{}".format(args.net, args.lr) # wandb.init(project="cifar10-challange", # name=watermark) # wandb.config.update(args) bs = int(args.bs) bs = int(bs / world_size) imsize = int(args.size) use_amp = bool(~args.noamp) aug = args.noaug #device = 'cuda' if torch.cuda.is_available() else 'cpu' device = torch.device('cuda', local_rank) torch.cuda.set_device(device) best_acc = 0 # best test accuracy start_epoch = 0 # start from epoch 0 or last checkpoint epoch global_steps = 0 target_acc = 84.49 final_acc = 0 num_trained_samples = 0 log_file_name = f'rank{local_rank}.out.log' log_file = open(log_file_name, 'w') # Data #print('==> Preparing data..') if args.net=="vit_timm": size = 384 else: size = imsize transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.Resize(size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) transform_test = transforms.Compose([ transforms.Resize(size), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) # Add RandAugment with N, M(hyperparameter) if aug: N = 2; M = 14; transform_train.transforms.insert(0, RandAugment(N, M)) # Prepare dataset trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform_train) train_sampler = torch.utils.data.distributed.DistributedSampler(trainset, num_replicas=world_size, rank=local_rank, shuffle=True) trainloader = torch.utils.data.DataLoader(trainset, batch_size=bs, sampler=train_sampler, num_workers=8) testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform_test) test_sampler = torch.utils.data.distributed.DistributedSampler(testset) testloader = torch.utils.data.DataLoader(testset, batch_size=100, sampler=test_sampler,shuffle=False, num_workers=8) classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') # Model factory.. #print('==> Building model..') # net = VGG('VGG19') if args.net=='res18': net = ResNet18() elif args.net=='vgg': net = VGG('VGG19') elif args.net=='res34': net = ResNet34() elif args.net=='res50': net = ResNet50() elif args.net=='res101': net = ResNet101() elif args.net=="convmixer": # from paper, accuracy >96%. you can tune the depth and dim to scale accuracy and speed. net = ConvMixer(256, 16, kernel_size=args.convkernel, patch_size=1, n_classes=10) elif args.net=="mlpmixer": from models.mlpmixer import MLPMixer net = MLPMixer( image_size = 32, channels = 3, patch_size = args.patch, dim = 512, depth = 6, num_classes = 10 ) elif args.net=="vit_small": from models.vit_small import ViT net = ViT( image_size = size, patch_size = args.patch, num_classes = 10, dim = int(args.dimhead), depth = 6, heads = 8, mlp_dim = 512, dropout = 0.1, emb_dropout = 0.1 ) elif args.net=="vit_tiny": from models.vit_small import ViT net = ViT( image_size = size, patch_size = args.patch, num_classes = 10, dim = int(args.dimhead), depth = 4, heads = 6, mlp_dim = 256, dropout = 0.1, emb_dropout = 0.1 ) elif args.net=="simplevit": from models.simplevit import SimpleViT net = SimpleViT( image_size = size, patch_size = args.patch, num_classes = 10, dim = int(args.dimhead), depth = 6, heads = 8, mlp_dim = 512 ) elif args.net=="vit": # ViT for cifar10 net = ViT( image_size = size, patch_size = args.patch, num_classes = 10, dim = int(args.dimhead), depth = 6, heads = 8, mlp_dim = 512, dropout = 0.1, emb_dropout = 0.1 ) elif args.net=="vit_timm": import timm net = timm.create_model("vit_base_patch16_384", pretrained=True) net.head = nn.Linear(net.head.in_features, 10) elif args.net=="cait": from models.cait import CaiT net = CaiT( image_size = size, patch_size = args.patch, num_classes = 10, dim = int(args.dimhead), depth = 6, # depth of transformer for patch to patch attention only cls_depth=2, # depth of cross attention of CLS tokens to patch heads = 8, mlp_dim = 512, dropout = 0.1, emb_dropout = 0.1, layer_dropout = 0.05 ) elif args.net=="cait_small": from models.cait import CaiT net = CaiT( image_size = size, patch_size = args.patch, num_classes = 10, dim = int(args.dimhead), depth = 6, # depth of transformer for patch to patch attention only cls_depth=2, # depth of cross attention of CLS tokens to patch heads = 6, mlp_dim = 256, dropout = 0.1, emb_dropout = 0.1, layer_dropout = 0.05 ) elif args.net=="swin": from models.swin import swin_t net = swin_t(window_size=args.patch, num_classes=10, downscaling_factors=(2,2,2,1)) # For Multi-GPU #if 'cuda' in device: #print(device) #print("using data parallel") #net = torch.nn.DataParallel(net) # make parallel net = net.to(device) net = DistributedDataParallel(net, device_ids=[device]) cudnn.benchmark = True if args.resume: # Load checkpoint. #print('==> Resuming from checkpoint..') assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' checkpoint = torch.load('./checkpoint/{}-ckpt.t7'.format(args.net)) net.load_state_dict(checkpoint['net']) best_acc = checkpoint['acc'] start_epoch = checkpoint['epoch'] # Loss is CE criterion = nn.CrossEntropyLoss() if args.opt == "adam": optimizer = optim.Adam(net.parameters(), lr=args.lr) elif args.opt == "sgd": optimizer = optim.SGD(net.parameters(), lr=args.lr) # use cosine scheduling scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.n_epochs) ##### Training scaler = torch.cuda.amp.GradScaler(enabled=use_amp) def train(epoch): global num_trained_samples, global_steps #print('\nEpoch: %d' % epoch) train_sampler.set_epoch(epoch) net.train() train_loss = 0 correct = 0 total = 0 for batch_idx, (inputs, targets) in enumerate(trainloader): inputs, targets = inputs.to(device), targets.to(device) # Train with amp with torch.cuda.amp.autocast(enabled=use_amp): outputs = net(inputs) loss = criterion(outputs, targets) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() optimizer.zero_grad() train_loss += loss.item() _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() num_trained_samples += targets.size(0) global_steps += 1 learning_rate = f'{optimizer.param_groups[0]["lr"]:.9f}' loss_str = "%.4f" % (train_loss/(batch_idx+1)) acc = 100.*correct/total step_output = f'[PerfLog] {{"event": "STEP_END", "value": {{"epoch": {epoch+1}, "global_steps": {global_steps},"loss": {loss_str},"accuracy":{acc:.4f},"num_trained_samples": {num_trained_samples}, "learning_rate": {learning_rate}}}}}' log_file.write(step_output + '\n') print(f'rank {local_rank}: ' + step_output) #progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' # % (train_loss/(batch_idx+1), 100.*correct/total, correct, total)) return train_loss/(batch_idx+1) ##### Validation def test(epoch): global best_acc net.eval() test_loss = 0 correct = 0 total = 0 with torch.no_grad(): for batch_idx, (inputs, targets) in enumerate(testloader): inputs, targets = inputs.to(device), targets.to(device) outputs = net(inputs) loss = criterion(outputs, targets) test_loss += loss.item() _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() #progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' # % (test_loss/(batch_idx+1), 100.*correct/total, correct, total)) torch.cuda.synchronize() t = torch.tensor([total, correct], device='cuda') dist.all_reduce(t) total = t[0] correct = t[1] # Save checkpoint. acc = 100.*correct/total if acc > best_acc: if dist.get_rank() == 0: #print('Saving..') state = {"model": net.state_dict(), "optimizer": optimizer.state_dict(), "scaler": scaler.state_dict()} if not os.path.isdir('checkpoint'): os.mkdir('checkpoint') torch.save(state, './checkpoint/'+args.net+'-{}-ckpt.t7'.format(args.patch)) best_acc = acc #os.makedirs("log", exist_ok=True) #content = time.ctime() + ' ' + f'Epoch {epoch}, lr: {optimizer.param_groups[0]["lr"]:.7f}, val loss: {test_loss:.5f}, acc: {(acc):.5f}' #print(content) #with open(f'log/log_{args.net}_patch{args.patch}.txt', 'a') as appender: # appender.write(content + "\n") return test_loss, acc, total #list_loss = [] #list_acc = [] #if usewandb: # wandb.watch(net) training_start = time.time() training_only = 0 net.cuda() for epoch in range(start_epoch, args.n_epochs): start = time.time() trainloss = train(epoch) epoch_time = time.time() - start training_only += epoch_time start = time.time() val_loss, acc, total= test(epoch) eval_time = time.time() - start eval_output = f'[PerfLog] {{"event": "EVALUATE_END", "value": {{"global_steps": {global_steps},"eval_loss": {val_loss:.4f},"eval_mlm_accuracy":{acc:.4f},"eval_time": {eval_time:.4f},"epoch_time":{epoch_time:.4f},"num_eval_samples":{total}}}}}' log_file.write(eval_output + '\n') print(f'rank {local_rank}: ' + eval_output) if acc >= target_acc: final_acc = acc break scheduler.step(epoch-1) # step cosine scheduling #list_loss.append(val_loss) #list_acc.append(acc) # Log training.. #if usewandb: # wandb.log({'epoch': epoch, 'train_loss': trainloss, 'val_loss': val_loss, "val_acc": acc, "lr": optimizer.param_groups[0]["lr"], # "epoch_time": time.time()-start}) # Write out csv.. #with open(f'log/log_{args.net}_patch{args.patch}.csv', 'w') as f: # writer = csv.writer(f, lineterminator='\n') # writer.writerow(list_loss) # writer.writerow(list_acc) #print(list_loss) train_time = time.time() - training_start samples_sec = num_trained_samples / training_only train_output = f'[PerfLog] {{"event": "TRAIN_END", "value": {{"accuracy":{final_acc:.4f},"train_time":{train_time:.4f},"samples/sec: {samples_sec:.4f}","num_trained_samples":{num_trained_samples}}}}}' log_file.write(train_output + '\n') print(f'rank {local_rank}: ' + train_output) log_file.close() # writeout wandb #if usewandb: # wandb.save("wandb_{}.h5".format(args.net))