Commit 474298d4 authored by liucong's avatar liucong
Browse files

修改retinaface工程格式

parent 58b5daa5
......@@ -61,16 +61,8 @@ if __name__ == '__main__':
model = migraphx.parse_onnx("./FaceDetector.onnx")
# 获取模型输入/输出节点信息
print("inputs:")
inputs = model.get_inputs()
for key,value in inputs.items():
print("{}:{}".format(key,value))
print("outputs:")
outputs = model.get_outputs()
for key,value in outputs.items():
print("{}:{}".format(key,value))
inputName=model.get_parameter_names()[0]
inputShape=inputs[inputName].lens()
print("inputName:{0} \ninputShape:{1}".format(inputName,inputShape))
......@@ -100,9 +92,7 @@ if __name__ == '__main__':
img = img.to(device)
scale = scale.to(device)
tic = time.time()
loc, conf, landms = migraphx_run(model,args.cpu,img) # forward pass
print('net forward time: {:.4f}'.format(time.time() - tic))
priorbox = PriorBox(cfg, image_size=(im_height, im_width))
priors = priorbox.forward()
......
......@@ -24,6 +24,7 @@ namespace migraphxSamples
DetectorRetinaFace::DetectorRetinaFace()
{
}
DetectorRetinaFace::~DetectorRetinaFace()
......@@ -38,7 +39,7 @@ ErrorCode DetectorRetinaFace::Initialize(InitializationParameterOfDetector initi
{
// 读取配置文件
std::string configFilePath=initializationParameterOfDetector.configFilePath;
if(Exists(configFilePath)==false)
if(!Exists(configFilePath))
{
LOG_ERROR(stdout, "no configuration file!\n");
return CONFIG_FILE_NOT_EXIST;
......@@ -63,7 +64,7 @@ ErrorCode DetectorRetinaFace::Initialize(InitializationParameterOfDetector initi
useFP16=(bool)(int)netNode["UseFP16"];
// 加载模型
if(Exists(modelPath)==false)
if(!Exists(modelPath))
{
LOG_ERROR(stdout,"%s not exist!\n",modelPath.c_str());
return MODEL_NOT_EXIST;
......@@ -72,18 +73,8 @@ ErrorCode DetectorRetinaFace::Initialize(InitializationParameterOfDetector initi
LOG_INFO(stdout,"succeed to load model: %s\n",GetFileName(modelPath).c_str());
// 获取模型输入/输出节点信息
std::cout<<"inputs:"<<std::endl;
std::unordered_map<std::string, migraphx::shape> inputs=net.get_inputs();
for(auto i:inputs)
{
std::cout<<i.first<<":"<<i.second<<std::endl;
}
std::cout<<"outputs:"<<std::endl;
std::unordered_map<std::string, migraphx::shape> outputs=net.get_outputs();
for(auto i:outputs)
{
std::cout<<i.first<<":"<<i.second<<std::endl;
}
inputName=inputs.begin()->first;
inputShape=inputs.begin()->second;
int N=inputShape.lens()[0];
......@@ -105,6 +96,7 @@ ErrorCode DetectorRetinaFace::Initialize(InitializationParameterOfDetector initi
{
srcImages.push_back(srcImage);
}
cv::Mat inputBlob;
cv::dnn::blobFromImages(srcImages,
inputBlob,
......@@ -113,6 +105,7 @@ ErrorCode DetectorRetinaFace::Initialize(InitializationParameterOfDetector initi
meanValue,
swapRB,
false);
std::unordered_map<std::string, migraphx::argument> inputData;
inputData[inputName]= migraphx::argument{inputShape, (float*)inputBlob.data};
std::vector<std::unordered_map<std::string, migraphx::argument>> calibrationData = {inputData};
......@@ -134,7 +127,7 @@ ErrorCode DetectorRetinaFace::Initialize(InitializationParameterOfDetector initi
// warm up
std::unordered_map<std::string, migraphx::argument> inputData;
inputData[inputName]=migraphx::argument{inputShape};
inputData[inputName] = migraphx::argument{inputShape};
net.eval(inputData);
// log
......@@ -151,7 +144,6 @@ ErrorCode DetectorRetinaFace::Initialize(InitializationParameterOfDetector initi
GetSSDParameter();
return SUCCESS;
}
ErrorCode DetectorRetinaFace::Detect(const cv::Mat &srcImage,std::vector<ResultOfDetection> &resultsOfDetection)
......@@ -219,7 +211,6 @@ ErrorCode DetectorRetinaFace::Detect(const cv::Mat &srcImage,std::vector<ResultO
sort(resultsOfDetection.begin(), resultsOfDetection.end(),CompareConfidence);
return SUCCESS;
}
void DetectorRetinaFace::GetSSDParameter()
......@@ -1129,6 +1120,4 @@ void DetectorRetinaFace::CreateDetectionResults(std::vector<ResultOfDetection> &
}
}
}
......@@ -24,11 +24,7 @@ int main()
// 推理
std::vector<migraphxSamples::ResultOfDetection> predictions;
double time1 = cv::getTickCount();
detector.Detect(srcImage,predictions);
double time2 = cv::getTickCount();
double elapsedTime = (time2 - time1)*1000 / cv::getTickFrequency();
LOG_INFO(stdout, "inference time:%f ms\n", elapsedTime);
// 获取推理结果
LOG_INFO(stdout,"========== Detection Results ==========\n");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment