import argparse import os import sys import time from copy import deepcopy from pathlib import Path import quantization.quantize as quantize from scripts.qat import export_onnx, create_coco_train_dataloader, create_coco_val_dataloader, evaluate_coco from models.common import Conv import numpy as np import torch import torch.distributed as dist import torch.nn as nn import yaml from tqdm import tqdm from models.yolo import Model from tqdm import tqdm from utils.loss import ComputeLoss from utils.dataloaders import create_dataloader from utils.torch_utils import ( smart_optimizer, ) from utils.general import ( LOGGER, check_dataset, check_yaml, colorstr, labels_to_class_weights, ) FILE = Path(__file__).resolve() ROOT = FILE.parents[0] # YOLOv5 root directory if str(ROOT) not in sys.path: sys.path.append(str(ROOT)) # add ROOT to PATH ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative # ================== 构建数据集 ================================== def create_coco_train_dataloader(cocodir, batch_size=10): with open("data/hyps/hyp.scratch-low.yaml") as f: hyp = yaml.load(f, Loader=yaml.SafeLoader) # load hyps loader, dataset = create_dataloader( f"{cocodir}/train2017.txt", imgsz=640, batch_size=batch_size, augment=True, hyp=hyp, rect=False, cache=False, stride=32,pad=0, image_weights=False) return loader, dataset def create_coco_val_dataloader(cocodir, batch_size=10, keep_images=None): loader, dataset = create_dataloader( f"{cocodir}/val2017.txt", imgsz=640, batch_size=32, augment=False, hyp=None, rect=True, cache=False,stride=32,pad=0.5, image_weights=False) def subclass_len(self): if keep_images is not None: return keep_images return len(self.img_files) loader.dataset.__len__ = subclass_len return loader, dataset def load_yolov5_model(weight, device) -> Model: if 'yolov5l' in weight: cfg = "models/yolov5l.yaml" elif 'yolov5m' in weight: cfg = "models/yolov5m.yaml" elif 'yolov5n' in weight: cfg = "models/yolov5n.yaml" elif 'yolov5s' in weight: cfg = "models/yolov5s.yaml" elif "yolov5x" in weight: cfg = "models/yolov5x.yaml" else: raise NotImplementedError("Only support yolov5[l, m, n, s, x]") model = Model(cfg=cfg).to(device) weight = torch.load(weight, map_location=device)["model"].state_dict() model.load_state_dict(weight,strict=False) for m in model.modules(): if type(m) is nn.Upsample: m.recompute_scale_factor = None # torch 1.11.0 compatibility elif type(m) is Conv: m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility model.float() model.eval() with torch.no_grad(): model.fuse() return model def qat(hyp, opt, device): # 加载超参数 if isinstance(hyp, str): with open(hyp, errors="ignore") as f: hyp = yaml.safe_load(f) # load hyps dict LOGGER.info(colorstr("hyperparameters: ") + ", ".join(f"{k}={v}" for k, v in hyp.items())) opt.hyp = hyp.copy() # for saving hyps to checkpoints # ========================== 加载数据集 ============================== data_dict = check_dataset(opt.data) nc = int(data_dict["nc"]) names = data_dict["names"] train_dataloader, dataset = create_coco_train_dataloader(opt.cocodir, opt.batch_size) test_dataloader, _ = create_coco_val_dataloader(opt.cocodir, opt.batch_size) # =========================== 原始模型及属性 ============================== model = load_yolov5_model(opt.weights, device) nl = model.model[-1].nl hyp["box"] *= 3 / nl # scale to layers hyp["cls"] *= nc / 80 * 3 / nl # scale to classes and layers hyp["obj"] *= (opt.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers model.nc = nc # attach number of classes to model model.hyp = hyp # attach hyperparameters to model model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights model.names = names # =========================== QAT模型 =============== quantize.replace_bottleneck_forward(model) quantize.replace_to_quantization_module(model, ignore_policy=None, all_node_with_qdq=opt.all_node_with_qdq) if not opt.all_node_with_qdq: quantize.apply_custom_rules_to_quantizer(model, export_onnx) quantize.calibrate_model(model, train_dataloader, device) # # ========================== 训练 ==================== compute_loss = ComputeLoss(model) # init loss class # optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) optimizer = smart_optimizer(model, 'SGD', 1e-3, 0.9, 1e-5) for epoch in tqdm(range(1, opt.epochs+1)): model.train() for i, (imgs, targets, paths, _) in enumerate(train_dataloader): pred = model(imgs.to(device)) loss, loss_items = compute_loss(pred, targets.to(device)) optimizer.zero_grad() loss.backward() optimizer.step() with quantize.disable_quantization(model.model[24]): ap = evaluate_coco(model, test_dataloader, True) try: save_path = "checkpoints/qat/yolov5s_qat.pt" if os.path.isfile(save_path): os.remove(save_path) except Exception as e: pass finally: torch.save(model, save_path) def parse_opt(known=False): """Parses command-line arguments for YOLOv5 training, validation, and testing.""" parser = argparse.ArgumentParser() parser.add_argument("--weights", type=str, default=ROOT / "yolov5s.pt", help="initial weights path") parser.add_argument("--cfg", type=str, default="", help="model.yaml path") parser.add_argument("--data", type=str, default=ROOT / "data/coco128.yaml", help="dataset.yaml path") parser.add_argument("--cocodir", type=str, default="/home/temp/coco2017") parser.add_argument("--hyp", type=str, default=ROOT / "data/hyps/hyp.scratch-low.yaml", help="hyperparameters path") parser.add_argument("--epochs", type=int, default=5, help="total training epochs") parser.add_argument("--batch-size", type=int, default=128, help="total batch size for all GPUs, -1 for autobatch") parser.add_argument("--imgsz", "--img", "--img-size", type=int, default=640, help="train, val image size (pixels)") parser.add_argument("--noautoanchor", action="store_true", help="disable AutoAnchor") parser.add_argument("--all_node_with_qdq", action="store_true") parser.add_argument("--anchors", type=int, default=3) parser.add_argument("--nc", type=int, default=80) return parser.parse_known_args()[0] if known else parser.parse_args() if __name__ == "__main__": opt = parse_opt() opt.cfg, opt.hyp, opt.weights = ( check_yaml(opt.cfg), check_yaml(opt.hyp), str(opt.weights), ) # checks qat(opt.hyp, opt, device=torch.device("cuda:2"))