Commit 466992ff authored by liucong's avatar liucong
Browse files

修改paddleocr工程格式

parent eb66b288
......@@ -28,13 +28,10 @@ class NormalizeImage(object):
if isinstance(img, Image.Image):
img = np.array(img)
assert isinstance(img,
np.ndarray), "invalid input 'img' in NormalizeImage"
data['image'] = (
img.astype('float32') * self.scale - self.mean) / self.std
assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage"
data['image'] = (img.astype('float32') * self.scale - self.mean) / self.std
return data
class ToCHWImage(object):
def __init__(self, **kwargs):
pass
......@@ -298,16 +295,8 @@ class det_rec_functions(object):
self.modelDet = migraphx.parse_onnx(self.det_file, map_input_dims=detInput)
# 获取模型输入/输出节点信息
print("det_inputs:")
inputs_det = self.modelDet.get_inputs()
for key,value in inputs_det.items():
print("{}:{}".format(key,value))
print("det_outputs:")
outputs_det = self.modelDet.get_outputs()
for key,value in outputs_det.items():
print("{}:{}".format(key,value))
self.inputName = self.modelDet.get_parameter_names()[0]
self.inputShape = inputs_det[self.inputName].lens()
print("DB inputName:{0} \nDB inputShape:{1}".format(self.inputName, self.inputShape))
......@@ -321,16 +310,8 @@ class det_rec_functions(object):
self.modelRec = migraphx.parse_onnx(self.rec_file, map_input_dims=recInput)
# 获取模型输入/输出节点信息
print("rec_inputs:")
inputs_rec = self.modelRec.get_inputs()
for key,value in inputs_rec.items():
print("{}:{}".format(key,value))
print("rec_outputs:")
outputs_rec = self.modelRec.get_outputs()
for key,value in outputs_rec.items():
print("{}:{}".format(key,value))
self.inputName = self.modelRec.get_parameter_names()[0]
self.inputShape = inputs_rec[self.inputName].lens()
print("SVTR inputName:{0} \nSVTR inputShape:{1}".format(self.inputName, self.inputShape))
......@@ -480,8 +461,7 @@ class det_rec_functions(object):
# migraphx推理
resultDets = self.modelDet.run({self.modelDet.get_parameter_names()[0]: img_part})
# 获取第一个输出节点的数据,migraphx.argument类型
resultDet = resultDets[0]
resultDet = resultDets[0] # 获取第一个输出节点的数据,migraphx.argument类型
outs_part = np.array(resultDet)
post_res_part = self.det_re_process_op(outs_part, shape_part_list)
......@@ -523,8 +503,7 @@ class det_rec_functions(object):
# migraphx推理
results = self.modelRec.run({self.modelRec.get_parameter_names()[0]: img})
# 获取第一个输出节点的数据,migraphx.argument类型
result = results[0]
result = results[0] # 获取第一个输出节点的数据,migraphx.argument类型
outs = np.array(result)
result = process_op(outs)
......@@ -552,12 +531,9 @@ class det_rec_functions(object):
if __name__=='__main__':
image = cv2.imread('../Resource/Images/vlpr.jpg')
start = time.time()
ocr_sys = det_rec_functions(image)
dt_boxes = ocr_sys.get_boxes()
results = ocr_sys.recognition_img(dt_boxes)
results_info = results[0][0]
print('net forward time: {:.4f}'.format(time.time() - start))
print("############# OCR Results #############")
print(results_info)
\ No newline at end of file
......@@ -25,7 +25,7 @@ ErrorCode DB::Initialize(InitializationParameterOfDB InitializationParameterOfDB
{
// 读取配置文件
std::string configFilePath=InitializationParameterOfDB.configFilePath;
if(Exists(configFilePath)==false)
if(!Exists(configFilePath))
{
LOG_ERROR(stdout, "no configuration file!\n");
return CONFIG_FILE_NOT_EXIST;
......@@ -47,29 +47,19 @@ ErrorCode DB::Initialize(InitializationParameterOfDB InitializationParameterOfDB
dbParameter.ScoreMode = (string)netNode["ScoreMode"];
// 加载模型
if(Exists(modelPath)==false)
if(!Exists(modelPath))
{
LOG_ERROR(stdout,"%s not exist!\n",modelPath.c_str());
return MODEL_NOT_EXIST;
}
migraphx::onnx_options onnx_options;
onnx_options.map_input_dims["x"]={1,3,2496,2496}; // 设置最大shape
onnx_options.map_input_dims["x"] = {1,3,2496,2496}; // 设置最大shape
net = migraphx::parse_onnx(modelPath, onnx_options);
LOG_INFO(stdout,"succeed to load model: %s\n",GetFileName(modelPath).c_str());
// 获取模型输入/输出节点信息
std::cout<<"DB_inputs:"<<std::endl;
std::unordered_map<std::string, migraphx::shape> inputs=net.get_inputs();
for(auto i:inputs)
{
std::cout<<i.first<<":"<<i.second<<std::endl;
}
std::cout<<"DB_outputs:"<<std::endl;
std::unordered_map<std::string, migraphx::shape> outputs=net.get_outputs();
for(auto i:outputs)
{
std::cout<<i.first<<":"<<i.second<<std::endl;
}
inputName=inputs.begin()->first;
inputShape=inputs.begin()->second;
int N=inputShape.lens()[0];
......@@ -161,6 +151,7 @@ ErrorCode DB::Infer(const cv::Mat &img, std::vector<cv::Mat> &imgList)
// 推理
std::vector<migraphx::argument> inferenceResults = net.eval(inputData);
// 获取推理结果
migraphx::argument result = inferenceResults[0];
......
......@@ -13,16 +13,14 @@ SVTR::SVTR()
SVTR::~SVTR()
{
configurationFile.release();
}
ErrorCode SVTR::Initialize(InitializationParameterOfSVTR InitializationParameterOfSVTR)
{
// 读取配置文件
std::string configFilePath=InitializationParameterOfSVTR.configFilePath;
if(Exists(configFilePath)==false)
if(!Exists(configFilePath))
{
LOG_ERROR(stdout, "no configuration file!\n");
return CONFIG_FILE_NOT_EXIST;
......@@ -40,7 +38,7 @@ ErrorCode SVTR::Initialize(InitializationParameterOfSVTR InitializationParameter
std::string dictPath = (std::string)netNode["DictPath"];
// 加载模型
if(Exists(modelPath)==false)
if(!Exists(modelPath))
{
LOG_ERROR(stdout,"%s not exist!\n",modelPath.c_str());
return MODEL_NOT_EXIST;
......@@ -51,19 +49,8 @@ ErrorCode SVTR::Initialize(InitializationParameterOfSVTR InitializationParameter
LOG_INFO(stdout,"succeed to load model: %s\n",GetFileName(modelPath).c_str());
// 获取模型输入/输出节点信息
std::cout<<"SVTR_inputs:"<<std::endl;
std::unordered_map<std::string, migraphx::shape> inputs=net.get_inputs();
for(auto i:inputs)
{
std::cout<<i.first<<":"<<i.second<<std::endl;
}
std::cout<<"DSVTR_outputs:"<<std::endl;
std::unordered_map<std::string, migraphx::shape> outputs=net.get_outputs();
for(auto i:outputs)
{
std::cout<<i.first<<":"<<i.second<<std::endl;
}
inputName=inputs.begin()->first;
inputShape=inputs.begin()->second;
int N=inputShape.lens()[0];
......
......@@ -10,9 +10,7 @@ VLPR::VLPR()
VLPR::~VLPR()
{
configurationFile.release();
}
ErrorCode VLPR::Initialize(InitializationParameterOfDB initParamOfDB, InitializationParameterOfSVTR initParamOfSVTR)
......
......@@ -20,11 +20,7 @@ int main()
// 推理
std::vector<std::string> recTexts;
std::vector<float> recTextScores;
double time1 = cv::getTickCount();
vlpr.Infer(Image, recTexts, recTextScores);
double time2 = cv::getTickCount();
double elapsedTime = (time2 - time1)*1000 / cv::getTickFrequency();
LOG_INFO(stdout, "inference time:%f ms\n", elapsedTime);
// 打印结果
for (int i = 0; i < recTexts.size(); i++)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment