Commit 0f9dc829 authored by liucong's avatar liucong
Browse files

重新格式化Cppd代码格式

parent 824cfb81
......@@ -12,90 +12,82 @@
namespace migraphxSamples
{
GPT2::GPT2()
{
}
GPT2::~GPT2()
{
GPT2::GPT2() {}
}
GPT2::~GPT2() {}
ErrorCode GPT2::Initialize()
{
// 获取模型文件
std::string modelPath="../Resource/GPT2_shici.onnx";
std::string modelPath = "../Resource/GPT2_shici.onnx";
// 设置最大输入shape
migraphx::onnx_options onnx_options;
onnx_options.map_input_dims["input"] = {1,1000};
onnx_options.map_input_dims["input"] = {1, 1000};
// 加载模型
if(!Exists(modelPath))
{
LOG_ERROR(stdout,"%s not exist!\n",modelPath.c_str());
LOG_ERROR(stdout, "%s not exist!\n", modelPath.c_str());
return MODEL_NOT_EXIST;
}
net = migraphx::parse_onnx(modelPath, onnx_options);
LOG_INFO(stdout,"succeed to load model: %s\n",GetFileName(modelPath).c_str());
net = migraphx::parse_onnx(modelPath, onnx_options);
LOG_INFO(stdout, "succeed to load model: %s\n", GetFileName(modelPath).c_str());
// 获取模型输入/输出节点信息
std::unordered_map<std::string, migraphx::shape> inputs = net.get_inputs();
std::unordered_map<std::string, migraphx::shape> inputs = net.get_inputs();
std::unordered_map<std::string, migraphx::shape> outputs = net.get_outputs();
inputName=inputs.begin()->first;
inputShape=inputs.begin()->second;
inputName = inputs.begin()->first;
inputShape = inputs.begin()->second;
// 设置模型为GPU模式
migraphx::target gpuTarget = migraphx::gpu::target{};
// 编译模型
migraphx::compile_options options;
options.device_id = 0; // 设置GPU设备,默认为0号设备
options.offload_copy = true; // 设置offload_copy
net.compile(gpuTarget,options);
LOG_INFO(stdout,"succeed to compile model: %s\n",GetFileName(modelPath).c_str());
options.device_id = 0; // 设置GPU设备,默认为0号设备
options.offload_copy = true; // 设置offload_copy
net.compile(gpuTarget, options);
LOG_INFO(stdout, "succeed to compile model: %s\n", GetFileName(modelPath).c_str());
return SUCCESS;
}
static bool CompareM(Predictions a, Predictions b)
{
return a.predictionvalue > b.predictionvalue;
}
static bool CompareM(Predictions a, Predictions b) { return a.predictionvalue > b.predictionvalue; }
long unsigned int GPT2::Inference(const std::vector<long unsigned int> &input_id)
long unsigned int GPT2::Inference(const std::vector<long unsigned int>& input_id)
{
long unsigned int input[1][input_id.size()];
for (int j=0;j<input_id.size();++j)
for(int j = 0; j < input_id.size(); ++j)
{
input[0][j] = input_id[j];
}
// 设置输入shape
std::vector<std::vector<std::size_t>> inputShapes;
inputShapes.push_back({1,input_id.size()});
inputShapes.push_back({1, input_id.size()});
// 创建输入数据
std::unordered_map<std::string, migraphx::argument> inputData;
inputData[inputName]=migraphx::argument{migraphx::shape(inputShape.type(),inputShapes[0]),(long unsigned int*)input};
inputData[inputName] = migraphx::argument{migraphx::shape(inputShape.type(), inputShapes[0]),
(long unsigned int*)input};
// 推理
std::vector<migraphx::argument> results = net.eval(inputData);
// 获取输出节点的属性
migraphx::argument result = results[0];
migraphx::shape outputShape = result.get_shape(); // 输出节点的shape
int numberOfOutput=outputShape.elements(); // 输出节点元素的个数
float *data = (float *)result.data(); // 输出节点数据指针
migraphx::argument result = results[0];
migraphx::shape outputShape = result.get_shape(); // 输出节点的shape
int numberOfOutput = outputShape.elements(); // 输出节点元素的个数
float* data = (float*)result.data(); // 输出节点数据指针
// 保存推理结果
long unsigned int n = 0;
std::vector<Predictions> resultsOfPredictions(22557);
for(int i=(input_id.size()-1)*22557; i<input_id.size()*22557; ++i)
for(int i = (input_id.size() - 1) * 22557; i < input_id.size() * 22557; ++i)
{
resultsOfPredictions[n].index = n;
resultsOfPredictions[n].index = n;
resultsOfPredictions[n].predictionvalue = data[i];
++n;
}
......@@ -110,8 +102,8 @@ long unsigned int GPT2::Inference(const std::vector<long unsigned int> &input_id
}
ErrorCode GPT2::Preprocessing(cuBERT::FullTokenizer tokenizer,
char *question,
std::vector<long unsigned int> &input_id)
char* question,
std::vector<long unsigned int>& input_id)
{
// 分词操作
int max_seq_length = 1000;
......@@ -121,7 +113,7 @@ ErrorCode GPT2::Preprocessing(cuBERT::FullTokenizer tokenizer,
// 保存编码信息
input_id.push_back(tokenizer.convert_token_to_id("[CLS]"));
for (int i=0;i<tokens_question.size();++i)
for(int i = 0; i < tokens_question.size(); ++i)
{
input_id.push_back(tokenizer.convert_token_to_id(tokens_question[i]));
}
......@@ -129,4 +121,4 @@ ErrorCode GPT2::Preprocessing(cuBERT::FullTokenizer tokenizer,
return SUCCESS;
}
}
\ No newline at end of file
} // namespace migraphxSamples
\ No newline at end of file
......@@ -8,42 +8,42 @@
namespace migraphxSamples
{
typedef enum _ErrorCode
{
SUCCESS=0,
MODEL_NOT_EXIST,
CONFIG_FILE_NOT_EXIST,
FAIL_TO_LOAD_MODEL,
FAIL_TO_OPEN_CONFIG_FILE,
}ErrorCode;
typedef struct _Predictions
{
long unsigned int index;
float predictionvalue;
}Predictions;
typedef enum _ErrorCode
{
SUCCESS = 0,
MODEL_NOT_EXIST,
CONFIG_FILE_NOT_EXIST,
FAIL_TO_LOAD_MODEL,
FAIL_TO_OPEN_CONFIG_FILE,
} ErrorCode;
typedef struct _Predictions
{
long unsigned int index;
float predictionvalue;
} Predictions;
class GPT2
{
public:
public:
GPT2();
~GPT2();
ErrorCode Initialize();
ErrorCode Preprocessing(cuBERT::FullTokenizer tokenizer,
char *question,
std::vector<long unsigned int> &input_id);
char* question,
std::vector<long unsigned int>& input_id);
long unsigned int Inference(const std::vector<long unsigned int> &input_id);
long unsigned int Inference(const std::vector<long unsigned int>& input_id);
private:
private:
migraphx::program net;
std::string inputName;
migraphx::shape inputShape;
};
}
} // namespace migraphxSamples
#endif
\ No newline at end of file
This diff is collapsed.
......@@ -5,27 +5,27 @@
#include <string>
#include <vector>
namespace migraphxSamples
{
// 路径是否存在
bool Exists(const std::string &path);
bool Exists(const std::string& path);
// 路径是否为目录
bool IsDirectory(const std::string &path);
bool IsDirectory(const std::string& path);
// 是否是路径分隔符(Linux:‘/’,Windows:’\\’)
bool IsPathSeparator(char c);
// 路径拼接
std::string JoinPath(const std::string &base, const std::string &path);
std::string JoinPath(const std::string& base, const std::string& path);
// 创建多级目录,注意:创建多级目录的时候,目标目录是不能有文件存在的
bool CreateDirectories(const std::string &directoryPath);
bool CreateDirectories(const std::string& directoryPath);
/** 生成符合指定模式的文件名列表(支持递归遍历)
*
*
* pattern: 模式,比如"*.jpg","*.png","*.jpg,*.png"
* addPath:是否包含父路径
* 注意:
......@@ -36,35 +36,43 @@ bool CreateDirectories(const std::string &directoryPath);
5. 不能返回子目录名
*
*/
void GetFileNameList(const std::string &directory, const std::string &pattern, std::vector<std::string> &result, bool recursive, bool addPath);
void GetFileNameList(const std::string& directory,
const std::string& pattern,
std::vector<std::string>& result,
bool recursive,
bool addPath);
// 与GetFileNameList的区别在于如果有子目录,在addPath为true的时候会返回子目录路径(目录名最后有"/")
void GetFileNameList2(const std::string &directory, const std::string &pattern, std::vector<std::string> &result, bool recursive, bool addPath);
void GetFileNameList2(const std::string& directory,
const std::string& pattern,
std::vector<std::string>& result,
bool recursive,
bool addPath);
// 删除文件或者目录,支持递归删除
void Remove(const std::string &directory, const std::string &extension="");
void Remove(const std::string& directory, const std::string& extension = "");
/** 获取路径的文件名和扩展名
*
*
* 示例:path为D:/1/1.txt,则GetFileName()为1.txt,GetFileName_NoExtension()为1,GetExtension()为.txt,GetParentPath()为D:/1/
*/
std::string GetFileName(const std::string &path);
std::string GetFileName_NoExtension(const std::string &path);
std::string GetExtension(const std::string &path);
std::string GetParentPath(const std::string &path);
*/
std::string GetFileName(const std::string& path);
std::string GetFileName_NoExtension(const std::string& path);
std::string GetExtension(const std::string& path);
std::string GetParentPath(const std::string& path);
// 拷贝文件
bool CopyFile(const std::string srcPath,const std::string dstPath);
bool CopyFile(const std::string srcPath, const std::string dstPath);
/** 拷贝目录
*
*
* 示例:CopyDirectories("D:/0/1/2/","E:/3/");实现把D:/0/1/2/目录拷贝到E:/3/目录中(即拷贝完成后的目录结构为E:/3/2/)
* 注意:
1.第一个参数的最后不能加”/”
2.不能拷贝隐藏文件
*/
bool CopyDirectories(std::string srcPath,const std::string dstPath);
bool CopyDirectories(std::string srcPath, const std::string dstPath);
}
} // namespace migraphxSamples
#endif
......@@ -8,7 +8,7 @@
#include <map>
#include <thread>
#include <mutex>
#if (defined WIN32 || defined _WIN32)
#if(defined WIN32 || defined _WIN32)
#include <Windows.h>
#else
#include <sys/time.h>
......@@ -16,13 +16,13 @@
using namespace std;
/** 简易日志
*
*
* 不依赖于其他第三方库,只需要包含一个头文件就可以使用。提供了4种日志级别,包括INFO,DEBUG,WARN和ERROR。
*
*
* 示例1:
// 初始化日志,在./Log/目录下创建两个日志文件log1.log和log2.log(注意:目录./Log/需要存在,否则日志创建失败)
//
初始化日志,在./Log/目录下创建两个日志文件log1.log和log2.log(注意:目录./Log/需要存在,否则日志创建失败)
LogManager::GetInstance()->Initialize("./Log/","log1");
LogManager::GetInstance()->Initialize("./Log/","log2");
......@@ -34,11 +34,11 @@ using namespace std;
// 关闭日志
LogManager::GetInstance()->Close("log1");
LogManager::GetInstance()->Close("log2");
* 示例2:
// 将日志输出到控制台
string log = "Hello World";
LOG_INFO(stdout, "%s\n", log.c_str());
LOG_INFO(stdout, "%s\n", log.c_str());
* 注意:
1. 需要C++11
......@@ -50,44 +50,43 @@ using namespace std;
class LogManager
{
private:
LogManager(){}
private:
LogManager() {}
public:
~LogManager(){}
inline void Initialize(const string &parentPath,const string &logName)
public:
~LogManager() {}
inline void Initialize(const string& parentPath, const string& logName)
{
// 日志名为空表示输出到控制台
if(logName.size()==0)
if(logName.size() == 0)
return;
// 查找该日志文件,如果没有则创建
std::map<string, FILE*>::const_iterator iter = logMap.find(logName);
if (iter == logMap.end())
if(iter == logMap.end())
{
string pathOfLog = parentPath+ logName + ".log";
FILE *logFile = fopen(pathOfLog.c_str(), "a"); // w:覆盖原有文件,a:追加
if(logFile!=NULL)
string pathOfLog = parentPath + logName + ".log";
FILE* logFile = fopen(pathOfLog.c_str(), "a"); // w:覆盖原有文件,a:追加
if(logFile != NULL)
{
logMap.insert(std::make_pair(logName, logFile));
}
}
}
inline FILE* GetLogFile(const string &logName)
inline FILE* GetLogFile(const string& logName)
{
std::map<string, FILE*>::const_iterator iter=logMap.find(logName);
if(iter==logMap.end())
std::map<string, FILE*>::const_iterator iter = logMap.find(logName);
if(iter == logMap.end())
{
return NULL;
}
return (*iter).second;
}
inline void Close(const string &logName)
inline void Close(const string& logName)
{
std::map<string, FILE*>::const_iterator iter=logMap.find(logName);
if(iter==logMap.end())
std::map<string, FILE*>::const_iterator iter = logMap.find(logName);
if(iter == logMap.end())
{
return;
}
......@@ -95,10 +94,7 @@ public:
fclose((*iter).second);
logMap.erase(iter);
}
inline std::mutex &GetLogMutex()
{
return logMutex;
}
inline std::mutex& GetLogMutex() { return logMutex; }
// Singleton
static LogManager* GetInstance()
......@@ -106,21 +102,22 @@ public:
static LogManager logManager;
return &logManager;
}
private:
private:
std::map<string, FILE*> logMap;
std::mutex logMutex;
};
#ifdef LOG_MUTEX
#define LOCK LogManager::GetInstance()->GetLogMutex().lock()
#define UNLOCK LogManager::GetInstance()->GetLogMutex().unlock()
#define LOCK LogManager::GetInstance()->GetLogMutex().lock()
#define UNLOCK LogManager::GetInstance()->GetLogMutex().unlock()
#else
#define LOCK
#define UNLOCK
#define LOCK
#define UNLOCK
#endif
// log time
typedef struct _LogTime
typedef struct _LogTime
{
string year;
string month;
......@@ -131,53 +128,53 @@ typedef struct _LogTime
string millisecond; // ms
string microsecond; // us
string weekDay;
}LogTime;
} LogTime;
inline LogTime GetTime()
{
LogTime currentTime;
#if (defined WIN32 || defined _WIN32)
#if(defined WIN32 || defined _WIN32)
SYSTEMTIME systemTime;
GetLocalTime(&systemTime);
char temp[8] = { 0 };
char temp[8] = {0};
sprintf(temp, "%04d", systemTime.wYear);
currentTime.year=string(temp);
currentTime.year = string(temp);
sprintf(temp, "%02d", systemTime.wMonth);
currentTime.month=string(temp);
currentTime.month = string(temp);
sprintf(temp, "%02d", systemTime.wDay);
currentTime.day=string(temp);
currentTime.day = string(temp);
sprintf(temp, "%02d", systemTime.wHour);
currentTime.hour=string(temp);
currentTime.hour = string(temp);
sprintf(temp, "%02d", systemTime.wMinute);
currentTime.minute=string(temp);
currentTime.minute = string(temp);
sprintf(temp, "%02d", systemTime.wSecond);
currentTime.second=string(temp);
currentTime.second = string(temp);
sprintf(temp, "%03d", systemTime.wMilliseconds);
currentTime.millisecond=string(temp);
currentTime.millisecond = string(temp);
sprintf(temp, "%d", systemTime.wDayOfWeek);
currentTime.weekDay=string(temp);
currentTime.weekDay = string(temp);
#else
struct timeval tv;
struct tm *p;
struct timeval tv;
struct tm* p;
gettimeofday(&tv, NULL);
p = localtime(&tv.tv_sec);
char temp[8]={0};
sprintf(temp,"%04d",1900+p->tm_year);
currentTime.year=string(temp);
sprintf(temp,"%02d",1+p->tm_mon);
currentTime.month=string(temp);
sprintf(temp,"%02d",p->tm_mday);
currentTime.day=string(temp);
sprintf(temp,"%02d",p->tm_hour);
currentTime.hour=string(temp);
sprintf(temp,"%02d",p->tm_min);
currentTime.minute=string(temp);
sprintf(temp,"%02d",p->tm_sec);
currentTime.second=string(temp);
sprintf(temp,"%03d",(int)(tv.tv_usec/1000));
char temp[8] = {0};
sprintf(temp, "%04d", 1900 + p->tm_year);
currentTime.year = string(temp);
sprintf(temp, "%02d", 1 + p->tm_mon);
currentTime.month = string(temp);
sprintf(temp, "%02d", p->tm_mday);
currentTime.day = string(temp);
sprintf(temp, "%02d", p->tm_hour);
currentTime.hour = string(temp);
sprintf(temp, "%02d", p->tm_min);
currentTime.minute = string(temp);
sprintf(temp, "%02d", p->tm_sec);
currentTime.second = string(temp);
sprintf(temp, "%03d", (int)(tv.tv_usec / 1000));
currentTime.millisecond = string(temp);
sprintf(temp, "%03d", (int)(tv.tv_usec % 1000));
currentTime.microsecond = string(temp);
......@@ -187,61 +184,83 @@ inline LogTime GetTime()
return currentTime;
}
#define LOG_TIME(logFile) \
do\
{\
LogTime currentTime=GetTime(); \
fprintf(((logFile == NULL) ? stdout : logFile), "%s-%s-%s %s:%s:%s.%s\t",currentTime.year.c_str(),currentTime.month.c_str(),currentTime.day.c_str(),currentTime.hour.c_str(),currentTime.minute.c_str(),currentTime.second.c_str(),currentTime.millisecond.c_str()); \
}while (0)
#define LOG_INFO(logFile,logInfo, ...) \
do\
{\
LOCK; \
LOG_TIME(logFile); \
fprintf(((logFile == NULL) ? stdout : logFile), "INFO\t"); \
fprintf(((logFile == NULL) ? stdout : logFile), "[%s:%d (%s) ]: ", __FILE__, __LINE__, __FUNCTION__); \
fprintf(((logFile == NULL) ? stdout : logFile), logInfo, ## __VA_ARGS__); \
fflush(logFile); \
UNLOCK; \
} while (0)
#define LOG_DEBUG(logFile,logInfo, ...) \
do\
{\
LOCK; \
LOG_TIME(logFile);\
fprintf(((logFile==NULL)?stdout:logFile), "DEBUG\t"); \
fprintf(((logFile==NULL)?stdout:logFile), "[%s:%d (%s) ]: ", __FILE__, __LINE__, __FUNCTION__); \
fprintf(((logFile==NULL)?stdout:logFile),logInfo, ## __VA_ARGS__); \
fflush(logFile); \
UNLOCK; \
} while (0)
#define LOG_ERROR(logFile,logInfo, ...) \
do\
{\
LOCK; \
LOG_TIME(logFile);\
fprintf(((logFile==NULL)?stdout:logFile), "ERROR\t"); \
fprintf(((logFile==NULL)?stdout:logFile), "[%s:%d (%s) ]: ", __FILE__, __LINE__, __FUNCTION__); \
fprintf(((logFile==NULL)?stdout:logFile),logInfo, ## __VA_ARGS__); \
fflush(logFile); \
UNLOCK; \
} while (0)
#define LOG_WARN(logFile,logInfo, ...) \
do\
{\
LOCK; \
LOG_TIME(logFile);\
fprintf(((logFile==NULL)?stdout:logFile), "WARN\t"); \
fprintf(((logFile==NULL)?stdout:logFile), "[%s:%d (%s) ]: ", __FILE__, __LINE__, __FUNCTION__); \
fprintf(((logFile==NULL)?stdout:logFile),logInfo, ## __VA_ARGS__); \
fflush(logFile); \
UNLOCK; \
} while (0)
#define LOG_TIME(logFile) \
do \
{ \
LogTime currentTime = GetTime(); \
fprintf(((logFile == NULL) ? stdout : logFile), \
"%s-%s-%s %s:%s:%s.%s\t", \
currentTime.year.c_str(), \
currentTime.month.c_str(), \
currentTime.day.c_str(), \
currentTime.hour.c_str(), \
currentTime.minute.c_str(), \
currentTime.second.c_str(), \
currentTime.millisecond.c_str()); \
} while(0)
#endif // __SIMPLE_LOG_H__
#define LOG_INFO(logFile, logInfo, ...) \
do \
{ \
LOCK; \
LOG_TIME(logFile); \
fprintf(((logFile == NULL) ? stdout : logFile), "INFO\t"); \
fprintf(((logFile == NULL) ? stdout : logFile), \
"[%s:%d (%s) ]: ", \
__FILE__, \
__LINE__, \
__FUNCTION__); \
fprintf(((logFile == NULL) ? stdout : logFile), logInfo, ##__VA_ARGS__); \
fflush(logFile); \
UNLOCK; \
} while(0)
#define LOG_DEBUG(logFile, logInfo, ...) \
do \
{ \
LOCK; \
LOG_TIME(logFile); \
fprintf(((logFile == NULL) ? stdout : logFile), "DEBUG\t"); \
fprintf(((logFile == NULL) ? stdout : logFile), \
"[%s:%d (%s) ]: ", \
__FILE__, \
__LINE__, \
__FUNCTION__); \
fprintf(((logFile == NULL) ? stdout : logFile), logInfo, ##__VA_ARGS__); \
fflush(logFile); \
UNLOCK; \
} while(0)
#define LOG_ERROR(logFile, logInfo, ...) \
do \
{ \
LOCK; \
LOG_TIME(logFile); \
fprintf(((logFile == NULL) ? stdout : logFile), "ERROR\t"); \
fprintf(((logFile == NULL) ? stdout : logFile), \
"[%s:%d (%s) ]: ", \
__FILE__, \
__LINE__, \
__FUNCTION__); \
fprintf(((logFile == NULL) ? stdout : logFile), logInfo, ##__VA_ARGS__); \
fflush(logFile); \
UNLOCK; \
} while(0)
#define LOG_WARN(logFile, logInfo, ...) \
do \
{ \
LOCK; \
LOG_TIME(logFile); \
fprintf(((logFile == NULL) ? stdout : logFile), "WARN\t"); \
fprintf(((logFile == NULL) ? stdout : logFile), \
"[%s:%d (%s) ]: ", \
__FILE__, \
__LINE__, \
__FUNCTION__); \
fprintf(((logFile == NULL) ? stdout : logFile), logInfo, ##__VA_ARGS__); \
fflush(logFile); \
UNLOCK; \
} while(0)
#endif // __SIMPLE_LOG_H__
......@@ -6,224 +6,257 @@
#include "./tokenization.h"
namespace cuBERT {
void FullTokenizer::convert_tokens_to_ids(const std::vector<std::string> &tokens, uint64_t *ids) {
for (int i = 0; i < tokens.size(); ++i) {
ids[i] = convert_token_to_id(tokens[i]);
}
namespace cuBERT
{
void FullTokenizer::convert_tokens_to_ids(const std::vector<std::string>& tokens, uint64_t* ids)
{
for(int i = 0; i < tokens.size(); ++i)
{
ids[i] = convert_token_to_id(tokens[i]);
}
}
// trim from start (in place)
static inline void ltrim(std::string &s) {
s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int ch) {
return !std::isspace(ch);
}));
}
static inline void ltrim(std::string& s)
{
s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int ch) { return !std::isspace(ch); }));
}
// trim from end (in place)
static inline void rtrim(std::string &s) {
s.erase(std::find_if(s.rbegin(), s.rend(), [](int ch) {
return !std::isspace(ch);
}).base(), s.end());
}
static inline void rtrim(std::string& s)
{
s.erase(std::find_if(s.rbegin(), s.rend(), [](int ch) { return !std::isspace(ch); }).base(),
s.end());
}
// trim from both ends (in place)
static inline void trim(std::string &s) {
ltrim(s);
rtrim(s);
}
void load_vocab(const char *vocab_file, std::unordered_map<std::string, uint64_t> *vocab) {
std::ifstream file(vocab_file);
if (!file) {
throw std::invalid_argument("Unable to open vocab file");
}
unsigned int index = 0;
std::string line;
while (std::getline(file, line)) {
trim(line);
(*vocab)[line] = index;
index++;
}
static inline void trim(std::string& s)
{
ltrim(s);
rtrim(s);
}
file.close();
void load_vocab(const char* vocab_file, std::unordered_map<std::string, uint64_t>* vocab)
{
std::ifstream file(vocab_file);
if(!file)
{
throw std::invalid_argument("Unable to open vocab file");
}
inline bool _is_whitespace(int c, const char *cat) {
if (c == ' ' || c == '\t' || c == '\n' || c == '\r') {
return true;
}
return cat[0] == 'Z' && cat[1] == 's';
unsigned int index = 0;
std::string line;
while(std::getline(file, line))
{
trim(line);
(*vocab)[line] = index;
index++;
}
inline bool _is_control(int c, const char *cat) {
// These are technically control characters but we count them as whitespace characters.
if (c == '\t' || c == '\n' || c == '\r') {
return false;
}
return 'C' == *cat;
}
file.close();
}
inline bool _is_punctuation(int cp, const char *cat) {
// We treat all non-letter/number ASCII as punctuation.
// Characters such as "^", "$", and "`" are not in the Unicode
// Punctuation class but we treat them as punctuation anyways, for
// consistency.
if ((cp >= 33 && cp <= 47) || (cp >= 58 && cp <= 64) ||
(cp >= 91 && cp <= 96) || (cp >= 123 && cp <= 126)) {
return true;
}
return 'P' == *cat;
inline bool _is_whitespace(int c, const char* cat)
{
if(c == ' ' || c == '\t' || c == '\n' || c == '\r')
{
return true;
}
return cat[0] == 'Z' && cat[1] == 's';
}
bool _is_whitespace(int c) {
return _is_whitespace(c, utf8proc_category_string(c));
inline bool _is_control(int c, const char* cat)
{
// These are technically control characters but we count them as whitespace characters.
if(c == '\t' || c == '\n' || c == '\r')
{
return false;
}
return 'C' == *cat;
}
bool _is_control(int c) {
return _is_control(c, utf8proc_category_string(c));
inline bool _is_punctuation(int cp, const char* cat)
{
// We treat all non-letter/number ASCII as punctuation.
// Characters such as "^", "$", and "`" are not in the Unicode
// Punctuation class but we treat them as punctuation anyways, for
// consistency.
if((cp >= 33 && cp <= 47) || (cp >= 58 && cp <= 64) || (cp >= 91 && cp <= 96) ||
(cp >= 123 && cp <= 126))
{
return true;
}
return 'P' == *cat;
}
bool _is_punctuation(int cp) {
return _is_punctuation(cp, utf8proc_category_string(cp));
}
bool _is_whitespace(int c) { return _is_whitespace(c, utf8proc_category_string(c)); }
bool _is_control(int c) { return _is_control(c, utf8proc_category_string(c)); }
bool _is_punctuation(int cp) { return _is_punctuation(cp, utf8proc_category_string(cp)); }
bool BasicTokenizer::_is_chinese_char(int cp)
{
// This defines a "chinese character" as anything in the CJK Unicode block:
// https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
//
// Note that the CJK Unicode block is NOT all Japanese and Korean characters,
// despite its name. The modern Korean Hangul alphabet is a different block,
// as is Japanese Hiragana and Katakana. Those alphabets are used to write
// space-separated words, so they are not treated specially and handled
// like the all of the other languages.
return (cp >= 0x4E00 && cp <= 0x9FFF) || (cp >= 0x3400 && cp <= 0x4DBF) ||
(cp >= 0x20000 && cp <= 0x2A6DF) || (cp >= 0x2A700 && cp <= 0x2B73F) ||
(cp >= 0x2B740 && cp <= 0x2B81F) || (cp >= 0x2B820 && cp <= 0x2CEAF) ||
(cp >= 0xF900 && cp <= 0xFAFF) || (cp >= 0x2F800 && cp <= 0x2FA1F);
}
bool BasicTokenizer::_is_chinese_char(int cp) {
// This defines a "chinese character" as anything in the CJK Unicode block:
// https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
//
// Note that the CJK Unicode block is NOT all Japanese and Korean characters,
// despite its name. The modern Korean Hangul alphabet is a different block,
// as is Japanese Hiragana and Katakana. Those alphabets are used to write
// space-separated words, so they are not treated specially and handled
// like the all of the other languages.
return (cp >= 0x4E00 && cp <= 0x9FFF) ||
(cp >= 0x3400 && cp <= 0x4DBF) ||
(cp >= 0x20000 && cp <= 0x2A6DF) ||
(cp >= 0x2A700 && cp <= 0x2B73F) ||
(cp >= 0x2B740 && cp <= 0x2B81F) ||
(cp >= 0x2B820 && cp <= 0x2CEAF) ||
(cp >= 0xF900 && cp <= 0xFAFF) ||
(cp >= 0x2F800 && cp <= 0x2FA1F);
void BasicTokenizer::tokenize(const char* text,
std::vector<std::string>* output_tokens,
size_t max_length)
{
// This was added on November 1st, 2018 for the multilingual and Chinese
// models. This is also applied to the English models now, but it doesn't
// matter since the English models were not trained on any Chinese data
// and generally don't have any Chinese data in them (there are Chinese
// characters in the vocabulary because Wikipedia does have some Chinese
// words in the English Wikipedia.).
if(do_lower_case)
{
text = (const char*)utf8proc_NFD((const utf8proc_uint8_t*)text);
}
void BasicTokenizer::tokenize(const char *text, std::vector<std::string> *output_tokens, size_t max_length) {
// This was added on November 1st, 2018 for the multilingual and Chinese
// models. This is also applied to the English models now, but it doesn't
// matter since the English models were not trained on any Chinese data
// and generally don't have any Chinese data in them (there are Chinese
// characters in the vocabulary because Wikipedia does have some Chinese
// words in the English Wikipedia.).
if (do_lower_case) {
text = (const char *) utf8proc_NFD((const utf8proc_uint8_t *) text);
size_t word_bytes = std::strlen(text);
bool new_token = true;
size_t subpos = 0;
int cp;
char dst[4];
while(word_bytes > 0)
{
int len = utf8proc_iterate((const utf8proc_uint8_t*)text + subpos, word_bytes, &cp);
if(len < 0)
{
std::cerr << "UTF-8 decode error: " << text << std::endl;
break;
}
if(do_lower_case)
{
cp = utf8proc_tolower(cp);
}
size_t word_bytes = std::strlen(text);
bool new_token = true;
size_t subpos = 0;
int cp;
char dst[4];
while (word_bytes > 0) {
int len = utf8proc_iterate((const utf8proc_uint8_t *) text + subpos, word_bytes, &cp);
if (len < 0) {
std::cerr << "UTF-8 decode error: " << text << std::endl;
break;
}
if (do_lower_case) {
cp = utf8proc_tolower(cp);
const char* cat = utf8proc_category_string(cp);
if(cp == 0 || cp == 0xfffd || _is_control(cp, cat))
{
// pass
}
else if(do_lower_case && cat[0] == 'M' && cat[1] == 'n')
{
// pass
}
else if(_is_whitespace(cp, cat))
{
new_token = true;
}
else
{
size_t dst_len = len;
const char* dst_ptr = text + subpos;
if(do_lower_case)
{
dst_len = utf8proc_encode_char(cp, (utf8proc_uint8_t*)dst);
dst_ptr = dst;
}
const char *cat = utf8proc_category_string(cp);
if (cp == 0 || cp == 0xfffd || _is_control(cp, cat)) {
// pass
} else if (do_lower_case && cat[0] == 'M' && cat[1] == 'n') {
// pass
} else if (_is_whitespace(cp, cat)) {
if(_is_punctuation(cp, cat) || _is_chinese_char(cp))
{
output_tokens->emplace_back(dst_ptr, dst_len);
new_token = true;
} else {
size_t dst_len = len;
const char *dst_ptr = text + subpos;
if (do_lower_case) {
dst_len = utf8proc_encode_char(cp, (utf8proc_uint8_t *) dst);
dst_ptr = dst;
}
if (_is_punctuation(cp, cat) || _is_chinese_char(cp)) {
}
else
{
if(new_token)
{
output_tokens->emplace_back(dst_ptr, dst_len);
new_token = true;
} else {
if (new_token) {
output_tokens->emplace_back(dst_ptr, dst_len);
new_token = false;
} else {
output_tokens->at(output_tokens->size() - 1).append(dst_ptr, dst_len);
}
new_token = false;
}
else
{
output_tokens->at(output_tokens->size() - 1).append(dst_ptr, dst_len);
}
}
}
word_bytes = word_bytes - len;
subpos = subpos + len;
word_bytes = word_bytes - len;
subpos = subpos + len;
// early terminate
if (output_tokens->size() >= max_length) {
break;
}
// early terminate
if(output_tokens->size() >= max_length)
{
break;
}
}
if (do_lower_case) {
free((void *) text);
}
if(do_lower_case)
{
free((void*)text);
}
}
void WordpieceTokenizer::tokenize(const std::string& token, std::vector<std::string>* output_tokens)
{
if(token.size() > max_input_chars_per_word)
{ // FIXME: slightly different
output_tokens->push_back(unk_token);
return;
}
size_t output_tokens_len = output_tokens->size();
for(size_t start = 0; start < token.size();)
{
bool is_bad = true;
// TODO: can be optimized by prefix-tree
for(size_t end = token.size(); start < end; --end)
{ // FIXME: slightly different
std::string substr = start > 0 ? "##" + token.substr(start, end - start)
: token.substr(start, end - start);
if(vocab->count(substr))
{
is_bad = false;
output_tokens->push_back(substr);
start = end;
break;
}
}
void WordpieceTokenizer::tokenize(const std::string &token, std::vector<std::string> *output_tokens) {
if (token.size() > max_input_chars_per_word) { // FIXME: slightly different
if(is_bad)
{
output_tokens->resize(output_tokens_len);
output_tokens->push_back(unk_token);
return;
}
size_t output_tokens_len = output_tokens->size();
for (size_t start = 0; start < token.size();) {
bool is_bad = true;
// TODO: can be optimized by prefix-tree
for (size_t end = token.size(); start < end; --end) { // FIXME: slightly different
std::string substr = start > 0
? "##" + token.substr(start, end - start)
: token.substr(start, end - start);
if (vocab->count(substr)) {
is_bad = false;
output_tokens->push_back(substr);
start = end;
break;
}
}
if (is_bad) {
output_tokens->resize(output_tokens_len);
output_tokens->push_back(unk_token);
return;
}
}
}
}
void FullTokenizer::tokenize(const char *text, std::vector<std::string> *output_tokens, size_t max_length) {
std::vector<std::string> tokens;
tokens.reserve(max_length);
basic_tokenizer->tokenize(text, &tokens, max_length);
for (const auto &token : tokens) {
wordpiece_tokenizer->tokenize(token, output_tokens);
// early terminate
if (output_tokens->size() >= max_length) {
break;
}
void FullTokenizer::tokenize(const char* text,
std::vector<std::string>* output_tokens,
size_t max_length)
{
std::vector<std::string> tokens;
tokens.reserve(max_length);
basic_tokenizer->tokenize(text, &tokens, max_length);
for(const auto& token : tokens)
{
wordpiece_tokenizer->tokenize(token, output_tokens);
// early terminate
if(output_tokens->size() >= max_length)
{
break;
}
}
}
} // namespace cuBERT
......@@ -6,158 +6,172 @@
#include <unordered_map>
#include <iostream>
namespace cuBERT {
namespace cuBERT
{
void load_vocab(const char *vocab_file, std::unordered_map<std::string, uint64_t> *vocab);
void load_vocab(const char* vocab_file, std::unordered_map<std::string, uint64_t>* vocab);
/**
* Checks whether `chars` is a whitespace character.
* @param c
* @return
*/
bool _is_whitespace(int c);
bool _is_whitespace(int c);
/**
* Checks whether `chars` is a control character.
* @param c
* @return
*/
bool _is_control(int c);
bool _is_control(int c);
/**
* Checks whether `chars` is a punctuation character.
* @param cp
* @return
*/
bool _is_punctuation(int cp);
bool _is_punctuation(int cp);
/**
* Runs basic tokenization (punctuation splitting, lower casing, etc.).
*/
class BasicTokenizer {
class BasicTokenizer
{
public:
/**
* Constructs a BasicTokenizer.
* @param do_lower_case Whether to lower case the input.
*/
explicit BasicTokenizer(bool do_lower_case = true) : do_lower_case(do_lower_case) {}
BasicTokenizer(const BasicTokenizer &other) = delete;
virtual ~BasicTokenizer() = default;
/**
* Tokenizes a piece of text.
*
* to_lower
* _run_strip_accents Strips accents from a piece of text.
* _clean_text Performs invalid character removal and whitespace cleanup on text.
* _tokenize_chinese_chars Adds whitespace around any CJK character.
* _run_split_on_punc Splits punctuation on a piece of text.
* whitespace_tokenize Runs basic whitespace cleaning and splitting on a piece of text.
*
* @param text
* @param output_tokens
*/
void tokenize(const char *text, std::vector<std::string> *output_tokens, size_t max_length);
/**
* Constructs a BasicTokenizer.
* @param do_lower_case Whether to lower case the input.
*/
explicit BasicTokenizer(bool do_lower_case = true) : do_lower_case(do_lower_case) {}
BasicTokenizer(const BasicTokenizer& other) = delete;
virtual ~BasicTokenizer() = default;
/**
* Tokenizes a piece of text.
*
* to_lower
* _run_strip_accents Strips accents from a piece of text.
* _clean_text Performs invalid character removal and whitespace cleanup on text.
* _tokenize_chinese_chars Adds whitespace around any CJK character.
* _run_split_on_punc Splits punctuation on a piece of text.
* whitespace_tokenize Runs basic whitespace cleaning and splitting on a piece of text.
*
* @param text
* @param output_tokens
*/
void tokenize(const char* text, std::vector<std::string>* output_tokens, size_t max_length);
private:
const bool do_lower_case;
const bool do_lower_case;
/**
* Checks whether CP is the codepoint of a CJK character.
* @param cp
* @return
*/
inline static bool _is_chinese_char(int cp);
};
/**
* Checks whether CP is the codepoint of a CJK character.
* @param cp
* @return
*/
inline static bool _is_chinese_char(int cp);
};
/**
* Runs WordPiece tokenziation.
*/
class WordpieceTokenizer {
class WordpieceTokenizer
{
public:
explicit WordpieceTokenizer(
std::unordered_map<std::string, uint64_t> *vocab,
std::string unk_token = "[UNK]",
int max_input_chars_per_word = 200
) : vocab(vocab), unk_token(unk_token), max_input_chars_per_word(max_input_chars_per_word) {}
WordpieceTokenizer(const WordpieceTokenizer &other) = delete;
virtual ~WordpieceTokenizer() = default;
/**
* Tokenizes a piece of text into its word pieces.
*
* This uses a greedy longest-match-first algorithm to perform tokenization
* using the given vocabulary.
*
* For example:
* input = "unaffable"
* output = ["un", "##aff", "##able"]
*
* @param text A single token or whitespace separated tokens. This should have already been passed through `BasicTokenizer.
* @param output_tokens A list of wordpiece tokens.
*/
void tokenize(const std::string &text, std::vector<std::string> *output_tokens);
explicit WordpieceTokenizer(std::unordered_map<std::string, uint64_t>* vocab,
std::string unk_token = "[UNK]",
int max_input_chars_per_word = 200)
: vocab(vocab), unk_token(unk_token), max_input_chars_per_word(max_input_chars_per_word)
{
}
WordpieceTokenizer(const WordpieceTokenizer& other) = delete;
virtual ~WordpieceTokenizer() = default;
/**
* Tokenizes a piece of text into its word pieces.
*
* This uses a greedy longest-match-first algorithm to perform tokenization
* using the given vocabulary.
*
* For example:
* input = "unaffable"
* output = ["un", "##aff", "##able"]
*
* @param text A single token or whitespace separated tokens. This should have already been
* passed through `BasicTokenizer.
* @param output_tokens A list of wordpiece tokens.
*/
void tokenize(const std::string& text, std::vector<std::string>* output_tokens);
private:
const std::unordered_map<std::string, uint64_t> *vocab;
const std::string unk_token;
const int max_input_chars_per_word;
};
const std::unordered_map<std::string, uint64_t>* vocab;
const std::string unk_token;
const int max_input_chars_per_word;
};
/**
* Runs end-to-end tokenziation.
*/
class FullTokenizer {
class FullTokenizer
{
public:
FullTokenizer(const char *vocab_file, bool do_lower_case = true) {
vocab = new std::unordered_map<std::string, uint64_t>();
load_vocab(vocab_file, vocab);
basic_tokenizer = new BasicTokenizer(do_lower_case);
wordpiece_tokenizer = new WordpieceTokenizer(vocab);
FullTokenizer(const char* vocab_file, bool do_lower_case = true)
{
vocab = new std::unordered_map<std::string, uint64_t>();
load_vocab(vocab_file, vocab);
basic_tokenizer = new BasicTokenizer(do_lower_case);
wordpiece_tokenizer = new WordpieceTokenizer(vocab);
}
~FullTokenizer()
{
if(wordpiece_tokenizer != NULL)
{
wordpiece_tokenizer = NULL;
}
delete wordpiece_tokenizer;
~FullTokenizer() {
if (wordpiece_tokenizer != NULL){
wordpiece_tokenizer = NULL;
}
delete wordpiece_tokenizer;
if (basic_tokenizer != NULL){
basic_tokenizer = NULL;
}
delete basic_tokenizer;
if (vocab != NULL){
vocab = NULL;
}
delete vocab;
if(basic_tokenizer != NULL)
{
basic_tokenizer = NULL;
}
delete basic_tokenizer;
void tokenize(const char *text, std::vector<std::string> *output_tokens, size_t max_length);
inline uint64_t convert_token_to_id(const std::string &token) {
auto item = vocab->find(token);
if (item == vocab->end()) {
std::cerr << "vocab missing key: " << token << std::endl;
return 0;
} else {
return item->second;
}
if(vocab != NULL)
{
vocab = NULL;
}
delete vocab;
}
void tokenize(const char* text, std::vector<std::string>* output_tokens, size_t max_length);
inline uint64_t convert_token_to_id(const std::string& token)
{
auto item = vocab->find(token);
if(item == vocab->end())
{
std::cerr << "vocab missing key: " << token << std::endl;
return 0;
}
else
{
return item->second;
}
}
void convert_tokens_to_ids(const std::vector<std::string> &tokens, uint64_t *ids);
void convert_tokens_to_ids(const std::vector<std::string>& tokens, uint64_t* ids);
private:
std::unordered_map<std::string, uint64_t> *vocab;
BasicTokenizer *basic_tokenizer;
WordpieceTokenizer *wordpiece_tokenizer;
};
std::unordered_map<std::string, uint64_t>* vocab;
BasicTokenizer* basic_tokenizer;
WordpieceTokenizer* wordpiece_tokenizer;
};
}
} // namespace cuBERT
#endif //CUBERT_TOKENIZATION_H
#endif // CUBERT_TOKENIZATION_H
This diff is collapsed.
This diff is collapsed.
......@@ -12,7 +12,7 @@ int main()
// 加载GPT2模型
migraphxSamples::GPT2 gpt2;
migraphxSamples::ErrorCode errorCode = gpt2.Initialize();
if (errorCode != migraphxSamples::SUCCESS)
if(errorCode != migraphxSamples::SUCCESS)
{
LOG_ERROR(stdout, "fail to initialize GPT2!\n");
exit(-1);
......@@ -25,7 +25,7 @@ int main()
std::string buf;
std::vector<std::string> output;
infile.open("../Resource/vocab_shici.txt");
while (std::getline(infile,buf))
while(std::getline(infile, buf))
{
output.push_back(buf);
}
......@@ -37,7 +37,7 @@ int main()
std::vector<std::string> result;
std::cout << "开始和GPT2对诗,输入CTRL + Z以退出" << std::endl;
while (true)
while(true)
{
// 数据预处理
std::cout << "question: ";
......@@ -45,7 +45,7 @@ int main()
gpt2.Preprocessing(tokenizer, question, input_id);
// 推理
for(int i=0;i<50;++i)
for(int i = 0; i < 50; ++i)
{
long unsigned int outputs = gpt2.Inference(input_id);
if(outputs == 102)
......@@ -57,7 +57,7 @@ int main()
}
// 将数值映射为字符
for(int i=0;i<score.size();++i)
for(int i = 0; i < score.size(); ++i)
{
result.push_back(output[score[i]]);
}
......@@ -65,12 +65,12 @@ int main()
// 打印结果
std::cout << "chatbot: ";
std::cout << question;
for(int j=0; j<result.size();++j)
for(int j = 0; j < result.size(); ++j)
{
std::cout << result[j];
}
std::cout << std::endl;
// 清除数据
input_id.clear();
result.clear();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment