"docs/vscode:/vscode.git/clone" did not exist on "e16e895d5364cdff243c2fb548e27fbc6366bf2b"
Commit 0f9dc829 authored by liucong's avatar liucong
Browse files

重新格式化Cppd代码格式

parent 824cfb81
...@@ -12,90 +12,82 @@ ...@@ -12,90 +12,82 @@
namespace migraphxSamples namespace migraphxSamples
{ {
GPT2::GPT2() GPT2::GPT2() {}
{
}
GPT2::~GPT2()
{
} GPT2::~GPT2() {}
ErrorCode GPT2::Initialize() ErrorCode GPT2::Initialize()
{ {
// 获取模型文件 // 获取模型文件
std::string modelPath="../Resource/GPT2_shici.onnx"; std::string modelPath = "../Resource/GPT2_shici.onnx";
// 设置最大输入shape // 设置最大输入shape
migraphx::onnx_options onnx_options; migraphx::onnx_options onnx_options;
onnx_options.map_input_dims["input"] = {1,1000}; onnx_options.map_input_dims["input"] = {1, 1000};
// 加载模型 // 加载模型
if(!Exists(modelPath)) if(!Exists(modelPath))
{ {
LOG_ERROR(stdout,"%s not exist!\n",modelPath.c_str()); LOG_ERROR(stdout, "%s not exist!\n", modelPath.c_str());
return MODEL_NOT_EXIST; return MODEL_NOT_EXIST;
} }
net = migraphx::parse_onnx(modelPath, onnx_options); net = migraphx::parse_onnx(modelPath, onnx_options);
LOG_INFO(stdout,"succeed to load model: %s\n",GetFileName(modelPath).c_str()); LOG_INFO(stdout, "succeed to load model: %s\n", GetFileName(modelPath).c_str());
// 获取模型输入/输出节点信息 // 获取模型输入/输出节点信息
std::unordered_map<std::string, migraphx::shape> inputs = net.get_inputs(); std::unordered_map<std::string, migraphx::shape> inputs = net.get_inputs();
std::unordered_map<std::string, migraphx::shape> outputs = net.get_outputs(); std::unordered_map<std::string, migraphx::shape> outputs = net.get_outputs();
inputName=inputs.begin()->first; inputName = inputs.begin()->first;
inputShape=inputs.begin()->second; inputShape = inputs.begin()->second;
// 设置模型为GPU模式 // 设置模型为GPU模式
migraphx::target gpuTarget = migraphx::gpu::target{}; migraphx::target gpuTarget = migraphx::gpu::target{};
// 编译模型 // 编译模型
migraphx::compile_options options; migraphx::compile_options options;
options.device_id = 0; // 设置GPU设备,默认为0号设备 options.device_id = 0; // 设置GPU设备,默认为0号设备
options.offload_copy = true; // 设置offload_copy options.offload_copy = true; // 设置offload_copy
net.compile(gpuTarget,options); net.compile(gpuTarget, options);
LOG_INFO(stdout,"succeed to compile model: %s\n",GetFileName(modelPath).c_str()); LOG_INFO(stdout, "succeed to compile model: %s\n", GetFileName(modelPath).c_str());
return SUCCESS; return SUCCESS;
} }
static bool CompareM(Predictions a, Predictions b) static bool CompareM(Predictions a, Predictions b) { return a.predictionvalue > b.predictionvalue; }
{
return a.predictionvalue > b.predictionvalue;
}
long unsigned int GPT2::Inference(const std::vector<long unsigned int> &input_id) long unsigned int GPT2::Inference(const std::vector<long unsigned int>& input_id)
{ {
long unsigned int input[1][input_id.size()]; long unsigned int input[1][input_id.size()];
for (int j=0;j<input_id.size();++j) for(int j = 0; j < input_id.size(); ++j)
{ {
input[0][j] = input_id[j]; input[0][j] = input_id[j];
} }
// 设置输入shape // 设置输入shape
std::vector<std::vector<std::size_t>> inputShapes; std::vector<std::vector<std::size_t>> inputShapes;
inputShapes.push_back({1,input_id.size()}); inputShapes.push_back({1, input_id.size()});
// 创建输入数据 // 创建输入数据
std::unordered_map<std::string, migraphx::argument> inputData; std::unordered_map<std::string, migraphx::argument> inputData;
inputData[inputName]=migraphx::argument{migraphx::shape(inputShape.type(),inputShapes[0]),(long unsigned int*)input}; inputData[inputName] = migraphx::argument{migraphx::shape(inputShape.type(), inputShapes[0]),
(long unsigned int*)input};
// 推理 // 推理
std::vector<migraphx::argument> results = net.eval(inputData); std::vector<migraphx::argument> results = net.eval(inputData);
// 获取输出节点的属性 // 获取输出节点的属性
migraphx::argument result = results[0]; migraphx::argument result = results[0];
migraphx::shape outputShape = result.get_shape(); // 输出节点的shape migraphx::shape outputShape = result.get_shape(); // 输出节点的shape
int numberOfOutput=outputShape.elements(); // 输出节点元素的个数 int numberOfOutput = outputShape.elements(); // 输出节点元素的个数
float *data = (float *)result.data(); // 输出节点数据指针 float* data = (float*)result.data(); // 输出节点数据指针
// 保存推理结果 // 保存推理结果
long unsigned int n = 0; long unsigned int n = 0;
std::vector<Predictions> resultsOfPredictions(22557); std::vector<Predictions> resultsOfPredictions(22557);
for(int i=(input_id.size()-1)*22557; i<input_id.size()*22557; ++i) for(int i = (input_id.size() - 1) * 22557; i < input_id.size() * 22557; ++i)
{ {
resultsOfPredictions[n].index = n; resultsOfPredictions[n].index = n;
resultsOfPredictions[n].predictionvalue = data[i]; resultsOfPredictions[n].predictionvalue = data[i];
++n; ++n;
} }
...@@ -110,8 +102,8 @@ long unsigned int GPT2::Inference(const std::vector<long unsigned int> &input_id ...@@ -110,8 +102,8 @@ long unsigned int GPT2::Inference(const std::vector<long unsigned int> &input_id
} }
ErrorCode GPT2::Preprocessing(cuBERT::FullTokenizer tokenizer, ErrorCode GPT2::Preprocessing(cuBERT::FullTokenizer tokenizer,
char *question, char* question,
std::vector<long unsigned int> &input_id) std::vector<long unsigned int>& input_id)
{ {
// 分词操作 // 分词操作
int max_seq_length = 1000; int max_seq_length = 1000;
...@@ -121,7 +113,7 @@ ErrorCode GPT2::Preprocessing(cuBERT::FullTokenizer tokenizer, ...@@ -121,7 +113,7 @@ ErrorCode GPT2::Preprocessing(cuBERT::FullTokenizer tokenizer,
// 保存编码信息 // 保存编码信息
input_id.push_back(tokenizer.convert_token_to_id("[CLS]")); input_id.push_back(tokenizer.convert_token_to_id("[CLS]"));
for (int i=0;i<tokens_question.size();++i) for(int i = 0; i < tokens_question.size(); ++i)
{ {
input_id.push_back(tokenizer.convert_token_to_id(tokens_question[i])); input_id.push_back(tokenizer.convert_token_to_id(tokens_question[i]));
} }
...@@ -129,4 +121,4 @@ ErrorCode GPT2::Preprocessing(cuBERT::FullTokenizer tokenizer, ...@@ -129,4 +121,4 @@ ErrorCode GPT2::Preprocessing(cuBERT::FullTokenizer tokenizer,
return SUCCESS; return SUCCESS;
} }
} } // namespace migraphxSamples
\ No newline at end of file \ No newline at end of file
...@@ -8,42 +8,42 @@ ...@@ -8,42 +8,42 @@
namespace migraphxSamples namespace migraphxSamples
{ {
typedef enum _ErrorCode typedef enum _ErrorCode
{ {
SUCCESS=0, SUCCESS = 0,
MODEL_NOT_EXIST, MODEL_NOT_EXIST,
CONFIG_FILE_NOT_EXIST, CONFIG_FILE_NOT_EXIST,
FAIL_TO_LOAD_MODEL, FAIL_TO_LOAD_MODEL,
FAIL_TO_OPEN_CONFIG_FILE, FAIL_TO_OPEN_CONFIG_FILE,
}ErrorCode; } ErrorCode;
typedef struct _Predictions typedef struct _Predictions
{ {
long unsigned int index; long unsigned int index;
float predictionvalue; float predictionvalue;
}Predictions; } Predictions;
class GPT2 class GPT2
{ {
public: public:
GPT2(); GPT2();
~GPT2(); ~GPT2();
ErrorCode Initialize(); ErrorCode Initialize();
ErrorCode Preprocessing(cuBERT::FullTokenizer tokenizer, ErrorCode Preprocessing(cuBERT::FullTokenizer tokenizer,
char *question, char* question,
std::vector<long unsigned int> &input_id); std::vector<long unsigned int>& input_id);
long unsigned int Inference(const std::vector<long unsigned int> &input_id); long unsigned int Inference(const std::vector<long unsigned int>& input_id);
private: private:
migraphx::program net; migraphx::program net;
std::string inputName; std::string inputName;
migraphx::shape inputShape; migraphx::shape inputShape;
}; };
} } // namespace migraphxSamples
#endif #endif
\ No newline at end of file
This diff is collapsed.
...@@ -5,27 +5,27 @@ ...@@ -5,27 +5,27 @@
#include <string> #include <string>
#include <vector> #include <vector>
namespace migraphxSamples namespace migraphxSamples
{ {
// 路径是否存在 // 路径是否存在
bool Exists(const std::string &path); bool Exists(const std::string& path);
// 路径是否为目录 // 路径是否为目录
bool IsDirectory(const std::string &path); bool IsDirectory(const std::string& path);
// 是否是路径分隔符(Linux:‘/’,Windows:’\\’) // 是否是路径分隔符(Linux:‘/’,Windows:’\\’)
bool IsPathSeparator(char c); bool IsPathSeparator(char c);
// 路径拼接 // 路径拼接
std::string JoinPath(const std::string &base, const std::string &path); std::string JoinPath(const std::string& base, const std::string& path);
// 创建多级目录,注意:创建多级目录的时候,目标目录是不能有文件存在的 // 创建多级目录,注意:创建多级目录的时候,目标目录是不能有文件存在的
bool CreateDirectories(const std::string &directoryPath); bool CreateDirectories(const std::string& directoryPath);
/** 生成符合指定模式的文件名列表(支持递归遍历) /** 生成符合指定模式的文件名列表(支持递归遍历)
* *
* pattern: 模式,比如"*.jpg","*.png","*.jpg,*.png" * pattern: 模式,比如"*.jpg","*.png","*.jpg,*.png"
* addPath:是否包含父路径 * addPath:是否包含父路径
* 注意: * 注意:
...@@ -36,35 +36,43 @@ bool CreateDirectories(const std::string &directoryPath); ...@@ -36,35 +36,43 @@ bool CreateDirectories(const std::string &directoryPath);
5. 不能返回子目录名 5. 不能返回子目录名
* *
*/ */
void GetFileNameList(const std::string &directory, const std::string &pattern, std::vector<std::string> &result, bool recursive, bool addPath); void GetFileNameList(const std::string& directory,
const std::string& pattern,
std::vector<std::string>& result,
bool recursive,
bool addPath);
// 与GetFileNameList的区别在于如果有子目录,在addPath为true的时候会返回子目录路径(目录名最后有"/") // 与GetFileNameList的区别在于如果有子目录,在addPath为true的时候会返回子目录路径(目录名最后有"/")
void GetFileNameList2(const std::string &directory, const std::string &pattern, std::vector<std::string> &result, bool recursive, bool addPath); void GetFileNameList2(const std::string& directory,
const std::string& pattern,
std::vector<std::string>& result,
bool recursive,
bool addPath);
// 删除文件或者目录,支持递归删除 // 删除文件或者目录,支持递归删除
void Remove(const std::string &directory, const std::string &extension=""); void Remove(const std::string& directory, const std::string& extension = "");
/** 获取路径的文件名和扩展名 /** 获取路径的文件名和扩展名
* *
* 示例:path为D:/1/1.txt,则GetFileName()为1.txt,GetFileName_NoExtension()为1,GetExtension()为.txt,GetParentPath()为D:/1/ * 示例:path为D:/1/1.txt,则GetFileName()为1.txt,GetFileName_NoExtension()为1,GetExtension()为.txt,GetParentPath()为D:/1/
*/ */
std::string GetFileName(const std::string &path); std::string GetFileName(const std::string& path);
std::string GetFileName_NoExtension(const std::string &path); std::string GetFileName_NoExtension(const std::string& path);
std::string GetExtension(const std::string &path); std::string GetExtension(const std::string& path);
std::string GetParentPath(const std::string &path); std::string GetParentPath(const std::string& path);
// 拷贝文件 // 拷贝文件
bool CopyFile(const std::string srcPath,const std::string dstPath); bool CopyFile(const std::string srcPath, const std::string dstPath);
/** 拷贝目录 /** 拷贝目录
* *
* 示例:CopyDirectories("D:/0/1/2/","E:/3/");实现把D:/0/1/2/目录拷贝到E:/3/目录中(即拷贝完成后的目录结构为E:/3/2/) * 示例:CopyDirectories("D:/0/1/2/","E:/3/");实现把D:/0/1/2/目录拷贝到E:/3/目录中(即拷贝完成后的目录结构为E:/3/2/)
* 注意: * 注意:
1.第一个参数的最后不能加”/” 1.第一个参数的最后不能加”/”
2.不能拷贝隐藏文件 2.不能拷贝隐藏文件
*/ */
bool CopyDirectories(std::string srcPath,const std::string dstPath); bool CopyDirectories(std::string srcPath, const std::string dstPath);
} } // namespace migraphxSamples
#endif #endif
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include <map> #include <map>
#include <thread> #include <thread>
#include <mutex> #include <mutex>
#if (defined WIN32 || defined _WIN32) #if(defined WIN32 || defined _WIN32)
#include <Windows.h> #include <Windows.h>
#else #else
#include <sys/time.h> #include <sys/time.h>
...@@ -16,13 +16,13 @@ ...@@ -16,13 +16,13 @@
using namespace std; using namespace std;
/** 简易日志 /** 简易日志
* *
* 不依赖于其他第三方库,只需要包含一个头文件就可以使用。提供了4种日志级别,包括INFO,DEBUG,WARN和ERROR。 * 不依赖于其他第三方库,只需要包含一个头文件就可以使用。提供了4种日志级别,包括INFO,DEBUG,WARN和ERROR。
* *
* 示例1: * 示例1:
// 初始化日志,在./Log/目录下创建两个日志文件log1.log和log2.log(注意:目录./Log/需要存在,否则日志创建失败) //
初始化日志,在./Log/目录下创建两个日志文件log1.log和log2.log(注意:目录./Log/需要存在,否则日志创建失败)
LogManager::GetInstance()->Initialize("./Log/","log1"); LogManager::GetInstance()->Initialize("./Log/","log1");
LogManager::GetInstance()->Initialize("./Log/","log2"); LogManager::GetInstance()->Initialize("./Log/","log2");
...@@ -34,11 +34,11 @@ using namespace std; ...@@ -34,11 +34,11 @@ using namespace std;
// 关闭日志 // 关闭日志
LogManager::GetInstance()->Close("log1"); LogManager::GetInstance()->Close("log1");
LogManager::GetInstance()->Close("log2"); LogManager::GetInstance()->Close("log2");
* 示例2: * 示例2:
// 将日志输出到控制台 // 将日志输出到控制台
string log = "Hello World"; string log = "Hello World";
LOG_INFO(stdout, "%s\n", log.c_str()); LOG_INFO(stdout, "%s\n", log.c_str());
* 注意: * 注意:
1. 需要C++11 1. 需要C++11
...@@ -50,44 +50,43 @@ using namespace std; ...@@ -50,44 +50,43 @@ using namespace std;
class LogManager class LogManager
{ {
private: private:
LogManager(){} LogManager() {}
public: public:
~LogManager(){} ~LogManager() {}
inline void Initialize(const string &parentPath,const string &logName) inline void Initialize(const string& parentPath, const string& logName)
{ {
// 日志名为空表示输出到控制台 // 日志名为空表示输出到控制台
if(logName.size()==0) if(logName.size() == 0)
return; return;
// 查找该日志文件,如果没有则创建 // 查找该日志文件,如果没有则创建
std::map<string, FILE*>::const_iterator iter = logMap.find(logName); std::map<string, FILE*>::const_iterator iter = logMap.find(logName);
if (iter == logMap.end()) if(iter == logMap.end())
{ {
string pathOfLog = parentPath+ logName + ".log"; string pathOfLog = parentPath + logName + ".log";
FILE *logFile = fopen(pathOfLog.c_str(), "a"); // w:覆盖原有文件,a:追加 FILE* logFile = fopen(pathOfLog.c_str(), "a"); // w:覆盖原有文件,a:追加
if(logFile!=NULL) if(logFile != NULL)
{ {
logMap.insert(std::make_pair(logName, logFile)); logMap.insert(std::make_pair(logName, logFile));
} }
} }
} }
inline FILE* GetLogFile(const string &logName) inline FILE* GetLogFile(const string& logName)
{ {
std::map<string, FILE*>::const_iterator iter=logMap.find(logName); std::map<string, FILE*>::const_iterator iter = logMap.find(logName);
if(iter==logMap.end()) if(iter == logMap.end())
{ {
return NULL; return NULL;
} }
return (*iter).second; return (*iter).second;
} }
inline void Close(const string &logName) inline void Close(const string& logName)
{ {
std::map<string, FILE*>::const_iterator iter=logMap.find(logName); std::map<string, FILE*>::const_iterator iter = logMap.find(logName);
if(iter==logMap.end()) if(iter == logMap.end())
{ {
return; return;
} }
...@@ -95,10 +94,7 @@ public: ...@@ -95,10 +94,7 @@ public:
fclose((*iter).second); fclose((*iter).second);
logMap.erase(iter); logMap.erase(iter);
} }
inline std::mutex &GetLogMutex() inline std::mutex& GetLogMutex() { return logMutex; }
{
return logMutex;
}
// Singleton // Singleton
static LogManager* GetInstance() static LogManager* GetInstance()
...@@ -106,21 +102,22 @@ public: ...@@ -106,21 +102,22 @@ public:
static LogManager logManager; static LogManager logManager;
return &logManager; return &logManager;
} }
private:
private:
std::map<string, FILE*> logMap; std::map<string, FILE*> logMap;
std::mutex logMutex; std::mutex logMutex;
}; };
#ifdef LOG_MUTEX #ifdef LOG_MUTEX
#define LOCK LogManager::GetInstance()->GetLogMutex().lock() #define LOCK LogManager::GetInstance()->GetLogMutex().lock()
#define UNLOCK LogManager::GetInstance()->GetLogMutex().unlock() #define UNLOCK LogManager::GetInstance()->GetLogMutex().unlock()
#else #else
#define LOCK #define LOCK
#define UNLOCK #define UNLOCK
#endif #endif
// log time // log time
typedef struct _LogTime typedef struct _LogTime
{ {
string year; string year;
string month; string month;
...@@ -131,53 +128,53 @@ typedef struct _LogTime ...@@ -131,53 +128,53 @@ typedef struct _LogTime
string millisecond; // ms string millisecond; // ms
string microsecond; // us string microsecond; // us
string weekDay; string weekDay;
}LogTime; } LogTime;
inline LogTime GetTime() inline LogTime GetTime()
{ {
LogTime currentTime; LogTime currentTime;
#if (defined WIN32 || defined _WIN32) #if(defined WIN32 || defined _WIN32)
SYSTEMTIME systemTime; SYSTEMTIME systemTime;
GetLocalTime(&systemTime); GetLocalTime(&systemTime);
char temp[8] = { 0 }; char temp[8] = {0};
sprintf(temp, "%04d", systemTime.wYear); sprintf(temp, "%04d", systemTime.wYear);
currentTime.year=string(temp); currentTime.year = string(temp);
sprintf(temp, "%02d", systemTime.wMonth); sprintf(temp, "%02d", systemTime.wMonth);
currentTime.month=string(temp); currentTime.month = string(temp);
sprintf(temp, "%02d", systemTime.wDay); sprintf(temp, "%02d", systemTime.wDay);
currentTime.day=string(temp); currentTime.day = string(temp);
sprintf(temp, "%02d", systemTime.wHour); sprintf(temp, "%02d", systemTime.wHour);
currentTime.hour=string(temp); currentTime.hour = string(temp);
sprintf(temp, "%02d", systemTime.wMinute); sprintf(temp, "%02d", systemTime.wMinute);
currentTime.minute=string(temp); currentTime.minute = string(temp);
sprintf(temp, "%02d", systemTime.wSecond); sprintf(temp, "%02d", systemTime.wSecond);
currentTime.second=string(temp); currentTime.second = string(temp);
sprintf(temp, "%03d", systemTime.wMilliseconds); sprintf(temp, "%03d", systemTime.wMilliseconds);
currentTime.millisecond=string(temp); currentTime.millisecond = string(temp);
sprintf(temp, "%d", systemTime.wDayOfWeek); sprintf(temp, "%d", systemTime.wDayOfWeek);
currentTime.weekDay=string(temp); currentTime.weekDay = string(temp);
#else #else
struct timeval tv; struct timeval tv;
struct tm *p; struct tm* p;
gettimeofday(&tv, NULL); gettimeofday(&tv, NULL);
p = localtime(&tv.tv_sec); p = localtime(&tv.tv_sec);
char temp[8]={0}; char temp[8] = {0};
sprintf(temp,"%04d",1900+p->tm_year); sprintf(temp, "%04d", 1900 + p->tm_year);
currentTime.year=string(temp); currentTime.year = string(temp);
sprintf(temp,"%02d",1+p->tm_mon); sprintf(temp, "%02d", 1 + p->tm_mon);
currentTime.month=string(temp); currentTime.month = string(temp);
sprintf(temp,"%02d",p->tm_mday); sprintf(temp, "%02d", p->tm_mday);
currentTime.day=string(temp); currentTime.day = string(temp);
sprintf(temp,"%02d",p->tm_hour); sprintf(temp, "%02d", p->tm_hour);
currentTime.hour=string(temp); currentTime.hour = string(temp);
sprintf(temp,"%02d",p->tm_min); sprintf(temp, "%02d", p->tm_min);
currentTime.minute=string(temp); currentTime.minute = string(temp);
sprintf(temp,"%02d",p->tm_sec); sprintf(temp, "%02d", p->tm_sec);
currentTime.second=string(temp); currentTime.second = string(temp);
sprintf(temp,"%03d",(int)(tv.tv_usec/1000)); sprintf(temp, "%03d", (int)(tv.tv_usec / 1000));
currentTime.millisecond = string(temp); currentTime.millisecond = string(temp);
sprintf(temp, "%03d", (int)(tv.tv_usec % 1000)); sprintf(temp, "%03d", (int)(tv.tv_usec % 1000));
currentTime.microsecond = string(temp); currentTime.microsecond = string(temp);
...@@ -187,61 +184,83 @@ inline LogTime GetTime() ...@@ -187,61 +184,83 @@ inline LogTime GetTime()
return currentTime; return currentTime;
} }
#define LOG_TIME(logFile) \ #define LOG_TIME(logFile) \
do\ do \
{\ { \
LogTime currentTime=GetTime(); \ LogTime currentTime = GetTime(); \
fprintf(((logFile == NULL) ? stdout : logFile), "%s-%s-%s %s:%s:%s.%s\t",currentTime.year.c_str(),currentTime.month.c_str(),currentTime.day.c_str(),currentTime.hour.c_str(),currentTime.minute.c_str(),currentTime.second.c_str(),currentTime.millisecond.c_str()); \ fprintf(((logFile == NULL) ? stdout : logFile), \
}while (0) "%s-%s-%s %s:%s:%s.%s\t", \
currentTime.year.c_str(), \
currentTime.month.c_str(), \
#define LOG_INFO(logFile,logInfo, ...) \ currentTime.day.c_str(), \
do\ currentTime.hour.c_str(), \
{\ currentTime.minute.c_str(), \
LOCK; \ currentTime.second.c_str(), \
LOG_TIME(logFile); \ currentTime.millisecond.c_str()); \
fprintf(((logFile == NULL) ? stdout : logFile), "INFO\t"); \ } while(0)
fprintf(((logFile == NULL) ? stdout : logFile), "[%s:%d (%s) ]: ", __FILE__, __LINE__, __FUNCTION__); \
fprintf(((logFile == NULL) ? stdout : logFile), logInfo, ## __VA_ARGS__); \
fflush(logFile); \
UNLOCK; \
} while (0)
#define LOG_DEBUG(logFile,logInfo, ...) \
do\
{\
LOCK; \
LOG_TIME(logFile);\
fprintf(((logFile==NULL)?stdout:logFile), "DEBUG\t"); \
fprintf(((logFile==NULL)?stdout:logFile), "[%s:%d (%s) ]: ", __FILE__, __LINE__, __FUNCTION__); \
fprintf(((logFile==NULL)?stdout:logFile),logInfo, ## __VA_ARGS__); \
fflush(logFile); \
UNLOCK; \
} while (0)
#define LOG_ERROR(logFile,logInfo, ...) \
do\
{\
LOCK; \
LOG_TIME(logFile);\
fprintf(((logFile==NULL)?stdout:logFile), "ERROR\t"); \
fprintf(((logFile==NULL)?stdout:logFile), "[%s:%d (%s) ]: ", __FILE__, __LINE__, __FUNCTION__); \
fprintf(((logFile==NULL)?stdout:logFile),logInfo, ## __VA_ARGS__); \
fflush(logFile); \
UNLOCK; \
} while (0)
#define LOG_WARN(logFile,logInfo, ...) \
do\
{\
LOCK; \
LOG_TIME(logFile);\
fprintf(((logFile==NULL)?stdout:logFile), "WARN\t"); \
fprintf(((logFile==NULL)?stdout:logFile), "[%s:%d (%s) ]: ", __FILE__, __LINE__, __FUNCTION__); \
fprintf(((logFile==NULL)?stdout:logFile),logInfo, ## __VA_ARGS__); \
fflush(logFile); \
UNLOCK; \
} while (0)
#endif // __SIMPLE_LOG_H__ #define LOG_INFO(logFile, logInfo, ...) \
do \
{ \
LOCK; \
LOG_TIME(logFile); \
fprintf(((logFile == NULL) ? stdout : logFile), "INFO\t"); \
fprintf(((logFile == NULL) ? stdout : logFile), \
"[%s:%d (%s) ]: ", \
__FILE__, \
__LINE__, \
__FUNCTION__); \
fprintf(((logFile == NULL) ? stdout : logFile), logInfo, ##__VA_ARGS__); \
fflush(logFile); \
UNLOCK; \
} while(0)
#define LOG_DEBUG(logFile, logInfo, ...) \
do \
{ \
LOCK; \
LOG_TIME(logFile); \
fprintf(((logFile == NULL) ? stdout : logFile), "DEBUG\t"); \
fprintf(((logFile == NULL) ? stdout : logFile), \
"[%s:%d (%s) ]: ", \
__FILE__, \
__LINE__, \
__FUNCTION__); \
fprintf(((logFile == NULL) ? stdout : logFile), logInfo, ##__VA_ARGS__); \
fflush(logFile); \
UNLOCK; \
} while(0)
#define LOG_ERROR(logFile, logInfo, ...) \
do \
{ \
LOCK; \
LOG_TIME(logFile); \
fprintf(((logFile == NULL) ? stdout : logFile), "ERROR\t"); \
fprintf(((logFile == NULL) ? stdout : logFile), \
"[%s:%d (%s) ]: ", \
__FILE__, \
__LINE__, \
__FUNCTION__); \
fprintf(((logFile == NULL) ? stdout : logFile), logInfo, ##__VA_ARGS__); \
fflush(logFile); \
UNLOCK; \
} while(0)
#define LOG_WARN(logFile, logInfo, ...) \
do \
{ \
LOCK; \
LOG_TIME(logFile); \
fprintf(((logFile == NULL) ? stdout : logFile), "WARN\t"); \
fprintf(((logFile == NULL) ? stdout : logFile), \
"[%s:%d (%s) ]: ", \
__FILE__, \
__LINE__, \
__FUNCTION__); \
fprintf(((logFile == NULL) ? stdout : logFile), logInfo, ##__VA_ARGS__); \
fflush(logFile); \
UNLOCK; \
} while(0)
#endif // __SIMPLE_LOG_H__
...@@ -6,224 +6,257 @@ ...@@ -6,224 +6,257 @@
#include "./tokenization.h" #include "./tokenization.h"
namespace cuBERT
namespace cuBERT { {
void FullTokenizer::convert_tokens_to_ids(const std::vector<std::string> &tokens, uint64_t *ids) { void FullTokenizer::convert_tokens_to_ids(const std::vector<std::string>& tokens, uint64_t* ids)
for (int i = 0; i < tokens.size(); ++i) { {
ids[i] = convert_token_to_id(tokens[i]); for(int i = 0; i < tokens.size(); ++i)
} {
ids[i] = convert_token_to_id(tokens[i]);
} }
}
// trim from start (in place) // trim from start (in place)
static inline void ltrim(std::string &s) { static inline void ltrim(std::string& s)
s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int ch) { {
return !std::isspace(ch); s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int ch) { return !std::isspace(ch); }));
})); }
}
// trim from end (in place) // trim from end (in place)
static inline void rtrim(std::string &s) { static inline void rtrim(std::string& s)
s.erase(std::find_if(s.rbegin(), s.rend(), [](int ch) { {
return !std::isspace(ch); s.erase(std::find_if(s.rbegin(), s.rend(), [](int ch) { return !std::isspace(ch); }).base(),
}).base(), s.end()); s.end());
} }
// trim from both ends (in place) // trim from both ends (in place)
static inline void trim(std::string &s) { static inline void trim(std::string& s)
ltrim(s); {
rtrim(s); ltrim(s);
} rtrim(s);
}
void load_vocab(const char *vocab_file, std::unordered_map<std::string, uint64_t> *vocab) {
std::ifstream file(vocab_file);
if (!file) {
throw std::invalid_argument("Unable to open vocab file");
}
unsigned int index = 0;
std::string line;
while (std::getline(file, line)) {
trim(line);
(*vocab)[line] = index;
index++;
}
file.close(); void load_vocab(const char* vocab_file, std::unordered_map<std::string, uint64_t>* vocab)
{
std::ifstream file(vocab_file);
if(!file)
{
throw std::invalid_argument("Unable to open vocab file");
} }
inline bool _is_whitespace(int c, const char *cat) { unsigned int index = 0;
if (c == ' ' || c == '\t' || c == '\n' || c == '\r') { std::string line;
return true; while(std::getline(file, line))
} {
return cat[0] == 'Z' && cat[1] == 's'; trim(line);
(*vocab)[line] = index;
index++;
} }
inline bool _is_control(int c, const char *cat) { file.close();
// These are technically control characters but we count them as whitespace characters. }
if (c == '\t' || c == '\n' || c == '\r') {
return false;
}
return 'C' == *cat;
}
inline bool _is_punctuation(int cp, const char *cat) { inline bool _is_whitespace(int c, const char* cat)
// We treat all non-letter/number ASCII as punctuation. {
// Characters such as "^", "$", and "`" are not in the Unicode if(c == ' ' || c == '\t' || c == '\n' || c == '\r')
// Punctuation class but we treat them as punctuation anyways, for {
// consistency. return true;
if ((cp >= 33 && cp <= 47) || (cp >= 58 && cp <= 64) ||
(cp >= 91 && cp <= 96) || (cp >= 123 && cp <= 126)) {
return true;
}
return 'P' == *cat;
} }
return cat[0] == 'Z' && cat[1] == 's';
}
bool _is_whitespace(int c) { inline bool _is_control(int c, const char* cat)
return _is_whitespace(c, utf8proc_category_string(c)); {
// These are technically control characters but we count them as whitespace characters.
if(c == '\t' || c == '\n' || c == '\r')
{
return false;
} }
return 'C' == *cat;
}
bool _is_control(int c) { inline bool _is_punctuation(int cp, const char* cat)
return _is_control(c, utf8proc_category_string(c)); {
// We treat all non-letter/number ASCII as punctuation.
// Characters such as "^", "$", and "`" are not in the Unicode
// Punctuation class but we treat them as punctuation anyways, for
// consistency.
if((cp >= 33 && cp <= 47) || (cp >= 58 && cp <= 64) || (cp >= 91 && cp <= 96) ||
(cp >= 123 && cp <= 126))
{
return true;
} }
return 'P' == *cat;
}
bool _is_punctuation(int cp) { bool _is_whitespace(int c) { return _is_whitespace(c, utf8proc_category_string(c)); }
return _is_punctuation(cp, utf8proc_category_string(cp));
} bool _is_control(int c) { return _is_control(c, utf8proc_category_string(c)); }
bool _is_punctuation(int cp) { return _is_punctuation(cp, utf8proc_category_string(cp)); }
bool BasicTokenizer::_is_chinese_char(int cp)
{
// This defines a "chinese character" as anything in the CJK Unicode block:
// https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
//
// Note that the CJK Unicode block is NOT all Japanese and Korean characters,
// despite its name. The modern Korean Hangul alphabet is a different block,
// as is Japanese Hiragana and Katakana. Those alphabets are used to write
// space-separated words, so they are not treated specially and handled
// like the all of the other languages.
return (cp >= 0x4E00 && cp <= 0x9FFF) || (cp >= 0x3400 && cp <= 0x4DBF) ||
(cp >= 0x20000 && cp <= 0x2A6DF) || (cp >= 0x2A700 && cp <= 0x2B73F) ||
(cp >= 0x2B740 && cp <= 0x2B81F) || (cp >= 0x2B820 && cp <= 0x2CEAF) ||
(cp >= 0xF900 && cp <= 0xFAFF) || (cp >= 0x2F800 && cp <= 0x2FA1F);
}
bool BasicTokenizer::_is_chinese_char(int cp) { void BasicTokenizer::tokenize(const char* text,
// This defines a "chinese character" as anything in the CJK Unicode block: std::vector<std::string>* output_tokens,
// https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) size_t max_length)
// {
// Note that the CJK Unicode block is NOT all Japanese and Korean characters, // This was added on November 1st, 2018 for the multilingual and Chinese
// despite its name. The modern Korean Hangul alphabet is a different block, // models. This is also applied to the English models now, but it doesn't
// as is Japanese Hiragana and Katakana. Those alphabets are used to write // matter since the English models were not trained on any Chinese data
// space-separated words, so they are not treated specially and handled // and generally don't have any Chinese data in them (there are Chinese
// like the all of the other languages. // characters in the vocabulary because Wikipedia does have some Chinese
return (cp >= 0x4E00 && cp <= 0x9FFF) || // words in the English Wikipedia.).
(cp >= 0x3400 && cp <= 0x4DBF) || if(do_lower_case)
(cp >= 0x20000 && cp <= 0x2A6DF) || {
(cp >= 0x2A700 && cp <= 0x2B73F) || text = (const char*)utf8proc_NFD((const utf8proc_uint8_t*)text);
(cp >= 0x2B740 && cp <= 0x2B81F) ||
(cp >= 0x2B820 && cp <= 0x2CEAF) ||
(cp >= 0xF900 && cp <= 0xFAFF) ||
(cp >= 0x2F800 && cp <= 0x2FA1F);
} }
void BasicTokenizer::tokenize(const char *text, std::vector<std::string> *output_tokens, size_t max_length) { size_t word_bytes = std::strlen(text);
// This was added on November 1st, 2018 for the multilingual and Chinese bool new_token = true;
// models. This is also applied to the English models now, but it doesn't size_t subpos = 0;
// matter since the English models were not trained on any Chinese data int cp;
// and generally don't have any Chinese data in them (there are Chinese char dst[4];
// characters in the vocabulary because Wikipedia does have some Chinese
// words in the English Wikipedia.). while(word_bytes > 0)
if (do_lower_case) { {
text = (const char *) utf8proc_NFD((const utf8proc_uint8_t *) text); int len = utf8proc_iterate((const utf8proc_uint8_t*)text + subpos, word_bytes, &cp);
if(len < 0)
{
std::cerr << "UTF-8 decode error: " << text << std::endl;
break;
}
if(do_lower_case)
{
cp = utf8proc_tolower(cp);
} }
size_t word_bytes = std::strlen(text); const char* cat = utf8proc_category_string(cp);
bool new_token = true; if(cp == 0 || cp == 0xfffd || _is_control(cp, cat))
size_t subpos = 0; {
int cp; // pass
char dst[4]; }
else if(do_lower_case && cat[0] == 'M' && cat[1] == 'n')
while (word_bytes > 0) { {
int len = utf8proc_iterate((const utf8proc_uint8_t *) text + subpos, word_bytes, &cp); // pass
if (len < 0) { }
std::cerr << "UTF-8 decode error: " << text << std::endl; else if(_is_whitespace(cp, cat))
break; {
} new_token = true;
if (do_lower_case) { }
cp = utf8proc_tolower(cp); else
{
size_t dst_len = len;
const char* dst_ptr = text + subpos;
if(do_lower_case)
{
dst_len = utf8proc_encode_char(cp, (utf8proc_uint8_t*)dst);
dst_ptr = dst;
} }
const char *cat = utf8proc_category_string(cp); if(_is_punctuation(cp, cat) || _is_chinese_char(cp))
if (cp == 0 || cp == 0xfffd || _is_control(cp, cat)) { {
// pass output_tokens->emplace_back(dst_ptr, dst_len);
} else if (do_lower_case && cat[0] == 'M' && cat[1] == 'n') {
// pass
} else if (_is_whitespace(cp, cat)) {
new_token = true; new_token = true;
} else { }
size_t dst_len = len; else
const char *dst_ptr = text + subpos; {
if (do_lower_case) { if(new_token)
dst_len = utf8proc_encode_char(cp, (utf8proc_uint8_t *) dst); {
dst_ptr = dst;
}
if (_is_punctuation(cp, cat) || _is_chinese_char(cp)) {
output_tokens->emplace_back(dst_ptr, dst_len); output_tokens->emplace_back(dst_ptr, dst_len);
new_token = true; new_token = false;
} else { }
if (new_token) { else
output_tokens->emplace_back(dst_ptr, dst_len); {
new_token = false; output_tokens->at(output_tokens->size() - 1).append(dst_ptr, dst_len);
} else {
output_tokens->at(output_tokens->size() - 1).append(dst_ptr, dst_len);
}
} }
} }
}
word_bytes = word_bytes - len; word_bytes = word_bytes - len;
subpos = subpos + len; subpos = subpos + len;
// early terminate // early terminate
if (output_tokens->size() >= max_length) { if(output_tokens->size() >= max_length)
break; {
} break;
} }
}
if (do_lower_case) { if(do_lower_case)
free((void *) text); {
} free((void*)text);
} }
}
void WordpieceTokenizer::tokenize(const std::string& token, std::vector<std::string>* output_tokens)
{
if(token.size() > max_input_chars_per_word)
{ // FIXME: slightly different
output_tokens->push_back(unk_token);
return;
}
size_t output_tokens_len = output_tokens->size();
for(size_t start = 0; start < token.size();)
{
bool is_bad = true;
// TODO: can be optimized by prefix-tree
for(size_t end = token.size(); start < end; --end)
{ // FIXME: slightly different
std::string substr = start > 0 ? "##" + token.substr(start, end - start)
: token.substr(start, end - start);
if(vocab->count(substr))
{
is_bad = false;
output_tokens->push_back(substr);
start = end;
break;
}
}
void WordpieceTokenizer::tokenize(const std::string &token, std::vector<std::string> *output_tokens) { if(is_bad)
if (token.size() > max_input_chars_per_word) { // FIXME: slightly different {
output_tokens->resize(output_tokens_len);
output_tokens->push_back(unk_token); output_tokens->push_back(unk_token);
return; return;
} }
size_t output_tokens_len = output_tokens->size();
for (size_t start = 0; start < token.size();) {
bool is_bad = true;
// TODO: can be optimized by prefix-tree
for (size_t end = token.size(); start < end; --end) { // FIXME: slightly different
std::string substr = start > 0
? "##" + token.substr(start, end - start)
: token.substr(start, end - start);
if (vocab->count(substr)) {
is_bad = false;
output_tokens->push_back(substr);
start = end;
break;
}
}
if (is_bad) {
output_tokens->resize(output_tokens_len);
output_tokens->push_back(unk_token);
return;
}
}
} }
}
void FullTokenizer::tokenize(const char* text,
void FullTokenizer::tokenize(const char *text, std::vector<std::string> *output_tokens, size_t max_length) { std::vector<std::string>* output_tokens,
std::vector<std::string> tokens; size_t max_length)
tokens.reserve(max_length); {
basic_tokenizer->tokenize(text, &tokens, max_length); std::vector<std::string> tokens;
tokens.reserve(max_length);
for (const auto &token : tokens) { basic_tokenizer->tokenize(text, &tokens, max_length);
wordpiece_tokenizer->tokenize(token, output_tokens);
for(const auto& token : tokens)
// early terminate {
if (output_tokens->size() >= max_length) { wordpiece_tokenizer->tokenize(token, output_tokens);
break;
} // early terminate
if(output_tokens->size() >= max_length)
{
break;
} }
} }
} }
} // namespace cuBERT
...@@ -6,158 +6,172 @@ ...@@ -6,158 +6,172 @@
#include <unordered_map> #include <unordered_map>
#include <iostream> #include <iostream>
namespace cuBERT { namespace cuBERT
{
void load_vocab(const char *vocab_file, std::unordered_map<std::string, uint64_t> *vocab); void load_vocab(const char* vocab_file, std::unordered_map<std::string, uint64_t>* vocab);
/** /**
* Checks whether `chars` is a whitespace character. * Checks whether `chars` is a whitespace character.
* @param c * @param c
* @return * @return
*/ */
bool _is_whitespace(int c); bool _is_whitespace(int c);
/** /**
* Checks whether `chars` is a control character. * Checks whether `chars` is a control character.
* @param c * @param c
* @return * @return
*/ */
bool _is_control(int c); bool _is_control(int c);
/** /**
* Checks whether `chars` is a punctuation character. * Checks whether `chars` is a punctuation character.
* @param cp * @param cp
* @return * @return
*/ */
bool _is_punctuation(int cp); bool _is_punctuation(int cp);
/** /**
* Runs basic tokenization (punctuation splitting, lower casing, etc.). * Runs basic tokenization (punctuation splitting, lower casing, etc.).
*/ */
class BasicTokenizer { class BasicTokenizer
{
public: public:
/** /**
* Constructs a BasicTokenizer. * Constructs a BasicTokenizer.
* @param do_lower_case Whether to lower case the input. * @param do_lower_case Whether to lower case the input.
*/ */
explicit BasicTokenizer(bool do_lower_case = true) : do_lower_case(do_lower_case) {} explicit BasicTokenizer(bool do_lower_case = true) : do_lower_case(do_lower_case) {}
BasicTokenizer(const BasicTokenizer &other) = delete; BasicTokenizer(const BasicTokenizer& other) = delete;
virtual ~BasicTokenizer() = default; virtual ~BasicTokenizer() = default;
/** /**
* Tokenizes a piece of text. * Tokenizes a piece of text.
* *
* to_lower * to_lower
* _run_strip_accents Strips accents from a piece of text. * _run_strip_accents Strips accents from a piece of text.
* _clean_text Performs invalid character removal and whitespace cleanup on text. * _clean_text Performs invalid character removal and whitespace cleanup on text.
* _tokenize_chinese_chars Adds whitespace around any CJK character. * _tokenize_chinese_chars Adds whitespace around any CJK character.
* _run_split_on_punc Splits punctuation on a piece of text. * _run_split_on_punc Splits punctuation on a piece of text.
* whitespace_tokenize Runs basic whitespace cleaning and splitting on a piece of text. * whitespace_tokenize Runs basic whitespace cleaning and splitting on a piece of text.
* *
* @param text * @param text
* @param output_tokens * @param output_tokens
*/ */
void tokenize(const char *text, std::vector<std::string> *output_tokens, size_t max_length); void tokenize(const char* text, std::vector<std::string>* output_tokens, size_t max_length);
private: private:
const bool do_lower_case; const bool do_lower_case;
/** /**
* Checks whether CP is the codepoint of a CJK character. * Checks whether CP is the codepoint of a CJK character.
* @param cp * @param cp
* @return * @return
*/ */
inline static bool _is_chinese_char(int cp); inline static bool _is_chinese_char(int cp);
}; };
/** /**
* Runs WordPiece tokenziation. * Runs WordPiece tokenziation.
*/ */
class WordpieceTokenizer { class WordpieceTokenizer
{
public: public:
explicit WordpieceTokenizer( explicit WordpieceTokenizer(std::unordered_map<std::string, uint64_t>* vocab,
std::unordered_map<std::string, uint64_t> *vocab, std::string unk_token = "[UNK]",
std::string unk_token = "[UNK]", int max_input_chars_per_word = 200)
int max_input_chars_per_word = 200 : vocab(vocab), unk_token(unk_token), max_input_chars_per_word(max_input_chars_per_word)
) : vocab(vocab), unk_token(unk_token), max_input_chars_per_word(max_input_chars_per_word) {} {
}
WordpieceTokenizer(const WordpieceTokenizer &other) = delete;
WordpieceTokenizer(const WordpieceTokenizer& other) = delete;
virtual ~WordpieceTokenizer() = default;
virtual ~WordpieceTokenizer() = default;
/**
* Tokenizes a piece of text into its word pieces. /**
* * Tokenizes a piece of text into its word pieces.
* This uses a greedy longest-match-first algorithm to perform tokenization *
* using the given vocabulary. * This uses a greedy longest-match-first algorithm to perform tokenization
* * using the given vocabulary.
* For example: *
* input = "unaffable" * For example:
* output = ["un", "##aff", "##able"] * input = "unaffable"
* * output = ["un", "##aff", "##able"]
* @param text A single token or whitespace separated tokens. This should have already been passed through `BasicTokenizer. *
* @param output_tokens A list of wordpiece tokens. * @param text A single token or whitespace separated tokens. This should have already been
*/ * passed through `BasicTokenizer.
void tokenize(const std::string &text, std::vector<std::string> *output_tokens); * @param output_tokens A list of wordpiece tokens.
*/
void tokenize(const std::string& text, std::vector<std::string>* output_tokens);
private: private:
const std::unordered_map<std::string, uint64_t> *vocab; const std::unordered_map<std::string, uint64_t>* vocab;
const std::string unk_token; const std::string unk_token;
const int max_input_chars_per_word; const int max_input_chars_per_word;
}; };
/** /**
* Runs end-to-end tokenziation. * Runs end-to-end tokenziation.
*/ */
class FullTokenizer { class FullTokenizer
{
public: public:
FullTokenizer(const char *vocab_file, bool do_lower_case = true) { FullTokenizer(const char* vocab_file, bool do_lower_case = true)
vocab = new std::unordered_map<std::string, uint64_t>(); {
load_vocab(vocab_file, vocab); vocab = new std::unordered_map<std::string, uint64_t>();
basic_tokenizer = new BasicTokenizer(do_lower_case); load_vocab(vocab_file, vocab);
wordpiece_tokenizer = new WordpieceTokenizer(vocab); basic_tokenizer = new BasicTokenizer(do_lower_case);
wordpiece_tokenizer = new WordpieceTokenizer(vocab);
}
~FullTokenizer()
{
if(wordpiece_tokenizer != NULL)
{
wordpiece_tokenizer = NULL;
} }
delete wordpiece_tokenizer;
~FullTokenizer() { if(basic_tokenizer != NULL)
if (wordpiece_tokenizer != NULL){ {
wordpiece_tokenizer = NULL; basic_tokenizer = NULL;
}
delete wordpiece_tokenizer;
if (basic_tokenizer != NULL){
basic_tokenizer = NULL;
}
delete basic_tokenizer;
if (vocab != NULL){
vocab = NULL;
}
delete vocab;
} }
delete basic_tokenizer;
void tokenize(const char *text, std::vector<std::string> *output_tokens, size_t max_length); if(vocab != NULL)
{
inline uint64_t convert_token_to_id(const std::string &token) { vocab = NULL;
auto item = vocab->find(token); }
if (item == vocab->end()) { delete vocab;
std::cerr << "vocab missing key: " << token << std::endl; }
return 0;
} else { void tokenize(const char* text, std::vector<std::string>* output_tokens, size_t max_length);
return item->second;
} inline uint64_t convert_token_to_id(const std::string& token)
{
auto item = vocab->find(token);
if(item == vocab->end())
{
std::cerr << "vocab missing key: " << token << std::endl;
return 0;
}
else
{
return item->second;
} }
}
void convert_tokens_to_ids(const std::vector<std::string> &tokens, uint64_t *ids); void convert_tokens_to_ids(const std::vector<std::string>& tokens, uint64_t* ids);
private: private:
std::unordered_map<std::string, uint64_t> *vocab; std::unordered_map<std::string, uint64_t>* vocab;
BasicTokenizer *basic_tokenizer; BasicTokenizer* basic_tokenizer;
WordpieceTokenizer *wordpiece_tokenizer; WordpieceTokenizer* wordpiece_tokenizer;
}; };
} } // namespace cuBERT
#endif //CUBERT_TOKENIZATION_H #endif // CUBERT_TOKENIZATION_H
This diff is collapsed.
This diff is collapsed.
...@@ -12,7 +12,7 @@ int main() ...@@ -12,7 +12,7 @@ int main()
// 加载GPT2模型 // 加载GPT2模型
migraphxSamples::GPT2 gpt2; migraphxSamples::GPT2 gpt2;
migraphxSamples::ErrorCode errorCode = gpt2.Initialize(); migraphxSamples::ErrorCode errorCode = gpt2.Initialize();
if (errorCode != migraphxSamples::SUCCESS) if(errorCode != migraphxSamples::SUCCESS)
{ {
LOG_ERROR(stdout, "fail to initialize GPT2!\n"); LOG_ERROR(stdout, "fail to initialize GPT2!\n");
exit(-1); exit(-1);
...@@ -25,7 +25,7 @@ int main() ...@@ -25,7 +25,7 @@ int main()
std::string buf; std::string buf;
std::vector<std::string> output; std::vector<std::string> output;
infile.open("../Resource/vocab_shici.txt"); infile.open("../Resource/vocab_shici.txt");
while (std::getline(infile,buf)) while(std::getline(infile, buf))
{ {
output.push_back(buf); output.push_back(buf);
} }
...@@ -37,7 +37,7 @@ int main() ...@@ -37,7 +37,7 @@ int main()
std::vector<std::string> result; std::vector<std::string> result;
std::cout << "开始和GPT2对诗,输入CTRL + Z以退出" << std::endl; std::cout << "开始和GPT2对诗,输入CTRL + Z以退出" << std::endl;
while (true) while(true)
{ {
// 数据预处理 // 数据预处理
std::cout << "question: "; std::cout << "question: ";
...@@ -45,7 +45,7 @@ int main() ...@@ -45,7 +45,7 @@ int main()
gpt2.Preprocessing(tokenizer, question, input_id); gpt2.Preprocessing(tokenizer, question, input_id);
// 推理 // 推理
for(int i=0;i<50;++i) for(int i = 0; i < 50; ++i)
{ {
long unsigned int outputs = gpt2.Inference(input_id); long unsigned int outputs = gpt2.Inference(input_id);
if(outputs == 102) if(outputs == 102)
...@@ -57,7 +57,7 @@ int main() ...@@ -57,7 +57,7 @@ int main()
} }
// 将数值映射为字符 // 将数值映射为字符
for(int i=0;i<score.size();++i) for(int i = 0; i < score.size(); ++i)
{ {
result.push_back(output[score[i]]); result.push_back(output[score[i]]);
} }
...@@ -65,12 +65,12 @@ int main() ...@@ -65,12 +65,12 @@ int main()
// 打印结果 // 打印结果
std::cout << "chatbot: "; std::cout << "chatbot: ";
std::cout << question; std::cout << question;
for(int j=0; j<result.size();++j) for(int j = 0; j < result.size(); ++j)
{ {
std::cout << result[j]; std::cout << result[j];
} }
std::cout << std::endl; std::cout << std::endl;
// 清除数据 // 清除数据
input_id.clear(); input_id.clear();
result.clear(); result.clear();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment