Commit 0c8e6668 authored by chenxj's avatar chenxj
Browse files

add eval onnx support

parent 291d1779
......@@ -58,6 +58,15 @@ python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/eval.py -c configs/
```
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/eval.py -c configs/rec/PP-OCRv3/en_PP-OCRv3_rec.yml -o Global.pretrained_model=./output/v3_en_mobile/best_accuracy.pdparams
```
### 测试(ort)
检测模型
```
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/eval.py -c configs/det/det_mv3_db.yml -o Global.pretrained_model=./ch_PP-OCRv3_det_infer/ch_PP-OCRv3_det.onnx --use_onnx=true
```
识别模型
```
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/eval.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml -o Global.pretrained_model=./ch_PP-OCRv3_rec_infer/ch_PP-OCRv3_rec.onnx --use_onnx=true
```
### 推理
```
python3 tools/infer/predict_system.py --image_dir="./doc/imgs/" --det_model_dir="./ch_PP-OCRv3_det_infer/" --rec_model_dir="./ch_PP-OCRv3_rec_infer/" --use_angle_cls=false --rec_image_shape=3,48,320 --warmup=1
......@@ -77,6 +86,16 @@ python3 tools/infer/predict_system.py --image_dir="./doc/imgs/" --det_model_dir=
| Model | Acc |
| :------: | :------: |
| rec | 0.6490 |
检测模型测试(ort)
| Model | Precision | Recall |
| :------: | :------: |:------: |
| det | 0.5097 | 0.4068 |
识别模型测试(ort)
| Model | Acc |
| :------: | :------: |
| rec | 0.6076 |
## 源码仓库及问题反馈
https://developer.hpccube.com/codes/modelzoo/paddleocr
## 参考
......
......@@ -15,6 +15,7 @@ from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric
from ppocr.utils.save_load import load_model
import tools.program as program
from onnxruntime import InferenceSession
def main():
......@@ -58,6 +59,19 @@ def main():
else: # base rec model
config['Architecture']["Head"]['out_channels'] = char_num
if args.use_onnx:
pretrained_model = global_config.get('pretrained_model')
print("pretrained_model:", pretrained_model)
model = InferenceSession(pretrained_model, providers=[('ROCMExecutionProvider', {'device_id': '4'}),'CPUExecutionProvider'])
# build metric
eval_class = build_metric(config['Metric'])
# start eval
metric = program.eval_onnx(model, valid_dataloader, post_process_class,
eval_class)
logger.info('metric eval ***************')
for k, v in metric.items():
logger.info('{}:{}'.format(k, v))
else:
model = build_model(config['Architecture'])
extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR"]
extra_input = False
......@@ -90,5 +104,5 @@ def main():
if __name__ == '__main__':
config, device, logger, vdl_writer = program.preprocess()
config, device, logger, vdl_writer, args = program.preprocess()
main()
......@@ -34,6 +34,11 @@ from ppocr.utils.logging import get_logger
from ppocr.utils.loggers import VDLLogger, WandbLogger, Loggers
from ppocr.utils import profiler
from ppocr.data import build_dataloader
import numpy as np
def str2bool(v):
return v.lower() in ("true", "t", "1")
class ArgsParser(ArgumentParser):
......@@ -51,6 +56,7 @@ class ArgsParser(ArgumentParser):
help='The option of profiler, which should be in format ' \
'\"key1=value1;key2=value2;key3=value3\".'
)
self.add_argument("--use_onnx", type=str2bool, default=False)
def parse_args(self, argv=None):
args = super(ArgsParser, self).parse_args(argv)
......@@ -482,6 +488,53 @@ def eval(model,
return metric
def eval_onnx(model,
valid_dataloader,
post_process_class,
eval_class):
total_frame = 0.0
total_time = 0.0
pbar = tqdm(
total=len(valid_dataloader),
desc='eval model:',
position=0,
leave=True)
max_iter = len(valid_dataloader) - 1 if platform.system(
) == "Windows" else len(valid_dataloader)
input_name = model.get_inputs()[0].name
for idx, batch in enumerate(valid_dataloader):
if idx >= max_iter:
break
images = batch[0]
start = time.time()
images = np.array(images)
input = {input_name:images}
preds = model.run(None, input_feed=input)
batch_numpy = []
for item in batch:
batch_numpy.append(item.numpy())
# Obtain usable results from post-processing methods
total_time += time.time() - start
# Evaluate the results of the current batch
preds = preds[0]
if eval_class.main_indicator == 'hmean':
onnx_preds = {'maps':preds}
else:
onnx_preds = preds
post_result = post_process_class(onnx_preds, batch_numpy[1])
eval_class(post_result, batch_numpy)
pbar.update(1)
total_frame += len(images)
# Get final metric,eg. acc or hmean
metric = eval_class.get_metric()
pbar.close()
metric['fps'] = total_frame / total_time
return metric
def update_center(char_center, post_result, preds):
result, label = post_result
feats, logits = preds
......@@ -607,4 +660,4 @@ def preprocess(is_train=False):
logger.info('train with paddle {} and device {}'.format(paddle.__version__,
device))
return config, device, logger, log_writer
return config, device, logger, log_writer, FLAGS
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment