Commit 38e4792a authored by mashun1's avatar mashun1
Browse files

vgg16-qat

parents
Pipeline #808 canceled with stages
*pyc*
data/
checkpoints*
nohup*
# VGG16-QAT
本项目旨在对VGG16模型执行量化感知训练,将其转换为onnx模型,并在TensorRT上运行。
## 论文
**Very Deep Convolutional Networks for Large-Scale Image Recognition**
* https://arxiv.org/abs/1409.1556
## 模型结构
VGG网络由小的卷积滤波器组成,VGG16有三个全连接层和13个卷积层,此外,可在模型中添加`BatchNorm`以及`Dropout`层。
![Alt text](readme_imgs/image-1.png)
## 算法原理
VGG使用多个较小的卷积滤波器,这减少了网络在训练过程中过度拟合的倾向。3×3的滤波器是最佳大小,因为较小的大小无法捕捉左右和上下的信息。
![alt text](readme_imgs/image-2.png)
## 环境配置
### Anaconda (方法一)
1、本项目目前仅支持在N卡环境运行
python 3.9.18
torch 2.0.1
cuda 11
pip install -r requirements.txt
pip install --no-cache-dir --extra-index-url https://pypi.nvidia.com pytorch-quantization
2、TensorRT
wget https://github.com/NVIDIA/TensorRT/archive/refs/tags/8.5.3.zip
unzip [下载的压缩包] -d [解压路径]
pip install 解压路径/python/tensorrt-8.5.3.1-cp39-none-linux_x86_64.whl
ln -s 解压路径(绝对路径)/bin/trtexec /usr/local/bin/trtexec
注意:若需要`cu12`则将`requirements.txt`中的相关注释关闭,并安装。
## 数据集
本项目使用CIFAR-10数据集,可直接运行`main.py`后自动下载并处理。
## 训练
# --epochs表示训练或校准回合数
# --resume表示继续训练
# --calibrate表示校准(在训练基础模型时不能使用此参数)
CUDA_VISIBLE_DEVICES=0,1 torchrun --nnodes=1 --nproc_per_node=2 --rdzv_id=100 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 main.py --epochs=N --resume --calibrate --batch_size=N --lr=X --num_classes=10
## 推理
trtexec --onnx=/path/to/onnx --saveEngine=/path/to/save --int8
python eval.py --device=0
## result
![alt text](readme_imgs/image-3.png)
### 精度
||原始模型|QAT校准模型|ONNX模型|TensorRT模型|
|:---|:---|:---|:---|:---|
|Acc|0.9189|0.9185|0.9181|0.9184|
|推理时间|5.5764s|13.7603s|4.2848s|2.9893s|
## 应用场景
### 算法类别
`图像分类`
### 热点应用行业
`制造,交通,网安`
## 源码仓库及问题反馈
* https://developer.hpccube.com/codes/modelzoo/vgg16-qat_pytorch
## 参考资料
* https://docs.nvidia.com/deeplearning/tensorrt/pytorch-quantization-toolkit/docs/index.html
from pathlib import Path
import sys
parent_dir = Path(__file__).resolve().parent
sys.path.append(str(parent_dir))
from models import vgg16
from tqdm import tqdm
from utils.data import prepare_dataloader
from utils.trt import TrtModel
import time
import torch
import onnxruntime
import numpy as np
import pycuda.driver as cuda
from pytorch_quantization import quant_modules
def eval_onnx(ckpt_path, dataloader, device):
sess_options = onnxruntime.SessionOptions()
if onnxruntime.get_device() == "GPU":
providers = ['CUDAExecutionProvider']
else:
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
session = onnxruntime.InferenceSession(ckpt_path, sess_options, providers=providers, provider_options=[{"device_id": device}]*len(providers))
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
correct, total = 0, 0
for it in range(2):
desc = "warmup"
if it == 1:
start_time = time.time()
desc = "eval onnx model"
for data, label in tqdm(dataloader, desc=desc, total=len(dataloader)):
data, label = data.numpy().astype(np.float32), label.numpy().astype(np.float32)
output = session.run([output_name], {input_name: data})
predictions = np.argmax(output, axis=-1)[0]
correct += (label == predictions).sum()
total += len(label)
if it == 1:
end_time = time.time()
return correct / total, end_time - start_time
def eval_trt(ckpt_path, dataloader, device):
cuda.init()
device = cuda.Device(device)
batch_size = 16
model = TrtModel(ckpt_path)
correct = 0
total = 0
desc = "warmup"
for it in range(2):
if it == 1:
desc = "eval trt model"
start_time = time.time()
for data, label in tqdm(dataloader, desc=desc, total=(len(dataloader))):
data = data.numpy()
result = model(data, batch_size)
result = np.argmax(result, axis=-1)
label = label.numpy()
total += label.shape[0]
correct += (label == result).sum()
if it == 1:
end_time = time.time()
return correct / total, end_time - start_time
def eval_original(ckpt_path, dataloader, num_classes, device):
model = vgg16(num_classes=num_classes)
model.load_state_dict(torch.load(ckpt_path))
model.to(device)
model.eval()
total, correct = 0, 0
for it in range(2):
desc = "warmup"
if it == 1:
start_time = time.time()
desc = 'eval original pytorch model'
for data, label in tqdm(dataloader, desc=desc, total=len(dataloader)):
output = model(data.to(device))
_, predictions = torch.max(output, dim=-1)
correct += torch.sum(predictions==label.to(device)).item()
total += label.size(0)
if it == 1:
end_time = time.time()
return correct / total, end_time - start_time
def eval_qat(ckpt_path, dataloader, num_classes, device):
quant_modules.initialize()
model = vgg16(num_classes=num_classes)
model.load_state_dict(torch.load(ckpt_path))
model.to(device)
model.eval()
total, correct = 0, 0
for it in range(2):
desc = "warmup"
if it == 1:
start_time = time.time()
desc = 'eval qat pytorch model'
for data, label in tqdm(dataloader, desc=desc, total=len(dataloader)):
output = model(data.to(device))
_, predictions = torch.max(output, dim=-1)
correct += torch.sum(predictions==label.to(device)).item()
total += label.size(0)
if it == 1:
end_time = time.time()
return correct / total, end_time - start_time
def main(args):
device = torch.device(f"cuda:{args.device}" if args.device != -1 else "cpu")
test_dataloader, _ = prepare_dataloader("./data/cifar10", False, args.batch_size)
# 测试pytorch模型
acc1, runtime1 = eval_original("./checkpoints/pretrained/pretrained_model.pth", test_dataloader, args.num_classes, device)
acc2, runtime2 = eval_qat("./checkpoints/calibrated/pretrained_model.pth", test_dataloader, args.num_classes, device)
acc_onnx, runtime_onnx = eval_onnx("./checkpoints/calibrated/pretrained_qat.onnx", test_dataloader, args.device)
acc_trt, runtime_trt = eval_trt("./checkpoints/calibrated/last.trt", test_dataloader, args.device)
print("==============================================================")
print(f"Original Model Acc: {acc1}, Inference Time: {runtime1:.4f}s")
print(f"Qat Model Acc: {acc2}, Inference Time: {runtime2:.4f}s")
print(f"Onnx Model Acc: {acc_onnx}, Inference Time: {runtime_onnx:.4f}s")
print(f"Trt Model Acc: {acc_trt}, Inference Time: {runtime_trt:.4f}s")
print("==============================================================")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--device", type=int, default=-1)
parser.add_argument("--num_classes", type=int, default=10)
args = parser.parse_args()
main(args)
from pathlib import Path
import sys
parent_dir = Path(__file__).resolve().parent
sys.path.append(str(parent_dir))
from models import vgg16
import os
import torch
import torch.distributed as dist
from tqdm import tqdm
from utils.data import prepare_dataloader
from utils.calibrate import *
from torch.nn.parallel import DistributedDataParallel as DDP
from pytorch_quantization import nn as quant_nn
from pytorch_quantization import quant_modules
def cleanup():
dist.destroy_process_group()
def prepare_training_obj(lr: float = 1e-3,
num_classes=10,
ckpt_root: str = '',
resume: bool = True,
calibrate: bool = True):
model = vgg16(num_classes=num_classes)
if resume or calibrate:
model.load_state_dict(torch.load(os.path.join(ckpt_root, "pretrained_model.pth"), map_location="cpu"))
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)
lr_scheduler.load_state_dict(torch.load(os.path.join(ckpt_root, "scheduler.pth")))
lr_scheduler.step()
else:
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)
loss_fc = torch.nn.CrossEntropyLoss()
return model, optimizer, lr_scheduler, loss_fc
def train_one_epoch(model,
optimizer,
lr_scheduler,
loss_fc,
dataloader,
device):
model.train()
epoch_loss = torch.zeros(1).to(device)
for it, (data, label) in enumerate(dataloader):
output = model(data.to(device))
loss = loss_fc(output, label.to(device))
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += (loss / label.size(0))
lr_scheduler.step()
dist.reduce(epoch_loss, dst=0)
return epoch_loss
def evaluate(model,
dataloader,
device):
correct = 0
total = 0
model.eval()
for data, label in dataloader:
output = model(data.to(device))
_, predictions = torch.max(output, dim=-1)
correct += torch.sum(predictions.cpu()==label)
total += label.size(0)
return correct / total
def pretrain(args):
dist.init_process_group('nccl')
rank = dist.get_rank()
model, optimizer, lr_scheduler, loss_fc = prepare_training_obj(args.lr, ckpt_root="./checkpoints/pretrained", resume=args.resume, calibrate=args.calibrate)
device = torch.device(f"cuda:{rank}")
model.to(device)
ddp_model = DDP(model, device_ids=[rank])
train_dataloader, sampler = prepare_dataloader("./data/cifar10", True, args.batch_size)
if rank == 0:
test_dataloader, _ = prepare_dataloader("./data/cifar10", False)
for epoch in range(args.epochs):
if rank == 0:
train_dataloader = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{args.epochs}", position=0, leave=False)
dist.barrier()
sampler.set_epoch(epoch)
loss = train_one_epoch(ddp_model, optimizer, lr_scheduler, loss_fc, train_dataloader, device)
if dist.get_rank() == 0:
avg_loss = loss.item() / dist.get_world_size()
if (epoch + 1) % 5 == 0:
acc = evaluate(model, test_dataloader, device)
tqdm.write(f"Epoch: {epoch+1}, Avg Train Loss: {avg_loss:.4f}, Eval Acc: {acc}")
else:
tqdm.write(f"Epoch: {epoch+1}, Avg Train Loss: {avg_loss:.4f}")
if (epoch + 1) % 5 == 0:
# save checkpoints and lr.
ckpt_path = "./checkpoints/pretrained"
if not os.path.exists(ckpt_path):
os.makedirs(ckpt_path)
torch.save(model.state_dict(), os.path.join(ckpt_path, "pretrained_model.pth"))
torch.save(lr_scheduler.state_dict(), os.path.join(ckpt_path, "scheduler.pth"))
cleanup()
def calibrate(args):
dist.init_process_group('nccl')
rank = dist.get_rank()
quant_modules.initialize()
if args.resume:
model, optimizer, lr_scheduler, loss_fc = prepare_training_obj(args.lr, ckpt_root="./checkpoints/calibrated", resume=args.resume, calibrate=args.calibrate)
else:
model, optimizer, lr_scheduler, loss_fc = prepare_training_obj(args.lr, ckpt_root="./checkpoints/pretrained", resume=args.resume, calibrate=args.calibrate)
device = torch.device(f"cuda:{rank}")
model.to(device)
train_dataloader, sampler = prepare_dataloader("./data/cifar10", True, args.batch_size)
ddp_model = DDP(model, device_ids=[rank])
with torch.no_grad():
collect_stats(ddp_model, train_dataloader, num_batches=2, device=device)
compute_amax(ddp_model, device=device, method="percentile", percentile=99.99)
if rank == 0:
test_dataloader, _ = prepare_dataloader("./data/cifar10", False)
for epoch in range(args.epochs):
if rank == 0:
train_dataloader = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{args.epochs}", position=0, leave=False)
dist.barrier()
sampler.set_epoch(epoch)
loss = train_one_epoch(ddp_model, optimizer, lr_scheduler, loss_fc, train_dataloader, device)
if dist.get_rank() == 0:
avg_loss = loss.item() / dist.get_world_size()
if (epoch + 1) % 5 == 0:
acc = evaluate(model, test_dataloader, device)
tqdm.write(f"Epoch: {epoch+1}, Avg Train Loss: {avg_loss:.4f}, Eval Acc: {acc}")
else:
tqdm.write(f"Epoch: {epoch+1}, Avg Train Loss: {avg_loss:.4f}")
if (epoch + 1) % 5 == 0:
# save checkpoints and lr.
ckpt_path = "./checkpoints/calibrated"
if not os.path.exists(ckpt_path):
os.makedirs(ckpt_path)
torch.save(model.state_dict(), os.path.join(ckpt_path, "pretrained_model.pth"))
torch.save(lr_scheduler.state_dict(), os.path.join(ckpt_path, "scheduler.pth"))
if rank == 0:
quant_nn.TensorQuantizer.use_fb_fake_quant = True
model.eval()
with torch.no_grad():
jit_model = torch.jit.trace(model, torch.randn((16, 3, 32, 32)).to(device))
# torch.jit.save(jit_model, "./checkpoints/calibrated/pretrained_model.jit")
jit_model.eval()
torch.onnx.export(jit_model.to(device), torch.randn((16, 3, 32, 32)).to(device), "checkpoints/calibrated/pretrained_qat.onnx")
cleanup()
def main(args):
if args.calibrate:
calibrate(args)
else:
pretrain(args)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--batch_size", type=int, default=512)
parser.add_argument("--num_classes", type=int, default=10)
parser.add_argument("--resume", action="store_true")
parser.add_argument("--calibrate", action="store_true")
args = parser.parse_args()
main(args)
\ No newline at end of file
# 模型唯一标识
modelCode = 552
# 模型名称
modelName = vgg16-qat_pytorch
# 模型描述
modelDescription = vgg16-qat
# 应用场景
appScenario = 训练,推理,图像分类,制造,交通,网安
# 框架类型
frameType = pytorch
"""
# Reference
- [Very Deep Convolutional Networks for Large-Scale Image Recognition](
https://arxiv.org/abs/1409.1556) (ICLR 2015)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import reduce
class VGG(nn.Module):
def __init__(self, layer_spec, num_classes=1000, init_weights=False):
super(VGG, self).__init__()
layers = []
in_channels = 3
for l in layer_spec:
if l == "pool":
layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
else:
layers += [
nn.Conv2d(in_channels, l, kernel_size=3, padding=1),
nn.BatchNorm2d(l),
nn.ReLU(),
]
in_channels = l
self.features = nn.Sequential(*layers)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.classifier = nn.Sequential(
nn.Linear(512 * 1 * 1, 4096),
nn.ReLU(),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(),
nn.Dropout(),
nn.Linear(4096, num_classes),
)
if init_weights:
self._initialize_weights()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
def vgg16(num_classes=1000, init_weights=False):
vgg16_cfg = [
64,
64,
"pool",
128,
128,
"pool",
256,
256,
256,
"pool",
512,
512,
512,
"pool",
512,
512,
512,
"pool",
]
return VGG(vgg16_cfg, num_classes, init_weights)
\ No newline at end of file
absl-py==2.0.0
aiohttp==3.9.3
aiosignal==1.3.1
anyio==4.2.0
appdirs==1.4.4
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asttokens==2.4.1
async-lru==2.0.4
async-timeout==4.0.3
attrs==23.2.0
Babel==2.14.0
beautifulsoup4==4.12.3
bleach==6.1.0
certifi==2023.11.17
cffi==1.16.0
charset-normalizer==3.3.2
cmake==3.28.1
coloredlogs==15.0.1
comm==0.2.1
contourpy==1.2.0
cycler==0.12.1
debugpy==1.8.0
decorator==5.1.1
defusedxml==0.7.1
exceptiongroup==1.2.0
executing==2.0.1
fastjsonschema==2.19.1
filelock==3.13.1
flatbuffers==23.5.26
fonttools==4.47.2
fqdn==1.5.1
frozenlist==1.4.1
fsspec==2024.2.0
gitdb==4.0.11
GitPython==3.1.41
hub-sdk==0.0.3
humanfriendly==10.0
idna==3.6
importlib-metadata==7.0.1
importlib-resources==6.1.1
ipykernel==6.29.0
ipython==8.18.1
isoduration==20.11.0
jedi==0.19.1
Jinja2==3.1.3
json5==0.9.14
jsonpointer==2.4
jsonschema==4.21.1
jsonschema-specifications==2023.12.1
jupyter-events==0.9.0
jupyter-lsp==2.2.2
jupyter_client==8.6.0
jupyter_core==5.7.1
jupyter_server==2.12.5
jupyter_server_terminals==0.5.1
jupyterlab==4.0.11
jupyterlab_pygments==0.3.0
jupyterlab_server==2.25.2
kiwisolver==1.4.5
lit==17.0.6
Mako==1.3.2
MarkupSafe==2.1.3
matplotlib==3.8.2
matplotlib-inline==0.1.6
mistune==3.0.2
mpmath==1.3.0
multidict==6.0.5
nbclient==0.9.0
nbconvert==7.14.2
nbformat==5.9.2
nest-asyncio==1.6.0
networkx==3.2.1
notebook==7.0.7
notebook_shim==0.2.3
numpy==1.23.2
nvidia-cublas-cu11==11.10.3.66
# nvidia-cublas-cu12==12.3.4.1
nvidia-cuda-cupti-cu11==11.7.101
nvidia-cuda-nvrtc-cu11==11.7.99
# nvidia-cuda-nvrtc-cu12==12.3.107
nvidia-cuda-runtime-cu11==11.7.99
# nvidia-cuda-runtime-cu12==12.3.101
nvidia-cudnn-cu11==8.5.0.96
# nvidia-cudnn-cu12==8.9.7.29
nvidia-cufft-cu11==10.9.0.58
nvidia-curand-cu11==10.2.10.91
nvidia-cusolver-cu11==11.4.0.1
nvidia-cusparse-cu11==11.7.4.91
nvidia-nccl-cu11==2.14.3
nvidia-nvtx-cu11==11.7.91
onnx==1.15.0
onnx-graphsurgeon==0.3.27
onnxoptimizer==0.3.2
onnxruntime-gpu==1.17.1
opencv-python==4.9.0.80
opencv-python-headless==4.9.0.80
overrides==7.6.0
packaging==23.2
pandas==2.1.4
pandocfilters==1.5.1
parso==0.8.3
pexpect==4.9.0
pillow==10.2.0
platformdirs==4.1.0
prettytable==3.9.0
prometheus-client==0.19.0
prompt-toolkit==3.0.43
protobuf==4.25.2
psutil==5.9.7
ptyprocess==0.7.0
pure-eval==0.2.2
py-cpuinfo==9.0.0
pycocotools==2.0.7
pycparser==2.21
pycuda==2020.1
Pygments==2.17.2
pyparsing==3.1.1
python-dateutil==2.8.2
python-json-logger==2.0.7
pytools==2023.1.1
pytorch-lightning==2.2.0.post0
pytorch-quantization==2.1.3
pytz==2023.3.post1
PyYAML==6.0.1
pyzmq==25.1.2
referencing==0.32.1
requests==2.31.0
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rpds-py==0.17.1
scipy==1.11.4
seaborn==0.13.1
Send2Trash==1.8.2
six==1.16.0
smmap==5.0.1
sniffio==1.3.0
soupsieve==2.5
sphinx-glpi-theme==0.5
stack-data==0.6.3
sympy==1.12
# tensorrt @ file:///home/qat/TensorRT-8.5.3.1/python/tensorrt-8.5.3.1-cp39-none-linux_x86_64.whl#sha256=ee25152809c09fd22057681ff24dfcf2dc2a1e7aad50dcb2658e0a1c0b32e768
terminado==0.18.0
thop==0.1.1.post2209072238
tinycss2==1.2.1
tomli==2.0.1
torch==2.0.1
torchaudio==2.0.2
torchmetrics==1.3.1
torchvision==0.15.2
tornado==6.4
tqdm==4.66.1
traitlets==5.14.1
triton==2.0.0
types-python-dateutil==2.8.19.20240106
typing_extensions==4.9.0
tzdata==2023.4
ultralytics==8.1.1
uri-template==1.3.0
urllib3==2.1.0
wcwidth==0.2.13
webcolors==1.13
webencodings==0.5.1
websocket-client==1.7.0
yarl==1.9.4
zipp==3.17.0
from pytorch_quantization import nn as quant_nn
from pytorch_quantization import quant_modules
from pytorch_quantization import calib
from tqdm import tqdm
def collect_stats(model, data_loader, num_batches, device):
# Enable calibrators
for name, module in model.named_modules():
if isinstance(module, quant_nn.TensorQuantizer):
if module._calibrator is not None:
module.disable_quant()
module.enable_calib()
else:
module.disable()
for i, (image, _) in tqdm(enumerate(data_loader), total=num_batches):
model(image.to(device))
if i >= num_batches:
break
# Disable calibrators
for name, module in model.named_modules():
if isinstance(module, quant_nn.TensorQuantizer):
if module._calibrator is not None:
module.enable_quant()
module.disable_calib()
else:
module.enable()
def compute_amax(model, device, **kwargs):
# Load calib result
for name, module in model.named_modules():
if isinstance(module, quant_nn.TensorQuantizer):
if module._calibrator is not None:
if isinstance(module._calibrator, calib.MaxCalibrator):
module.load_calib_amax()
else:
module.load_calib_amax(**kwargs)
model.to(device)
\ No newline at end of file
import torchvision.transforms.transforms as T
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
def prepare_dataloader(data_root,
train=True,
batch_size = 512):
if train:
train_transform = T.Compose([
T.RandomHorizontalFlip(p=0.5),
T.RandomAffine(degrees=15, translate=(0.1,0.1)),
T.ToTensor(),
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
train_dataset = CIFAR10(data_root, train=True, transform=train_transform, download=True)
sampler = DistributedSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, shuffle=False, batch_size=batch_size, sampler=sampler)
return train_dataloader, sampler
else:
test_transform = T.Compose([
T.ToTensor(),
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
test_dataset = CIFAR10(data_root, train=False, transform=test_transform, download=True)
test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=16)
return test_dataloader, None
import pycuda.driver as cuda
import pycuda.autoinit
import tensorrt as trt
import numpy as np
class HostDeviceMem(object):
def __init__(self, host_mem, device_mem):
self.host = host_mem
self.device = device_mem
def __str__(self):
return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device)
def __repr__(self):
return self.__str__()
class TrtModel:
def __init__(self,
engine_path,
max_batch_size=1,
dtype=np.float32):
self.engine_path = engine_path
self.dtype = dtype
self.logger = trt.Logger(trt.Logger.ERROR)
self.runtime = trt.Runtime(self.logger)
self.engine = self.load_engine(self.runtime, self.engine_path)
self.max_batch_size = max_batch_size
self.inputs, self.outputs, self.bindings, self.stream = self.allocate_buffers()
self.context = self.engine.create_execution_context()
@staticmethod
def load_engine(trt_runtime, engine_path):
trt.init_libnvinfer_plugins(None, "")
with open(engine_path, 'rb') as f:
engine_data = f.read()
engine = trt_runtime.deserialize_cuda_engine(engine_data)
return engine
def allocate_buffers(self):
inputs = []
outputs = []
bindings = []
stream = cuda.Stream()
for binding in self.engine:
size = trt.volume(self.engine.get_binding_shape(binding)) * self.max_batch_size
host_mem = cuda.pagelocked_empty(size, self.dtype)
device_mem = cuda.mem_alloc(host_mem.nbytes)
bindings.append(int(device_mem))
if self.engine.binding_is_input(binding):
inputs.append(HostDeviceMem(host_mem, device_mem))
else:
outputs.append(HostDeviceMem(host_mem, device_mem))
return inputs, outputs, bindings, stream
def __call__(self,
x,
batch_size=2):
x = x.astype(self.dtype)
np.copyto(self.inputs[0].host, x.ravel())
for inp in self.inputs:
cuda.memcpy_htod_async(inp.device, inp.host, self.stream)
self.context.execute_async(batch_size=batch_size, bindings=self.bindings, stream_handle=self.stream.handle)
for out in self.outputs:
cuda.memcpy_dtoh_async(out.host, out.device, self.stream)
self.stream.synchronize()
return [out.host.reshape(batch_size,-1) for out in self.outputs]
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment