Commit c08438d8 authored by benjaminwan's avatar benjaminwan
Browse files

支持GPU(cuda)

parent 1587a9f5
...@@ -2,7 +2,24 @@ ...@@ -2,7 +2,24 @@
#include "OcrUtils.h" #include "OcrUtils.h"
#include <numeric> #include <numeric>
AngleNet::AngleNet() {} void AngleNet::setGpuIndex(int gpuIndex) {
#ifdef __CUDA__
if (gpuIndex >= 0) {
OrtCUDAProviderOptions cuda_options;
cuda_options.device_id = gpuIndex;
cuda_options.arena_extend_strategy = 0;
cuda_options.gpu_mem_limit = 2 * 1024 * 1024 * 1024;
cuda_options.cudnn_conv_algo_search = OrtCudnnConvAlgoSearch::OrtCudnnConvAlgoSearchExhaustive;
cuda_options.do_copy_in_default_stream = 1;
sessionOptions.AppendExecutionProvider_CUDA(cuda_options);
printf("cls try to use GPU%d\n", gpuIndex);
}
else {
printf("cls use CPU\n");
}
#endif
}
AngleNet::~AngleNet() { AngleNet::~AngleNet() {
delete session; delete session;
......
...@@ -3,7 +3,24 @@ ...@@ -3,7 +3,24 @@
#include <fstream> #include <fstream>
#include <numeric> #include <numeric>
CrnnNet::CrnnNet() {} void CrnnNet::setGpuIndex(int gpuIndex) {
#ifdef __CUDA__
if (gpuIndex >= 0) {
OrtCUDAProviderOptions cuda_options;
cuda_options.device_id = gpuIndex;
cuda_options.arena_extend_strategy = 0;
cuda_options.gpu_mem_limit = 2 * 1024 * 1024 * 1024;
cuda_options.cudnn_conv_algo_search = OrtCudnnConvAlgoSearch::OrtCudnnConvAlgoSearchExhaustive;
cuda_options.do_copy_in_default_stream = 1;
sessionOptions.AppendExecutionProvider_CUDA(cuda_options);
printf("rec try to use GPU%d\n", gpuIndex);
}
else {
printf("rec use CPU\n");
}
#endif
}
CrnnNet::~CrnnNet() { CrnnNet::~CrnnNet() {
delete session; delete session;
......
#include "DbNet.h" #include "DbNet.h"
#include "OcrUtils.h" #include "OcrUtils.h"
DbNet::DbNet() {} void DbNet::setGpuIndex(int gpuIndex) {
#ifdef __CUDA__
if (gpuIndex >= 0) {
OrtCUDAProviderOptions cuda_options;
cuda_options.device_id = gpuIndex;
cuda_options.arena_extend_strategy = 0;
cuda_options.gpu_mem_limit = 2 * 1024 * 1024 * 1024;
cuda_options.cudnn_conv_algo_search = OrtCudnnConvAlgoSearch::OrtCudnnConvAlgoSearchExhaustive;
cuda_options.do_copy_in_default_stream = 1;
sessionOptions.AppendExecutionProvider_CUDA(cuda_options);
printf("det try to use GPU%d\n", gpuIndex);
}
else {
printf("det use CPU\n");
}
#endif
}
DbNet::~DbNet() { DbNet::~DbNet() {
delete session; delete session;
......
...@@ -29,6 +29,12 @@ void OcrLite::enableResultTxt(const char *path, const char *imgName) { ...@@ -29,6 +29,12 @@ void OcrLite::enableResultTxt(const char *path, const char *imgName) {
resultTxt = fopen(resultTxtPath.c_str(), "w"); resultTxt = fopen(resultTxtPath.c_str(), "w");
} }
void OcrLite::setGpuIndex(int gpuIndex) {
dbNet.setGpuIndex(gpuIndex);
angleNet.setGpuIndex(-1);
crnnNet.setGpuIndex(gpuIndex);
}
bool OcrLite::initModels(const std::string &detPath, const std::string &clsPath, bool OcrLite::initModels(const std::string &detPath, const std::string &clsPath,
const std::string &recPath, const std::string &keysPath) { const std::string &recPath, const std::string &keysPath) {
Logger("=====Init Models=====\n"); Logger("=====Init Models=====\n");
......
...@@ -43,10 +43,11 @@ int main(int argc, char **argv) { ...@@ -43,10 +43,11 @@ int main(int argc, char **argv) {
int flagDoAngle = 1; int flagDoAngle = 1;
bool mostAngle = true; bool mostAngle = true;
int flagMostAngle = 1; int flagMostAngle = 1;
int flagGpu = -1;
int opt; int opt;
int optionIndex = 0; int optionIndex = 0;
while ((opt = getopt_long(argc, argv, "d:1:2:3:4:i:t:p:s:b:o:u:a:A:v:h", long_options, &optionIndex)) != -1) { while ((opt = getopt_long(argc, argv, "d:1:2:3:4:i:t:p:s:b:o:u:a:A:G:v:h", long_options, &optionIndex)) != -1) {
//printf("option(-%c)=%s\n", opt, optarg); //printf("option(-%c)=%s\n", opt, optarg);
switch (opt) { switch (opt) {
case 'd': case 'd':
...@@ -123,6 +124,9 @@ int main(int argc, char **argv) { ...@@ -123,6 +124,9 @@ int main(int argc, char **argv) {
case 'h': case 'h':
printHelp(stdout, argv[0]); printHelp(stdout, argv[0]);
return 0; return 0;
case 'G':
flagGpu = (int) strtol(optarg, NULL, 10);
break;
default: default:
printf("other option %c :%s\n", opt, optarg); printf("other option %c :%s\n", opt, optarg);
} }
...@@ -160,10 +164,12 @@ int main(int argc, char **argv) { ...@@ -160,10 +164,12 @@ int main(int argc, char **argv) {
true);//isOutputResultImg true);//isOutputResultImg
ocrLite.enableResultTxt(imgDir.c_str(), imgName.c_str()); ocrLite.enableResultTxt(imgDir.c_str(), imgName.c_str());
ocrLite.setGpuIndex(flagGpu);
ocrLite.Logger("=====Input Params=====\n"); ocrLite.Logger("=====Input Params=====\n");
ocrLite.Logger( ocrLite.Logger(
"numThread(%d),padding(%d),maxSideLen(%d),boxScoreThresh(%f),boxThresh(%f),unClipRatio(%f),doAngle(%d),mostAngle(%d)\n", "numThread(%d),padding(%d),maxSideLen(%d),boxScoreThresh(%f),boxThresh(%f),unClipRatio(%f),doAngle(%d),mostAngle(%d),GPU(%d)\n",
numThread, padding, maxSideLen, boxScoreThresh, boxThresh, unClipRatio, doAngle, mostAngle); numThread, padding, maxSideLen, boxScoreThresh, boxThresh, unClipRatio, doAngle, mostAngle,
flagGpu);
ocrLite.initModels(modelDetPath, modelClsPath, modelRecPath, keysPath); ocrLite.initModels(modelDetPath, modelClsPath, modelRecPath, keysPath);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment