from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import sys __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.insert(0, __dir__) sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..'))) from ppocr.data import build_dataloader from ppocr.modeling.architectures import build_model from ppocr.postprocess import build_post_process from ppocr.metrics import build_metric from ppocr.utils.save_load import load_model import tools.program as program from onnxruntime import InferenceSession def main(): global_config = config['Global'] # build dataloader valid_dataloader = build_dataloader(config, 'Eval', device, logger) # build post process post_process_class = build_post_process(config['PostProcess'], global_config) # build model # for rec algorithm if hasattr(post_process_class, 'character'): char_num = len(getattr(post_process_class, 'character')) if config['Architecture']["algorithm"] in ["Distillation", ]: # distillation model for key in config['Architecture']["Models"]: if config['Architecture']['Models'][key]['Head'][ 'name'] == 'MultiHead': # for multi head out_channels_list = {} if config['PostProcess'][ 'name'] == 'DistillationSARLabelDecode': char_num = char_num - 2 out_channels_list['CTCLabelDecode'] = char_num out_channels_list['SARLabelDecode'] = char_num + 2 config['Architecture']['Models'][key]['Head'][ 'out_channels_list'] = out_channels_list else: config['Architecture']["Models"][key]["Head"][ 'out_channels'] = char_num elif config['Architecture']['Head'][ 'name'] == 'MultiHead': # for multi head out_channels_list = {} if config['PostProcess']['name'] == 'SARLabelDecode': char_num = char_num - 2 out_channels_list['CTCLabelDecode'] = char_num out_channels_list['SARLabelDecode'] = char_num + 2 config['Architecture']['Head'][ 'out_channels_list'] = out_channels_list else: # base rec model config['Architecture']["Head"]['out_channels'] = char_num if args.use_onnx: pretrained_model = global_config.get('pretrained_model') print("pretrained_model:", pretrained_model) model = InferenceSession(pretrained_model, providers=[('ROCMExecutionProvider', {'device_id': '4'}),'CPUExecutionProvider']) # build metric eval_class = build_metric(config['Metric']) # start eval metric = program.eval_onnx(model, valid_dataloader, post_process_class, eval_class) logger.info('metric eval ***************') for k, v in metric.items(): logger.info('{}:{}'.format(k, v)) else: model = build_model(config['Architecture']) extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR"] extra_input = False if config['Architecture']['algorithm'] == 'Distillation': for key in config['Architecture']["Models"]: extra_input = extra_input or config['Architecture']['Models'][key][ 'algorithm'] in extra_input_models else: extra_input = config['Architecture']['algorithm'] in extra_input_models if "model_type" in config['Architecture'].keys(): model_type = config['Architecture']['model_type'] else: model_type = None best_model_dict = load_model( config, model, model_type=config['Architecture']["model_type"]) if len(best_model_dict): logger.info('metric in ckpt ***************') for k, v in best_model_dict.items(): logger.info('{}:{}'.format(k, v)) # build metric eval_class = build_metric(config['Metric']) # start eval metric = program.eval(model, valid_dataloader, post_process_class, eval_class, model_type, extra_input) logger.info('metric eval ***************') for k, v in metric.items(): logger.info('{}:{}'.format(k, v)) if __name__ == '__main__': config, device, logger, vdl_writer, args = program.preprocess() main()