from pathlib import Path import sys parent_dir = Path(__file__).resolve().parent sys.path.append(str(parent_dir)) from models import vgg16 import os import torch import torch.distributed as dist from tqdm import tqdm from utils.data import prepare_dataloader from utils.qat import * from torch.nn.parallel import DistributedDataParallel as DDP from pytorch_quantization import nn as quant_nn from pytorch_quantization import quant_modules def cleanup(): dist.destroy_process_group() def prepare_training_obj(lr: float = 1e-3, num_classes=10, ckpt_root: str = '', resume: bool = True, qat: bool = True): model = vgg16(num_classes=num_classes) if resume or qat: model.load_state_dict(torch.load(os.path.join(ckpt_root, "pretrained_model.pth"), map_location="cpu")) optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20) lr_scheduler.load_state_dict(torch.load(os.path.join(ckpt_root, "scheduler.pth"))) lr_scheduler.step() else: optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20) loss_fc = torch.nn.CrossEntropyLoss() return model, optimizer, lr_scheduler, loss_fc def train_one_epoch(model, optimizer, lr_scheduler, loss_fc, dataloader, device): model.train() epoch_loss = torch.zeros(1).to(device) for it, (data, label) in enumerate(dataloader): output = model(data.to(device)) loss = loss_fc(output, label.to(device)) optimizer.zero_grad() loss.backward() optimizer.step() epoch_loss += (loss / label.size(0)) lr_scheduler.step() dist.reduce(epoch_loss, dst=0) return epoch_loss def evaluate(model, dataloader, device): correct = 0 total = 0 model.eval() for data, label in dataloader: output = model(data.to(device)) _, predictions = torch.max(output, dim=-1) correct += torch.sum(predictions.cpu()==label) total += label.size(0) return correct / total def pretrain(args): dist.init_process_group('nccl') rank = dist.get_rank() model, optimizer, lr_scheduler, loss_fc = prepare_training_obj(args.lr, ckpt_root="./checkpoints/pretrained", resume=args.resume, qat=args.qat) device = torch.device(f"cuda:{rank}") model.to(device) ddp_model = DDP(model, device_ids=[rank]) train_dataloader, sampler = prepare_dataloader("./data/cifar10", True, args.batch_size) if rank == 0: test_dataloader, _ = prepare_dataloader("./data/cifar10", False) for epoch in range(args.epochs): if rank == 0: train_dataloader = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{args.epochs}", position=0, leave=False) dist.barrier() sampler.set_epoch(epoch) loss = train_one_epoch(ddp_model, optimizer, lr_scheduler, loss_fc, train_dataloader, device) if dist.get_rank() == 0: avg_loss = loss.item() / dist.get_world_size() if (epoch + 1) % 5 == 0: acc = evaluate(model, test_dataloader, device) tqdm.write(f"Epoch: {epoch+1}, Avg Train Loss: {avg_loss:.4f}, Eval Acc: {acc}") else: tqdm.write(f"Epoch: {epoch+1}, Avg Train Loss: {avg_loss:.4f}") if (epoch + 1) % 5 == 0: # save checkpoints and lr. ckpt_path = "./checkpoints/pretrained" if not os.path.exists(ckpt_path): os.makedirs(ckpt_path) torch.save(model.state_dict(), os.path.join(ckpt_path, "pretrained_model.pth")) torch.save(lr_scheduler.state_dict(), os.path.join(ckpt_path, "scheduler.pth")) cleanup() def qat(args): dist.init_process_group('nccl') rank = dist.get_rank() quant_modules.initialize() if args.resume: model, optimizer, lr_scheduler, loss_fc = prepare_training_obj(args.lr, ckpt_root="./checkpoints/qat", resume=args.resume, qat=args.qat) else: model, optimizer, lr_scheduler, loss_fc = prepare_training_obj(args.lr, ckpt_root="./checkpoints/pretrained", resume=args.resume, qat=args.qat) device = torch.device(f"cuda:{rank}") model.to(device) train_dataloader, sampler = prepare_dataloader("./data/cifar10", True, args.batch_size) ddp_model = DDP(model, device_ids=[rank]) with torch.no_grad(): collect_stats(ddp_model, train_dataloader, num_batches=2, device=device) compute_amax(ddp_model, device=device, method="percentile", percentile=99.99) if rank == 0: test_dataloader, _ = prepare_dataloader("./data/cifar10", False) for epoch in range(args.epochs): if rank == 0: train_dataloader = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{args.epochs}", position=0, leave=False) dist.barrier() sampler.set_epoch(epoch) loss = train_one_epoch(ddp_model, optimizer, lr_scheduler, loss_fc, train_dataloader, device) if dist.get_rank() == 0: avg_loss = loss.item() / dist.get_world_size() if (epoch + 1) % 5 == 0: acc = evaluate(model, test_dataloader, device) tqdm.write(f"Epoch: {epoch+1}, Avg Train Loss: {avg_loss:.4f}, Eval Acc: {acc}") else: tqdm.write(f"Epoch: {epoch+1}, Avg Train Loss: {avg_loss:.4f}") if (epoch + 1) % 5 == 0: # save checkpoints and lr. ckpt_path = "./checkpoints/qat" if not os.path.exists(ckpt_path): os.makedirs(ckpt_path) torch.save(model.state_dict(), os.path.join(ckpt_path, "pretrained_model.pth")) torch.save(lr_scheduler.state_dict(), os.path.join(ckpt_path, "scheduler.pth")) if rank == 0: quant_nn.TensorQuantizer.use_fb_fake_quant = True model.eval() with torch.no_grad(): jit_model = torch.jit.trace(model, torch.randn((16, 3, 32, 32)).to(device)) # torch.jit.save(jit_model, "./checkpoints/qat/pretrained_model.jit") jit_model.eval() torch.onnx.export(jit_model.to(device), torch.randn((16, 3, 32, 32)).to(device), "checkpoints/qat/pretrained_qat.onnx") cleanup() def main(args): if args.qat: qat(args) else: pretrain(args) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--epochs", type=int, default=100) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--batch_size", type=int, default=512) parser.add_argument("--num_classes", type=int, default=10) parser.add_argument("--resume", action="store_true") parser.add_argument("--qat", action="store_true") args = parser.parse_args() main(args)