Commit 0f9dc829 authored by liucong's avatar liucong
Browse files

重新格式化Cppd代码格式

parent 824cfb81
...@@ -12,39 +12,33 @@ ...@@ -12,39 +12,33 @@
namespace migraphxSamples namespace migraphxSamples
{ {
GPT2::GPT2() GPT2::GPT2() {}
{
} GPT2::~GPT2() {}
GPT2::~GPT2()
{
}
ErrorCode GPT2::Initialize() ErrorCode GPT2::Initialize()
{ {
// 获取模型文件 // 获取模型文件
std::string modelPath="../Resource/GPT2_shici.onnx"; std::string modelPath = "../Resource/GPT2_shici.onnx";
// 设置最大输入shape // 设置最大输入shape
migraphx::onnx_options onnx_options; migraphx::onnx_options onnx_options;
onnx_options.map_input_dims["input"] = {1,1000}; onnx_options.map_input_dims["input"] = {1, 1000};
// 加载模型 // 加载模型
if(!Exists(modelPath)) if(!Exists(modelPath))
{ {
LOG_ERROR(stdout,"%s not exist!\n",modelPath.c_str()); LOG_ERROR(stdout, "%s not exist!\n", modelPath.c_str());
return MODEL_NOT_EXIST; return MODEL_NOT_EXIST;
} }
net = migraphx::parse_onnx(modelPath, onnx_options); net = migraphx::parse_onnx(modelPath, onnx_options);
LOG_INFO(stdout,"succeed to load model: %s\n",GetFileName(modelPath).c_str()); LOG_INFO(stdout, "succeed to load model: %s\n", GetFileName(modelPath).c_str());
// 获取模型输入/输出节点信息 // 获取模型输入/输出节点信息
std::unordered_map<std::string, migraphx::shape> inputs = net.get_inputs(); std::unordered_map<std::string, migraphx::shape> inputs = net.get_inputs();
std::unordered_map<std::string, migraphx::shape> outputs = net.get_outputs(); std::unordered_map<std::string, migraphx::shape> outputs = net.get_outputs();
inputName=inputs.begin()->first; inputName = inputs.begin()->first;
inputShape=inputs.begin()->second; inputShape = inputs.begin()->second;
// 设置模型为GPU模式 // 设置模型为GPU模式
migraphx::target gpuTarget = migraphx::gpu::target{}; migraphx::target gpuTarget = migraphx::gpu::target{};
...@@ -53,33 +47,31 @@ ErrorCode GPT2::Initialize() ...@@ -53,33 +47,31 @@ ErrorCode GPT2::Initialize()
migraphx::compile_options options; migraphx::compile_options options;
options.device_id = 0; // 设置GPU设备,默认为0号设备 options.device_id = 0; // 设置GPU设备,默认为0号设备
options.offload_copy = true; // 设置offload_copy options.offload_copy = true; // 设置offload_copy
net.compile(gpuTarget,options); net.compile(gpuTarget, options);
LOG_INFO(stdout,"succeed to compile model: %s\n",GetFileName(modelPath).c_str()); LOG_INFO(stdout, "succeed to compile model: %s\n", GetFileName(modelPath).c_str());
return SUCCESS; return SUCCESS;
} }
static bool CompareM(Predictions a, Predictions b) static bool CompareM(Predictions a, Predictions b) { return a.predictionvalue > b.predictionvalue; }
{
return a.predictionvalue > b.predictionvalue;
}
long unsigned int GPT2::Inference(const std::vector<long unsigned int> &input_id) long unsigned int GPT2::Inference(const std::vector<long unsigned int>& input_id)
{ {
long unsigned int input[1][input_id.size()]; long unsigned int input[1][input_id.size()];
for (int j=0;j<input_id.size();++j) for(int j = 0; j < input_id.size(); ++j)
{ {
input[0][j] = input_id[j]; input[0][j] = input_id[j];
} }
// 设置输入shape // 设置输入shape
std::vector<std::vector<std::size_t>> inputShapes; std::vector<std::vector<std::size_t>> inputShapes;
inputShapes.push_back({1,input_id.size()}); inputShapes.push_back({1, input_id.size()});
// 创建输入数据 // 创建输入数据
std::unordered_map<std::string, migraphx::argument> inputData; std::unordered_map<std::string, migraphx::argument> inputData;
inputData[inputName]=migraphx::argument{migraphx::shape(inputShape.type(),inputShapes[0]),(long unsigned int*)input}; inputData[inputName] = migraphx::argument{migraphx::shape(inputShape.type(), inputShapes[0]),
(long unsigned int*)input};
// 推理 // 推理
std::vector<migraphx::argument> results = net.eval(inputData); std::vector<migraphx::argument> results = net.eval(inputData);
...@@ -87,13 +79,13 @@ long unsigned int GPT2::Inference(const std::vector<long unsigned int> &input_id ...@@ -87,13 +79,13 @@ long unsigned int GPT2::Inference(const std::vector<long unsigned int> &input_id
// 获取输出节点的属性 // 获取输出节点的属性
migraphx::argument result = results[0]; migraphx::argument result = results[0];
migraphx::shape outputShape = result.get_shape(); // 输出节点的shape migraphx::shape outputShape = result.get_shape(); // 输出节点的shape
int numberOfOutput=outputShape.elements(); // 输出节点元素的个数 int numberOfOutput = outputShape.elements(); // 输出节点元素的个数
float *data = (float *)result.data(); // 输出节点数据指针 float* data = (float*)result.data(); // 输出节点数据指针
// 保存推理结果 // 保存推理结果
long unsigned int n = 0; long unsigned int n = 0;
std::vector<Predictions> resultsOfPredictions(22557); std::vector<Predictions> resultsOfPredictions(22557);
for(int i=(input_id.size()-1)*22557; i<input_id.size()*22557; ++i) for(int i = (input_id.size() - 1) * 22557; i < input_id.size() * 22557; ++i)
{ {
resultsOfPredictions[n].index = n; resultsOfPredictions[n].index = n;
resultsOfPredictions[n].predictionvalue = data[i]; resultsOfPredictions[n].predictionvalue = data[i];
...@@ -110,8 +102,8 @@ long unsigned int GPT2::Inference(const std::vector<long unsigned int> &input_id ...@@ -110,8 +102,8 @@ long unsigned int GPT2::Inference(const std::vector<long unsigned int> &input_id
} }
ErrorCode GPT2::Preprocessing(cuBERT::FullTokenizer tokenizer, ErrorCode GPT2::Preprocessing(cuBERT::FullTokenizer tokenizer,
char *question, char* question,
std::vector<long unsigned int> &input_id) std::vector<long unsigned int>& input_id)
{ {
// 分词操作 // 分词操作
int max_seq_length = 1000; int max_seq_length = 1000;
...@@ -121,7 +113,7 @@ ErrorCode GPT2::Preprocessing(cuBERT::FullTokenizer tokenizer, ...@@ -121,7 +113,7 @@ ErrorCode GPT2::Preprocessing(cuBERT::FullTokenizer tokenizer,
// 保存编码信息 // 保存编码信息
input_id.push_back(tokenizer.convert_token_to_id("[CLS]")); input_id.push_back(tokenizer.convert_token_to_id("[CLS]"));
for (int i=0;i<tokens_question.size();++i) for(int i = 0; i < tokens_question.size(); ++i)
{ {
input_id.push_back(tokenizer.convert_token_to_id(tokens_question[i])); input_id.push_back(tokenizer.convert_token_to_id(tokens_question[i]));
} }
...@@ -129,4 +121,4 @@ ErrorCode GPT2::Preprocessing(cuBERT::FullTokenizer tokenizer, ...@@ -129,4 +121,4 @@ ErrorCode GPT2::Preprocessing(cuBERT::FullTokenizer tokenizer,
return SUCCESS; return SUCCESS;
} }
} } // namespace migraphxSamples
\ No newline at end of file \ No newline at end of file
...@@ -8,25 +8,25 @@ ...@@ -8,25 +8,25 @@
namespace migraphxSamples namespace migraphxSamples
{ {
typedef enum _ErrorCode typedef enum _ErrorCode
{ {
SUCCESS=0, SUCCESS = 0,
MODEL_NOT_EXIST, MODEL_NOT_EXIST,
CONFIG_FILE_NOT_EXIST, CONFIG_FILE_NOT_EXIST,
FAIL_TO_LOAD_MODEL, FAIL_TO_LOAD_MODEL,
FAIL_TO_OPEN_CONFIG_FILE, FAIL_TO_OPEN_CONFIG_FILE,
}ErrorCode; } ErrorCode;
typedef struct _Predictions typedef struct _Predictions
{ {
long unsigned int index; long unsigned int index;
float predictionvalue; float predictionvalue;
}Predictions; } Predictions;
class GPT2 class GPT2
{ {
public: public:
GPT2(); GPT2();
~GPT2(); ~GPT2();
...@@ -34,16 +34,16 @@ public: ...@@ -34,16 +34,16 @@ public:
ErrorCode Initialize(); ErrorCode Initialize();
ErrorCode Preprocessing(cuBERT::FullTokenizer tokenizer, ErrorCode Preprocessing(cuBERT::FullTokenizer tokenizer,
char *question, char* question,
std::vector<long unsigned int> &input_id); std::vector<long unsigned int>& input_id);
long unsigned int Inference(const std::vector<long unsigned int> &input_id); long unsigned int Inference(const std::vector<long unsigned int>& input_id);
private: private:
migraphx::program net; migraphx::program net;
std::string inputName; std::string inputName;
migraphx::shape inputShape; migraphx::shape inputShape;
}; };
} } // namespace migraphxSamples
#endif #endif
\ No newline at end of file
...@@ -28,32 +28,32 @@ static std::vector<std::string> SplitString(std::string str, std::string separat ...@@ -28,32 +28,32 @@ static std::vector<std::string> SplitString(std::string str, std::string separat
{ {
std::string::size_type pos; std::string::size_type pos;
std::vector<std::string> result; std::vector<std::string> result;
str+=separator;//扩展字符串以方便操作 str += separator; // 扩展字符串以方便操作
int size=str.size(); int size = str.size();
for(int i=0; i<size; i++) for(int i = 0; i < size; i++)
{ {
pos=str.find(separator,i); pos = str.find(separator, i);
if(pos<size) if(pos < size)
{ {
std::string s=str.substr(i,pos-i); std::string s = str.substr(i, pos - i);
result.push_back(s); result.push_back(s);
i=pos+separator.size()-1; i = pos + separator.size() - 1;
} }
} }
return result; return result;
} }
#if defined _WIN32 || defined WINCE #if defined _WIN32 || defined WINCE
const char dir_separators[] = "/\\"; const char dir_separators[] = "/\\";
struct dirent struct dirent
{ {
const char* d_name; const char* d_name;
}; };
struct DIR struct DIR
{ {
#ifdef WINRT #ifdef WINRT
WIN32_FIND_DATAW data; WIN32_FIND_DATAW data;
#else #else
...@@ -62,17 +62,17 @@ static std::vector<std::string> SplitString(std::string str, std::string separat ...@@ -62,17 +62,17 @@ static std::vector<std::string> SplitString(std::string str, std::string separat
HANDLE handle; HANDLE handle;
dirent ent; dirent ent;
#ifdef WINRT #ifdef WINRT
DIR() { } DIR() {}
~DIR() ~DIR()
{ {
if (ent.d_name) if(ent.d_name)
delete[] ent.d_name; delete[] ent.d_name;
} }
#endif #endif
}; };
DIR* opendir(const char* path) DIR* opendir(const char* path)
{ {
DIR* dir = new DIR; DIR* dir = new DIR;
dir->ent.d_name = 0; dir->ent.d_name = 0;
#ifdef WINRT #ifdef WINRT
...@@ -80,27 +80,31 @@ static std::vector<std::string> SplitString(std::string str, std::string separat ...@@ -80,27 +80,31 @@ static std::vector<std::string> SplitString(std::string str, std::string separat
wchar_t wfull_path[MAX_PATH]; wchar_t wfull_path[MAX_PATH];
size_t copied = mbstowcs(wfull_path, full_path.c_str(), MAX_PATH); size_t copied = mbstowcs(wfull_path, full_path.c_str(), MAX_PATH);
CV_Assert((copied != MAX_PATH) && (copied != (size_t)-1)); CV_Assert((copied != MAX_PATH) && (copied != (size_t)-1));
dir->handle = ::FindFirstFileExW(wfull_path, FindExInfoStandard, dir->handle = ::FindFirstFileExW(
&dir->data, FindExSearchNameMatch, NULL, 0); wfull_path, FindExInfoStandard, &dir->data, FindExSearchNameMatch, NULL, 0);
#else #else
dir->handle = ::FindFirstFileExA((string(path) + "\\*").c_str(), dir->handle = ::FindFirstFileExA((string(path) + "\\*").c_str(),
FindExInfoStandard, &dir->data, FindExSearchNameMatch, NULL, 0); FindExInfoStandard,
&dir->data,
FindExSearchNameMatch,
NULL,
0);
#endif #endif
if (dir->handle == INVALID_HANDLE_VALUE) if(dir->handle == INVALID_HANDLE_VALUE)
{ {
/*closedir will do all cleanup*/ /*closedir will do all cleanup*/
delete dir; delete dir;
return 0; return 0;
} }
return dir; return dir;
} }
dirent* readdir(DIR* dir) dirent* readdir(DIR* dir)
{ {
#ifdef WINRT #ifdef WINRT
if (dir->ent.d_name != 0) if(dir->ent.d_name != 0)
{ {
if (::FindNextFileW(dir->handle, &dir->data) != TRUE) if(::FindNextFileW(dir->handle, &dir->data) != TRUE)
return 0; return 0;
} }
size_t asize = wcstombs(NULL, dir->data.cFileName, 0); size_t asize = wcstombs(NULL, dir->data.cFileName, 0);
...@@ -110,33 +114,33 @@ static std::vector<std::string> SplitString(std::string str, std::string separat ...@@ -110,33 +114,33 @@ static std::vector<std::string> SplitString(std::string str, std::string separat
wcstombs(aname, dir->data.cFileName, asize); wcstombs(aname, dir->data.cFileName, asize);
dir->ent.d_name = aname; dir->ent.d_name = aname;
#else #else
if (dir->ent.d_name != 0) if(dir->ent.d_name != 0)
{ {
if (::FindNextFileA(dir->handle, &dir->data) != TRUE) if(::FindNextFileA(dir->handle, &dir->data) != TRUE)
return 0; return 0;
} }
dir->ent.d_name = dir->data.cFileName; dir->ent.d_name = dir->data.cFileName;
#endif #endif
return &dir->ent; return &dir->ent;
} }
void closedir(DIR* dir) void closedir(DIR* dir)
{ {
::FindClose(dir->handle); ::FindClose(dir->handle);
delete dir; delete dir;
} }
#else #else
# include <dirent.h> #include <dirent.h>
# include <sys/stat.h> #include <sys/stat.h>
const char dir_separators[] = "/"; const char dir_separators[] = "/";
#endif #endif
static bool isDir(const string &path, DIR* dir) static bool isDir(const string& path, DIR* dir)
{ {
#if defined _WIN32 || defined WINCE #if defined _WIN32 || defined WINCE
DWORD attributes; DWORD attributes;
BOOL status = TRUE; BOOL status = TRUE;
if (dir) if(dir)
attributes = dir->data.dwFileAttributes; attributes = dir->data.dwFileAttributes;
else else
{ {
...@@ -156,20 +160,17 @@ static std::vector<std::string> SplitString(std::string str, std::string separat ...@@ -156,20 +160,17 @@ static std::vector<std::string> SplitString(std::string str, std::string separat
#else #else
(void)dir; (void)dir;
struct stat stat_buf; struct stat stat_buf;
if (0 != stat(path.c_str(), &stat_buf)) if(0 != stat(path.c_str(), &stat_buf))
return false; return false;
int is_dir = S_ISDIR(stat_buf.st_mode); int is_dir = S_ISDIR(stat_buf.st_mode);
return is_dir != 0; return is_dir != 0;
#endif #endif
} }
bool IsDirectory(const string &path) bool IsDirectory(const string& path) { return isDir(path, NULL); }
{
return isDir(path, NULL);
}
bool Exists(const string& path) bool Exists(const string& path)
{ {
#if defined _WIN32 || defined WINCE #if defined _WIN32 || defined WINCE
BOOL status = TRUE; BOOL status = TRUE;
...@@ -190,28 +191,25 @@ static std::vector<std::string> SplitString(std::string str, std::string separat ...@@ -190,28 +191,25 @@ static std::vector<std::string> SplitString(std::string str, std::string separat
struct stat stat_buf; struct stat stat_buf;
return (0 == stat(path.c_str(), &stat_buf)); return (0 == stat(path.c_str(), &stat_buf));
#endif #endif
} }
bool IsPathSeparator(char c) bool IsPathSeparator(char c) { return c == '/' || c == '\\'; }
{
return c == '/' || c == '\\';
}
string JoinPath(const string& base, const string& path) string JoinPath(const string& base, const string& path)
{ {
if (base.empty()) if(base.empty())
return path; return path;
if (path.empty()) if(path.empty())
return base; return base;
bool baseSep = IsPathSeparator(base[base.size() - 1]); bool baseSep = IsPathSeparator(base[base.size() - 1]);
bool pathSep = IsPathSeparator(path[0]); bool pathSep = IsPathSeparator(path[0]);
string result; string result;
if (baseSep && pathSep) if(baseSep && pathSep)
{ {
result = base + path.substr(1); result = base + path.substr(1);
} }
else if (!baseSep && !pathSep) else if(!baseSep && !pathSep)
{ {
result = base + PATH_SEPARATOR + path; result = base + PATH_SEPARATOR + path;
} }
...@@ -220,15 +218,15 @@ static std::vector<std::string> SplitString(std::string str, std::string separat ...@@ -220,15 +218,15 @@ static std::vector<std::string> SplitString(std::string str, std::string separat
result = base + path; result = base + path;
} }
return result; return result;
} }
static bool wildcmp(const char *string, const char *wild) static bool wildcmp(const char* string, const char* wild)
{ {
const char *cp = 0, *mp = 0; const char *cp = 0, *mp = 0;
while ((*string) && (*wild != '*')) while((*string) && (*wild != '*'))
{ {
if ((*wild != *string) && (*wild != '?')) if((*wild != *string) && (*wild != '?'))
{ {
return false; return false;
} }
...@@ -237,11 +235,11 @@ static std::vector<std::string> SplitString(std::string str, std::string separat ...@@ -237,11 +235,11 @@ static std::vector<std::string> SplitString(std::string str, std::string separat
string++; string++;
} }
while (*string) while(*string)
{ {
if (*wild == '*') if(*wild == '*')
{ {
if (!*++wild) if(!*++wild)
{ {
return true; return true;
} }
...@@ -249,7 +247,7 @@ static std::vector<std::string> SplitString(std::string str, std::string separat ...@@ -249,7 +247,7 @@ static std::vector<std::string> SplitString(std::string str, std::string separat
mp = wild; mp = wild;
cp = string + 1; cp = string + 1;
} }
else if ((*wild == *string) || (*wild == '?')) else if((*wild == *string) || (*wild == '?'))
{ {
wild++; wild++;
string++; string++;
...@@ -261,47 +259,52 @@ static std::vector<std::string> SplitString(std::string str, std::string separat ...@@ -261,47 +259,52 @@ static std::vector<std::string> SplitString(std::string str, std::string separat
} }
} }
while (*wild == '*') while(*wild == '*')
{ {
wild++; wild++;
} }
return *wild == 0; return *wild == 0;
} }
static void glob_rec(const string &directory, const string& wildchart, std::vector<string>& result, static void glob_rec(const string& directory,
bool recursive, bool includeDirectories, const string& pathPrefix) const string& wildchart,
{ std::vector<string>& result,
DIR *dir; bool recursive,
bool includeDirectories,
const string& pathPrefix)
{
DIR* dir;
if ((dir = opendir(directory.c_str())) != 0) if((dir = opendir(directory.c_str())) != 0)
{ {
/* find all the files and directories within directory */ /* find all the files and directories within directory */
try try
{ {
struct dirent *ent; struct dirent* ent;
while ((ent = readdir(dir)) != 0) while((ent = readdir(dir)) != 0)
{ {
const char* name = ent->d_name; const char* name = ent->d_name;
if ((name[0] == 0) || (name[0] == '.' && name[1] == 0) || (name[0] == '.' && name[1] == '.' && name[2] == 0)) if((name[0] == 0) || (name[0] == '.' && name[1] == 0) ||
(name[0] == '.' && name[1] == '.' && name[2] == 0))
continue; continue;
string path = JoinPath(directory, name); string path = JoinPath(directory, name);
string entry = JoinPath(pathPrefix, name); string entry = JoinPath(pathPrefix, name);
if (isDir(path, dir)) if(isDir(path, dir))
{ {
if (recursive) if(recursive)
glob_rec(path, wildchart, result, recursive, includeDirectories, entry); glob_rec(path, wildchart, result, recursive, includeDirectories, entry);
if (!includeDirectories) if(!includeDirectories)
continue; continue;
} }
if (wildchart.empty() || wildcmp(name, wildchart.c_str())) if(wildchart.empty() || wildcmp(name, wildchart.c_str()))
result.push_back(entry); result.push_back(entry);
} }
} }
catch (...) catch(...)
{ {
closedir(dir); closedir(dir);
throw; throw;
...@@ -312,23 +315,27 @@ static std::vector<std::string> SplitString(std::string str, std::string separat ...@@ -312,23 +315,27 @@ static std::vector<std::string> SplitString(std::string str, std::string separat
{ {
printf("could not open directory: %s", directory.c_str()); printf("could not open directory: %s", directory.c_str());
} }
} }
void GetFileNameList(const string &directory, const string &pattern, std::vector<string>& result, bool recursive, bool addPath) void GetFileNameList(const string& directory,
{ const string& pattern,
std::vector<string>& result,
bool recursive,
bool addPath)
{
// split pattern // split pattern
vector<string> patterns=SplitString(pattern,","); vector<string> patterns = SplitString(pattern, ",");
result.clear(); result.clear();
for(int i=0;i<patterns.size();++i) for(int i = 0; i < patterns.size(); ++i)
{ {
string eachPattern=patterns[i]; string eachPattern = patterns[i];
std::vector<string> eachResult; std::vector<string> eachResult;
glob_rec(directory, eachPattern, eachResult, recursive, true, directory); glob_rec(directory, eachPattern, eachResult, recursive, true, directory);
for(int j=0;j<eachResult.size();++j) for(int j = 0; j < eachResult.size(); ++j)
{ {
if (IsDirectory(eachResult[j])) if(IsDirectory(eachResult[j]))
continue; continue;
if(addPath) if(addPath)
{ {
...@@ -341,41 +348,45 @@ static std::vector<std::string> SplitString(std::string str, std::string separat ...@@ -341,41 +348,45 @@ static std::vector<std::string> SplitString(std::string str, std::string separat
} }
} }
std::sort(result.begin(), result.end()); std::sort(result.begin(), result.end());
} }
void GetFileNameList2(const string &directory, const string &pattern, std::vector<string>& result, bool recursive, bool addPath) void GetFileNameList2(const string& directory,
{ const string& pattern,
std::vector<string>& result,
bool recursive,
bool addPath)
{
// split pattern // split pattern
vector<string> patterns = SplitString(pattern, ","); vector<string> patterns = SplitString(pattern, ",");
result.clear(); result.clear();
for (int i = 0; i<patterns.size(); ++i) for(int i = 0; i < patterns.size(); ++i)
{ {
string eachPattern = patterns[i]; string eachPattern = patterns[i];
std::vector<string> eachResult; std::vector<string> eachResult;
glob_rec(directory, eachPattern, eachResult, recursive, true, directory); glob_rec(directory, eachPattern, eachResult, recursive, true, directory);
for (int j = 0; j<eachResult.size(); ++j) for(int j = 0; j < eachResult.size(); ++j)
{ {
string filePath = eachResult[j]; string filePath = eachResult[j];
if (IsDirectory(filePath)) if(IsDirectory(filePath))
{ {
filePath = filePath + "/"; filePath = filePath + "/";
for (int k = 0; k < filePath.size(); ++k) for(int k = 0; k < filePath.size(); ++k)
{ {
if (IsPathSeparator(filePath[k])) if(IsPathSeparator(filePath[k]))
{ {
filePath[k] = '/'; filePath[k] = '/';
} }
} }
} }
if (addPath) if(addPath)
{ {
result.push_back(filePath); result.push_back(filePath);
} }
else else
{ {
if (!IsDirectory(filePath)) if(!IsDirectory(filePath))
{ {
result.push_back(GetFileName(filePath)); result.push_back(GetFileName(filePath));
} }
...@@ -383,19 +394,19 @@ static std::vector<std::string> SplitString(std::string str, std::string separat ...@@ -383,19 +394,19 @@ static std::vector<std::string> SplitString(std::string str, std::string separat
} }
} }
std::sort(result.begin(), result.end()); std::sort(result.begin(), result.end());
} }
void RemoveAll(const string& path) void RemoveAll(const string& path)
{ {
if (!Exists(path)) if(!Exists(path))
return; return;
if (IsDirectory(path)) if(IsDirectory(path))
{ {
std::vector<string> entries; std::vector<string> entries;
GetFileNameList2(path, string(), entries, false, true); GetFileNameList2(path, string(), entries, false, true);
for (size_t i = 0; i < entries.size(); i++) for(size_t i = 0; i < entries.size(); i++)
{ {
const string& e = entries[i]; const string& e = entries[i];
RemoveAll(e); RemoveAll(e);
...@@ -405,7 +416,7 @@ static std::vector<std::string> SplitString(std::string str, std::string separat ...@@ -405,7 +416,7 @@ static std::vector<std::string> SplitString(std::string str, std::string separat
#else #else
bool result = rmdir(path.c_str()) == 0; bool result = rmdir(path.c_str()) == 0;
#endif #endif
if (!result) if(!result)
{ {
printf("can't remove directory: %s\n", path.c_str()); printf("can't remove directory: %s\n", path.c_str());
} }
...@@ -417,50 +428,50 @@ static std::vector<std::string> SplitString(std::string str, std::string separat ...@@ -417,50 +428,50 @@ static std::vector<std::string> SplitString(std::string str, std::string separat
#else #else
bool result = unlink(path.c_str()) == 0; bool result = unlink(path.c_str()) == 0;
#endif #endif
if (!result) if(!result)
{ {
printf("can't remove file: %s\n", path.c_str()); printf("can't remove file: %s\n", path.c_str());
} }
} }
} }
void Remove(const string &directory, const string &extension) void Remove(const string& directory, const string& extension)
{ {
DIR *dir; DIR* dir;
static int numberOfFiles = 0; static int numberOfFiles = 0;
if ((dir = opendir(directory.c_str())) != 0) if((dir = opendir(directory.c_str())) != 0)
{ {
/* find all the files and directories within directory */ /* find all the files and directories within directory */
try try
{ {
struct dirent *ent; struct dirent* ent;
while ((ent = readdir(dir)) != 0) while((ent = readdir(dir)) != 0)
{ {
const char* name = ent->d_name; const char* name = ent->d_name;
if ((name[0] == 0) || (name[0] == '.' && name[1] == 0) || (name[0] == '.' && name[1] == '.' && name[2] == 0)) if((name[0] == 0) || (name[0] == '.' && name[1] == 0) ||
(name[0] == '.' && name[1] == '.' && name[2] == 0))
continue; continue;
string path = JoinPath(directory, name); string path = JoinPath(directory, name);
if (isDir(path, dir)) if(isDir(path, dir))
{ {
Remove(path, extension); Remove(path, extension);
} }
// �ж���չ�� // �ж���չ��
if (extension.empty() || wildcmp(name, extension.c_str())) if(extension.empty() || wildcmp(name, extension.c_str()))
{ {
RemoveAll(path); RemoveAll(path);
++numberOfFiles; ++numberOfFiles;
printf("%s deleted! number of deleted files:%d\n", path.c_str(), numberOfFiles); printf("%s deleted! number of deleted files:%d\n", path.c_str(), numberOfFiles);
} }
} }
} }
catch (...) catch(...)
{ {
closedir(dir); closedir(dir);
throw; throw;
...@@ -474,49 +485,49 @@ static std::vector<std::string> SplitString(std::string str, std::string separat ...@@ -474,49 +485,49 @@ static std::vector<std::string> SplitString(std::string str, std::string separat
// ����RemoveAllɾ��Ŀ¼ // ����RemoveAllɾ��Ŀ¼
RemoveAll(directory); RemoveAll(directory);
} }
string GetFileName(const string &path) string GetFileName(const string& path)
{ {
string fileName; string fileName;
int indexOfPathSeparator = -1; int indexOfPathSeparator = -1;
for (int i = path.size() - 1; i >= 0; --i) for(int i = path.size() - 1; i >= 0; --i)
{ {
if (IsPathSeparator(path[i])) if(IsPathSeparator(path[i]))
{ {
fileName = path.substr(i + 1, path.size() - i - 1); fileName = path.substr(i + 1, path.size() - i - 1);
indexOfPathSeparator = i; indexOfPathSeparator = i;
break; break;
} }
} }
if (indexOfPathSeparator == -1) if(indexOfPathSeparator == -1)
{ {
fileName = path; fileName = path;
} }
return fileName; return fileName;
} }
string GetFileName_NoExtension(const string &path) string GetFileName_NoExtension(const string& path)
{ {
string fileName=GetFileName(path); string fileName = GetFileName(path);
string fileName_NoExtension; string fileName_NoExtension;
for(int i=fileName.size()-1;i>0;--i) for(int i = fileName.size() - 1; i > 0; --i)
{ {
if(fileName[i]=='.') if(fileName[i] == '.')
{ {
fileName_NoExtension=fileName.substr(0,i); fileName_NoExtension = fileName.substr(0, i);
break; break;
} }
} }
return fileName_NoExtension; return fileName_NoExtension;
} }
string GetExtension(const string &path) string GetExtension(const string& path)
{ {
string fileName; string fileName;
for (int i = path.size() - 1; i >= 0; --i) for(int i = path.size() - 1; i >= 0; --i)
{ {
if (path[i]=='.') if(path[i] == '.')
{ {
fileName = path.substr(i, path.size() - i); fileName = path.substr(i, path.size() - i);
break; break;
...@@ -524,56 +535,55 @@ static std::vector<std::string> SplitString(std::string str, std::string separat ...@@ -524,56 +535,55 @@ static std::vector<std::string> SplitString(std::string str, std::string separat
} }
return fileName; return fileName;
}
} string GetParentPath(const string& path)
{
string GetParentPath(const string &path)
{
string fileName; string fileName;
for (int i = path.size() - 1; i >= 0; --i) for(int i = path.size() - 1; i >= 0; --i)
{ {
if (IsPathSeparator(path[i])) if(IsPathSeparator(path[i]))
{ {
fileName = path.substr(0, i+1); fileName = path.substr(0, i + 1);
break; break;
} }
} }
return fileName; return fileName;
} }
static bool CreateDirectory(const string &path) static bool CreateDirectory(const string& path)
{ {
#if defined WIN32 || defined _WIN32 || defined WINCE #if defined WIN32 || defined _WIN32 || defined WINCE
#ifdef WINRT #ifdef WINRT
wchar_t wpath[MAX_PATH]; wchar_t wpath[MAX_PATH];
size_t copied = mbstowcs(wpath, path.c_str(), MAX_PATH); size_t copied = mbstowcs(wpath, path.c_str(), MAX_PATH);
CV_Assert((copied != MAX_PATH) && (copied != (size_t)-1)); CV_Assert((copied != MAX_PATH) && (copied != (size_t)-1));
int result = CreateDirectoryA(wpath, NULL) ? 0 : -1; int result = CreateDirectoryA(wpath, NULL) ? 0 : -1;
#else #else
int result = _mkdir(path.c_str()); int result = _mkdir(path.c_str());
#endif #endif
#elif defined __linux__ || defined __APPLE__ #elif defined __linux__ || defined __APPLE__
int result = mkdir(path.c_str(), 0777); int result = mkdir(path.c_str(), 0777);
#else #else
int result = -1; int result = -1;
#endif #endif
if (result == -1) if(result == -1)
{ {
return IsDirectory(path); return IsDirectory(path);
} }
return true; return true;
} }
bool CreateDirectories(const string &directoryPath) bool CreateDirectories(const string& directoryPath)
{ {
string path = directoryPath; string path = directoryPath;
for (;;) for(;;)
{ {
char last_char = path.empty() ? 0 : path[path.length() - 1]; char last_char = path.empty() ? 0 : path[path.length() - 1];
if (IsPathSeparator(last_char)) if(IsPathSeparator(last_char))
{ {
path = path.substr(0, path.length() - 1); path = path.substr(0, path.length() - 1);
continue; continue;
...@@ -581,35 +591,35 @@ static std::vector<std::string> SplitString(std::string str, std::string separat ...@@ -581,35 +591,35 @@ static std::vector<std::string> SplitString(std::string str, std::string separat
break; break;
} }
if (path.empty() || path == "./" || path == ".\\" || path == ".") if(path.empty() || path == "./" || path == ".\\" || path == ".")
return true; return true;
if (IsDirectory(path)) if(IsDirectory(path))
return true; return true;
size_t pos = path.rfind('/'); size_t pos = path.rfind('/');
if (pos == string::npos) if(pos == string::npos)
pos = path.rfind('\\'); pos = path.rfind('\\');
if (pos != string::npos) if(pos != string::npos)
{ {
string parent_directory = path.substr(0, pos); string parent_directory = path.substr(0, pos);
if (!parent_directory.empty()) if(!parent_directory.empty())
{ {
if (!CreateDirectories(parent_directory)) if(!CreateDirectories(parent_directory))
return false; return false;
} }
} }
return CreateDirectory(path); return CreateDirectory(path);
} }
bool CopyFile(const string srcPath, const string dstPath) bool CopyFile(const string srcPath, const string dstPath)
{ {
std::ifstream srcFile(srcPath,ios::binary); std::ifstream srcFile(srcPath, ios::binary);
std::ofstream dstFile(dstPath,ios::binary); std::ofstream dstFile(dstPath, ios::binary);
if(!srcFile.is_open()) if(!srcFile.is_open())
{ {
printf("can not open %s\n",srcPath.c_str()); printf("can not open %s\n", srcPath.c_str());
return false; return false;
} }
if(!dstFile.is_open()) if(!dstFile.is_open())
...@@ -617,27 +627,27 @@ static std::vector<std::string> SplitString(std::string str, std::string separat ...@@ -617,27 +627,27 @@ static std::vector<std::string> SplitString(std::string str, std::string separat
printf("can not open %s\n", dstPath.c_str()); printf("can not open %s\n", dstPath.c_str());
return false; return false;
} }
if(srcPath==dstPath) if(srcPath == dstPath)
{ {
printf("src can not be same with dst\n"); printf("src can not be same with dst\n");
return false; return false;
} }
char buffer[2048]; char buffer[2048];
unsigned int numberOfBytes=0; unsigned int numberOfBytes = 0;
while(srcFile) while(srcFile)
{ {
srcFile.read(buffer,2048); srcFile.read(buffer, 2048);
dstFile.write(buffer,srcFile.gcount()); dstFile.write(buffer, srcFile.gcount());
numberOfBytes+=srcFile.gcount(); numberOfBytes += srcFile.gcount();
} }
srcFile.close(); srcFile.close();
dstFile.close(); dstFile.close();
return true; return true;
} }
bool CopyDirectories(string srcPath, const string dstPath) bool CopyDirectories(string srcPath, const string dstPath)
{ {
if(srcPath==dstPath) if(srcPath == dstPath)
{ {
printf("src can not be same with dst\n"); printf("src can not be same with dst\n");
return false; return false;
...@@ -649,45 +659,41 @@ static std::vector<std::string> SplitString(std::string str, std::string separat ...@@ -649,45 +659,41 @@ static std::vector<std::string> SplitString(std::string str, std::string separat
vector<string> fileNameList; vector<string> fileNameList;
GetFileNameList2(srcPath, "", fileNameList, true, true); GetFileNameList2(srcPath, "", fileNameList, true, true);
string parentPathOfSrc=GetParentPath(srcPath); string parentPathOfSrc = GetParentPath(srcPath);
int length=parentPathOfSrc.size(); int length = parentPathOfSrc.size();
// create all directories // create all directories
for(int i=0;i<fileNameList.size();++i) for(int i = 0; i < fileNameList.size(); ++i)
{ {
// create directory // create directory
string srcFilePath=fileNameList[i]; string srcFilePath = fileNameList[i];
string subStr=srcFilePath.substr(length,srcFilePath.size()-length); string subStr = srcFilePath.substr(length, srcFilePath.size() - length);
string dstFilePath=dstPath+subStr; string dstFilePath = dstPath + subStr;
string parentPathOfDst=GetParentPath(dstFilePath); string parentPathOfDst = GetParentPath(dstFilePath);
CreateDirectories(parentPathOfDst); CreateDirectories(parentPathOfDst);
} }
// copy file // copy file
for(int i=0;i<fileNameList.size();++i) for(int i = 0; i < fileNameList.size(); ++i)
{ {
string srcFilePath=fileNameList[i]; string srcFilePath = fileNameList[i];
if (IsDirectory(srcFilePath)) if(IsDirectory(srcFilePath))
{ {
continue; continue;
} }
string subStr=srcFilePath.substr(length,srcFilePath.size()-length); string subStr = srcFilePath.substr(length, srcFilePath.size() - length);
string dstFilePath=dstPath+subStr; string dstFilePath = dstPath + subStr;
// copy file // copy file
CopyFile(srcFilePath,dstFilePath); CopyFile(srcFilePath, dstFilePath);
// process // process
double process = (1.0*(i + 1) / fileNameList.size()) * 100; double process = (1.0 * (i + 1) / fileNameList.size()) * 100;
printf("%s done! %f% \n", GetFileName(fileNameList[i]).c_str(), process); printf("%s done! %f% \n", GetFileName(fileNameList[i]).c_str(), process);
} }
printf("all done!(the number of files:%d)\n", fileNameList.size()); printf("all done!(the number of files:%d)\n", fileNameList.size());
return true; return true;
}
} }
} // namespace migraphxSamples
...@@ -10,19 +10,19 @@ namespace migraphxSamples ...@@ -10,19 +10,19 @@ namespace migraphxSamples
{ {
// 路径是否存在 // 路径是否存在
bool Exists(const std::string &path); bool Exists(const std::string& path);
// 路径是否为目录 // 路径是否为目录
bool IsDirectory(const std::string &path); bool IsDirectory(const std::string& path);
// 是否是路径分隔符(Linux:‘/’,Windows:’\\’) // 是否是路径分隔符(Linux:‘/’,Windows:’\\’)
bool IsPathSeparator(char c); bool IsPathSeparator(char c);
// 路径拼接 // 路径拼接
std::string JoinPath(const std::string &base, const std::string &path); std::string JoinPath(const std::string& base, const std::string& path);
// 创建多级目录,注意:创建多级目录的时候,目标目录是不能有文件存在的 // 创建多级目录,注意:创建多级目录的时候,目标目录是不能有文件存在的
bool CreateDirectories(const std::string &directoryPath); bool CreateDirectories(const std::string& directoryPath);
/** 生成符合指定模式的文件名列表(支持递归遍历) /** 生成符合指定模式的文件名列表(支持递归遍历)
* *
...@@ -36,25 +36,33 @@ bool CreateDirectories(const std::string &directoryPath); ...@@ -36,25 +36,33 @@ bool CreateDirectories(const std::string &directoryPath);
5. 不能返回子目录名 5. 不能返回子目录名
* *
*/ */
void GetFileNameList(const std::string &directory, const std::string &pattern, std::vector<std::string> &result, bool recursive, bool addPath); void GetFileNameList(const std::string& directory,
const std::string& pattern,
std::vector<std::string>& result,
bool recursive,
bool addPath);
// 与GetFileNameList的区别在于如果有子目录,在addPath为true的时候会返回子目录路径(目录名最后有"/") // 与GetFileNameList的区别在于如果有子目录,在addPath为true的时候会返回子目录路径(目录名最后有"/")
void GetFileNameList2(const std::string &directory, const std::string &pattern, std::vector<std::string> &result, bool recursive, bool addPath); void GetFileNameList2(const std::string& directory,
const std::string& pattern,
std::vector<std::string>& result,
bool recursive,
bool addPath);
// 删除文件或者目录,支持递归删除 // 删除文件或者目录,支持递归删除
void Remove(const std::string &directory, const std::string &extension=""); void Remove(const std::string& directory, const std::string& extension = "");
/** 获取路径的文件名和扩展名 /** 获取路径的文件名和扩展名
* *
* 示例:path为D:/1/1.txt,则GetFileName()为1.txt,GetFileName_NoExtension()为1,GetExtension()为.txt,GetParentPath()为D:/1/ * 示例:path为D:/1/1.txt,则GetFileName()为1.txt,GetFileName_NoExtension()为1,GetExtension()为.txt,GetParentPath()为D:/1/
*/ */
std::string GetFileName(const std::string &path); std::string GetFileName(const std::string& path);
std::string GetFileName_NoExtension(const std::string &path); std::string GetFileName_NoExtension(const std::string& path);
std::string GetExtension(const std::string &path); std::string GetExtension(const std::string& path);
std::string GetParentPath(const std::string &path); std::string GetParentPath(const std::string& path);
// 拷贝文件 // 拷贝文件
bool CopyFile(const std::string srcPath,const std::string dstPath); bool CopyFile(const std::string srcPath, const std::string dstPath);
/** 拷贝目录 /** 拷贝目录
* *
...@@ -63,8 +71,8 @@ bool CopyFile(const std::string srcPath,const std::string dstPath); ...@@ -63,8 +71,8 @@ bool CopyFile(const std::string srcPath,const std::string dstPath);
1.第一个参数的最后不能加”/” 1.第一个参数的最后不能加”/”
2.不能拷贝隐藏文件 2.不能拷贝隐藏文件
*/ */
bool CopyDirectories(std::string srcPath,const std::string dstPath); bool CopyDirectories(std::string srcPath, const std::string dstPath);
} } // namespace migraphxSamples
#endif #endif
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include <map> #include <map>
#include <thread> #include <thread>
#include <mutex> #include <mutex>
#if (defined WIN32 || defined _WIN32) #if(defined WIN32 || defined _WIN32)
#include <Windows.h> #include <Windows.h>
#else #else
#include <sys/time.h> #include <sys/time.h>
...@@ -16,13 +16,13 @@ ...@@ -16,13 +16,13 @@
using namespace std; using namespace std;
/** 简易日志 /** 简易日志
* *
* 不依赖于其他第三方库,只需要包含一个头文件就可以使用。提供了4种日志级别,包括INFO,DEBUG,WARN和ERROR。 * 不依赖于其他第三方库,只需要包含一个头文件就可以使用。提供了4种日志级别,包括INFO,DEBUG,WARN和ERROR。
* *
* 示例1: * 示例1:
// 初始化日志,在./Log/目录下创建两个日志文件log1.log和log2.log(注意:目录./Log/需要存在,否则日志创建失败) //
初始化日志,在./Log/目录下创建两个日志文件log1.log和log2.log(注意:目录./Log/需要存在,否则日志创建失败)
LogManager::GetInstance()->Initialize("./Log/","log1"); LogManager::GetInstance()->Initialize("./Log/","log1");
LogManager::GetInstance()->Initialize("./Log/","log2"); LogManager::GetInstance()->Initialize("./Log/","log2");
...@@ -50,44 +50,43 @@ using namespace std; ...@@ -50,44 +50,43 @@ using namespace std;
class LogManager class LogManager
{ {
private: private:
LogManager(){} LogManager() {}
public: public:
~LogManager(){} ~LogManager() {}
inline void Initialize(const string &parentPath,const string &logName) inline void Initialize(const string& parentPath, const string& logName)
{ {
// 日志名为空表示输出到控制台 // 日志名为空表示输出到控制台
if(logName.size()==0) if(logName.size() == 0)
return; return;
// 查找该日志文件,如果没有则创建 // 查找该日志文件,如果没有则创建
std::map<string, FILE*>::const_iterator iter = logMap.find(logName); std::map<string, FILE*>::const_iterator iter = logMap.find(logName);
if (iter == logMap.end()) if(iter == logMap.end())
{ {
string pathOfLog = parentPath+ logName + ".log"; string pathOfLog = parentPath + logName + ".log";
FILE *logFile = fopen(pathOfLog.c_str(), "a"); // w:覆盖原有文件,a:追加 FILE* logFile = fopen(pathOfLog.c_str(), "a"); // w:覆盖原有文件,a:追加
if(logFile!=NULL) if(logFile != NULL)
{ {
logMap.insert(std::make_pair(logName, logFile)); logMap.insert(std::make_pair(logName, logFile));
} }
} }
} }
inline FILE* GetLogFile(const string &logName) inline FILE* GetLogFile(const string& logName)
{ {
std::map<string, FILE*>::const_iterator iter=logMap.find(logName); std::map<string, FILE*>::const_iterator iter = logMap.find(logName);
if(iter==logMap.end()) if(iter == logMap.end())
{ {
return NULL; return NULL;
} }
return (*iter).second; return (*iter).second;
} }
inline void Close(const string &logName) inline void Close(const string& logName)
{ {
std::map<string, FILE*>::const_iterator iter=logMap.find(logName); std::map<string, FILE*>::const_iterator iter = logMap.find(logName);
if(iter==logMap.end()) if(iter == logMap.end())
{ {
return; return;
} }
...@@ -95,10 +94,7 @@ public: ...@@ -95,10 +94,7 @@ public:
fclose((*iter).second); fclose((*iter).second);
logMap.erase(iter); logMap.erase(iter);
} }
inline std::mutex &GetLogMutex() inline std::mutex& GetLogMutex() { return logMutex; }
{
return logMutex;
}
// Singleton // Singleton
static LogManager* GetInstance() static LogManager* GetInstance()
...@@ -106,17 +102,18 @@ public: ...@@ -106,17 +102,18 @@ public:
static LogManager logManager; static LogManager logManager;
return &logManager; return &logManager;
} }
private:
private:
std::map<string, FILE*> logMap; std::map<string, FILE*> logMap;
std::mutex logMutex; std::mutex logMutex;
}; };
#ifdef LOG_MUTEX #ifdef LOG_MUTEX
#define LOCK LogManager::GetInstance()->GetLogMutex().lock() #define LOCK LogManager::GetInstance()->GetLogMutex().lock()
#define UNLOCK LogManager::GetInstance()->GetLogMutex().unlock() #define UNLOCK LogManager::GetInstance()->GetLogMutex().unlock()
#else #else
#define LOCK #define LOCK
#define UNLOCK #define UNLOCK
#endif #endif
// log time // log time
...@@ -131,53 +128,53 @@ typedef struct _LogTime ...@@ -131,53 +128,53 @@ typedef struct _LogTime
string millisecond; // ms string millisecond; // ms
string microsecond; // us string microsecond; // us
string weekDay; string weekDay;
}LogTime; } LogTime;
inline LogTime GetTime() inline LogTime GetTime()
{ {
LogTime currentTime; LogTime currentTime;
#if (defined WIN32 || defined _WIN32) #if(defined WIN32 || defined _WIN32)
SYSTEMTIME systemTime; SYSTEMTIME systemTime;
GetLocalTime(&systemTime); GetLocalTime(&systemTime);
char temp[8] = { 0 }; char temp[8] = {0};
sprintf(temp, "%04d", systemTime.wYear); sprintf(temp, "%04d", systemTime.wYear);
currentTime.year=string(temp); currentTime.year = string(temp);
sprintf(temp, "%02d", systemTime.wMonth); sprintf(temp, "%02d", systemTime.wMonth);
currentTime.month=string(temp); currentTime.month = string(temp);
sprintf(temp, "%02d", systemTime.wDay); sprintf(temp, "%02d", systemTime.wDay);
currentTime.day=string(temp); currentTime.day = string(temp);
sprintf(temp, "%02d", systemTime.wHour); sprintf(temp, "%02d", systemTime.wHour);
currentTime.hour=string(temp); currentTime.hour = string(temp);
sprintf(temp, "%02d", systemTime.wMinute); sprintf(temp, "%02d", systemTime.wMinute);
currentTime.minute=string(temp); currentTime.minute = string(temp);
sprintf(temp, "%02d", systemTime.wSecond); sprintf(temp, "%02d", systemTime.wSecond);
currentTime.second=string(temp); currentTime.second = string(temp);
sprintf(temp, "%03d", systemTime.wMilliseconds); sprintf(temp, "%03d", systemTime.wMilliseconds);
currentTime.millisecond=string(temp); currentTime.millisecond = string(temp);
sprintf(temp, "%d", systemTime.wDayOfWeek); sprintf(temp, "%d", systemTime.wDayOfWeek);
currentTime.weekDay=string(temp); currentTime.weekDay = string(temp);
#else #else
struct timeval tv; struct timeval tv;
struct tm *p; struct tm* p;
gettimeofday(&tv, NULL); gettimeofday(&tv, NULL);
p = localtime(&tv.tv_sec); p = localtime(&tv.tv_sec);
char temp[8]={0}; char temp[8] = {0};
sprintf(temp,"%04d",1900+p->tm_year); sprintf(temp, "%04d", 1900 + p->tm_year);
currentTime.year=string(temp); currentTime.year = string(temp);
sprintf(temp,"%02d",1+p->tm_mon); sprintf(temp, "%02d", 1 + p->tm_mon);
currentTime.month=string(temp); currentTime.month = string(temp);
sprintf(temp,"%02d",p->tm_mday); sprintf(temp, "%02d", p->tm_mday);
currentTime.day=string(temp); currentTime.day = string(temp);
sprintf(temp,"%02d",p->tm_hour); sprintf(temp, "%02d", p->tm_hour);
currentTime.hour=string(temp); currentTime.hour = string(temp);
sprintf(temp,"%02d",p->tm_min); sprintf(temp, "%02d", p->tm_min);
currentTime.minute=string(temp); currentTime.minute = string(temp);
sprintf(temp,"%02d",p->tm_sec); sprintf(temp, "%02d", p->tm_sec);
currentTime.second=string(temp); currentTime.second = string(temp);
sprintf(temp,"%03d",(int)(tv.tv_usec/1000)); sprintf(temp, "%03d", (int)(tv.tv_usec / 1000));
currentTime.millisecond = string(temp); currentTime.millisecond = string(temp);
sprintf(temp, "%03d", (int)(tv.tv_usec % 1000)); sprintf(temp, "%03d", (int)(tv.tv_usec % 1000));
currentTime.microsecond = string(temp); currentTime.microsecond = string(temp);
...@@ -188,60 +185,82 @@ inline LogTime GetTime() ...@@ -188,60 +185,82 @@ inline LogTime GetTime()
} }
#define LOG_TIME(logFile) \ #define LOG_TIME(logFile) \
do\ do \
{\ { \
LogTime currentTime=GetTime(); \ LogTime currentTime = GetTime(); \
fprintf(((logFile == NULL) ? stdout : logFile), "%s-%s-%s %s:%s:%s.%s\t",currentTime.year.c_str(),currentTime.month.c_str(),currentTime.day.c_str(),currentTime.hour.c_str(),currentTime.minute.c_str(),currentTime.second.c_str(),currentTime.millisecond.c_str()); \ fprintf(((logFile == NULL) ? stdout : logFile), \
}while (0) "%s-%s-%s %s:%s:%s.%s\t", \
currentTime.year.c_str(), \
currentTime.month.c_str(), \
#define LOG_INFO(logFile,logInfo, ...) \ currentTime.day.c_str(), \
do\ currentTime.hour.c_str(), \
{\ currentTime.minute.c_str(), \
currentTime.second.c_str(), \
currentTime.millisecond.c_str()); \
} while(0)
#define LOG_INFO(logFile, logInfo, ...) \
do \
{ \
LOCK; \ LOCK; \
LOG_TIME(logFile); \ LOG_TIME(logFile); \
fprintf(((logFile == NULL) ? stdout : logFile), "INFO\t"); \ fprintf(((logFile == NULL) ? stdout : logFile), "INFO\t"); \
fprintf(((logFile == NULL) ? stdout : logFile), "[%s:%d (%s) ]: ", __FILE__, __LINE__, __FUNCTION__); \ fprintf(((logFile == NULL) ? stdout : logFile), \
fprintf(((logFile == NULL) ? stdout : logFile), logInfo, ## __VA_ARGS__); \ "[%s:%d (%s) ]: ", \
__FILE__, \
__LINE__, \
__FUNCTION__); \
fprintf(((logFile == NULL) ? stdout : logFile), logInfo, ##__VA_ARGS__); \
fflush(logFile); \ fflush(logFile); \
UNLOCK; \ UNLOCK; \
} while (0) } while(0)
#define LOG_DEBUG(logFile,logInfo, ...) \ #define LOG_DEBUG(logFile, logInfo, ...) \
do\ do \
{\ { \
LOCK; \ LOCK; \
LOG_TIME(logFile);\ LOG_TIME(logFile); \
fprintf(((logFile==NULL)?stdout:logFile), "DEBUG\t"); \ fprintf(((logFile == NULL) ? stdout : logFile), "DEBUG\t"); \
fprintf(((logFile==NULL)?stdout:logFile), "[%s:%d (%s) ]: ", __FILE__, __LINE__, __FUNCTION__); \ fprintf(((logFile == NULL) ? stdout : logFile), \
fprintf(((logFile==NULL)?stdout:logFile),logInfo, ## __VA_ARGS__); \ "[%s:%d (%s) ]: ", \
__FILE__, \
__LINE__, \
__FUNCTION__); \
fprintf(((logFile == NULL) ? stdout : logFile), logInfo, ##__VA_ARGS__); \
fflush(logFile); \ fflush(logFile); \
UNLOCK; \ UNLOCK; \
} while (0) } while(0)
#define LOG_ERROR(logFile,logInfo, ...) \ #define LOG_ERROR(logFile, logInfo, ...) \
do\ do \
{\ { \
LOCK; \ LOCK; \
LOG_TIME(logFile);\ LOG_TIME(logFile); \
fprintf(((logFile==NULL)?stdout:logFile), "ERROR\t"); \ fprintf(((logFile == NULL) ? stdout : logFile), "ERROR\t"); \
fprintf(((logFile==NULL)?stdout:logFile), "[%s:%d (%s) ]: ", __FILE__, __LINE__, __FUNCTION__); \ fprintf(((logFile == NULL) ? stdout : logFile), \
fprintf(((logFile==NULL)?stdout:logFile),logInfo, ## __VA_ARGS__); \ "[%s:%d (%s) ]: ", \
__FILE__, \
__LINE__, \
__FUNCTION__); \
fprintf(((logFile == NULL) ? stdout : logFile), logInfo, ##__VA_ARGS__); \
fflush(logFile); \ fflush(logFile); \
UNLOCK; \ UNLOCK; \
} while (0) } while(0)
#define LOG_WARN(logFile,logInfo, ...) \ #define LOG_WARN(logFile, logInfo, ...) \
do\ do \
{\ { \
LOCK; \ LOCK; \
LOG_TIME(logFile);\ LOG_TIME(logFile); \
fprintf(((logFile==NULL)?stdout:logFile), "WARN\t"); \ fprintf(((logFile == NULL) ? stdout : logFile), "WARN\t"); \
fprintf(((logFile==NULL)?stdout:logFile), "[%s:%d (%s) ]: ", __FILE__, __LINE__, __FUNCTION__); \ fprintf(((logFile == NULL) ? stdout : logFile), \
fprintf(((logFile==NULL)?stdout:logFile),logInfo, ## __VA_ARGS__); \ "[%s:%d (%s) ]: ", \
__FILE__, \
__LINE__, \
__FUNCTION__); \
fprintf(((logFile == NULL) ? stdout : logFile), logInfo, ##__VA_ARGS__); \
fflush(logFile); \ fflush(logFile); \
UNLOCK; \ UNLOCK; \
} while (0) } while(0)
#endif // __SIMPLE_LOG_H__ #endif // __SIMPLE_LOG_H__
...@@ -6,119 +6,125 @@ ...@@ -6,119 +6,125 @@
#include "./tokenization.h" #include "./tokenization.h"
namespace cuBERT
{
namespace cuBERT { void FullTokenizer::convert_tokens_to_ids(const std::vector<std::string>& tokens, uint64_t* ids)
{
void FullTokenizer::convert_tokens_to_ids(const std::vector<std::string> &tokens, uint64_t *ids) { for(int i = 0; i < tokens.size(); ++i)
for (int i = 0; i < tokens.size(); ++i) { {
ids[i] = convert_token_to_id(tokens[i]); ids[i] = convert_token_to_id(tokens[i]);
} }
} }
// trim from start (in place) // trim from start (in place)
static inline void ltrim(std::string &s) { static inline void ltrim(std::string& s)
s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int ch) { {
return !std::isspace(ch); s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int ch) { return !std::isspace(ch); }));
})); }
}
// trim from end (in place) // trim from end (in place)
static inline void rtrim(std::string &s) { static inline void rtrim(std::string& s)
s.erase(std::find_if(s.rbegin(), s.rend(), [](int ch) { {
return !std::isspace(ch); s.erase(std::find_if(s.rbegin(), s.rend(), [](int ch) { return !std::isspace(ch); }).base(),
}).base(), s.end()); s.end());
} }
// trim from both ends (in place) // trim from both ends (in place)
static inline void trim(std::string &s) { static inline void trim(std::string& s)
{
ltrim(s); ltrim(s);
rtrim(s); rtrim(s);
} }
void load_vocab(const char *vocab_file, std::unordered_map<std::string, uint64_t> *vocab) { void load_vocab(const char* vocab_file, std::unordered_map<std::string, uint64_t>* vocab)
{
std::ifstream file(vocab_file); std::ifstream file(vocab_file);
if (!file) { if(!file)
{
throw std::invalid_argument("Unable to open vocab file"); throw std::invalid_argument("Unable to open vocab file");
} }
unsigned int index = 0; unsigned int index = 0;
std::string line; std::string line;
while (std::getline(file, line)) { while(std::getline(file, line))
{
trim(line); trim(line);
(*vocab)[line] = index; (*vocab)[line] = index;
index++; index++;
} }
file.close(); file.close();
} }
inline bool _is_whitespace(int c, const char *cat) { inline bool _is_whitespace(int c, const char* cat)
if (c == ' ' || c == '\t' || c == '\n' || c == '\r') { {
if(c == ' ' || c == '\t' || c == '\n' || c == '\r')
{
return true; return true;
} }
return cat[0] == 'Z' && cat[1] == 's'; return cat[0] == 'Z' && cat[1] == 's';
} }
inline bool _is_control(int c, const char *cat) { inline bool _is_control(int c, const char* cat)
{
// These are technically control characters but we count them as whitespace characters. // These are technically control characters but we count them as whitespace characters.
if (c == '\t' || c == '\n' || c == '\r') { if(c == '\t' || c == '\n' || c == '\r')
{
return false; return false;
} }
return 'C' == *cat; return 'C' == *cat;
} }
inline bool _is_punctuation(int cp, const char *cat) { inline bool _is_punctuation(int cp, const char* cat)
// We treat all non-letter/number ASCII as punctuation. {
// Characters such as "^", "$", and "`" are not in the Unicode // We treat all non-letter/number ASCII as punctuation.
// Punctuation class but we treat them as punctuation anyways, for // Characters such as "^", "$", and "`" are not in the Unicode
// consistency. // Punctuation class but we treat them as punctuation anyways, for
if ((cp >= 33 && cp <= 47) || (cp >= 58 && cp <= 64) || // consistency.
(cp >= 91 && cp <= 96) || (cp >= 123 && cp <= 126)) { if((cp >= 33 && cp <= 47) || (cp >= 58 && cp <= 64) || (cp >= 91 && cp <= 96) ||
(cp >= 123 && cp <= 126))
{
return true; return true;
} }
return 'P' == *cat; return 'P' == *cat;
} }
bool _is_whitespace(int c) {
return _is_whitespace(c, utf8proc_category_string(c));
}
bool _is_control(int c) {
return _is_control(c, utf8proc_category_string(c));
}
bool _is_punctuation(int cp) {
return _is_punctuation(cp, utf8proc_category_string(cp));
}
bool BasicTokenizer::_is_chinese_char(int cp) { bool _is_whitespace(int c) { return _is_whitespace(c, utf8proc_category_string(c)); }
// This defines a "chinese character" as anything in the CJK Unicode block:
// https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) bool _is_control(int c) { return _is_control(c, utf8proc_category_string(c)); }
//
// Note that the CJK Unicode block is NOT all Japanese and Korean characters, bool _is_punctuation(int cp) { return _is_punctuation(cp, utf8proc_category_string(cp)); }
// despite its name. The modern Korean Hangul alphabet is a different block,
// as is Japanese Hiragana and Katakana. Those alphabets are used to write bool BasicTokenizer::_is_chinese_char(int cp)
// space-separated words, so they are not treated specially and handled {
// like the all of the other languages. // This defines a "chinese character" as anything in the CJK Unicode block:
return (cp >= 0x4E00 && cp <= 0x9FFF) || // https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
(cp >= 0x3400 && cp <= 0x4DBF) || //
(cp >= 0x20000 && cp <= 0x2A6DF) || // Note that the CJK Unicode block is NOT all Japanese and Korean characters,
(cp >= 0x2A700 && cp <= 0x2B73F) || // despite its name. The modern Korean Hangul alphabet is a different block,
(cp >= 0x2B740 && cp <= 0x2B81F) || // as is Japanese Hiragana and Katakana. Those alphabets are used to write
(cp >= 0x2B820 && cp <= 0x2CEAF) || // space-separated words, so they are not treated specially and handled
(cp >= 0xF900 && cp <= 0xFAFF) || // like the all of the other languages.
(cp >= 0x2F800 && cp <= 0x2FA1F); return (cp >= 0x4E00 && cp <= 0x9FFF) || (cp >= 0x3400 && cp <= 0x4DBF) ||
} (cp >= 0x20000 && cp <= 0x2A6DF) || (cp >= 0x2A700 && cp <= 0x2B73F) ||
(cp >= 0x2B740 && cp <= 0x2B81F) || (cp >= 0x2B820 && cp <= 0x2CEAF) ||
(cp >= 0xF900 && cp <= 0xFAFF) || (cp >= 0x2F800 && cp <= 0x2FA1F);
}
void BasicTokenizer::tokenize(const char *text, std::vector<std::string> *output_tokens, size_t max_length) { void BasicTokenizer::tokenize(const char* text,
// This was added on November 1st, 2018 for the multilingual and Chinese std::vector<std::string>* output_tokens,
// models. This is also applied to the English models now, but it doesn't size_t max_length)
// matter since the English models were not trained on any Chinese data {
// and generally don't have any Chinese data in them (there are Chinese // This was added on November 1st, 2018 for the multilingual and Chinese
// characters in the vocabulary because Wikipedia does have some Chinese // models. This is also applied to the English models now, but it doesn't
// words in the English Wikipedia.). // matter since the English models were not trained on any Chinese data
if (do_lower_case) { // and generally don't have any Chinese data in them (there are Chinese
text = (const char *) utf8proc_NFD((const utf8proc_uint8_t *) text); // characters in the vocabulary because Wikipedia does have some Chinese
// words in the English Wikipedia.).
if(do_lower_case)
{
text = (const char*)utf8proc_NFD((const utf8proc_uint8_t*)text);
} }
size_t word_bytes = std::strlen(text); size_t word_bytes = std::strlen(text);
...@@ -127,39 +133,56 @@ namespace cuBERT { ...@@ -127,39 +133,56 @@ namespace cuBERT {
int cp; int cp;
char dst[4]; char dst[4];
while (word_bytes > 0) { while(word_bytes > 0)
int len = utf8proc_iterate((const utf8proc_uint8_t *) text + subpos, word_bytes, &cp); {
if (len < 0) { int len = utf8proc_iterate((const utf8proc_uint8_t*)text + subpos, word_bytes, &cp);
if(len < 0)
{
std::cerr << "UTF-8 decode error: " << text << std::endl; std::cerr << "UTF-8 decode error: " << text << std::endl;
break; break;
} }
if (do_lower_case) { if(do_lower_case)
{
cp = utf8proc_tolower(cp); cp = utf8proc_tolower(cp);
} }
const char *cat = utf8proc_category_string(cp); const char* cat = utf8proc_category_string(cp);
if (cp == 0 || cp == 0xfffd || _is_control(cp, cat)) { if(cp == 0 || cp == 0xfffd || _is_control(cp, cat))
{
// pass // pass
} else if (do_lower_case && cat[0] == 'M' && cat[1] == 'n') { }
else if(do_lower_case && cat[0] == 'M' && cat[1] == 'n')
{
// pass // pass
} else if (_is_whitespace(cp, cat)) { }
else if(_is_whitespace(cp, cat))
{
new_token = true; new_token = true;
} else { }
else
{
size_t dst_len = len; size_t dst_len = len;
const char *dst_ptr = text + subpos; const char* dst_ptr = text + subpos;
if (do_lower_case) { if(do_lower_case)
dst_len = utf8proc_encode_char(cp, (utf8proc_uint8_t *) dst); {
dst_len = utf8proc_encode_char(cp, (utf8proc_uint8_t*)dst);
dst_ptr = dst; dst_ptr = dst;
} }
if (_is_punctuation(cp, cat) || _is_chinese_char(cp)) { if(_is_punctuation(cp, cat) || _is_chinese_char(cp))
{
output_tokens->emplace_back(dst_ptr, dst_len); output_tokens->emplace_back(dst_ptr, dst_len);
new_token = true; new_token = true;
} else { }
if (new_token) { else
{
if(new_token)
{
output_tokens->emplace_back(dst_ptr, dst_len); output_tokens->emplace_back(dst_ptr, dst_len);
new_token = false; new_token = false;
} else { }
else
{
output_tokens->at(output_tokens->size() - 1).append(dst_ptr, dst_len); output_tokens->at(output_tokens->size() - 1).append(dst_ptr, dst_len);
} }
} }
...@@ -169,33 +192,38 @@ namespace cuBERT { ...@@ -169,33 +192,38 @@ namespace cuBERT {
subpos = subpos + len; subpos = subpos + len;
// early terminate // early terminate
if (output_tokens->size() >= max_length) { if(output_tokens->size() >= max_length)
{
break; break;
} }
} }
if (do_lower_case) { if(do_lower_case)
free((void *) text); {
} free((void*)text);
} }
}
void WordpieceTokenizer::tokenize(const std::string& token, std::vector<std::string>* output_tokens)
void WordpieceTokenizer::tokenize(const std::string &token, std::vector<std::string> *output_tokens) { {
if (token.size() > max_input_chars_per_word) { // FIXME: slightly different if(token.size() > max_input_chars_per_word)
{ // FIXME: slightly different
output_tokens->push_back(unk_token); output_tokens->push_back(unk_token);
return; return;
} }
size_t output_tokens_len = output_tokens->size(); size_t output_tokens_len = output_tokens->size();
for (size_t start = 0; start < token.size();) { for(size_t start = 0; start < token.size();)
{
bool is_bad = true; bool is_bad = true;
// TODO: can be optimized by prefix-tree // TODO: can be optimized by prefix-tree
for (size_t end = token.size(); start < end; --end) { // FIXME: slightly different for(size_t end = token.size(); start < end; --end)
std::string substr = start > 0 { // FIXME: slightly different
? "##" + token.substr(start, end - start) std::string substr = start > 0 ? "##" + token.substr(start, end - start)
: token.substr(start, end - start); : token.substr(start, end - start);
if (vocab->count(substr)) { if(vocab->count(substr))
{
is_bad = false; is_bad = false;
output_tokens->push_back(substr); output_tokens->push_back(substr);
start = end; start = end;
...@@ -203,27 +231,32 @@ namespace cuBERT { ...@@ -203,27 +231,32 @@ namespace cuBERT {
} }
} }
if (is_bad) { if(is_bad)
{
output_tokens->resize(output_tokens_len); output_tokens->resize(output_tokens_len);
output_tokens->push_back(unk_token); output_tokens->push_back(unk_token);
return; return;
} }
} }
} }
void FullTokenizer::tokenize(const char *text, std::vector<std::string> *output_tokens, size_t max_length) { void FullTokenizer::tokenize(const char* text,
std::vector<std::string>* output_tokens,
size_t max_length)
{
std::vector<std::string> tokens; std::vector<std::string> tokens;
tokens.reserve(max_length); tokens.reserve(max_length);
basic_tokenizer->tokenize(text, &tokens, max_length); basic_tokenizer->tokenize(text, &tokens, max_length);
for (const auto &token : tokens) { for(const auto& token : tokens)
{
wordpiece_tokenizer->tokenize(token, output_tokens); wordpiece_tokenizer->tokenize(token, output_tokens);
// early terminate // early terminate
if (output_tokens->size() >= max_length) { if(output_tokens->size() >= max_length)
{
break; break;
} }
} }
}
} }
} // namespace cuBERT
...@@ -6,35 +6,37 @@ ...@@ -6,35 +6,37 @@
#include <unordered_map> #include <unordered_map>
#include <iostream> #include <iostream>
namespace cuBERT { namespace cuBERT
{
void load_vocab(const char *vocab_file, std::unordered_map<std::string, uint64_t> *vocab); void load_vocab(const char* vocab_file, std::unordered_map<std::string, uint64_t>* vocab);
/** /**
* Checks whether `chars` is a whitespace character. * Checks whether `chars` is a whitespace character.
* @param c * @param c
* @return * @return
*/ */
bool _is_whitespace(int c); bool _is_whitespace(int c);
/** /**
* Checks whether `chars` is a control character. * Checks whether `chars` is a control character.
* @param c * @param c
* @return * @return
*/ */
bool _is_control(int c); bool _is_control(int c);
/** /**
* Checks whether `chars` is a punctuation character. * Checks whether `chars` is a punctuation character.
* @param cp * @param cp
* @return * @return
*/ */
bool _is_punctuation(int cp); bool _is_punctuation(int cp);
/** /**
* Runs basic tokenization (punctuation splitting, lower casing, etc.). * Runs basic tokenization (punctuation splitting, lower casing, etc.).
*/ */
class BasicTokenizer { class BasicTokenizer
{
public: public:
/** /**
* Constructs a BasicTokenizer. * Constructs a BasicTokenizer.
...@@ -42,7 +44,7 @@ namespace cuBERT { ...@@ -42,7 +44,7 @@ namespace cuBERT {
*/ */
explicit BasicTokenizer(bool do_lower_case = true) : do_lower_case(do_lower_case) {} explicit BasicTokenizer(bool do_lower_case = true) : do_lower_case(do_lower_case) {}
BasicTokenizer(const BasicTokenizer &other) = delete; BasicTokenizer(const BasicTokenizer& other) = delete;
virtual ~BasicTokenizer() = default; virtual ~BasicTokenizer() = default;
...@@ -59,7 +61,7 @@ namespace cuBERT { ...@@ -59,7 +61,7 @@ namespace cuBERT {
* @param text * @param text
* @param output_tokens * @param output_tokens
*/ */
void tokenize(const char *text, std::vector<std::string> *output_tokens, size_t max_length); void tokenize(const char* text, std::vector<std::string>* output_tokens, size_t max_length);
private: private:
const bool do_lower_case; const bool do_lower_case;
...@@ -70,20 +72,22 @@ namespace cuBERT { ...@@ -70,20 +72,22 @@ namespace cuBERT {
* @return * @return
*/ */
inline static bool _is_chinese_char(int cp); inline static bool _is_chinese_char(int cp);
}; };
/** /**
* Runs WordPiece tokenziation. * Runs WordPiece tokenziation.
*/ */
class WordpieceTokenizer { class WordpieceTokenizer
{
public: public:
explicit WordpieceTokenizer( explicit WordpieceTokenizer(std::unordered_map<std::string, uint64_t>* vocab,
std::unordered_map<std::string, uint64_t> *vocab,
std::string unk_token = "[UNK]", std::string unk_token = "[UNK]",
int max_input_chars_per_word = 200 int max_input_chars_per_word = 200)
) : vocab(vocab), unk_token(unk_token), max_input_chars_per_word(max_input_chars_per_word) {} : vocab(vocab), unk_token(unk_token), max_input_chars_per_word(max_input_chars_per_word)
{
}
WordpieceTokenizer(const WordpieceTokenizer &other) = delete; WordpieceTokenizer(const WordpieceTokenizer& other) = delete;
virtual ~WordpieceTokenizer() = default; virtual ~WordpieceTokenizer() = default;
...@@ -97,67 +101,77 @@ namespace cuBERT { ...@@ -97,67 +101,77 @@ namespace cuBERT {
* input = "unaffable" * input = "unaffable"
* output = ["un", "##aff", "##able"] * output = ["un", "##aff", "##able"]
* *
* @param text A single token or whitespace separated tokens. This should have already been passed through `BasicTokenizer. * @param text A single token or whitespace separated tokens. This should have already been
* passed through `BasicTokenizer.
* @param output_tokens A list of wordpiece tokens. * @param output_tokens A list of wordpiece tokens.
*/ */
void tokenize(const std::string &text, std::vector<std::string> *output_tokens); void tokenize(const std::string& text, std::vector<std::string>* output_tokens);
private: private:
const std::unordered_map<std::string, uint64_t> *vocab; const std::unordered_map<std::string, uint64_t>* vocab;
const std::string unk_token; const std::string unk_token;
const int max_input_chars_per_word; const int max_input_chars_per_word;
}; };
/** /**
* Runs end-to-end tokenziation. * Runs end-to-end tokenziation.
*/ */
class FullTokenizer { class FullTokenizer
{
public: public:
FullTokenizer(const char *vocab_file, bool do_lower_case = true) { FullTokenizer(const char* vocab_file, bool do_lower_case = true)
{
vocab = new std::unordered_map<std::string, uint64_t>(); vocab = new std::unordered_map<std::string, uint64_t>();
load_vocab(vocab_file, vocab); load_vocab(vocab_file, vocab);
basic_tokenizer = new BasicTokenizer(do_lower_case); basic_tokenizer = new BasicTokenizer(do_lower_case);
wordpiece_tokenizer = new WordpieceTokenizer(vocab); wordpiece_tokenizer = new WordpieceTokenizer(vocab);
} }
~FullTokenizer() { ~FullTokenizer()
if (wordpiece_tokenizer != NULL){ {
if(wordpiece_tokenizer != NULL)
{
wordpiece_tokenizer = NULL; wordpiece_tokenizer = NULL;
} }
delete wordpiece_tokenizer; delete wordpiece_tokenizer;
if (basic_tokenizer != NULL){ if(basic_tokenizer != NULL)
{
basic_tokenizer = NULL; basic_tokenizer = NULL;
} }
delete basic_tokenizer; delete basic_tokenizer;
if (vocab != NULL){ if(vocab != NULL)
{
vocab = NULL; vocab = NULL;
} }
delete vocab; delete vocab;
} }
void tokenize(const char *text, std::vector<std::string> *output_tokens, size_t max_length); void tokenize(const char* text, std::vector<std::string>* output_tokens, size_t max_length);
inline uint64_t convert_token_to_id(const std::string &token) { inline uint64_t convert_token_to_id(const std::string& token)
{
auto item = vocab->find(token); auto item = vocab->find(token);
if (item == vocab->end()) { if(item == vocab->end())
{
std::cerr << "vocab missing key: " << token << std::endl; std::cerr << "vocab missing key: " << token << std::endl;
return 0; return 0;
} else { }
else
{
return item->second; return item->second;
} }
} }
void convert_tokens_to_ids(const std::vector<std::string> &tokens, uint64_t *ids); void convert_tokens_to_ids(const std::vector<std::string>& tokens, uint64_t* ids);
private: private:
std::unordered_map<std::string, uint64_t> *vocab; std::unordered_map<std::string, uint64_t>* vocab;
BasicTokenizer *basic_tokenizer; BasicTokenizer* basic_tokenizer;
WordpieceTokenizer *wordpiece_tokenizer; WordpieceTokenizer* wordpiece_tokenizer;
}; };
} } // namespace cuBERT
#endif //CUBERT_TOKENIZATION_H #endif // CUBERT_TOKENIZATION_H
/* -*- mode: c; c-basic-offset: 2; tab-width: 2; indent-tabs-mode: nil -*- */ /* -*- mode: c; c-basic-offset: 2; tab-width: 2; indent-tabs-mode: nil -*- */
/* /*
* Copyright (c) 2014-2021 Steven G. Johnson, Jiahao Chen, Peter Colberg, Tony Kelman, Scott P. Jones, and other contributors. * Copyright (c) 2014-2021 Steven G. Johnson, Jiahao Chen, Peter Colberg, Tony Kelman, Scott P.
* Copyright (c) 2009 Public Software Group e. V., Berlin, Germany * Jones, and other contributors. Copyright (c) 2009 Public Software Group e. V., Berlin, Germany
* *
* Permission is hereby granted, free of charge, to any person obtaining a * Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"), * copy of this software and associated documentation files (the "Software"),
...@@ -32,7 +32,6 @@ ...@@ -32,7 +32,6 @@
* Please notice the copyright statement in the file "utf8proc_data.c". * Please notice the copyright statement in the file "utf8proc_data.c".
*/ */
/* /*
* File name: utf8proc.c * File name: utf8proc.c
* *
...@@ -40,36 +39,26 @@ ...@@ -40,36 +39,26 @@
* Implementation of libutf8proc. * Implementation of libutf8proc.
*/ */
#include "utf8proc.h" #include "utf8proc.h"
#ifndef SSIZE_MAX #ifndef SSIZE_MAX
#define SSIZE_MAX ((size_t)SIZE_MAX/2) #define SSIZE_MAX ((size_t)SIZE_MAX / 2)
#endif #endif
#ifndef UINT16_MAX #ifndef UINT16_MAX
# define UINT16_MAX 65535U #define UINT16_MAX 65535U
#endif #endif
#include "utf8proc_data.c" #include "utf8proc_data.c"
UTF8PROC_DLLEXPORT const utf8proc_int8_t utf8proc_utf8class[256] = { UTF8PROC_DLLEXPORT const utf8proc_int8_t utf8proc_utf8class[256] = {
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0};
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0 };
#define UTF8PROC_HANGUL_SBASE 0xAC00 #define UTF8PROC_HANGUL_SBASE 0xAC00
#define UTF8PROC_HANGUL_LBASE 0x1100 #define UTF8PROC_HANGUL_LBASE 0x1100
...@@ -96,150 +85,182 @@ UTF8PROC_DLLEXPORT const utf8proc_int8_t utf8proc_utf8class[256] = { ...@@ -96,150 +85,182 @@ UTF8PROC_DLLEXPORT const utf8proc_int8_t utf8proc_utf8class[256] = {
be different, being based on ABI compatibility.): */ be different, being based on ABI compatibility.): */
#define STRINGIZEx(x) #x #define STRINGIZEx(x) #x
#define STRINGIZE(x) STRINGIZEx(x) #define STRINGIZE(x) STRINGIZEx(x)
UTF8PROC_DLLEXPORT const char *utf8proc_version(void) { UTF8PROC_DLLEXPORT const char* utf8proc_version(void)
{
return STRINGIZE(UTF8PROC_VERSION_MAJOR) "." STRINGIZE(UTF8PROC_VERSION_MINOR) "." STRINGIZE(UTF8PROC_VERSION_PATCH) ""; return STRINGIZE(UTF8PROC_VERSION_MAJOR) "." STRINGIZE(UTF8PROC_VERSION_MINOR) "." STRINGIZE(UTF8PROC_VERSION_PATCH) "";
} }
UTF8PROC_DLLEXPORT const char *utf8proc_unicode_version(void) { UTF8PROC_DLLEXPORT const char* utf8proc_unicode_version(void) { return "15.0.0"; }
return "15.0.0";
}
UTF8PROC_DLLEXPORT const char *utf8proc_errmsg(utf8proc_ssize_t errcode) { UTF8PROC_DLLEXPORT const char* utf8proc_errmsg(utf8proc_ssize_t errcode)
switch (errcode) { {
case UTF8PROC_ERROR_NOMEM: switch(errcode)
return "Memory for processing UTF-8 data could not be allocated."; {
case UTF8PROC_ERROR_OVERFLOW: case UTF8PROC_ERROR_NOMEM: return "Memory for processing UTF-8 data could not be allocated.";
return "UTF-8 string is too long to be processed."; case UTF8PROC_ERROR_OVERFLOW: return "UTF-8 string is too long to be processed.";
case UTF8PROC_ERROR_INVALIDUTF8: case UTF8PROC_ERROR_INVALIDUTF8: return "Invalid UTF-8 string";
return "Invalid UTF-8 string"; case UTF8PROC_ERROR_NOTASSIGNED: return "Unassigned Unicode code point found in UTF-8 string.";
case UTF8PROC_ERROR_NOTASSIGNED: case UTF8PROC_ERROR_INVALIDOPTS: return "Invalid options for UTF-8 processing chosen.";
return "Unassigned Unicode code point found in UTF-8 string."; default: return "An unknown error occurred while processing UTF-8 data.";
case UTF8PROC_ERROR_INVALIDOPTS:
return "Invalid options for UTF-8 processing chosen.";
default:
return "An unknown error occurred while processing UTF-8 data.";
} }
} }
#define utf_cont(ch) (((ch) & 0xc0) == 0x80) #define utf_cont(ch) (((ch) & 0xc0) == 0x80)
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_iterate( UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_iterate(const utf8proc_uint8_t* str,
const utf8proc_uint8_t *str, utf8proc_ssize_t strlen, utf8proc_int32_t *dst utf8proc_ssize_t strlen,
) { utf8proc_int32_t* dst)
{
utf8proc_int32_t uc; utf8proc_int32_t uc;
const utf8proc_uint8_t *end; const utf8proc_uint8_t* end;
*dst = -1; *dst = -1;
if (!strlen) return 0; if(!strlen)
return 0;
end = str + ((strlen < 0) ? 4 : strlen); end = str + ((strlen < 0) ? 4 : strlen);
uc = *str++; uc = *str++;
if (uc < 0x80) { if(uc < 0x80)
{
*dst = uc; *dst = uc;
return 1; return 1;
} }
// Must be between 0xc2 and 0xf4 inclusive to be valid // Must be between 0xc2 and 0xf4 inclusive to be valid
if ((utf8proc_uint32_t)(uc - 0xc2) > (0xf4-0xc2)) return UTF8PROC_ERROR_INVALIDUTF8; if((utf8proc_uint32_t)(uc - 0xc2) > (0xf4 - 0xc2))
if (uc < 0xe0) { // 2-byte sequence return UTF8PROC_ERROR_INVALIDUTF8;
if(uc < 0xe0)
{ // 2-byte sequence
// Must have valid continuation character // Must have valid continuation character
if (str >= end || !utf_cont(*str)) return UTF8PROC_ERROR_INVALIDUTF8; if(str >= end || !utf_cont(*str))
*dst = ((uc & 0x1f)<<6) | (*str & 0x3f); return UTF8PROC_ERROR_INVALIDUTF8;
*dst = ((uc & 0x1f) << 6) | (*str & 0x3f);
return 2; return 2;
} }
if (uc < 0xf0) { // 3-byte sequence if(uc < 0xf0)
if ((str + 1 >= end) || !utf_cont(*str) || !utf_cont(str[1])) { // 3-byte sequence
if((str + 1 >= end) || !utf_cont(*str) || !utf_cont(str[1]))
return UTF8PROC_ERROR_INVALIDUTF8; return UTF8PROC_ERROR_INVALIDUTF8;
// Check for surrogate chars // Check for surrogate chars
if (uc == 0xed && *str > 0x9f) if(uc == 0xed && *str > 0x9f)
return UTF8PROC_ERROR_INVALIDUTF8; return UTF8PROC_ERROR_INVALIDUTF8;
uc = ((uc & 0xf)<<12) | ((*str & 0x3f)<<6) | (str[1] & 0x3f); uc = ((uc & 0xf) << 12) | ((*str & 0x3f) << 6) | (str[1] & 0x3f);
if (uc < 0x800) if(uc < 0x800)
return UTF8PROC_ERROR_INVALIDUTF8; return UTF8PROC_ERROR_INVALIDUTF8;
*dst = uc; *dst = uc;
return 3; return 3;
} }
// 4-byte sequence // 4-byte sequence
// Must have 3 valid continuation characters // Must have 3 valid continuation characters
if ((str + 2 >= end) || !utf_cont(*str) || !utf_cont(str[1]) || !utf_cont(str[2])) if((str + 2 >= end) || !utf_cont(*str) || !utf_cont(str[1]) || !utf_cont(str[2]))
return UTF8PROC_ERROR_INVALIDUTF8; return UTF8PROC_ERROR_INVALIDUTF8;
// Make sure in correct range (0x10000 - 0x10ffff) // Make sure in correct range (0x10000 - 0x10ffff)
if (uc == 0xf0) { if(uc == 0xf0)
if (*str < 0x90) return UTF8PROC_ERROR_INVALIDUTF8; {
} else if (uc == 0xf4) { if(*str < 0x90)
if (*str > 0x8f) return UTF8PROC_ERROR_INVALIDUTF8; return UTF8PROC_ERROR_INVALIDUTF8;
}
else if(uc == 0xf4)
{
if(*str > 0x8f)
return UTF8PROC_ERROR_INVALIDUTF8;
} }
*dst = ((uc & 7)<<18) | ((*str & 0x3f)<<12) | ((str[1] & 0x3f)<<6) | (str[2] & 0x3f); *dst = ((uc & 7) << 18) | ((*str & 0x3f) << 12) | ((str[1] & 0x3f) << 6) | (str[2] & 0x3f);
return 4; return 4;
} }
UTF8PROC_DLLEXPORT utf8proc_bool utf8proc_codepoint_valid(utf8proc_int32_t uc) { UTF8PROC_DLLEXPORT utf8proc_bool utf8proc_codepoint_valid(utf8proc_int32_t uc)
return (((utf8proc_uint32_t)uc)-0xd800 > 0x07ff) && ((utf8proc_uint32_t)uc < 0x110000); {
return (((utf8proc_uint32_t)uc) - 0xd800 > 0x07ff) && ((utf8proc_uint32_t)uc < 0x110000);
} }
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_encode_char(utf8proc_int32_t uc, utf8proc_uint8_t *dst) { UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_encode_char(utf8proc_int32_t uc, utf8proc_uint8_t* dst)
if (uc < 0x00) { {
if(uc < 0x00)
{
return 0; return 0;
} else if (uc < 0x80) { }
dst[0] = (utf8proc_uint8_t) uc; else if(uc < 0x80)
{
dst[0] = (utf8proc_uint8_t)uc;
return 1; return 1;
} else if (uc < 0x800) { }
else if(uc < 0x800)
{
dst[0] = (utf8proc_uint8_t)(0xC0 + (uc >> 6)); dst[0] = (utf8proc_uint8_t)(0xC0 + (uc >> 6));
dst[1] = (utf8proc_uint8_t)(0x80 + (uc & 0x3F)); dst[1] = (utf8proc_uint8_t)(0x80 + (uc & 0x3F));
return 2; return 2;
// Note: we allow encoding 0xd800-0xdfff here, so as not to change // Note: we allow encoding 0xd800-0xdfff here, so as not to change
// the API, however, these are actually invalid in UTF-8 // the API, however, these are actually invalid in UTF-8
} else if (uc < 0x10000) { }
else if(uc < 0x10000)
{
dst[0] = (utf8proc_uint8_t)(0xE0 + (uc >> 12)); dst[0] = (utf8proc_uint8_t)(0xE0 + (uc >> 12));
dst[1] = (utf8proc_uint8_t)(0x80 + ((uc >> 6) & 0x3F)); dst[1] = (utf8proc_uint8_t)(0x80 + ((uc >> 6) & 0x3F));
dst[2] = (utf8proc_uint8_t)(0x80 + (uc & 0x3F)); dst[2] = (utf8proc_uint8_t)(0x80 + (uc & 0x3F));
return 3; return 3;
} else if (uc < 0x110000) { }
else if(uc < 0x110000)
{
dst[0] = (utf8proc_uint8_t)(0xF0 + (uc >> 18)); dst[0] = (utf8proc_uint8_t)(0xF0 + (uc >> 18));
dst[1] = (utf8proc_uint8_t)(0x80 + ((uc >> 12) & 0x3F)); dst[1] = (utf8proc_uint8_t)(0x80 + ((uc >> 12) & 0x3F));
dst[2] = (utf8proc_uint8_t)(0x80 + ((uc >> 6) & 0x3F)); dst[2] = (utf8proc_uint8_t)(0x80 + ((uc >> 6) & 0x3F));
dst[3] = (utf8proc_uint8_t)(0x80 + (uc & 0x3F)); dst[3] = (utf8proc_uint8_t)(0x80 + (uc & 0x3F));
return 4; return 4;
} else return 0; }
else
return 0;
} }
/* internal version used for inserting 0xff bytes between graphemes */ /* internal version used for inserting 0xff bytes between graphemes */
static utf8proc_ssize_t charbound_encode_char(utf8proc_int32_t uc, utf8proc_uint8_t *dst) { static utf8proc_ssize_t charbound_encode_char(utf8proc_int32_t uc, utf8proc_uint8_t* dst)
if (uc < 0x00) { {
if (uc == -1) { /* internal value used for grapheme breaks */ if(uc < 0x00)
{
if(uc == -1)
{ /* internal value used for grapheme breaks */
dst[0] = (utf8proc_uint8_t)0xFF; dst[0] = (utf8proc_uint8_t)0xFF;
return 1; return 1;
} }
return 0; return 0;
} else if (uc < 0x80) { }
else if(uc < 0x80)
{
dst[0] = (utf8proc_uint8_t)uc; dst[0] = (utf8proc_uint8_t)uc;
return 1; return 1;
} else if (uc < 0x800) { }
else if(uc < 0x800)
{
dst[0] = (utf8proc_uint8_t)(0xC0 + (uc >> 6)); dst[0] = (utf8proc_uint8_t)(0xC0 + (uc >> 6));
dst[1] = (utf8proc_uint8_t)(0x80 + (uc & 0x3F)); dst[1] = (utf8proc_uint8_t)(0x80 + (uc & 0x3F));
return 2; return 2;
} else if (uc < 0x10000) { }
else if(uc < 0x10000)
{
dst[0] = (utf8proc_uint8_t)(0xE0 + (uc >> 12)); dst[0] = (utf8proc_uint8_t)(0xE0 + (uc >> 12));
dst[1] = (utf8proc_uint8_t)(0x80 + ((uc >> 6) & 0x3F)); dst[1] = (utf8proc_uint8_t)(0x80 + ((uc >> 6) & 0x3F));
dst[2] = (utf8proc_uint8_t)(0x80 + (uc & 0x3F)); dst[2] = (utf8proc_uint8_t)(0x80 + (uc & 0x3F));
return 3; return 3;
} else if (uc < 0x110000) { }
else if(uc < 0x110000)
{
dst[0] = (utf8proc_uint8_t)(0xF0 + (uc >> 18)); dst[0] = (utf8proc_uint8_t)(0xF0 + (uc >> 18));
dst[1] = (utf8proc_uint8_t)(0x80 + ((uc >> 12) & 0x3F)); dst[1] = (utf8proc_uint8_t)(0x80 + ((uc >> 12) & 0x3F));
dst[2] = (utf8proc_uint8_t)(0x80 + ((uc >> 6) & 0x3F)); dst[2] = (utf8proc_uint8_t)(0x80 + ((uc >> 6) & 0x3F));
dst[3] = (utf8proc_uint8_t)(0x80 + (uc & 0x3F)); dst[3] = (utf8proc_uint8_t)(0x80 + (uc & 0x3F));
return 4; return 4;
} else return 0; }
else
return 0;
} }
/* internal "unsafe" version that does not check whether uc is in range */ /* internal "unsafe" version that does not check whether uc is in range */
static const utf8proc_property_t *unsafe_get_property(utf8proc_int32_t uc) { static const utf8proc_property_t* unsafe_get_property(utf8proc_int32_t uc)
{
/* ASSERT: uc >= 0 && uc < 0x110000 */ /* ASSERT: uc >= 0 && uc < 0x110000 */
return utf8proc_properties + ( return utf8proc_properties +
utf8proc_stage2table[ (utf8proc_stage2table[utf8proc_stage1table[uc >> 8] + (uc & 0xFF)]);
utf8proc_stage1table[uc >> 8] + (uc & 0xFF)
]
);
} }
UTF8PROC_DLLEXPORT const utf8proc_property_t *utf8proc_get_property(utf8proc_int32_t uc) { UTF8PROC_DLLEXPORT const utf8proc_property_t* utf8proc_get_property(utf8proc_int32_t uc)
{
return uc < 0 || uc >= 0x110000 ? utf8proc_properties : unsafe_get_property(uc); return uc < 0 || uc >= 0x110000 ? utf8proc_properties : unsafe_get_property(uc);
} }
...@@ -258,41 +279,59 @@ UTF8PROC_DLLEXPORT const utf8proc_property_t *utf8proc_get_property(utf8proc_int ...@@ -258,41 +279,59 @@ UTF8PROC_DLLEXPORT const utf8proc_property_t *utf8proc_get_property(utf8proc_int
See the special support in grapheme_break_extended, for required bookkeeping by the caller. See the special support in grapheme_break_extended, for required bookkeeping by the caller.
*/ */
static utf8proc_bool grapheme_break_simple(int lbc, int tbc) { static utf8proc_bool grapheme_break_simple(int lbc, int tbc)
return {
(lbc == UTF8PROC_BOUNDCLASS_START) ? true : // GB1 return (lbc == UTF8PROC_BOUNDCLASS_START) ? true : // GB1
(lbc == UTF8PROC_BOUNDCLASS_CR && // GB3 (lbc == UTF8PROC_BOUNDCLASS_CR && // GB3
tbc == UTF8PROC_BOUNDCLASS_LF) ? false : // --- tbc == UTF8PROC_BOUNDCLASS_LF)
(lbc >= UTF8PROC_BOUNDCLASS_CR && lbc <= UTF8PROC_BOUNDCLASS_CONTROL) ? true : // GB4 ? false
(tbc >= UTF8PROC_BOUNDCLASS_CR && tbc <= UTF8PROC_BOUNDCLASS_CONTROL) ? true : // GB5 : // ---
(lbc >= UTF8PROC_BOUNDCLASS_CR && lbc <= UTF8PROC_BOUNDCLASS_CONTROL) ? true
: // GB4
(tbc >= UTF8PROC_BOUNDCLASS_CR && tbc <= UTF8PROC_BOUNDCLASS_CONTROL) ? true
: // GB5
(lbc == UTF8PROC_BOUNDCLASS_L && // GB6 (lbc == UTF8PROC_BOUNDCLASS_L && // GB6
(tbc == UTF8PROC_BOUNDCLASS_L || // --- (tbc == UTF8PROC_BOUNDCLASS_L || // ---
tbc == UTF8PROC_BOUNDCLASS_V || // --- tbc == UTF8PROC_BOUNDCLASS_V || // ---
tbc == UTF8PROC_BOUNDCLASS_LV || // --- tbc == UTF8PROC_BOUNDCLASS_LV || // ---
tbc == UTF8PROC_BOUNDCLASS_LVT)) ? false : // --- tbc == UTF8PROC_BOUNDCLASS_LVT))
? false
: // ---
((lbc == UTF8PROC_BOUNDCLASS_LV || // GB7 ((lbc == UTF8PROC_BOUNDCLASS_LV || // GB7
lbc == UTF8PROC_BOUNDCLASS_V) && // --- lbc == UTF8PROC_BOUNDCLASS_V) && // ---
(tbc == UTF8PROC_BOUNDCLASS_V || // --- (tbc == UTF8PROC_BOUNDCLASS_V || // ---
tbc == UTF8PROC_BOUNDCLASS_T)) ? false : // --- tbc == UTF8PROC_BOUNDCLASS_T))
? false
: // ---
((lbc == UTF8PROC_BOUNDCLASS_LVT || // GB8 ((lbc == UTF8PROC_BOUNDCLASS_LVT || // GB8
lbc == UTF8PROC_BOUNDCLASS_T) && // --- lbc == UTF8PROC_BOUNDCLASS_T) && // ---
tbc == UTF8PROC_BOUNDCLASS_T) ? false : // --- tbc == UTF8PROC_BOUNDCLASS_T)
? false
: // ---
(tbc == UTF8PROC_BOUNDCLASS_EXTEND || // GB9 (tbc == UTF8PROC_BOUNDCLASS_EXTEND || // GB9
tbc == UTF8PROC_BOUNDCLASS_ZWJ || // --- tbc == UTF8PROC_BOUNDCLASS_ZWJ || // ---
tbc == UTF8PROC_BOUNDCLASS_SPACINGMARK || // GB9a tbc == UTF8PROC_BOUNDCLASS_SPACINGMARK || // GB9a
lbc == UTF8PROC_BOUNDCLASS_PREPEND) ? false : // GB9b lbc == UTF8PROC_BOUNDCLASS_PREPEND)
? false
: // GB9b
(lbc == UTF8PROC_BOUNDCLASS_E_ZWG && // GB11 (requires additional handling below) (lbc == UTF8PROC_BOUNDCLASS_E_ZWG && // GB11 (requires additional handling below)
tbc == UTF8PROC_BOUNDCLASS_EXTENDED_PICTOGRAPHIC) ? false : // ---- tbc == UTF8PROC_BOUNDCLASS_EXTENDED_PICTOGRAPHIC)
(lbc == UTF8PROC_BOUNDCLASS_REGIONAL_INDICATOR && // GB12/13 (requires additional handling below) ? false
tbc == UTF8PROC_BOUNDCLASS_REGIONAL_INDICATOR) ? false : // ---- : // ----
(lbc == UTF8PROC_BOUNDCLASS_REGIONAL_INDICATOR && // GB12/13 (requires additional
// handling below)
tbc == UTF8PROC_BOUNDCLASS_REGIONAL_INDICATOR)
? false
: // ----
true; // GB999 true; // GB999
} }
static utf8proc_bool grapheme_break_extended(int lbc, int tbc, utf8proc_int32_t *state) static utf8proc_bool grapheme_break_extended(int lbc, int tbc, utf8proc_int32_t* state)
{ {
if (state) { if(state)
{
int lbc_override; int lbc_override;
if (*state == UTF8PROC_BOUNDCLASS_START) if(*state == UTF8PROC_BOUNDCLASS_START)
*state = lbc_override = lbc; *state = lbc_override = lbc;
else else
lbc_override = *state; lbc_override = *state;
...@@ -303,13 +342,14 @@ static utf8proc_bool grapheme_break_extended(int lbc, int tbc, utf8proc_int32_t ...@@ -303,13 +342,14 @@ static utf8proc_bool grapheme_break_extended(int lbc, int tbc, utf8proc_int32_t
// second RI's bound class to UTF8PROC_BOUNDCLASS_OTHER, to force a break // second RI's bound class to UTF8PROC_BOUNDCLASS_OTHER, to force a break
// after that character according to GB999 (unless of course such a break is // after that character according to GB999 (unless of course such a break is
// forbidden by a different rule such as GB9). // forbidden by a different rule such as GB9).
if (*state == tbc && tbc == UTF8PROC_BOUNDCLASS_REGIONAL_INDICATOR) if(*state == tbc && tbc == UTF8PROC_BOUNDCLASS_REGIONAL_INDICATOR)
*state = UTF8PROC_BOUNDCLASS_OTHER; *state = UTF8PROC_BOUNDCLASS_OTHER;
// Special support for GB11 (emoji extend* zwj / emoji) // Special support for GB11 (emoji extend* zwj / emoji)
else if (*state == UTF8PROC_BOUNDCLASS_EXTENDED_PICTOGRAPHIC) { else if(*state == UTF8PROC_BOUNDCLASS_EXTENDED_PICTOGRAPHIC)
if (tbc == UTF8PROC_BOUNDCLASS_EXTEND) // fold EXTEND codepoints into emoji {
if(tbc == UTF8PROC_BOUNDCLASS_EXTEND) // fold EXTEND codepoints into emoji
*state = UTF8PROC_BOUNDCLASS_EXTENDED_PICTOGRAPHIC; *state = UTF8PROC_BOUNDCLASS_EXTENDED_PICTOGRAPHIC;
else if (tbc == UTF8PROC_BOUNDCLASS_ZWJ) else if(tbc == UTF8PROC_BOUNDCLASS_ZWJ)
*state = UTF8PROC_BOUNDCLASS_E_ZWG; // state to record emoji+zwg combo *state = UTF8PROC_BOUNDCLASS_E_ZWG; // state to record emoji+zwg combo
else else
*state = tbc; *state = tbc;
...@@ -323,24 +363,25 @@ static utf8proc_bool grapheme_break_extended(int lbc, int tbc, utf8proc_int32_t ...@@ -323,24 +363,25 @@ static utf8proc_bool grapheme_break_extended(int lbc, int tbc, utf8proc_int32_t
return grapheme_break_simple(lbc, tbc); return grapheme_break_simple(lbc, tbc);
} }
UTF8PROC_DLLEXPORT utf8proc_bool utf8proc_grapheme_break_stateful( UTF8PROC_DLLEXPORT utf8proc_bool utf8proc_grapheme_break_stateful(utf8proc_int32_t c1,
utf8proc_int32_t c1, utf8proc_int32_t c2, utf8proc_int32_t *state) { utf8proc_int32_t c2,
utf8proc_int32_t* state)
{
return grapheme_break_extended(utf8proc_get_property(c1)->boundclass, return grapheme_break_extended(
utf8proc_get_property(c2)->boundclass, utf8proc_get_property(c1)->boundclass, utf8proc_get_property(c2)->boundclass, state);
state);
} }
UTF8PROC_DLLEXPORT utf8proc_bool utf8proc_grapheme_break(utf8proc_int32_t c1, utf8proc_int32_t c2)
UTF8PROC_DLLEXPORT utf8proc_bool utf8proc_grapheme_break( {
utf8proc_int32_t c1, utf8proc_int32_t c2) {
return utf8proc_grapheme_break_stateful(c1, c2, NULL); return utf8proc_grapheme_break_stateful(c1, c2, NULL);
} }
static utf8proc_int32_t seqindex_decode_entry(const utf8proc_uint16_t **entry) static utf8proc_int32_t seqindex_decode_entry(const utf8proc_uint16_t** entry)
{ {
utf8proc_int32_t entry_cp = **entry; utf8proc_int32_t entry_cp = **entry;
if ((entry_cp & 0xF800) == 0xD800) { if((entry_cp & 0xF800) == 0xD800)
{
*entry = *entry + 1; *entry = *entry + 1;
entry_cp = ((entry_cp & 0x03FF) << 10) | (**entry & 0x03FF); entry_cp = ((entry_cp & 0x03FF) << 10) | (**entry & 0x03FF);
entry_cp += 0x10000; entry_cp += 0x10000;
...@@ -350,25 +391,35 @@ static utf8proc_int32_t seqindex_decode_entry(const utf8proc_uint16_t **entry) ...@@ -350,25 +391,35 @@ static utf8proc_int32_t seqindex_decode_entry(const utf8proc_uint16_t **entry)
static utf8proc_int32_t seqindex_decode_index(const utf8proc_uint32_t seqindex) static utf8proc_int32_t seqindex_decode_index(const utf8proc_uint32_t seqindex)
{ {
const utf8proc_uint16_t *entry = &utf8proc_sequences[seqindex]; const utf8proc_uint16_t* entry = &utf8proc_sequences[seqindex];
return seqindex_decode_entry(&entry); return seqindex_decode_entry(&entry);
} }
static utf8proc_ssize_t seqindex_write_char_decomposed(utf8proc_uint16_t seqindex, utf8proc_int32_t *dst, utf8proc_ssize_t bufsize, utf8proc_option_t options, int *last_boundclass) { static utf8proc_ssize_t seqindex_write_char_decomposed(utf8proc_uint16_t seqindex,
utf8proc_int32_t* dst,
utf8proc_ssize_t bufsize,
utf8proc_option_t options,
int* last_boundclass)
{
utf8proc_ssize_t written = 0; utf8proc_ssize_t written = 0;
const utf8proc_uint16_t *entry = &utf8proc_sequences[seqindex & 0x3FFF]; const utf8proc_uint16_t* entry = &utf8proc_sequences[seqindex & 0x3FFF];
int len = seqindex >> 14; int len = seqindex >> 14;
if (len >= 3) { if(len >= 3)
{
len = *entry; len = *entry;
entry++; entry++;
} }
for (; len >= 0; entry++, len--) { for(; len >= 0; entry++, len--)
{
utf8proc_int32_t entry_cp = seqindex_decode_entry(&entry); utf8proc_int32_t entry_cp = seqindex_decode_entry(&entry);
written += utf8proc_decompose_char(entry_cp, dst+written, written += utf8proc_decompose_char(entry_cp,
(bufsize > written) ? (bufsize - written) : 0, options, dst + written,
(bufsize > written) ? (bufsize - written) : 0,
options,
last_boundclass); last_boundclass);
if (written < 0) return UTF8PROC_ERROR_OVERFLOW; if(written < 0)
return UTF8PROC_ERROR_OVERFLOW;
} }
return written; return written;
} }
...@@ -393,190 +444,254 @@ UTF8PROC_DLLEXPORT utf8proc_int32_t utf8proc_totitle(utf8proc_int32_t c) ...@@ -393,190 +444,254 @@ UTF8PROC_DLLEXPORT utf8proc_int32_t utf8proc_totitle(utf8proc_int32_t c)
UTF8PROC_DLLEXPORT int utf8proc_islower(utf8proc_int32_t c) UTF8PROC_DLLEXPORT int utf8proc_islower(utf8proc_int32_t c)
{ {
const utf8proc_property_t *p = utf8proc_get_property(c); const utf8proc_property_t* p = utf8proc_get_property(c);
return p->lowercase_seqindex != p->uppercase_seqindex && p->lowercase_seqindex == UINT16_MAX; return p->lowercase_seqindex != p->uppercase_seqindex && p->lowercase_seqindex == UINT16_MAX;
} }
UTF8PROC_DLLEXPORT int utf8proc_isupper(utf8proc_int32_t c) UTF8PROC_DLLEXPORT int utf8proc_isupper(utf8proc_int32_t c)
{ {
const utf8proc_property_t *p = utf8proc_get_property(c); const utf8proc_property_t* p = utf8proc_get_property(c);
return p->lowercase_seqindex != p->uppercase_seqindex && p->uppercase_seqindex == UINT16_MAX && p->category != UTF8PROC_CATEGORY_LT; return p->lowercase_seqindex != p->uppercase_seqindex && p->uppercase_seqindex == UINT16_MAX &&
p->category != UTF8PROC_CATEGORY_LT;
} }
/* return a character width analogous to wcwidth (except portable and /* return a character width analogous to wcwidth (except portable and
hopefully less buggy than most system wcwidth functions). */ hopefully less buggy than most system wcwidth functions). */
UTF8PROC_DLLEXPORT int utf8proc_charwidth(utf8proc_int32_t c) { UTF8PROC_DLLEXPORT int utf8proc_charwidth(utf8proc_int32_t c)
{
return utf8proc_get_property(c)->charwidth; return utf8proc_get_property(c)->charwidth;
} }
UTF8PROC_DLLEXPORT utf8proc_category_t utf8proc_category(utf8proc_int32_t c) { UTF8PROC_DLLEXPORT utf8proc_category_t utf8proc_category(utf8proc_int32_t c)
return (utf8proc_category_t) utf8proc_get_property(c)->category; {
return (utf8proc_category_t)utf8proc_get_property(c)->category;
} }
UTF8PROC_DLLEXPORT const char *utf8proc_category_string(utf8proc_int32_t c) { UTF8PROC_DLLEXPORT const char* utf8proc_category_string(utf8proc_int32_t c)
static const char s[][3] = {"Cn","Lu","Ll","Lt","Lm","Lo","Mn","Mc","Me","Nd","Nl","No","Pc","Pd","Ps","Pe","Pi","Pf","Po","Sm","Sc","Sk","So","Zs","Zl","Zp","Cc","Cf","Cs","Co"}; {
static const char s[][3] = {"Cn", "Lu", "Ll", "Lt", "Lm", "Lo", "Mn", "Mc", "Me", "Nd",
"Nl", "No", "Pc", "Pd", "Ps", "Pe", "Pi", "Pf", "Po", "Sm",
"Sc", "Sk", "So", "Zs", "Zl", "Zp", "Cc", "Cf", "Cs", "Co"};
return s[utf8proc_category(c)]; return s[utf8proc_category(c)];
} }
#define utf8proc_decompose_lump(replacement_uc) \ #define utf8proc_decompose_lump(replacement_uc) \
return utf8proc_decompose_char((replacement_uc), dst, bufsize, \ return utf8proc_decompose_char( \
options & ~(unsigned int)UTF8PROC_LUMP, last_boundclass) (replacement_uc), dst, bufsize, options & ~(unsigned int)UTF8PROC_LUMP, last_boundclass)
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose_char(utf8proc_int32_t uc, utf8proc_int32_t *dst, utf8proc_ssize_t bufsize, utf8proc_option_t options, int *last_boundclass) { UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose_char(utf8proc_int32_t uc,
const utf8proc_property_t *property; utf8proc_int32_t* dst,
utf8proc_ssize_t bufsize,
utf8proc_option_t options,
int* last_boundclass)
{
const utf8proc_property_t* property;
utf8proc_propval_t category; utf8proc_propval_t category;
utf8proc_int32_t hangul_sindex; utf8proc_int32_t hangul_sindex;
if (uc < 0 || uc >= 0x110000) return UTF8PROC_ERROR_NOTASSIGNED; if(uc < 0 || uc >= 0x110000)
return UTF8PROC_ERROR_NOTASSIGNED;
property = unsafe_get_property(uc); property = unsafe_get_property(uc);
category = property->category; category = property->category;
hangul_sindex = uc - UTF8PROC_HANGUL_SBASE; hangul_sindex = uc - UTF8PROC_HANGUL_SBASE;
if (options & (UTF8PROC_COMPOSE|UTF8PROC_DECOMPOSE)) { if(options & (UTF8PROC_COMPOSE | UTF8PROC_DECOMPOSE))
if (hangul_sindex >= 0 && hangul_sindex < UTF8PROC_HANGUL_SCOUNT) { {
if(hangul_sindex >= 0 && hangul_sindex < UTF8PROC_HANGUL_SCOUNT)
{
utf8proc_int32_t hangul_tindex; utf8proc_int32_t hangul_tindex;
if (bufsize >= 1) { if(bufsize >= 1)
dst[0] = UTF8PROC_HANGUL_LBASE + {
hangul_sindex / UTF8PROC_HANGUL_NCOUNT; dst[0] = UTF8PROC_HANGUL_LBASE + hangul_sindex / UTF8PROC_HANGUL_NCOUNT;
if (bufsize >= 2) dst[1] = UTF8PROC_HANGUL_VBASE + if(bufsize >= 2)
dst[1] = UTF8PROC_HANGUL_VBASE +
(hangul_sindex % UTF8PROC_HANGUL_NCOUNT) / UTF8PROC_HANGUL_TCOUNT; (hangul_sindex % UTF8PROC_HANGUL_NCOUNT) / UTF8PROC_HANGUL_TCOUNT;
} }
hangul_tindex = hangul_sindex % UTF8PROC_HANGUL_TCOUNT; hangul_tindex = hangul_sindex % UTF8PROC_HANGUL_TCOUNT;
if (!hangul_tindex) return 2; if(!hangul_tindex)
if (bufsize >= 3) dst[2] = UTF8PROC_HANGUL_TBASE + hangul_tindex; return 2;
if(bufsize >= 3)
dst[2] = UTF8PROC_HANGUL_TBASE + hangul_tindex;
return 3; return 3;
} }
} }
if (options & UTF8PROC_REJECTNA) { if(options & UTF8PROC_REJECTNA)
if (!category) return UTF8PROC_ERROR_NOTASSIGNED; {
if(!category)
return UTF8PROC_ERROR_NOTASSIGNED;
} }
if (options & UTF8PROC_IGNORE) { if(options & UTF8PROC_IGNORE)
if (property->ignorable) return 0; {
if(property->ignorable)
return 0;
} }
if (options & UTF8PROC_STRIPNA) { if(options & UTF8PROC_STRIPNA)
if (!category) return 0; {
if(!category)
return 0;
} }
if (options & UTF8PROC_LUMP) { if(options & UTF8PROC_LUMP)
if (category == UTF8PROC_CATEGORY_ZS) utf8proc_decompose_lump(0x0020); {
if (uc == 0x2018 || uc == 0x2019 || uc == 0x02BC || uc == 0x02C8) if(category == UTF8PROC_CATEGORY_ZS)
utf8proc_decompose_lump(0x0020);
if(uc == 0x2018 || uc == 0x2019 || uc == 0x02BC || uc == 0x02C8)
utf8proc_decompose_lump(0x0027); utf8proc_decompose_lump(0x0027);
if (category == UTF8PROC_CATEGORY_PD || uc == 0x2212) if(category == UTF8PROC_CATEGORY_PD || uc == 0x2212)
utf8proc_decompose_lump(0x002D); utf8proc_decompose_lump(0x002D);
if (uc == 0x2044 || uc == 0x2215) utf8proc_decompose_lump(0x002F); if(uc == 0x2044 || uc == 0x2215)
if (uc == 0x2236) utf8proc_decompose_lump(0x003A); utf8proc_decompose_lump(0x002F);
if (uc == 0x2039 || uc == 0x2329 || uc == 0x3008) if(uc == 0x2236)
utf8proc_decompose_lump(0x003A);
if(uc == 0x2039 || uc == 0x2329 || uc == 0x3008)
utf8proc_decompose_lump(0x003C); utf8proc_decompose_lump(0x003C);
if (uc == 0x203A || uc == 0x232A || uc == 0x3009) if(uc == 0x203A || uc == 0x232A || uc == 0x3009)
utf8proc_decompose_lump(0x003E); utf8proc_decompose_lump(0x003E);
if (uc == 0x2216) utf8proc_decompose_lump(0x005C); if(uc == 0x2216)
if (uc == 0x02C4 || uc == 0x02C6 || uc == 0x2038 || uc == 0x2303) utf8proc_decompose_lump(0x005C);
if(uc == 0x02C4 || uc == 0x02C6 || uc == 0x2038 || uc == 0x2303)
utf8proc_decompose_lump(0x005E); utf8proc_decompose_lump(0x005E);
if (category == UTF8PROC_CATEGORY_PC || uc == 0x02CD) if(category == UTF8PROC_CATEGORY_PC || uc == 0x02CD)
utf8proc_decompose_lump(0x005F); utf8proc_decompose_lump(0x005F);
if (uc == 0x02CB) utf8proc_decompose_lump(0x0060); if(uc == 0x02CB)
if (uc == 0x2223) utf8proc_decompose_lump(0x007C); utf8proc_decompose_lump(0x0060);
if (uc == 0x223C) utf8proc_decompose_lump(0x007E); if(uc == 0x2223)
if ((options & UTF8PROC_NLF2LS) && (options & UTF8PROC_NLF2PS)) { utf8proc_decompose_lump(0x007C);
if (category == UTF8PROC_CATEGORY_ZL || if(uc == 0x223C)
category == UTF8PROC_CATEGORY_ZP) utf8proc_decompose_lump(0x007E);
if((options & UTF8PROC_NLF2LS) && (options & UTF8PROC_NLF2PS))
{
if(category == UTF8PROC_CATEGORY_ZL || category == UTF8PROC_CATEGORY_ZP)
utf8proc_decompose_lump(0x000A); utf8proc_decompose_lump(0x000A);
} }
} }
if (options & UTF8PROC_STRIPMARK) { if(options & UTF8PROC_STRIPMARK)
if (category == UTF8PROC_CATEGORY_MN || {
category == UTF8PROC_CATEGORY_MC || if(category == UTF8PROC_CATEGORY_MN || category == UTF8PROC_CATEGORY_MC ||
category == UTF8PROC_CATEGORY_ME) return 0; category == UTF8PROC_CATEGORY_ME)
return 0;
} }
if (options & UTF8PROC_CASEFOLD) { if(options & UTF8PROC_CASEFOLD)
if (property->casefold_seqindex != UINT16_MAX) { {
return seqindex_write_char_decomposed(property->casefold_seqindex, dst, bufsize, options, last_boundclass); if(property->casefold_seqindex != UINT16_MAX)
{
return seqindex_write_char_decomposed(
property->casefold_seqindex, dst, bufsize, options, last_boundclass);
} }
} }
if (options & (UTF8PROC_COMPOSE|UTF8PROC_DECOMPOSE)) { if(options & (UTF8PROC_COMPOSE | UTF8PROC_DECOMPOSE))
if (property->decomp_seqindex != UINT16_MAX && {
(!property->decomp_type || (options & UTF8PROC_COMPAT))) { if(property->decomp_seqindex != UINT16_MAX &&
return seqindex_write_char_decomposed(property->decomp_seqindex, dst, bufsize, options, last_boundclass); (!property->decomp_type || (options & UTF8PROC_COMPAT)))
{
return seqindex_write_char_decomposed(
property->decomp_seqindex, dst, bufsize, options, last_boundclass);
} }
} }
if (options & UTF8PROC_CHARBOUND) { if(options & UTF8PROC_CHARBOUND)
{
utf8proc_bool boundary; utf8proc_bool boundary;
int tbc = property->boundclass; int tbc = property->boundclass;
boundary = grapheme_break_extended(*last_boundclass, tbc, last_boundclass); boundary = grapheme_break_extended(*last_boundclass, tbc, last_boundclass);
if (boundary) { if(boundary)
if (bufsize >= 1) dst[0] = -1; /* sentinel value for grapheme break */ {
if (bufsize >= 2) dst[1] = uc; if(bufsize >= 1)
dst[0] = -1; /* sentinel value for grapheme break */
if(bufsize >= 2)
dst[1] = uc;
return 2; return 2;
} }
} }
if (bufsize >= 1) *dst = uc; if(bufsize >= 1)
*dst = uc;
return 1; return 1;
} }
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose( UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose(const utf8proc_uint8_t* str,
const utf8proc_uint8_t *str, utf8proc_ssize_t strlen, utf8proc_ssize_t strlen,
utf8proc_int32_t *buffer, utf8proc_ssize_t bufsize, utf8proc_option_t options utf8proc_int32_t* buffer,
) { utf8proc_ssize_t bufsize,
utf8proc_option_t options)
{
return utf8proc_decompose_custom(str, strlen, buffer, bufsize, options, NULL, NULL); return utf8proc_decompose_custom(str, strlen, buffer, bufsize, options, NULL, NULL);
} }
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose_custom( UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose_custom(const utf8proc_uint8_t* str,
const utf8proc_uint8_t *str, utf8proc_ssize_t strlen, utf8proc_ssize_t strlen,
utf8proc_int32_t *buffer, utf8proc_ssize_t bufsize, utf8proc_option_t options, utf8proc_int32_t* buffer,
utf8proc_custom_func custom_func, void *custom_data utf8proc_ssize_t bufsize,
) { utf8proc_option_t options,
utf8proc_custom_func custom_func,
void* custom_data)
{
/* strlen will be ignored, if UTF8PROC_NULLTERM is set in options */ /* strlen will be ignored, if UTF8PROC_NULLTERM is set in options */
utf8proc_ssize_t wpos = 0; utf8proc_ssize_t wpos = 0;
if ((options & UTF8PROC_COMPOSE) && (options & UTF8PROC_DECOMPOSE)) if((options & UTF8PROC_COMPOSE) && (options & UTF8PROC_DECOMPOSE))
return UTF8PROC_ERROR_INVALIDOPTS; return UTF8PROC_ERROR_INVALIDOPTS;
if ((options & UTF8PROC_STRIPMARK) && if((options & UTF8PROC_STRIPMARK) && !(options & UTF8PROC_COMPOSE) &&
!(options & UTF8PROC_COMPOSE) && !(options & UTF8PROC_DECOMPOSE)) !(options & UTF8PROC_DECOMPOSE))
return UTF8PROC_ERROR_INVALIDOPTS; return UTF8PROC_ERROR_INVALIDOPTS;
{ {
utf8proc_int32_t uc; utf8proc_int32_t uc;
utf8proc_ssize_t rpos = 0; utf8proc_ssize_t rpos = 0;
utf8proc_ssize_t decomp_result; utf8proc_ssize_t decomp_result;
int boundclass = UTF8PROC_BOUNDCLASS_START; int boundclass = UTF8PROC_BOUNDCLASS_START;
while (1) { while(1)
if (options & UTF8PROC_NULLTERM) { {
if(options & UTF8PROC_NULLTERM)
{
rpos += utf8proc_iterate(str + rpos, -1, &uc); rpos += utf8proc_iterate(str + rpos, -1, &uc);
/* checking of return value is not necessary, /* checking of return value is not necessary,
as 'uc' is < 0 in case of error */ as 'uc' is < 0 in case of error */
if (uc < 0) return UTF8PROC_ERROR_INVALIDUTF8; if(uc < 0)
if (rpos < 0) return UTF8PROC_ERROR_OVERFLOW; return UTF8PROC_ERROR_INVALIDUTF8;
if (uc == 0) break; if(rpos < 0)
} else { return UTF8PROC_ERROR_OVERFLOW;
if (rpos >= strlen) break; if(uc == 0)
break;
}
else
{
if(rpos >= strlen)
break;
rpos += utf8proc_iterate(str + rpos, strlen - rpos, &uc); rpos += utf8proc_iterate(str + rpos, strlen - rpos, &uc);
if (uc < 0) return UTF8PROC_ERROR_INVALIDUTF8; if(uc < 0)
return UTF8PROC_ERROR_INVALIDUTF8;
} }
if (custom_func != NULL) { if(custom_func != NULL)
{
uc = custom_func(uc, custom_data); /* user-specified custom mapping */ uc = custom_func(uc, custom_data); /* user-specified custom mapping */
} }
decomp_result = utf8proc_decompose_char( decomp_result = utf8proc_decompose_char(
uc, buffer + wpos, (bufsize > wpos) ? (bufsize - wpos) : 0, options, uc, buffer + wpos, (bufsize > wpos) ? (bufsize - wpos) : 0, options, &boundclass);
&boundclass if(decomp_result < 0)
); return decomp_result;
if (decomp_result < 0) return decomp_result;
wpos += decomp_result; wpos += decomp_result;
/* prohibiting integer overflows due to too long strings: */ /* prohibiting integer overflows due to too long strings: */
if (wpos < 0 || if(wpos < 0 || wpos > (utf8proc_ssize_t)(SSIZE_MAX / sizeof(utf8proc_int32_t) / 2))
wpos > (utf8proc_ssize_t)(SSIZE_MAX/sizeof(utf8proc_int32_t)/2))
return UTF8PROC_ERROR_OVERFLOW; return UTF8PROC_ERROR_OVERFLOW;
} }
} }
if ((options & (UTF8PROC_COMPOSE|UTF8PROC_DECOMPOSE)) && bufsize >= wpos) { if((options & (UTF8PROC_COMPOSE | UTF8PROC_DECOMPOSE)) && bufsize >= wpos)
{
utf8proc_ssize_t pos = 0; utf8proc_ssize_t pos = 0;
while (pos < wpos-1) { while(pos < wpos - 1)
{
utf8proc_int32_t uc1, uc2; utf8proc_int32_t uc1, uc2;
const utf8proc_property_t *property1, *property2; const utf8proc_property_t *property1, *property2;
uc1 = buffer[pos]; uc1 = buffer[pos];
uc2 = buffer[pos+1]; uc2 = buffer[pos + 1];
property1 = unsafe_get_property(uc1); property1 = unsafe_get_property(uc1);
property2 = unsafe_get_property(uc2); property2 = unsafe_get_property(uc2);
if (property1->combining_class > property2->combining_class && if(property1->combining_class > property2->combining_class &&
property2->combining_class > 0) { property2->combining_class > 0)
{
buffer[pos] = uc2; buffer[pos] = uc2;
buffer[pos+1] = uc1; buffer[pos + 1] = uc1;
if (pos > 0) pos--; else pos++; if(pos > 0)
} else { pos--;
else
pos++;
}
else
{
pos++; pos++;
} }
} }
...@@ -584,59 +699,84 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose_custom( ...@@ -584,59 +699,84 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose_custom(
return wpos; return wpos;
} }
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_normalize_utf32(utf8proc_int32_t *buffer, utf8proc_ssize_t length, utf8proc_option_t options) { UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_normalize_utf32(utf8proc_int32_t* buffer,
utf8proc_ssize_t length,
utf8proc_option_t options)
{
/* UTF8PROC_NULLTERM option will be ignored, 'length' is never ignored */ /* UTF8PROC_NULLTERM option will be ignored, 'length' is never ignored */
if (options & (UTF8PROC_NLF2LS | UTF8PROC_NLF2PS | UTF8PROC_STRIPCC)) { if(options & (UTF8PROC_NLF2LS | UTF8PROC_NLF2PS | UTF8PROC_STRIPCC))
{
utf8proc_ssize_t rpos; utf8proc_ssize_t rpos;
utf8proc_ssize_t wpos = 0; utf8proc_ssize_t wpos = 0;
utf8proc_int32_t uc; utf8proc_int32_t uc;
for (rpos = 0; rpos < length; rpos++) { for(rpos = 0; rpos < length; rpos++)
{
uc = buffer[rpos]; uc = buffer[rpos];
if (uc == 0x000D && rpos < length-1 && buffer[rpos+1] == 0x000A) rpos++; if(uc == 0x000D && rpos < length - 1 && buffer[rpos + 1] == 0x000A)
if (uc == 0x000A || uc == 0x000D || uc == 0x0085 || rpos++;
((options & UTF8PROC_STRIPCC) && (uc == 0x000B || uc == 0x000C))) { if(uc == 0x000A || uc == 0x000D || uc == 0x0085 ||
if (options & UTF8PROC_NLF2LS) { ((options & UTF8PROC_STRIPCC) && (uc == 0x000B || uc == 0x000C)))
if (options & UTF8PROC_NLF2PS) { {
if(options & UTF8PROC_NLF2LS)
{
if(options & UTF8PROC_NLF2PS)
{
buffer[wpos++] = 0x000A; buffer[wpos++] = 0x000A;
} else { }
else
{
buffer[wpos++] = 0x2028; buffer[wpos++] = 0x2028;
} }
} else { }
if (options & UTF8PROC_NLF2PS) { else
{
if(options & UTF8PROC_NLF2PS)
{
buffer[wpos++] = 0x2029; buffer[wpos++] = 0x2029;
} else { }
else
{
buffer[wpos++] = 0x0020; buffer[wpos++] = 0x0020;
} }
} }
} else if ((options & UTF8PROC_STRIPCC) && }
(uc < 0x0020 || (uc >= 0x007F && uc < 0x00A0))) { else if((options & UTF8PROC_STRIPCC) && (uc < 0x0020 || (uc >= 0x007F && uc < 0x00A0)))
if (uc == 0x0009) buffer[wpos++] = 0x0020; {
} else { if(uc == 0x0009)
buffer[wpos++] = 0x0020;
}
else
{
buffer[wpos++] = uc; buffer[wpos++] = uc;
} }
} }
length = wpos; length = wpos;
} }
if (options & UTF8PROC_COMPOSE) { if(options & UTF8PROC_COMPOSE)
utf8proc_int32_t *starter = NULL; {
utf8proc_int32_t* starter = NULL;
utf8proc_int32_t current_char; utf8proc_int32_t current_char;
const utf8proc_property_t *starter_property = NULL, *current_property; const utf8proc_property_t *starter_property = NULL, *current_property;
utf8proc_propval_t max_combining_class = -1; utf8proc_propval_t max_combining_class = -1;
utf8proc_ssize_t rpos; utf8proc_ssize_t rpos;
utf8proc_ssize_t wpos = 0; utf8proc_ssize_t wpos = 0;
utf8proc_int32_t composition; utf8proc_int32_t composition;
for (rpos = 0; rpos < length; rpos++) { for(rpos = 0; rpos < length; rpos++)
{
current_char = buffer[rpos]; current_char = buffer[rpos];
current_property = unsafe_get_property(current_char); current_property = unsafe_get_property(current_char);
if (starter && current_property->combining_class > max_combining_class) { if(starter && current_property->combining_class > max_combining_class)
{
/* combination perhaps possible */ /* combination perhaps possible */
utf8proc_int32_t hangul_lindex; utf8proc_int32_t hangul_lindex;
utf8proc_int32_t hangul_sindex; utf8proc_int32_t hangul_sindex;
hangul_lindex = *starter - UTF8PROC_HANGUL_LBASE; hangul_lindex = *starter - UTF8PROC_HANGUL_LBASE;
if (hangul_lindex >= 0 && hangul_lindex < UTF8PROC_HANGUL_LCOUNT) { if(hangul_lindex >= 0 && hangul_lindex < UTF8PROC_HANGUL_LCOUNT)
{
utf8proc_int32_t hangul_vindex; utf8proc_int32_t hangul_vindex;
hangul_vindex = current_char - UTF8PROC_HANGUL_VBASE; hangul_vindex = current_char - UTF8PROC_HANGUL_VBASE;
if (hangul_vindex >= 0 && hangul_vindex < UTF8PROC_HANGUL_VCOUNT) { if(hangul_vindex >= 0 && hangul_vindex < UTF8PROC_HANGUL_VCOUNT)
{
*starter = UTF8PROC_HANGUL_SBASE + *starter = UTF8PROC_HANGUL_SBASE +
(hangul_lindex * UTF8PROC_HANGUL_VCOUNT + hangul_vindex) * (hangul_lindex * UTF8PROC_HANGUL_VCOUNT + hangul_vindex) *
UTF8PROC_HANGUL_TCOUNT; UTF8PROC_HANGUL_TCOUNT;
...@@ -645,33 +785,42 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_normalize_utf32(utf8proc_int32_t *b ...@@ -645,33 +785,42 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_normalize_utf32(utf8proc_int32_t *b
} }
} }
hangul_sindex = *starter - UTF8PROC_HANGUL_SBASE; hangul_sindex = *starter - UTF8PROC_HANGUL_SBASE;
if (hangul_sindex >= 0 && hangul_sindex < UTF8PROC_HANGUL_SCOUNT && if(hangul_sindex >= 0 && hangul_sindex < UTF8PROC_HANGUL_SCOUNT &&
(hangul_sindex % UTF8PROC_HANGUL_TCOUNT) == 0) { (hangul_sindex % UTF8PROC_HANGUL_TCOUNT) == 0)
{
utf8proc_int32_t hangul_tindex; utf8proc_int32_t hangul_tindex;
hangul_tindex = current_char - UTF8PROC_HANGUL_TBASE; hangul_tindex = current_char - UTF8PROC_HANGUL_TBASE;
if (hangul_tindex >= 0 && hangul_tindex < UTF8PROC_HANGUL_TCOUNT) { if(hangul_tindex >= 0 && hangul_tindex < UTF8PROC_HANGUL_TCOUNT)
{
*starter += hangul_tindex; *starter += hangul_tindex;
starter_property = NULL; starter_property = NULL;
continue; continue;
} }
} }
if (!starter_property) { if(!starter_property)
{
starter_property = unsafe_get_property(*starter); starter_property = unsafe_get_property(*starter);
} }
if (starter_property->comb_index < 0x8000 && if(starter_property->comb_index < 0x8000 &&
current_property->comb_index != UINT16_MAX && current_property->comb_index != UINT16_MAX &&
current_property->comb_index >= 0x8000) { current_property->comb_index >= 0x8000)
{
int sidx = starter_property->comb_index; int sidx = starter_property->comb_index;
int idx = current_property->comb_index & 0x3FFF; int idx = current_property->comb_index & 0x3FFF;
if (idx >= utf8proc_combinations[sidx] && idx <= utf8proc_combinations[sidx + 1] ) { if(idx >= utf8proc_combinations[sidx] && idx <= utf8proc_combinations[sidx + 1])
{
idx += sidx + 2 - utf8proc_combinations[sidx]; idx += sidx + 2 - utf8proc_combinations[sidx];
if (current_property->comb_index & 0x4000) { if(current_property->comb_index & 0x4000)
composition = (utf8proc_combinations[idx] << 16) | utf8proc_combinations[idx+1]; {
} else composition =
(utf8proc_combinations[idx] << 16) | utf8proc_combinations[idx + 1];
}
else
composition = utf8proc_combinations[idx]; composition = utf8proc_combinations[idx];
if (composition > 0 && (!(options & UTF8PROC_STABLE) || if(composition > 0 && (!(options & UTF8PROC_STABLE) ||
!(unsafe_get_property(composition)->comp_exclusion))) { !(unsafe_get_property(composition)->comp_exclusion)))
{
*starter = composition; *starter = composition;
starter_property = NULL; starter_property = NULL;
continue; continue;
...@@ -680,11 +829,15 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_normalize_utf32(utf8proc_int32_t *b ...@@ -680,11 +829,15 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_normalize_utf32(utf8proc_int32_t *b
} }
} }
buffer[wpos] = current_char; buffer[wpos] = current_char;
if (current_property->combining_class) { if(current_property->combining_class)
if (current_property->combining_class > max_combining_class) { {
if(current_property->combining_class > max_combining_class)
{
max_combining_class = current_property->combining_class; max_combining_class = current_property->combining_class;
} }
} else { }
else
{
starter = buffer + wpos; starter = buffer + wpos;
starter_property = NULL; starter_property = NULL;
max_combining_class = -1; max_combining_class = -1;
...@@ -696,97 +849,125 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_normalize_utf32(utf8proc_int32_t *b ...@@ -696,97 +849,125 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_normalize_utf32(utf8proc_int32_t *b
return length; return length;
} }
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_reencode(utf8proc_int32_t *buffer, utf8proc_ssize_t length, utf8proc_option_t options) { UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_reencode(utf8proc_int32_t* buffer,
utf8proc_ssize_t length,
utf8proc_option_t options)
{
/* UTF8PROC_NULLTERM option will be ignored, 'length' is never ignored /* UTF8PROC_NULLTERM option will be ignored, 'length' is never ignored
ASSERT: 'buffer' has one spare byte of free space at the end! */ ASSERT: 'buffer' has one spare byte of free space at the end! */
length = utf8proc_normalize_utf32(buffer, length, options); length = utf8proc_normalize_utf32(buffer, length, options);
if (length < 0) return length; if(length < 0)
return length;
{ {
utf8proc_ssize_t rpos, wpos = 0; utf8proc_ssize_t rpos, wpos = 0;
utf8proc_int32_t uc; utf8proc_int32_t uc;
if (options & UTF8PROC_CHARBOUND) { if(options & UTF8PROC_CHARBOUND)
for (rpos = 0; rpos < length; rpos++) { {
for(rpos = 0; rpos < length; rpos++)
{
uc = buffer[rpos]; uc = buffer[rpos];
wpos += charbound_encode_char(uc, ((utf8proc_uint8_t *)buffer) + wpos); wpos += charbound_encode_char(uc, ((utf8proc_uint8_t*)buffer) + wpos);
}
} }
} else { else
for (rpos = 0; rpos < length; rpos++) { {
for(rpos = 0; rpos < length; rpos++)
{
uc = buffer[rpos]; uc = buffer[rpos];
wpos += utf8proc_encode_char(uc, ((utf8proc_uint8_t *)buffer) + wpos); wpos += utf8proc_encode_char(uc, ((utf8proc_uint8_t*)buffer) + wpos);
} }
} }
((utf8proc_uint8_t *)buffer)[wpos] = 0; ((utf8proc_uint8_t*)buffer)[wpos] = 0;
return wpos; return wpos;
} }
} }
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_map( UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_map(const utf8proc_uint8_t* str,
const utf8proc_uint8_t *str, utf8proc_ssize_t strlen, utf8proc_uint8_t **dstptr, utf8proc_option_t options utf8proc_ssize_t strlen,
) { utf8proc_uint8_t** dstptr,
utf8proc_option_t options)
{
return utf8proc_map_custom(str, strlen, dstptr, options, NULL, NULL); return utf8proc_map_custom(str, strlen, dstptr, options, NULL, NULL);
} }
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_map_custom( UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_map_custom(const utf8proc_uint8_t* str,
const utf8proc_uint8_t *str, utf8proc_ssize_t strlen, utf8proc_uint8_t **dstptr, utf8proc_option_t options, utf8proc_ssize_t strlen,
utf8proc_custom_func custom_func, void *custom_data utf8proc_uint8_t** dstptr,
) { utf8proc_option_t options,
utf8proc_int32_t *buffer; utf8proc_custom_func custom_func,
void* custom_data)
{
utf8proc_int32_t* buffer;
utf8proc_ssize_t result; utf8proc_ssize_t result;
*dstptr = NULL; *dstptr = NULL;
result = utf8proc_decompose_custom(str, strlen, NULL, 0, options, custom_func, custom_data); result = utf8proc_decompose_custom(str, strlen, NULL, 0, options, custom_func, custom_data);
if (result < 0) return result; if(result < 0)
buffer = (utf8proc_int32_t *) malloc(((utf8proc_size_t)result) * sizeof(utf8proc_int32_t) + 1); return result;
if (!buffer) return UTF8PROC_ERROR_NOMEM; buffer = (utf8proc_int32_t*)malloc(((utf8proc_size_t)result) * sizeof(utf8proc_int32_t) + 1);
result = utf8proc_decompose_custom(str, strlen, buffer, result, options, custom_func, custom_data); if(!buffer)
if (result < 0) { return UTF8PROC_ERROR_NOMEM;
result =
utf8proc_decompose_custom(str, strlen, buffer, result, options, custom_func, custom_data);
if(result < 0)
{
free(buffer); free(buffer);
return result; return result;
} }
result = utf8proc_reencode(buffer, result, options); result = utf8proc_reencode(buffer, result, options);
if (result < 0) { if(result < 0)
{
free(buffer); free(buffer);
return result; return result;
} }
{ {
utf8proc_int32_t *newptr; utf8proc_int32_t* newptr;
newptr = (utf8proc_int32_t *) realloc(buffer, (size_t)result+1); newptr = (utf8proc_int32_t*)realloc(buffer, (size_t)result + 1);
if (newptr) buffer = newptr; if(newptr)
buffer = newptr;
} }
*dstptr = (utf8proc_uint8_t *)buffer; *dstptr = (utf8proc_uint8_t*)buffer;
return result; return result;
} }
UTF8PROC_DLLEXPORT utf8proc_uint8_t *utf8proc_NFD(const utf8proc_uint8_t *str) { UTF8PROC_DLLEXPORT utf8proc_uint8_t* utf8proc_NFD(const utf8proc_uint8_t* str)
utf8proc_uint8_t *retval; {
utf8proc_map(str, 0, &retval, UTF8PROC_NULLTERM | UTF8PROC_STABLE | utf8proc_uint8_t* retval;
UTF8PROC_DECOMPOSE); utf8proc_map(str, 0, &retval, UTF8PROC_NULLTERM | UTF8PROC_STABLE | UTF8PROC_DECOMPOSE);
return retval; return retval;
} }
UTF8PROC_DLLEXPORT utf8proc_uint8_t *utf8proc_NFC(const utf8proc_uint8_t *str) { UTF8PROC_DLLEXPORT utf8proc_uint8_t* utf8proc_NFC(const utf8proc_uint8_t* str)
utf8proc_uint8_t *retval; {
utf8proc_map(str, 0, &retval, UTF8PROC_NULLTERM | UTF8PROC_STABLE | utf8proc_uint8_t* retval;
UTF8PROC_COMPOSE); utf8proc_map(str, 0, &retval, UTF8PROC_NULLTERM | UTF8PROC_STABLE | UTF8PROC_COMPOSE);
return retval; return retval;
} }
UTF8PROC_DLLEXPORT utf8proc_uint8_t *utf8proc_NFKD(const utf8proc_uint8_t *str) { UTF8PROC_DLLEXPORT utf8proc_uint8_t* utf8proc_NFKD(const utf8proc_uint8_t* str)
utf8proc_uint8_t *retval; {
utf8proc_map(str, 0, &retval, UTF8PROC_NULLTERM | UTF8PROC_STABLE | utf8proc_uint8_t* retval;
UTF8PROC_DECOMPOSE | UTF8PROC_COMPAT); utf8proc_map(str,
0,
&retval,
UTF8PROC_NULLTERM | UTF8PROC_STABLE | UTF8PROC_DECOMPOSE | UTF8PROC_COMPAT);
return retval; return retval;
} }
UTF8PROC_DLLEXPORT utf8proc_uint8_t *utf8proc_NFKC(const utf8proc_uint8_t *str) { UTF8PROC_DLLEXPORT utf8proc_uint8_t* utf8proc_NFKC(const utf8proc_uint8_t* str)
utf8proc_uint8_t *retval; {
utf8proc_map(str, 0, &retval, UTF8PROC_NULLTERM | UTF8PROC_STABLE | utf8proc_uint8_t* retval;
UTF8PROC_COMPOSE | UTF8PROC_COMPAT); utf8proc_map(
str, 0, &retval, UTF8PROC_NULLTERM | UTF8PROC_STABLE | UTF8PROC_COMPOSE | UTF8PROC_COMPAT);
return retval; return retval;
} }
UTF8PROC_DLLEXPORT utf8proc_uint8_t *utf8proc_NFKC_Casefold(const utf8proc_uint8_t *str) { UTF8PROC_DLLEXPORT utf8proc_uint8_t* utf8proc_NFKC_Casefold(const utf8proc_uint8_t* str)
utf8proc_uint8_t *retval; {
utf8proc_map(str, 0, &retval, UTF8PROC_NULLTERM | UTF8PROC_STABLE | utf8proc_uint8_t* retval;
UTF8PROC_COMPOSE | UTF8PROC_COMPAT | UTF8PROC_CASEFOLD | UTF8PROC_IGNORE); utf8proc_map(str,
0,
&retval,
UTF8PROC_NULLTERM | UTF8PROC_STABLE | UTF8PROC_COMPOSE | UTF8PROC_COMPAT |
UTF8PROC_CASEFOLD | UTF8PROC_IGNORE);
return retval; return retval;
} }
/* /*
* Copyright (c) 2014-2021 Steven G. Johnson, Jiahao Chen, Peter Colberg, Tony Kelman, Scott P. Jones, and other contributors. * Copyright (c) 2014-2021 Steven G. Johnson, Jiahao Chen, Peter Colberg, Tony Kelman, Scott P.
* Copyright (c) 2009 Public Software Group e. V., Berlin, Germany * Jones, and other contributors. Copyright (c) 2009 Public Software Group e. V., Berlin, Germany
* *
* Permission is hereby granted, free of charge, to any person obtaining a * Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"), * copy of this software and associated documentation files (the "Software"),
...@@ -21,7 +21,6 @@ ...@@ -21,7 +21,6 @@
* DEALINGS IN THE SOFTWARE. * DEALINGS IN THE SOFTWARE.
*/ */
/** /**
* @mainpage * @mainpage
* *
...@@ -37,15 +36,20 @@ ...@@ -37,15 +36,20 @@
* The features of utf8proc include: * The features of utf8proc include:
* *
* - Transformation of strings (@ref utf8proc_map) to: * - Transformation of strings (@ref utf8proc_map) to:
* - decompose (@ref UTF8PROC_DECOMPOSE) or compose (@ref UTF8PROC_COMPOSE) Unicode combining characters (http://en.wikipedia.org/wiki/Combining_character) * - decompose (@ref UTF8PROC_DECOMPOSE) or compose (@ref UTF8PROC_COMPOSE) Unicode combining
* characters (http://en.wikipedia.org/wiki/Combining_character)
* - canonicalize Unicode compatibility characters (@ref UTF8PROC_COMPAT) * - canonicalize Unicode compatibility characters (@ref UTF8PROC_COMPAT)
* - strip "ignorable" (@ref UTF8PROC_IGNORE) characters, control characters (@ref UTF8PROC_STRIPCC), or combining characters such as accents (@ref UTF8PROC_STRIPMARK) * - strip "ignorable" (@ref UTF8PROC_IGNORE) characters, control characters (@ref
* UTF8PROC_STRIPCC), or combining characters such as accents (@ref UTF8PROC_STRIPMARK)
* - case-folding (@ref UTF8PROC_CASEFOLD) * - case-folding (@ref UTF8PROC_CASEFOLD)
* - Unicode normalization: @ref utf8proc_NFD, @ref utf8proc_NFC, @ref utf8proc_NFKD, @ref utf8proc_NFKC * - Unicode normalization: @ref utf8proc_NFD, @ref utf8proc_NFC, @ref utf8proc_NFKD, @ref
* utf8proc_NFKC
* - Detecting grapheme boundaries (@ref utf8proc_grapheme_break and @ref UTF8PROC_CHARBOUND) * - Detecting grapheme boundaries (@ref utf8proc_grapheme_break and @ref UTF8PROC_CHARBOUND)
* - Character-width computation: @ref utf8proc_charwidth * - Character-width computation: @ref utf8proc_charwidth
* - Classification of characters by Unicode category: @ref utf8proc_category and @ref utf8proc_category_string * - Classification of characters by Unicode category: @ref utf8proc_category and @ref
* - Encode (@ref utf8proc_encode_char) and decode (@ref utf8proc_iterate) Unicode codepoints to/from UTF-8. * utf8proc_category_string
* - Encode (@ref utf8proc_encode_char) and decode (@ref utf8proc_iterate) Unicode codepoints
* to/from UTF-8.
*/ */
/** @file */ /** @file */
...@@ -70,7 +74,8 @@ ...@@ -70,7 +74,8 @@
/** @{ */ /** @{ */
/** The MAJOR version number (increased when backwards API compatibility is broken). */ /** The MAJOR version number (increased when backwards API compatibility is broken). */
#define UTF8PROC_VERSION_MAJOR 2 #define UTF8PROC_VERSION_MAJOR 2
/** The MINOR version number (increased when new functionality is added in a backwards-compatible manner). */ /** The MINOR version number (increased when new functionality is added in a backwards-compatible
* manner). */
#define UTF8PROC_VERSION_MINOR 8 #define UTF8PROC_VERSION_MINOR 8
/** The PATCH version (increased for fixes that do not change the API). */ /** The PATCH version (increased for fixes that do not change the API). */
#define UTF8PROC_VERSION_PATCH 0 #define UTF8PROC_VERSION_PATCH 0
...@@ -86,28 +91,28 @@ typedef short utf8proc_int16_t; ...@@ -86,28 +91,28 @@ typedef short utf8proc_int16_t;
typedef unsigned short utf8proc_uint16_t; typedef unsigned short utf8proc_uint16_t;
typedef int utf8proc_int32_t; typedef int utf8proc_int32_t;
typedef unsigned int utf8proc_uint32_t; typedef unsigned int utf8proc_uint32_t;
# ifdef _WIN64 #ifdef _WIN64
typedef __int64 utf8proc_ssize_t; typedef __int64 utf8proc_ssize_t;
typedef unsigned __int64 utf8proc_size_t; typedef unsigned __int64 utf8proc_size_t;
# else #else
typedef int utf8proc_ssize_t; typedef int utf8proc_ssize_t;
typedef unsigned int utf8proc_size_t; typedef unsigned int utf8proc_size_t;
# endif #endif
# ifndef __cplusplus #ifndef __cplusplus
// emulate C99 bool // emulate C99 bool
typedef unsigned char utf8proc_bool; typedef unsigned char utf8proc_bool;
# ifndef __bool_true_false_are_defined #ifndef __bool_true_false_are_defined
# define false 0 #define false 0
# define true 1 #define true 1
# define __bool_true_false_are_defined 1 #define __bool_true_false_are_defined 1
# endif #endif
# else #else
typedef bool utf8proc_bool; typedef bool utf8proc_bool;
# endif #endif
#else #else
# include <stddef.h> #include <stddef.h>
# include <stdbool.h> #include <stdbool.h>
# include <inttypes.h> #include <inttypes.h>
typedef int8_t utf8proc_int8_t; typedef int8_t utf8proc_int8_t;
typedef uint8_t utf8proc_uint8_t; typedef uint8_t utf8proc_uint8_t;
typedef int16_t utf8proc_int16_t; typedef int16_t utf8proc_int16_t;
...@@ -121,19 +126,19 @@ typedef bool utf8proc_bool; ...@@ -121,19 +126,19 @@ typedef bool utf8proc_bool;
#include <limits.h> #include <limits.h>
#ifdef UTF8PROC_STATIC #ifdef UTF8PROC_STATIC
# define UTF8PROC_DLLEXPORT #define UTF8PROC_DLLEXPORT
#else #else
# ifdef _WIN32 #ifdef _WIN32
# ifdef UTF8PROC_EXPORTS #ifdef UTF8PROC_EXPORTS
# define UTF8PROC_DLLEXPORT __declspec(dllexport) #define UTF8PROC_DLLEXPORT __declspec(dllexport)
# else #else
# define UTF8PROC_DLLEXPORT __declspec(dllimport) #define UTF8PROC_DLLEXPORT __declspec(dllimport)
# endif #endif
# elif __GNUC__ >= 4 #elif __GNUC__ >= 4
# define UTF8PROC_DLLEXPORT __attribute__ ((visibility("default"))) #define UTF8PROC_DLLEXPORT __attribute__((visibility("default")))
# else #else
# define UTF8PROC_DLLEXPORT #define UTF8PROC_DLLEXPORT
# endif #endif
#endif #endif
#ifdef __cplusplus #ifdef __cplusplus
...@@ -143,33 +148,34 @@ extern "C" { ...@@ -143,33 +148,34 @@ extern "C" {
/** /**
* Option flags used by several functions in the library. * Option flags used by several functions in the library.
*/ */
typedef enum { typedef enum
{
/** The given UTF-8 input is NULL terminated. */ /** The given UTF-8 input is NULL terminated. */
UTF8PROC_NULLTERM = (1<<0), UTF8PROC_NULLTERM = (1 << 0),
/** Unicode Versioning Stability has to be respected. */ /** Unicode Versioning Stability has to be respected. */
UTF8PROC_STABLE = (1<<1), UTF8PROC_STABLE = (1 << 1),
/** Compatibility decomposition (i.e. formatting information is lost). */ /** Compatibility decomposition (i.e. formatting information is lost). */
UTF8PROC_COMPAT = (1<<2), UTF8PROC_COMPAT = (1 << 2),
/** Return a result with decomposed characters. */ /** Return a result with decomposed characters. */
UTF8PROC_COMPOSE = (1<<3), UTF8PROC_COMPOSE = (1 << 3),
/** Return a result with decomposed characters. */ /** Return a result with decomposed characters. */
UTF8PROC_DECOMPOSE = (1<<4), UTF8PROC_DECOMPOSE = (1 << 4),
/** Strip "default ignorable characters" such as SOFT-HYPHEN or ZERO-WIDTH-SPACE. */ /** Strip "default ignorable characters" such as SOFT-HYPHEN or ZERO-WIDTH-SPACE. */
UTF8PROC_IGNORE = (1<<5), UTF8PROC_IGNORE = (1 << 5),
/** Return an error, if the input contains unassigned codepoints. */ /** Return an error, if the input contains unassigned codepoints. */
UTF8PROC_REJECTNA = (1<<6), UTF8PROC_REJECTNA = (1 << 6),
/** /**
* Indicating that NLF-sequences (LF, CRLF, CR, NEL) are representing a * Indicating that NLF-sequences (LF, CRLF, CR, NEL) are representing a
* line break, and should be converted to the codepoint for line * line break, and should be converted to the codepoint for line
* separation (LS). * separation (LS).
*/ */
UTF8PROC_NLF2LS = (1<<7), UTF8PROC_NLF2LS = (1 << 7),
/** /**
* Indicating that NLF-sequences are representing a paragraph break, and * Indicating that NLF-sequences are representing a paragraph break, and
* should be converted to the codepoint for paragraph separation * should be converted to the codepoint for paragraph separation
* (PS). * (PS).
*/ */
UTF8PROC_NLF2PS = (1<<8), UTF8PROC_NLF2PS = (1 << 8),
/** Indicating that the meaning of NLF-sequences is unknown. */ /** Indicating that the meaning of NLF-sequences is unknown. */
UTF8PROC_NLF2LF = (UTF8PROC_NLF2LS | UTF8PROC_NLF2PS), UTF8PROC_NLF2LF = (UTF8PROC_NLF2LS | UTF8PROC_NLF2PS),
/** Strips and/or convers control characters. /** Strips and/or convers control characters.
...@@ -179,17 +185,17 @@ typedef enum { ...@@ -179,17 +185,17 @@ typedef enum {
* are treated as a NLF-sequence in this case. All other control * are treated as a NLF-sequence in this case. All other control
* characters are simply removed. * characters are simply removed.
*/ */
UTF8PROC_STRIPCC = (1<<9), UTF8PROC_STRIPCC = (1 << 9),
/** /**
* Performs unicode case folding, to be able to do a case-insensitive * Performs unicode case folding, to be able to do a case-insensitive
* string comparison. * string comparison.
*/ */
UTF8PROC_CASEFOLD = (1<<10), UTF8PROC_CASEFOLD = (1 << 10),
/** /**
* Inserts 0xFF bytes at the beginning of each sequence which is * Inserts 0xFF bytes at the beginning of each sequence which is
* representing a single grapheme cluster (see UAX#29). * representing a single grapheme cluster (see UAX#29).
*/ */
UTF8PROC_CHARBOUND = (1<<11), UTF8PROC_CHARBOUND = (1 << 11),
/** Lumps certain characters together. /** Lumps certain characters together.
* *
* E.g. HYPHEN U+2010 and MINUS U+2212 to ASCII "-". See lump.md for details. * E.g. HYPHEN U+2010 and MINUS U+2212 to ASCII "-". See lump.md for details.
...@@ -197,18 +203,18 @@ typedef enum { ...@@ -197,18 +203,18 @@ typedef enum {
* If NLF2LF is set, this includes a transformation of paragraph and * If NLF2LF is set, this includes a transformation of paragraph and
* line separators to ASCII line-feed (LF). * line separators to ASCII line-feed (LF).
*/ */
UTF8PROC_LUMP = (1<<12), UTF8PROC_LUMP = (1 << 12),
/** Strips all character markings. /** Strips all character markings.
* *
* This includes non-spacing, spacing and enclosing (i.e. accents). * This includes non-spacing, spacing and enclosing (i.e. accents).
* @note This option works only with @ref UTF8PROC_COMPOSE or * @note This option works only with @ref UTF8PROC_COMPOSE or
* @ref UTF8PROC_DECOMPOSE * @ref UTF8PROC_DECOMPOSE
*/ */
UTF8PROC_STRIPMARK = (1<<13), UTF8PROC_STRIPMARK = (1 << 13),
/** /**
* Strip unassigned codepoints. * Strip unassigned codepoints.
*/ */
UTF8PROC_STRIPNA = (1<<14), UTF8PROC_STRIPNA = (1 << 14),
} utf8proc_option_t; } utf8proc_option_t;
/** @name Error codes /** @name Error codes
...@@ -233,7 +239,8 @@ typedef enum { ...@@ -233,7 +239,8 @@ typedef enum {
typedef utf8proc_int16_t utf8proc_propval_t; typedef utf8proc_int16_t utf8proc_propval_t;
/** Struct containing information about a codepoint. */ /** Struct containing information about a codepoint. */
typedef struct utf8proc_property_struct { typedef struct utf8proc_property_struct
{
/** /**
* Unicode category. * Unicode category.
* @see utf8proc_category_t. * @see utf8proc_category_t.
...@@ -256,28 +263,29 @@ typedef struct utf8proc_property_struct { ...@@ -256,28 +263,29 @@ typedef struct utf8proc_property_struct {
utf8proc_uint16_t lowercase_seqindex; utf8proc_uint16_t lowercase_seqindex;
utf8proc_uint16_t titlecase_seqindex; utf8proc_uint16_t titlecase_seqindex;
utf8proc_uint16_t comb_index; utf8proc_uint16_t comb_index;
unsigned bidi_mirrored:1; unsigned bidi_mirrored : 1;
unsigned comp_exclusion:1; unsigned comp_exclusion : 1;
/** /**
* Can this codepoint be ignored? * Can this codepoint be ignored?
* *
* Used by @ref utf8proc_decompose_char when @ref UTF8PROC_IGNORE is * Used by @ref utf8proc_decompose_char when @ref UTF8PROC_IGNORE is
* passed as an option. * passed as an option.
*/ */
unsigned ignorable:1; unsigned ignorable : 1;
unsigned control_boundary:1; unsigned control_boundary : 1;
/** The width of the codepoint. */ /** The width of the codepoint. */
unsigned charwidth:2; unsigned charwidth : 2;
unsigned pad:2; unsigned pad : 2;
/** /**
* Boundclass. * Boundclass.
* @see utf8proc_boundclass_t. * @see utf8proc_boundclass_t.
*/ */
unsigned boundclass:8; unsigned boundclass : 8;
} utf8proc_property_t; } utf8proc_property_t;
/** Unicode categories. */ /** Unicode categories. */
typedef enum { typedef enum
{
UTF8PROC_CATEGORY_CN = 0, /**< Other, not assigned */ UTF8PROC_CATEGORY_CN = 0, /**< Other, not assigned */
UTF8PROC_CATEGORY_LU = 1, /**< Letter, uppercase */ UTF8PROC_CATEGORY_LU = 1, /**< Letter, uppercase */
UTF8PROC_CATEGORY_LL = 2, /**< Letter, lowercase */ UTF8PROC_CATEGORY_LL = 2, /**< Letter, lowercase */
...@@ -311,7 +319,8 @@ typedef enum { ...@@ -311,7 +319,8 @@ typedef enum {
} utf8proc_category_t; } utf8proc_category_t;
/** Bidirectional character classes. */ /** Bidirectional character classes. */
typedef enum { typedef enum
{
UTF8PROC_BIDI_CLASS_L = 1, /**< Left-to-Right */ UTF8PROC_BIDI_CLASS_L = 1, /**< Left-to-Right */
UTF8PROC_BIDI_CLASS_LRE = 2, /**< Left-to-Right Embedding */ UTF8PROC_BIDI_CLASS_LRE = 2, /**< Left-to-Right Embedding */
UTF8PROC_BIDI_CLASS_LRO = 3, /**< Left-to-Right Override */ UTF8PROC_BIDI_CLASS_LRO = 3, /**< Left-to-Right Override */
...@@ -338,7 +347,8 @@ typedef enum { ...@@ -338,7 +347,8 @@ typedef enum {
} utf8proc_bidi_class_t; } utf8proc_bidi_class_t;
/** Decomposition type. */ /** Decomposition type. */
typedef enum { typedef enum
{
UTF8PROC_DECOMP_TYPE_FONT = 1, /**< Font */ UTF8PROC_DECOMP_TYPE_FONT = 1, /**< Font */
UTF8PROC_DECOMP_TYPE_NOBREAK = 2, /**< Nobreak */ UTF8PROC_DECOMP_TYPE_NOBREAK = 2, /**< Nobreak */
UTF8PROC_DECOMP_TYPE_INITIAL = 3, /**< Initial */ UTF8PROC_DECOMP_TYPE_INITIAL = 3, /**< Initial */
...@@ -358,7 +368,8 @@ typedef enum { ...@@ -358,7 +368,8 @@ typedef enum {
} utf8proc_decomp_type_t; } utf8proc_decomp_type_t;
/** Boundclass property. (TR29) */ /** Boundclass property. (TR29) */
typedef enum { typedef enum
{
UTF8PROC_BOUNDCLASS_START = 0, /**< Start */ UTF8PROC_BOUNDCLASS_START = 0, /**< Start */
UTF8PROC_BOUNDCLASS_OTHER = 1, /**< Other */ UTF8PROC_BOUNDCLASS_OTHER = 1, /**< Other */
UTF8PROC_BOUNDCLASS_CR = 2, /**< Cr */ UTF8PROC_BOUNDCLASS_CR = 2, /**< Cr */
...@@ -393,7 +404,7 @@ typedef enum { ...@@ -393,7 +404,7 @@ typedef enum {
* @ref utf8proc_decompose_custom, which is used to specify a user-defined * @ref utf8proc_decompose_custom, which is used to specify a user-defined
* mapping of codepoints to be applied in conjunction with other mappings. * mapping of codepoints to be applied in conjunction with other mappings.
*/ */
typedef utf8proc_int32_t (*utf8proc_custom_func)(utf8proc_int32_t codepoint, void *data); typedef utf8proc_int32_t (*utf8proc_custom_func)(utf8proc_int32_t codepoint, void* data);
/** /**
* Array containing the byte lengths of a UTF-8 encoded codepoint based * Array containing the byte lengths of a UTF-8 encoded codepoint based
...@@ -406,18 +417,18 @@ UTF8PROC_DLLEXPORT extern const utf8proc_int8_t utf8proc_utf8class[256]; ...@@ -406,18 +417,18 @@ UTF8PROC_DLLEXPORT extern const utf8proc_int8_t utf8proc_utf8class[256];
* (http://semver.org format), possibly with a "-dev" suffix for * (http://semver.org format), possibly with a "-dev" suffix for
* development versions. * development versions.
*/ */
UTF8PROC_DLLEXPORT const char *utf8proc_version(void); UTF8PROC_DLLEXPORT const char* utf8proc_version(void);
/** /**
* Returns the utf8proc supported Unicode version as a string MAJOR.MINOR.PATCH. * Returns the utf8proc supported Unicode version as a string MAJOR.MINOR.PATCH.
*/ */
UTF8PROC_DLLEXPORT const char *utf8proc_unicode_version(void); UTF8PROC_DLLEXPORT const char* utf8proc_unicode_version(void);
/** /**
* Returns an informative error string for the given utf8proc error code * Returns an informative error string for the given utf8proc error code
* (e.g. the error codes returned by @ref utf8proc_map). * (e.g. the error codes returned by @ref utf8proc_map).
*/ */
UTF8PROC_DLLEXPORT const char *utf8proc_errmsg(utf8proc_ssize_t errcode); UTF8PROC_DLLEXPORT const char* utf8proc_errmsg(utf8proc_ssize_t errcode);
/** /**
* Reads a single codepoint from the UTF-8 sequence being pointed to by `str`. * Reads a single codepoint from the UTF-8 sequence being pointed to by `str`.
...@@ -429,7 +440,9 @@ UTF8PROC_DLLEXPORT const char *utf8proc_errmsg(utf8proc_ssize_t errcode); ...@@ -429,7 +440,9 @@ UTF8PROC_DLLEXPORT const char *utf8proc_errmsg(utf8proc_ssize_t errcode);
* In case of success, the number of bytes read is returned; otherwise, a * In case of success, the number of bytes read is returned; otherwise, a
* negative error code is returned. * negative error code is returned.
*/ */
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_iterate(const utf8proc_uint8_t *str, utf8proc_ssize_t strlen, utf8proc_int32_t *codepoint_ref); UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_iterate(const utf8proc_uint8_t* str,
utf8proc_ssize_t strlen,
utf8proc_int32_t* codepoint_ref);
/** /**
* Check if a codepoint is valid (regardless of whether it has been * Check if a codepoint is valid (regardless of whether it has been
...@@ -448,7 +461,8 @@ UTF8PROC_DLLEXPORT utf8proc_bool utf8proc_codepoint_valid(utf8proc_int32_t codep ...@@ -448,7 +461,8 @@ UTF8PROC_DLLEXPORT utf8proc_bool utf8proc_codepoint_valid(utf8proc_int32_t codep
* *
* This function does not check whether `codepoint` is valid Unicode. * This function does not check whether `codepoint` is valid Unicode.
*/ */
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_encode_char(utf8proc_int32_t codepoint, utf8proc_uint8_t *dst); UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_encode_char(utf8proc_int32_t codepoint,
utf8proc_uint8_t* dst);
/** /**
* Look up the properties for a given codepoint. * Look up the properties for a given codepoint.
...@@ -462,7 +476,7 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_encode_char(utf8proc_int32_t codepo ...@@ -462,7 +476,7 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_encode_char(utf8proc_int32_t codepo
* If the codepoint is unassigned or invalid, a pointer to a special struct is * If the codepoint is unassigned or invalid, a pointer to a special struct is
* returned in which `category` is 0 (@ref UTF8PROC_CATEGORY_CN). * returned in which `category` is 0 (@ref UTF8PROC_CATEGORY_CN).
*/ */
UTF8PROC_DLLEXPORT const utf8proc_property_t *utf8proc_get_property(utf8proc_int32_t codepoint); UTF8PROC_DLLEXPORT const utf8proc_property_t* utf8proc_get_property(utf8proc_int32_t codepoint);
/** Decompose a codepoint into an array of codepoints. /** Decompose a codepoint into an array of codepoints.
* *
...@@ -492,10 +506,11 @@ UTF8PROC_DLLEXPORT const utf8proc_property_t *utf8proc_get_property(utf8proc_int ...@@ -492,10 +506,11 @@ UTF8PROC_DLLEXPORT const utf8proc_property_t *utf8proc_get_property(utf8proc_int
* required buffer size is returned, while the buffer will be overwritten with * required buffer size is returned, while the buffer will be overwritten with
* undefined data. * undefined data.
*/ */
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose_char( UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose_char(utf8proc_int32_t codepoint,
utf8proc_int32_t codepoint, utf8proc_int32_t *dst, utf8proc_ssize_t bufsize, utf8proc_int32_t* dst,
utf8proc_option_t options, int *last_boundclass utf8proc_ssize_t bufsize,
); utf8proc_option_t options,
int* last_boundclass);
/** /**
* The same as @ref utf8proc_decompose_char, but acts on a whole UTF-8 * The same as @ref utf8proc_decompose_char, but acts on a whole UTF-8
...@@ -514,10 +529,11 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose_char( ...@@ -514,10 +529,11 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose_char(
* required buffer size is returned, while the buffer will be overwritten with * required buffer size is returned, while the buffer will be overwritten with
* undefined data. * undefined data.
*/ */
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose( UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose(const utf8proc_uint8_t* str,
const utf8proc_uint8_t *str, utf8proc_ssize_t strlen, utf8proc_ssize_t strlen,
utf8proc_int32_t *buffer, utf8proc_ssize_t bufsize, utf8proc_option_t options utf8proc_int32_t* buffer,
); utf8proc_ssize_t bufsize,
utf8proc_option_t options);
/** /**
* The same as @ref utf8proc_decompose, but also takes a `custom_func` mapping function * The same as @ref utf8proc_decompose, but also takes a `custom_func` mapping function
...@@ -525,11 +541,13 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose( ...@@ -525,11 +541,13 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose(
* (along with a `custom_data` pointer that is passed through to `custom_func`). * (along with a `custom_data` pointer that is passed through to `custom_func`).
* The `custom_func` argument is ignored if it is `NULL`. See also @ref utf8proc_map_custom. * The `custom_func` argument is ignored if it is `NULL`. See also @ref utf8proc_map_custom.
*/ */
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose_custom( UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose_custom(const utf8proc_uint8_t* str,
const utf8proc_uint8_t *str, utf8proc_ssize_t strlen, utf8proc_ssize_t strlen,
utf8proc_int32_t *buffer, utf8proc_ssize_t bufsize, utf8proc_option_t options, utf8proc_int32_t* buffer,
utf8proc_custom_func custom_func, void *custom_data utf8proc_ssize_t bufsize,
); utf8proc_option_t options,
utf8proc_custom_func custom_func,
void* custom_data);
/** /**
* Normalizes the sequence of `length` codepoints pointed to by `buffer` * Normalizes the sequence of `length` codepoints pointed to by `buffer`
...@@ -554,7 +572,9 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose_custom( ...@@ -554,7 +572,9 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose_custom(
* @warning The entries of the array pointed to by `str` have to be in the * @warning The entries of the array pointed to by `str` have to be in the
* range `0x0000` to `0x10FFFF`. Otherwise, the program might crash! * range `0x0000` to `0x10FFFF`. Otherwise, the program might crash!
*/ */
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_normalize_utf32(utf8proc_int32_t *buffer, utf8proc_ssize_t length, utf8proc_option_t options); UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_normalize_utf32(utf8proc_int32_t* buffer,
utf8proc_ssize_t length,
utf8proc_option_t options);
/** /**
* Reencodes the sequence of `length` codepoints pointed to by `buffer` * Reencodes the sequence of `length` codepoints pointed to by `buffer`
...@@ -584,7 +604,9 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_normalize_utf32(utf8proc_int32_t *b ...@@ -584,7 +604,9 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_normalize_utf32(utf8proc_int32_t *b
* entries of the array pointed to by `str` have to be in the * entries of the array pointed to by `str` have to be in the
* range `0x0000` to `0x10FFFF`. Otherwise, the program might crash! * range `0x0000` to `0x10FFFF`. Otherwise, the program might crash!
*/ */
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_reencode(utf8proc_int32_t *buffer, utf8proc_ssize_t length, utf8proc_option_t options); UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_reencode(utf8proc_int32_t* buffer,
utf8proc_ssize_t length,
utf8proc_option_t options);
/** /**
* Given a pair of consecutive codepoints, return whether a grapheme break is * Given a pair of consecutive codepoints, return whether a grapheme break is
...@@ -603,16 +625,16 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_reencode(utf8proc_int32_t *buffer, ...@@ -603,16 +625,16 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_reencode(utf8proc_int32_t *buffer,
* be called IN ORDER on ALL potential breaks in a string. However, it * be called IN ORDER on ALL potential breaks in a string. However, it
* is safe to reset the state to zero after a grapheme break. * is safe to reset the state to zero after a grapheme break.
*/ */
UTF8PROC_DLLEXPORT utf8proc_bool utf8proc_grapheme_break_stateful( UTF8PROC_DLLEXPORT utf8proc_bool utf8proc_grapheme_break_stateful(utf8proc_int32_t codepoint1,
utf8proc_int32_t codepoint1, utf8proc_int32_t codepoint2, utf8proc_int32_t *state); utf8proc_int32_t codepoint2,
utf8proc_int32_t* state);
/** /**
* Same as @ref utf8proc_grapheme_break_stateful, except without support for the * Same as @ref utf8proc_grapheme_break_stateful, except without support for the
* Unicode 9 additions to the algorithm. Supported for legacy reasons. * Unicode 9 additions to the algorithm. Supported for legacy reasons.
*/ */
UTF8PROC_DLLEXPORT utf8proc_bool utf8proc_grapheme_break( UTF8PROC_DLLEXPORT utf8proc_bool utf8proc_grapheme_break(utf8proc_int32_t codepoint1,
utf8proc_int32_t codepoint1, utf8proc_int32_t codepoint2); utf8proc_int32_t codepoint2);
/** /**
* Given a codepoint `c`, return the codepoint of the corresponding * Given a codepoint `c`, return the codepoint of the corresponding
...@@ -667,7 +689,7 @@ UTF8PROC_DLLEXPORT utf8proc_category_t utf8proc_category(utf8proc_int32_t codepo ...@@ -667,7 +689,7 @@ UTF8PROC_DLLEXPORT utf8proc_category_t utf8proc_category(utf8proc_int32_t codepo
* Return the two-letter (nul-terminated) Unicode category string for * Return the two-letter (nul-terminated) Unicode category string for
* the codepoint (e.g. `"Lu"` or `"Co"`). * the codepoint (e.g. `"Lu"` or `"Co"`).
*/ */
UTF8PROC_DLLEXPORT const char *utf8proc_category_string(utf8proc_int32_t codepoint); UTF8PROC_DLLEXPORT const char* utf8proc_category_string(utf8proc_int32_t codepoint);
/** /**
* Maps the given UTF-8 string pointed to by `str` to a new UTF-8 * Maps the given UTF-8 string pointed to by `str` to a new UTF-8
...@@ -688,9 +710,10 @@ UTF8PROC_DLLEXPORT const char *utf8proc_category_string(utf8proc_int32_t codepoi ...@@ -688,9 +710,10 @@ UTF8PROC_DLLEXPORT const char *utf8proc_category_string(utf8proc_int32_t codepoi
* @note The memory of the new UTF-8 string will have been allocated * @note The memory of the new UTF-8 string will have been allocated
* with `malloc`, and should therefore be deallocated with `free`. * with `malloc`, and should therefore be deallocated with `free`.
*/ */
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_map( UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_map(const utf8proc_uint8_t* str,
const utf8proc_uint8_t *str, utf8proc_ssize_t strlen, utf8proc_uint8_t **dstptr, utf8proc_option_t options utf8proc_ssize_t strlen,
); utf8proc_uint8_t** dstptr,
utf8proc_option_t options);
/** /**
* Like @ref utf8proc_map, but also takes a `custom_func` mapping function * Like @ref utf8proc_map, but also takes a `custom_func` mapping function
...@@ -698,10 +721,12 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_map( ...@@ -698,10 +721,12 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_map(
* (along with a `custom_data` pointer that is passed through to `custom_func`). * (along with a `custom_data` pointer that is passed through to `custom_func`).
* The `custom_func` argument is ignored if it is `NULL`. * The `custom_func` argument is ignored if it is `NULL`.
*/ */
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_map_custom( UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_map_custom(const utf8proc_uint8_t* str,
const utf8proc_uint8_t *str, utf8proc_ssize_t strlen, utf8proc_uint8_t **dstptr, utf8proc_option_t options, utf8proc_ssize_t strlen,
utf8proc_custom_func custom_func, void *custom_data utf8proc_uint8_t** dstptr,
); utf8proc_option_t options,
utf8proc_custom_func custom_func,
void* custom_data);
/** @name Unicode normalization /** @name Unicode normalization
* *
...@@ -712,18 +737,18 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_map_custom( ...@@ -712,18 +737,18 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_map_custom(
*/ */
/** @{ */ /** @{ */
/** NFD normalization (@ref UTF8PROC_DECOMPOSE). */ /** NFD normalization (@ref UTF8PROC_DECOMPOSE). */
UTF8PROC_DLLEXPORT utf8proc_uint8_t *utf8proc_NFD(const utf8proc_uint8_t *str); UTF8PROC_DLLEXPORT utf8proc_uint8_t* utf8proc_NFD(const utf8proc_uint8_t* str);
/** NFC normalization (@ref UTF8PROC_COMPOSE). */ /** NFC normalization (@ref UTF8PROC_COMPOSE). */
UTF8PROC_DLLEXPORT utf8proc_uint8_t *utf8proc_NFC(const utf8proc_uint8_t *str); UTF8PROC_DLLEXPORT utf8proc_uint8_t* utf8proc_NFC(const utf8proc_uint8_t* str);
/** NFKD normalization (@ref UTF8PROC_DECOMPOSE and @ref UTF8PROC_COMPAT). */ /** NFKD normalization (@ref UTF8PROC_DECOMPOSE and @ref UTF8PROC_COMPAT). */
UTF8PROC_DLLEXPORT utf8proc_uint8_t *utf8proc_NFKD(const utf8proc_uint8_t *str); UTF8PROC_DLLEXPORT utf8proc_uint8_t* utf8proc_NFKD(const utf8proc_uint8_t* str);
/** NFKC normalization (@ref UTF8PROC_COMPOSE and @ref UTF8PROC_COMPAT). */ /** NFKC normalization (@ref UTF8PROC_COMPOSE and @ref UTF8PROC_COMPAT). */
UTF8PROC_DLLEXPORT utf8proc_uint8_t *utf8proc_NFKC(const utf8proc_uint8_t *str); UTF8PROC_DLLEXPORT utf8proc_uint8_t* utf8proc_NFKC(const utf8proc_uint8_t* str);
/** /**
* NFKC_Casefold normalization (@ref UTF8PROC_COMPOSE and @ref UTF8PROC_COMPAT * NFKC_Casefold normalization (@ref UTF8PROC_COMPOSE and @ref UTF8PROC_COMPAT
* and @ref UTF8PROC_CASEFOLD and @ref UTF8PROC_IGNORE). * and @ref UTF8PROC_CASEFOLD and @ref UTF8PROC_IGNORE).
**/ **/
UTF8PROC_DLLEXPORT utf8proc_uint8_t *utf8proc_NFKC_Casefold(const utf8proc_uint8_t *str); UTF8PROC_DLLEXPORT utf8proc_uint8_t* utf8proc_NFKC_Casefold(const utf8proc_uint8_t* str);
/** @} */ /** @} */
#ifdef __cplusplus #ifdef __cplusplus
......
...@@ -12,7 +12,7 @@ int main() ...@@ -12,7 +12,7 @@ int main()
// 加载GPT2模型 // 加载GPT2模型
migraphxSamples::GPT2 gpt2; migraphxSamples::GPT2 gpt2;
migraphxSamples::ErrorCode errorCode = gpt2.Initialize(); migraphxSamples::ErrorCode errorCode = gpt2.Initialize();
if (errorCode != migraphxSamples::SUCCESS) if(errorCode != migraphxSamples::SUCCESS)
{ {
LOG_ERROR(stdout, "fail to initialize GPT2!\n"); LOG_ERROR(stdout, "fail to initialize GPT2!\n");
exit(-1); exit(-1);
...@@ -25,7 +25,7 @@ int main() ...@@ -25,7 +25,7 @@ int main()
std::string buf; std::string buf;
std::vector<std::string> output; std::vector<std::string> output;
infile.open("../Resource/vocab_shici.txt"); infile.open("../Resource/vocab_shici.txt");
while (std::getline(infile,buf)) while(std::getline(infile, buf))
{ {
output.push_back(buf); output.push_back(buf);
} }
...@@ -37,7 +37,7 @@ int main() ...@@ -37,7 +37,7 @@ int main()
std::vector<std::string> result; std::vector<std::string> result;
std::cout << "开始和GPT2对诗,输入CTRL + Z以退出" << std::endl; std::cout << "开始和GPT2对诗,输入CTRL + Z以退出" << std::endl;
while (true) while(true)
{ {
// 数据预处理 // 数据预处理
std::cout << "question: "; std::cout << "question: ";
...@@ -45,7 +45,7 @@ int main() ...@@ -45,7 +45,7 @@ int main()
gpt2.Preprocessing(tokenizer, question, input_id); gpt2.Preprocessing(tokenizer, question, input_id);
// 推理 // 推理
for(int i=0;i<50;++i) for(int i = 0; i < 50; ++i)
{ {
long unsigned int outputs = gpt2.Inference(input_id); long unsigned int outputs = gpt2.Inference(input_id);
if(outputs == 102) if(outputs == 102)
...@@ -57,7 +57,7 @@ int main() ...@@ -57,7 +57,7 @@ int main()
} }
// 将数值映射为字符 // 将数值映射为字符
for(int i=0;i<score.size();++i) for(int i = 0; i < score.size(); ++i)
{ {
result.push_back(output[score[i]]); result.push_back(output[score[i]]);
} }
...@@ -65,7 +65,7 @@ int main() ...@@ -65,7 +65,7 @@ int main()
// 打印结果 // 打印结果
std::cout << "chatbot: "; std::cout << "chatbot: ";
std::cout << question; std::cout << question;
for(int j=0; j<result.size();++j) for(int j = 0; j < result.size(); ++j)
{ {
std::cout << result[j]; std::cout << result[j];
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment