Commit 5c88a35d authored by liucong's avatar liucong
Browse files

提交migraphx推理方法

parent 229a0d76
......@@ -44,7 +44,6 @@ VGG使用多个较小的卷积滤波器,这减少了网络在训练过程中
ln -s 解压路径(绝对路径)/bin/trtexec /usr/local/bin/trtexec
注意:若需要`cu12`则将`requirements.txt`中的相关注释关闭,并安装。
## 数据集
......@@ -60,20 +59,23 @@ VGG使用多个较小的卷积滤波器,这减少了网络在训练过程中
## 推理
# N卡推理
trtexec --onnx=/path/to/onnx --saveEngine=./checkpoints/qat/last.trt --int8
python eval.py --device=0
# DCU卡推理
python evaluate_migraphx.py --device=0
## result
![alt text](readme_imgs/image-3.png)
### 精度
||原始模型|QAT模型|ONNX模型|TensorRT模型|
|:---|:---|:---|:---|:---|
|Acc|0.9189|0.9185|0.9181|0.9184|
|推理时间|5.5764s|13.7603s|4.2848s|2.9893s|
||原始模型|QAT模型|ONNX模型|TensorRT模型|MIGraphX模型|
|:---|:---|:---|:---|:---|----|
|Acc|0.9189|0.9185|0.9181|0.9184|0.919|
|推理时间|5.5764s|13.7603s|4.2848s|2.9893s|6.7672s|
## 应用场景
......
import argparse
import numpy as np
import migraphx
import torch
import time
from tqdm import tqdm
from utils.data import prepare_dataloader
def eval_migraphx(onnx_path, dataloader, device):
# 加载模型
model = migraphx.parse_onnx(onnx_path)
# 获取模型输入/输出节点信息
inputs = model.get_inputs()
outputs = model.get_outputs()
inputName = model.get_parameter_names()[0]
inputShape = inputs[inputName].lens()
# 编译模型
model.compile(t=migraphx.get_target("gpu"), device_id=device)
correct, total = 0, 0
for it in range(2):
desc = "warmup"
if it == 1:
start_time = time.time()
desc = "eval onnx model"
for data, label in tqdm(dataloader, desc=desc, total=len(dataloader)):
data, label = data.numpy().astype(np.float32), label.numpy().astype(np.float32)
results = model.run({inputName:data})
predictions = np.argmax(results[0], axis=-1)
correct += (label == predictions).sum()
total += len(label)
if it == 1:
end_time = time.time()
return correct / total, end_time - start_time
def main(args):
device = torch.device(f"cuda:{args.device}" if args.device != -1 else "cpu")
test_dataloader, _ = prepare_dataloader("./data/cifar10", False, args.batch_size)
# 测试onnx模型
acc_onnx, runtime_onnx = eval_migraphx("./checkpoints/calibrated/pretrained_qat.onnx", test_dataloader, args.device)
print("==============================================================")
print(f"MIGraphX Model Acc: {acc_onnx}, Inference Time: {runtime_onnx:.4f}s")
print("==============================================================")
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--device", type=int, default=-1)
parser.add_argument("--num_classes", type=int, default=10)
args = parser.parse_args()
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment