Commit aec6280a authored by liucong's avatar liucong
Browse files

对C++代码通过格式化

parent ab78f8ec
#include <Bert.h> #include <Bert.h>
#include <migraphx/onnx.hpp>
#include <migraphx/gpu/target.hpp>
#include <Filesystem.h> #include <Filesystem.h>
#include <SimpleLog.h> #include <SimpleLog.h>
#include <algorithm>
#include <stdexcept>
#include <tokenization.h> #include <tokenization.h>
namespace migraphxSamples #include <algorithm>
{ #include <migraphx/gpu/target.hpp>
#include <migraphx/onnx.hpp>
Bert::Bert() #include <stdexcept>
{
} namespace migraphxSamples {
Bert::~Bert() Bert::Bert() {}
{
} Bert::~Bert() {}
ErrorCode Bert::Initialize() ErrorCode Bert::Initialize()
{ {
// 获取模型文件 // 获取模型文件
std::string modelPath="../Resource/bertsquad-10.onnx"; std::string modelPath = "../Resource/bertsquad-10.onnx";
// 设置最大输入shape // 设置最大输入shape
migraphx::onnx_options onnx_options; migraphx::onnx_options onnx_options;
onnx_options.map_input_dims["unique_ids_raw_output___9:0"]={1}; onnx_options.map_input_dims["unique_ids_raw_output___9:0"] = {1};
onnx_options.map_input_dims["input_ids:0"]={1,256}; onnx_options.map_input_dims["input_ids:0"] = {1, 256};
onnx_options.map_input_dims["input_mask:0"]={1,256}; onnx_options.map_input_dims["input_mask:0"] = {1, 256};
onnx_options.map_input_dims["segment_ids:0"]={1,256}; onnx_options.map_input_dims["segment_ids:0"] = {1, 256};
// 加载模型 // 加载模型
if(!Exists(modelPath)) if(!Exists(modelPath))
{ {
LOG_ERROR(stdout,"%s not exist!\n",modelPath.c_str()); LOG_ERROR(stdout, "%s not exist!\n", modelPath.c_str());
return MODEL_NOT_EXIST; return MODEL_NOT_EXIST;
} }
net = migraphx::parse_onnx(modelPath, onnx_options); net = migraphx::parse_onnx(modelPath, onnx_options);
LOG_INFO(stdout,"succeed to load model: %s\n",GetFileName(modelPath).c_str()); LOG_INFO(stdout, "succeed to load model: %s\n", GetFileName(modelPath).c_str());
// 获取模型输入/输出节点信息 // 获取模型输入/输出节点信息
std::unordered_map<std::string, migraphx::shape> inputs=net.get_inputs(); std::unordered_map<std::string, migraphx::shape> inputs = net.get_inputs();
std::unordered_map<std::string, migraphx::shape> outputs=net.get_outputs(); std::unordered_map<std::string, migraphx::shape> outputs = net.get_outputs();
inputName1 = "unique_ids_raw_output___9:0"; inputName1 = "unique_ids_raw_output___9:0";
inputShape1 = inputs.at(inputName1); inputShape1 = inputs.at(inputName1);
...@@ -65,27 +56,27 @@ ErrorCode Bert::Initialize() ...@@ -65,27 +56,27 @@ ErrorCode Bert::Initialize()
// 编译模型 // 编译模型
migraphx::compile_options options; migraphx::compile_options options;
options.device_id=0; // 设置GPU设备,默认为0号设备 options.device_id = 0; // 设置GPU设备,默认为0号设备
options.offload_copy=true; options.offload_copy = true;
net.compile(gpuTarget,options); net.compile(gpuTarget, options);
LOG_INFO(stdout,"succeed to compile model: %s\n",GetFileName(modelPath).c_str()); LOG_INFO(stdout, "succeed to compile model: %s\n", GetFileName(modelPath).c_str());
// warm up // warm up
std::unordered_map<std::string, migraphx::argument> inputData; std::unordered_map<std::string, migraphx::argument> inputData;
inputData[inputName1]=migraphx::argument(inputShape1); inputData[inputName1] = migraphx::argument(inputShape1);
inputData[inputName2]=migraphx::argument(inputShape2); inputData[inputName2] = migraphx::argument(inputShape2);
inputData[inputName3]=migraphx::argument(inputShape3); inputData[inputName3] = migraphx::argument(inputShape3);
inputData[inputName4]=migraphx::argument(inputShape4); inputData[inputName4] = migraphx::argument(inputShape4);
net.eval(inputData); net.eval(inputData);
return SUCCESS; return SUCCESS;
} }
ErrorCode Bert::Inference(const std::vector<std::vector<long unsigned int>> &input_ids, ErrorCode Bert::Inference(const std::vector<std::vector<long unsigned int>>& input_ids,
const std::vector<std::vector<long unsigned int>> &input_masks, const std::vector<std::vector<long unsigned int>>& input_masks,
const std::vector<std::vector<long unsigned int>> &segment_ids, const std::vector<std::vector<long unsigned int>>& segment_ids,
std::vector<float> &start_position, std::vector<float>& start_position,
std::vector<float> &end_position) std::vector<float>& end_position)
{ {
// 保存预处理后的数据 // 保存预处理后的数据
int num = input_ids.size(); int num = input_ids.size();
...@@ -93,9 +84,9 @@ ErrorCode Bert::Inference(const std::vector<std::vector<long unsigned int>> &inp ...@@ -93,9 +84,9 @@ ErrorCode Bert::Inference(const std::vector<std::vector<long unsigned int>> &inp
long unsigned int input_mask[num][256]; long unsigned int input_mask[num][256];
long unsigned int segment_id[num][256]; long unsigned int segment_id[num][256];
long unsigned int position_id[num][1]; long unsigned int position_id[num][1];
for(int i=0;i<input_ids.size();++i) for(int i = 0; i < input_ids.size(); ++i)
{ {
for(int j=0;j<input_ids[0].size();++j) for(int j = 0; j < input_ids[0].size(); ++j)
{ {
input_id[i][j] = input_ids[i][j]; input_id[i][j] = input_ids[i][j];
segment_id[i][j] = segment_ids[i][j]; segment_id[i][j] = segment_ids[i][j];
...@@ -111,25 +102,25 @@ ErrorCode Bert::Inference(const std::vector<std::vector<long unsigned int>> &inp ...@@ -111,25 +102,25 @@ ErrorCode Bert::Inference(const std::vector<std::vector<long unsigned int>> &inp
float* start_data; float* start_data;
float* end_data; float* end_data;
for(int i=0;i<input_ids.size();++i) for(int i = 0; i < input_ids.size(); ++i)
{ {
// 创建输入数据 // 创建输入数据
inputData[inputName1]=migraphx::argument{inputShape1, (long unsigned int*)position_id[i]}; inputData[inputName1] = migraphx::argument{inputShape1, (long unsigned int*)position_id[i]};
inputData[inputName2]=migraphx::argument{inputShape2, (long unsigned int*)segment_id[i]}; inputData[inputName2] = migraphx::argument{inputShape2, (long unsigned int*)segment_id[i]};
inputData[inputName3]=migraphx::argument{inputShape3, (long unsigned int*)input_mask[i]}; inputData[inputName3] = migraphx::argument{inputShape3, (long unsigned int*)input_mask[i]};
inputData[inputName4]=migraphx::argument{inputShape4, (long unsigned int*)input_id[i]}; inputData[inputName4] = migraphx::argument{inputShape4, (long unsigned int*)input_id[i]};
// 推理 // 推理
results = net.eval(inputData); results = net.eval(inputData);
// 获取输出节点的属性 // 获取输出节点的属性
start_prediction = results[1]; // 答案的开始位置 start_prediction = results[1]; // 答案的开始位置
start_data = (float *)start_prediction.data(); // 开始位置的数据指针 start_data = (float*)start_prediction.data(); // 开始位置的数据指针
end_prediction = results[0]; // 答案的结束位置 end_prediction = results[0]; // 答案的结束位置
end_data = (float *)end_prediction.data(); // 结束位置的数据指针 end_data = (float*)end_prediction.data(); // 结束位置的数据指针
// 保存推理结果 // 保存推理结果
for(int i=0;i<256;++i) for(int i = 0; i < 256; ++i)
{ {
start_position.push_back(start_data[i]); start_position.push_back(start_data[i]);
end_position.push_back(end_data[i]); end_position.push_back(end_data[i]);
...@@ -142,11 +133,11 @@ ErrorCode Bert::Inference(const std::vector<std::vector<long unsigned int>> &inp ...@@ -142,11 +133,11 @@ ErrorCode Bert::Inference(const std::vector<std::vector<long unsigned int>> &inp
ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer, ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer,
int batch_size, int batch_size,
int max_seq_length, int max_seq_length,
const char *text, const char* text,
char *question, char* question,
std::vector<std::vector<long unsigned int>> &input_ids, std::vector<std::vector<long unsigned int>>& input_ids,
std::vector<std::vector<long unsigned int>> &input_masks, std::vector<std::vector<long unsigned int>>& input_masks,
std::vector<std::vector<long unsigned int>> &segment_ids) std::vector<std::vector<long unsigned int>>& segment_ids)
{ {
std::vector<long unsigned int> input_id(max_seq_length); std::vector<long unsigned int> input_id(max_seq_length);
std::vector<long unsigned int> input_mask(max_seq_length); std::vector<long unsigned int> input_mask(max_seq_length);
...@@ -161,18 +152,18 @@ ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer, ...@@ -161,18 +152,18 @@ ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer,
// 当上下文文本加问题文本的长度大于规定的最大长度,采用滑动窗口操作 // 当上下文文本加问题文本的长度大于规定的最大长度,采用滑动窗口操作
if(tokens_text.size() + tokens_question.size() > max_seq_length - 5) if(tokens_text.size() + tokens_question.size() > max_seq_length - 5)
{ {
int windows_len = max_seq_length - 5 -tokens_question.size(); int windows_len = max_seq_length - 5 - tokens_question.size();
std::vector<std::string> tokens_text_window(windows_len); std::vector<std::string> tokens_text_window(windows_len);
std::vector<std::vector<std::string>> tokens_text_windows; std::vector<std::vector<std::string>> tokens_text_windows;
int start_offset = 0; int start_offset = 0;
int position = 0; int position = 0;
int n; int n;
while (start_offset < tokens_text.size()) while(start_offset < tokens_text.size())
{ {
n = 0; n = 0;
if(start_offset+windows_len>tokens_text.size()) if(start_offset + windows_len > tokens_text.size())
{ {
for(int i=start_offset;i<tokens_text.size();++i) for(int i = start_offset; i < tokens_text.size(); ++i)
{ {
tokens_text_window[n] = tokens_text[i]; tokens_text_window[n] = tokens_text[i];
++n; ++n;
...@@ -180,7 +171,7 @@ ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer, ...@@ -180,7 +171,7 @@ ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer,
} }
else else
{ {
for(int i=start_offset;i<start_offset+windows_len;++i) for(int i = start_offset; i < start_offset + windows_len; ++i)
{ {
tokens_text_window[n] = tokens_text[i]; tokens_text_window[n] = tokens_text[i];
++n; ++n;
...@@ -191,7 +182,7 @@ ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer, ...@@ -191,7 +182,7 @@ ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer,
++position; ++position;
} }
for(int i=0;i<position;++i) for(int i = 0; i < position; ++i)
{ {
input_id[0] = tokenizer.convert_token_to_id("[CLS]"); input_id[0] = tokenizer.convert_token_to_id("[CLS]");
segment_id[0] = 0; segment_id[0] = 0;
...@@ -199,7 +190,7 @@ ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer, ...@@ -199,7 +190,7 @@ ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer,
input_id[1] = tokenizer.convert_token_to_id("[CLS]"); input_id[1] = tokenizer.convert_token_to_id("[CLS]");
segment_id[1] = 0; segment_id[1] = 0;
for (int j=0;j<tokens_question.size();++j) for(int j = 0; j < tokens_question.size(); ++j)
{ {
input_id[j + 2] = tokenizer.convert_token_to_id(tokens_question[j]); input_id[j + 2] = tokenizer.convert_token_to_id(tokens_question[j]);
segment_id[j + 2] = 0; segment_id[j + 2] = 0;
...@@ -211,13 +202,15 @@ ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer, ...@@ -211,13 +202,15 @@ ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer,
input_id[tokens_question.size() + 3] = tokenizer.convert_token_to_id("[SEP]"); input_id[tokens_question.size() + 3] = tokenizer.convert_token_to_id("[SEP]");
segment_id[tokens_question.size() + 3] = 0; segment_id[tokens_question.size() + 3] = 0;
for (int j=0;j<tokens_question.size();++j) for(int j = 0; j < tokens_question.size(); ++j)
{ {
input_id[j + tokens_text_windows[i].size() + 4] = tokenizer.convert_token_to_id(tokens_text_windows[i][j]); input_id[j + tokens_text_windows[i].size() + 4] =
tokenizer.convert_token_to_id(tokens_text_windows[i][j]);
segment_id[j + tokens_text_windows[i].size() + 4] = 1; segment_id[j + tokens_text_windows[i].size() + 4] = 1;
} }
input_id[tokens_question.size() + tokens_text_windows[i].size() + 4] = tokenizer.convert_token_to_id("[SEP]"); input_id[tokens_question.size() + tokens_text_windows[i].size() + 4] =
tokenizer.convert_token_to_id("[SEP]");
segment_id[tokens_question.size() + tokens_text_windows[i].size() + 4] = 1; segment_id[tokens_question.size() + tokens_text_windows[i].size() + 4] = 1;
// 掩码为1的表示为真实标记,0表示为填充标记。 // 掩码为1的表示为真实标记,0表示为填充标记。
...@@ -240,7 +233,7 @@ ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer, ...@@ -240,7 +233,7 @@ ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer,
input_id[1] = tokenizer.convert_token_to_id("[CLS]"); input_id[1] = tokenizer.convert_token_to_id("[CLS]");
segment_id[1] = 0; segment_id[1] = 0;
for (int i=0;i<tokens_question.size();++i) for(int i = 0; i < tokens_question.size(); ++i)
{ {
input_id[i + 2] = tokenizer.convert_token_to_id(tokens_question[i]); input_id[i + 2] = tokenizer.convert_token_to_id(tokens_question[i]);
segment_id[i + 2] = 0; segment_id[i + 2] = 0;
...@@ -252,13 +245,15 @@ ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer, ...@@ -252,13 +245,15 @@ ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer,
input_id[tokens_question.size() + 3] = tokenizer.convert_token_to_id("[SEP]"); input_id[tokens_question.size() + 3] = tokenizer.convert_token_to_id("[SEP]");
segment_id[tokens_question.size() + 3] = 0; segment_id[tokens_question.size() + 3] = 0;
for (int i=0;i<tokens_text.size();++i) for(int i = 0; i < tokens_text.size(); ++i)
{ {
input_id[i + tokens_question.size() + 4] = tokenizer.convert_token_to_id(tokens_text[i]); input_id[i + tokens_question.size() + 4] =
tokenizer.convert_token_to_id(tokens_text[i]);
segment_id[i + tokens_question.size() + 4] = 1; segment_id[i + tokens_question.size() + 4] = 1;
} }
input_id[tokens_question.size() + tokens_text.size() + 4] = tokenizer.convert_token_to_id("[SEP]"); input_id[tokens_question.size() + tokens_text.size() + 4] =
tokenizer.convert_token_to_id("[SEP]");
segment_id[tokens_question.size() + tokens_text.size() + 4] = 1; segment_id[tokens_question.size() + tokens_text.size() + 4] = 1;
// 掩码为1的表示为真实标记,0表示为填充标记。 // 掩码为1的表示为真实标记,0表示为填充标记。
...@@ -275,27 +270,25 @@ ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer, ...@@ -275,27 +270,25 @@ ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer,
return SUCCESS; return SUCCESS;
} }
static bool Compare(Sort_st a, Sort_st b) static bool Compare(Sort_st a, Sort_st b) { return a.value > b.value; }
{
return a.value > b.value;
}
static bool CompareM(ResultOfPredictions a, ResultOfPredictions b) static bool CompareM(ResultOfPredictions a, ResultOfPredictions b)
{ {
return a.start_predictionvalue + a.end_predictionvalue > b.start_predictionvalue + b.end_predictionvalue; return a.start_predictionvalue + a.end_predictionvalue >
b.start_predictionvalue + b.end_predictionvalue;
} }
ErrorCode Bert::Postprocessing(int n_best_size, ErrorCode Bert::Postprocessing(int n_best_size,
int max_answer_length, int max_answer_length,
const std::vector<float> &start_position, const std::vector<float>& start_position,
const std::vector<float> &end_position, const std::vector<float>& end_position,
std::string &answer) std::string& answer)
{ {
// 取前n_best_size个最大概率值的索引 // 取前n_best_size个最大概率值的索引
std::vector<Sort_st> start_array(start_position.size()); std::vector<Sort_st> start_array(start_position.size());
std::vector<Sort_st> end_array(end_position.size()); std::vector<Sort_st> end_array(end_position.size());
for (int i=0;i<start_position.size();++i) for(int i = 0; i < start_position.size(); ++i)
{ {
start_array[i].index = i; start_array[i].index = i;
start_array[i].value = start_position.at(i); start_array[i].value = start_position.at(i);
...@@ -309,10 +302,10 @@ ErrorCode Bert::Postprocessing(int n_best_size, ...@@ -309,10 +302,10 @@ ErrorCode Bert::Postprocessing(int n_best_size,
std::vector<ResultOfPredictions> resultsOfPredictions(400); std::vector<ResultOfPredictions> resultsOfPredictions(400);
int num = start_position.size() / 256; int num = start_position.size() / 256;
bool flag; bool flag;
int n=0; int n = 0;
for(int i=0;i<n_best_size;++i) for(int i = 0; i < n_best_size; ++i)
{ {
for(int j=0;j<n_best_size;++j) for(int j = 0; j < n_best_size; ++j)
{ {
flag = false; flag = false;
if(start_array[i].index > start_position.size()) if(start_array[i].index > start_position.size())
...@@ -325,15 +318,17 @@ ErrorCode Bert::Postprocessing(int n_best_size, ...@@ -325,15 +318,17 @@ ErrorCode Bert::Postprocessing(int n_best_size,
continue; continue;
} }
for(int t=0;t<num;++t) for(int t = 0; t < num; ++t)
{ {
if(start_array[i].index > t*256 && start_array[i].index < tokens_question.size()+4+t*256) if(start_array[i].index > t * 256 &&
start_array[i].index < tokens_question.size() + 4 + t * 256)
{ {
flag = true; flag = true;
break; break;
} }
if(end_array[j].index > t*256 && end_array[j].index < tokens_question.size()+4+t*256) if(end_array[j].index > t * 256 &&
end_array[j].index < tokens_question.size() + 4 + t * 256)
{ {
flag = true; flag = true;
break; break;
...@@ -368,9 +363,10 @@ ErrorCode Bert::Postprocessing(int n_best_size, ...@@ -368,9 +363,10 @@ ErrorCode Bert::Postprocessing(int n_best_size,
int start_index = 0; int start_index = 0;
int end_index = 0; int end_index = 0;
for(int i=0;i<400;++i) for(int i = 0; i < 400; ++i)
{ {
if(resultsOfPredictions[i].start_predictionvalue==0 && resultsOfPredictions[i].end_predictionvalue==0) if(resultsOfPredictions[i].start_predictionvalue == 0 &&
resultsOfPredictions[i].end_predictionvalue == 0)
{ {
continue; continue;
} }
...@@ -380,24 +376,24 @@ ErrorCode Bert::Postprocessing(int n_best_size, ...@@ -380,24 +376,24 @@ ErrorCode Bert::Postprocessing(int n_best_size,
} }
// 映射回上下文文本的索引,(当前的索引值-问题的长度-4) // 映射回上下文文本的索引,(当前的索引值-问题的长度-4)
int answer_start_index = start_index - tokens_question.size()- 4; int answer_start_index = start_index - tokens_question.size() - 4;
int answer_end_index = end_index - tokens_question.size() - 4 + 1; int answer_end_index = end_index - tokens_question.size() - 4 + 1;
// 根据开始索引和结束索引,获取区间内的数据 // 根据开始索引和结束索引,获取区间内的数据
int j=0; int j = 0;
for(int i=answer_start_index;i<answer_end_index;++i) for(int i = answer_start_index; i < answer_end_index; ++i)
{ {
if(tokens_text[i].find('#') != -1) if(tokens_text[i].find('#') != -1)
{ {
j=i-1; j = i - 1;
break; break;
} }
} }
for(int i=answer_start_index;i<answer_end_index;++i) for(int i = answer_start_index; i < answer_end_index; ++i)
{ {
answer += tokens_text[i]; answer += tokens_text[i];
if(tokens_text[i].find('#') != -1 || i==j) if(tokens_text[i].find('#') != -1 || i == j)
{ {
continue; continue;
} }
...@@ -405,9 +401,9 @@ ErrorCode Bert::Postprocessing(int n_best_size, ...@@ -405,9 +401,9 @@ ErrorCode Bert::Postprocessing(int n_best_size,
} }
int index = 0; int index = 0;
while( (index = answer.find('#',index)) != string::npos) while((index = answer.find('#', index)) != string::npos)
{ {
answer.erase(index,1); answer.erase(index, 1);
} }
tokens_text.clear(); tokens_text.clear();
tokens_question.clear(); tokens_question.clear();
...@@ -415,5 +411,4 @@ ErrorCode Bert::Postprocessing(int n_best_size, ...@@ -415,5 +411,4 @@ ErrorCode Bert::Postprocessing(int n_best_size,
return SUCCESS; return SUCCESS;
} }
} } // namespace migraphxSamples
#ifndef __BERT_H__ #ifndef __BERT_H__
#define __BERT_H__ #define __BERT_H__
#include <tokenization.h>
#include <cstdint> #include <cstdint>
#include <string>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <tokenization.h> #include <string>
namespace migraphxSamples namespace migraphxSamples {
typedef enum _ErrorCode
{ {
typedef enum _ErrorCode SUCCESS = 0,
{
SUCCESS=0,
MODEL_NOT_EXIST, MODEL_NOT_EXIST,
CONFIG_FILE_NOT_EXIST, CONFIG_FILE_NOT_EXIST,
FAIL_TO_LOAD_MODEL, FAIL_TO_LOAD_MODEL,
FAIL_TO_OPEN_CONFIG_FILE, FAIL_TO_OPEN_CONFIG_FILE,
}ErrorCode; } ErrorCode;
typedef struct _Sort_st typedef struct _Sort_st
{ {
int index; int index;
float value; float value;
}Sort_st; } Sort_st;
typedef struct _ResultOfPredictions typedef struct _ResultOfPredictions
{ {
int start_index; int start_index;
int end_index; int end_index;
float start_predictionvalue; float start_predictionvalue;
float end_predictionvalue; float end_predictionvalue;
}ResultOfPredictions; } ResultOfPredictions;
class Bert class Bert
{ {
public: public:
Bert(); Bert();
~Bert(); ~Bert();
ErrorCode Initialize(); ErrorCode Initialize();
ErrorCode Inference(const std::vector<std::vector<long unsigned int>> &input_ids, ErrorCode Inference(const std::vector<std::vector<long unsigned int>>& input_ids,
const std::vector<std::vector<long unsigned int>> &input_masks, const std::vector<std::vector<long unsigned int>>& input_masks,
const std::vector<std::vector<long unsigned int>> &segment_ids, const std::vector<std::vector<long unsigned int>>& segment_ids,
std::vector<float> &start_position, std::vector<float>& start_position,
std::vector<float> &end_position); std::vector<float>& end_position);
ErrorCode Preprocessing(cuBERT::FullTokenizer tokenizer, ErrorCode Preprocessing(cuBERT::FullTokenizer tokenizer,
int batch_size, int batch_size,
int max_seq_length, int max_seq_length,
const char *text, const char* text,
char *question, char* question,
std::vector<std::vector<long unsigned int>> &input_ids, std::vector<std::vector<long unsigned int>>& input_ids,
std::vector<std::vector<long unsigned int>> &input_masks, std::vector<std::vector<long unsigned int>>& input_masks,
std::vector<std::vector<long unsigned int>> &segment_ids); std::vector<std::vector<long unsigned int>>& segment_ids);
ErrorCode Postprocessing(int n_best_size, ErrorCode Postprocessing(int n_best_size,
int max_answer_length, int max_answer_length,
const std::vector<float> &start_position, const std::vector<float>& start_position,
const std::vector<float> &end_position, const std::vector<float>& end_position,
std::string &answer); std::string& answer);
private: private:
std::vector<std::string> tokens_text; std::vector<std::string> tokens_text;
std::vector<std::string> tokens_question; std::vector<std::string> tokens_question;
...@@ -74,9 +74,8 @@ private: ...@@ -74,9 +74,8 @@ private:
migraphx::shape inputShape2; migraphx::shape inputShape2;
migraphx::shape inputShape3; migraphx::shape inputShape3;
migraphx::shape inputShape4; migraphx::shape inputShape4;
}; };
} } // namespace migraphxSamples
#endif #endif
\ No newline at end of file
#include <Filesystem.h> #include <Filesystem.h>
#include <algorithm>
#include <sys/stat.h> #include <sys/stat.h>
#include <sys/types.h> #include <sys/types.h>
#include <algorithm>
#include <fstream> #include <fstream>
#ifdef _WIN32 #ifdef _WIN32
#include <io.h>
#include <direct.h>
#include <Windows.h> #include <Windows.h>
#include <direct.h>
#include <io.h>
#else #else
#include <unistd.h>
#include <dirent.h> #include <dirent.h>
#include <unistd.h>
#endif #endif
// 路径分隔符(Linux:‘/’,Windows:’\\’) // 路径分隔符(Linux:‘/’,Windows:’\\’)
...@@ -21,39 +25,38 @@ ...@@ -21,39 +25,38 @@
using namespace std; using namespace std;
namespace migraphxSamples namespace migraphxSamples {
{
static std::vector<std::string> SplitString(std::string str, std::string separator) static std::vector<std::string> SplitString(std::string str, std::string separator)
{ {
std::string::size_type pos; std::string::size_type pos;
std::vector<std::string> result; std::vector<std::string> result;
str+=separator;//扩展字符串以方便操作 str += separator; // 扩展字符串以方便操作
int size=str.size(); int size = str.size();
for(int i=0; i<size; i++) for(int i = 0; i < size; i++)
{ {
pos=str.find(separator,i); pos = str.find(separator, i);
if(pos<size) if(pos < size)
{ {
std::string s=str.substr(i,pos-i); std::string s = str.substr(i, pos - i);
result.push_back(s); result.push_back(s);
i=pos+separator.size()-1; i = pos + separator.size() - 1;
} }
} }
return result; return result;
} }
#if defined _WIN32 || defined WINCE #if defined _WIN32 || defined WINCE
const char dir_separators[] = "/\\"; const char dir_separators[] = "/\\";
struct dirent struct dirent
{ {
const char* d_name; const char* d_name;
}; };
struct DIR struct DIR
{ {
#ifdef WINRT #ifdef WINRT
WIN32_FIND_DATAW data; WIN32_FIND_DATAW data;
#else #else
...@@ -62,17 +65,17 @@ static std::vector<std::string> SplitString(std::string str, std::string separat ...@@ -62,17 +65,17 @@ static std::vector<std::string> SplitString(std::string str, std::string separat
HANDLE handle; HANDLE handle;
dirent ent; dirent ent;
#ifdef WINRT #ifdef WINRT
DIR() { } DIR() {}
~DIR() ~DIR()
{ {
if (ent.d_name) if(ent.d_name)
delete[] ent.d_name; delete[] ent.d_name;
} }
#endif #endif
}; };
DIR* opendir(const char* path) DIR* opendir(const char* path)
{ {
DIR* dir = new DIR; DIR* dir = new DIR;
dir->ent.d_name = 0; dir->ent.d_name = 0;
#ifdef WINRT #ifdef WINRT
...@@ -80,27 +83,31 @@ static std::vector<std::string> SplitString(std::string str, std::string separat ...@@ -80,27 +83,31 @@ static std::vector<std::string> SplitString(std::string str, std::string separat
wchar_t wfull_path[MAX_PATH]; wchar_t wfull_path[MAX_PATH];
size_t copied = mbstowcs(wfull_path, full_path.c_str(), MAX_PATH); size_t copied = mbstowcs(wfull_path, full_path.c_str(), MAX_PATH);
CV_Assert((copied != MAX_PATH) && (copied != (size_t)-1)); CV_Assert((copied != MAX_PATH) && (copied != (size_t)-1));
dir->handle = ::FindFirstFileExW(wfull_path, FindExInfoStandard, dir->handle = ::FindFirstFileExW(
&dir->data, FindExSearchNameMatch, NULL, 0); wfull_path, FindExInfoStandard, &dir->data, FindExSearchNameMatch, NULL, 0);
#else #else
dir->handle = ::FindFirstFileExA((string(path) + "\\*").c_str(), dir->handle = ::FindFirstFileExA((string(path) + "\\*").c_str(),
FindExInfoStandard, &dir->data, FindExSearchNameMatch, NULL, 0); FindExInfoStandard,
&dir->data,
FindExSearchNameMatch,
NULL,
0);
#endif #endif
if (dir->handle == INVALID_HANDLE_VALUE) if(dir->handle == INVALID_HANDLE_VALUE)
{ {
/*closedir will do all cleanup*/ /*closedir will do all cleanup*/
delete dir; delete dir;
return 0; return 0;
} }
return dir; return dir;
} }
dirent* readdir(DIR* dir) dirent* readdir(DIR* dir)
{ {
#ifdef WINRT #ifdef WINRT
if (dir->ent.d_name != 0) if(dir->ent.d_name != 0)
{ {
if (::FindNextFileW(dir->handle, &dir->data) != TRUE) if(::FindNextFileW(dir->handle, &dir->data) != TRUE)
return 0; return 0;
} }
size_t asize = wcstombs(NULL, dir->data.cFileName, 0); size_t asize = wcstombs(NULL, dir->data.cFileName, 0);
...@@ -110,33 +117,33 @@ static std::vector<std::string> SplitString(std::string str, std::string separat ...@@ -110,33 +117,33 @@ static std::vector<std::string> SplitString(std::string str, std::string separat
wcstombs(aname, dir->data.cFileName, asize); wcstombs(aname, dir->data.cFileName, asize);
dir->ent.d_name = aname; dir->ent.d_name = aname;
#else #else
if (dir->ent.d_name != 0) if(dir->ent.d_name != 0)
{ {
if (::FindNextFileA(dir->handle, &dir->data) != TRUE) if(::FindNextFileA(dir->handle, &dir->data) != TRUE)
return 0; return 0;
} }
dir->ent.d_name = dir->data.cFileName; dir->ent.d_name = dir->data.cFileName;
#endif #endif
return &dir->ent; return &dir->ent;
} }
void closedir(DIR* dir) void closedir(DIR* dir)
{ {
::FindClose(dir->handle); ::FindClose(dir->handle);
delete dir; delete dir;
} }
#else #else
# include <dirent.h> #include <dirent.h>
# include <sys/stat.h> #include <sys/stat.h>
const char dir_separators[] = "/"; const char dir_separators[] = "/";
#endif #endif
static bool isDir(const string &path, DIR* dir) static bool isDir(const string& path, DIR* dir)
{ {
#if defined _WIN32 || defined WINCE #if defined _WIN32 || defined WINCE
DWORD attributes; DWORD attributes;
BOOL status = TRUE; BOOL status = TRUE;
if (dir) if(dir)
attributes = dir->data.dwFileAttributes; attributes = dir->data.dwFileAttributes;
else else
{ {
...@@ -156,21 +163,17 @@ static std::vector<std::string> SplitString(std::string str, std::string separat ...@@ -156,21 +163,17 @@ static std::vector<std::string> SplitString(std::string str, std::string separat
#else #else
(void)dir; (void)dir;
struct stat stat_buf; struct stat stat_buf;
if (0 != stat(path.c_str(), &stat_buf)) if(0 != stat(path.c_str(), &stat_buf))
return false; return false;
int is_dir = S_ISDIR(stat_buf.st_mode); int is_dir = S_ISDIR(stat_buf.st_mode);
return is_dir != 0; return is_dir != 0;
#endif #endif
} }
bool IsDirectory(const string &path) bool IsDirectory(const string& path) { return isDir(path, NULL); }
{
return isDir(path, NULL);
}
bool Exists(const string& path)
{
bool Exists(const string& path)
{
#if defined _WIN32 || defined WINCE #if defined _WIN32 || defined WINCE
BOOL status = TRUE; BOOL status = TRUE;
{ {
...@@ -190,28 +193,25 @@ static std::vector<std::string> SplitString(std::string str, std::string separat ...@@ -190,28 +193,25 @@ static std::vector<std::string> SplitString(std::string str, std::string separat
struct stat stat_buf; struct stat stat_buf;
return (0 == stat(path.c_str(), &stat_buf)); return (0 == stat(path.c_str(), &stat_buf));
#endif #endif
} }
bool IsPathSeparator(char c) bool IsPathSeparator(char c) { return c == '/' || c == '\\'; }
{
return c == '/' || c == '\\';
}
string JoinPath(const string& base, const string& path) string JoinPath(const string& base, const string& path)
{ {
if (base.empty()) if(base.empty())
return path; return path;
if (path.empty()) if(path.empty())
return base; return base;
bool baseSep = IsPathSeparator(base[base.size() - 1]); bool baseSep = IsPathSeparator(base[base.size() - 1]);
bool pathSep = IsPathSeparator(path[0]); bool pathSep = IsPathSeparator(path[0]);
string result; string result;
if (baseSep && pathSep) if(baseSep && pathSep)
{ {
result = base + path.substr(1); result = base + path.substr(1);
} }
else if (!baseSep && !pathSep) else if(!baseSep && !pathSep)
{ {
result = base + PATH_SEPARATOR + path; result = base + PATH_SEPARATOR + path;
} }
...@@ -220,15 +220,15 @@ static std::vector<std::string> SplitString(std::string str, std::string separat ...@@ -220,15 +220,15 @@ static std::vector<std::string> SplitString(std::string str, std::string separat
result = base + path; result = base + path;
} }
return result; return result;
} }
static bool wildcmp(const char *string, const char *wild) static bool wildcmp(const char* string, const char* wild)
{ {
const char *cp = 0, *mp = 0; const char *cp = 0, *mp = 0;
while ((*string) && (*wild != '*')) while((*string) && (*wild != '*'))
{ {
if ((*wild != *string) && (*wild != '?')) if((*wild != *string) && (*wild != '?'))
{ {
return false; return false;
} }
...@@ -237,11 +237,11 @@ static std::vector<std::string> SplitString(std::string str, std::string separat ...@@ -237,11 +237,11 @@ static std::vector<std::string> SplitString(std::string str, std::string separat
string++; string++;
} }
while (*string) while(*string)
{ {
if (*wild == '*') if(*wild == '*')
{ {
if (!*++wild) if(!*++wild)
{ {
return true; return true;
} }
...@@ -249,7 +249,7 @@ static std::vector<std::string> SplitString(std::string str, std::string separat ...@@ -249,7 +249,7 @@ static std::vector<std::string> SplitString(std::string str, std::string separat
mp = wild; mp = wild;
cp = string + 1; cp = string + 1;
} }
else if ((*wild == *string) || (*wild == '?')) else if((*wild == *string) || (*wild == '?'))
{ {
wild++; wild++;
string++; string++;
...@@ -261,47 +261,52 @@ static std::vector<std::string> SplitString(std::string str, std::string separat ...@@ -261,47 +261,52 @@ static std::vector<std::string> SplitString(std::string str, std::string separat
} }
} }
while (*wild == '*') while(*wild == '*')
{ {
wild++; wild++;
} }
return *wild == 0; return *wild == 0;
} }
static void glob_rec(const string &directory, const string& wildchart, std::vector<string>& result, static void glob_rec(const string& directory,
bool recursive, bool includeDirectories, const string& pathPrefix) const string& wildchart,
{ std::vector<string>& result,
DIR *dir; bool recursive,
bool includeDirectories,
const string& pathPrefix)
{
DIR* dir;
if ((dir = opendir(directory.c_str())) != 0) if((dir = opendir(directory.c_str())) != 0)
{ {
/* find all the files and directories within directory */ /* find all the files and directories within directory */
try try
{ {
struct dirent *ent; struct dirent* ent;
while ((ent = readdir(dir)) != 0) while((ent = readdir(dir)) != 0)
{ {
const char* name = ent->d_name; const char* name = ent->d_name;
if ((name[0] == 0) || (name[0] == '.' && name[1] == 0) || (name[0] == '.' && name[1] == '.' && name[2] == 0)) if((name[0] == 0) || (name[0] == '.' && name[1] == 0) ||
(name[0] == '.' && name[1] == '.' && name[2] == 0))
continue; continue;
string path = JoinPath(directory, name); string path = JoinPath(directory, name);
string entry = JoinPath(pathPrefix, name); string entry = JoinPath(pathPrefix, name);
if (isDir(path, dir)) if(isDir(path, dir))
{ {
if (recursive) if(recursive)
glob_rec(path, wildchart, result, recursive, includeDirectories, entry); glob_rec(path, wildchart, result, recursive, includeDirectories, entry);
if (!includeDirectories) if(!includeDirectories)
continue; continue;
} }
if (wildchart.empty() || wildcmp(name, wildchart.c_str())) if(wildchart.empty() || wildcmp(name, wildchart.c_str()))
result.push_back(entry); result.push_back(entry);
} }
} }
catch (...) catch(...)
{ {
closedir(dir); closedir(dir);
throw; throw;
...@@ -312,23 +317,27 @@ static std::vector<std::string> SplitString(std::string str, std::string separat ...@@ -312,23 +317,27 @@ static std::vector<std::string> SplitString(std::string str, std::string separat
{ {
printf("could not open directory: %s", directory.c_str()); printf("could not open directory: %s", directory.c_str());
} }
} }
void GetFileNameList(const string &directory, const string &pattern, std::vector<string>& result, bool recursive, bool addPath) void GetFileNameList(const string& directory,
{ const string& pattern,
std::vector<string>& result,
bool recursive,
bool addPath)
{
// split pattern // split pattern
vector<string> patterns=SplitString(pattern,","); vector<string> patterns = SplitString(pattern, ",");
result.clear(); result.clear();
for(int i=0;i<patterns.size();++i) for(int i = 0; i < patterns.size(); ++i)
{ {
string eachPattern=patterns[i]; string eachPattern = patterns[i];
std::vector<string> eachResult; std::vector<string> eachResult;
glob_rec(directory, eachPattern, eachResult, recursive, true, directory); glob_rec(directory, eachPattern, eachResult, recursive, true, directory);
for(int j=0;j<eachResult.size();++j) for(int j = 0; j < eachResult.size(); ++j)
{ {
if (IsDirectory(eachResult[j])) if(IsDirectory(eachResult[j]))
continue; continue;
if(addPath) if(addPath)
{ {
...@@ -341,41 +350,45 @@ static std::vector<std::string> SplitString(std::string str, std::string separat ...@@ -341,41 +350,45 @@ static std::vector<std::string> SplitString(std::string str, std::string separat
} }
} }
std::sort(result.begin(), result.end()); std::sort(result.begin(), result.end());
} }
void GetFileNameList2(const string &directory, const string &pattern, std::vector<string>& result, bool recursive, bool addPath) void GetFileNameList2(const string& directory,
{ const string& pattern,
std::vector<string>& result,
bool recursive,
bool addPath)
{
// split pattern // split pattern
vector<string> patterns = SplitString(pattern, ","); vector<string> patterns = SplitString(pattern, ",");
result.clear(); result.clear();
for (int i = 0; i<patterns.size(); ++i) for(int i = 0; i < patterns.size(); ++i)
{ {
string eachPattern = patterns[i]; string eachPattern = patterns[i];
std::vector<string> eachResult; std::vector<string> eachResult;
glob_rec(directory, eachPattern, eachResult, recursive, true, directory); glob_rec(directory, eachPattern, eachResult, recursive, true, directory);
for (int j = 0; j<eachResult.size(); ++j) for(int j = 0; j < eachResult.size(); ++j)
{ {
string filePath = eachResult[j]; string filePath = eachResult[j];
if (IsDirectory(filePath)) if(IsDirectory(filePath))
{ {
filePath = filePath + "/"; filePath = filePath + "/";
for (int k = 0; k < filePath.size(); ++k) for(int k = 0; k < filePath.size(); ++k)
{ {
if (IsPathSeparator(filePath[k])) if(IsPathSeparator(filePath[k]))
{ {
filePath[k] = '/'; filePath[k] = '/';
} }
} }
} }
if (addPath) if(addPath)
{ {
result.push_back(filePath); result.push_back(filePath);
} }
else else
{ {
if (!IsDirectory(filePath)) if(!IsDirectory(filePath))
{ {
result.push_back(GetFileName(filePath)); result.push_back(GetFileName(filePath));
} }
...@@ -383,19 +396,18 @@ static std::vector<std::string> SplitString(std::string str, std::string separat ...@@ -383,19 +396,18 @@ static std::vector<std::string> SplitString(std::string str, std::string separat
} }
} }
std::sort(result.begin(), result.end()); std::sort(result.begin(), result.end());
} }
void RemoveAll(const string& path)
{
if (!Exists(path)) void RemoveAll(const string& path)
{
if(!Exists(path))
return; return;
if (IsDirectory(path)) if(IsDirectory(path))
{ {
std::vector<string> entries; std::vector<string> entries;
GetFileNameList2(path, string(), entries, false, true); GetFileNameList2(path, string(), entries, false, true);
for (size_t i = 0; i < entries.size(); i++) for(size_t i = 0; i < entries.size(); i++)
{ {
const string& e = entries[i]; const string& e = entries[i];
RemoveAll(e); RemoveAll(e);
...@@ -405,7 +417,7 @@ static std::vector<std::string> SplitString(std::string str, std::string separat ...@@ -405,7 +417,7 @@ static std::vector<std::string> SplitString(std::string str, std::string separat
#else #else
bool result = rmdir(path.c_str()) == 0; bool result = rmdir(path.c_str()) == 0;
#endif #endif
if (!result) if(!result)
{ {
printf("can't remove directory: %s\n", path.c_str()); printf("can't remove directory: %s\n", path.c_str());
} }
...@@ -417,50 +429,49 @@ static std::vector<std::string> SplitString(std::string str, std::string separat ...@@ -417,50 +429,49 @@ static std::vector<std::string> SplitString(std::string str, std::string separat
#else #else
bool result = unlink(path.c_str()) == 0; bool result = unlink(path.c_str()) == 0;
#endif #endif
if (!result) if(!result)
{ {
printf("can't remove file: %s\n", path.c_str()); printf("can't remove file: %s\n", path.c_str());
} }
} }
} }
void Remove(const string &directory, const string &extension)
{
DIR *dir; void Remove(const string& directory, const string& extension)
{
DIR* dir;
static int numberOfFiles = 0; static int numberOfFiles = 0;
if ((dir = opendir(directory.c_str())) != 0) if((dir = opendir(directory.c_str())) != 0)
{ {
/* find all the files and directories within directory */ /* find all the files and directories within directory */
try try
{ {
struct dirent *ent; struct dirent* ent;
while ((ent = readdir(dir)) != 0) while((ent = readdir(dir)) != 0)
{ {
const char* name = ent->d_name; const char* name = ent->d_name;
if ((name[0] == 0) || (name[0] == '.' && name[1] == 0) || (name[0] == '.' && name[1] == '.' && name[2] == 0)) if((name[0] == 0) || (name[0] == '.' && name[1] == 0) ||
(name[0] == '.' && name[1] == '.' && name[2] == 0))
continue; continue;
string path = JoinPath(directory, name); string path = JoinPath(directory, name);
if (isDir(path, dir)) if(isDir(path, dir))
{ {
Remove(path, extension); Remove(path, extension);
} }
// �ж���չ�� // �ж���չ��
if (extension.empty() || wildcmp(name, extension.c_str())) if(extension.empty() || wildcmp(name, extension.c_str()))
{ {
RemoveAll(path); RemoveAll(path);
++numberOfFiles; ++numberOfFiles;
printf("%s deleted! number of deleted files:%d\n", path.c_str(), numberOfFiles); printf("%s deleted! number of deleted files:%d\n", path.c_str(), numberOfFiles);
} }
} }
} }
catch (...) catch(...)
{ {
closedir(dir); closedir(dir);
throw; throw;
...@@ -474,49 +485,49 @@ static std::vector<std::string> SplitString(std::string str, std::string separat ...@@ -474,49 +485,49 @@ static std::vector<std::string> SplitString(std::string str, std::string separat
// ����RemoveAllɾ��Ŀ¼ // ����RemoveAllɾ��Ŀ¼
RemoveAll(directory); RemoveAll(directory);
} }
string GetFileName(const string &path) string GetFileName(const string& path)
{ {
string fileName; string fileName;
int indexOfPathSeparator = -1; int indexOfPathSeparator = -1;
for (int i = path.size() - 1; i >= 0; --i) for(int i = path.size() - 1; i >= 0; --i)
{ {
if (IsPathSeparator(path[i])) if(IsPathSeparator(path[i]))
{ {
fileName = path.substr(i + 1, path.size() - i - 1); fileName = path.substr(i + 1, path.size() - i - 1);
indexOfPathSeparator = i; indexOfPathSeparator = i;
break; break;
} }
} }
if (indexOfPathSeparator == -1) if(indexOfPathSeparator == -1)
{ {
fileName = path; fileName = path;
} }
return fileName; return fileName;
} }
string GetFileName_NoExtension(const string &path) string GetFileName_NoExtension(const string& path)
{ {
string fileName=GetFileName(path); string fileName = GetFileName(path);
string fileName_NoExtension; string fileName_NoExtension;
for(int i=fileName.size()-1;i>0;--i) for(int i = fileName.size() - 1; i > 0; --i)
{ {
if(fileName[i]=='.') if(fileName[i] == '.')
{ {
fileName_NoExtension=fileName.substr(0,i); fileName_NoExtension = fileName.substr(0, i);
break; break;
} }
} }
return fileName_NoExtension; return fileName_NoExtension;
} }
string GetExtension(const string &path) string GetExtension(const string& path)
{ {
string fileName; string fileName;
for (int i = path.size() - 1; i >= 0; --i) for(int i = path.size() - 1; i >= 0; --i)
{ {
if (path[i]=='.') if(path[i] == '.')
{ {
fileName = path.substr(i, path.size() - i); fileName = path.substr(i, path.size() - i);
break; break;
...@@ -524,56 +535,55 @@ static std::vector<std::string> SplitString(std::string str, std::string separat ...@@ -524,56 +535,55 @@ static std::vector<std::string> SplitString(std::string str, std::string separat
} }
return fileName; return fileName;
}
} string GetParentPath(const string& path)
{
string GetParentPath(const string &path)
{
string fileName; string fileName;
for (int i = path.size() - 1; i >= 0; --i) for(int i = path.size() - 1; i >= 0; --i)
{ {
if (IsPathSeparator(path[i])) if(IsPathSeparator(path[i]))
{ {
fileName = path.substr(0, i+1); fileName = path.substr(0, i + 1);
break; break;
} }
} }
return fileName; return fileName;
} }
static bool CreateDirectory(const string &path) static bool CreateDirectory(const string& path)
{ {
#if defined WIN32 || defined _WIN32 || defined WINCE #if defined WIN32 || defined _WIN32 || defined WINCE
#ifdef WINRT #ifdef WINRT
wchar_t wpath[MAX_PATH]; wchar_t wpath[MAX_PATH];
size_t copied = mbstowcs(wpath, path.c_str(), MAX_PATH); size_t copied = mbstowcs(wpath, path.c_str(), MAX_PATH);
CV_Assert((copied != MAX_PATH) && (copied != (size_t)-1)); CV_Assert((copied != MAX_PATH) && (copied != (size_t)-1));
int result = CreateDirectoryA(wpath, NULL) ? 0 : -1; int result = CreateDirectoryA(wpath, NULL) ? 0 : -1;
#else #else
int result = _mkdir(path.c_str()); int result = _mkdir(path.c_str());
#endif #endif
#elif defined __linux__ || defined __APPLE__ #elif defined __linux__ || defined __APPLE__
int result = mkdir(path.c_str(), 0777); int result = mkdir(path.c_str(), 0777);
#else #else
int result = -1; int result = -1;
#endif #endif
if (result == -1) if(result == -1)
{ {
return IsDirectory(path); return IsDirectory(path);
} }
return true; return true;
} }
bool CreateDirectories(const string &directoryPath) bool CreateDirectories(const string& directoryPath)
{ {
string path = directoryPath; string path = directoryPath;
for (;;) for(;;)
{ {
char last_char = path.empty() ? 0 : path[path.length() - 1]; char last_char = path.empty() ? 0 : path[path.length() - 1];
if (IsPathSeparator(last_char)) if(IsPathSeparator(last_char))
{ {
path = path.substr(0, path.length() - 1); path = path.substr(0, path.length() - 1);
continue; continue;
...@@ -581,35 +591,35 @@ static std::vector<std::string> SplitString(std::string str, std::string separat ...@@ -581,35 +591,35 @@ static std::vector<std::string> SplitString(std::string str, std::string separat
break; break;
} }
if (path.empty() || path == "./" || path == ".\\" || path == ".") if(path.empty() || path == "./" || path == ".\\" || path == ".")
return true; return true;
if (IsDirectory(path)) if(IsDirectory(path))
return true; return true;
size_t pos = path.rfind('/'); size_t pos = path.rfind('/');
if (pos == string::npos) if(pos == string::npos)
pos = path.rfind('\\'); pos = path.rfind('\\');
if (pos != string::npos) if(pos != string::npos)
{ {
string parent_directory = path.substr(0, pos); string parent_directory = path.substr(0, pos);
if (!parent_directory.empty()) if(!parent_directory.empty())
{ {
if (!CreateDirectories(parent_directory)) if(!CreateDirectories(parent_directory))
return false; return false;
} }
} }
return CreateDirectory(path); return CreateDirectory(path);
} }
bool CopyFile(const string srcPath, const string dstPath) bool CopyFile(const string srcPath, const string dstPath)
{ {
std::ifstream srcFile(srcPath,ios::binary); std::ifstream srcFile(srcPath, ios::binary);
std::ofstream dstFile(dstPath,ios::binary); std::ofstream dstFile(dstPath, ios::binary);
if(!srcFile.is_open()) if(!srcFile.is_open())
{ {
printf("can not open %s\n",srcPath.c_str()); printf("can not open %s\n", srcPath.c_str());
return false; return false;
} }
if(!dstFile.is_open()) if(!dstFile.is_open())
...@@ -617,77 +627,72 @@ static std::vector<std::string> SplitString(std::string str, std::string separat ...@@ -617,77 +627,72 @@ static std::vector<std::string> SplitString(std::string str, std::string separat
printf("can not open %s\n", dstPath.c_str()); printf("can not open %s\n", dstPath.c_str());
return false; return false;
} }
if(srcPath==dstPath) if(srcPath == dstPath)
{ {
printf("src can not be same with dst\n"); printf("src can not be same with dst\n");
return false; return false;
} }
char buffer[2048]; char buffer[2048];
unsigned int numberOfBytes=0; unsigned int numberOfBytes = 0;
while(srcFile) while(srcFile)
{ {
srcFile.read(buffer,2048); srcFile.read(buffer, 2048);
dstFile.write(buffer,srcFile.gcount()); dstFile.write(buffer, srcFile.gcount());
numberOfBytes+=srcFile.gcount(); numberOfBytes += srcFile.gcount();
} }
srcFile.close(); srcFile.close();
dstFile.close(); dstFile.close();
return true; return true;
} }
bool CopyDirectories(string srcPath, const string dstPath) bool CopyDirectories(string srcPath, const string dstPath)
{ {
if(srcPath==dstPath) if(srcPath == dstPath)
{ {
printf("src can not be same with dst\n"); printf("src can not be same with dst\n");
return false; return false;
} }
// ȥ������·���ָ���
srcPath = srcPath.substr(0, srcPath.size() - 1); srcPath = srcPath.substr(0, srcPath.size() - 1);
vector<string> fileNameList; vector<string> fileNameList;
GetFileNameList2(srcPath, "", fileNameList, true, true); GetFileNameList2(srcPath, "", fileNameList, true, true);
string parentPathOfSrc=GetParentPath(srcPath); string parentPathOfSrc = GetParentPath(srcPath);
int length=parentPathOfSrc.size(); int length = parentPathOfSrc.size();
// create all directories // create all directories
for(int i=0;i<fileNameList.size();++i) for(int i = 0; i < fileNameList.size(); ++i)
{ {
// create directory // create directory
string srcFilePath=fileNameList[i]; string srcFilePath = fileNameList[i];
string subStr=srcFilePath.substr(length,srcFilePath.size()-length); string subStr = srcFilePath.substr(length, srcFilePath.size() - length);
string dstFilePath=dstPath+subStr; string dstFilePath = dstPath + subStr;
string parentPathOfDst=GetParentPath(dstFilePath); string parentPathOfDst = GetParentPath(dstFilePath);
CreateDirectories(parentPathOfDst); CreateDirectories(parentPathOfDst);
} }
// copy file // copy file
for(int i=0;i<fileNameList.size();++i) for(int i = 0; i < fileNameList.size(); ++i)
{ {
string srcFilePath=fileNameList[i]; string srcFilePath = fileNameList[i];
if (IsDirectory(srcFilePath)) if(IsDirectory(srcFilePath))
{ {
continue; continue;
} }
string subStr=srcFilePath.substr(length,srcFilePath.size()-length); string subStr = srcFilePath.substr(length, srcFilePath.size() - length);
string dstFilePath=dstPath+subStr; string dstFilePath = dstPath + subStr;
// copy file // copy file
CopyFile(srcFilePath,dstFilePath); CopyFile(srcFilePath, dstFilePath);
// process // process
double process = (1.0*(i + 1) / fileNameList.size()) * 100; double process = (1.0 * (i + 1) / fileNameList.size()) * 100;
printf("%s done! %f% \n", GetFileName(fileNameList[i]).c_str(), process); printf("%s done! %f% \n", GetFileName(fileNameList[i]).c_str(), process);
} }
printf("all done!(the number of files:%d)\n", fileNameList.size()); printf("all done!(the number of files:%d)\n", fileNameList.size());
return true; return true;
}
} }
} // namespace migraphxSamples
...@@ -6,23 +6,22 @@ ...@@ -6,23 +6,22 @@
#include <string> #include <string>
#include <vector> #include <vector>
namespace migraphxSamples namespace migraphxSamples {
{
// 路径是否存在 // 路径是否存在
bool Exists(const std::string &path); bool Exists(const std::string& path);
// 路径是否为目录 // 路径是否为目录
bool IsDirectory(const std::string &path); bool IsDirectory(const std::string& path);
// 是否是路径分隔符(Linux:‘/’,Windows:’\\’) // 是否是路径分隔符(Linux:‘/’,Windows:’\\’)
bool IsPathSeparator(char c); bool IsPathSeparator(char c);
// 路径拼接 // 路径拼接
std::string JoinPath(const std::string &base, const std::string &path); std::string JoinPath(const std::string& base, const std::string& path);
// 创建多级目录,注意:创建多级目录的时候,目标目录是不能有文件存在的 // 创建多级目录,注意:创建多级目录的时候,目标目录是不能有文件存在的
bool CreateDirectories(const std::string &directoryPath); bool CreateDirectories(const std::string& directoryPath);
/** 生成符合指定模式的文件名列表(支持递归遍历) /** 生成符合指定模式的文件名列表(支持递归遍历)
* *
...@@ -30,31 +29,40 @@ bool CreateDirectories(const std::string &directoryPath); ...@@ -30,31 +29,40 @@ bool CreateDirectories(const std::string &directoryPath);
* addPath:是否包含父路径 * addPath:是否包含父路径
* 注意: * 注意:
1. 多个模式使用","分割,比如"*.jpg,*.png" 1. 多个模式使用","分割,比如"*.jpg,*.png"
2. 支持通配符'*','?' ,比如第一个字符是7的所有文件名:"7*.*", 以512结尾的所有jpg文件名:"*512.jpg" 2. 支持通配符'*','?' ,比如第一个字符是7的所有文件名:"7*.*",
以512结尾的所有jpg文件名:"*512.jpg"
3. 使用"*.jpg",而不是".jpg" 3. 使用"*.jpg",而不是".jpg"
4. 空string表示返回所有结果 4. 空string表示返回所有结果
5. 不能返回子目录名 5. 不能返回子目录名
* *
*/ */
void GetFileNameList(const std::string &directory, const std::string &pattern, std::vector<std::string> &result, bool recursive, bool addPath); void GetFileNameList(const std::string& directory,
const std::string& pattern,
std::vector<std::string>& result,
bool recursive,
bool addPath);
// 与GetFileNameList的区别在于如果有子目录,在addPath为true的时候会返回子目录路径(目录名最后有"/") // 与GetFileNameList的区别在于如果有子目录,在addPath为true的时候会返回子目录路径(目录名最后有"/")
void GetFileNameList2(const std::string &directory, const std::string &pattern, std::vector<std::string> &result, bool recursive, bool addPath); void GetFileNameList2(const std::string& directory,
const std::string& pattern,
std::vector<std::string>& result,
bool recursive,
bool addPath);
// 删除文件或者目录,支持递归删除 // 删除文件或者目录,支持递归删除
void Remove(const std::string &directory, const std::string &extension=""); void Remove(const std::string& directory, const std::string& extension = "");
/** 获取路径的文件名和扩展名 /** 获取路径的文件名和扩展名
* *
* 示例:path为D:/1/1.txt,则GetFileName()为1.txt,GetFileName_NoExtension()为1,GetExtension()为.txt,GetParentPath()为D:/1/ * 示例:path为D:/1/1.txt,则GetFileName()为1.txt,GetFileName_NoExtension()为1,GetExtension()为.txt,GetParentPath()为D:/1/
*/ */
std::string GetFileName(const std::string &path); std::string GetFileName(const std::string& path);
std::string GetFileName_NoExtension(const std::string &path); std::string GetFileName_NoExtension(const std::string& path);
std::string GetExtension(const std::string &path); std::string GetExtension(const std::string& path);
std::string GetParentPath(const std::string &path); std::string GetParentPath(const std::string& path);
// 拷贝文件 // 拷贝文件
bool CopyFile(const std::string srcPath,const std::string dstPath); bool CopyFile(const std::string srcPath, const std::string dstPath);
/** 拷贝目录 /** 拷贝目录
* *
...@@ -63,8 +71,8 @@ bool CopyFile(const std::string srcPath,const std::string dstPath); ...@@ -63,8 +71,8 @@ bool CopyFile(const std::string srcPath,const std::string dstPath);
1.第一个参数的最后不能加”/” 1.第一个参数的最后不能加”/”
2.不能拷贝隐藏文件 2.不能拷贝隐藏文件
*/ */
bool CopyDirectories(std::string srcPath,const std::string dstPath); bool CopyDirectories(std::string srcPath, const std::string dstPath);
} } // namespace migraphxSamples
#endif #endif
...@@ -4,11 +4,13 @@ ...@@ -4,11 +4,13 @@
#define __SIMPLE_LOG_H__ #define __SIMPLE_LOG_H__
#include <time.h> #include <time.h>
#include <string>
#include <map> #include <map>
#include <thread>
#include <mutex> #include <mutex>
#if (defined WIN32 || defined _WIN32) #include <string>
#include <thread>
#if(defined WIN32 || defined _WIN32)
#include <Windows.h> #include <Windows.h>
#else #else
#include <sys/time.h> #include <sys/time.h>
...@@ -16,20 +18,22 @@ ...@@ -16,20 +18,22 @@
using namespace std; using namespace std;
/** 简易日志 /** 简易日志
* *
* 不依赖于其他第三方库,只需要包含一个头文件就可以使用。提供了4种日志级别,包括INFO,DEBUG,WARN和ERROR。 * 不依赖于其他第三方库,只需要包含一个头文件就可以使用。提供了4种日志级别,包括INFO,DEBUG,WARN和ERROR。
* *
* 示例1: * 示例1:
// 初始化日志,在./Log/目录下创建两个日志文件log1.log和log2.log(注意:目录./Log/需要存在,否则日志创建失败) //
初始化日志,在./Log/目录下创建两个日志文件log1.log和log2.log(注意:目录./Log/需要存在,否则日志创建失败)
LogManager::GetInstance()->Initialize("./Log/","log1"); LogManager::GetInstance()->Initialize("./Log/","log1");
LogManager::GetInstance()->Initialize("./Log/","log2"); LogManager::GetInstance()->Initialize("./Log/","log2");
// 写日志 // 写日志
string log = "Hello World"; string log = "Hello World";
LOG_INFO(LogManager::GetInstance()->GetLogFile("log1"), "%s\n", log.c_str()); // 写入log1.log LOG_INFO(LogManager::GetInstance()->GetLogFile("log1"), "%s\n",
LOG_INFO(LogManager::GetInstance()->GetLogFile("log2"), "%s\n", log.c_str()); // 写入log2.log log.c_str()); // 写入log1.log
LOG_INFO(LogManager::GetInstance()->GetLogFile("log2"), "%s\n",
log.c_str()); // 写入log2.log
// 关闭日志 // 关闭日志
LogManager::GetInstance()->Close("log1"); LogManager::GetInstance()->Close("log1");
...@@ -50,44 +54,43 @@ using namespace std; ...@@ -50,44 +54,43 @@ using namespace std;
class LogManager class LogManager
{ {
private: private:
LogManager(){} LogManager() {}
public: public:
~LogManager(){} ~LogManager() {}
inline void Initialize(const string &parentPath,const string &logName) inline void Initialize(const string& parentPath, const string& logName)
{ {
// 日志名为空表示输出到控制台 // 日志名为空表示输出到控制台
if(logName.size()==0) if(logName.size() == 0)
return; return;
// 查找该日志文件,如果没有则创建 // 查找该日志文件,如果没有则创建
std::map<string, FILE*>::const_iterator iter = logMap.find(logName); std::map<string, FILE*>::const_iterator iter = logMap.find(logName);
if (iter == logMap.end()) if(iter == logMap.end())
{ {
string pathOfLog = parentPath+ logName + ".log"; string pathOfLog = parentPath + logName + ".log";
FILE *logFile = fopen(pathOfLog.c_str(), "a"); // w:覆盖原有文件,a:追加 FILE* logFile = fopen(pathOfLog.c_str(), "a"); // w:覆盖原有文件,a:追加
if(logFile!=NULL) if(logFile != NULL)
{ {
logMap.insert(std::make_pair(logName, logFile)); logMap.insert(std::make_pair(logName, logFile));
} }
} }
} }
inline FILE* GetLogFile(const string &logName) inline FILE* GetLogFile(const string& logName)
{ {
std::map<string, FILE*>::const_iterator iter=logMap.find(logName); std::map<string, FILE*>::const_iterator iter = logMap.find(logName);
if(iter==logMap.end()) if(iter == logMap.end())
{ {
return NULL; return NULL;
} }
return (*iter).second; return (*iter).second;
} }
inline void Close(const string &logName) inline void Close(const string& logName)
{ {
std::map<string, FILE*>::const_iterator iter=logMap.find(logName); std::map<string, FILE*>::const_iterator iter = logMap.find(logName);
if(iter==logMap.end()) if(iter == logMap.end())
{ {
return; return;
} }
...@@ -95,10 +98,7 @@ public: ...@@ -95,10 +98,7 @@ public:
fclose((*iter).second); fclose((*iter).second);
logMap.erase(iter); logMap.erase(iter);
} }
inline std::mutex &GetLogMutex() inline std::mutex& GetLogMutex() { return logMutex; }
{
return logMutex;
}
// Singleton // Singleton
static LogManager* GetInstance() static LogManager* GetInstance()
...@@ -106,17 +106,18 @@ public: ...@@ -106,17 +106,18 @@ public:
static LogManager logManager; static LogManager logManager;
return &logManager; return &logManager;
} }
private:
private:
std::map<string, FILE*> logMap; std::map<string, FILE*> logMap;
std::mutex logMutex; std::mutex logMutex;
}; };
#ifdef LOG_MUTEX #ifdef LOG_MUTEX
#define LOCK LogManager::GetInstance()->GetLogMutex().lock() #define LOCK LogManager::GetInstance()->GetLogMutex().lock()
#define UNLOCK LogManager::GetInstance()->GetLogMutex().unlock() #define UNLOCK LogManager::GetInstance()->GetLogMutex().unlock()
#else #else
#define LOCK #define LOCK
#define UNLOCK #define UNLOCK
#endif #endif
// log time // log time
...@@ -131,53 +132,53 @@ typedef struct _LogTime ...@@ -131,53 +132,53 @@ typedef struct _LogTime
string millisecond; // ms string millisecond; // ms
string microsecond; // us string microsecond; // us
string weekDay; string weekDay;
}LogTime; } LogTime;
inline LogTime GetTime() inline LogTime GetTime()
{ {
LogTime currentTime; LogTime currentTime;
#if (defined WIN32 || defined _WIN32) #if(defined WIN32 || defined _WIN32)
SYSTEMTIME systemTime; SYSTEMTIME systemTime;
GetLocalTime(&systemTime); GetLocalTime(&systemTime);
char temp[8] = { 0 }; char temp[8] = {0};
sprintf(temp, "%04d", systemTime.wYear); sprintf(temp, "%04d", systemTime.wYear);
currentTime.year=string(temp); currentTime.year = string(temp);
sprintf(temp, "%02d", systemTime.wMonth); sprintf(temp, "%02d", systemTime.wMonth);
currentTime.month=string(temp); currentTime.month = string(temp);
sprintf(temp, "%02d", systemTime.wDay); sprintf(temp, "%02d", systemTime.wDay);
currentTime.day=string(temp); currentTime.day = string(temp);
sprintf(temp, "%02d", systemTime.wHour); sprintf(temp, "%02d", systemTime.wHour);
currentTime.hour=string(temp); currentTime.hour = string(temp);
sprintf(temp, "%02d", systemTime.wMinute); sprintf(temp, "%02d", systemTime.wMinute);
currentTime.minute=string(temp); currentTime.minute = string(temp);
sprintf(temp, "%02d", systemTime.wSecond); sprintf(temp, "%02d", systemTime.wSecond);
currentTime.second=string(temp); currentTime.second = string(temp);
sprintf(temp, "%03d", systemTime.wMilliseconds); sprintf(temp, "%03d", systemTime.wMilliseconds);
currentTime.millisecond=string(temp); currentTime.millisecond = string(temp);
sprintf(temp, "%d", systemTime.wDayOfWeek); sprintf(temp, "%d", systemTime.wDayOfWeek);
currentTime.weekDay=string(temp); currentTime.weekDay = string(temp);
#else #else
struct timeval tv; struct timeval tv;
struct tm *p; struct tm* p;
gettimeofday(&tv, NULL); gettimeofday(&tv, NULL);
p = localtime(&tv.tv_sec); p = localtime(&tv.tv_sec);
char temp[8]={0}; char temp[8] = {0};
sprintf(temp,"%04d",1900+p->tm_year); sprintf(temp, "%04d", 1900 + p->tm_year);
currentTime.year=string(temp); currentTime.year = string(temp);
sprintf(temp,"%02d",1+p->tm_mon); sprintf(temp, "%02d", 1 + p->tm_mon);
currentTime.month=string(temp); currentTime.month = string(temp);
sprintf(temp,"%02d",p->tm_mday); sprintf(temp, "%02d", p->tm_mday);
currentTime.day=string(temp); currentTime.day = string(temp);
sprintf(temp,"%02d",p->tm_hour); sprintf(temp, "%02d", p->tm_hour);
currentTime.hour=string(temp); currentTime.hour = string(temp);
sprintf(temp,"%02d",p->tm_min); sprintf(temp, "%02d", p->tm_min);
currentTime.minute=string(temp); currentTime.minute = string(temp);
sprintf(temp,"%02d",p->tm_sec); sprintf(temp, "%02d", p->tm_sec);
currentTime.second=string(temp); currentTime.second = string(temp);
sprintf(temp,"%03d",(int)(tv.tv_usec/1000)); sprintf(temp, "%03d", (int)(tv.tv_usec / 1000));
currentTime.millisecond = string(temp); currentTime.millisecond = string(temp);
sprintf(temp, "%03d", (int)(tv.tv_usec % 1000)); sprintf(temp, "%03d", (int)(tv.tv_usec % 1000));
currentTime.microsecond = string(temp); currentTime.microsecond = string(temp);
...@@ -188,60 +189,82 @@ inline LogTime GetTime() ...@@ -188,60 +189,82 @@ inline LogTime GetTime()
} }
#define LOG_TIME(logFile) \ #define LOG_TIME(logFile) \
do\ do \
{\ { \
LogTime currentTime=GetTime(); \ LogTime currentTime = GetTime(); \
fprintf(((logFile == NULL) ? stdout : logFile), "%s-%s-%s %s:%s:%s.%s\t",currentTime.year.c_str(),currentTime.month.c_str(),currentTime.day.c_str(),currentTime.hour.c_str(),currentTime.minute.c_str(),currentTime.second.c_str(),currentTime.millisecond.c_str()); \ fprintf(((logFile == NULL) ? stdout : logFile), \
}while (0) "%s-%s-%s %s:%s:%s.%s\t", \
currentTime.year.c_str(), \
currentTime.month.c_str(), \
#define LOG_INFO(logFile,logInfo, ...) \ currentTime.day.c_str(), \
do\ currentTime.hour.c_str(), \
{\ currentTime.minute.c_str(), \
currentTime.second.c_str(), \
currentTime.millisecond.c_str()); \
} while(0)
#define LOG_INFO(logFile, logInfo, ...) \
do \
{ \
LOCK; \ LOCK; \
LOG_TIME(logFile); \ LOG_TIME(logFile); \
fprintf(((logFile == NULL) ? stdout : logFile), "INFO\t"); \ fprintf(((logFile == NULL) ? stdout : logFile), "INFO\t"); \
fprintf(((logFile == NULL) ? stdout : logFile), "[%s:%d (%s) ]: ", __FILE__, __LINE__, __FUNCTION__); \ fprintf(((logFile == NULL) ? stdout : logFile), \
fprintf(((logFile == NULL) ? stdout : logFile), logInfo, ## __VA_ARGS__); \ "[%s:%d (%s) ]: ", \
__FILE__, \
__LINE__, \
__FUNCTION__); \
fprintf(((logFile == NULL) ? stdout : logFile), logInfo, ##__VA_ARGS__); \
fflush(logFile); \ fflush(logFile); \
UNLOCK; \ UNLOCK; \
} while (0) } while(0)
#define LOG_DEBUG(logFile,logInfo, ...) \ #define LOG_DEBUG(logFile, logInfo, ...) \
do\ do \
{\ { \
LOCK; \ LOCK; \
LOG_TIME(logFile);\ LOG_TIME(logFile); \
fprintf(((logFile==NULL)?stdout:logFile), "DEBUG\t"); \ fprintf(((logFile == NULL) ? stdout : logFile), "DEBUG\t"); \
fprintf(((logFile==NULL)?stdout:logFile), "[%s:%d (%s) ]: ", __FILE__, __LINE__, __FUNCTION__); \ fprintf(((logFile == NULL) ? stdout : logFile), \
fprintf(((logFile==NULL)?stdout:logFile),logInfo, ## __VA_ARGS__); \ "[%s:%d (%s) ]: ", \
__FILE__, \
__LINE__, \
__FUNCTION__); \
fprintf(((logFile == NULL) ? stdout : logFile), logInfo, ##__VA_ARGS__); \
fflush(logFile); \ fflush(logFile); \
UNLOCK; \ UNLOCK; \
} while (0) } while(0)
#define LOG_ERROR(logFile,logInfo, ...) \ #define LOG_ERROR(logFile, logInfo, ...) \
do\ do \
{\ { \
LOCK; \ LOCK; \
LOG_TIME(logFile);\ LOG_TIME(logFile); \
fprintf(((logFile==NULL)?stdout:logFile), "ERROR\t"); \ fprintf(((logFile == NULL) ? stdout : logFile), "ERROR\t"); \
fprintf(((logFile==NULL)?stdout:logFile), "[%s:%d (%s) ]: ", __FILE__, __LINE__, __FUNCTION__); \ fprintf(((logFile == NULL) ? stdout : logFile), \
fprintf(((logFile==NULL)?stdout:logFile),logInfo, ## __VA_ARGS__); \ "[%s:%d (%s) ]: ", \
__FILE__, \
__LINE__, \
__FUNCTION__); \
fprintf(((logFile == NULL) ? stdout : logFile), logInfo, ##__VA_ARGS__); \
fflush(logFile); \ fflush(logFile); \
UNLOCK; \ UNLOCK; \
} while (0) } while(0)
#define LOG_WARN(logFile,logInfo, ...) \ #define LOG_WARN(logFile, logInfo, ...) \
do\ do \
{\ { \
LOCK; \ LOCK; \
LOG_TIME(logFile);\ LOG_TIME(logFile); \
fprintf(((logFile==NULL)?stdout:logFile), "WARN\t"); \ fprintf(((logFile == NULL) ? stdout : logFile), "WARN\t"); \
fprintf(((logFile==NULL)?stdout:logFile), "[%s:%d (%s) ]: ", __FILE__, __LINE__, __FUNCTION__); \ fprintf(((logFile == NULL) ? stdout : logFile), \
fprintf(((logFile==NULL)?stdout:logFile),logInfo, ## __VA_ARGS__); \ "[%s:%d (%s) ]: ", \
__FILE__, \
__LINE__, \
__FUNCTION__); \
fprintf(((logFile == NULL) ? stdout : logFile), logInfo, ##__VA_ARGS__); \
fflush(logFile); \ fflush(logFile); \
UNLOCK; \ UNLOCK; \
} while (0) } while(0)
#endif // __SIMPLE_LOG_H__ #endif // __SIMPLE_LOG_H__
#include <stdexcept>
#include <algorithm> #include <algorithm>
#include <cstring> #include <cstring>
#include <fstream> #include <fstream>
#include "utf8proc.h" #include <stdexcept>
#include "./tokenization.h" #include "./tokenization.h"
#include "utf8proc.h"
namespace cuBERT { namespace cuBERT {
void FullTokenizer::convert_tokens_to_ids(const std::vector<std::string> &tokens, uint64_t *ids) { void FullTokenizer::convert_tokens_to_ids(const std::vector<std::string>& tokens, uint64_t* ids)
for (int i = 0; i < tokens.size(); ++i) { {
for(int i = 0; i < tokens.size(); ++i)
{
ids[i] = convert_token_to_id(tokens[i]); ids[i] = convert_token_to_id(tokens[i]);
} }
} }
// trim from start (in place) // trim from start (in place)
static inline void ltrim(std::string &s) { static inline void ltrim(std::string& s)
s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int ch) { {
return !std::isspace(ch); s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int ch) { return !std::isspace(ch); }));
})); }
}
// trim from end (in place) // trim from end (in place)
static inline void rtrim(std::string &s) { static inline void rtrim(std::string& s)
s.erase(std::find_if(s.rbegin(), s.rend(), [](int ch) { {
return !std::isspace(ch); s.erase(std::find_if(s.rbegin(), s.rend(), [](int ch) { return !std::isspace(ch); }).base(),
}).base(), s.end()); s.end());
} }
// trim from both ends (in place) // trim from both ends (in place)
static inline void trim(std::string &s) { static inline void trim(std::string& s)
{
ltrim(s); ltrim(s);
rtrim(s); rtrim(s);
} }
void load_vocab(const char *vocab_file, std::unordered_map<std::string, uint64_t> *vocab) { void load_vocab(const char* vocab_file, std::unordered_map<std::string, uint64_t>* vocab)
{
std::ifstream file(vocab_file); std::ifstream file(vocab_file);
if (!file) { if(!file)
{
throw std::invalid_argument("Unable to open vocab file"); throw std::invalid_argument("Unable to open vocab file");
} }
unsigned int index = 0; unsigned int index = 0;
std::string line; std::string line;
while (std::getline(file, line)) { while(std::getline(file, line))
{
trim(line); trim(line);
(*vocab)[line] = index; (*vocab)[line] = index;
index++; index++;
} }
file.close(); file.close();
} }
inline bool _is_whitespace(int c, const char *cat) { inline bool _is_whitespace(int c, const char* cat)
if (c == ' ' || c == '\t' || c == '\n' || c == '\r') { {
if(c == ' ' || c == '\t' || c == '\n' || c == '\r')
{
return true; return true;
} }
return cat[0] == 'Z' && cat[1] == 's'; return cat[0] == 'Z' && cat[1] == 's';
} }
inline bool _is_control(int c, const char *cat) { inline bool _is_control(int c, const char* cat)
// These are technically control characters but we count them as whitespace characters. {
if (c == '\t' || c == '\n' || c == '\r') { // These are technically control characters but we count them as whitespace
// characters.
if(c == '\t' || c == '\n' || c == '\r')
{
return false; return false;
} }
return 'C' == *cat; return 'C' == *cat;
} }
inline bool _is_punctuation(int cp, const char *cat) { inline bool _is_punctuation(int cp, const char* cat)
// We treat all non-letter/number ASCII as punctuation. {
// Characters such as "^", "$", and "`" are not in the Unicode // We treat all non-letter/number ASCII as punctuation.
// Punctuation class but we treat them as punctuation anyways, for // Characters such as "^", "$", and "`" are not in the Unicode
// consistency. // Punctuation class but we treat them as punctuation anyways, for
if ((cp >= 33 && cp <= 47) || (cp >= 58 && cp <= 64) || // consistency.
(cp >= 91 && cp <= 96) || (cp >= 123 && cp <= 126)) { if((cp >= 33 && cp <= 47) || (cp >= 58 && cp <= 64) || (cp >= 91 && cp <= 96) ||
(cp >= 123 && cp <= 126))
{
return true; return true;
} }
return 'P' == *cat; return 'P' == *cat;
} }
bool _is_whitespace(int c) {
return _is_whitespace(c, utf8proc_category_string(c));
}
bool _is_control(int c) {
return _is_control(c, utf8proc_category_string(c));
}
bool _is_punctuation(int cp) {
return _is_punctuation(cp, utf8proc_category_string(cp));
}
bool BasicTokenizer::_is_chinese_char(int cp) { bool _is_whitespace(int c) { return _is_whitespace(c, utf8proc_category_string(c)); }
// This defines a "chinese character" as anything in the CJK Unicode block:
// https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) bool _is_control(int c) { return _is_control(c, utf8proc_category_string(c)); }
//
// Note that the CJK Unicode block is NOT all Japanese and Korean characters, bool _is_punctuation(int cp) { return _is_punctuation(cp, utf8proc_category_string(cp)); }
// despite its name. The modern Korean Hangul alphabet is a different block,
// as is Japanese Hiragana and Katakana. Those alphabets are used to write bool BasicTokenizer::_is_chinese_char(int cp)
// space-separated words, so they are not treated specially and handled {
// like the all of the other languages. // This defines a "chinese character" as anything in the CJK Unicode block:
return (cp >= 0x4E00 && cp <= 0x9FFF) || // https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
(cp >= 0x3400 && cp <= 0x4DBF) || //
(cp >= 0x20000 && cp <= 0x2A6DF) || // Note that the CJK Unicode block is NOT all Japanese and Korean characters,
(cp >= 0x2A700 && cp <= 0x2B73F) || // despite its name. The modern Korean Hangul alphabet is a different block,
(cp >= 0x2B740 && cp <= 0x2B81F) || // as is Japanese Hiragana and Katakana. Those alphabets are used to write
(cp >= 0x2B820 && cp <= 0x2CEAF) || // space-separated words, so they are not treated specially and handled
(cp >= 0xF900 && cp <= 0xFAFF) || // like the all of the other languages.
(cp >= 0x2F800 && cp <= 0x2FA1F); return (cp >= 0x4E00 && cp <= 0x9FFF) || (cp >= 0x3400 && cp <= 0x4DBF) ||
} (cp >= 0x20000 && cp <= 0x2A6DF) || (cp >= 0x2A700 && cp <= 0x2B73F) ||
(cp >= 0x2B740 && cp <= 0x2B81F) || (cp >= 0x2B820 && cp <= 0x2CEAF) ||
(cp >= 0xF900 && cp <= 0xFAFF) || (cp >= 0x2F800 && cp <= 0x2FA1F);
}
void BasicTokenizer::tokenize(const char *text, std::vector<std::string> *output_tokens, size_t max_length) { void BasicTokenizer::tokenize(const char* text,
// This was added on November 1st, 2018 for the multilingual and Chinese std::vector<std::string>* output_tokens,
// models. This is also applied to the English models now, but it doesn't size_t max_length)
// matter since the English models were not trained on any Chinese data {
// and generally don't have any Chinese data in them (there are Chinese // This was added on November 1st, 2018 for the multilingual and Chinese
// characters in the vocabulary because Wikipedia does have some Chinese // models. This is also applied to the English models now, but it doesn't
// words in the English Wikipedia.). // matter since the English models were not trained on any Chinese data
if (do_lower_case) { // and generally don't have any Chinese data in them (there are Chinese
text = (const char *) utf8proc_NFD((const utf8proc_uint8_t *) text); // characters in the vocabulary because Wikipedia does have some Chinese
// words in the English Wikipedia.).
if(do_lower_case)
{
text = (const char*)utf8proc_NFD((const utf8proc_uint8_t*)text);
} }
size_t word_bytes = std::strlen(text); size_t word_bytes = std::strlen(text);
...@@ -127,39 +133,56 @@ namespace cuBERT { ...@@ -127,39 +133,56 @@ namespace cuBERT {
int cp; int cp;
char dst[4]; char dst[4];
while (word_bytes > 0) { while(word_bytes > 0)
int len = utf8proc_iterate((const utf8proc_uint8_t *) text + subpos, word_bytes, &cp); {
if (len < 0) { int len = utf8proc_iterate((const utf8proc_uint8_t*)text + subpos, word_bytes, &cp);
if(len < 0)
{
std::cerr << "UTF-8 decode error: " << text << std::endl; std::cerr << "UTF-8 decode error: " << text << std::endl;
break; break;
} }
if (do_lower_case) { if(do_lower_case)
{
cp = utf8proc_tolower(cp); cp = utf8proc_tolower(cp);
} }
const char *cat = utf8proc_category_string(cp); const char* cat = utf8proc_category_string(cp);
if (cp == 0 || cp == 0xfffd || _is_control(cp, cat)) { if(cp == 0 || cp == 0xfffd || _is_control(cp, cat))
{
// pass // pass
} else if (do_lower_case && cat[0] == 'M' && cat[1] == 'n') { }
else if(do_lower_case && cat[0] == 'M' && cat[1] == 'n')
{
// pass // pass
} else if (_is_whitespace(cp, cat)) { }
else if(_is_whitespace(cp, cat))
{
new_token = true; new_token = true;
} else { }
else
{
size_t dst_len = len; size_t dst_len = len;
const char *dst_ptr = text + subpos; const char* dst_ptr = text + subpos;
if (do_lower_case) { if(do_lower_case)
dst_len = utf8proc_encode_char(cp, (utf8proc_uint8_t *) dst); {
dst_len = utf8proc_encode_char(cp, (utf8proc_uint8_t*)dst);
dst_ptr = dst; dst_ptr = dst;
} }
if (_is_punctuation(cp, cat) || _is_chinese_char(cp)) { if(_is_punctuation(cp, cat) || _is_chinese_char(cp))
{
output_tokens->emplace_back(dst_ptr, dst_len); output_tokens->emplace_back(dst_ptr, dst_len);
new_token = true; new_token = true;
} else { }
if (new_token) { else
{
if(new_token)
{
output_tokens->emplace_back(dst_ptr, dst_len); output_tokens->emplace_back(dst_ptr, dst_len);
new_token = false; new_token = false;
} else { }
else
{
output_tokens->at(output_tokens->size() - 1).append(dst_ptr, dst_len); output_tokens->at(output_tokens->size() - 1).append(dst_ptr, dst_len);
} }
} }
...@@ -169,33 +192,38 @@ namespace cuBERT { ...@@ -169,33 +192,38 @@ namespace cuBERT {
subpos = subpos + len; subpos = subpos + len;
// early terminate // early terminate
if (output_tokens->size() >= max_length) { if(output_tokens->size() >= max_length)
{
break; break;
} }
} }
if (do_lower_case) { if(do_lower_case)
free((void *) text); {
free((void*)text);
} }
} }
void WordpieceTokenizer::tokenize(const std::string &token, std::vector<std::string> *output_tokens) { void WordpieceTokenizer::tokenize(const std::string& token, std::vector<std::string>* output_tokens)
if (token.size() > max_input_chars_per_word) { // FIXME: slightly different {
if(token.size() > max_input_chars_per_word)
{ // FIXME: slightly different
output_tokens->push_back(unk_token); output_tokens->push_back(unk_token);
return; return;
} }
size_t output_tokens_len = output_tokens->size(); size_t output_tokens_len = output_tokens->size();
for (size_t start = 0; start < token.size();) { for(size_t start = 0; start < token.size();)
{
bool is_bad = true; bool is_bad = true;
// TODO: can be optimized by prefix-tree // TODO: can be optimized by prefix-tree
for (size_t end = token.size(); start < end; --end) { // FIXME: slightly different for(size_t end = token.size(); start < end; --end)
std::string substr = start > 0 { // FIXME: slightly different
? "##" + token.substr(start, end - start) std::string substr = start > 0 ? "##" + token.substr(start, end - start)
: token.substr(start, end - start); : token.substr(start, end - start);
if (vocab->count(substr)) { if(vocab->count(substr))
{
is_bad = false; is_bad = false;
output_tokens->push_back(substr); output_tokens->push_back(substr);
start = end; start = end;
...@@ -203,27 +231,32 @@ namespace cuBERT { ...@@ -203,27 +231,32 @@ namespace cuBERT {
} }
} }
if (is_bad) { if(is_bad)
{
output_tokens->resize(output_tokens_len); output_tokens->resize(output_tokens_len);
output_tokens->push_back(unk_token); output_tokens->push_back(unk_token);
return; return;
} }
} }
} }
void FullTokenizer::tokenize(const char *text, std::vector<std::string> *output_tokens, size_t max_length) { void FullTokenizer::tokenize(const char* text,
std::vector<std::string>* output_tokens,
size_t max_length)
{
std::vector<std::string> tokens; std::vector<std::string> tokens;
tokens.reserve(max_length); tokens.reserve(max_length);
basic_tokenizer->tokenize(text, &tokens, max_length); basic_tokenizer->tokenize(text, &tokens, max_length);
for (const auto &token : tokens) { for(const auto& token : tokens)
{
wordpiece_tokenizer->tokenize(token, output_tokens); wordpiece_tokenizer->tokenize(token, output_tokens);
// early terminate // early terminate
if (output_tokens->size() >= max_length) { if(output_tokens->size() >= max_length)
{
break; break;
} }
} }
}
} }
} // namespace cuBERT
#ifndef CUBERT_TOKENIZATION_H #ifndef CUBERT_TOKENIZATION_H
#define CUBERT_TOKENIZATION_H #define CUBERT_TOKENIZATION_H
#include <iostream>
#include <string> #include <string>
#include <vector>
#include <unordered_map> #include <unordered_map>
#include <iostream> #include <vector>
namespace cuBERT { namespace cuBERT {
void load_vocab(const char *vocab_file, std::unordered_map<std::string, uint64_t> *vocab); void load_vocab(const char* vocab_file, std::unordered_map<std::string, uint64_t>* vocab);
/** /**
* Checks whether `chars` is a whitespace character. * Checks whether `chars` is a whitespace character.
* @param c * @param c
* @return * @return
*/ */
bool _is_whitespace(int c); bool _is_whitespace(int c);
/** /**
* Checks whether `chars` is a control character. * Checks whether `chars` is a control character.
* @param c * @param c
* @return * @return
*/ */
bool _is_control(int c); bool _is_control(int c);
/** /**
* Checks whether `chars` is a punctuation character. * Checks whether `chars` is a punctuation character.
* @param cp * @param cp
* @return * @return
*/ */
bool _is_punctuation(int cp); bool _is_punctuation(int cp);
/** /**
* Runs basic tokenization (punctuation splitting, lower casing, etc.). * Runs basic tokenization (punctuation splitting, lower casing, etc.).
*/ */
class BasicTokenizer { class BasicTokenizer
{
public: public:
/** /**
* Constructs a BasicTokenizer. * Constructs a BasicTokenizer.
...@@ -42,7 +43,7 @@ namespace cuBERT { ...@@ -42,7 +43,7 @@ namespace cuBERT {
*/ */
explicit BasicTokenizer(bool do_lower_case = true) : do_lower_case(do_lower_case) {} explicit BasicTokenizer(bool do_lower_case = true) : do_lower_case(do_lower_case) {}
BasicTokenizer(const BasicTokenizer &other) = delete; BasicTokenizer(const BasicTokenizer& other) = delete;
virtual ~BasicTokenizer() = default; virtual ~BasicTokenizer() = default;
...@@ -51,15 +52,16 @@ namespace cuBERT { ...@@ -51,15 +52,16 @@ namespace cuBERT {
* *
* to_lower * to_lower
* _run_strip_accents Strips accents from a piece of text. * _run_strip_accents Strips accents from a piece of text.
* _clean_text Performs invalid character removal and whitespace cleanup on text. * _clean_text Performs invalid character removal and whitespace cleanup on
* _tokenize_chinese_chars Adds whitespace around any CJK character. * text. _tokenize_chinese_chars Adds whitespace around any CJK character.
* _run_split_on_punc Splits punctuation on a piece of text. * _run_split_on_punc Splits punctuation on a piece of text.
* whitespace_tokenize Runs basic whitespace cleaning and splitting on a piece of text. * whitespace_tokenize Runs basic whitespace cleaning and splitting on a piece
* of text.
* *
* @param text * @param text
* @param output_tokens * @param output_tokens
*/ */
void tokenize(const char *text, std::vector<std::string> *output_tokens, size_t max_length); void tokenize(const char* text, std::vector<std::string>* output_tokens, size_t max_length);
private: private:
const bool do_lower_case; const bool do_lower_case;
...@@ -70,20 +72,22 @@ namespace cuBERT { ...@@ -70,20 +72,22 @@ namespace cuBERT {
* @return * @return
*/ */
inline static bool _is_chinese_char(int cp); inline static bool _is_chinese_char(int cp);
}; };
/** /**
* Runs WordPiece tokenziation. * Runs WordPiece tokenziation.
*/ */
class WordpieceTokenizer { class WordpieceTokenizer
{
public: public:
explicit WordpieceTokenizer( explicit WordpieceTokenizer(std::unordered_map<std::string, uint64_t>* vocab,
std::unordered_map<std::string, uint64_t> *vocab,
std::string unk_token = "[UNK]", std::string unk_token = "[UNK]",
int max_input_chars_per_word = 200 int max_input_chars_per_word = 200)
) : vocab(vocab), unk_token(unk_token), max_input_chars_per_word(max_input_chars_per_word) {} : vocab(vocab), unk_token(unk_token), max_input_chars_per_word(max_input_chars_per_word)
{
}
WordpieceTokenizer(const WordpieceTokenizer &other) = delete; WordpieceTokenizer(const WordpieceTokenizer& other) = delete;
virtual ~WordpieceTokenizer() = default; virtual ~WordpieceTokenizer() = default;
...@@ -97,67 +101,77 @@ namespace cuBERT { ...@@ -97,67 +101,77 @@ namespace cuBERT {
* input = "unaffable" * input = "unaffable"
* output = ["un", "##aff", "##able"] * output = ["un", "##aff", "##able"]
* *
* @param text A single token or whitespace separated tokens. This should have already been passed through `BasicTokenizer. * @param text A single token or whitespace separated tokens. This should have
* already been passed through `BasicTokenizer.
* @param output_tokens A list of wordpiece tokens. * @param output_tokens A list of wordpiece tokens.
*/ */
void tokenize(const std::string &text, std::vector<std::string> *output_tokens); void tokenize(const std::string& text, std::vector<std::string>* output_tokens);
private: private:
const std::unordered_map<std::string, uint64_t> *vocab; const std::unordered_map<std::string, uint64_t>* vocab;
const std::string unk_token; const std::string unk_token;
const int max_input_chars_per_word; const int max_input_chars_per_word;
}; };
/** /**
* Runs end-to-end tokenziation. * Runs end-to-end tokenziation.
*/ */
class FullTokenizer { class FullTokenizer
{
public: public:
FullTokenizer(const char *vocab_file, bool do_lower_case = true) { FullTokenizer(const char* vocab_file, bool do_lower_case = true)
{
vocab = new std::unordered_map<std::string, uint64_t>(); vocab = new std::unordered_map<std::string, uint64_t>();
load_vocab(vocab_file, vocab); load_vocab(vocab_file, vocab);
basic_tokenizer = new BasicTokenizer(do_lower_case); basic_tokenizer = new BasicTokenizer(do_lower_case);
wordpiece_tokenizer = new WordpieceTokenizer(vocab); wordpiece_tokenizer = new WordpieceTokenizer(vocab);
} }
~FullTokenizer() { ~FullTokenizer()
if (wordpiece_tokenizer != NULL){ {
if(wordpiece_tokenizer != NULL)
{
wordpiece_tokenizer = NULL; wordpiece_tokenizer = NULL;
} }
delete wordpiece_tokenizer; delete wordpiece_tokenizer;
if (basic_tokenizer != NULL){ if(basic_tokenizer != NULL)
{
basic_tokenizer = NULL; basic_tokenizer = NULL;
} }
delete basic_tokenizer; delete basic_tokenizer;
if (vocab != NULL){ if(vocab != NULL)
{
vocab = NULL; vocab = NULL;
} }
delete vocab; delete vocab;
} }
void tokenize(const char *text, std::vector<std::string> *output_tokens, size_t max_length); void tokenize(const char* text, std::vector<std::string>* output_tokens, size_t max_length);
inline uint64_t convert_token_to_id(const std::string &token) { inline uint64_t convert_token_to_id(const std::string& token)
{
auto item = vocab->find(token); auto item = vocab->find(token);
if (item == vocab->end()) { if(item == vocab->end())
{
std::cerr << "vocab missing key: " << token << std::endl; std::cerr << "vocab missing key: " << token << std::endl;
return 0; return 0;
} else { }
else
{
return item->second; return item->second;
} }
} }
void convert_tokens_to_ids(const std::vector<std::string> &tokens, uint64_t *ids); void convert_tokens_to_ids(const std::vector<std::string>& tokens, uint64_t* ids);
private: private:
std::unordered_map<std::string, uint64_t> *vocab; std::unordered_map<std::string, uint64_t>* vocab;
BasicTokenizer *basic_tokenizer; BasicTokenizer* basic_tokenizer;
WordpieceTokenizer *wordpiece_tokenizer; WordpieceTokenizer* wordpiece_tokenizer;
}; };
} } // namespace cuBERT
#endif //CUBERT_TOKENIZATION_H #endif // CUBERT_TOKENIZATION_H
/* -*- mode: c; c-basic-offset: 2; tab-width: 2; indent-tabs-mode: nil -*- */ /* -*- mode: c; c-basic-offset: 2; tab-width: 2; indent-tabs-mode: nil -*- */
/* /*
* Copyright (c) 2014-2021 Steven G. Johnson, Jiahao Chen, Peter Colberg, Tony Kelman, Scott P. Jones, and other contributors. * Copyright (c) 2014-2021 Steven G. Johnson, Jiahao Chen, Peter Colberg, Tony
* Copyright (c) 2009 Public Software Group e. V., Berlin, Germany * Kelman, Scott P. Jones, and other contributors. Copyright (c) 2009 Public
* Software Group e. V., Berlin, Germany
* *
* Permission is hereby granted, free of charge, to any person obtaining a * Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"), * copy of this software and associated documentation files (the "Software"),
...@@ -32,7 +33,6 @@ ...@@ -32,7 +33,6 @@
* Please notice the copyright statement in the file "utf8proc_data.c". * Please notice the copyright statement in the file "utf8proc_data.c".
*/ */
/* /*
* File name: utf8proc.c * File name: utf8proc.c
* *
...@@ -40,36 +40,26 @@ ...@@ -40,36 +40,26 @@
* Implementation of libutf8proc. * Implementation of libutf8proc.
*/ */
#include "utf8proc.h" #include "utf8proc.h"
#ifndef SSIZE_MAX #ifndef SSIZE_MAX
#define SSIZE_MAX ((size_t)SIZE_MAX/2) #define SSIZE_MAX ((size_t)SIZE_MAX / 2)
#endif #endif
#ifndef UINT16_MAX #ifndef UINT16_MAX
# define UINT16_MAX 65535U #define UINT16_MAX 65535U
#endif #endif
#include "utf8proc_data.c" #include "utf8proc_data.c"
UTF8PROC_DLLEXPORT const utf8proc_int8_t utf8proc_utf8class[256] = { UTF8PROC_DLLEXPORT const utf8proc_int8_t utf8proc_utf8class[256] = {
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0};
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0 };
#define UTF8PROC_HANGUL_SBASE 0xAC00 #define UTF8PROC_HANGUL_SBASE 0xAC00
#define UTF8PROC_HANGUL_LBASE 0x1100 #define UTF8PROC_HANGUL_LBASE 0x1100
...@@ -96,150 +86,182 @@ UTF8PROC_DLLEXPORT const utf8proc_int8_t utf8proc_utf8class[256] = { ...@@ -96,150 +86,182 @@ UTF8PROC_DLLEXPORT const utf8proc_int8_t utf8proc_utf8class[256] = {
be different, being based on ABI compatibility.): */ be different, being based on ABI compatibility.): */
#define STRINGIZEx(x) #x #define STRINGIZEx(x) #x
#define STRINGIZE(x) STRINGIZEx(x) #define STRINGIZE(x) STRINGIZEx(x)
UTF8PROC_DLLEXPORT const char *utf8proc_version(void) { UTF8PROC_DLLEXPORT const char* utf8proc_version(void)
{
return STRINGIZE(UTF8PROC_VERSION_MAJOR) "." STRINGIZE(UTF8PROC_VERSION_MINOR) "." STRINGIZE(UTF8PROC_VERSION_PATCH) ""; return STRINGIZE(UTF8PROC_VERSION_MAJOR) "." STRINGIZE(UTF8PROC_VERSION_MINOR) "." STRINGIZE(UTF8PROC_VERSION_PATCH) "";
} }
UTF8PROC_DLLEXPORT const char *utf8proc_unicode_version(void) { UTF8PROC_DLLEXPORT const char* utf8proc_unicode_version(void) { return "15.0.0"; }
return "15.0.0";
}
UTF8PROC_DLLEXPORT const char *utf8proc_errmsg(utf8proc_ssize_t errcode) { UTF8PROC_DLLEXPORT const char* utf8proc_errmsg(utf8proc_ssize_t errcode)
switch (errcode) { {
case UTF8PROC_ERROR_NOMEM: switch(errcode)
return "Memory for processing UTF-8 data could not be allocated."; {
case UTF8PROC_ERROR_OVERFLOW: case UTF8PROC_ERROR_NOMEM: return "Memory for processing UTF-8 data could not be allocated.";
return "UTF-8 string is too long to be processed."; case UTF8PROC_ERROR_OVERFLOW: return "UTF-8 string is too long to be processed.";
case UTF8PROC_ERROR_INVALIDUTF8: case UTF8PROC_ERROR_INVALIDUTF8: return "Invalid UTF-8 string";
return "Invalid UTF-8 string"; case UTF8PROC_ERROR_NOTASSIGNED: return "Unassigned Unicode code point found in UTF-8 string.";
case UTF8PROC_ERROR_NOTASSIGNED: case UTF8PROC_ERROR_INVALIDOPTS: return "Invalid options for UTF-8 processing chosen.";
return "Unassigned Unicode code point found in UTF-8 string."; default: return "An unknown error occurred while processing UTF-8 data.";
case UTF8PROC_ERROR_INVALIDOPTS:
return "Invalid options for UTF-8 processing chosen.";
default:
return "An unknown error occurred while processing UTF-8 data.";
} }
} }
#define utf_cont(ch) (((ch) & 0xc0) == 0x80) #define utf_cont(ch) (((ch) & 0xc0) == 0x80)
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_iterate( UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_iterate(const utf8proc_uint8_t* str,
const utf8proc_uint8_t *str, utf8proc_ssize_t strlen, utf8proc_int32_t *dst utf8proc_ssize_t strlen,
) { utf8proc_int32_t* dst)
{
utf8proc_int32_t uc; utf8proc_int32_t uc;
const utf8proc_uint8_t *end; const utf8proc_uint8_t* end;
*dst = -1; *dst = -1;
if (!strlen) return 0; if(!strlen)
return 0;
end = str + ((strlen < 0) ? 4 : strlen); end = str + ((strlen < 0) ? 4 : strlen);
uc = *str++; uc = *str++;
if (uc < 0x80) { if(uc < 0x80)
{
*dst = uc; *dst = uc;
return 1; return 1;
} }
// Must be between 0xc2 and 0xf4 inclusive to be valid // Must be between 0xc2 and 0xf4 inclusive to be valid
if ((utf8proc_uint32_t)(uc - 0xc2) > (0xf4-0xc2)) return UTF8PROC_ERROR_INVALIDUTF8; if((utf8proc_uint32_t)(uc - 0xc2) > (0xf4 - 0xc2))
if (uc < 0xe0) { // 2-byte sequence return UTF8PROC_ERROR_INVALIDUTF8;
if(uc < 0xe0)
{ // 2-byte sequence
// Must have valid continuation character // Must have valid continuation character
if (str >= end || !utf_cont(*str)) return UTF8PROC_ERROR_INVALIDUTF8; if(str >= end || !utf_cont(*str))
*dst = ((uc & 0x1f)<<6) | (*str & 0x3f); return UTF8PROC_ERROR_INVALIDUTF8;
*dst = ((uc & 0x1f) << 6) | (*str & 0x3f);
return 2; return 2;
} }
if (uc < 0xf0) { // 3-byte sequence if(uc < 0xf0)
if ((str + 1 >= end) || !utf_cont(*str) || !utf_cont(str[1])) { // 3-byte sequence
if((str + 1 >= end) || !utf_cont(*str) || !utf_cont(str[1]))
return UTF8PROC_ERROR_INVALIDUTF8; return UTF8PROC_ERROR_INVALIDUTF8;
// Check for surrogate chars // Check for surrogate chars
if (uc == 0xed && *str > 0x9f) if(uc == 0xed && *str > 0x9f)
return UTF8PROC_ERROR_INVALIDUTF8; return UTF8PROC_ERROR_INVALIDUTF8;
uc = ((uc & 0xf)<<12) | ((*str & 0x3f)<<6) | (str[1] & 0x3f); uc = ((uc & 0xf) << 12) | ((*str & 0x3f) << 6) | (str[1] & 0x3f);
if (uc < 0x800) if(uc < 0x800)
return UTF8PROC_ERROR_INVALIDUTF8; return UTF8PROC_ERROR_INVALIDUTF8;
*dst = uc; *dst = uc;
return 3; return 3;
} }
// 4-byte sequence // 4-byte sequence
// Must have 3 valid continuation characters // Must have 3 valid continuation characters
if ((str + 2 >= end) || !utf_cont(*str) || !utf_cont(str[1]) || !utf_cont(str[2])) if((str + 2 >= end) || !utf_cont(*str) || !utf_cont(str[1]) || !utf_cont(str[2]))
return UTF8PROC_ERROR_INVALIDUTF8; return UTF8PROC_ERROR_INVALIDUTF8;
// Make sure in correct range (0x10000 - 0x10ffff) // Make sure in correct range (0x10000 - 0x10ffff)
if (uc == 0xf0) { if(uc == 0xf0)
if (*str < 0x90) return UTF8PROC_ERROR_INVALIDUTF8; {
} else if (uc == 0xf4) { if(*str < 0x90)
if (*str > 0x8f) return UTF8PROC_ERROR_INVALIDUTF8; return UTF8PROC_ERROR_INVALIDUTF8;
}
else if(uc == 0xf4)
{
if(*str > 0x8f)
return UTF8PROC_ERROR_INVALIDUTF8;
} }
*dst = ((uc & 7)<<18) | ((*str & 0x3f)<<12) | ((str[1] & 0x3f)<<6) | (str[2] & 0x3f); *dst = ((uc & 7) << 18) | ((*str & 0x3f) << 12) | ((str[1] & 0x3f) << 6) | (str[2] & 0x3f);
return 4; return 4;
} }
UTF8PROC_DLLEXPORT utf8proc_bool utf8proc_codepoint_valid(utf8proc_int32_t uc) { UTF8PROC_DLLEXPORT utf8proc_bool utf8proc_codepoint_valid(utf8proc_int32_t uc)
return (((utf8proc_uint32_t)uc)-0xd800 > 0x07ff) && ((utf8proc_uint32_t)uc < 0x110000); {
return (((utf8proc_uint32_t)uc) - 0xd800 > 0x07ff) && ((utf8proc_uint32_t)uc < 0x110000);
} }
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_encode_char(utf8proc_int32_t uc, utf8proc_uint8_t *dst) { UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_encode_char(utf8proc_int32_t uc, utf8proc_uint8_t* dst)
if (uc < 0x00) { {
if(uc < 0x00)
{
return 0; return 0;
} else if (uc < 0x80) { }
dst[0] = (utf8proc_uint8_t) uc; else if(uc < 0x80)
{
dst[0] = (utf8proc_uint8_t)uc;
return 1; return 1;
} else if (uc < 0x800) { }
else if(uc < 0x800)
{
dst[0] = (utf8proc_uint8_t)(0xC0 + (uc >> 6)); dst[0] = (utf8proc_uint8_t)(0xC0 + (uc >> 6));
dst[1] = (utf8proc_uint8_t)(0x80 + (uc & 0x3F)); dst[1] = (utf8proc_uint8_t)(0x80 + (uc & 0x3F));
return 2; return 2;
// Note: we allow encoding 0xd800-0xdfff here, so as not to change // Note: we allow encoding 0xd800-0xdfff here, so as not to change
// the API, however, these are actually invalid in UTF-8 // the API, however, these are actually invalid in UTF-8
} else if (uc < 0x10000) { }
else if(uc < 0x10000)
{
dst[0] = (utf8proc_uint8_t)(0xE0 + (uc >> 12)); dst[0] = (utf8proc_uint8_t)(0xE0 + (uc >> 12));
dst[1] = (utf8proc_uint8_t)(0x80 + ((uc >> 6) & 0x3F)); dst[1] = (utf8proc_uint8_t)(0x80 + ((uc >> 6) & 0x3F));
dst[2] = (utf8proc_uint8_t)(0x80 + (uc & 0x3F)); dst[2] = (utf8proc_uint8_t)(0x80 + (uc & 0x3F));
return 3; return 3;
} else if (uc < 0x110000) { }
else if(uc < 0x110000)
{
dst[0] = (utf8proc_uint8_t)(0xF0 + (uc >> 18)); dst[0] = (utf8proc_uint8_t)(0xF0 + (uc >> 18));
dst[1] = (utf8proc_uint8_t)(0x80 + ((uc >> 12) & 0x3F)); dst[1] = (utf8proc_uint8_t)(0x80 + ((uc >> 12) & 0x3F));
dst[2] = (utf8proc_uint8_t)(0x80 + ((uc >> 6) & 0x3F)); dst[2] = (utf8proc_uint8_t)(0x80 + ((uc >> 6) & 0x3F));
dst[3] = (utf8proc_uint8_t)(0x80 + (uc & 0x3F)); dst[3] = (utf8proc_uint8_t)(0x80 + (uc & 0x3F));
return 4; return 4;
} else return 0; }
else
return 0;
} }
/* internal version used for inserting 0xff bytes between graphemes */ /* internal version used for inserting 0xff bytes between graphemes */
static utf8proc_ssize_t charbound_encode_char(utf8proc_int32_t uc, utf8proc_uint8_t *dst) { static utf8proc_ssize_t charbound_encode_char(utf8proc_int32_t uc, utf8proc_uint8_t* dst)
if (uc < 0x00) { {
if (uc == -1) { /* internal value used for grapheme breaks */ if(uc < 0x00)
{
if(uc == -1)
{ /* internal value used for grapheme breaks */
dst[0] = (utf8proc_uint8_t)0xFF; dst[0] = (utf8proc_uint8_t)0xFF;
return 1; return 1;
} }
return 0; return 0;
} else if (uc < 0x80) { }
else if(uc < 0x80)
{
dst[0] = (utf8proc_uint8_t)uc; dst[0] = (utf8proc_uint8_t)uc;
return 1; return 1;
} else if (uc < 0x800) { }
else if(uc < 0x800)
{
dst[0] = (utf8proc_uint8_t)(0xC0 + (uc >> 6)); dst[0] = (utf8proc_uint8_t)(0xC0 + (uc >> 6));
dst[1] = (utf8proc_uint8_t)(0x80 + (uc & 0x3F)); dst[1] = (utf8proc_uint8_t)(0x80 + (uc & 0x3F));
return 2; return 2;
} else if (uc < 0x10000) { }
else if(uc < 0x10000)
{
dst[0] = (utf8proc_uint8_t)(0xE0 + (uc >> 12)); dst[0] = (utf8proc_uint8_t)(0xE0 + (uc >> 12));
dst[1] = (utf8proc_uint8_t)(0x80 + ((uc >> 6) & 0x3F)); dst[1] = (utf8proc_uint8_t)(0x80 + ((uc >> 6) & 0x3F));
dst[2] = (utf8proc_uint8_t)(0x80 + (uc & 0x3F)); dst[2] = (utf8proc_uint8_t)(0x80 + (uc & 0x3F));
return 3; return 3;
} else if (uc < 0x110000) { }
else if(uc < 0x110000)
{
dst[0] = (utf8proc_uint8_t)(0xF0 + (uc >> 18)); dst[0] = (utf8proc_uint8_t)(0xF0 + (uc >> 18));
dst[1] = (utf8proc_uint8_t)(0x80 + ((uc >> 12) & 0x3F)); dst[1] = (utf8proc_uint8_t)(0x80 + ((uc >> 12) & 0x3F));
dst[2] = (utf8proc_uint8_t)(0x80 + ((uc >> 6) & 0x3F)); dst[2] = (utf8proc_uint8_t)(0x80 + ((uc >> 6) & 0x3F));
dst[3] = (utf8proc_uint8_t)(0x80 + (uc & 0x3F)); dst[3] = (utf8proc_uint8_t)(0x80 + (uc & 0x3F));
return 4; return 4;
} else return 0; }
else
return 0;
} }
/* internal "unsafe" version that does not check whether uc is in range */ /* internal "unsafe" version that does not check whether uc is in range */
static const utf8proc_property_t *unsafe_get_property(utf8proc_int32_t uc) { static const utf8proc_property_t* unsafe_get_property(utf8proc_int32_t uc)
{
/* ASSERT: uc >= 0 && uc < 0x110000 */ /* ASSERT: uc >= 0 && uc < 0x110000 */
return utf8proc_properties + ( return utf8proc_properties +
utf8proc_stage2table[ (utf8proc_stage2table[utf8proc_stage1table[uc >> 8] + (uc & 0xFF)]);
utf8proc_stage1table[uc >> 8] + (uc & 0xFF)
]
);
} }
UTF8PROC_DLLEXPORT const utf8proc_property_t *utf8proc_get_property(utf8proc_int32_t uc) { UTF8PROC_DLLEXPORT const utf8proc_property_t* utf8proc_get_property(utf8proc_int32_t uc)
{
return uc < 0 || uc >= 0x110000 ? utf8proc_properties : unsafe_get_property(uc); return uc < 0 || uc >= 0x110000 ? utf8proc_properties : unsafe_get_property(uc);
} }
...@@ -250,49 +272,72 @@ UTF8PROC_DLLEXPORT const utf8proc_property_t *utf8proc_get_property(utf8proc_int ...@@ -250,49 +272,72 @@ UTF8PROC_DLLEXPORT const utf8proc_property_t *utf8proc_get_property(utf8proc_int
http://www.unicode.org/reports/tr29/tr29-29.html http://www.unicode.org/reports/tr29/tr29-29.html
CAVEATS: CAVEATS:
Please note that evaluation of GB10 (grapheme breaks between emoji zwj sequences) Please note that evaluation of GB10 (grapheme breaks between emoji zwj
and GB 12/13 (regional indicator code points) require knowledge of previous characters sequences) and GB 12/13 (regional indicator code points) require knowledge of
and are thus not handled by this function. This may result in an incorrect break before previous characters and are thus not handled by this function. This may result
an E_Modifier class codepoint and an incorrectly missing break between two in an incorrect break before an E_Modifier class codepoint and an incorrectly
REGIONAL_INDICATOR class code points if such support does not exist in the caller. missing break between two REGIONAL_INDICATOR class code points if such support
does not exist in the caller.
See the special support in grapheme_break_extended, for required bookkeeping by the caller.
See the special support in grapheme_break_extended, for required bookkeeping
by the caller.
*/ */
static utf8proc_bool grapheme_break_simple(int lbc, int tbc) { static utf8proc_bool grapheme_break_simple(int lbc, int tbc)
return {
(lbc == UTF8PROC_BOUNDCLASS_START) ? true : // GB1 return (lbc == UTF8PROC_BOUNDCLASS_START) ? true : // GB1
(lbc == UTF8PROC_BOUNDCLASS_CR && // GB3 (lbc == UTF8PROC_BOUNDCLASS_CR && // GB3
tbc == UTF8PROC_BOUNDCLASS_LF) ? false : // --- tbc == UTF8PROC_BOUNDCLASS_LF)
(lbc >= UTF8PROC_BOUNDCLASS_CR && lbc <= UTF8PROC_BOUNDCLASS_CONTROL) ? true : // GB4 ? false
(tbc >= UTF8PROC_BOUNDCLASS_CR && tbc <= UTF8PROC_BOUNDCLASS_CONTROL) ? true : // GB5 : // ---
(lbc >= UTF8PROC_BOUNDCLASS_CR && lbc <= UTF8PROC_BOUNDCLASS_CONTROL) ? true
: // GB4
(tbc >= UTF8PROC_BOUNDCLASS_CR && tbc <= UTF8PROC_BOUNDCLASS_CONTROL) ? true
: // GB5
(lbc == UTF8PROC_BOUNDCLASS_L && // GB6 (lbc == UTF8PROC_BOUNDCLASS_L && // GB6
(tbc == UTF8PROC_BOUNDCLASS_L || // --- (tbc == UTF8PROC_BOUNDCLASS_L || // ---
tbc == UTF8PROC_BOUNDCLASS_V || // --- tbc == UTF8PROC_BOUNDCLASS_V || // ---
tbc == UTF8PROC_BOUNDCLASS_LV || // --- tbc == UTF8PROC_BOUNDCLASS_LV || // ---
tbc == UTF8PROC_BOUNDCLASS_LVT)) ? false : // --- tbc == UTF8PROC_BOUNDCLASS_LVT))
? false
: // ---
((lbc == UTF8PROC_BOUNDCLASS_LV || // GB7 ((lbc == UTF8PROC_BOUNDCLASS_LV || // GB7
lbc == UTF8PROC_BOUNDCLASS_V) && // --- lbc == UTF8PROC_BOUNDCLASS_V) && // ---
(tbc == UTF8PROC_BOUNDCLASS_V || // --- (tbc == UTF8PROC_BOUNDCLASS_V || // ---
tbc == UTF8PROC_BOUNDCLASS_T)) ? false : // --- tbc == UTF8PROC_BOUNDCLASS_T))
? false
: // ---
((lbc == UTF8PROC_BOUNDCLASS_LVT || // GB8 ((lbc == UTF8PROC_BOUNDCLASS_LVT || // GB8
lbc == UTF8PROC_BOUNDCLASS_T) && // --- lbc == UTF8PROC_BOUNDCLASS_T) && // ---
tbc == UTF8PROC_BOUNDCLASS_T) ? false : // --- tbc == UTF8PROC_BOUNDCLASS_T)
? false
: // ---
(tbc == UTF8PROC_BOUNDCLASS_EXTEND || // GB9 (tbc == UTF8PROC_BOUNDCLASS_EXTEND || // GB9
tbc == UTF8PROC_BOUNDCLASS_ZWJ || // --- tbc == UTF8PROC_BOUNDCLASS_ZWJ || // ---
tbc == UTF8PROC_BOUNDCLASS_SPACINGMARK || // GB9a tbc == UTF8PROC_BOUNDCLASS_SPACINGMARK || // GB9a
lbc == UTF8PROC_BOUNDCLASS_PREPEND) ? false : // GB9b lbc == UTF8PROC_BOUNDCLASS_PREPEND)
(lbc == UTF8PROC_BOUNDCLASS_E_ZWG && // GB11 (requires additional handling below) ? false
tbc == UTF8PROC_BOUNDCLASS_EXTENDED_PICTOGRAPHIC) ? false : // ---- : // GB9b
(lbc == UTF8PROC_BOUNDCLASS_REGIONAL_INDICATOR && // GB12/13 (requires additional handling below) (lbc == UTF8PROC_BOUNDCLASS_E_ZWG && // GB11 (requires additional
tbc == UTF8PROC_BOUNDCLASS_REGIONAL_INDICATOR) ? false : // ---- // handling below)
tbc == UTF8PROC_BOUNDCLASS_EXTENDED_PICTOGRAPHIC)
? false
: // ----
(lbc == UTF8PROC_BOUNDCLASS_REGIONAL_INDICATOR && // GB12/13
// (requires
// additional
// handling below)
tbc == UTF8PROC_BOUNDCLASS_REGIONAL_INDICATOR)
? false
: // ----
true; // GB999 true; // GB999
} }
static utf8proc_bool grapheme_break_extended(int lbc, int tbc, utf8proc_int32_t *state) static utf8proc_bool grapheme_break_extended(int lbc, int tbc, utf8proc_int32_t* state)
{ {
if (state) { if(state)
{
int lbc_override; int lbc_override;
if (*state == UTF8PROC_BOUNDCLASS_START) if(*state == UTF8PROC_BOUNDCLASS_START)
*state = lbc_override = lbc; *state = lbc_override = lbc;
else else
lbc_override = *state; lbc_override = *state;
...@@ -303,13 +348,14 @@ static utf8proc_bool grapheme_break_extended(int lbc, int tbc, utf8proc_int32_t ...@@ -303,13 +348,14 @@ static utf8proc_bool grapheme_break_extended(int lbc, int tbc, utf8proc_int32_t
// second RI's bound class to UTF8PROC_BOUNDCLASS_OTHER, to force a break // second RI's bound class to UTF8PROC_BOUNDCLASS_OTHER, to force a break
// after that character according to GB999 (unless of course such a break is // after that character according to GB999 (unless of course such a break is
// forbidden by a different rule such as GB9). // forbidden by a different rule such as GB9).
if (*state == tbc && tbc == UTF8PROC_BOUNDCLASS_REGIONAL_INDICATOR) if(*state == tbc && tbc == UTF8PROC_BOUNDCLASS_REGIONAL_INDICATOR)
*state = UTF8PROC_BOUNDCLASS_OTHER; *state = UTF8PROC_BOUNDCLASS_OTHER;
// Special support for GB11 (emoji extend* zwj / emoji) // Special support for GB11 (emoji extend* zwj / emoji)
else if (*state == UTF8PROC_BOUNDCLASS_EXTENDED_PICTOGRAPHIC) { else if(*state == UTF8PROC_BOUNDCLASS_EXTENDED_PICTOGRAPHIC)
if (tbc == UTF8PROC_BOUNDCLASS_EXTEND) // fold EXTEND codepoints into emoji {
if(tbc == UTF8PROC_BOUNDCLASS_EXTEND) // fold EXTEND codepoints into emoji
*state = UTF8PROC_BOUNDCLASS_EXTENDED_PICTOGRAPHIC; *state = UTF8PROC_BOUNDCLASS_EXTENDED_PICTOGRAPHIC;
else if (tbc == UTF8PROC_BOUNDCLASS_ZWJ) else if(tbc == UTF8PROC_BOUNDCLASS_ZWJ)
*state = UTF8PROC_BOUNDCLASS_E_ZWG; // state to record emoji+zwg combo *state = UTF8PROC_BOUNDCLASS_E_ZWG; // state to record emoji+zwg combo
else else
*state = tbc; *state = tbc;
...@@ -323,24 +369,24 @@ static utf8proc_bool grapheme_break_extended(int lbc, int tbc, utf8proc_int32_t ...@@ -323,24 +369,24 @@ static utf8proc_bool grapheme_break_extended(int lbc, int tbc, utf8proc_int32_t
return grapheme_break_simple(lbc, tbc); return grapheme_break_simple(lbc, tbc);
} }
UTF8PROC_DLLEXPORT utf8proc_bool utf8proc_grapheme_break_stateful( UTF8PROC_DLLEXPORT utf8proc_bool utf8proc_grapheme_break_stateful(utf8proc_int32_t c1,
utf8proc_int32_t c1, utf8proc_int32_t c2, utf8proc_int32_t *state) { utf8proc_int32_t c2,
utf8proc_int32_t* state)
return grapheme_break_extended(utf8proc_get_property(c1)->boundclass, {
utf8proc_get_property(c2)->boundclass, return grapheme_break_extended(
state); utf8proc_get_property(c1)->boundclass, utf8proc_get_property(c2)->boundclass, state);
} }
UTF8PROC_DLLEXPORT utf8proc_bool utf8proc_grapheme_break(utf8proc_int32_t c1, utf8proc_int32_t c2)
UTF8PROC_DLLEXPORT utf8proc_bool utf8proc_grapheme_break( {
utf8proc_int32_t c1, utf8proc_int32_t c2) {
return utf8proc_grapheme_break_stateful(c1, c2, NULL); return utf8proc_grapheme_break_stateful(c1, c2, NULL);
} }
static utf8proc_int32_t seqindex_decode_entry(const utf8proc_uint16_t **entry) static utf8proc_int32_t seqindex_decode_entry(const utf8proc_uint16_t** entry)
{ {
utf8proc_int32_t entry_cp = **entry; utf8proc_int32_t entry_cp = **entry;
if ((entry_cp & 0xF800) == 0xD800) { if((entry_cp & 0xF800) == 0xD800)
{
*entry = *entry + 1; *entry = *entry + 1;
entry_cp = ((entry_cp & 0x03FF) << 10) | (**entry & 0x03FF); entry_cp = ((entry_cp & 0x03FF) << 10) | (**entry & 0x03FF);
entry_cp += 0x10000; entry_cp += 0x10000;
...@@ -350,25 +396,35 @@ static utf8proc_int32_t seqindex_decode_entry(const utf8proc_uint16_t **entry) ...@@ -350,25 +396,35 @@ static utf8proc_int32_t seqindex_decode_entry(const utf8proc_uint16_t **entry)
static utf8proc_int32_t seqindex_decode_index(const utf8proc_uint32_t seqindex) static utf8proc_int32_t seqindex_decode_index(const utf8proc_uint32_t seqindex)
{ {
const utf8proc_uint16_t *entry = &utf8proc_sequences[seqindex]; const utf8proc_uint16_t* entry = &utf8proc_sequences[seqindex];
return seqindex_decode_entry(&entry); return seqindex_decode_entry(&entry);
} }
static utf8proc_ssize_t seqindex_write_char_decomposed(utf8proc_uint16_t seqindex, utf8proc_int32_t *dst, utf8proc_ssize_t bufsize, utf8proc_option_t options, int *last_boundclass) { static utf8proc_ssize_t seqindex_write_char_decomposed(utf8proc_uint16_t seqindex,
utf8proc_int32_t* dst,
utf8proc_ssize_t bufsize,
utf8proc_option_t options,
int* last_boundclass)
{
utf8proc_ssize_t written = 0; utf8proc_ssize_t written = 0;
const utf8proc_uint16_t *entry = &utf8proc_sequences[seqindex & 0x3FFF]; const utf8proc_uint16_t* entry = &utf8proc_sequences[seqindex & 0x3FFF];
int len = seqindex >> 14; int len = seqindex >> 14;
if (len >= 3) { if(len >= 3)
{
len = *entry; len = *entry;
entry++; entry++;
} }
for (; len >= 0; entry++, len--) { for(; len >= 0; entry++, len--)
{
utf8proc_int32_t entry_cp = seqindex_decode_entry(&entry); utf8proc_int32_t entry_cp = seqindex_decode_entry(&entry);
written += utf8proc_decompose_char(entry_cp, dst+written, written += utf8proc_decompose_char(entry_cp,
(bufsize > written) ? (bufsize - written) : 0, options, dst + written,
(bufsize > written) ? (bufsize - written) : 0,
options,
last_boundclass); last_boundclass);
if (written < 0) return UTF8PROC_ERROR_OVERFLOW; if(written < 0)
return UTF8PROC_ERROR_OVERFLOW;
} }
return written; return written;
} }
...@@ -393,190 +449,254 @@ UTF8PROC_DLLEXPORT utf8proc_int32_t utf8proc_totitle(utf8proc_int32_t c) ...@@ -393,190 +449,254 @@ UTF8PROC_DLLEXPORT utf8proc_int32_t utf8proc_totitle(utf8proc_int32_t c)
UTF8PROC_DLLEXPORT int utf8proc_islower(utf8proc_int32_t c) UTF8PROC_DLLEXPORT int utf8proc_islower(utf8proc_int32_t c)
{ {
const utf8proc_property_t *p = utf8proc_get_property(c); const utf8proc_property_t* p = utf8proc_get_property(c);
return p->lowercase_seqindex != p->uppercase_seqindex && p->lowercase_seqindex == UINT16_MAX; return p->lowercase_seqindex != p->uppercase_seqindex && p->lowercase_seqindex == UINT16_MAX;
} }
UTF8PROC_DLLEXPORT int utf8proc_isupper(utf8proc_int32_t c) UTF8PROC_DLLEXPORT int utf8proc_isupper(utf8proc_int32_t c)
{ {
const utf8proc_property_t *p = utf8proc_get_property(c); const utf8proc_property_t* p = utf8proc_get_property(c);
return p->lowercase_seqindex != p->uppercase_seqindex && p->uppercase_seqindex == UINT16_MAX && p->category != UTF8PROC_CATEGORY_LT; return p->lowercase_seqindex != p->uppercase_seqindex && p->uppercase_seqindex == UINT16_MAX &&
p->category != UTF8PROC_CATEGORY_LT;
} }
/* return a character width analogous to wcwidth (except portable and /* return a character width analogous to wcwidth (except portable and
hopefully less buggy than most system wcwidth functions). */ hopefully less buggy than most system wcwidth functions). */
UTF8PROC_DLLEXPORT int utf8proc_charwidth(utf8proc_int32_t c) { UTF8PROC_DLLEXPORT int utf8proc_charwidth(utf8proc_int32_t c)
{
return utf8proc_get_property(c)->charwidth; return utf8proc_get_property(c)->charwidth;
} }
UTF8PROC_DLLEXPORT utf8proc_category_t utf8proc_category(utf8proc_int32_t c) { UTF8PROC_DLLEXPORT utf8proc_category_t utf8proc_category(utf8proc_int32_t c)
return (utf8proc_category_t) utf8proc_get_property(c)->category; {
return (utf8proc_category_t)utf8proc_get_property(c)->category;
} }
UTF8PROC_DLLEXPORT const char *utf8proc_category_string(utf8proc_int32_t c) { UTF8PROC_DLLEXPORT const char* utf8proc_category_string(utf8proc_int32_t c)
static const char s[][3] = {"Cn","Lu","Ll","Lt","Lm","Lo","Mn","Mc","Me","Nd","Nl","No","Pc","Pd","Ps","Pe","Pi","Pf","Po","Sm","Sc","Sk","So","Zs","Zl","Zp","Cc","Cf","Cs","Co"}; {
static const char s[][3] = {"Cn", "Lu", "Ll", "Lt", "Lm", "Lo", "Mn", "Mc", "Me", "Nd",
"Nl", "No", "Pc", "Pd", "Ps", "Pe", "Pi", "Pf", "Po", "Sm",
"Sc", "Sk", "So", "Zs", "Zl", "Zp", "Cc", "Cf", "Cs", "Co"};
return s[utf8proc_category(c)]; return s[utf8proc_category(c)];
} }
#define utf8proc_decompose_lump(replacement_uc) \ #define utf8proc_decompose_lump(replacement_uc) \
return utf8proc_decompose_char((replacement_uc), dst, bufsize, \ return utf8proc_decompose_char( \
options & ~(unsigned int)UTF8PROC_LUMP, last_boundclass) (replacement_uc), dst, bufsize, options & ~(unsigned int)UTF8PROC_LUMP, last_boundclass)
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose_char(utf8proc_int32_t uc, utf8proc_int32_t *dst, utf8proc_ssize_t bufsize, utf8proc_option_t options, int *last_boundclass) { UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose_char(utf8proc_int32_t uc,
const utf8proc_property_t *property; utf8proc_int32_t* dst,
utf8proc_ssize_t bufsize,
utf8proc_option_t options,
int* last_boundclass)
{
const utf8proc_property_t* property;
utf8proc_propval_t category; utf8proc_propval_t category;
utf8proc_int32_t hangul_sindex; utf8proc_int32_t hangul_sindex;
if (uc < 0 || uc >= 0x110000) return UTF8PROC_ERROR_NOTASSIGNED; if(uc < 0 || uc >= 0x110000)
return UTF8PROC_ERROR_NOTASSIGNED;
property = unsafe_get_property(uc); property = unsafe_get_property(uc);
category = property->category; category = property->category;
hangul_sindex = uc - UTF8PROC_HANGUL_SBASE; hangul_sindex = uc - UTF8PROC_HANGUL_SBASE;
if (options & (UTF8PROC_COMPOSE|UTF8PROC_DECOMPOSE)) { if(options & (UTF8PROC_COMPOSE | UTF8PROC_DECOMPOSE))
if (hangul_sindex >= 0 && hangul_sindex < UTF8PROC_HANGUL_SCOUNT) { {
if(hangul_sindex >= 0 && hangul_sindex < UTF8PROC_HANGUL_SCOUNT)
{
utf8proc_int32_t hangul_tindex; utf8proc_int32_t hangul_tindex;
if (bufsize >= 1) { if(bufsize >= 1)
dst[0] = UTF8PROC_HANGUL_LBASE + {
hangul_sindex / UTF8PROC_HANGUL_NCOUNT; dst[0] = UTF8PROC_HANGUL_LBASE + hangul_sindex / UTF8PROC_HANGUL_NCOUNT;
if (bufsize >= 2) dst[1] = UTF8PROC_HANGUL_VBASE + if(bufsize >= 2)
dst[1] = UTF8PROC_HANGUL_VBASE +
(hangul_sindex % UTF8PROC_HANGUL_NCOUNT) / UTF8PROC_HANGUL_TCOUNT; (hangul_sindex % UTF8PROC_HANGUL_NCOUNT) / UTF8PROC_HANGUL_TCOUNT;
} }
hangul_tindex = hangul_sindex % UTF8PROC_HANGUL_TCOUNT; hangul_tindex = hangul_sindex % UTF8PROC_HANGUL_TCOUNT;
if (!hangul_tindex) return 2; if(!hangul_tindex)
if (bufsize >= 3) dst[2] = UTF8PROC_HANGUL_TBASE + hangul_tindex; return 2;
if(bufsize >= 3)
dst[2] = UTF8PROC_HANGUL_TBASE + hangul_tindex;
return 3; return 3;
} }
} }
if (options & UTF8PROC_REJECTNA) { if(options & UTF8PROC_REJECTNA)
if (!category) return UTF8PROC_ERROR_NOTASSIGNED; {
if(!category)
return UTF8PROC_ERROR_NOTASSIGNED;
} }
if (options & UTF8PROC_IGNORE) { if(options & UTF8PROC_IGNORE)
if (property->ignorable) return 0; {
if(property->ignorable)
return 0;
} }
if (options & UTF8PROC_STRIPNA) { if(options & UTF8PROC_STRIPNA)
if (!category) return 0; {
if(!category)
return 0;
} }
if (options & UTF8PROC_LUMP) { if(options & UTF8PROC_LUMP)
if (category == UTF8PROC_CATEGORY_ZS) utf8proc_decompose_lump(0x0020); {
if (uc == 0x2018 || uc == 0x2019 || uc == 0x02BC || uc == 0x02C8) if(category == UTF8PROC_CATEGORY_ZS)
utf8proc_decompose_lump(0x0020);
if(uc == 0x2018 || uc == 0x2019 || uc == 0x02BC || uc == 0x02C8)
utf8proc_decompose_lump(0x0027); utf8proc_decompose_lump(0x0027);
if (category == UTF8PROC_CATEGORY_PD || uc == 0x2212) if(category == UTF8PROC_CATEGORY_PD || uc == 0x2212)
utf8proc_decompose_lump(0x002D); utf8proc_decompose_lump(0x002D);
if (uc == 0x2044 || uc == 0x2215) utf8proc_decompose_lump(0x002F); if(uc == 0x2044 || uc == 0x2215)
if (uc == 0x2236) utf8proc_decompose_lump(0x003A); utf8proc_decompose_lump(0x002F);
if (uc == 0x2039 || uc == 0x2329 || uc == 0x3008) if(uc == 0x2236)
utf8proc_decompose_lump(0x003A);
if(uc == 0x2039 || uc == 0x2329 || uc == 0x3008)
utf8proc_decompose_lump(0x003C); utf8proc_decompose_lump(0x003C);
if (uc == 0x203A || uc == 0x232A || uc == 0x3009) if(uc == 0x203A || uc == 0x232A || uc == 0x3009)
utf8proc_decompose_lump(0x003E); utf8proc_decompose_lump(0x003E);
if (uc == 0x2216) utf8proc_decompose_lump(0x005C); if(uc == 0x2216)
if (uc == 0x02C4 || uc == 0x02C6 || uc == 0x2038 || uc == 0x2303) utf8proc_decompose_lump(0x005C);
if(uc == 0x02C4 || uc == 0x02C6 || uc == 0x2038 || uc == 0x2303)
utf8proc_decompose_lump(0x005E); utf8proc_decompose_lump(0x005E);
if (category == UTF8PROC_CATEGORY_PC || uc == 0x02CD) if(category == UTF8PROC_CATEGORY_PC || uc == 0x02CD)
utf8proc_decompose_lump(0x005F); utf8proc_decompose_lump(0x005F);
if (uc == 0x02CB) utf8proc_decompose_lump(0x0060); if(uc == 0x02CB)
if (uc == 0x2223) utf8proc_decompose_lump(0x007C); utf8proc_decompose_lump(0x0060);
if (uc == 0x223C) utf8proc_decompose_lump(0x007E); if(uc == 0x2223)
if ((options & UTF8PROC_NLF2LS) && (options & UTF8PROC_NLF2PS)) { utf8proc_decompose_lump(0x007C);
if (category == UTF8PROC_CATEGORY_ZL || if(uc == 0x223C)
category == UTF8PROC_CATEGORY_ZP) utf8proc_decompose_lump(0x007E);
if((options & UTF8PROC_NLF2LS) && (options & UTF8PROC_NLF2PS))
{
if(category == UTF8PROC_CATEGORY_ZL || category == UTF8PROC_CATEGORY_ZP)
utf8proc_decompose_lump(0x000A); utf8proc_decompose_lump(0x000A);
} }
} }
if (options & UTF8PROC_STRIPMARK) { if(options & UTF8PROC_STRIPMARK)
if (category == UTF8PROC_CATEGORY_MN || {
category == UTF8PROC_CATEGORY_MC || if(category == UTF8PROC_CATEGORY_MN || category == UTF8PROC_CATEGORY_MC ||
category == UTF8PROC_CATEGORY_ME) return 0; category == UTF8PROC_CATEGORY_ME)
return 0;
} }
if (options & UTF8PROC_CASEFOLD) { if(options & UTF8PROC_CASEFOLD)
if (property->casefold_seqindex != UINT16_MAX) { {
return seqindex_write_char_decomposed(property->casefold_seqindex, dst, bufsize, options, last_boundclass); if(property->casefold_seqindex != UINT16_MAX)
{
return seqindex_write_char_decomposed(
property->casefold_seqindex, dst, bufsize, options, last_boundclass);
} }
} }
if (options & (UTF8PROC_COMPOSE|UTF8PROC_DECOMPOSE)) { if(options & (UTF8PROC_COMPOSE | UTF8PROC_DECOMPOSE))
if (property->decomp_seqindex != UINT16_MAX && {
(!property->decomp_type || (options & UTF8PROC_COMPAT))) { if(property->decomp_seqindex != UINT16_MAX &&
return seqindex_write_char_decomposed(property->decomp_seqindex, dst, bufsize, options, last_boundclass); (!property->decomp_type || (options & UTF8PROC_COMPAT)))
{
return seqindex_write_char_decomposed(
property->decomp_seqindex, dst, bufsize, options, last_boundclass);
} }
} }
if (options & UTF8PROC_CHARBOUND) { if(options & UTF8PROC_CHARBOUND)
{
utf8proc_bool boundary; utf8proc_bool boundary;
int tbc = property->boundclass; int tbc = property->boundclass;
boundary = grapheme_break_extended(*last_boundclass, tbc, last_boundclass); boundary = grapheme_break_extended(*last_boundclass, tbc, last_boundclass);
if (boundary) { if(boundary)
if (bufsize >= 1) dst[0] = -1; /* sentinel value for grapheme break */ {
if (bufsize >= 2) dst[1] = uc; if(bufsize >= 1)
dst[0] = -1; /* sentinel value for grapheme break */
if(bufsize >= 2)
dst[1] = uc;
return 2; return 2;
} }
} }
if (bufsize >= 1) *dst = uc; if(bufsize >= 1)
*dst = uc;
return 1; return 1;
} }
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose( UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose(const utf8proc_uint8_t* str,
const utf8proc_uint8_t *str, utf8proc_ssize_t strlen, utf8proc_ssize_t strlen,
utf8proc_int32_t *buffer, utf8proc_ssize_t bufsize, utf8proc_option_t options utf8proc_int32_t* buffer,
) { utf8proc_ssize_t bufsize,
utf8proc_option_t options)
{
return utf8proc_decompose_custom(str, strlen, buffer, bufsize, options, NULL, NULL); return utf8proc_decompose_custom(str, strlen, buffer, bufsize, options, NULL, NULL);
} }
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose_custom( UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose_custom(const utf8proc_uint8_t* str,
const utf8proc_uint8_t *str, utf8proc_ssize_t strlen, utf8proc_ssize_t strlen,
utf8proc_int32_t *buffer, utf8proc_ssize_t bufsize, utf8proc_option_t options, utf8proc_int32_t* buffer,
utf8proc_custom_func custom_func, void *custom_data utf8proc_ssize_t bufsize,
) { utf8proc_option_t options,
utf8proc_custom_func custom_func,
void* custom_data)
{
/* strlen will be ignored, if UTF8PROC_NULLTERM is set in options */ /* strlen will be ignored, if UTF8PROC_NULLTERM is set in options */
utf8proc_ssize_t wpos = 0; utf8proc_ssize_t wpos = 0;
if ((options & UTF8PROC_COMPOSE) && (options & UTF8PROC_DECOMPOSE)) if((options & UTF8PROC_COMPOSE) && (options & UTF8PROC_DECOMPOSE))
return UTF8PROC_ERROR_INVALIDOPTS; return UTF8PROC_ERROR_INVALIDOPTS;
if ((options & UTF8PROC_STRIPMARK) && if((options & UTF8PROC_STRIPMARK) && !(options & UTF8PROC_COMPOSE) &&
!(options & UTF8PROC_COMPOSE) && !(options & UTF8PROC_DECOMPOSE)) !(options & UTF8PROC_DECOMPOSE))
return UTF8PROC_ERROR_INVALIDOPTS; return UTF8PROC_ERROR_INVALIDOPTS;
{ {
utf8proc_int32_t uc; utf8proc_int32_t uc;
utf8proc_ssize_t rpos = 0; utf8proc_ssize_t rpos = 0;
utf8proc_ssize_t decomp_result; utf8proc_ssize_t decomp_result;
int boundclass = UTF8PROC_BOUNDCLASS_START; int boundclass = UTF8PROC_BOUNDCLASS_START;
while (1) { while(1)
if (options & UTF8PROC_NULLTERM) { {
if(options & UTF8PROC_NULLTERM)
{
rpos += utf8proc_iterate(str + rpos, -1, &uc); rpos += utf8proc_iterate(str + rpos, -1, &uc);
/* checking of return value is not necessary, /* checking of return value is not necessary,
as 'uc' is < 0 in case of error */ as 'uc' is < 0 in case of error */
if (uc < 0) return UTF8PROC_ERROR_INVALIDUTF8; if(uc < 0)
if (rpos < 0) return UTF8PROC_ERROR_OVERFLOW; return UTF8PROC_ERROR_INVALIDUTF8;
if (uc == 0) break; if(rpos < 0)
} else { return UTF8PROC_ERROR_OVERFLOW;
if (rpos >= strlen) break; if(uc == 0)
break;
}
else
{
if(rpos >= strlen)
break;
rpos += utf8proc_iterate(str + rpos, strlen - rpos, &uc); rpos += utf8proc_iterate(str + rpos, strlen - rpos, &uc);
if (uc < 0) return UTF8PROC_ERROR_INVALIDUTF8; if(uc < 0)
return UTF8PROC_ERROR_INVALIDUTF8;
} }
if (custom_func != NULL) { if(custom_func != NULL)
{
uc = custom_func(uc, custom_data); /* user-specified custom mapping */ uc = custom_func(uc, custom_data); /* user-specified custom mapping */
} }
decomp_result = utf8proc_decompose_char( decomp_result = utf8proc_decompose_char(
uc, buffer + wpos, (bufsize > wpos) ? (bufsize - wpos) : 0, options, uc, buffer + wpos, (bufsize > wpos) ? (bufsize - wpos) : 0, options, &boundclass);
&boundclass if(decomp_result < 0)
); return decomp_result;
if (decomp_result < 0) return decomp_result;
wpos += decomp_result; wpos += decomp_result;
/* prohibiting integer overflows due to too long strings: */ /* prohibiting integer overflows due to too long strings: */
if (wpos < 0 || if(wpos < 0 || wpos > (utf8proc_ssize_t)(SSIZE_MAX / sizeof(utf8proc_int32_t) / 2))
wpos > (utf8proc_ssize_t)(SSIZE_MAX/sizeof(utf8proc_int32_t)/2))
return UTF8PROC_ERROR_OVERFLOW; return UTF8PROC_ERROR_OVERFLOW;
} }
} }
if ((options & (UTF8PROC_COMPOSE|UTF8PROC_DECOMPOSE)) && bufsize >= wpos) { if((options & (UTF8PROC_COMPOSE | UTF8PROC_DECOMPOSE)) && bufsize >= wpos)
{
utf8proc_ssize_t pos = 0; utf8proc_ssize_t pos = 0;
while (pos < wpos-1) { while(pos < wpos - 1)
{
utf8proc_int32_t uc1, uc2; utf8proc_int32_t uc1, uc2;
const utf8proc_property_t *property1, *property2; const utf8proc_property_t *property1, *property2;
uc1 = buffer[pos]; uc1 = buffer[pos];
uc2 = buffer[pos+1]; uc2 = buffer[pos + 1];
property1 = unsafe_get_property(uc1); property1 = unsafe_get_property(uc1);
property2 = unsafe_get_property(uc2); property2 = unsafe_get_property(uc2);
if (property1->combining_class > property2->combining_class && if(property1->combining_class > property2->combining_class &&
property2->combining_class > 0) { property2->combining_class > 0)
{
buffer[pos] = uc2; buffer[pos] = uc2;
buffer[pos+1] = uc1; buffer[pos + 1] = uc1;
if (pos > 0) pos--; else pos++; if(pos > 0)
} else { pos--;
else
pos++;
}
else
{
pos++; pos++;
} }
} }
...@@ -584,59 +704,84 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose_custom( ...@@ -584,59 +704,84 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose_custom(
return wpos; return wpos;
} }
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_normalize_utf32(utf8proc_int32_t *buffer, utf8proc_ssize_t length, utf8proc_option_t options) { UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_normalize_utf32(utf8proc_int32_t* buffer,
utf8proc_ssize_t length,
utf8proc_option_t options)
{
/* UTF8PROC_NULLTERM option will be ignored, 'length' is never ignored */ /* UTF8PROC_NULLTERM option will be ignored, 'length' is never ignored */
if (options & (UTF8PROC_NLF2LS | UTF8PROC_NLF2PS | UTF8PROC_STRIPCC)) { if(options & (UTF8PROC_NLF2LS | UTF8PROC_NLF2PS | UTF8PROC_STRIPCC))
{
utf8proc_ssize_t rpos; utf8proc_ssize_t rpos;
utf8proc_ssize_t wpos = 0; utf8proc_ssize_t wpos = 0;
utf8proc_int32_t uc; utf8proc_int32_t uc;
for (rpos = 0; rpos < length; rpos++) { for(rpos = 0; rpos < length; rpos++)
{
uc = buffer[rpos]; uc = buffer[rpos];
if (uc == 0x000D && rpos < length-1 && buffer[rpos+1] == 0x000A) rpos++; if(uc == 0x000D && rpos < length - 1 && buffer[rpos + 1] == 0x000A)
if (uc == 0x000A || uc == 0x000D || uc == 0x0085 || rpos++;
((options & UTF8PROC_STRIPCC) && (uc == 0x000B || uc == 0x000C))) { if(uc == 0x000A || uc == 0x000D || uc == 0x0085 ||
if (options & UTF8PROC_NLF2LS) { ((options & UTF8PROC_STRIPCC) && (uc == 0x000B || uc == 0x000C)))
if (options & UTF8PROC_NLF2PS) { {
if(options & UTF8PROC_NLF2LS)
{
if(options & UTF8PROC_NLF2PS)
{
buffer[wpos++] = 0x000A; buffer[wpos++] = 0x000A;
} else { }
else
{
buffer[wpos++] = 0x2028; buffer[wpos++] = 0x2028;
} }
} else { }
if (options & UTF8PROC_NLF2PS) { else
{
if(options & UTF8PROC_NLF2PS)
{
buffer[wpos++] = 0x2029; buffer[wpos++] = 0x2029;
} else { }
else
{
buffer[wpos++] = 0x0020; buffer[wpos++] = 0x0020;
} }
} }
} else if ((options & UTF8PROC_STRIPCC) && }
(uc < 0x0020 || (uc >= 0x007F && uc < 0x00A0))) { else if((options & UTF8PROC_STRIPCC) && (uc < 0x0020 || (uc >= 0x007F && uc < 0x00A0)))
if (uc == 0x0009) buffer[wpos++] = 0x0020; {
} else { if(uc == 0x0009)
buffer[wpos++] = 0x0020;
}
else
{
buffer[wpos++] = uc; buffer[wpos++] = uc;
} }
} }
length = wpos; length = wpos;
} }
if (options & UTF8PROC_COMPOSE) { if(options & UTF8PROC_COMPOSE)
utf8proc_int32_t *starter = NULL; {
utf8proc_int32_t* starter = NULL;
utf8proc_int32_t current_char; utf8proc_int32_t current_char;
const utf8proc_property_t *starter_property = NULL, *current_property; const utf8proc_property_t *starter_property = NULL, *current_property;
utf8proc_propval_t max_combining_class = -1; utf8proc_propval_t max_combining_class = -1;
utf8proc_ssize_t rpos; utf8proc_ssize_t rpos;
utf8proc_ssize_t wpos = 0; utf8proc_ssize_t wpos = 0;
utf8proc_int32_t composition; utf8proc_int32_t composition;
for (rpos = 0; rpos < length; rpos++) { for(rpos = 0; rpos < length; rpos++)
{
current_char = buffer[rpos]; current_char = buffer[rpos];
current_property = unsafe_get_property(current_char); current_property = unsafe_get_property(current_char);
if (starter && current_property->combining_class > max_combining_class) { if(starter && current_property->combining_class > max_combining_class)
{
/* combination perhaps possible */ /* combination perhaps possible */
utf8proc_int32_t hangul_lindex; utf8proc_int32_t hangul_lindex;
utf8proc_int32_t hangul_sindex; utf8proc_int32_t hangul_sindex;
hangul_lindex = *starter - UTF8PROC_HANGUL_LBASE; hangul_lindex = *starter - UTF8PROC_HANGUL_LBASE;
if (hangul_lindex >= 0 && hangul_lindex < UTF8PROC_HANGUL_LCOUNT) { if(hangul_lindex >= 0 && hangul_lindex < UTF8PROC_HANGUL_LCOUNT)
{
utf8proc_int32_t hangul_vindex; utf8proc_int32_t hangul_vindex;
hangul_vindex = current_char - UTF8PROC_HANGUL_VBASE; hangul_vindex = current_char - UTF8PROC_HANGUL_VBASE;
if (hangul_vindex >= 0 && hangul_vindex < UTF8PROC_HANGUL_VCOUNT) { if(hangul_vindex >= 0 && hangul_vindex < UTF8PROC_HANGUL_VCOUNT)
{
*starter = UTF8PROC_HANGUL_SBASE + *starter = UTF8PROC_HANGUL_SBASE +
(hangul_lindex * UTF8PROC_HANGUL_VCOUNT + hangul_vindex) * (hangul_lindex * UTF8PROC_HANGUL_VCOUNT + hangul_vindex) *
UTF8PROC_HANGUL_TCOUNT; UTF8PROC_HANGUL_TCOUNT;
...@@ -645,33 +790,42 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_normalize_utf32(utf8proc_int32_t *b ...@@ -645,33 +790,42 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_normalize_utf32(utf8proc_int32_t *b
} }
} }
hangul_sindex = *starter - UTF8PROC_HANGUL_SBASE; hangul_sindex = *starter - UTF8PROC_HANGUL_SBASE;
if (hangul_sindex >= 0 && hangul_sindex < UTF8PROC_HANGUL_SCOUNT && if(hangul_sindex >= 0 && hangul_sindex < UTF8PROC_HANGUL_SCOUNT &&
(hangul_sindex % UTF8PROC_HANGUL_TCOUNT) == 0) { (hangul_sindex % UTF8PROC_HANGUL_TCOUNT) == 0)
{
utf8proc_int32_t hangul_tindex; utf8proc_int32_t hangul_tindex;
hangul_tindex = current_char - UTF8PROC_HANGUL_TBASE; hangul_tindex = current_char - UTF8PROC_HANGUL_TBASE;
if (hangul_tindex >= 0 && hangul_tindex < UTF8PROC_HANGUL_TCOUNT) { if(hangul_tindex >= 0 && hangul_tindex < UTF8PROC_HANGUL_TCOUNT)
{
*starter += hangul_tindex; *starter += hangul_tindex;
starter_property = NULL; starter_property = NULL;
continue; continue;
} }
} }
if (!starter_property) { if(!starter_property)
{
starter_property = unsafe_get_property(*starter); starter_property = unsafe_get_property(*starter);
} }
if (starter_property->comb_index < 0x8000 && if(starter_property->comb_index < 0x8000 &&
current_property->comb_index != UINT16_MAX && current_property->comb_index != UINT16_MAX &&
current_property->comb_index >= 0x8000) { current_property->comb_index >= 0x8000)
{
int sidx = starter_property->comb_index; int sidx = starter_property->comb_index;
int idx = current_property->comb_index & 0x3FFF; int idx = current_property->comb_index & 0x3FFF;
if (idx >= utf8proc_combinations[sidx] && idx <= utf8proc_combinations[sidx + 1] ) { if(idx >= utf8proc_combinations[sidx] && idx <= utf8proc_combinations[sidx + 1])
{
idx += sidx + 2 - utf8proc_combinations[sidx]; idx += sidx + 2 - utf8proc_combinations[sidx];
if (current_property->comb_index & 0x4000) { if(current_property->comb_index & 0x4000)
composition = (utf8proc_combinations[idx] << 16) | utf8proc_combinations[idx+1]; {
} else composition =
(utf8proc_combinations[idx] << 16) | utf8proc_combinations[idx + 1];
}
else
composition = utf8proc_combinations[idx]; composition = utf8proc_combinations[idx];
if (composition > 0 && (!(options & UTF8PROC_STABLE) || if(composition > 0 && (!(options & UTF8PROC_STABLE) ||
!(unsafe_get_property(composition)->comp_exclusion))) { !(unsafe_get_property(composition)->comp_exclusion)))
{
*starter = composition; *starter = composition;
starter_property = NULL; starter_property = NULL;
continue; continue;
...@@ -680,11 +834,15 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_normalize_utf32(utf8proc_int32_t *b ...@@ -680,11 +834,15 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_normalize_utf32(utf8proc_int32_t *b
} }
} }
buffer[wpos] = current_char; buffer[wpos] = current_char;
if (current_property->combining_class) { if(current_property->combining_class)
if (current_property->combining_class > max_combining_class) { {
if(current_property->combining_class > max_combining_class)
{
max_combining_class = current_property->combining_class; max_combining_class = current_property->combining_class;
} }
} else { }
else
{
starter = buffer + wpos; starter = buffer + wpos;
starter_property = NULL; starter_property = NULL;
max_combining_class = -1; max_combining_class = -1;
...@@ -696,97 +854,125 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_normalize_utf32(utf8proc_int32_t *b ...@@ -696,97 +854,125 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_normalize_utf32(utf8proc_int32_t *b
return length; return length;
} }
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_reencode(utf8proc_int32_t *buffer, utf8proc_ssize_t length, utf8proc_option_t options) { UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_reencode(utf8proc_int32_t* buffer,
utf8proc_ssize_t length,
utf8proc_option_t options)
{
/* UTF8PROC_NULLTERM option will be ignored, 'length' is never ignored /* UTF8PROC_NULLTERM option will be ignored, 'length' is never ignored
ASSERT: 'buffer' has one spare byte of free space at the end! */ ASSERT: 'buffer' has one spare byte of free space at the end! */
length = utf8proc_normalize_utf32(buffer, length, options); length = utf8proc_normalize_utf32(buffer, length, options);
if (length < 0) return length; if(length < 0)
return length;
{ {
utf8proc_ssize_t rpos, wpos = 0; utf8proc_ssize_t rpos, wpos = 0;
utf8proc_int32_t uc; utf8proc_int32_t uc;
if (options & UTF8PROC_CHARBOUND) { if(options & UTF8PROC_CHARBOUND)
for (rpos = 0; rpos < length; rpos++) { {
for(rpos = 0; rpos < length; rpos++)
{
uc = buffer[rpos]; uc = buffer[rpos];
wpos += charbound_encode_char(uc, ((utf8proc_uint8_t *)buffer) + wpos); wpos += charbound_encode_char(uc, ((utf8proc_uint8_t*)buffer) + wpos);
}
} }
} else { else
for (rpos = 0; rpos < length; rpos++) { {
for(rpos = 0; rpos < length; rpos++)
{
uc = buffer[rpos]; uc = buffer[rpos];
wpos += utf8proc_encode_char(uc, ((utf8proc_uint8_t *)buffer) + wpos); wpos += utf8proc_encode_char(uc, ((utf8proc_uint8_t*)buffer) + wpos);
} }
} }
((utf8proc_uint8_t *)buffer)[wpos] = 0; ((utf8proc_uint8_t*)buffer)[wpos] = 0;
return wpos; return wpos;
} }
} }
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_map( UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_map(const utf8proc_uint8_t* str,
const utf8proc_uint8_t *str, utf8proc_ssize_t strlen, utf8proc_uint8_t **dstptr, utf8proc_option_t options utf8proc_ssize_t strlen,
) { utf8proc_uint8_t** dstptr,
utf8proc_option_t options)
{
return utf8proc_map_custom(str, strlen, dstptr, options, NULL, NULL); return utf8proc_map_custom(str, strlen, dstptr, options, NULL, NULL);
} }
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_map_custom( UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_map_custom(const utf8proc_uint8_t* str,
const utf8proc_uint8_t *str, utf8proc_ssize_t strlen, utf8proc_uint8_t **dstptr, utf8proc_option_t options, utf8proc_ssize_t strlen,
utf8proc_custom_func custom_func, void *custom_data utf8proc_uint8_t** dstptr,
) { utf8proc_option_t options,
utf8proc_int32_t *buffer; utf8proc_custom_func custom_func,
void* custom_data)
{
utf8proc_int32_t* buffer;
utf8proc_ssize_t result; utf8proc_ssize_t result;
*dstptr = NULL; *dstptr = NULL;
result = utf8proc_decompose_custom(str, strlen, NULL, 0, options, custom_func, custom_data); result = utf8proc_decompose_custom(str, strlen, NULL, 0, options, custom_func, custom_data);
if (result < 0) return result; if(result < 0)
buffer = (utf8proc_int32_t *) malloc(((utf8proc_size_t)result) * sizeof(utf8proc_int32_t) + 1); return result;
if (!buffer) return UTF8PROC_ERROR_NOMEM; buffer = (utf8proc_int32_t*)malloc(((utf8proc_size_t)result) * sizeof(utf8proc_int32_t) + 1);
result = utf8proc_decompose_custom(str, strlen, buffer, result, options, custom_func, custom_data); if(!buffer)
if (result < 0) { return UTF8PROC_ERROR_NOMEM;
result =
utf8proc_decompose_custom(str, strlen, buffer, result, options, custom_func, custom_data);
if(result < 0)
{
free(buffer); free(buffer);
return result; return result;
} }
result = utf8proc_reencode(buffer, result, options); result = utf8proc_reencode(buffer, result, options);
if (result < 0) { if(result < 0)
{
free(buffer); free(buffer);
return result; return result;
} }
{ {
utf8proc_int32_t *newptr; utf8proc_int32_t* newptr;
newptr = (utf8proc_int32_t *) realloc(buffer, (size_t)result+1); newptr = (utf8proc_int32_t*)realloc(buffer, (size_t)result + 1);
if (newptr) buffer = newptr; if(newptr)
buffer = newptr;
} }
*dstptr = (utf8proc_uint8_t *)buffer; *dstptr = (utf8proc_uint8_t*)buffer;
return result; return result;
} }
UTF8PROC_DLLEXPORT utf8proc_uint8_t *utf8proc_NFD(const utf8proc_uint8_t *str) { UTF8PROC_DLLEXPORT utf8proc_uint8_t* utf8proc_NFD(const utf8proc_uint8_t* str)
utf8proc_uint8_t *retval; {
utf8proc_map(str, 0, &retval, UTF8PROC_NULLTERM | UTF8PROC_STABLE | utf8proc_uint8_t* retval;
UTF8PROC_DECOMPOSE); utf8proc_map(str, 0, &retval, UTF8PROC_NULLTERM | UTF8PROC_STABLE | UTF8PROC_DECOMPOSE);
return retval; return retval;
} }
UTF8PROC_DLLEXPORT utf8proc_uint8_t *utf8proc_NFC(const utf8proc_uint8_t *str) { UTF8PROC_DLLEXPORT utf8proc_uint8_t* utf8proc_NFC(const utf8proc_uint8_t* str)
utf8proc_uint8_t *retval; {
utf8proc_map(str, 0, &retval, UTF8PROC_NULLTERM | UTF8PROC_STABLE | utf8proc_uint8_t* retval;
UTF8PROC_COMPOSE); utf8proc_map(str, 0, &retval, UTF8PROC_NULLTERM | UTF8PROC_STABLE | UTF8PROC_COMPOSE);
return retval; return retval;
} }
UTF8PROC_DLLEXPORT utf8proc_uint8_t *utf8proc_NFKD(const utf8proc_uint8_t *str) { UTF8PROC_DLLEXPORT utf8proc_uint8_t* utf8proc_NFKD(const utf8proc_uint8_t* str)
utf8proc_uint8_t *retval; {
utf8proc_map(str, 0, &retval, UTF8PROC_NULLTERM | UTF8PROC_STABLE | utf8proc_uint8_t* retval;
UTF8PROC_DECOMPOSE | UTF8PROC_COMPAT); utf8proc_map(str,
0,
&retval,
UTF8PROC_NULLTERM | UTF8PROC_STABLE | UTF8PROC_DECOMPOSE | UTF8PROC_COMPAT);
return retval; return retval;
} }
UTF8PROC_DLLEXPORT utf8proc_uint8_t *utf8proc_NFKC(const utf8proc_uint8_t *str) { UTF8PROC_DLLEXPORT utf8proc_uint8_t* utf8proc_NFKC(const utf8proc_uint8_t* str)
utf8proc_uint8_t *retval; {
utf8proc_map(str, 0, &retval, UTF8PROC_NULLTERM | UTF8PROC_STABLE | utf8proc_uint8_t* retval;
UTF8PROC_COMPOSE | UTF8PROC_COMPAT); utf8proc_map(
str, 0, &retval, UTF8PROC_NULLTERM | UTF8PROC_STABLE | UTF8PROC_COMPOSE | UTF8PROC_COMPAT);
return retval; return retval;
} }
UTF8PROC_DLLEXPORT utf8proc_uint8_t *utf8proc_NFKC_Casefold(const utf8proc_uint8_t *str) { UTF8PROC_DLLEXPORT utf8proc_uint8_t* utf8proc_NFKC_Casefold(const utf8proc_uint8_t* str)
utf8proc_uint8_t *retval; {
utf8proc_map(str, 0, &retval, UTF8PROC_NULLTERM | UTF8PROC_STABLE | utf8proc_uint8_t* retval;
UTF8PROC_COMPOSE | UTF8PROC_COMPAT | UTF8PROC_CASEFOLD | UTF8PROC_IGNORE); utf8proc_map(str,
0,
&retval,
UTF8PROC_NULLTERM | UTF8PROC_STABLE | UTF8PROC_COMPOSE | UTF8PROC_COMPAT |
UTF8PROC_CASEFOLD | UTF8PROC_IGNORE);
return retval; return retval;
} }
/* /*
* Copyright (c) 2014-2021 Steven G. Johnson, Jiahao Chen, Peter Colberg, Tony Kelman, Scott P. Jones, and other contributors. * Copyright (c) 2014-2021 Steven G. Johnson, Jiahao Chen, Peter Colberg, Tony
* Copyright (c) 2009 Public Software Group e. V., Berlin, Germany * Kelman, Scott P. Jones, and other contributors. Copyright (c) 2009 Public
* Software Group e. V., Berlin, Germany
* *
* Permission is hereby granted, free of charge, to any person obtaining a * Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"), * copy of this software and associated documentation files (the "Software"),
...@@ -21,7 +22,6 @@ ...@@ -21,7 +22,6 @@
* DEALINGS IN THE SOFTWARE. * DEALINGS IN THE SOFTWARE.
*/ */
/** /**
* @mainpage * @mainpage
* *
...@@ -37,15 +37,23 @@ ...@@ -37,15 +37,23 @@
* The features of utf8proc include: * The features of utf8proc include:
* *
* - Transformation of strings (@ref utf8proc_map) to: * - Transformation of strings (@ref utf8proc_map) to:
* - decompose (@ref UTF8PROC_DECOMPOSE) or compose (@ref UTF8PROC_COMPOSE) Unicode combining characters (http://en.wikipedia.org/wiki/Combining_character) * - decompose (@ref UTF8PROC_DECOMPOSE) or compose (@ref UTF8PROC_COMPOSE)
* Unicode combining characters
* (http://en.wikipedia.org/wiki/Combining_character)
* - canonicalize Unicode compatibility characters (@ref UTF8PROC_COMPAT) * - canonicalize Unicode compatibility characters (@ref UTF8PROC_COMPAT)
* - strip "ignorable" (@ref UTF8PROC_IGNORE) characters, control characters (@ref UTF8PROC_STRIPCC), or combining characters such as accents (@ref UTF8PROC_STRIPMARK) * - strip "ignorable" (@ref UTF8PROC_IGNORE) characters, control characters
* (@ref UTF8PROC_STRIPCC), or combining characters such as accents (@ref
* UTF8PROC_STRIPMARK)
* - case-folding (@ref UTF8PROC_CASEFOLD) * - case-folding (@ref UTF8PROC_CASEFOLD)
* - Unicode normalization: @ref utf8proc_NFD, @ref utf8proc_NFC, @ref utf8proc_NFKD, @ref utf8proc_NFKC * - Unicode normalization: @ref utf8proc_NFD, @ref utf8proc_NFC, @ref
* - Detecting grapheme boundaries (@ref utf8proc_grapheme_break and @ref UTF8PROC_CHARBOUND) * utf8proc_NFKD, @ref utf8proc_NFKC
* - Detecting grapheme boundaries (@ref utf8proc_grapheme_break and @ref
* UTF8PROC_CHARBOUND)
* - Character-width computation: @ref utf8proc_charwidth * - Character-width computation: @ref utf8proc_charwidth
* - Classification of characters by Unicode category: @ref utf8proc_category and @ref utf8proc_category_string * - Classification of characters by Unicode category: @ref utf8proc_category
* - Encode (@ref utf8proc_encode_char) and decode (@ref utf8proc_iterate) Unicode codepoints to/from UTF-8. * and @ref utf8proc_category_string
* - Encode (@ref utf8proc_encode_char) and decode (@ref utf8proc_iterate)
* Unicode codepoints to/from UTF-8.
*/ */
/** @file */ /** @file */
...@@ -68,9 +76,11 @@ ...@@ -68,9 +76,11 @@
* being based on ABI compatibility rather than API compatibility. * being based on ABI compatibility rather than API compatibility.
*/ */
/** @{ */ /** @{ */
/** The MAJOR version number (increased when backwards API compatibility is broken). */ /** The MAJOR version number (increased when backwards API compatibility is
* broken). */
#define UTF8PROC_VERSION_MAJOR 2 #define UTF8PROC_VERSION_MAJOR 2
/** The MINOR version number (increased when new functionality is added in a backwards-compatible manner). */ /** The MINOR version number (increased when new functionality is added in a
* backwards-compatible manner). */
#define UTF8PROC_VERSION_MINOR 8 #define UTF8PROC_VERSION_MINOR 8
/** The PATCH version (increased for fixes that do not change the API). */ /** The PATCH version (increased for fixes that do not change the API). */
#define UTF8PROC_VERSION_PATCH 0 #define UTF8PROC_VERSION_PATCH 0
...@@ -86,28 +96,29 @@ typedef short utf8proc_int16_t; ...@@ -86,28 +96,29 @@ typedef short utf8proc_int16_t;
typedef unsigned short utf8proc_uint16_t; typedef unsigned short utf8proc_uint16_t;
typedef int utf8proc_int32_t; typedef int utf8proc_int32_t;
typedef unsigned int utf8proc_uint32_t; typedef unsigned int utf8proc_uint32_t;
# ifdef _WIN64 #ifdef _WIN64
typedef __int64 utf8proc_ssize_t; typedef __int64 utf8proc_ssize_t;
typedef unsigned __int64 utf8proc_size_t; typedef unsigned __int64 utf8proc_size_t;
# else #else
typedef int utf8proc_ssize_t; typedef int utf8proc_ssize_t;
typedef unsigned int utf8proc_size_t; typedef unsigned int utf8proc_size_t;
# endif #endif
# ifndef __cplusplus #ifndef __cplusplus
// emulate C99 bool // emulate C99 bool
typedef unsigned char utf8proc_bool; typedef unsigned char utf8proc_bool;
# ifndef __bool_true_false_are_defined #ifndef __bool_true_false_are_defined
# define false 0 #define false 0
# define true 1 #define true 1
# define __bool_true_false_are_defined 1 #define __bool_true_false_are_defined 1
# endif #endif
# else #else
typedef bool utf8proc_bool; typedef bool utf8proc_bool;
# endif #endif
#else #else
# include <stddef.h> #include <inttypes.h>
# include <stdbool.h> #include <stdbool.h>
# include <inttypes.h> #include <stddef.h>
typedef int8_t utf8proc_int8_t; typedef int8_t utf8proc_int8_t;
typedef uint8_t utf8proc_uint8_t; typedef uint8_t utf8proc_uint8_t;
typedef int16_t utf8proc_int16_t; typedef int16_t utf8proc_int16_t;
...@@ -121,19 +132,19 @@ typedef bool utf8proc_bool; ...@@ -121,19 +132,19 @@ typedef bool utf8proc_bool;
#include <limits.h> #include <limits.h>
#ifdef UTF8PROC_STATIC #ifdef UTF8PROC_STATIC
# define UTF8PROC_DLLEXPORT #define UTF8PROC_DLLEXPORT
#else
#ifdef _WIN32
#ifdef UTF8PROC_EXPORTS
#define UTF8PROC_DLLEXPORT __declspec(dllexport)
#else #else
# ifdef _WIN32 #define UTF8PROC_DLLEXPORT __declspec(dllimport)
# ifdef UTF8PROC_EXPORTS #endif
# define UTF8PROC_DLLEXPORT __declspec(dllexport) #elif __GNUC__ >= 4
# else #define UTF8PROC_DLLEXPORT __attribute__((visibility("default")))
# define UTF8PROC_DLLEXPORT __declspec(dllimport) #else
# endif #define UTF8PROC_DLLEXPORT
# elif __GNUC__ >= 4 #endif
# define UTF8PROC_DLLEXPORT __attribute__ ((visibility("default")))
# else
# define UTF8PROC_DLLEXPORT
# endif
#endif #endif
#ifdef __cplusplus #ifdef __cplusplus
...@@ -143,33 +154,35 @@ extern "C" { ...@@ -143,33 +154,35 @@ extern "C" {
/** /**
* Option flags used by several functions in the library. * Option flags used by several functions in the library.
*/ */
typedef enum { typedef enum
{
/** The given UTF-8 input is NULL terminated. */ /** The given UTF-8 input is NULL terminated. */
UTF8PROC_NULLTERM = (1<<0), UTF8PROC_NULLTERM = (1 << 0),
/** Unicode Versioning Stability has to be respected. */ /** Unicode Versioning Stability has to be respected. */
UTF8PROC_STABLE = (1<<1), UTF8PROC_STABLE = (1 << 1),
/** Compatibility decomposition (i.e. formatting information is lost). */ /** Compatibility decomposition (i.e. formatting information is lost). */
UTF8PROC_COMPAT = (1<<2), UTF8PROC_COMPAT = (1 << 2),
/** Return a result with decomposed characters. */ /** Return a result with decomposed characters. */
UTF8PROC_COMPOSE = (1<<3), UTF8PROC_COMPOSE = (1 << 3),
/** Return a result with decomposed characters. */ /** Return a result with decomposed characters. */
UTF8PROC_DECOMPOSE = (1<<4), UTF8PROC_DECOMPOSE = (1 << 4),
/** Strip "default ignorable characters" such as SOFT-HYPHEN or ZERO-WIDTH-SPACE. */ /** Strip "default ignorable characters" such as SOFT-HYPHEN or
UTF8PROC_IGNORE = (1<<5), ZERO-WIDTH-SPACE. */
UTF8PROC_IGNORE = (1 << 5),
/** Return an error, if the input contains unassigned codepoints. */ /** Return an error, if the input contains unassigned codepoints. */
UTF8PROC_REJECTNA = (1<<6), UTF8PROC_REJECTNA = (1 << 6),
/** /**
* Indicating that NLF-sequences (LF, CRLF, CR, NEL) are representing a * Indicating that NLF-sequences (LF, CRLF, CR, NEL) are representing a
* line break, and should be converted to the codepoint for line * line break, and should be converted to the codepoint for line
* separation (LS). * separation (LS).
*/ */
UTF8PROC_NLF2LS = (1<<7), UTF8PROC_NLF2LS = (1 << 7),
/** /**
* Indicating that NLF-sequences are representing a paragraph break, and * Indicating that NLF-sequences are representing a paragraph break, and
* should be converted to the codepoint for paragraph separation * should be converted to the codepoint for paragraph separation
* (PS). * (PS).
*/ */
UTF8PROC_NLF2PS = (1<<8), UTF8PROC_NLF2PS = (1 << 8),
/** Indicating that the meaning of NLF-sequences is unknown. */ /** Indicating that the meaning of NLF-sequences is unknown. */
UTF8PROC_NLF2LF = (UTF8PROC_NLF2LS | UTF8PROC_NLF2PS), UTF8PROC_NLF2LF = (UTF8PROC_NLF2LS | UTF8PROC_NLF2PS),
/** Strips and/or convers control characters. /** Strips and/or convers control characters.
...@@ -179,17 +192,17 @@ typedef enum { ...@@ -179,17 +192,17 @@ typedef enum {
* are treated as a NLF-sequence in this case. All other control * are treated as a NLF-sequence in this case. All other control
* characters are simply removed. * characters are simply removed.
*/ */
UTF8PROC_STRIPCC = (1<<9), UTF8PROC_STRIPCC = (1 << 9),
/** /**
* Performs unicode case folding, to be able to do a case-insensitive * Performs unicode case folding, to be able to do a case-insensitive
* string comparison. * string comparison.
*/ */
UTF8PROC_CASEFOLD = (1<<10), UTF8PROC_CASEFOLD = (1 << 10),
/** /**
* Inserts 0xFF bytes at the beginning of each sequence which is * Inserts 0xFF bytes at the beginning of each sequence which is
* representing a single grapheme cluster (see UAX#29). * representing a single grapheme cluster (see UAX#29).
*/ */
UTF8PROC_CHARBOUND = (1<<11), UTF8PROC_CHARBOUND = (1 << 11),
/** Lumps certain characters together. /** Lumps certain characters together.
* *
* E.g. HYPHEN U+2010 and MINUS U+2212 to ASCII "-". See lump.md for details. * E.g. HYPHEN U+2010 and MINUS U+2212 to ASCII "-". See lump.md for details.
...@@ -197,18 +210,18 @@ typedef enum { ...@@ -197,18 +210,18 @@ typedef enum {
* If NLF2LF is set, this includes a transformation of paragraph and * If NLF2LF is set, this includes a transformation of paragraph and
* line separators to ASCII line-feed (LF). * line separators to ASCII line-feed (LF).
*/ */
UTF8PROC_LUMP = (1<<12), UTF8PROC_LUMP = (1 << 12),
/** Strips all character markings. /** Strips all character markings.
* *
* This includes non-spacing, spacing and enclosing (i.e. accents). * This includes non-spacing, spacing and enclosing (i.e. accents).
* @note This option works only with @ref UTF8PROC_COMPOSE or * @note This option works only with @ref UTF8PROC_COMPOSE or
* @ref UTF8PROC_DECOMPOSE * @ref UTF8PROC_DECOMPOSE
*/ */
UTF8PROC_STRIPMARK = (1<<13), UTF8PROC_STRIPMARK = (1 << 13),
/** /**
* Strip unassigned codepoints. * Strip unassigned codepoints.
*/ */
UTF8PROC_STRIPNA = (1<<14), UTF8PROC_STRIPNA = (1 << 14),
} utf8proc_option_t; } utf8proc_option_t;
/** @name Error codes /** @name Error codes
...@@ -221,7 +234,8 @@ typedef enum { ...@@ -221,7 +234,8 @@ typedef enum {
#define UTF8PROC_ERROR_OVERFLOW -2 #define UTF8PROC_ERROR_OVERFLOW -2
/** The given string is not a legal UTF-8 string. */ /** The given string is not a legal UTF-8 string. */
#define UTF8PROC_ERROR_INVALIDUTF8 -3 #define UTF8PROC_ERROR_INVALIDUTF8 -3
/** The @ref UTF8PROC_REJECTNA flag was set and an unassigned codepoint was found. */ /** The @ref UTF8PROC_REJECTNA flag was set and an unassigned codepoint was
* found. */
#define UTF8PROC_ERROR_NOTASSIGNED -4 #define UTF8PROC_ERROR_NOTASSIGNED -4
/** Invalid options have been used. */ /** Invalid options have been used. */
#define UTF8PROC_ERROR_INVALIDOPTS -5 #define UTF8PROC_ERROR_INVALIDOPTS -5
...@@ -233,7 +247,8 @@ typedef enum { ...@@ -233,7 +247,8 @@ typedef enum {
typedef utf8proc_int16_t utf8proc_propval_t; typedef utf8proc_int16_t utf8proc_propval_t;
/** Struct containing information about a codepoint. */ /** Struct containing information about a codepoint. */
typedef struct utf8proc_property_struct { typedef struct utf8proc_property_struct
{
/** /**
* Unicode category. * Unicode category.
* @see utf8proc_category_t. * @see utf8proc_category_t.
...@@ -256,28 +271,29 @@ typedef struct utf8proc_property_struct { ...@@ -256,28 +271,29 @@ typedef struct utf8proc_property_struct {
utf8proc_uint16_t lowercase_seqindex; utf8proc_uint16_t lowercase_seqindex;
utf8proc_uint16_t titlecase_seqindex; utf8proc_uint16_t titlecase_seqindex;
utf8proc_uint16_t comb_index; utf8proc_uint16_t comb_index;
unsigned bidi_mirrored:1; unsigned bidi_mirrored : 1;
unsigned comp_exclusion:1; unsigned comp_exclusion : 1;
/** /**
* Can this codepoint be ignored? * Can this codepoint be ignored?
* *
* Used by @ref utf8proc_decompose_char when @ref UTF8PROC_IGNORE is * Used by @ref utf8proc_decompose_char when @ref UTF8PROC_IGNORE is
* passed as an option. * passed as an option.
*/ */
unsigned ignorable:1; unsigned ignorable : 1;
unsigned control_boundary:1; unsigned control_boundary : 1;
/** The width of the codepoint. */ /** The width of the codepoint. */
unsigned charwidth:2; unsigned charwidth : 2;
unsigned pad:2; unsigned pad : 2;
/** /**
* Boundclass. * Boundclass.
* @see utf8proc_boundclass_t. * @see utf8proc_boundclass_t.
*/ */
unsigned boundclass:8; unsigned boundclass : 8;
} utf8proc_property_t; } utf8proc_property_t;
/** Unicode categories. */ /** Unicode categories. */
typedef enum { typedef enum
{
UTF8PROC_CATEGORY_CN = 0, /**< Other, not assigned */ UTF8PROC_CATEGORY_CN = 0, /**< Other, not assigned */
UTF8PROC_CATEGORY_LU = 1, /**< Letter, uppercase */ UTF8PROC_CATEGORY_LU = 1, /**< Letter, uppercase */
UTF8PROC_CATEGORY_LL = 2, /**< Letter, lowercase */ UTF8PROC_CATEGORY_LL = 2, /**< Letter, lowercase */
...@@ -311,7 +327,8 @@ typedef enum { ...@@ -311,7 +327,8 @@ typedef enum {
} utf8proc_category_t; } utf8proc_category_t;
/** Bidirectional character classes. */ /** Bidirectional character classes. */
typedef enum { typedef enum
{
UTF8PROC_BIDI_CLASS_L = 1, /**< Left-to-Right */ UTF8PROC_BIDI_CLASS_L = 1, /**< Left-to-Right */
UTF8PROC_BIDI_CLASS_LRE = 2, /**< Left-to-Right Embedding */ UTF8PROC_BIDI_CLASS_LRE = 2, /**< Left-to-Right Embedding */
UTF8PROC_BIDI_CLASS_LRO = 3, /**< Left-to-Right Override */ UTF8PROC_BIDI_CLASS_LRO = 3, /**< Left-to-Right Override */
...@@ -338,7 +355,8 @@ typedef enum { ...@@ -338,7 +355,8 @@ typedef enum {
} utf8proc_bidi_class_t; } utf8proc_bidi_class_t;
/** Decomposition type. */ /** Decomposition type. */
typedef enum { typedef enum
{
UTF8PROC_DECOMP_TYPE_FONT = 1, /**< Font */ UTF8PROC_DECOMP_TYPE_FONT = 1, /**< Font */
UTF8PROC_DECOMP_TYPE_NOBREAK = 2, /**< Nobreak */ UTF8PROC_DECOMP_TYPE_NOBREAK = 2, /**< Nobreak */
UTF8PROC_DECOMP_TYPE_INITIAL = 3, /**< Initial */ UTF8PROC_DECOMP_TYPE_INITIAL = 3, /**< Initial */
...@@ -358,7 +376,8 @@ typedef enum { ...@@ -358,7 +376,8 @@ typedef enum {
} utf8proc_decomp_type_t; } utf8proc_decomp_type_t;
/** Boundclass property. (TR29) */ /** Boundclass property. (TR29) */
typedef enum { typedef enum
{
UTF8PROC_BOUNDCLASS_START = 0, /**< Start */ UTF8PROC_BOUNDCLASS_START = 0, /**< Start */
UTF8PROC_BOUNDCLASS_OTHER = 1, /**< Other */ UTF8PROC_BOUNDCLASS_OTHER = 1, /**< Other */
UTF8PROC_BOUNDCLASS_CR = 2, /**< Cr */ UTF8PROC_BOUNDCLASS_CR = 2, /**< Cr */
...@@ -393,7 +412,7 @@ typedef enum { ...@@ -393,7 +412,7 @@ typedef enum {
* @ref utf8proc_decompose_custom, which is used to specify a user-defined * @ref utf8proc_decompose_custom, which is used to specify a user-defined
* mapping of codepoints to be applied in conjunction with other mappings. * mapping of codepoints to be applied in conjunction with other mappings.
*/ */
typedef utf8proc_int32_t (*utf8proc_custom_func)(utf8proc_int32_t codepoint, void *data); typedef utf8proc_int32_t (*utf8proc_custom_func)(utf8proc_int32_t codepoint, void* data);
/** /**
* Array containing the byte lengths of a UTF-8 encoded codepoint based * Array containing the byte lengths of a UTF-8 encoded codepoint based
...@@ -406,18 +425,18 @@ UTF8PROC_DLLEXPORT extern const utf8proc_int8_t utf8proc_utf8class[256]; ...@@ -406,18 +425,18 @@ UTF8PROC_DLLEXPORT extern const utf8proc_int8_t utf8proc_utf8class[256];
* (http://semver.org format), possibly with a "-dev" suffix for * (http://semver.org format), possibly with a "-dev" suffix for
* development versions. * development versions.
*/ */
UTF8PROC_DLLEXPORT const char *utf8proc_version(void); UTF8PROC_DLLEXPORT const char* utf8proc_version(void);
/** /**
* Returns the utf8proc supported Unicode version as a string MAJOR.MINOR.PATCH. * Returns the utf8proc supported Unicode version as a string MAJOR.MINOR.PATCH.
*/ */
UTF8PROC_DLLEXPORT const char *utf8proc_unicode_version(void); UTF8PROC_DLLEXPORT const char* utf8proc_unicode_version(void);
/** /**
* Returns an informative error string for the given utf8proc error code * Returns an informative error string for the given utf8proc error code
* (e.g. the error codes returned by @ref utf8proc_map). * (e.g. the error codes returned by @ref utf8proc_map).
*/ */
UTF8PROC_DLLEXPORT const char *utf8proc_errmsg(utf8proc_ssize_t errcode); UTF8PROC_DLLEXPORT const char* utf8proc_errmsg(utf8proc_ssize_t errcode);
/** /**
* Reads a single codepoint from the UTF-8 sequence being pointed to by `str`. * Reads a single codepoint from the UTF-8 sequence being pointed to by `str`.
...@@ -429,7 +448,9 @@ UTF8PROC_DLLEXPORT const char *utf8proc_errmsg(utf8proc_ssize_t errcode); ...@@ -429,7 +448,9 @@ UTF8PROC_DLLEXPORT const char *utf8proc_errmsg(utf8proc_ssize_t errcode);
* In case of success, the number of bytes read is returned; otherwise, a * In case of success, the number of bytes read is returned; otherwise, a
* negative error code is returned. * negative error code is returned.
*/ */
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_iterate(const utf8proc_uint8_t *str, utf8proc_ssize_t strlen, utf8proc_int32_t *codepoint_ref); UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_iterate(const utf8proc_uint8_t* str,
utf8proc_ssize_t strlen,
utf8proc_int32_t* codepoint_ref);
/** /**
* Check if a codepoint is valid (regardless of whether it has been * Check if a codepoint is valid (regardless of whether it has been
...@@ -448,7 +469,8 @@ UTF8PROC_DLLEXPORT utf8proc_bool utf8proc_codepoint_valid(utf8proc_int32_t codep ...@@ -448,7 +469,8 @@ UTF8PROC_DLLEXPORT utf8proc_bool utf8proc_codepoint_valid(utf8proc_int32_t codep
* *
* This function does not check whether `codepoint` is valid Unicode. * This function does not check whether `codepoint` is valid Unicode.
*/ */
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_encode_char(utf8proc_int32_t codepoint, utf8proc_uint8_t *dst); UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_encode_char(utf8proc_int32_t codepoint,
utf8proc_uint8_t* dst);
/** /**
* Look up the properties for a given codepoint. * Look up the properties for a given codepoint.
...@@ -462,7 +484,7 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_encode_char(utf8proc_int32_t codepo ...@@ -462,7 +484,7 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_encode_char(utf8proc_int32_t codepo
* If the codepoint is unassigned or invalid, a pointer to a special struct is * If the codepoint is unassigned or invalid, a pointer to a special struct is
* returned in which `category` is 0 (@ref UTF8PROC_CATEGORY_CN). * returned in which `category` is 0 (@ref UTF8PROC_CATEGORY_CN).
*/ */
UTF8PROC_DLLEXPORT const utf8proc_property_t *utf8proc_get_property(utf8proc_int32_t codepoint); UTF8PROC_DLLEXPORT const utf8proc_property_t* utf8proc_get_property(utf8proc_int32_t codepoint);
/** Decompose a codepoint into an array of codepoints. /** Decompose a codepoint into an array of codepoints.
* *
...@@ -492,10 +514,11 @@ UTF8PROC_DLLEXPORT const utf8proc_property_t *utf8proc_get_property(utf8proc_int ...@@ -492,10 +514,11 @@ UTF8PROC_DLLEXPORT const utf8proc_property_t *utf8proc_get_property(utf8proc_int
* required buffer size is returned, while the buffer will be overwritten with * required buffer size is returned, while the buffer will be overwritten with
* undefined data. * undefined data.
*/ */
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose_char( UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose_char(utf8proc_int32_t codepoint,
utf8proc_int32_t codepoint, utf8proc_int32_t *dst, utf8proc_ssize_t bufsize, utf8proc_int32_t* dst,
utf8proc_option_t options, int *last_boundclass utf8proc_ssize_t bufsize,
); utf8proc_option_t options,
int* last_boundclass);
/** /**
* The same as @ref utf8proc_decompose_char, but acts on a whole UTF-8 * The same as @ref utf8proc_decompose_char, but acts on a whole UTF-8
...@@ -514,22 +537,26 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose_char( ...@@ -514,22 +537,26 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose_char(
* required buffer size is returned, while the buffer will be overwritten with * required buffer size is returned, while the buffer will be overwritten with
* undefined data. * undefined data.
*/ */
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose( UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose(const utf8proc_uint8_t* str,
const utf8proc_uint8_t *str, utf8proc_ssize_t strlen, utf8proc_ssize_t strlen,
utf8proc_int32_t *buffer, utf8proc_ssize_t bufsize, utf8proc_option_t options utf8proc_int32_t* buffer,
); utf8proc_ssize_t bufsize,
utf8proc_option_t options);
/** /**
* The same as @ref utf8proc_decompose, but also takes a `custom_func` mapping function * The same as @ref utf8proc_decompose, but also takes a `custom_func` mapping
* that is called on each codepoint in `str` before any other transformations * function that is called on each codepoint in `str` before any other
* (along with a `custom_data` pointer that is passed through to `custom_func`). * transformations (along with a `custom_data` pointer that is passed through to
* The `custom_func` argument is ignored if it is `NULL`. See also @ref utf8proc_map_custom. * `custom_func`). The `custom_func` argument is ignored if it is `NULL`. See
*/ * also @ref utf8proc_map_custom.
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose_custom( */
const utf8proc_uint8_t *str, utf8proc_ssize_t strlen, UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose_custom(const utf8proc_uint8_t* str,
utf8proc_int32_t *buffer, utf8proc_ssize_t bufsize, utf8proc_option_t options, utf8proc_ssize_t strlen,
utf8proc_custom_func custom_func, void *custom_data utf8proc_int32_t* buffer,
); utf8proc_ssize_t bufsize,
utf8proc_option_t options,
utf8proc_custom_func custom_func,
void* custom_data);
/** /**
* Normalizes the sequence of `length` codepoints pointed to by `buffer` * Normalizes the sequence of `length` codepoints pointed to by `buffer`
...@@ -541,20 +568,24 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose_custom( ...@@ -541,20 +568,24 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose_custom(
* - @ref UTF8PROC_NLF2LS - convert LF, CRLF, CR and NEL into LS * - @ref UTF8PROC_NLF2LS - convert LF, CRLF, CR and NEL into LS
* - @ref UTF8PROC_NLF2PS - convert LF, CRLF, CR and NEL into PS * - @ref UTF8PROC_NLF2PS - convert LF, CRLF, CR and NEL into PS
* - @ref UTF8PROC_NLF2LF - convert LF, CRLF, CR and NEL into LF * - @ref UTF8PROC_NLF2LF - convert LF, CRLF, CR and NEL into LF
* - @ref UTF8PROC_STRIPCC - strip or convert all non-affected control characters * - @ref UTF8PROC_STRIPCC - strip or convert all non-affected control
* characters
* - @ref UTF8PROC_COMPOSE - try to combine decomposed codepoints into composite * - @ref UTF8PROC_COMPOSE - try to combine decomposed codepoints into composite
* codepoints * codepoints
* - @ref UTF8PROC_STABLE - prohibit combining characters that would violate * - @ref UTF8PROC_STABLE - prohibit combining characters that would violate
* the unicode versioning stability * the unicode versioning stability
* *
* @return * @return
* In case of success, the length (in codepoints) of the normalized UTF-32 string is * In case of success, the length (in codepoints) of the normalized UTF-32
* returned; otherwise, a negative error code is returned (@ref utf8proc_errmsg). * string is returned; otherwise, a negative error code is returned (@ref
* utf8proc_errmsg).
* *
* @warning The entries of the array pointed to by `str` have to be in the * @warning The entries of the array pointed to by `str` have to be in the
* range `0x0000` to `0x10FFFF`. Otherwise, the program might crash! * range `0x0000` to `0x10FFFF`. Otherwise, the program might crash!
*/ */
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_normalize_utf32(utf8proc_int32_t *buffer, utf8proc_ssize_t length, utf8proc_option_t options); UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_normalize_utf32(utf8proc_int32_t* buffer,
utf8proc_ssize_t length,
utf8proc_option_t options);
/** /**
* Reencodes the sequence of `length` codepoints pointed to by `buffer` * Reencodes the sequence of `length` codepoints pointed to by `buffer`
...@@ -567,7 +598,8 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_normalize_utf32(utf8proc_int32_t *b ...@@ -567,7 +598,8 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_normalize_utf32(utf8proc_int32_t *b
* - @ref UTF8PROC_NLF2LS - convert LF, CRLF, CR and NEL into LS * - @ref UTF8PROC_NLF2LS - convert LF, CRLF, CR and NEL into LS
* - @ref UTF8PROC_NLF2PS - convert LF, CRLF, CR and NEL into PS * - @ref UTF8PROC_NLF2PS - convert LF, CRLF, CR and NEL into PS
* - @ref UTF8PROC_NLF2LF - convert LF, CRLF, CR and NEL into LF * - @ref UTF8PROC_NLF2LF - convert LF, CRLF, CR and NEL into LF
* - @ref UTF8PROC_STRIPCC - strip or convert all non-affected control characters * - @ref UTF8PROC_STRIPCC - strip or convert all non-affected control
* characters
* - @ref UTF8PROC_COMPOSE - try to combine decomposed codepoints into composite * - @ref UTF8PROC_COMPOSE - try to combine decomposed codepoints into composite
* codepoints * codepoints
* - @ref UTF8PROC_STABLE - prohibit combining characters that would violate * - @ref UTF8PROC_STABLE - prohibit combining characters that would violate
...@@ -584,35 +616,39 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_normalize_utf32(utf8proc_int32_t *b ...@@ -584,35 +616,39 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_normalize_utf32(utf8proc_int32_t *b
* entries of the array pointed to by `str` have to be in the * entries of the array pointed to by `str` have to be in the
* range `0x0000` to `0x10FFFF`. Otherwise, the program might crash! * range `0x0000` to `0x10FFFF`. Otherwise, the program might crash!
*/ */
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_reencode(utf8proc_int32_t *buffer, utf8proc_ssize_t length, utf8proc_option_t options); UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_reencode(utf8proc_int32_t* buffer,
utf8proc_ssize_t length,
utf8proc_option_t options);
/** /**
* Given a pair of consecutive codepoints, return whether a grapheme break is * Given a pair of consecutive codepoints, return whether a grapheme break is
* permitted between them (as defined by the extended grapheme clusters in UAX#29). * permitted between them (as defined by the extended grapheme clusters in
* UAX#29).
* *
* @param codepoint1 The first codepoint. * @param codepoint1 The first codepoint.
* @param codepoint2 The second codepoint, occurring consecutively after `codepoint1`. * @param codepoint2 The second codepoint, occurring consecutively after
* @param state Beginning with Version 29 (Unicode 9.0.0), this algorithm requires * `codepoint1`.
* state to break graphemes. This state can be passed in as a pointer * @param state Beginning with Version 29 (Unicode 9.0.0), this algorithm
* requires state to break graphemes. This state can be passed in as a pointer
* in the `state` argument and should initially be set to 0. If the * in the `state` argument and should initially be set to 0. If the
* state is not passed in (i.e. a null pointer is passed), UAX#29 rules * state is not passed in (i.e. a null pointer is passed), UAX#29
* GB10/12/13 which require this state will not be applied, essentially * rules GB10/12/13 which require this state will not be applied, essentially
* matching the rules in Unicode 8.0.0. * matching the rules in Unicode 8.0.0.
* *
* @warning If the state parameter is used, `utf8proc_grapheme_break_stateful` must * @warning If the state parameter is used, `utf8proc_grapheme_break_stateful`
* be called IN ORDER on ALL potential breaks in a string. However, it * must be called IN ORDER on ALL potential breaks in a string. However, it is
* is safe to reset the state to zero after a grapheme break. * safe to reset the state to zero after a grapheme break.
*/ */
UTF8PROC_DLLEXPORT utf8proc_bool utf8proc_grapheme_break_stateful( UTF8PROC_DLLEXPORT utf8proc_bool utf8proc_grapheme_break_stateful(utf8proc_int32_t codepoint1,
utf8proc_int32_t codepoint1, utf8proc_int32_t codepoint2, utf8proc_int32_t *state); utf8proc_int32_t codepoint2,
utf8proc_int32_t* state);
/** /**
* Same as @ref utf8proc_grapheme_break_stateful, except without support for the * Same as @ref utf8proc_grapheme_break_stateful, except without support for the
* Unicode 9 additions to the algorithm. Supported for legacy reasons. * Unicode 9 additions to the algorithm. Supported for legacy reasons.
*/ */
UTF8PROC_DLLEXPORT utf8proc_bool utf8proc_grapheme_break( UTF8PROC_DLLEXPORT utf8proc_bool utf8proc_grapheme_break(utf8proc_int32_t codepoint1,
utf8proc_int32_t codepoint1, utf8proc_int32_t codepoint2); utf8proc_int32_t codepoint2);
/** /**
* Given a codepoint `c`, return the codepoint of the corresponding * Given a codepoint `c`, return the codepoint of the corresponding
...@@ -636,21 +672,21 @@ UTF8PROC_DLLEXPORT utf8proc_int32_t utf8proc_toupper(utf8proc_int32_t c); ...@@ -636,21 +672,21 @@ UTF8PROC_DLLEXPORT utf8proc_int32_t utf8proc_toupper(utf8proc_int32_t c);
UTF8PROC_DLLEXPORT utf8proc_int32_t utf8proc_totitle(utf8proc_int32_t c); UTF8PROC_DLLEXPORT utf8proc_int32_t utf8proc_totitle(utf8proc_int32_t c);
/** /**
* Given a codepoint `c`, return `1` if the codepoint corresponds to a lower-case character * Given a codepoint `c`, return `1` if the codepoint corresponds to a
* and `0` otherwise. * lower-case character and `0` otherwise.
*/ */
UTF8PROC_DLLEXPORT int utf8proc_islower(utf8proc_int32_t c); UTF8PROC_DLLEXPORT int utf8proc_islower(utf8proc_int32_t c);
/** /**
* Given a codepoint `c`, return `1` if the codepoint corresponds to an upper-case character * Given a codepoint `c`, return `1` if the codepoint corresponds to an
* and `0` otherwise. * upper-case character and `0` otherwise.
*/ */
UTF8PROC_DLLEXPORT int utf8proc_isupper(utf8proc_int32_t c); UTF8PROC_DLLEXPORT int utf8proc_isupper(utf8proc_int32_t c);
/** /**
* Given a codepoint, return a character width analogous to `wcwidth(codepoint)`, * Given a codepoint, return a character width analogous to
* except that a width of 0 is returned for non-printable codepoints * `wcwidth(codepoint)`, except that a width of 0 is returned for non-printable
* instead of -1 as in `wcwidth`. * codepoints instead of -1 as in `wcwidth`.
* *
* @note * @note
* If you want to check for particular types of non-printable characters, * If you want to check for particular types of non-printable characters,
...@@ -667,7 +703,7 @@ UTF8PROC_DLLEXPORT utf8proc_category_t utf8proc_category(utf8proc_int32_t codepo ...@@ -667,7 +703,7 @@ UTF8PROC_DLLEXPORT utf8proc_category_t utf8proc_category(utf8proc_int32_t codepo
* Return the two-letter (nul-terminated) Unicode category string for * Return the two-letter (nul-terminated) Unicode category string for
* the codepoint (e.g. `"Lu"` or `"Co"`). * the codepoint (e.g. `"Lu"` or `"Co"`).
*/ */
UTF8PROC_DLLEXPORT const char *utf8proc_category_string(utf8proc_int32_t codepoint); UTF8PROC_DLLEXPORT const char* utf8proc_category_string(utf8proc_int32_t codepoint);
/** /**
* Maps the given UTF-8 string pointed to by `str` to a new UTF-8 * Maps the given UTF-8 string pointed to by `str` to a new UTF-8
...@@ -688,9 +724,10 @@ UTF8PROC_DLLEXPORT const char *utf8proc_category_string(utf8proc_int32_t codepoi ...@@ -688,9 +724,10 @@ UTF8PROC_DLLEXPORT const char *utf8proc_category_string(utf8proc_int32_t codepoi
* @note The memory of the new UTF-8 string will have been allocated * @note The memory of the new UTF-8 string will have been allocated
* with `malloc`, and should therefore be deallocated with `free`. * with `malloc`, and should therefore be deallocated with `free`.
*/ */
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_map( UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_map(const utf8proc_uint8_t* str,
const utf8proc_uint8_t *str, utf8proc_ssize_t strlen, utf8proc_uint8_t **dstptr, utf8proc_option_t options utf8proc_ssize_t strlen,
); utf8proc_uint8_t** dstptr,
utf8proc_option_t options);
/** /**
* Like @ref utf8proc_map, but also takes a `custom_func` mapping function * Like @ref utf8proc_map, but also takes a `custom_func` mapping function
...@@ -698,10 +735,12 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_map( ...@@ -698,10 +735,12 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_map(
* (along with a `custom_data` pointer that is passed through to `custom_func`). * (along with a `custom_data` pointer that is passed through to `custom_func`).
* The `custom_func` argument is ignored if it is `NULL`. * The `custom_func` argument is ignored if it is `NULL`.
*/ */
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_map_custom( UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_map_custom(const utf8proc_uint8_t* str,
const utf8proc_uint8_t *str, utf8proc_ssize_t strlen, utf8proc_uint8_t **dstptr, utf8proc_option_t options, utf8proc_ssize_t strlen,
utf8proc_custom_func custom_func, void *custom_data utf8proc_uint8_t** dstptr,
); utf8proc_option_t options,
utf8proc_custom_func custom_func,
void* custom_data);
/** @name Unicode normalization /** @name Unicode normalization
* *
...@@ -712,18 +751,18 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_map_custom( ...@@ -712,18 +751,18 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_map_custom(
*/ */
/** @{ */ /** @{ */
/** NFD normalization (@ref UTF8PROC_DECOMPOSE). */ /** NFD normalization (@ref UTF8PROC_DECOMPOSE). */
UTF8PROC_DLLEXPORT utf8proc_uint8_t *utf8proc_NFD(const utf8proc_uint8_t *str); UTF8PROC_DLLEXPORT utf8proc_uint8_t* utf8proc_NFD(const utf8proc_uint8_t* str);
/** NFC normalization (@ref UTF8PROC_COMPOSE). */ /** NFC normalization (@ref UTF8PROC_COMPOSE). */
UTF8PROC_DLLEXPORT utf8proc_uint8_t *utf8proc_NFC(const utf8proc_uint8_t *str); UTF8PROC_DLLEXPORT utf8proc_uint8_t* utf8proc_NFC(const utf8proc_uint8_t* str);
/** NFKD normalization (@ref UTF8PROC_DECOMPOSE and @ref UTF8PROC_COMPAT). */ /** NFKD normalization (@ref UTF8PROC_DECOMPOSE and @ref UTF8PROC_COMPAT). */
UTF8PROC_DLLEXPORT utf8proc_uint8_t *utf8proc_NFKD(const utf8proc_uint8_t *str); UTF8PROC_DLLEXPORT utf8proc_uint8_t* utf8proc_NFKD(const utf8proc_uint8_t* str);
/** NFKC normalization (@ref UTF8PROC_COMPOSE and @ref UTF8PROC_COMPAT). */ /** NFKC normalization (@ref UTF8PROC_COMPOSE and @ref UTF8PROC_COMPAT). */
UTF8PROC_DLLEXPORT utf8proc_uint8_t *utf8proc_NFKC(const utf8proc_uint8_t *str); UTF8PROC_DLLEXPORT utf8proc_uint8_t* utf8proc_NFKC(const utf8proc_uint8_t* str);
/** /**
* NFKC_Casefold normalization (@ref UTF8PROC_COMPOSE and @ref UTF8PROC_COMPAT * NFKC_Casefold normalization (@ref UTF8PROC_COMPOSE and @ref UTF8PROC_COMPAT
* and @ref UTF8PROC_CASEFOLD and @ref UTF8PROC_IGNORE). * and @ref UTF8PROC_CASEFOLD and @ref UTF8PROC_IGNORE).
**/ **/
UTF8PROC_DLLEXPORT utf8proc_uint8_t *utf8proc_NFKC_Casefold(const utf8proc_uint8_t *str); UTF8PROC_DLLEXPORT utf8proc_uint8_t* utf8proc_NFKC_Casefold(const utf8proc_uint8_t* str);
/** @} */ /** @} */
#ifdef __cplusplus #ifdef __cplusplus
......
#include <Bert.h>
#include <Filesystem.h>
#include <SimpleLog.h>
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
#include <Bert.h>
#include <SimpleLog.h>
#include <Filesystem.h>
#include <tokenization.h> #include <tokenization.h>
int main(int argc, char *argv[]) int main(int argc, char* argv[])
{ {
// 加载Bert模型 // 加载Bert模型
migraphxSamples::Bert bert; migraphxSamples::Bert bert;
migraphxSamples::ErrorCode errorCode = bert.Initialize(); migraphxSamples::ErrorCode errorCode = bert.Initialize();
if (errorCode != migraphxSamples::SUCCESS) if(errorCode != migraphxSamples::SUCCESS)
{ {
LOG_ERROR(stdout, "fail to initialize Bert!\n"); LOG_ERROR(stdout, "fail to initialize Bert!\n");
exit(-1); exit(-1);
...@@ -25,7 +25,18 @@ int main(int argc, char *argv[]) ...@@ -25,7 +25,18 @@ int main(int argc, char *argv[])
int max_answer_length = 30; // 答案的最大长度 int max_answer_length = 30; // 答案的最大长度
// 上下文文本数据 // 上下文文本数据
const char text[] = { u8"ROCm is the first open-source exascale-class platform for accelerated computing that’s also programming-language independent. It brings a philosophy of choice, minimalism and modular software development to GPU computing. You are free to choose or even develop tools and a language run time for your application. ROCm is built for scale, it supports multi-GPU computing and has a rich system run time with the critical features that large-scale application, compiler and language-run-time development requires. Since the ROCm ecosystem is comprised of open technologies: frameworks (Tensorflow / PyTorch), libraries (MIOpen / Blas / RCCL), programming model (HIP), inter-connect (OCD) and up streamed Linux® Kernel support – the platform is continually optimized for performance and extensibility." }; const char text[] = {u8"ROCm is the first open-source exascale-class platform for accelerated "
u8"computing that’s also programming-language independent. It brings a "
u8"philosophy of choice, minimalism and modular software development to "
u8"GPU computing. You are free to choose or even develop tools and a "
u8"language run time for your application. ROCm is built for scale, it "
u8"supports multi-GPU computing and has a rich system run time with the "
u8"critical features that large-scale application, compiler and "
u8"language-run-time development requires. Since the ROCm ecosystem is "
u8"comprised of open technologies: frameworks (Tensorflow / PyTorch), "
u8"libraries (MIOpen / Blas / RCCL), programming model (HIP), "
u8"inter-connect (OCD) and up streamed Linux® Kernel support – the "
u8"platform is continually optimized for performance and extensibility."};
char question[100]; char question[100];
std::vector<std::vector<long unsigned int>> input_ids; std::vector<std::vector<long unsigned int>> input_ids;
...@@ -35,14 +46,22 @@ int main(int argc, char *argv[]) ...@@ -35,14 +46,22 @@ int main(int argc, char *argv[])
std::vector<float> end_position; std::vector<float> end_position;
std::string answer = {}; std::string answer = {};
cuBERT::FullTokenizer tokenizer = cuBERT::FullTokenizer("../Resource/uncased_L-12_H-768_A-12/vocab.txt"); // 分词工具 cuBERT::FullTokenizer tokenizer =
cuBERT::FullTokenizer("../Resource/uncased_L-12_H-768_A-12/vocab.txt"); // 分词工具
while (true) while(true)
{ {
// 数据前处理 // 数据前处理
std::cout << "question: "; std::cout << "question: ";
cin.getline(question, 100); cin.getline(question, 100);
bert.Preprocessing(tokenizer, batch_size, max_seq_length, text, question, input_ids, input_masks, segment_ids); bert.Preprocessing(tokenizer,
batch_size,
max_seq_length,
text,
question,
input_ids,
input_masks,
segment_ids);
// 推理 // 推理
bert.Inference(input_ids, input_masks, segment_ids, start_position, end_position); bert.Inference(input_ids, input_masks, segment_ids, start_position, end_position);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment