Commit aec6280a authored by liucong's avatar liucong
Browse files

对C++代码通过格式化

parent ab78f8ec
This diff is collapsed.
#ifndef __BERT_H__ #ifndef __BERT_H__
#define __BERT_H__ #define __BERT_H__
#include <tokenization.h>
#include <cstdint> #include <cstdint>
#include <string>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <tokenization.h> #include <string>
namespace migraphxSamples namespace migraphxSamples {
typedef enum _ErrorCode
{ {
typedef enum _ErrorCode SUCCESS = 0,
{ MODEL_NOT_EXIST,
SUCCESS=0, CONFIG_FILE_NOT_EXIST,
MODEL_NOT_EXIST, FAIL_TO_LOAD_MODEL,
CONFIG_FILE_NOT_EXIST, FAIL_TO_OPEN_CONFIG_FILE,
FAIL_TO_LOAD_MODEL, } ErrorCode;
FAIL_TO_OPEN_CONFIG_FILE,
}ErrorCode;
typedef struct _Sort_st typedef struct _Sort_st
{ {
int index; int index;
float value; float value;
}Sort_st; } Sort_st;
typedef struct _ResultOfPredictions typedef struct _ResultOfPredictions
{ {
int start_index; int start_index;
int end_index; int end_index;
float start_predictionvalue; float start_predictionvalue;
float end_predictionvalue; float end_predictionvalue;
}ResultOfPredictions; } ResultOfPredictions;
class Bert class Bert
{ {
public: public:
Bert(); Bert();
~Bert(); ~Bert();
ErrorCode Initialize(); ErrorCode Initialize();
ErrorCode Inference(const std::vector<std::vector<long unsigned int>> &input_ids, ErrorCode Inference(const std::vector<std::vector<long unsigned int>>& input_ids,
const std::vector<std::vector<long unsigned int>> &input_masks, const std::vector<std::vector<long unsigned int>>& input_masks,
const std::vector<std::vector<long unsigned int>> &segment_ids, const std::vector<std::vector<long unsigned int>>& segment_ids,
std::vector<float> &start_position, std::vector<float>& start_position,
std::vector<float> &end_position); std::vector<float>& end_position);
ErrorCode Preprocessing(cuBERT::FullTokenizer tokenizer, ErrorCode Preprocessing(cuBERT::FullTokenizer tokenizer,
int batch_size, int batch_size,
int max_seq_length, int max_seq_length,
const char *text, const char* text,
char *question, char* question,
std::vector<std::vector<long unsigned int>> &input_ids, std::vector<std::vector<long unsigned int>>& input_ids,
std::vector<std::vector<long unsigned int>> &input_masks, std::vector<std::vector<long unsigned int>>& input_masks,
std::vector<std::vector<long unsigned int>> &segment_ids); std::vector<std::vector<long unsigned int>>& segment_ids);
ErrorCode Postprocessing(int n_best_size, ErrorCode Postprocessing(int n_best_size,
int max_answer_length, int max_answer_length,
const std::vector<float> &start_position, const std::vector<float>& start_position,
const std::vector<float> &end_position, const std::vector<float>& end_position,
std::string &answer); std::string& answer);
private: private:
std::vector<std::string> tokens_text; std::vector<std::string> tokens_text;
std::vector<std::string> tokens_question; std::vector<std::string> tokens_question;
...@@ -74,9 +74,8 @@ private: ...@@ -74,9 +74,8 @@ private:
migraphx::shape inputShape2; migraphx::shape inputShape2;
migraphx::shape inputShape3; migraphx::shape inputShape3;
migraphx::shape inputShape4; migraphx::shape inputShape4;
}; };
} } // namespace migraphxSamples
#endif #endif
\ No newline at end of file
This diff is collapsed.
...@@ -5,66 +5,74 @@ ...@@ -5,66 +5,74 @@
#include <string> #include <string>
#include <vector> #include <vector>
namespace migraphxSamples namespace migraphxSamples {
{
// 路径是否存在 // 路径是否存在
bool Exists(const std::string &path); bool Exists(const std::string& path);
// 路径是否为目录 // 路径是否为目录
bool IsDirectory(const std::string &path); bool IsDirectory(const std::string& path);
// 是否是路径分隔符(Linux:‘/’,Windows:’\\’) // 是否是路径分隔符(Linux:‘/’,Windows:’\\’)
bool IsPathSeparator(char c); bool IsPathSeparator(char c);
// 路径拼接 // 路径拼接
std::string JoinPath(const std::string &base, const std::string &path); std::string JoinPath(const std::string& base, const std::string& path);
// 创建多级目录,注意:创建多级目录的时候,目标目录是不能有文件存在的 // 创建多级目录,注意:创建多级目录的时候,目标目录是不能有文件存在的
bool CreateDirectories(const std::string &directoryPath); bool CreateDirectories(const std::string& directoryPath);
/** 生成符合指定模式的文件名列表(支持递归遍历) /** 生成符合指定模式的文件名列表(支持递归遍历)
* *
* pattern: 模式,比如"*.jpg","*.png","*.jpg,*.png" * pattern: 模式,比如"*.jpg","*.png","*.jpg,*.png"
* addPath:是否包含父路径 * addPath:是否包含父路径
* 注意: * 注意:
1. 多个模式使用","分割,比如"*.jpg,*.png" 1. 多个模式使用","分割,比如"*.jpg,*.png"
2. 支持通配符'*','?' ,比如第一个字符是7的所有文件名:"7*.*", 以512结尾的所有jpg文件名:"*512.jpg" 2. 支持通配符'*','?' ,比如第一个字符是7的所有文件名:"7*.*",
以512结尾的所有jpg文件名:"*512.jpg"
3. 使用"*.jpg",而不是".jpg" 3. 使用"*.jpg",而不是".jpg"
4. 空string表示返回所有结果 4. 空string表示返回所有结果
5. 不能返回子目录名 5. 不能返回子目录名
* *
*/ */
void GetFileNameList(const std::string &directory, const std::string &pattern, std::vector<std::string> &result, bool recursive, bool addPath); void GetFileNameList(const std::string& directory,
const std::string& pattern,
std::vector<std::string>& result,
bool recursive,
bool addPath);
// 与GetFileNameList的区别在于如果有子目录,在addPath为true的时候会返回子目录路径(目录名最后有"/") // 与GetFileNameList的区别在于如果有子目录,在addPath为true的时候会返回子目录路径(目录名最后有"/")
void GetFileNameList2(const std::string &directory, const std::string &pattern, std::vector<std::string> &result, bool recursive, bool addPath); void GetFileNameList2(const std::string& directory,
const std::string& pattern,
std::vector<std::string>& result,
bool recursive,
bool addPath);
// 删除文件或者目录,支持递归删除 // 删除文件或者目录,支持递归删除
void Remove(const std::string &directory, const std::string &extension=""); void Remove(const std::string& directory, const std::string& extension = "");
/** 获取路径的文件名和扩展名 /** 获取路径的文件名和扩展名
* *
* 示例:path为D:/1/1.txt,则GetFileName()为1.txt,GetFileName_NoExtension()为1,GetExtension()为.txt,GetParentPath()为D:/1/ * 示例:path为D:/1/1.txt,则GetFileName()为1.txt,GetFileName_NoExtension()为1,GetExtension()为.txt,GetParentPath()为D:/1/
*/ */
std::string GetFileName(const std::string &path); std::string GetFileName(const std::string& path);
std::string GetFileName_NoExtension(const std::string &path); std::string GetFileName_NoExtension(const std::string& path);
std::string GetExtension(const std::string &path); std::string GetExtension(const std::string& path);
std::string GetParentPath(const std::string &path); std::string GetParentPath(const std::string& path);
// 拷贝文件 // 拷贝文件
bool CopyFile(const std::string srcPath,const std::string dstPath); bool CopyFile(const std::string srcPath, const std::string dstPath);
/** 拷贝目录 /** 拷贝目录
* *
* 示例:CopyDirectories("D:/0/1/2/","E:/3/");实现把D:/0/1/2/目录拷贝到E:/3/目录中(即拷贝完成后的目录结构为E:/3/2/) * 示例:CopyDirectories("D:/0/1/2/","E:/3/");实现把D:/0/1/2/目录拷贝到E:/3/目录中(即拷贝完成后的目录结构为E:/3/2/)
* 注意: * 注意:
1.第一个参数的最后不能加”/” 1.第一个参数的最后不能加”/”
2.不能拷贝隐藏文件 2.不能拷贝隐藏文件
*/ */
bool CopyDirectories(std::string srcPath,const std::string dstPath); bool CopyDirectories(std::string srcPath, const std::string dstPath);
} } // namespace migraphxSamples
#endif #endif
...@@ -4,11 +4,13 @@ ...@@ -4,11 +4,13 @@
#define __SIMPLE_LOG_H__ #define __SIMPLE_LOG_H__
#include <time.h> #include <time.h>
#include <string>
#include <map> #include <map>
#include <thread>
#include <mutex> #include <mutex>
#if (defined WIN32 || defined _WIN32) #include <string>
#include <thread>
#if(defined WIN32 || defined _WIN32)
#include <Windows.h> #include <Windows.h>
#else #else
#include <sys/time.h> #include <sys/time.h>
...@@ -16,29 +18,31 @@ ...@@ -16,29 +18,31 @@
using namespace std; using namespace std;
/** 简易日志 /** 简易日志
* *
* 不依赖于其他第三方库,只需要包含一个头文件就可以使用。提供了4种日志级别,包括INFO,DEBUG,WARN和ERROR。 * 不依赖于其他第三方库,只需要包含一个头文件就可以使用。提供了4种日志级别,包括INFO,DEBUG,WARN和ERROR。
* *
* 示例1: * 示例1:
// 初始化日志,在./Log/目录下创建两个日志文件log1.log和log2.log(注意:目录./Log/需要存在,否则日志创建失败) //
初始化日志,在./Log/目录下创建两个日志文件log1.log和log2.log(注意:目录./Log/需要存在,否则日志创建失败)
LogManager::GetInstance()->Initialize("./Log/","log1"); LogManager::GetInstance()->Initialize("./Log/","log1");
LogManager::GetInstance()->Initialize("./Log/","log2"); LogManager::GetInstance()->Initialize("./Log/","log2");
// 写日志 // 写日志
string log = "Hello World"; string log = "Hello World";
LOG_INFO(LogManager::GetInstance()->GetLogFile("log1"), "%s\n", log.c_str()); // 写入log1.log LOG_INFO(LogManager::GetInstance()->GetLogFile("log1"), "%s\n",
LOG_INFO(LogManager::GetInstance()->GetLogFile("log2"), "%s\n", log.c_str()); // 写入log2.log log.c_str()); // 写入log1.log
LOG_INFO(LogManager::GetInstance()->GetLogFile("log2"), "%s\n",
log.c_str()); // 写入log2.log
// 关闭日志 // 关闭日志
LogManager::GetInstance()->Close("log1"); LogManager::GetInstance()->Close("log1");
LogManager::GetInstance()->Close("log2"); LogManager::GetInstance()->Close("log2");
* 示例2: * 示例2:
// 将日志输出到控制台 // 将日志输出到控制台
string log = "Hello World"; string log = "Hello World";
LOG_INFO(stdout, "%s\n", log.c_str()); LOG_INFO(stdout, "%s\n", log.c_str());
* 注意: * 注意:
1. 需要C++11 1. 需要C++11
...@@ -50,44 +54,43 @@ using namespace std; ...@@ -50,44 +54,43 @@ using namespace std;
class LogManager class LogManager
{ {
private: private:
LogManager(){} LogManager() {}
public: public:
~LogManager(){} ~LogManager() {}
inline void Initialize(const string &parentPath,const string &logName) inline void Initialize(const string& parentPath, const string& logName)
{ {
// 日志名为空表示输出到控制台 // 日志名为空表示输出到控制台
if(logName.size()==0) if(logName.size() == 0)
return; return;
// 查找该日志文件,如果没有则创建 // 查找该日志文件,如果没有则创建
std::map<string, FILE*>::const_iterator iter = logMap.find(logName); std::map<string, FILE*>::const_iterator iter = logMap.find(logName);
if (iter == logMap.end()) if(iter == logMap.end())
{ {
string pathOfLog = parentPath+ logName + ".log"; string pathOfLog = parentPath + logName + ".log";
FILE *logFile = fopen(pathOfLog.c_str(), "a"); // w:覆盖原有文件,a:追加 FILE* logFile = fopen(pathOfLog.c_str(), "a"); // w:覆盖原有文件,a:追加
if(logFile!=NULL) if(logFile != NULL)
{ {
logMap.insert(std::make_pair(logName, logFile)); logMap.insert(std::make_pair(logName, logFile));
} }
} }
} }
inline FILE* GetLogFile(const string &logName) inline FILE* GetLogFile(const string& logName)
{ {
std::map<string, FILE*>::const_iterator iter=logMap.find(logName); std::map<string, FILE*>::const_iterator iter = logMap.find(logName);
if(iter==logMap.end()) if(iter == logMap.end())
{ {
return NULL; return NULL;
} }
return (*iter).second; return (*iter).second;
} }
inline void Close(const string &logName) inline void Close(const string& logName)
{ {
std::map<string, FILE*>::const_iterator iter=logMap.find(logName); std::map<string, FILE*>::const_iterator iter = logMap.find(logName);
if(iter==logMap.end()) if(iter == logMap.end())
{ {
return; return;
} }
...@@ -95,10 +98,7 @@ public: ...@@ -95,10 +98,7 @@ public:
fclose((*iter).second); fclose((*iter).second);
logMap.erase(iter); logMap.erase(iter);
} }
inline std::mutex &GetLogMutex() inline std::mutex& GetLogMutex() { return logMutex; }
{
return logMutex;
}
// Singleton // Singleton
static LogManager* GetInstance() static LogManager* GetInstance()
...@@ -106,21 +106,22 @@ public: ...@@ -106,21 +106,22 @@ public:
static LogManager logManager; static LogManager logManager;
return &logManager; return &logManager;
} }
private:
private:
std::map<string, FILE*> logMap; std::map<string, FILE*> logMap;
std::mutex logMutex; std::mutex logMutex;
}; };
#ifdef LOG_MUTEX #ifdef LOG_MUTEX
#define LOCK LogManager::GetInstance()->GetLogMutex().lock() #define LOCK LogManager::GetInstance()->GetLogMutex().lock()
#define UNLOCK LogManager::GetInstance()->GetLogMutex().unlock() #define UNLOCK LogManager::GetInstance()->GetLogMutex().unlock()
#else #else
#define LOCK #define LOCK
#define UNLOCK #define UNLOCK
#endif #endif
// log time // log time
typedef struct _LogTime typedef struct _LogTime
{ {
string year; string year;
string month; string month;
...@@ -131,53 +132,53 @@ typedef struct _LogTime ...@@ -131,53 +132,53 @@ typedef struct _LogTime
string millisecond; // ms string millisecond; // ms
string microsecond; // us string microsecond; // us
string weekDay; string weekDay;
}LogTime; } LogTime;
inline LogTime GetTime() inline LogTime GetTime()
{ {
LogTime currentTime; LogTime currentTime;
#if (defined WIN32 || defined _WIN32) #if(defined WIN32 || defined _WIN32)
SYSTEMTIME systemTime; SYSTEMTIME systemTime;
GetLocalTime(&systemTime); GetLocalTime(&systemTime);
char temp[8] = { 0 }; char temp[8] = {0};
sprintf(temp, "%04d", systemTime.wYear); sprintf(temp, "%04d", systemTime.wYear);
currentTime.year=string(temp); currentTime.year = string(temp);
sprintf(temp, "%02d", systemTime.wMonth); sprintf(temp, "%02d", systemTime.wMonth);
currentTime.month=string(temp); currentTime.month = string(temp);
sprintf(temp, "%02d", systemTime.wDay); sprintf(temp, "%02d", systemTime.wDay);
currentTime.day=string(temp); currentTime.day = string(temp);
sprintf(temp, "%02d", systemTime.wHour); sprintf(temp, "%02d", systemTime.wHour);
currentTime.hour=string(temp); currentTime.hour = string(temp);
sprintf(temp, "%02d", systemTime.wMinute); sprintf(temp, "%02d", systemTime.wMinute);
currentTime.minute=string(temp); currentTime.minute = string(temp);
sprintf(temp, "%02d", systemTime.wSecond); sprintf(temp, "%02d", systemTime.wSecond);
currentTime.second=string(temp); currentTime.second = string(temp);
sprintf(temp, "%03d", systemTime.wMilliseconds); sprintf(temp, "%03d", systemTime.wMilliseconds);
currentTime.millisecond=string(temp); currentTime.millisecond = string(temp);
sprintf(temp, "%d", systemTime.wDayOfWeek); sprintf(temp, "%d", systemTime.wDayOfWeek);
currentTime.weekDay=string(temp); currentTime.weekDay = string(temp);
#else #else
struct timeval tv; struct timeval tv;
struct tm *p; struct tm* p;
gettimeofday(&tv, NULL); gettimeofday(&tv, NULL);
p = localtime(&tv.tv_sec); p = localtime(&tv.tv_sec);
char temp[8]={0}; char temp[8] = {0};
sprintf(temp,"%04d",1900+p->tm_year); sprintf(temp, "%04d", 1900 + p->tm_year);
currentTime.year=string(temp); currentTime.year = string(temp);
sprintf(temp,"%02d",1+p->tm_mon); sprintf(temp, "%02d", 1 + p->tm_mon);
currentTime.month=string(temp); currentTime.month = string(temp);
sprintf(temp,"%02d",p->tm_mday); sprintf(temp, "%02d", p->tm_mday);
currentTime.day=string(temp); currentTime.day = string(temp);
sprintf(temp,"%02d",p->tm_hour); sprintf(temp, "%02d", p->tm_hour);
currentTime.hour=string(temp); currentTime.hour = string(temp);
sprintf(temp,"%02d",p->tm_min); sprintf(temp, "%02d", p->tm_min);
currentTime.minute=string(temp); currentTime.minute = string(temp);
sprintf(temp,"%02d",p->tm_sec); sprintf(temp, "%02d", p->tm_sec);
currentTime.second=string(temp); currentTime.second = string(temp);
sprintf(temp,"%03d",(int)(tv.tv_usec/1000)); sprintf(temp, "%03d", (int)(tv.tv_usec / 1000));
currentTime.millisecond = string(temp); currentTime.millisecond = string(temp);
sprintf(temp, "%03d", (int)(tv.tv_usec % 1000)); sprintf(temp, "%03d", (int)(tv.tv_usec % 1000));
currentTime.microsecond = string(temp); currentTime.microsecond = string(temp);
...@@ -187,61 +188,83 @@ inline LogTime GetTime() ...@@ -187,61 +188,83 @@ inline LogTime GetTime()
return currentTime; return currentTime;
} }
#define LOG_TIME(logFile) \ #define LOG_TIME(logFile) \
do\ do \
{\ { \
LogTime currentTime=GetTime(); \ LogTime currentTime = GetTime(); \
fprintf(((logFile == NULL) ? stdout : logFile), "%s-%s-%s %s:%s:%s.%s\t",currentTime.year.c_str(),currentTime.month.c_str(),currentTime.day.c_str(),currentTime.hour.c_str(),currentTime.minute.c_str(),currentTime.second.c_str(),currentTime.millisecond.c_str()); \ fprintf(((logFile == NULL) ? stdout : logFile), \
}while (0) "%s-%s-%s %s:%s:%s.%s\t", \
currentTime.year.c_str(), \
currentTime.month.c_str(), \
#define LOG_INFO(logFile,logInfo, ...) \ currentTime.day.c_str(), \
do\ currentTime.hour.c_str(), \
{\ currentTime.minute.c_str(), \
LOCK; \ currentTime.second.c_str(), \
LOG_TIME(logFile); \ currentTime.millisecond.c_str()); \
fprintf(((logFile == NULL) ? stdout : logFile), "INFO\t"); \ } while(0)
fprintf(((logFile == NULL) ? stdout : logFile), "[%s:%d (%s) ]: ", __FILE__, __LINE__, __FUNCTION__); \
fprintf(((logFile == NULL) ? stdout : logFile), logInfo, ## __VA_ARGS__); \
fflush(logFile); \
UNLOCK; \
} while (0)
#define LOG_DEBUG(logFile,logInfo, ...) \
do\
{\
LOCK; \
LOG_TIME(logFile);\
fprintf(((logFile==NULL)?stdout:logFile), "DEBUG\t"); \
fprintf(((logFile==NULL)?stdout:logFile), "[%s:%d (%s) ]: ", __FILE__, __LINE__, __FUNCTION__); \
fprintf(((logFile==NULL)?stdout:logFile),logInfo, ## __VA_ARGS__); \
fflush(logFile); \
UNLOCK; \
} while (0)
#define LOG_ERROR(logFile,logInfo, ...) \
do\
{\
LOCK; \
LOG_TIME(logFile);\
fprintf(((logFile==NULL)?stdout:logFile), "ERROR\t"); \
fprintf(((logFile==NULL)?stdout:logFile), "[%s:%d (%s) ]: ", __FILE__, __LINE__, __FUNCTION__); \
fprintf(((logFile==NULL)?stdout:logFile),logInfo, ## __VA_ARGS__); \
fflush(logFile); \
UNLOCK; \
} while (0)
#define LOG_WARN(logFile,logInfo, ...) \
do\
{\
LOCK; \
LOG_TIME(logFile);\
fprintf(((logFile==NULL)?stdout:logFile), "WARN\t"); \
fprintf(((logFile==NULL)?stdout:logFile), "[%s:%d (%s) ]: ", __FILE__, __LINE__, __FUNCTION__); \
fprintf(((logFile==NULL)?stdout:logFile),logInfo, ## __VA_ARGS__); \
fflush(logFile); \
UNLOCK; \
} while (0)
#endif // __SIMPLE_LOG_H__ #define LOG_INFO(logFile, logInfo, ...) \
do \
{ \
LOCK; \
LOG_TIME(logFile); \
fprintf(((logFile == NULL) ? stdout : logFile), "INFO\t"); \
fprintf(((logFile == NULL) ? stdout : logFile), \
"[%s:%d (%s) ]: ", \
__FILE__, \
__LINE__, \
__FUNCTION__); \
fprintf(((logFile == NULL) ? stdout : logFile), logInfo, ##__VA_ARGS__); \
fflush(logFile); \
UNLOCK; \
} while(0)
#define LOG_DEBUG(logFile, logInfo, ...) \
do \
{ \
LOCK; \
LOG_TIME(logFile); \
fprintf(((logFile == NULL) ? stdout : logFile), "DEBUG\t"); \
fprintf(((logFile == NULL) ? stdout : logFile), \
"[%s:%d (%s) ]: ", \
__FILE__, \
__LINE__, \
__FUNCTION__); \
fprintf(((logFile == NULL) ? stdout : logFile), logInfo, ##__VA_ARGS__); \
fflush(logFile); \
UNLOCK; \
} while(0)
#define LOG_ERROR(logFile, logInfo, ...) \
do \
{ \
LOCK; \
LOG_TIME(logFile); \
fprintf(((logFile == NULL) ? stdout : logFile), "ERROR\t"); \
fprintf(((logFile == NULL) ? stdout : logFile), \
"[%s:%d (%s) ]: ", \
__FILE__, \
__LINE__, \
__FUNCTION__); \
fprintf(((logFile == NULL) ? stdout : logFile), logInfo, ##__VA_ARGS__); \
fflush(logFile); \
UNLOCK; \
} while(0)
#define LOG_WARN(logFile, logInfo, ...) \
do \
{ \
LOCK; \
LOG_TIME(logFile); \
fprintf(((logFile == NULL) ? stdout : logFile), "WARN\t"); \
fprintf(((logFile == NULL) ? stdout : logFile), \
"[%s:%d (%s) ]: ", \
__FILE__, \
__LINE__, \
__FUNCTION__); \
fprintf(((logFile == NULL) ? stdout : logFile), logInfo, ##__VA_ARGS__); \
fflush(logFile); \
UNLOCK; \
} while(0)
#endif // __SIMPLE_LOG_H__
#include <stdexcept>
#include <algorithm> #include <algorithm>
#include <cstring> #include <cstring>
#include <fstream> #include <fstream>
#include "utf8proc.h" #include <stdexcept>
#include "./tokenization.h" #include "./tokenization.h"
#include "utf8proc.h"
namespace cuBERT { namespace cuBERT {
void FullTokenizer::convert_tokens_to_ids(const std::vector<std::string> &tokens, uint64_t *ids) { void FullTokenizer::convert_tokens_to_ids(const std::vector<std::string>& tokens, uint64_t* ids)
for (int i = 0; i < tokens.size(); ++i) { {
ids[i] = convert_token_to_id(tokens[i]); for(int i = 0; i < tokens.size(); ++i)
} {
ids[i] = convert_token_to_id(tokens[i]);
} }
}
// trim from start (in place) // trim from start (in place)
static inline void ltrim(std::string &s) { static inline void ltrim(std::string& s)
s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int ch) { {
return !std::isspace(ch); s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int ch) { return !std::isspace(ch); }));
})); }
}
// trim from end (in place) // trim from end (in place)
static inline void rtrim(std::string &s) { static inline void rtrim(std::string& s)
s.erase(std::find_if(s.rbegin(), s.rend(), [](int ch) { {
return !std::isspace(ch); s.erase(std::find_if(s.rbegin(), s.rend(), [](int ch) { return !std::isspace(ch); }).base(),
}).base(), s.end()); s.end());
} }
// trim from both ends (in place) // trim from both ends (in place)
static inline void trim(std::string &s) { static inline void trim(std::string& s)
ltrim(s); {
rtrim(s); ltrim(s);
} rtrim(s);
}
void load_vocab(const char *vocab_file, std::unordered_map<std::string, uint64_t> *vocab) {
std::ifstream file(vocab_file);
if (!file) {
throw std::invalid_argument("Unable to open vocab file");
}
unsigned int index = 0;
std::string line;
while (std::getline(file, line)) {
trim(line);
(*vocab)[line] = index;
index++;
}
file.close(); void load_vocab(const char* vocab_file, std::unordered_map<std::string, uint64_t>* vocab)
{
std::ifstream file(vocab_file);
if(!file)
{
throw std::invalid_argument("Unable to open vocab file");
} }
inline bool _is_whitespace(int c, const char *cat) { unsigned int index = 0;
if (c == ' ' || c == '\t' || c == '\n' || c == '\r') { std::string line;
return true; while(std::getline(file, line))
} {
return cat[0] == 'Z' && cat[1] == 's'; trim(line);
(*vocab)[line] = index;
index++;
} }
inline bool _is_control(int c, const char *cat) { file.close();
// These are technically control characters but we count them as whitespace characters. }
if (c == '\t' || c == '\n' || c == '\r') {
return false;
}
return 'C' == *cat;
}
inline bool _is_punctuation(int cp, const char *cat) { inline bool _is_whitespace(int c, const char* cat)
// We treat all non-letter/number ASCII as punctuation. {
// Characters such as "^", "$", and "`" are not in the Unicode if(c == ' ' || c == '\t' || c == '\n' || c == '\r')
// Punctuation class but we treat them as punctuation anyways, for {
// consistency. return true;
if ((cp >= 33 && cp <= 47) || (cp >= 58 && cp <= 64) ||
(cp >= 91 && cp <= 96) || (cp >= 123 && cp <= 126)) {
return true;
}
return 'P' == *cat;
} }
return cat[0] == 'Z' && cat[1] == 's';
}
bool _is_whitespace(int c) { inline bool _is_control(int c, const char* cat)
return _is_whitespace(c, utf8proc_category_string(c)); {
// These are technically control characters but we count them as whitespace
// characters.
if(c == '\t' || c == '\n' || c == '\r')
{
return false;
} }
return 'C' == *cat;
}
bool _is_control(int c) { inline bool _is_punctuation(int cp, const char* cat)
return _is_control(c, utf8proc_category_string(c)); {
// We treat all non-letter/number ASCII as punctuation.
// Characters such as "^", "$", and "`" are not in the Unicode
// Punctuation class but we treat them as punctuation anyways, for
// consistency.
if((cp >= 33 && cp <= 47) || (cp >= 58 && cp <= 64) || (cp >= 91 && cp <= 96) ||
(cp >= 123 && cp <= 126))
{
return true;
} }
return 'P' == *cat;
}
bool _is_punctuation(int cp) { bool _is_whitespace(int c) { return _is_whitespace(c, utf8proc_category_string(c)); }
return _is_punctuation(cp, utf8proc_category_string(cp));
} bool _is_control(int c) { return _is_control(c, utf8proc_category_string(c)); }
bool _is_punctuation(int cp) { return _is_punctuation(cp, utf8proc_category_string(cp)); }
bool BasicTokenizer::_is_chinese_char(int cp)
{
// This defines a "chinese character" as anything in the CJK Unicode block:
// https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
//
// Note that the CJK Unicode block is NOT all Japanese and Korean characters,
// despite its name. The modern Korean Hangul alphabet is a different block,
// as is Japanese Hiragana and Katakana. Those alphabets are used to write
// space-separated words, so they are not treated specially and handled
// like the all of the other languages.
return (cp >= 0x4E00 && cp <= 0x9FFF) || (cp >= 0x3400 && cp <= 0x4DBF) ||
(cp >= 0x20000 && cp <= 0x2A6DF) || (cp >= 0x2A700 && cp <= 0x2B73F) ||
(cp >= 0x2B740 && cp <= 0x2B81F) || (cp >= 0x2B820 && cp <= 0x2CEAF) ||
(cp >= 0xF900 && cp <= 0xFAFF) || (cp >= 0x2F800 && cp <= 0x2FA1F);
}
bool BasicTokenizer::_is_chinese_char(int cp) { void BasicTokenizer::tokenize(const char* text,
// This defines a "chinese character" as anything in the CJK Unicode block: std::vector<std::string>* output_tokens,
// https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) size_t max_length)
// {
// Note that the CJK Unicode block is NOT all Japanese and Korean characters, // This was added on November 1st, 2018 for the multilingual and Chinese
// despite its name. The modern Korean Hangul alphabet is a different block, // models. This is also applied to the English models now, but it doesn't
// as is Japanese Hiragana and Katakana. Those alphabets are used to write // matter since the English models were not trained on any Chinese data
// space-separated words, so they are not treated specially and handled // and generally don't have any Chinese data in them (there are Chinese
// like the all of the other languages. // characters in the vocabulary because Wikipedia does have some Chinese
return (cp >= 0x4E00 && cp <= 0x9FFF) || // words in the English Wikipedia.).
(cp >= 0x3400 && cp <= 0x4DBF) || if(do_lower_case)
(cp >= 0x20000 && cp <= 0x2A6DF) || {
(cp >= 0x2A700 && cp <= 0x2B73F) || text = (const char*)utf8proc_NFD((const utf8proc_uint8_t*)text);
(cp >= 0x2B740 && cp <= 0x2B81F) ||
(cp >= 0x2B820 && cp <= 0x2CEAF) ||
(cp >= 0xF900 && cp <= 0xFAFF) ||
(cp >= 0x2F800 && cp <= 0x2FA1F);
} }
void BasicTokenizer::tokenize(const char *text, std::vector<std::string> *output_tokens, size_t max_length) { size_t word_bytes = std::strlen(text);
// This was added on November 1st, 2018 for the multilingual and Chinese bool new_token = true;
// models. This is also applied to the English models now, but it doesn't size_t subpos = 0;
// matter since the English models were not trained on any Chinese data int cp;
// and generally don't have any Chinese data in them (there are Chinese char dst[4];
// characters in the vocabulary because Wikipedia does have some Chinese
// words in the English Wikipedia.). while(word_bytes > 0)
if (do_lower_case) { {
text = (const char *) utf8proc_NFD((const utf8proc_uint8_t *) text); int len = utf8proc_iterate((const utf8proc_uint8_t*)text + subpos, word_bytes, &cp);
if(len < 0)
{
std::cerr << "UTF-8 decode error: " << text << std::endl;
break;
}
if(do_lower_case)
{
cp = utf8proc_tolower(cp);
} }
size_t word_bytes = std::strlen(text); const char* cat = utf8proc_category_string(cp);
bool new_token = true; if(cp == 0 || cp == 0xfffd || _is_control(cp, cat))
size_t subpos = 0; {
int cp; // pass
char dst[4]; }
else if(do_lower_case && cat[0] == 'M' && cat[1] == 'n')
while (word_bytes > 0) { {
int len = utf8proc_iterate((const utf8proc_uint8_t *) text + subpos, word_bytes, &cp); // pass
if (len < 0) { }
std::cerr << "UTF-8 decode error: " << text << std::endl; else if(_is_whitespace(cp, cat))
break; {
} new_token = true;
if (do_lower_case) { }
cp = utf8proc_tolower(cp); else
{
size_t dst_len = len;
const char* dst_ptr = text + subpos;
if(do_lower_case)
{
dst_len = utf8proc_encode_char(cp, (utf8proc_uint8_t*)dst);
dst_ptr = dst;
} }
const char *cat = utf8proc_category_string(cp); if(_is_punctuation(cp, cat) || _is_chinese_char(cp))
if (cp == 0 || cp == 0xfffd || _is_control(cp, cat)) { {
// pass output_tokens->emplace_back(dst_ptr, dst_len);
} else if (do_lower_case && cat[0] == 'M' && cat[1] == 'n') {
// pass
} else if (_is_whitespace(cp, cat)) {
new_token = true; new_token = true;
} else { }
size_t dst_len = len; else
const char *dst_ptr = text + subpos; {
if (do_lower_case) { if(new_token)
dst_len = utf8proc_encode_char(cp, (utf8proc_uint8_t *) dst); {
dst_ptr = dst;
}
if (_is_punctuation(cp, cat) || _is_chinese_char(cp)) {
output_tokens->emplace_back(dst_ptr, dst_len); output_tokens->emplace_back(dst_ptr, dst_len);
new_token = true; new_token = false;
} else { }
if (new_token) { else
output_tokens->emplace_back(dst_ptr, dst_len); {
new_token = false; output_tokens->at(output_tokens->size() - 1).append(dst_ptr, dst_len);
} else {
output_tokens->at(output_tokens->size() - 1).append(dst_ptr, dst_len);
}
} }
} }
}
word_bytes = word_bytes - len; word_bytes = word_bytes - len;
subpos = subpos + len; subpos = subpos + len;
// early terminate // early terminate
if (output_tokens->size() >= max_length) { if(output_tokens->size() >= max_length)
break; {
} break;
} }
}
if (do_lower_case) { if(do_lower_case)
free((void *) text); {
} free((void*)text);
} }
}
void WordpieceTokenizer::tokenize(const std::string& token, std::vector<std::string>* output_tokens)
{
if(token.size() > max_input_chars_per_word)
{ // FIXME: slightly different
output_tokens->push_back(unk_token);
return;
}
size_t output_tokens_len = output_tokens->size();
for(size_t start = 0; start < token.size();)
{
bool is_bad = true;
// TODO: can be optimized by prefix-tree
for(size_t end = token.size(); start < end; --end)
{ // FIXME: slightly different
std::string substr = start > 0 ? "##" + token.substr(start, end - start)
: token.substr(start, end - start);
if(vocab->count(substr))
{
is_bad = false;
output_tokens->push_back(substr);
start = end;
break;
}
}
void WordpieceTokenizer::tokenize(const std::string &token, std::vector<std::string> *output_tokens) { if(is_bad)
if (token.size() > max_input_chars_per_word) { // FIXME: slightly different {
output_tokens->resize(output_tokens_len);
output_tokens->push_back(unk_token); output_tokens->push_back(unk_token);
return; return;
} }
size_t output_tokens_len = output_tokens->size();
for (size_t start = 0; start < token.size();) {
bool is_bad = true;
// TODO: can be optimized by prefix-tree
for (size_t end = token.size(); start < end; --end) { // FIXME: slightly different
std::string substr = start > 0
? "##" + token.substr(start, end - start)
: token.substr(start, end - start);
if (vocab->count(substr)) {
is_bad = false;
output_tokens->push_back(substr);
start = end;
break;
}
}
if (is_bad) {
output_tokens->resize(output_tokens_len);
output_tokens->push_back(unk_token);
return;
}
}
} }
}
void FullTokenizer::tokenize(const char* text,
void FullTokenizer::tokenize(const char *text, std::vector<std::string> *output_tokens, size_t max_length) { std::vector<std::string>* output_tokens,
std::vector<std::string> tokens; size_t max_length)
tokens.reserve(max_length); {
basic_tokenizer->tokenize(text, &tokens, max_length); std::vector<std::string> tokens;
tokens.reserve(max_length);
for (const auto &token : tokens) { basic_tokenizer->tokenize(text, &tokens, max_length);
wordpiece_tokenizer->tokenize(token, output_tokens);
for(const auto& token : tokens)
// early terminate {
if (output_tokens->size() >= max_length) { wordpiece_tokenizer->tokenize(token, output_tokens);
break;
} // early terminate
if(output_tokens->size() >= max_length)
{
break;
} }
} }
} }
} // namespace cuBERT
#ifndef CUBERT_TOKENIZATION_H #ifndef CUBERT_TOKENIZATION_H
#define CUBERT_TOKENIZATION_H #define CUBERT_TOKENIZATION_H
#include <iostream>
#include <string> #include <string>
#include <vector>
#include <unordered_map> #include <unordered_map>
#include <iostream> #include <vector>
namespace cuBERT { namespace cuBERT {
void load_vocab(const char *vocab_file, std::unordered_map<std::string, uint64_t> *vocab); void load_vocab(const char* vocab_file, std::unordered_map<std::string, uint64_t>* vocab);
/** /**
* Checks whether `chars` is a whitespace character. * Checks whether `chars` is a whitespace character.
* @param c * @param c
* @return * @return
*/ */
bool _is_whitespace(int c); bool _is_whitespace(int c);
/** /**
* Checks whether `chars` is a control character. * Checks whether `chars` is a control character.
* @param c * @param c
* @return * @return
*/ */
bool _is_control(int c); bool _is_control(int c);
/** /**
* Checks whether `chars` is a punctuation character. * Checks whether `chars` is a punctuation character.
* @param cp * @param cp
* @return * @return
*/ */
bool _is_punctuation(int cp); bool _is_punctuation(int cp);
/** /**
* Runs basic tokenization (punctuation splitting, lower casing, etc.). * Runs basic tokenization (punctuation splitting, lower casing, etc.).
*/ */
class BasicTokenizer { class BasicTokenizer
{
public: public:
/** /**
* Constructs a BasicTokenizer. * Constructs a BasicTokenizer.
* @param do_lower_case Whether to lower case the input. * @param do_lower_case Whether to lower case the input.
*/ */
explicit BasicTokenizer(bool do_lower_case = true) : do_lower_case(do_lower_case) {} explicit BasicTokenizer(bool do_lower_case = true) : do_lower_case(do_lower_case) {}
BasicTokenizer(const BasicTokenizer &other) = delete; BasicTokenizer(const BasicTokenizer& other) = delete;
virtual ~BasicTokenizer() = default; virtual ~BasicTokenizer() = default;
/** /**
* Tokenizes a piece of text. * Tokenizes a piece of text.
* *
* to_lower * to_lower
* _run_strip_accents Strips accents from a piece of text. * _run_strip_accents Strips accents from a piece of text.
* _clean_text Performs invalid character removal and whitespace cleanup on text. * _clean_text Performs invalid character removal and whitespace cleanup on
* _tokenize_chinese_chars Adds whitespace around any CJK character. * text. _tokenize_chinese_chars Adds whitespace around any CJK character.
* _run_split_on_punc Splits punctuation on a piece of text. * _run_split_on_punc Splits punctuation on a piece of text.
* whitespace_tokenize Runs basic whitespace cleaning and splitting on a piece of text. * whitespace_tokenize Runs basic whitespace cleaning and splitting on a piece
* * of text.
* @param text *
* @param output_tokens * @param text
*/ * @param output_tokens
void tokenize(const char *text, std::vector<std::string> *output_tokens, size_t max_length); */
void tokenize(const char* text, std::vector<std::string>* output_tokens, size_t max_length);
private: private:
const bool do_lower_case; const bool do_lower_case;
/** /**
* Checks whether CP is the codepoint of a CJK character. * Checks whether CP is the codepoint of a CJK character.
* @param cp * @param cp
* @return * @return
*/ */
inline static bool _is_chinese_char(int cp); inline static bool _is_chinese_char(int cp);
}; };
/** /**
* Runs WordPiece tokenziation. * Runs WordPiece tokenziation.
*/ */
class WordpieceTokenizer { class WordpieceTokenizer
{
public: public:
explicit WordpieceTokenizer( explicit WordpieceTokenizer(std::unordered_map<std::string, uint64_t>* vocab,
std::unordered_map<std::string, uint64_t> *vocab, std::string unk_token = "[UNK]",
std::string unk_token = "[UNK]", int max_input_chars_per_word = 200)
int max_input_chars_per_word = 200 : vocab(vocab), unk_token(unk_token), max_input_chars_per_word(max_input_chars_per_word)
) : vocab(vocab), unk_token(unk_token), max_input_chars_per_word(max_input_chars_per_word) {} {
}
WordpieceTokenizer(const WordpieceTokenizer &other) = delete;
WordpieceTokenizer(const WordpieceTokenizer& other) = delete;
virtual ~WordpieceTokenizer() = default;
virtual ~WordpieceTokenizer() = default;
/**
* Tokenizes a piece of text into its word pieces. /**
* * Tokenizes a piece of text into its word pieces.
* This uses a greedy longest-match-first algorithm to perform tokenization *
* using the given vocabulary. * This uses a greedy longest-match-first algorithm to perform tokenization
* * using the given vocabulary.
* For example: *
* input = "unaffable" * For example:
* output = ["un", "##aff", "##able"] * input = "unaffable"
* * output = ["un", "##aff", "##able"]
* @param text A single token or whitespace separated tokens. This should have already been passed through `BasicTokenizer. *
* @param output_tokens A list of wordpiece tokens. * @param text A single token or whitespace separated tokens. This should have
*/ * already been passed through `BasicTokenizer.
void tokenize(const std::string &text, std::vector<std::string> *output_tokens); * @param output_tokens A list of wordpiece tokens.
*/
void tokenize(const std::string& text, std::vector<std::string>* output_tokens);
private: private:
const std::unordered_map<std::string, uint64_t> *vocab; const std::unordered_map<std::string, uint64_t>* vocab;
const std::string unk_token; const std::string unk_token;
const int max_input_chars_per_word; const int max_input_chars_per_word;
}; };
/** /**
* Runs end-to-end tokenziation. * Runs end-to-end tokenziation.
*/ */
class FullTokenizer { class FullTokenizer
{
public: public:
FullTokenizer(const char *vocab_file, bool do_lower_case = true) { FullTokenizer(const char* vocab_file, bool do_lower_case = true)
vocab = new std::unordered_map<std::string, uint64_t>(); {
load_vocab(vocab_file, vocab); vocab = new std::unordered_map<std::string, uint64_t>();
basic_tokenizer = new BasicTokenizer(do_lower_case); load_vocab(vocab_file, vocab);
wordpiece_tokenizer = new WordpieceTokenizer(vocab); basic_tokenizer = new BasicTokenizer(do_lower_case);
wordpiece_tokenizer = new WordpieceTokenizer(vocab);
}
~FullTokenizer()
{
if(wordpiece_tokenizer != NULL)
{
wordpiece_tokenizer = NULL;
} }
delete wordpiece_tokenizer;
~FullTokenizer() { if(basic_tokenizer != NULL)
if (wordpiece_tokenizer != NULL){ {
wordpiece_tokenizer = NULL; basic_tokenizer = NULL;
}
delete wordpiece_tokenizer;
if (basic_tokenizer != NULL){
basic_tokenizer = NULL;
}
delete basic_tokenizer;
if (vocab != NULL){
vocab = NULL;
}
delete vocab;
} }
delete basic_tokenizer;
void tokenize(const char *text, std::vector<std::string> *output_tokens, size_t max_length); if(vocab != NULL)
{
inline uint64_t convert_token_to_id(const std::string &token) { vocab = NULL;
auto item = vocab->find(token); }
if (item == vocab->end()) { delete vocab;
std::cerr << "vocab missing key: " << token << std::endl; }
return 0;
} else { void tokenize(const char* text, std::vector<std::string>* output_tokens, size_t max_length);
return item->second;
} inline uint64_t convert_token_to_id(const std::string& token)
{
auto item = vocab->find(token);
if(item == vocab->end())
{
std::cerr << "vocab missing key: " << token << std::endl;
return 0;
}
else
{
return item->second;
} }
}
void convert_tokens_to_ids(const std::vector<std::string> &tokens, uint64_t *ids); void convert_tokens_to_ids(const std::vector<std::string>& tokens, uint64_t* ids);
private: private:
std::unordered_map<std::string, uint64_t> *vocab; std::unordered_map<std::string, uint64_t>* vocab;
BasicTokenizer *basic_tokenizer; BasicTokenizer* basic_tokenizer;
WordpieceTokenizer *wordpiece_tokenizer; WordpieceTokenizer* wordpiece_tokenizer;
}; };
} } // namespace cuBERT
#endif //CUBERT_TOKENIZATION_H #endif // CUBERT_TOKENIZATION_H
This diff is collapsed.
This diff is collapsed.
#include <Bert.h>
#include <Filesystem.h>
#include <SimpleLog.h>
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
#include <Bert.h>
#include <SimpleLog.h>
#include <Filesystem.h>
#include <tokenization.h> #include <tokenization.h>
int main(int argc, char *argv[]) int main(int argc, char* argv[])
{ {
// 加载Bert模型 // 加载Bert模型
migraphxSamples::Bert bert; migraphxSamples::Bert bert;
migraphxSamples::ErrorCode errorCode = bert.Initialize(); migraphxSamples::ErrorCode errorCode = bert.Initialize();
if (errorCode != migraphxSamples::SUCCESS) if(errorCode != migraphxSamples::SUCCESS)
{ {
LOG_ERROR(stdout, "fail to initialize Bert!\n"); LOG_ERROR(stdout, "fail to initialize Bert!\n");
exit(-1); exit(-1);
} }
LOG_INFO(stdout, "succeed to initialize Bert\n"); LOG_INFO(stdout, "succeed to initialize Bert\n");
int max_seq_length = 256; // 滑动窗口的长度 int max_seq_length = 256; // 滑动窗口的长度
int max_query_length = 64; // 问题的最大长度 int max_query_length = 64; // 问题的最大长度
int batch_size = 1; // batch_size值 int batch_size = 1; // batch_size值
int n_best_size = 20; // 索引数量 int n_best_size = 20; // 索引数量
int max_answer_length = 30; // 答案的最大长度 int max_answer_length = 30; // 答案的最大长度
// 上下文文本数据 // 上下文文本数据
const char text[] = { u8"ROCm is the first open-source exascale-class platform for accelerated computing that’s also programming-language independent. It brings a philosophy of choice, minimalism and modular software development to GPU computing. You are free to choose or even develop tools and a language run time for your application. ROCm is built for scale, it supports multi-GPU computing and has a rich system run time with the critical features that large-scale application, compiler and language-run-time development requires. Since the ROCm ecosystem is comprised of open technologies: frameworks (Tensorflow / PyTorch), libraries (MIOpen / Blas / RCCL), programming model (HIP), inter-connect (OCD) and up streamed Linux® Kernel support – the platform is continually optimized for performance and extensibility." }; const char text[] = {u8"ROCm is the first open-source exascale-class platform for accelerated "
u8"computing that’s also programming-language independent. It brings a "
u8"philosophy of choice, minimalism and modular software development to "
u8"GPU computing. You are free to choose or even develop tools and a "
u8"language run time for your application. ROCm is built for scale, it "
u8"supports multi-GPU computing and has a rich system run time with the "
u8"critical features that large-scale application, compiler and "
u8"language-run-time development requires. Since the ROCm ecosystem is "
u8"comprised of open technologies: frameworks (Tensorflow / PyTorch), "
u8"libraries (MIOpen / Blas / RCCL), programming model (HIP), "
u8"inter-connect (OCD) and up streamed Linux® Kernel support – the "
u8"platform is continually optimized for performance and extensibility."};
char question[100]; char question[100];
std::vector<std::vector<long unsigned int>> input_ids; std::vector<std::vector<long unsigned int>> input_ids;
...@@ -35,14 +46,22 @@ int main(int argc, char *argv[]) ...@@ -35,14 +46,22 @@ int main(int argc, char *argv[])
std::vector<float> end_position; std::vector<float> end_position;
std::string answer = {}; std::string answer = {};
cuBERT::FullTokenizer tokenizer = cuBERT::FullTokenizer("../Resource/uncased_L-12_H-768_A-12/vocab.txt"); // 分词工具 cuBERT::FullTokenizer tokenizer =
cuBERT::FullTokenizer("../Resource/uncased_L-12_H-768_A-12/vocab.txt"); // 分词工具
while (true) while(true)
{ {
// 数据前处理 // 数据前处理
std::cout << "question: "; std::cout << "question: ";
cin.getline(question, 100); cin.getline(question, 100);
bert.Preprocessing(tokenizer, batch_size, max_seq_length, text, question, input_ids, input_masks, segment_ids); bert.Preprocessing(tokenizer,
batch_size,
max_seq_length,
text,
question,
input_ids,
input_masks,
segment_ids);
// 推理 // 推理
bert.Inference(input_ids, input_masks, segment_ids, start_position, end_position); bert.Inference(input_ids, input_masks, segment_ids, start_position, end_position);
...@@ -61,6 +80,6 @@ int main(int argc, char *argv[]) ...@@ -61,6 +80,6 @@ int main(int argc, char *argv[])
end_position.clear(); end_position.clear();
answer = {}; answer = {};
} }
return 0; return 0;
} }
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment