Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
vgg16-qat_pytorch
Commits
5c88a35d
"src/nni_manager/vscode:/vscode.git/clone" did not exist on "77dac12baee6c3243445d71cd1eb812d7f73c7a7"
Commit
5c88a35d
authored
Mar 01, 2024
by
liucong
Browse files
提交migraphx推理方法
parent
229a0d76
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
81 additions
and
10 deletions
+81
-10
README.md
README.md
+12
-10
evaluate_migraphx.py
evaluate_migraphx.py
+69
-0
No files found.
README.md
View file @
5c88a35d
...
...
@@ -31,20 +31,19 @@ VGG使用多个较小的卷积滤波器,这减少了网络在训练过程中
cuda 11
pip install -r requirements.txt
pip install --no-cache-dir --extra-index-url https://pypi.nvidia.com pytorch-quantization
2、TensorRT
wget https://github.com/NVIDIA/TensorRT/archive/refs/tags/8.5.3.zip
unzip [下载的压缩包] -d [解压路径]
pip install 解压路径/python/tensorrt-8.5.3.1-cp39-none-linux_x86_64.whl
ln -s 解压路径(绝对路径)/bin/trtexec /usr/local/bin/trtexec
注意:若需要
`cu12`
则将
`requirements.txt`
中的相关注释关闭,并安装。
## 数据集
...
...
@@ -60,9 +59,12 @@ VGG使用多个较小的卷积滤波器,这减少了网络在训练过程中
## 推理
# N卡推理
trtexec --onnx=/path/to/onnx --saveEngine=./checkpoints/qat/last.trt --int8
python eval.py --device=0
# DCU卡推理
python evaluate_migraphx.py --device=0
## result
...
...
@@ -70,10 +72,10 @@ VGG使用多个较小的卷积滤波器,这减少了网络在训练过程中
### 精度
||原始模型|QAT模型|ONNX模型|TensorRT模型|
|:---|:---|:---|:---|:---|
|Acc|0.9189|0.9185|0.9181|0.9184|
|推理时间|5.5764s|13.7603s|4.2848s|2.9893s|
||原始模型|QAT模型|ONNX模型|TensorRT模型|
MIGraphX模型|
|:---|:---|:---|:---|:---|
----|
|Acc|0.9189|0.9185|0.9181|0.9184|
0.919|
|推理时间|5.5764s|13.7603s|4.2848s|2.9893s|
6.7672s|
## 应用场景
...
...
evaluate_migraphx.py
0 → 100644
View file @
5c88a35d
import
argparse
import
numpy
as
np
import
migraphx
import
torch
import
time
from
tqdm
import
tqdm
from
utils.data
import
prepare_dataloader
def
eval_migraphx
(
onnx_path
,
dataloader
,
device
):
# 加载模型
model
=
migraphx
.
parse_onnx
(
onnx_path
)
# 获取模型输入/输出节点信息
inputs
=
model
.
get_inputs
()
outputs
=
model
.
get_outputs
()
inputName
=
model
.
get_parameter_names
()[
0
]
inputShape
=
inputs
[
inputName
].
lens
()
# 编译模型
model
.
compile
(
t
=
migraphx
.
get_target
(
"gpu"
),
device_id
=
device
)
correct
,
total
=
0
,
0
for
it
in
range
(
2
):
desc
=
"warmup"
if
it
==
1
:
start_time
=
time
.
time
()
desc
=
"eval onnx model"
for
data
,
label
in
tqdm
(
dataloader
,
desc
=
desc
,
total
=
len
(
dataloader
)):
data
,
label
=
data
.
numpy
().
astype
(
np
.
float32
),
label
.
numpy
().
astype
(
np
.
float32
)
results
=
model
.
run
({
inputName
:
data
})
predictions
=
np
.
argmax
(
results
[
0
],
axis
=-
1
)
correct
+=
(
label
==
predictions
).
sum
()
total
+=
len
(
label
)
if
it
==
1
:
end_time
=
time
.
time
()
return
correct
/
total
,
end_time
-
start_time
def
main
(
args
):
device
=
torch
.
device
(
f
"cuda:
{
args
.
device
}
"
if
args
.
device
!=
-
1
else
"cpu"
)
test_dataloader
,
_
=
prepare_dataloader
(
"./data/cifar10"
,
False
,
args
.
batch_size
)
# 测试onnx模型
acc_onnx
,
runtime_onnx
=
eval_migraphx
(
"./checkpoints/calibrated/pretrained_qat.onnx"
,
test_dataloader
,
args
.
device
)
print
(
"=============================================================="
)
print
(
f
"MIGraphX Model Acc:
{
acc_onnx
}
, Inference Time:
{
runtime_onnx
:.
4
f
}
s"
)
print
(
"=============================================================="
)
if
__name__
==
'__main__'
:
import
argparse
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
16
)
parser
.
add_argument
(
"--device"
,
type
=
int
,
default
=-
1
)
parser
.
add_argument
(
"--num_classes"
,
type
=
int
,
default
=
10
)
args
=
parser
.
parse_args
()
main
(
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment