from pathlib import Path import sys parent_dir = Path(__file__).resolve().parent sys.path.append(str(parent_dir)) from models import vgg16 from tqdm import tqdm from utils.data import prepare_dataloader from utils.trt import TrtModel import time import torch import onnxruntime import numpy as np import pycuda.driver as cuda from pytorch_quantization import quant_modules from torch.utils.data import DataLoader, Dataset class NumpyDataLoader: def __init__(self, dataloader): self.data = [] for data, label in dataloader: self.data.append((data.numpy().astype(np.float32), label.numpy().astype(np.float32))) def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx] class CacheDataLoader: def __init__(self, dataloader): self.data = [] for data, label in dataloader: self.data.append((data, label)) def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx] def eval_onnx(ckpt_path, dataloader, device): sess_options = onnxruntime.SessionOptions() if onnxruntime.get_device() == "GPU": providers = ['CUDAExecutionProvider'] else: providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_EXTENDED session = onnxruntime.InferenceSession(ckpt_path, sess_options, providers=providers, provider_options=[{"device_id": device}]*len(providers)) input_name = session.get_inputs()[0].name output_name = session.get_outputs()[0].name correct, total = 0, 0 for it in range(2): desc = "warmup" if it == 1: start_time = time.time() desc = "eval onnx model" for data, label in tqdm(dataloader, desc=desc, total=len(dataloader)): output = session.run([output_name], {input_name: data}) predictions = np.argmax(output, axis=-1)[0] correct += (label == predictions).sum() total += len(label) if it == 1: end_time = time.time() return correct / total, end_time - start_time def eval_trt(ckpt_path, dataloader, device): cuda.init() device = cuda.Device(device) batch_size = 16 model = TrtModel(ckpt_path) correct = 0 total = 0 desc = "warmup" for it in range(2): if it == 1: desc = "eval trt model" start_time = time.time() for data, label in tqdm(dataloader, desc=desc, total=(len(dataloader))): result = model(data, batch_size) result = np.argmax(result, axis=-1) total += label.shape[0] correct += (label == result).sum() if it == 1: end_time = time.time() return correct / total, end_time - start_time def eval_original(ckpt_path, dataloader, num_classes, device): model = vgg16(num_classes=num_classes) model.load_state_dict(torch.load(ckpt_path)) model.to(device) model.eval() total, correct = 0, 0 for it in range(2): desc = "warmup" if it == 1: start_time = time.time() desc = 'eval original pytorch model' for data, label in tqdm(dataloader, desc=desc, total=len(dataloader)): output = model(data.to(device)) _, predictions = torch.max(output, dim=-1) correct += torch.sum(predictions==label.to(device)).item() total += label.size(0) if it == 1: end_time = time.time() return correct / total, end_time - start_time def eval_qat(ckpt_path, dataloader, num_classes, device): quant_modules.initialize() model = vgg16(num_classes=num_classes) model.load_state_dict(torch.load(ckpt_path)) model.to(device) model.eval() total, correct = 0, 0 for it in range(2): desc = "warmup" if it == 1: start_time = time.time() desc = 'eval qat pytorch model' for data, label in tqdm(dataloader, desc=desc, total=len(dataloader)): output = model(data.to(device)) _, predictions = torch.max(output, dim=-1) correct += torch.sum(predictions==label.to(device)).item() total += label.size(0) if it == 1: end_time = time.time() return correct / total, end_time - start_time def main(args): device = torch.device(f"cuda:{args.device}" if args.device != -1 else "cpu") test_dataloader, _ = prepare_dataloader("./data/cifar10", False, 1) numpy_dataloader = NumpyDataLoader(test_dataloader) cache_dataloader = CacheDataLoader(test_dataloader) # 测试pytorch模型 acc1, runtime1 = eval_original("./checkpoints/pretrained/pretrained_model.pth", cache_dataloader, args.num_classes, device) acc2, runtime2 = eval_qat("./checkpoints/qat/pretrained_model.pth", cache_dataloader, args.num_classes, device) acc_onnx, runtime_onnx = eval_onnx("./checkpoints/qat/pretrained_qat.onnx", numpy_dataloader, args.device) acc_trt, runtime_trt = eval_trt("./checkpoints/qat/last.trt", numpy_dataloader, args.device) print("==============================================================") print(f"Original Model Acc: {acc1}, Inference Time: {runtime1:.4f}s") print(f"Qat Model Acc: {acc2}, Inference Time: {runtime2:.4f}s") print(f"Onnx Model Acc: {acc_onnx}, Inference Time: {runtime_onnx:.4f}s") print(f"Trt Model Acc: {acc_trt}, Inference Time: {runtime_trt:.4f}s") print("==============================================================") if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--batch_size", type=int, default=16) parser.add_argument("--device", type=int, default=-1) parser.add_argument("--num_classes", type=int, default=10) args = parser.parse_args() main(args)