"vscode:/vscode.git/clone" did not exist on "d76af4d4c12bc0fe121b4530102f8f15964e98cc"
Commit 2c275bc7 authored by lijian6's avatar lijian6
Browse files

Update


Signed-off-by: lijian6's avatarlijian <lijian6@sugon.com>
parent 99304bcc
...@@ -5,7 +5,7 @@ CC = g++ ...@@ -5,7 +5,7 @@ CC = g++
CFLAGS = -std=c++17 CFLAGS = -std=c++17
INC_P = -I/opt/dtk/include -I/usr/local/include INC_P = -I/opt/dtk/include -I/usr/local/include
LIB_P = -L/opt/dtk/lib -L/usr/local/lib LIB_P = -L/opt/dtk/lib -L/usr/local/lib
LDLIBS = -lopencv_core -lopencv_imgcodecs -lopencv_dnn -lmigraphx -lmigraphx_device -lmigraphx_gpu -lmigraphx_onnx LDLIBS = -lopencv_core -lopencv_imgcodecs -lopencv_dnn -lopencv_imgproc -lmigraphx -lmigraphx_device -lmigraphx_gpu -lmigraphx_onnx
SRC_F = src/main.cpp SRC_F = src/main.cpp
EXEC = ViT_MIGraphX EXEC = ViT_MIGraphX
......
...@@ -73,6 +73,8 @@ make ...@@ -73,6 +73,8 @@ make
``` ```
运行ViT模型,对daisy图片进行分类 运行ViT模型,对daisy图片进行分类
## result
![img](./docs/result.jpg)
## 精度 ## 精度
......
...@@ -112,6 +112,34 @@ void postprocess(migraphx::argument result, int *n, string inputdir) ...@@ -112,6 +112,34 @@ void postprocess(migraphx::argument result, int *n, string inputdir)
} }
} }
} }
void postprocess_single(migraphx::argument result, int *n, cv::Mat &srcImage)
{
const char* labels[] = {"daisy", "dandelion", "roses", "sunflowers", "tulips"};
migraphx::shape outputShape = result.get_shape();
float *logits = (float *)result.data();
std::vector<float> logit;
for(int j=0; j<outputShape.elements(); ++j)
{
logit.push_back(logits[j]);
}
std::vector<float> probs = ComputeSoftmax(logit);
for (int j = 0; j < outputShape.elements(); ++j)
{
if (probs[j] >= 0.5)
{
char text[20];
char text1[20];
fprintf(stdout, "labels: %s, confidence: %.3f\n", labels[j], probs[j]);
snprintf(text, sizeof(text), "labels: %s", labels[j]);
snprintf(text1, sizeof(text1), "confidence: %.3f", probs[j]);
cv::putText(srcImage, text, cv::Point(8, 15), cv::FONT_HERSHEY_PLAIN, 1.0, Scalar(0, 0, 255), 1);
cv::putText(srcImage, text1, cv::Point(8, 25), cv::FONT_HERSHEY_PLAIN, 1.0, Scalar(0, 0, 255), 1);
cv::imwrite("result.jpg", srcImage);
}
}
}
int main(int argc, char *argv[]) int main(int argc, char *argv[])
{ {
if (argc < 3 || argc > 3) if (argc < 3 || argc > 3)
...@@ -173,7 +201,7 @@ int main(int argc, char *argv[]) ...@@ -173,7 +201,7 @@ int main(int argc, char *argv[])
std::vector<migraphx::argument> results = classifier.net.eval(inputData); std::vector<migraphx::argument> results = classifier.net.eval(inputData);
postprocess(results[0], &n, inputs); postprocess_single(results[0], &n, srcImage);
} }
return 0; return 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment