Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
vgg16-qat_pytorch
Commits
f2acf41b
Commit
f2acf41b
authored
Mar 02, 2024
by
mashun1
Browse files
fix eval
parent
5c88a35d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
40 additions
and
10 deletions
+40
-10
README.md
README.md
+2
-2
evaluate.py
evaluate.py
+38
-8
No files found.
README.md
View file @
f2acf41b
...
@@ -72,10 +72,10 @@ VGG使用多个较小的卷积滤波器,这减少了网络在训练过程中
...
@@ -72,10 +72,10 @@ VGG使用多个较小的卷积滤波器,这减少了网络在训练过程中
### 精度
### 精度
||原始模型
|QAT模型
|ONNX模型|TensorRT模型|MIGraphX模型|
||原始模型
(A800)|QAT模型(A800)
|ONNX模型
(A800)
|TensorRT模型
(A800)
|MIGraphX模型|
|:---|:---|:---|:---|:---|----|
|:---|:---|:---|:---|:---|----|
|Acc|0.9189|0.9185|0.9181|0.9184|0.919|
|Acc|0.9189|0.9185|0.9181|0.9184|0.919|
|推理时间|
5.5764
s|1
3
.7
60
3s|
4.2848s|2.9893
s|6.7672s|
|推理时间|
2.2469
s|1
0
.7
95
3s|
1.3253s|0.2368
s|6.7672s|
## 应用场景
## 应用场景
...
...
evaluate.py
View file @
f2acf41b
...
@@ -17,8 +17,37 @@ import numpy as np
...
@@ -17,8 +17,37 @@ import numpy as np
import
pycuda.driver
as
cuda
import
pycuda.driver
as
cuda
from
pytorch_quantization
import
quant_modules
from
pytorch_quantization
import
quant_modules
from
torch.utils.data
import
DataLoader
,
Dataset
class
NumpyDataLoader
:
def
__init__
(
self
,
dataloader
):
self
.
data
=
[]
for
data
,
label
in
dataloader
:
self
.
data
.
append
((
data
.
numpy
().
astype
(
np
.
float32
),
label
.
numpy
().
astype
(
np
.
float32
)))
def
__len__
(
self
):
return
len
(
self
.
data
)
def
__getitem__
(
self
,
idx
):
return
self
.
data
[
idx
]
class
CacheDataLoader
:
def
__init__
(
self
,
dataloader
):
self
.
data
=
[]
for
data
,
label
in
dataloader
:
self
.
data
.
append
((
data
,
label
))
def
__len__
(
self
):
return
len
(
self
.
data
)
def
__getitem__
(
self
,
idx
):
return
self
.
data
[
idx
]
def
eval_onnx
(
ckpt_path
,
dataloader
,
device
):
def
eval_onnx
(
ckpt_path
,
dataloader
,
device
):
sess_options
=
onnxruntime
.
SessionOptions
()
sess_options
=
onnxruntime
.
SessionOptions
()
...
@@ -42,7 +71,6 @@ def eval_onnx(ckpt_path, dataloader, device):
...
@@ -42,7 +71,6 @@ def eval_onnx(ckpt_path, dataloader, device):
desc
=
"eval onnx model"
desc
=
"eval onnx model"
for
data
,
label
in
tqdm
(
dataloader
,
desc
=
desc
,
total
=
len
(
dataloader
)):
for
data
,
label
in
tqdm
(
dataloader
,
desc
=
desc
,
total
=
len
(
dataloader
)):
data
,
label
=
data
.
numpy
().
astype
(
np
.
float32
),
label
.
numpy
().
astype
(
np
.
float32
)
output
=
session
.
run
([
output_name
],
{
input_name
:
data
})
output
=
session
.
run
([
output_name
],
{
input_name
:
data
})
predictions
=
np
.
argmax
(
output
,
axis
=-
1
)[
0
]
predictions
=
np
.
argmax
(
output
,
axis
=-
1
)[
0
]
...
@@ -73,10 +101,8 @@ def eval_trt(ckpt_path, dataloader, device):
...
@@ -73,10 +101,8 @@ def eval_trt(ckpt_path, dataloader, device):
start_time
=
time
.
time
()
start_time
=
time
.
time
()
for
data
,
label
in
tqdm
(
dataloader
,
desc
=
desc
,
total
=
(
len
(
dataloader
))):
for
data
,
label
in
tqdm
(
dataloader
,
desc
=
desc
,
total
=
(
len
(
dataloader
))):
data
=
data
.
numpy
()
result
=
model
(
data
,
batch_size
)
result
=
model
(
data
,
batch_size
)
result
=
np
.
argmax
(
result
,
axis
=-
1
)
result
=
np
.
argmax
(
result
,
axis
=-
1
)
label
=
label
.
numpy
()
total
+=
label
.
shape
[
0
]
total
+=
label
.
shape
[
0
]
correct
+=
(
label
==
result
).
sum
()
correct
+=
(
label
==
result
).
sum
()
...
@@ -147,16 +173,20 @@ def eval_qat(ckpt_path, dataloader, num_classes, device):
...
@@ -147,16 +173,20 @@ def eval_qat(ckpt_path, dataloader, num_classes, device):
def
main
(
args
):
def
main
(
args
):
device
=
torch
.
device
(
f
"cuda:
{
args
.
device
}
"
if
args
.
device
!=
-
1
else
"cpu"
)
device
=
torch
.
device
(
f
"cuda:
{
args
.
device
}
"
if
args
.
device
!=
-
1
else
"cpu"
)
test_dataloader
,
_
=
prepare_dataloader
(
"./data/cifar10"
,
False
,
args
.
batch_size
)
test_dataloader
,
_
=
prepare_dataloader
(
"./data/cifar10"
,
False
,
1
)
numpy_dataloader
=
NumpyDataLoader
(
test_dataloader
)
cache_dataloader
=
CacheDataLoader
(
test_dataloader
)
# 测试pytorch模型
# 测试pytorch模型
acc1
,
runtime1
=
eval_original
(
"./checkpoints/pretrained/pretrained_model.pth"
,
test
_dataloader
,
args
.
num_classes
,
device
)
acc1
,
runtime1
=
eval_original
(
"./checkpoints/pretrained/pretrained_model.pth"
,
cache
_dataloader
,
args
.
num_classes
,
device
)
acc2
,
runtime2
=
eval_qat
(
"./checkpoints/
calibrated
/pretrained_model.pth"
,
test
_dataloader
,
args
.
num_classes
,
device
)
acc2
,
runtime2
=
eval_qat
(
"./checkpoints/
qat
/pretrained_model.pth"
,
cache
_dataloader
,
args
.
num_classes
,
device
)
acc_onnx
,
runtime_onnx
=
eval_onnx
(
"./checkpoints/
calibrated
/pretrained_qat.onnx"
,
test
_dataloader
,
args
.
device
)
acc_onnx
,
runtime_onnx
=
eval_onnx
(
"./checkpoints/
qat
/pretrained_qat.onnx"
,
numpy
_dataloader
,
args
.
device
)
acc_trt
,
runtime_trt
=
eval_trt
(
"./checkpoints/
calibrated
/last.trt"
,
test
_dataloader
,
args
.
device
)
acc_trt
,
runtime_trt
=
eval_trt
(
"./checkpoints/
qat
/last.trt"
,
numpy
_dataloader
,
args
.
device
)
print
(
"=============================================================="
)
print
(
"=============================================================="
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment