Commit 466992ff authored by liucong's avatar liucong
Browse files

修改paddleocr工程格式

parent eb66b288
...@@ -28,13 +28,10 @@ class NormalizeImage(object): ...@@ -28,13 +28,10 @@ class NormalizeImage(object):
if isinstance(img, Image.Image): if isinstance(img, Image.Image):
img = np.array(img) img = np.array(img)
assert isinstance(img, assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage"
np.ndarray), "invalid input 'img' in NormalizeImage" data['image'] = (img.astype('float32') * self.scale - self.mean) / self.std
data['image'] = (
img.astype('float32') * self.scale - self.mean) / self.std
return data return data
class ToCHWImage(object): class ToCHWImage(object):
def __init__(self, **kwargs): def __init__(self, **kwargs):
pass pass
...@@ -298,16 +295,8 @@ class det_rec_functions(object): ...@@ -298,16 +295,8 @@ class det_rec_functions(object):
self.modelDet = migraphx.parse_onnx(self.det_file, map_input_dims=detInput) self.modelDet = migraphx.parse_onnx(self.det_file, map_input_dims=detInput)
# 获取模型输入/输出节点信息 # 获取模型输入/输出节点信息
print("det_inputs:")
inputs_det = self.modelDet.get_inputs() inputs_det = self.modelDet.get_inputs()
for key,value in inputs_det.items():
print("{}:{}".format(key,value))
print("det_outputs:")
outputs_det = self.modelDet.get_outputs() outputs_det = self.modelDet.get_outputs()
for key,value in outputs_det.items():
print("{}:{}".format(key,value))
self.inputName = self.modelDet.get_parameter_names()[0] self.inputName = self.modelDet.get_parameter_names()[0]
self.inputShape = inputs_det[self.inputName].lens() self.inputShape = inputs_det[self.inputName].lens()
print("DB inputName:{0} \nDB inputShape:{1}".format(self.inputName, self.inputShape)) print("DB inputName:{0} \nDB inputShape:{1}".format(self.inputName, self.inputShape))
...@@ -321,16 +310,8 @@ class det_rec_functions(object): ...@@ -321,16 +310,8 @@ class det_rec_functions(object):
self.modelRec = migraphx.parse_onnx(self.rec_file, map_input_dims=recInput) self.modelRec = migraphx.parse_onnx(self.rec_file, map_input_dims=recInput)
# 获取模型输入/输出节点信息 # 获取模型输入/输出节点信息
print("rec_inputs:")
inputs_rec = self.modelRec.get_inputs() inputs_rec = self.modelRec.get_inputs()
for key,value in inputs_rec.items():
print("{}:{}".format(key,value))
print("rec_outputs:")
outputs_rec = self.modelRec.get_outputs() outputs_rec = self.modelRec.get_outputs()
for key,value in outputs_rec.items():
print("{}:{}".format(key,value))
self.inputName = self.modelRec.get_parameter_names()[0] self.inputName = self.modelRec.get_parameter_names()[0]
self.inputShape = inputs_rec[self.inputName].lens() self.inputShape = inputs_rec[self.inputName].lens()
print("SVTR inputName:{0} \nSVTR inputShape:{1}".format(self.inputName, self.inputShape)) print("SVTR inputName:{0} \nSVTR inputShape:{1}".format(self.inputName, self.inputShape))
...@@ -480,8 +461,7 @@ class det_rec_functions(object): ...@@ -480,8 +461,7 @@ class det_rec_functions(object):
# migraphx推理 # migraphx推理
resultDets = self.modelDet.run({self.modelDet.get_parameter_names()[0]: img_part}) resultDets = self.modelDet.run({self.modelDet.get_parameter_names()[0]: img_part})
# 获取第一个输出节点的数据,migraphx.argument类型 resultDet = resultDets[0] # 获取第一个输出节点的数据,migraphx.argument类型
resultDet = resultDets[0]
outs_part = np.array(resultDet) outs_part = np.array(resultDet)
post_res_part = self.det_re_process_op(outs_part, shape_part_list) post_res_part = self.det_re_process_op(outs_part, shape_part_list)
...@@ -523,8 +503,7 @@ class det_rec_functions(object): ...@@ -523,8 +503,7 @@ class det_rec_functions(object):
# migraphx推理 # migraphx推理
results = self.modelRec.run({self.modelRec.get_parameter_names()[0]: img}) results = self.modelRec.run({self.modelRec.get_parameter_names()[0]: img})
# 获取第一个输出节点的数据,migraphx.argument类型 result = results[0] # 获取第一个输出节点的数据,migraphx.argument类型
result = results[0]
outs = np.array(result) outs = np.array(result)
result = process_op(outs) result = process_op(outs)
...@@ -552,12 +531,9 @@ class det_rec_functions(object): ...@@ -552,12 +531,9 @@ class det_rec_functions(object):
if __name__=='__main__': if __name__=='__main__':
image = cv2.imread('../Resource/Images/vlpr.jpg') image = cv2.imread('../Resource/Images/vlpr.jpg')
start = time.time()
ocr_sys = det_rec_functions(image) ocr_sys = det_rec_functions(image)
dt_boxes = ocr_sys.get_boxes() dt_boxes = ocr_sys.get_boxes()
results = ocr_sys.recognition_img(dt_boxes) results = ocr_sys.recognition_img(dt_boxes)
results_info = results[0][0] results_info = results[0][0]
print('net forward time: {:.4f}'.format(time.time() - start))
print("############# OCR Results #############") print("############# OCR Results #############")
print(results_info) print(results_info)
\ No newline at end of file
...@@ -25,7 +25,7 @@ ErrorCode DB::Initialize(InitializationParameterOfDB InitializationParameterOfDB ...@@ -25,7 +25,7 @@ ErrorCode DB::Initialize(InitializationParameterOfDB InitializationParameterOfDB
{ {
// 读取配置文件 // 读取配置文件
std::string configFilePath=InitializationParameterOfDB.configFilePath; std::string configFilePath=InitializationParameterOfDB.configFilePath;
if(Exists(configFilePath)==false) if(!Exists(configFilePath))
{ {
LOG_ERROR(stdout, "no configuration file!\n"); LOG_ERROR(stdout, "no configuration file!\n");
return CONFIG_FILE_NOT_EXIST; return CONFIG_FILE_NOT_EXIST;
...@@ -47,29 +47,19 @@ ErrorCode DB::Initialize(InitializationParameterOfDB InitializationParameterOfDB ...@@ -47,29 +47,19 @@ ErrorCode DB::Initialize(InitializationParameterOfDB InitializationParameterOfDB
dbParameter.ScoreMode = (string)netNode["ScoreMode"]; dbParameter.ScoreMode = (string)netNode["ScoreMode"];
// 加载模型 // 加载模型
if(Exists(modelPath)==false) if(!Exists(modelPath))
{ {
LOG_ERROR(stdout,"%s not exist!\n",modelPath.c_str()); LOG_ERROR(stdout,"%s not exist!\n",modelPath.c_str());
return MODEL_NOT_EXIST; return MODEL_NOT_EXIST;
} }
migraphx::onnx_options onnx_options; migraphx::onnx_options onnx_options;
onnx_options.map_input_dims["x"]={1,3,2496,2496}; // 设置最大shape onnx_options.map_input_dims["x"] = {1,3,2496,2496}; // 设置最大shape
net = migraphx::parse_onnx(modelPath, onnx_options); net = migraphx::parse_onnx(modelPath, onnx_options);
LOG_INFO(stdout,"succeed to load model: %s\n",GetFileName(modelPath).c_str()); LOG_INFO(stdout,"succeed to load model: %s\n",GetFileName(modelPath).c_str());
// 获取模型输入/输出节点信息 // 获取模型输入/输出节点信息
std::cout<<"DB_inputs:"<<std::endl;
std::unordered_map<std::string, migraphx::shape> inputs=net.get_inputs(); std::unordered_map<std::string, migraphx::shape> inputs=net.get_inputs();
for(auto i:inputs)
{
std::cout<<i.first<<":"<<i.second<<std::endl;
}
std::cout<<"DB_outputs:"<<std::endl;
std::unordered_map<std::string, migraphx::shape> outputs=net.get_outputs(); std::unordered_map<std::string, migraphx::shape> outputs=net.get_outputs();
for(auto i:outputs)
{
std::cout<<i.first<<":"<<i.second<<std::endl;
}
inputName=inputs.begin()->first; inputName=inputs.begin()->first;
inputShape=inputs.begin()->second; inputShape=inputs.begin()->second;
int N=inputShape.lens()[0]; int N=inputShape.lens()[0];
...@@ -161,6 +151,7 @@ ErrorCode DB::Infer(const cv::Mat &img, std::vector<cv::Mat> &imgList) ...@@ -161,6 +151,7 @@ ErrorCode DB::Infer(const cv::Mat &img, std::vector<cv::Mat> &imgList)
// 推理 // 推理
std::vector<migraphx::argument> inferenceResults = net.eval(inputData); std::vector<migraphx::argument> inferenceResults = net.eval(inputData);
// 获取推理结果 // 获取推理结果
migraphx::argument result = inferenceResults[0]; migraphx::argument result = inferenceResults[0];
......
...@@ -13,16 +13,14 @@ SVTR::SVTR() ...@@ -13,16 +13,14 @@ SVTR::SVTR()
SVTR::~SVTR() SVTR::~SVTR()
{ {
configurationFile.release(); configurationFile.release();
} }
ErrorCode SVTR::Initialize(InitializationParameterOfSVTR InitializationParameterOfSVTR) ErrorCode SVTR::Initialize(InitializationParameterOfSVTR InitializationParameterOfSVTR)
{ {
// 读取配置文件 // 读取配置文件
std::string configFilePath=InitializationParameterOfSVTR.configFilePath; std::string configFilePath=InitializationParameterOfSVTR.configFilePath;
if(Exists(configFilePath)==false) if(!Exists(configFilePath))
{ {
LOG_ERROR(stdout, "no configuration file!\n"); LOG_ERROR(stdout, "no configuration file!\n");
return CONFIG_FILE_NOT_EXIST; return CONFIG_FILE_NOT_EXIST;
...@@ -40,30 +38,19 @@ ErrorCode SVTR::Initialize(InitializationParameterOfSVTR InitializationParameter ...@@ -40,30 +38,19 @@ ErrorCode SVTR::Initialize(InitializationParameterOfSVTR InitializationParameter
std::string dictPath = (std::string)netNode["DictPath"]; std::string dictPath = (std::string)netNode["DictPath"];
// 加载模型 // 加载模型
if(Exists(modelPath)==false) if(!Exists(modelPath))
{ {
LOG_ERROR(stdout,"%s not exist!\n",modelPath.c_str()); LOG_ERROR(stdout,"%s not exist!\n",modelPath.c_str());
return MODEL_NOT_EXIST; return MODEL_NOT_EXIST;
} }
migraphx::onnx_options onnx_options; migraphx::onnx_options onnx_options;
onnx_options.map_input_dims["x"]={1,3,48,320}; // 设置最大shape onnx_options.map_input_dims["x"]={1,3,48,320}; // 设置最大shape
net = migraphx::parse_onnx(modelPath, onnx_options); net = migraphx::parse_onnx(modelPath, onnx_options);
LOG_INFO(stdout,"succeed to load model: %s\n",GetFileName(modelPath).c_str()); LOG_INFO(stdout,"succeed to load model: %s\n",GetFileName(modelPath).c_str());
// 获取模型输入/输出节点信息 // 获取模型输入/输出节点信息
std::cout<<"SVTR_inputs:"<<std::endl;
std::unordered_map<std::string, migraphx::shape> inputs=net.get_inputs(); std::unordered_map<std::string, migraphx::shape> inputs=net.get_inputs();
for(auto i:inputs)
{
std::cout<<i.first<<":"<<i.second<<std::endl;
}
std::cout<<"DSVTR_outputs:"<<std::endl;
std::unordered_map<std::string, migraphx::shape> outputs=net.get_outputs(); std::unordered_map<std::string, migraphx::shape> outputs=net.get_outputs();
for(auto i:outputs)
{
std::cout<<i.first<<":"<<i.second<<std::endl;
}
inputName=inputs.begin()->first; inputName=inputs.begin()->first;
inputShape=inputs.begin()->second; inputShape=inputs.begin()->second;
int N=inputShape.lens()[0]; int N=inputShape.lens()[0];
......
...@@ -10,9 +10,7 @@ VLPR::VLPR() ...@@ -10,9 +10,7 @@ VLPR::VLPR()
VLPR::~VLPR() VLPR::~VLPR()
{ {
configurationFile.release(); configurationFile.release();
} }
ErrorCode VLPR::Initialize(InitializationParameterOfDB initParamOfDB, InitializationParameterOfSVTR initParamOfSVTR) ErrorCode VLPR::Initialize(InitializationParameterOfDB initParamOfDB, InitializationParameterOfSVTR initParamOfSVTR)
......
...@@ -20,11 +20,7 @@ int main() ...@@ -20,11 +20,7 @@ int main()
// 推理 // 推理
std::vector<std::string> recTexts; std::vector<std::string> recTexts;
std::vector<float> recTextScores; std::vector<float> recTextScores;
double time1 = cv::getTickCount();
vlpr.Infer(Image, recTexts, recTextScores); vlpr.Infer(Image, recTexts, recTextScores);
double time2 = cv::getTickCount();
double elapsedTime = (time2 - time1)*1000 / cv::getTickFrequency();
LOG_INFO(stdout, "inference time:%f ms\n", elapsedTime);
// 打印结果 // 打印结果
for (int i = 0; i < recTexts.size(); i++) for (int i = 0; i < recTexts.size(); i++)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment