Commit f2acf41b authored by mashun1's avatar mashun1
Browse files

fix eval

parent 5c88a35d
...@@ -72,10 +72,10 @@ VGG使用多个较小的卷积滤波器,这减少了网络在训练过程中 ...@@ -72,10 +72,10 @@ VGG使用多个较小的卷积滤波器,这减少了网络在训练过程中
### 精度 ### 精度
||原始模型|QAT模型|ONNX模型|TensorRT模型|MIGraphX模型| ||原始模型(A800)|QAT模型(A800)|ONNX模型(A800)|TensorRT模型(A800)|MIGraphX模型|
|:---|:---|:---|:---|:---|----| |:---|:---|:---|:---|:---|----|
|Acc|0.9189|0.9185|0.9181|0.9184|0.919| |Acc|0.9189|0.9185|0.9181|0.9184|0.919|
|推理时间|5.5764s|13.7603s|4.2848s|2.9893s|6.7672s| |推理时间|2.2469s|10.7953s|1.3253s|0.2368s|6.7672s|
## 应用场景 ## 应用场景
......
...@@ -17,8 +17,37 @@ import numpy as np ...@@ -17,8 +17,37 @@ import numpy as np
import pycuda.driver as cuda import pycuda.driver as cuda
from pytorch_quantization import quant_modules from pytorch_quantization import quant_modules
from torch.utils.data import DataLoader, Dataset
class NumpyDataLoader:
def __init__(self, dataloader):
self.data = []
for data, label in dataloader:
self.data.append((data.numpy().astype(np.float32), label.numpy().astype(np.float32)))
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
class CacheDataLoader:
def __init__(self, dataloader):
self.data = []
for data, label in dataloader:
self.data.append((data, label))
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
def eval_onnx(ckpt_path, dataloader, device): def eval_onnx(ckpt_path, dataloader, device):
sess_options = onnxruntime.SessionOptions() sess_options = onnxruntime.SessionOptions()
...@@ -42,7 +71,6 @@ def eval_onnx(ckpt_path, dataloader, device): ...@@ -42,7 +71,6 @@ def eval_onnx(ckpt_path, dataloader, device):
desc = "eval onnx model" desc = "eval onnx model"
for data, label in tqdm(dataloader, desc=desc, total=len(dataloader)): for data, label in tqdm(dataloader, desc=desc, total=len(dataloader)):
data, label = data.numpy().astype(np.float32), label.numpy().astype(np.float32)
output = session.run([output_name], {input_name: data}) output = session.run([output_name], {input_name: data})
predictions = np.argmax(output, axis=-1)[0] predictions = np.argmax(output, axis=-1)[0]
...@@ -73,10 +101,8 @@ def eval_trt(ckpt_path, dataloader, device): ...@@ -73,10 +101,8 @@ def eval_trt(ckpt_path, dataloader, device):
start_time = time.time() start_time = time.time()
for data, label in tqdm(dataloader, desc=desc, total=(len(dataloader))): for data, label in tqdm(dataloader, desc=desc, total=(len(dataloader))):
data = data.numpy()
result = model(data, batch_size) result = model(data, batch_size)
result = np.argmax(result, axis=-1) result = np.argmax(result, axis=-1)
label = label.numpy()
total += label.shape[0] total += label.shape[0]
correct += (label == result).sum() correct += (label == result).sum()
...@@ -147,16 +173,20 @@ def eval_qat(ckpt_path, dataloader, num_classes, device): ...@@ -147,16 +173,20 @@ def eval_qat(ckpt_path, dataloader, num_classes, device):
def main(args): def main(args):
device = torch.device(f"cuda:{args.device}" if args.device != -1 else "cpu") device = torch.device(f"cuda:{args.device}" if args.device != -1 else "cpu")
test_dataloader, _ = prepare_dataloader("./data/cifar10", False, args.batch_size) test_dataloader, _ = prepare_dataloader("./data/cifar10", False, 1)
numpy_dataloader = NumpyDataLoader(test_dataloader)
cache_dataloader = CacheDataLoader(test_dataloader)
# 测试pytorch模型 # 测试pytorch模型
acc1, runtime1 = eval_original("./checkpoints/pretrained/pretrained_model.pth", test_dataloader, args.num_classes, device) acc1, runtime1 = eval_original("./checkpoints/pretrained/pretrained_model.pth", cache_dataloader, args.num_classes, device)
acc2, runtime2 = eval_qat("./checkpoints/calibrated/pretrained_model.pth", test_dataloader, args.num_classes, device) acc2, runtime2 = eval_qat("./checkpoints/qat/pretrained_model.pth", cache_dataloader, args.num_classes, device)
acc_onnx, runtime_onnx = eval_onnx("./checkpoints/calibrated/pretrained_qat.onnx", test_dataloader, args.device) acc_onnx, runtime_onnx = eval_onnx("./checkpoints/qat/pretrained_qat.onnx", numpy_dataloader, args.device)
acc_trt, runtime_trt = eval_trt("./checkpoints/calibrated/last.trt", test_dataloader, args.device) acc_trt, runtime_trt = eval_trt("./checkpoints/qat/last.trt", numpy_dataloader, args.device)
print("==============================================================") print("==============================================================")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment