from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np import os import sys import json __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) sys.path.insert(0, os.path.abspath(os.path.join(__dir__, ".."))) os.environ["FLAGS_allocator_strategy"] = "auto_growth" import paddle import paddle.profiler as profiler # 导入性能分析器模块 from ppocr.data import create_operators, transform from ppocr.modeling.architectures import build_model from ppocr.postprocess import build_post_process from ppocr.utils.save_load import load_model from ppocr.utils.utility import get_image_file_list import tools.program as program def main(): global_config = config["Global"] # build post process post_process_class = build_post_process(config["PostProcess"], global_config) # build model if hasattr(post_process_class, "character"): char_num = len(getattr(post_process_class, "character")) if config["Architecture"]["algorithm"] in [ "Distillation", ]: # distillation model for key in config["Architecture"]["Models"]: if ( config["Architecture"]["Models"][key]["Head"]["name"] == "MultiHead" ): # multi head out_channels_list = {} if config["PostProcess"]["name"] == "DistillationSARLabelDecode": char_num = char_num - 2 if config["PostProcess"]["name"] == "DistillationNRTRLabelDecode": char_num = char_num - 3 out_channels_list["CTCLabelDecode"] = char_num out_channels_list["SARLabelDecode"] = char_num + 2 out_channels_list["NRTRLabelDecode"] = char_num + 3 config["Architecture"]["Models"][key]["Head"][ "out_channels_list" ] = out_channels_list else: config["Architecture"]["Models"][key]["Head"][ "out_channels" ] = char_num elif config["Architecture"]["Head"]["name"] == "MultiHead": # multi head out_channels_list = {} char_num = len(getattr(post_process_class, "character")) if config["PostProcess"]["name"] == "SARLabelDecode": char_num = char_num - 2 if config["PostProcess"]["name"] == "NRTRLabelDecode": char_num = char_num - 3 out_channels_list["CTCLabelDecode"] = char_num out_channels_list["SARLabelDecode"] = char_num + 2 out_channels_list["NRTRLabelDecode"] = char_num + 3 config["Architecture"]["Head"]["out_channels_list"] = out_channels_list else: # base rec model config["Architecture"]["Head"]["out_channels"] = char_num if config["Architecture"].get("algorithm") in ["LaTeXOCR"]: config["Architecture"]["Backbone"]["is_predict"] = True config["Architecture"]["Backbone"]["is_export"] = True config["Architecture"]["Head"]["is_export"] = True model = build_model(config["Architecture"]) load_model(config, model) # create data ops transforms = [] for op in config["Eval"]["dataset"]["transforms"]: op_name = list(op)[0] if "Label" in op_name: continue elif op_name in ["RecResizeImg"]: op[op_name]["infer_mode"] = True elif op_name == "KeepKeys": if config["Architecture"]["algorithm"] == "SRN": op[op_name]["keep_keys"] = [ "image", "encoder_word_pos", "gsrm_word_pos", "gsrm_slf_attn_bias1", "gsrm_slf_attn_bias2", ] elif config["Architecture"]["algorithm"] == "SAR": op[op_name]["keep_keys"] = ["image", "valid_ratio"] elif config["Architecture"]["algorithm"] == "RobustScanner": op[op_name]["keep_keys"] = ["image", "valid_ratio", "word_positons"] else: op[op_name]["keep_keys"] = ["image"] transforms.append(op) global_config["infer_mode"] = True ops = create_operators(transforms, global_config) save_res_path = config["Global"].get( "save_res_path", "./output/rec/predicts_rec.txt" ) if not os.path.exists(os.path.dirname(save_res_path)): os.makedirs(os.path.dirname(save_res_path)) model.eval() # 创建性能分析器相关的代码 def my_on_trace_ready(prof): callback = profiler.export_chrome_tracing('./profiler_demo') callback(prof) # 将 Overview Summary 和 Operator Summary 保存到文件 with open('./profiler_summary.txt', 'w') as f: f.write("Overview Summary:\n") summary_overview = prof.summary(sorted_by=profiler.SortedKeys.GPUTotal, op_detail=True, thread_sep=True, time_unit='ms') if summary_overview is not None: f.write(summary_overview) else: f.write("No summary available for Overview.\n") f.write("\n\nOperator Summary:\n") summary_operator = prof.summary(sorted_by=profiler.SortedKeys.GPUTotal, op_detail=True, thread_sep=True, time_unit='ms') if summary_operator is not None: f.write(summary_operator) else: f.write("No summary available for Operator.\n") # 初始化 Profiler 对象,设置 timer_only=False 以收集详细信息 p = profiler.Profiler(on_trace_ready=my_on_trace_ready, timer_only=False) p.start() infer_imgs = config["Global"]["infer_img"] infer_list = config["Global"].get("infer_list", None) with open(save_res_path, "w") as fout: for file in get_image_file_list(infer_imgs, infer_list=infer_list): logger.info("infer_img: {}".format(file)) with open(file, "rb") as f: img = f.read() if config["Architecture"]["algorithm"] in [ "UniMERNet", "PP-FormulaNet-S", "PP-FormulaNet-L", ]: data = {"image": img, "filename": file} else: data = {"image": img} batch = transform(data, ops) if config["Architecture"]["algorithm"] == "SRN": encoder_word_pos_list = np.expand_dims(batch[1], axis=0) gsrm_word_pos_list = np.expand_dims(batch[2], axis=0) gsrm_slf_attn_bias1_list = np.expand_dims(batch[3], axis=0) gsrm_slf_attn_bias2_list = np.expand_dims(batch[4], axis=0) others = [ paddle.to_tensor(encoder_word_pos_list), paddle.to_tensor(gsrm_word_pos_list), paddle.to_tensor(gsrm_slf_attn_bias1_list), paddle.to_tensor(gsrm_slf_attn_bias2_list), ] if config["Architecture"]["algorithm"] == "SAR": valid_ratio = np.expand_dims(batch[-1], axis=0) img_metas = [paddle.to_tensor(valid_ratio)] if config["Architecture"]["algorithm"] == "RobustScanner": valid_ratio = np.expand_dims(batch[1], axis=0) word_positons = np.expand_dims(batch[2], axis=0) img_metas = [ paddle.to_tensor(valid_ratio), paddle.to_tensor(word_positons), ] if config["Architecture"]["algorithm"] == "CAN": image_mask = paddle.ones( (np.expand_dims(batch[0], axis=0).shape), dtype="float32" ) label = paddle.ones((1, 36), dtype="int64") images = np.expand_dims(batch[0], axis=0) images = paddle.to_tensor(images) if config["Architecture"]["algorithm"] == "SRN": preds = model(images, others) elif config["Architecture"]["algorithm"] == "SAR": preds = model(images, img_metas) elif config["Architecture"]["algorithm"] == "RobustScanner": preds = model(images, img_metas) elif config["Architecture"]["algorithm"] == "CAN": preds = model([images, image_mask, label]) else: preds = model(images) post_result = post_process_class(preds) info = None if isinstance(post_result, dict): rec_info = dict() for key in post_result: if len(post_result[key][0]) >= 2: rec_info[key] = { "label": post_result[key][0][0], "score": float(post_result[key][0][1]), } info = json.dumps(rec_info, ensure_ascii=False) elif isinstance(post_result, list) and isinstance(post_result[0], int): # for RFLearning CNT branch info = str(post_result[0]) elif config["Architecture"]["algorithm"] in [ "LaTeXOCR", "UniMERNet", "PP-FormulaNet-S", "PP-FormulaNet-L", ]: info = str(post_result[0]) else: if len(post_result[0]) >= 2: info = post_result[0][0] + "\t" + str(post_result[0][1]) if info is not None: logger.info("\t result: {}".format(info)) fout.write(file + "\t" + info + "\n") p.step() # 每次推理后调用 profiler 的 step 方法 p.stop() # 停止 profiler logger.info("success!") if __name__ == "__main__": config, device, logger, vdl_writer = program.preprocess() main()