"src/nni_manager/vscode:/vscode.git/clone" did not exist on "77dac12baee6c3243445d71cd1eb812d7f73c7a7"
Commit 5c88a35d authored by liucong's avatar liucong
Browse files

提交migraphx推理方法

parent 229a0d76
......@@ -31,20 +31,19 @@ VGG使用多个较小的卷积滤波器,这减少了网络在训练过程中
cuda 11
pip install -r requirements.txt
pip install --no-cache-dir --extra-index-url https://pypi.nvidia.com pytorch-quantization
2、TensorRT
wget https://github.com/NVIDIA/TensorRT/archive/refs/tags/8.5.3.zip
unzip [下载的压缩包] -d [解压路径]
pip install 解压路径/python/tensorrt-8.5.3.1-cp39-none-linux_x86_64.whl
ln -s 解压路径(绝对路径)/bin/trtexec /usr/local/bin/trtexec
注意:若需要`cu12`则将`requirements.txt`中的相关注释关闭,并安装。
## 数据集
......@@ -60,9 +59,12 @@ VGG使用多个较小的卷积滤波器,这减少了网络在训练过程中
## 推理
# N卡推理
trtexec --onnx=/path/to/onnx --saveEngine=./checkpoints/qat/last.trt --int8
python eval.py --device=0
# DCU卡推理
python evaluate_migraphx.py --device=0
## result
......@@ -70,10 +72,10 @@ VGG使用多个较小的卷积滤波器,这减少了网络在训练过程中
### 精度
||原始模型|QAT模型|ONNX模型|TensorRT模型|
|:---|:---|:---|:---|:---|
|Acc|0.9189|0.9185|0.9181|0.9184|
|推理时间|5.5764s|13.7603s|4.2848s|2.9893s|
||原始模型|QAT模型|ONNX模型|TensorRT模型|MIGraphX模型|
|:---|:---|:---|:---|:---|----|
|Acc|0.9189|0.9185|0.9181|0.9184|0.919|
|推理时间|5.5764s|13.7603s|4.2848s|2.9893s|6.7672s|
## 应用场景
......
import argparse
import numpy as np
import migraphx
import torch
import time
from tqdm import tqdm
from utils.data import prepare_dataloader
def eval_migraphx(onnx_path, dataloader, device):
# 加载模型
model = migraphx.parse_onnx(onnx_path)
# 获取模型输入/输出节点信息
inputs = model.get_inputs()
outputs = model.get_outputs()
inputName = model.get_parameter_names()[0]
inputShape = inputs[inputName].lens()
# 编译模型
model.compile(t=migraphx.get_target("gpu"), device_id=device)
correct, total = 0, 0
for it in range(2):
desc = "warmup"
if it == 1:
start_time = time.time()
desc = "eval onnx model"
for data, label in tqdm(dataloader, desc=desc, total=len(dataloader)):
data, label = data.numpy().astype(np.float32), label.numpy().astype(np.float32)
results = model.run({inputName:data})
predictions = np.argmax(results[0], axis=-1)
correct += (label == predictions).sum()
total += len(label)
if it == 1:
end_time = time.time()
return correct / total, end_time - start_time
def main(args):
device = torch.device(f"cuda:{args.device}" if args.device != -1 else "cpu")
test_dataloader, _ = prepare_dataloader("./data/cifar10", False, args.batch_size)
# 测试onnx模型
acc_onnx, runtime_onnx = eval_migraphx("./checkpoints/calibrated/pretrained_qat.onnx", test_dataloader, args.device)
print("==============================================================")
print(f"MIGraphX Model Acc: {acc_onnx}, Inference Time: {runtime_onnx:.4f}s")
print("==============================================================")
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--device", type=int, default=-1)
parser.add_argument("--num_classes", type=int, default=10)
args = parser.parse_args()
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment