Commit 56695530 authored by liucong8560's avatar liucong8560
Browse files

Merge branch 'develop' into 'master'

Develop

See merge request !2
parents 1d125612 eb88cab4
......@@ -3,10 +3,8 @@
#ifndef __FILE_SYSTEM_H__
#define __FILE_SYSTEM_H__
#include <vector>
#include <string>
using namespace std;
#include <vector>
namespace migraphxSamples
{
......@@ -21,7 +19,7 @@ bool IsDirectory(const std::string &path);
bool IsPathSeparator(char c);
// 路径拼接
string JoinPath(const std::string &base, const std::string &path);
std::string JoinPath(const std::string &base, const std::string &path);
// 创建多级目录,注意:创建多级目录的时候,目标目录是不能有文件存在的
bool CreateDirectories(const std::string &directoryPath);
......@@ -49,14 +47,13 @@ void Remove(const std::string &directory, const std::string &extension="");
/** 获取路径的文件名和扩展名
*
* 示例:path为D:/1/1.txt,则GetFileName()为1.txt,GetFileName_NoExtension()为1,GetExtension()为.txt,GetParentPath()为D:/1/
*/
string GetFileName(const std::string &path); // 1.txt
string GetFileName_NoExtension(const std::string &path); // 1
string GetExtension(const std::string &path);// .txt
string GetParentPath(const std::string &path);// D:/1/
std::string GetFileName(const std::string &path);
std::string GetFileName_NoExtension(const std::string &path);
std::string GetExtension(const std::string &path);
std::string GetParentPath(const std::string &path);
// 拷贝文件:CopyFile("D:/1.txt","D:/2.txt");将1.txt拷贝为2.txt
// 拷贝文件
bool CopyFile(const std::string srcPath,const std::string dstPath);
/** 拷贝目录
......
......@@ -19,7 +19,7 @@ using namespace std;
/** 简易日志
*
* 轻量级日志系统,不依赖于其他第三方库,只需要包含一个头文件就可以使用。提供了4种日志级别,包括INFO,DEBUG,WARN和ERROR。
* 不依赖于其他第三方库,只需要包含一个头文件就可以使用。提供了4种日志级别,包括INFO,DEBUG,WARN和ERROR。
*
* 示例1:
// 初始化日志,在./Log/目录下创建两个日志文件log1.log和log2.log(注意:目录./Log/需要存在,否则日志创建失败)
......
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <Sample.h>
void MIGraphXSamplesUsage(char* programName)
{
printf("Usage : %s <index> \n", programName);
printf("index:\n");
printf("\t 0) Bert sample.\n");
}
#include <Bert.h>
#include <SimpleLog.h>
#include <Filesystem.h>
#include <tokenization.h>
int main(int argc, char *argv[])
{
if (argc < 2 || argc > 2)
{
MIGraphXSamplesUsage(argv[0]);
return -1;
}
if (!strncmp(argv[1], "-h", 2))
// 加载Bert模型
migraphxSamples::Bert bert;
migraphxSamples::ErrorCode errorCode = bert.Initialize();
if (errorCode != migraphxSamples::SUCCESS)
{
MIGraphXSamplesUsage(argv[0]);
return 0;
LOG_ERROR(stdout, "fail to initialize Bert!\n");
exit(-1);
}
switch (*argv[1])
LOG_INFO(stdout, "succeed to initialize Bert\n");
int max_seq_length = 256; // 滑动窗口的长度
int max_query_length = 64; // 问题的最大长度
int batch_size = 1; // batch_size值
int n_best_size = 20; // 索引数量
int max_answer_length = 30; // 答案的最大长度
// 上下文文本数据
const char text[] = { u8"ROCm is the first open-source exascale-class platform for accelerated computing that’s also programming-language independent. It brings a philosophy of choice, minimalism and modular software development to GPU computing. You are free to choose or even develop tools and a language run time for your application. ROCm is built for scale, it supports multi-GPU computing and has a rich system run time with the critical features that large-scale application, compiler and language-run-time development requires. Since the ROCm ecosystem is comprised of open technologies: frameworks (Tensorflow / PyTorch), libraries (MIOpen / Blas / RCCL), programming model (HIP), inter-connect (OCD) and up streamed Linux® Kernel support – the platform is continually optimized for performance and extensibility." };
char question[100];
std::vector<std::vector<long unsigned int>> input_ids;
std::vector<std::vector<long unsigned int>> input_masks;
std::vector<std::vector<long unsigned int>> segment_ids;
std::vector<float> start_position;
std::vector<float> end_position;
std::string answer = {};
cuBERT::FullTokenizer tokenizer = cuBERT::FullTokenizer("../Resource/uncased_L-12_H-768_A-12/vocab.txt"); // 分词工具
while (true)
{
case '0':
{
Sample_Bert();
break;
}
default :
{
MIGraphXSamplesUsage(argv[0]);
break;
}
// 数据前处理
std::cout << "question: ";
cin.getline(question, 100);
bert.Preprocessing(tokenizer, batch_size, max_seq_length, text, question, input_ids, input_masks, segment_ids);
// 推理
bert.Inference(input_ids, input_masks, segment_ids, start_position, end_position);
// 数据后处理
bert.Postprocessing(n_best_size, max_answer_length, start_position, end_position, answer);
// 打印输出预测结果
std::cout << "answer: " << answer << std::endl;
// 清除数据
input_ids.clear();
input_masks.clear();
segment_ids.clear();
start_position.clear();
end_position.clear();
answer = {};
}
return 0;
}
\ No newline at end of file
./3rdParty/opencv-3.4.11_mini.tar.gz
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment