Commit aec6280a authored by liucong's avatar liucong
Browse files

对C++代码通过格式化

parent ab78f8ec
#include <Bert.h> #include <Bert.h>
#include <migraphx/onnx.hpp>
#include <migraphx/gpu/target.hpp>
#include <Filesystem.h> #include <Filesystem.h>
#include <SimpleLog.h> #include <SimpleLog.h>
#include <algorithm>
#include <stdexcept>
#include <tokenization.h> #include <tokenization.h>
namespace migraphxSamples #include <algorithm>
{ #include <migraphx/gpu/target.hpp>
#include <migraphx/onnx.hpp>
#include <stdexcept>
Bert::Bert() namespace migraphxSamples {
{
} Bert::Bert() {}
Bert::~Bert() Bert::~Bert() {}
{
}
ErrorCode Bert::Initialize() ErrorCode Bert::Initialize()
{ {
// 获取模型文件 // 获取模型文件
std::string modelPath="../Resource/bertsquad-10.onnx"; std::string modelPath = "../Resource/bertsquad-10.onnx";
// 设置最大输入shape // 设置最大输入shape
migraphx::onnx_options onnx_options; migraphx::onnx_options onnx_options;
onnx_options.map_input_dims["unique_ids_raw_output___9:0"]={1}; onnx_options.map_input_dims["unique_ids_raw_output___9:0"] = {1};
onnx_options.map_input_dims["input_ids:0"]={1,256}; onnx_options.map_input_dims["input_ids:0"] = {1, 256};
onnx_options.map_input_dims["input_mask:0"]={1,256}; onnx_options.map_input_dims["input_mask:0"] = {1, 256};
onnx_options.map_input_dims["segment_ids:0"]={1,256}; onnx_options.map_input_dims["segment_ids:0"] = {1, 256};
// 加载模型 // 加载模型
if(!Exists(modelPath)) if(!Exists(modelPath))
{ {
LOG_ERROR(stdout,"%s not exist!\n",modelPath.c_str()); LOG_ERROR(stdout, "%s not exist!\n", modelPath.c_str());
return MODEL_NOT_EXIST; return MODEL_NOT_EXIST;
} }
net = migraphx::parse_onnx(modelPath, onnx_options); net = migraphx::parse_onnx(modelPath, onnx_options);
LOG_INFO(stdout,"succeed to load model: %s\n",GetFileName(modelPath).c_str()); LOG_INFO(stdout, "succeed to load model: %s\n", GetFileName(modelPath).c_str());
// 获取模型输入/输出节点信息 // 获取模型输入/输出节点信息
std::unordered_map<std::string, migraphx::shape> inputs=net.get_inputs(); std::unordered_map<std::string, migraphx::shape> inputs = net.get_inputs();
std::unordered_map<std::string, migraphx::shape> outputs=net.get_outputs(); std::unordered_map<std::string, migraphx::shape> outputs = net.get_outputs();
inputName1 = "unique_ids_raw_output___9:0"; inputName1 = "unique_ids_raw_output___9:0";
inputShape1 = inputs.at(inputName1); inputShape1 = inputs.at(inputName1);
inputName2 = "segment_ids:0"; inputName2 = "segment_ids:0";
inputShape2 = inputs.at(inputName2); inputShape2 = inputs.at(inputName2);
inputName3 = "input_mask:0"; inputName3 = "input_mask:0";
inputShape3 = inputs.at(inputName3); inputShape3 = inputs.at(inputName3);
inputName4 = "input_ids:0"; inputName4 = "input_ids:0";
inputShape4 = inputs.at(inputName4); inputShape4 = inputs.at(inputName4);
// 设置模型为GPU模式 // 设置模型为GPU模式
...@@ -65,27 +56,27 @@ ErrorCode Bert::Initialize() ...@@ -65,27 +56,27 @@ ErrorCode Bert::Initialize()
// 编译模型 // 编译模型
migraphx::compile_options options; migraphx::compile_options options;
options.device_id=0; // 设置GPU设备,默认为0号设备 options.device_id = 0; // 设置GPU设备,默认为0号设备
options.offload_copy=true; options.offload_copy = true;
net.compile(gpuTarget,options); net.compile(gpuTarget, options);
LOG_INFO(stdout,"succeed to compile model: %s\n",GetFileName(modelPath).c_str()); LOG_INFO(stdout, "succeed to compile model: %s\n", GetFileName(modelPath).c_str());
// warm up // warm up
std::unordered_map<std::string, migraphx::argument> inputData; std::unordered_map<std::string, migraphx::argument> inputData;
inputData[inputName1]=migraphx::argument(inputShape1); inputData[inputName1] = migraphx::argument(inputShape1);
inputData[inputName2]=migraphx::argument(inputShape2); inputData[inputName2] = migraphx::argument(inputShape2);
inputData[inputName3]=migraphx::argument(inputShape3); inputData[inputName3] = migraphx::argument(inputShape3);
inputData[inputName4]=migraphx::argument(inputShape4); inputData[inputName4] = migraphx::argument(inputShape4);
net.eval(inputData); net.eval(inputData);
return SUCCESS; return SUCCESS;
} }
ErrorCode Bert::Inference(const std::vector<std::vector<long unsigned int>> &input_ids, ErrorCode Bert::Inference(const std::vector<std::vector<long unsigned int>>& input_ids,
const std::vector<std::vector<long unsigned int>> &input_masks, const std::vector<std::vector<long unsigned int>>& input_masks,
const std::vector<std::vector<long unsigned int>> &segment_ids, const std::vector<std::vector<long unsigned int>>& segment_ids,
std::vector<float> &start_position, std::vector<float>& start_position,
std::vector<float> &end_position) std::vector<float>& end_position)
{ {
// 保存预处理后的数据 // 保存预处理后的数据
int num = input_ids.size(); int num = input_ids.size();
...@@ -93,13 +84,13 @@ ErrorCode Bert::Inference(const std::vector<std::vector<long unsigned int>> &inp ...@@ -93,13 +84,13 @@ ErrorCode Bert::Inference(const std::vector<std::vector<long unsigned int>> &inp
long unsigned int input_mask[num][256]; long unsigned int input_mask[num][256];
long unsigned int segment_id[num][256]; long unsigned int segment_id[num][256];
long unsigned int position_id[num][1]; long unsigned int position_id[num][1];
for(int i=0;i<input_ids.size();++i) for(int i = 0; i < input_ids.size(); ++i)
{ {
for(int j=0;j<input_ids[0].size();++j) for(int j = 0; j < input_ids[0].size(); ++j)
{ {
input_id[i][j] = input_ids[i][j]; input_id[i][j] = input_ids[i][j];
segment_id[i][j] = segment_ids[i][j]; segment_id[i][j] = segment_ids[i][j];
input_mask[i][j] = input_masks[i][j]; input_mask[i][j] = input_masks[i][j];
position_id[i][0] = 1; position_id[i][0] = 1;
} }
} }
...@@ -111,25 +102,25 @@ ErrorCode Bert::Inference(const std::vector<std::vector<long unsigned int>> &inp ...@@ -111,25 +102,25 @@ ErrorCode Bert::Inference(const std::vector<std::vector<long unsigned int>> &inp
float* start_data; float* start_data;
float* end_data; float* end_data;
for(int i=0;i<input_ids.size();++i) for(int i = 0; i < input_ids.size(); ++i)
{ {
// 创建输入数据 // 创建输入数据
inputData[inputName1]=migraphx::argument{inputShape1, (long unsigned int*)position_id[i]}; inputData[inputName1] = migraphx::argument{inputShape1, (long unsigned int*)position_id[i]};
inputData[inputName2]=migraphx::argument{inputShape2, (long unsigned int*)segment_id[i]}; inputData[inputName2] = migraphx::argument{inputShape2, (long unsigned int*)segment_id[i]};
inputData[inputName3]=migraphx::argument{inputShape3, (long unsigned int*)input_mask[i]}; inputData[inputName3] = migraphx::argument{inputShape3, (long unsigned int*)input_mask[i]};
inputData[inputName4]=migraphx::argument{inputShape4, (long unsigned int*)input_id[i]}; inputData[inputName4] = migraphx::argument{inputShape4, (long unsigned int*)input_id[i]};
// 推理 // 推理
results = net.eval(inputData); results = net.eval(inputData);
// 获取输出节点的属性 // 获取输出节点的属性
start_prediction = results[1]; // 答案的开始位置 start_prediction = results[1]; // 答案的开始位置
start_data = (float *)start_prediction.data(); // 开始位置的数据指针 start_data = (float*)start_prediction.data(); // 开始位置的数据指针
end_prediction = results[0]; // 答案的结束位置 end_prediction = results[0]; // 答案的结束位置
end_data = (float *)end_prediction.data(); // 结束位置的数据指针 end_data = (float*)end_prediction.data(); // 结束位置的数据指针
// 保存推理结果 // 保存推理结果
for(int i=0;i<256;++i) for(int i = 0; i < 256; ++i)
{ {
start_position.push_back(start_data[i]); start_position.push_back(start_data[i]);
end_position.push_back(end_data[i]); end_position.push_back(end_data[i]);
...@@ -140,39 +131,39 @@ ErrorCode Bert::Inference(const std::vector<std::vector<long unsigned int>> &inp ...@@ -140,39 +131,39 @@ ErrorCode Bert::Inference(const std::vector<std::vector<long unsigned int>> &inp
} }
ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer, ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer,
int batch_size, int batch_size,
int max_seq_length, int max_seq_length,
const char *text, const char* text,
char *question, char* question,
std::vector<std::vector<long unsigned int>> &input_ids, std::vector<std::vector<long unsigned int>>& input_ids,
std::vector<std::vector<long unsigned int>> &input_masks, std::vector<std::vector<long unsigned int>>& input_masks,
std::vector<std::vector<long unsigned int>> &segment_ids) std::vector<std::vector<long unsigned int>>& segment_ids)
{ {
std::vector<long unsigned int> input_id(max_seq_length); std::vector<long unsigned int> input_id(max_seq_length);
std::vector<long unsigned int> input_mask(max_seq_length); std::vector<long unsigned int> input_mask(max_seq_length);
std::vector<long unsigned int> segment_id(max_seq_length); std::vector<long unsigned int> segment_id(max_seq_length);
// 对上下文文本和问题进行分词操作 // 对上下文文本和问题进行分词操作
tokens_text.reserve(max_seq_length); tokens_text.reserve(max_seq_length);
tokens_question.reserve(max_seq_length); tokens_question.reserve(max_seq_length);
tokenizer.tokenize(text, &tokens_text, max_seq_length); tokenizer.tokenize(text, &tokens_text, max_seq_length);
tokenizer.tokenize(question, &tokens_question, max_seq_length); tokenizer.tokenize(question, &tokens_question, max_seq_length);
// 当上下文文本加问题文本的长度大于规定的最大长度,采用滑动窗口操作 // 当上下文文本加问题文本的长度大于规定的最大长度,采用滑动窗口操作
if(tokens_text.size() + tokens_question.size() > max_seq_length - 5) if(tokens_text.size() + tokens_question.size() > max_seq_length - 5)
{ {
int windows_len = max_seq_length - 5 -tokens_question.size(); int windows_len = max_seq_length - 5 - tokens_question.size();
std::vector<std::string> tokens_text_window(windows_len); std::vector<std::string> tokens_text_window(windows_len);
std::vector<std::vector<std::string>> tokens_text_windows; std::vector<std::vector<std::string>> tokens_text_windows;
int start_offset = 0; int start_offset = 0;
int position = 0; int position = 0;
int n; int n;
while (start_offset < tokens_text.size()) while(start_offset < tokens_text.size())
{ {
n = 0; n = 0;
if(start_offset+windows_len>tokens_text.size()) if(start_offset + windows_len > tokens_text.size())
{ {
for(int i=start_offset;i<tokens_text.size();++i) for(int i = start_offset; i < tokens_text.size(); ++i)
{ {
tokens_text_window[n] = tokens_text[i]; tokens_text_window[n] = tokens_text[i];
++n; ++n;
...@@ -180,44 +171,46 @@ ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer, ...@@ -180,44 +171,46 @@ ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer,
} }
else else
{ {
for(int i=start_offset;i<start_offset+windows_len;++i) for(int i = start_offset; i < start_offset + windows_len; ++i)
{ {
tokens_text_window[n] = tokens_text[i]; tokens_text_window[n] = tokens_text[i];
++n; ++n;
} }
} }
tokens_text_windows.push_back(tokens_text_window); tokens_text_windows.push_back(tokens_text_window);
start_offset += 256; start_offset += 256;
++position; ++position;
} }
for(int i=0;i<position;++i) for(int i = 0; i < position; ++i)
{ {
input_id[0] = tokenizer.convert_token_to_id("[CLS]"); input_id[0] = tokenizer.convert_token_to_id("[CLS]");
segment_id[0] = 0; segment_id[0] = 0;
input_id[1] = tokenizer.convert_token_to_id("[CLS]"); input_id[1] = tokenizer.convert_token_to_id("[CLS]");
segment_id[1] = 0; segment_id[1] = 0;
for (int j=0;j<tokens_question.size();++j) for(int j = 0; j < tokens_question.size(); ++j)
{ {
input_id[j + 2] = tokenizer.convert_token_to_id(tokens_question[j]); input_id[j + 2] = tokenizer.convert_token_to_id(tokens_question[j]);
segment_id[j + 2] = 0; segment_id[j + 2] = 0;
} }
input_id[tokens_question.size() + 2] = tokenizer.convert_token_to_id("[SEP]"); input_id[tokens_question.size() + 2] = tokenizer.convert_token_to_id("[SEP]");
segment_id[tokens_question.size() + 2] = 0; segment_id[tokens_question.size() + 2] = 0;
input_id[tokens_question.size() + 3] = tokenizer.convert_token_to_id("[SEP]"); input_id[tokens_question.size() + 3] = tokenizer.convert_token_to_id("[SEP]");
segment_id[tokens_question.size() + 3] = 0; segment_id[tokens_question.size() + 3] = 0;
for (int j=0;j<tokens_question.size();++j) for(int j = 0; j < tokens_question.size(); ++j)
{ {
input_id[j + tokens_text_windows[i].size() + 4] = tokenizer.convert_token_to_id(tokens_text_windows[i][j]); input_id[j + tokens_text_windows[i].size() + 4] =
tokenizer.convert_token_to_id(tokens_text_windows[i][j]);
segment_id[j + tokens_text_windows[i].size() + 4] = 1; segment_id[j + tokens_text_windows[i].size() + 4] = 1;
} }
input_id[tokens_question.size() + tokens_text_windows[i].size() + 4] = tokenizer.convert_token_to_id("[SEP]"); input_id[tokens_question.size() + tokens_text_windows[i].size() + 4] =
tokenizer.convert_token_to_id("[SEP]");
segment_id[tokens_question.size() + tokens_text_windows[i].size() + 4] = 1; segment_id[tokens_question.size() + tokens_text_windows[i].size() + 4] = 1;
// 掩码为1的表示为真实标记,0表示为填充标记。 // 掩码为1的表示为真实标记,0表示为填充标记。
...@@ -234,31 +227,33 @@ ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer, ...@@ -234,31 +227,33 @@ ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer,
else else
{ {
// 当上下文文本加问题文本的长度小于等于规定的最大长度,直接拼接处理 // 当上下文文本加问题文本的长度小于等于规定的最大长度,直接拼接处理
input_id[0] = tokenizer.convert_token_to_id("[CLS]"); input_id[0] = tokenizer.convert_token_to_id("[CLS]");
segment_id[0] = 0; segment_id[0] = 0;
input_id[1] = tokenizer.convert_token_to_id("[CLS]"); input_id[1] = tokenizer.convert_token_to_id("[CLS]");
segment_id[1] = 0; segment_id[1] = 0;
for (int i=0;i<tokens_question.size();++i) for(int i = 0; i < tokens_question.size(); ++i)
{ {
input_id[i + 2] = tokenizer.convert_token_to_id(tokens_question[i]); input_id[i + 2] = tokenizer.convert_token_to_id(tokens_question[i]);
segment_id[i + 2] = 0; segment_id[i + 2] = 0;
} }
input_id[tokens_question.size() + 2] = tokenizer.convert_token_to_id("[SEP]"); input_id[tokens_question.size() + 2] = tokenizer.convert_token_to_id("[SEP]");
segment_id[tokens_question.size() + 2] = 0; segment_id[tokens_question.size() + 2] = 0;
input_id[tokens_question.size() + 3] = tokenizer.convert_token_to_id("[SEP]"); input_id[tokens_question.size() + 3] = tokenizer.convert_token_to_id("[SEP]");
segment_id[tokens_question.size() + 3] = 0; segment_id[tokens_question.size() + 3] = 0;
for (int i=0;i<tokens_text.size();++i) for(int i = 0; i < tokens_text.size(); ++i)
{ {
input_id[i + tokens_question.size() + 4] = tokenizer.convert_token_to_id(tokens_text[i]); input_id[i + tokens_question.size() + 4] =
tokenizer.convert_token_to_id(tokens_text[i]);
segment_id[i + tokens_question.size() + 4] = 1; segment_id[i + tokens_question.size() + 4] = 1;
} }
input_id[tokens_question.size() + tokens_text.size() + 4] = tokenizer.convert_token_to_id("[SEP]"); input_id[tokens_question.size() + tokens_text.size() + 4] =
tokenizer.convert_token_to_id("[SEP]");
segment_id[tokens_question.size() + tokens_text.size() + 4] = 1; segment_id[tokens_question.size() + tokens_text.size() + 4] = 1;
// 掩码为1的表示为真实标记,0表示为填充标记。 // 掩码为1的表示为真实标记,0表示为填充标记。
...@@ -271,48 +266,46 @@ ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer, ...@@ -271,48 +266,46 @@ ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer,
input_masks.push_back(input_mask); input_masks.push_back(input_mask);
segment_ids.push_back(segment_id); segment_ids.push_back(segment_id);
} }
return SUCCESS; return SUCCESS;
} }
static bool Compare(Sort_st a, Sort_st b) static bool Compare(Sort_st a, Sort_st b) { return a.value > b.value; }
{
return a.value > b.value;
}
static bool CompareM(ResultOfPredictions a, ResultOfPredictions b) static bool CompareM(ResultOfPredictions a, ResultOfPredictions b)
{ {
return a.start_predictionvalue + a.end_predictionvalue > b.start_predictionvalue + b.end_predictionvalue; return a.start_predictionvalue + a.end_predictionvalue >
b.start_predictionvalue + b.end_predictionvalue;
} }
ErrorCode Bert::Postprocessing(int n_best_size, ErrorCode Bert::Postprocessing(int n_best_size,
int max_answer_length, int max_answer_length,
const std::vector<float> &start_position, const std::vector<float>& start_position,
const std::vector<float> &end_position, const std::vector<float>& end_position,
std::string &answer) std::string& answer)
{ {
// 取前n_best_size个最大概率值的索引 // 取前n_best_size个最大概率值的索引
std::vector<Sort_st> start_array(start_position.size()); std::vector<Sort_st> start_array(start_position.size());
std::vector<Sort_st> end_array(end_position.size()); std::vector<Sort_st> end_array(end_position.size());
for (int i=0;i<start_position.size();++i) for(int i = 0; i < start_position.size(); ++i)
{ {
start_array[i].index = i; start_array[i].index = i;
start_array[i].value = start_position.at(i); start_array[i].value = start_position.at(i);
end_array[i].index = i; end_array[i].index = i;
end_array[i].value = end_position.at(i); end_array[i].value = end_position.at(i);
} }
std::sort(start_array.begin(), start_array.end(), Compare); std::sort(start_array.begin(), start_array.end(), Compare);
std::sort(end_array.begin(), end_array.end(), Compare); std::sort(end_array.begin(), end_array.end(), Compare);
// 过滤和筛选,筛选掉不符合的索引 // 过滤和筛选,筛选掉不符合的索引
std::vector<ResultOfPredictions> resultsOfPredictions(400); std::vector<ResultOfPredictions> resultsOfPredictions(400);
int num = start_position.size() / 256; int num = start_position.size() / 256;
bool flag; bool flag;
int n=0; int n = 0;
for(int i=0;i<n_best_size;++i) for(int i = 0; i < n_best_size; ++i)
{ {
for(int j=0;j<n_best_size;++j) for(int j = 0; j < n_best_size; ++j)
{ {
flag = false; flag = false;
if(start_array[i].index > start_position.size()) if(start_array[i].index > start_position.size())
...@@ -325,15 +318,17 @@ ErrorCode Bert::Postprocessing(int n_best_size, ...@@ -325,15 +318,17 @@ ErrorCode Bert::Postprocessing(int n_best_size,
continue; continue;
} }
for(int t=0;t<num;++t) for(int t = 0; t < num; ++t)
{ {
if(start_array[i].index > t*256 && start_array[i].index < tokens_question.size()+4+t*256) if(start_array[i].index > t * 256 &&
start_array[i].index < tokens_question.size() + 4 + t * 256)
{ {
flag = true; flag = true;
break; break;
} }
if(end_array[j].index > t*256 && end_array[j].index < tokens_question.size()+4+t*256) if(end_array[j].index > t * 256 &&
end_array[j].index < tokens_question.size() + 4 + t * 256)
{ {
flag = true; flag = true;
break; break;
...@@ -349,71 +344,71 @@ ErrorCode Bert::Postprocessing(int n_best_size, ...@@ -349,71 +344,71 @@ ErrorCode Bert::Postprocessing(int n_best_size,
continue; continue;
} }
int length = end_array[j].index - start_array[i].index + 1; int length = end_array[j].index - start_array[i].index + 1;
if(length > max_answer_length) if(length > max_answer_length)
{ {
continue; continue;
} }
resultsOfPredictions[n].start_index = start_array[i].index; resultsOfPredictions[n].start_index = start_array[i].index;
resultsOfPredictions[n].end_index = end_array[j].index; resultsOfPredictions[n].end_index = end_array[j].index;
resultsOfPredictions[n].start_predictionvalue = start_array[i].value; resultsOfPredictions[n].start_predictionvalue = start_array[i].value;
resultsOfPredictions[n].end_predictionvalue = end_array[j].value; resultsOfPredictions[n].end_predictionvalue = end_array[j].value;
++n; ++n;
} }
} }
// 排序,将开始索引加结束索引的概率值和最大的排在前面 // 排序,将开始索引加结束索引的概率值和最大的排在前面
std::sort(resultsOfPredictions.begin(), resultsOfPredictions.end(), CompareM); std::sort(resultsOfPredictions.begin(), resultsOfPredictions.end(), CompareM);
int start_index = 0; int start_index = 0;
int end_index = 0; int end_index = 0;
for(int i=0;i<400;++i) for(int i = 0; i < 400; ++i)
{ {
if(resultsOfPredictions[i].start_predictionvalue==0 && resultsOfPredictions[i].end_predictionvalue==0) if(resultsOfPredictions[i].start_predictionvalue == 0 &&
resultsOfPredictions[i].end_predictionvalue == 0)
{ {
continue; continue;
} }
start_index = resultsOfPredictions[i].start_index; start_index = resultsOfPredictions[i].start_index;
end_index = resultsOfPredictions[i].end_index; end_index = resultsOfPredictions[i].end_index;
break; break;
} }
// 映射回上下文文本的索引,(当前的索引值-问题的长度-4) // 映射回上下文文本的索引,(当前的索引值-问题的长度-4)
int answer_start_index = start_index - tokens_question.size()- 4; int answer_start_index = start_index - tokens_question.size() - 4;
int answer_end_index = end_index - tokens_question.size() - 4 + 1; int answer_end_index = end_index - tokens_question.size() - 4 + 1;
// 根据开始索引和结束索引,获取区间内的数据 // 根据开始索引和结束索引,获取区间内的数据
int j=0; int j = 0;
for(int i=answer_start_index;i<answer_end_index;++i) for(int i = answer_start_index; i < answer_end_index; ++i)
{ {
if(tokens_text[i].find('#') != -1) if(tokens_text[i].find('#') != -1)
{ {
j=i-1; j = i - 1;
break; break;
} }
} }
for(int i=answer_start_index;i<answer_end_index;++i) for(int i = answer_start_index; i < answer_end_index; ++i)
{ {
answer += tokens_text[i]; answer += tokens_text[i];
if(tokens_text[i].find('#') != -1 || i==j) if(tokens_text[i].find('#') != -1 || i == j)
{ {
continue; continue;
} }
answer += " "; answer += " ";
} }
int index = 0; int index = 0;
while( (index = answer.find('#',index)) != string::npos) while((index = answer.find('#', index)) != string::npos)
{ {
answer.erase(index,1); answer.erase(index, 1);
} }
tokens_text.clear(); tokens_text.clear();
tokens_question.clear(); tokens_question.clear();
return SUCCESS; return SUCCESS;
} }
} } // namespace migraphxSamples
#ifndef __BERT_H__ #ifndef __BERT_H__
#define __BERT_H__ #define __BERT_H__
#include <tokenization.h>
#include <cstdint> #include <cstdint>
#include <string>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <tokenization.h> #include <string>
namespace migraphxSamples namespace migraphxSamples {
typedef enum _ErrorCode
{ {
typedef enum _ErrorCode SUCCESS = 0,
{ MODEL_NOT_EXIST,
SUCCESS=0, CONFIG_FILE_NOT_EXIST,
MODEL_NOT_EXIST, FAIL_TO_LOAD_MODEL,
CONFIG_FILE_NOT_EXIST, FAIL_TO_OPEN_CONFIG_FILE,
FAIL_TO_LOAD_MODEL, } ErrorCode;
FAIL_TO_OPEN_CONFIG_FILE,
}ErrorCode;
typedef struct _Sort_st typedef struct _Sort_st
{ {
int index; int index;
float value; float value;
}Sort_st; } Sort_st;
typedef struct _ResultOfPredictions typedef struct _ResultOfPredictions
{ {
int start_index; int start_index;
int end_index; int end_index;
float start_predictionvalue; float start_predictionvalue;
float end_predictionvalue; float end_predictionvalue;
}ResultOfPredictions; } ResultOfPredictions;
class Bert class Bert
{ {
public: public:
Bert(); Bert();
~Bert(); ~Bert();
ErrorCode Initialize(); ErrorCode Initialize();
ErrorCode Inference(const std::vector<std::vector<long unsigned int>> &input_ids, ErrorCode Inference(const std::vector<std::vector<long unsigned int>>& input_ids,
const std::vector<std::vector<long unsigned int>> &input_masks, const std::vector<std::vector<long unsigned int>>& input_masks,
const std::vector<std::vector<long unsigned int>> &segment_ids, const std::vector<std::vector<long unsigned int>>& segment_ids,
std::vector<float> &start_position, std::vector<float>& start_position,
std::vector<float> &end_position); std::vector<float>& end_position);
ErrorCode Preprocessing(cuBERT::FullTokenizer tokenizer, ErrorCode Preprocessing(cuBERT::FullTokenizer tokenizer,
int batch_size, int batch_size,
int max_seq_length, int max_seq_length,
const char *text, const char* text,
char *question, char* question,
std::vector<std::vector<long unsigned int>> &input_ids, std::vector<std::vector<long unsigned int>>& input_ids,
std::vector<std::vector<long unsigned int>> &input_masks, std::vector<std::vector<long unsigned int>>& input_masks,
std::vector<std::vector<long unsigned int>> &segment_ids); std::vector<std::vector<long unsigned int>>& segment_ids);
ErrorCode Postprocessing(int n_best_size, ErrorCode Postprocessing(int n_best_size,
int max_answer_length, int max_answer_length,
const std::vector<float> &start_position, const std::vector<float>& start_position,
const std::vector<float> &end_position, const std::vector<float>& end_position,
std::string &answer); std::string& answer);
private: private:
std::vector<std::string> tokens_text; std::vector<std::string> tokens_text;
std::vector<std::string> tokens_question; std::vector<std::string> tokens_question;
...@@ -74,9 +74,8 @@ private: ...@@ -74,9 +74,8 @@ private:
migraphx::shape inputShape2; migraphx::shape inputShape2;
migraphx::shape inputShape3; migraphx::shape inputShape3;
migraphx::shape inputShape4; migraphx::shape inputShape4;
}; };
} } // namespace migraphxSamples
#endif #endif
\ No newline at end of file
#include <Filesystem.h> #include <Filesystem.h>
#include <algorithm>
#include <sys/stat.h> #include <sys/stat.h>
#include <sys/types.h> #include <sys/types.h>
#include <algorithm>
#include <fstream> #include <fstream>
#ifdef _WIN32 #ifdef _WIN32
#include <io.h>
#include <direct.h>
#include <Windows.h> #include <Windows.h>
#include <direct.h>
#include <io.h>
#else #else
#include <unistd.h>
#include <dirent.h> #include <dirent.h>
#include <unistd.h>
#endif #endif
// 路径分隔符(Linux:‘/’,Windows:’\\’) // 路径分隔符(Linux:‘/’,Windows:’\\’)
#ifdef _WIN32 #ifdef _WIN32
#define PATH_SEPARATOR '\\' #define PATH_SEPARATOR '\\'
#else #else
#define PATH_SEPARATOR '/' #define PATH_SEPARATOR '/'
#endif #endif
using namespace std; using namespace std;
namespace migraphxSamples namespace migraphxSamples {
{
static std::vector<std::string> SplitString(std::string str, std::string separator) static std::vector<std::string> SplitString(std::string str, std::string separator)
{ {
std::string::size_type pos; std::string::size_type pos;
std::vector<std::string> result; std::vector<std::string> result;
str+=separator;//扩展字符串以方便操作 str += separator; // 扩展字符串以方便操作
int size=str.size(); int size = str.size();
for(int i=0; i<size; i++) for(int i = 0; i < size; i++)
{ {
pos=str.find(separator,i); pos = str.find(separator, i);
if(pos<size) if(pos < size)
{ {
std::string s=str.substr(i,pos-i); std::string s = str.substr(i, pos - i);
result.push_back(s); result.push_back(s);
i=pos+separator.size()-1; i = pos + separator.size() - 1;
} }
} }
return result; return result;
} }
#if defined _WIN32 || defined WINCE #if defined _WIN32 || defined WINCE
const char dir_separators[] = "/\\"; const char dir_separators[] = "/\\";
struct dirent struct dirent
{ {
const char* d_name; const char* d_name;
}; };
struct DIR struct DIR
{ {
#ifdef WINRT #ifdef WINRT
WIN32_FIND_DATAW data; WIN32_FIND_DATAW data;
#else #else
WIN32_FIND_DATAA data; WIN32_FIND_DATAA data;
#endif #endif
HANDLE handle; HANDLE handle;
dirent ent; dirent ent;
#ifdef WINRT #ifdef WINRT
DIR() { } DIR() {}
~DIR() ~DIR()
{ {
if (ent.d_name) if(ent.d_name)
delete[] ent.d_name; delete[] ent.d_name;
} }
#endif #endif
}; };
DIR* opendir(const char* path) DIR* opendir(const char* path)
{ {
DIR* dir = new DIR; DIR* dir = new DIR;
dir->ent.d_name = 0; dir->ent.d_name = 0;
#ifdef WINRT #ifdef WINRT
string full_path = string(path) + "\\*"; string full_path = string(path) + "\\*";
wchar_t wfull_path[MAX_PATH]; wchar_t wfull_path[MAX_PATH];
size_t copied = mbstowcs(wfull_path, full_path.c_str(), MAX_PATH); size_t copied = mbstowcs(wfull_path, full_path.c_str(), MAX_PATH);
CV_Assert((copied != MAX_PATH) && (copied != (size_t)-1)); CV_Assert((copied != MAX_PATH) && (copied != (size_t)-1));
dir->handle = ::FindFirstFileExW(wfull_path, FindExInfoStandard, dir->handle = ::FindFirstFileExW(
&dir->data, FindExSearchNameMatch, NULL, 0); wfull_path, FindExInfoStandard, &dir->data, FindExSearchNameMatch, NULL, 0);
#else #else
dir->handle = ::FindFirstFileExA((string(path) + "\\*").c_str(), dir->handle = ::FindFirstFileExA((string(path) + "\\*").c_str(),
FindExInfoStandard, &dir->data, FindExSearchNameMatch, NULL, 0); FindExInfoStandard,
&dir->data,
FindExSearchNameMatch,
NULL,
0);
#endif #endif
if (dir->handle == INVALID_HANDLE_VALUE) if(dir->handle == INVALID_HANDLE_VALUE)
{ {
/*closedir will do all cleanup*/ /*closedir will do all cleanup*/
delete dir; delete dir;
return 0; return 0;
} }
return dir; return dir;
} }
dirent* readdir(DIR* dir) dirent* readdir(DIR* dir)
{ {
#ifdef WINRT #ifdef WINRT
if (dir->ent.d_name != 0) if(dir->ent.d_name != 0)
{ {
if (::FindNextFileW(dir->handle, &dir->data) != TRUE) if(::FindNextFileW(dir->handle, &dir->data) != TRUE)
return 0; return 0;
} }
size_t asize = wcstombs(NULL, dir->data.cFileName, 0); size_t asize = wcstombs(NULL, dir->data.cFileName, 0);
CV_Assert((asize != 0) && (asize != (size_t)-1)); CV_Assert((asize != 0) && (asize != (size_t)-1));
char* aname = new char[asize + 1]; char* aname = new char[asize + 1];
aname[asize] = 0; aname[asize] = 0;
wcstombs(aname, dir->data.cFileName, asize); wcstombs(aname, dir->data.cFileName, asize);
dir->ent.d_name = aname; dir->ent.d_name = aname;
#else #else
if (dir->ent.d_name != 0) if(dir->ent.d_name != 0)
{ {
if (::FindNextFileA(dir->handle, &dir->data) != TRUE) if(::FindNextFileA(dir->handle, &dir->data) != TRUE)
return 0; return 0;
} }
dir->ent.d_name = dir->data.cFileName; dir->ent.d_name = dir->data.cFileName;
#endif #endif
return &dir->ent; return &dir->ent;
} }
void closedir(DIR* dir) void closedir(DIR* dir)
{ {
::FindClose(dir->handle); ::FindClose(dir->handle);
delete dir; delete dir;
} }
#else #else
# include <dirent.h> #include <dirent.h>
# include <sys/stat.h> #include <sys/stat.h>
const char dir_separators[] = "/"; const char dir_separators[] = "/";
#endif #endif
static bool isDir(const string &path, DIR* dir) static bool isDir(const string& path, DIR* dir)
{ {
#if defined _WIN32 || defined WINCE #if defined _WIN32 || defined WINCE
DWORD attributes; DWORD attributes;
BOOL status = TRUE; BOOL status = TRUE;
if (dir) if(dir)
attributes = dir->data.dwFileAttributes; attributes = dir->data.dwFileAttributes;
else else
{ {
WIN32_FILE_ATTRIBUTE_DATA all_attrs; WIN32_FILE_ATTRIBUTE_DATA all_attrs;
#ifdef WINRT #ifdef WINRT
wchar_t wpath[MAX_PATH]; wchar_t wpath[MAX_PATH];
size_t copied = mbstowcs(wpath, path.c_str(), MAX_PATH); size_t copied = mbstowcs(wpath, path.c_str(), MAX_PATH);
CV_Assert((copied != MAX_PATH) && (copied != (size_t)-1)); CV_Assert((copied != MAX_PATH) && (copied != (size_t)-1));
status = ::GetFileAttributesExW(wpath, GetFileExInfoStandard, &all_attrs); status = ::GetFileAttributesExW(wpath, GetFileExInfoStandard, &all_attrs);
#else #else
status = ::GetFileAttributesExA(path.c_str(), GetFileExInfoStandard, &all_attrs); status = ::GetFileAttributesExA(path.c_str(), GetFileExInfoStandard, &all_attrs);
#endif #endif
attributes = all_attrs.dwFileAttributes; attributes = all_attrs.dwFileAttributes;
} }
return status && ((attributes & FILE_ATTRIBUTE_DIRECTORY) != 0); return status && ((attributes & FILE_ATTRIBUTE_DIRECTORY) != 0);
#else #else
(void)dir; (void)dir;
struct stat stat_buf; struct stat stat_buf;
if (0 != stat(path.c_str(), &stat_buf)) if(0 != stat(path.c_str(), &stat_buf))
return false; return false;
int is_dir = S_ISDIR(stat_buf.st_mode); int is_dir = S_ISDIR(stat_buf.st_mode);
return is_dir != 0; return is_dir != 0;
#endif #endif
} }
bool IsDirectory(const string &path)
{
return isDir(path, NULL);
}
bool Exists(const string& path) bool IsDirectory(const string& path) { return isDir(path, NULL); }
{
bool Exists(const string& path)
{
#if defined _WIN32 || defined WINCE #if defined _WIN32 || defined WINCE
BOOL status = TRUE; BOOL status = TRUE;
{ {
WIN32_FILE_ATTRIBUTE_DATA all_attrs; WIN32_FILE_ATTRIBUTE_DATA all_attrs;
#ifdef WINRT #ifdef WINRT
wchar_t wpath[MAX_PATH]; wchar_t wpath[MAX_PATH];
size_t copied = mbstowcs(wpath, path.c_str(), MAX_PATH); size_t copied = mbstowcs(wpath, path.c_str(), MAX_PATH);
CV_Assert((copied != MAX_PATH) && (copied != (size_t)-1)); CV_Assert((copied != MAX_PATH) && (copied != (size_t)-1));
status = ::GetFileAttributesExW(wpath, GetFileExInfoStandard, &all_attrs); status = ::GetFileAttributesExW(wpath, GetFileExInfoStandard, &all_attrs);
#else #else
status = ::GetFileAttributesExA(path.c_str(), GetFileExInfoStandard, &all_attrs); status = ::GetFileAttributesExA(path.c_str(), GetFileExInfoStandard, &all_attrs);
#endif #endif
} }
return !!status; return !!status;
#else #else
struct stat stat_buf; struct stat stat_buf;
return (0 == stat(path.c_str(), &stat_buf)); return (0 == stat(path.c_str(), &stat_buf));
#endif #endif
} }
bool IsPathSeparator(char c) bool IsPathSeparator(char c) { return c == '/' || c == '\\'; }
{
return c == '/' || c == '\\'; string JoinPath(const string& base, const string& path)
} {
if(base.empty())
string JoinPath(const string& base, const string& path) return path;
{ if(path.empty())
if (base.empty()) return base;
return path;
if (path.empty()) bool baseSep = IsPathSeparator(base[base.size() - 1]);
return base; bool pathSep = IsPathSeparator(path[0]);
string result;
bool baseSep = IsPathSeparator(base[base.size() - 1]); if(baseSep && pathSep)
bool pathSep = IsPathSeparator(path[0]); {
string result; result = base + path.substr(1);
if (baseSep && pathSep) }
{ else if(!baseSep && !pathSep)
result = base + path.substr(1); {
} result = base + PATH_SEPARATOR + path;
else if (!baseSep && !pathSep) }
{ else
result = base + PATH_SEPARATOR + path; {
} result = base + path;
else }
{ return result;
result = base + path; }
}
return result; static bool wildcmp(const char* string, const char* wild)
} {
const char *cp = 0, *mp = 0;
static bool wildcmp(const char *string, const char *wild)
{ while((*string) && (*wild != '*'))
const char *cp = 0, *mp = 0; {
if((*wild != *string) && (*wild != '?'))
while ((*string) && (*wild != '*')) {
{ return false;
if ((*wild != *string) && (*wild != '?')) }
{
return false; wild++;
} string++;
}
wild++;
string++; while(*string)
} {
if(*wild == '*')
while (*string) {
{ if(!*++wild)
if (*wild == '*') {
{ return true;
if (!*++wild) }
{
return true; mp = wild;
} cp = string + 1;
}
mp = wild; else if((*wild == *string) || (*wild == '?'))
cp = string + 1; {
} wild++;
else if ((*wild == *string) || (*wild == '?')) string++;
{ }
wild++; else
string++; {
} wild = mp;
else string = cp++;
{ }
wild = mp; }
string = cp++;
} while(*wild == '*')
} {
wild++;
while (*wild == '*') }
{
wild++; return *wild == 0;
} }
return *wild == 0; static void glob_rec(const string& directory,
} const string& wildchart,
std::vector<string>& result,
static void glob_rec(const string &directory, const string& wildchart, std::vector<string>& result, bool recursive,
bool recursive, bool includeDirectories, const string& pathPrefix) bool includeDirectories,
{ const string& pathPrefix)
DIR *dir; {
DIR* dir;
if ((dir = opendir(directory.c_str())) != 0)
{ if((dir = opendir(directory.c_str())) != 0)
/* find all the files and directories within directory */ {
try /* find all the files and directories within directory */
{ try
struct dirent *ent; {
while ((ent = readdir(dir)) != 0) struct dirent* ent;
{ while((ent = readdir(dir)) != 0)
const char* name = ent->d_name; {
if ((name[0] == 0) || (name[0] == '.' && name[1] == 0) || (name[0] == '.' && name[1] == '.' && name[2] == 0)) const char* name = ent->d_name;
continue; if((name[0] == 0) || (name[0] == '.' && name[1] == 0) ||
(name[0] == '.' && name[1] == '.' && name[2] == 0))
string path = JoinPath(directory, name); continue;
string entry = JoinPath(pathPrefix, name);
string path = JoinPath(directory, name);
if (isDir(path, dir)) string entry = JoinPath(pathPrefix, name);
{
if (recursive) if(isDir(path, dir))
glob_rec(path, wildchart, result, recursive, includeDirectories, entry); {
if (!includeDirectories) if(recursive)
continue; glob_rec(path, wildchart, result, recursive, includeDirectories, entry);
} if(!includeDirectories)
continue;
if (wildchart.empty() || wildcmp(name, wildchart.c_str())) }
result.push_back(entry);
} if(wildchart.empty() || wildcmp(name, wildchart.c_str()))
} result.push_back(entry);
catch (...) }
{ }
closedir(dir); catch(...)
throw; {
} closedir(dir);
closedir(dir); throw;
} }
else closedir(dir);
{ }
printf("could not open directory: %s", directory.c_str()); else
} {
} printf("could not open directory: %s", directory.c_str());
}
void GetFileNameList(const string &directory, const string &pattern, std::vector<string>& result, bool recursive, bool addPath) }
{
// split pattern void GetFileNameList(const string& directory,
vector<string> patterns=SplitString(pattern,","); const string& pattern,
std::vector<string>& result,
result.clear(); bool recursive,
bool addPath)
for(int i=0;i<patterns.size();++i) {
// split pattern
vector<string> patterns = SplitString(pattern, ",");
result.clear();
for(int i = 0; i < patterns.size(); ++i)
{
string eachPattern = patterns[i];
std::vector<string> eachResult;
glob_rec(directory, eachPattern, eachResult, recursive, true, directory);
for(int j = 0; j < eachResult.size(); ++j)
{
if(IsDirectory(eachResult[j]))
continue;
if(addPath)
{
result.push_back(eachResult[j]);
}
else
{
result.push_back(GetFileName(eachResult[j]));
}
}
}
std::sort(result.begin(), result.end());
}
void GetFileNameList2(const string& directory,
const string& pattern,
std::vector<string>& result,
bool recursive,
bool addPath)
{
// split pattern
vector<string> patterns = SplitString(pattern, ",");
result.clear();
for(int i = 0; i < patterns.size(); ++i)
{
string eachPattern = patterns[i];
std::vector<string> eachResult;
glob_rec(directory, eachPattern, eachResult, recursive, true, directory);
for(int j = 0; j < eachResult.size(); ++j)
{ {
string eachPattern=patterns[i]; string filePath = eachResult[j];
std::vector<string> eachResult; if(IsDirectory(filePath))
glob_rec(directory, eachPattern, eachResult, recursive, true, directory);
for(int j=0;j<eachResult.size();++j)
{ {
if (IsDirectory(eachResult[j])) filePath = filePath + "/";
continue; for(int k = 0; k < filePath.size(); ++k)
if(addPath)
{ {
result.push_back(eachResult[j]); if(IsPathSeparator(filePath[k]))
{
filePath[k] = '/';
}
} }
else }
if(addPath)
{
result.push_back(filePath);
}
else
{
if(!IsDirectory(filePath))
{ {
result.push_back(GetFileName(eachResult[j])); result.push_back(GetFileName(filePath));
} }
} }
} }
std::sort(result.begin(), result.end()); }
} std::sort(result.begin(), result.end());
}
void GetFileNameList2(const string &directory, const string &pattern, std::vector<string>& result, bool recursive, bool addPath)
{ void RemoveAll(const string& path)
// split pattern {
vector<string> patterns = SplitString(pattern, ","); if(!Exists(path))
return;
result.clear();
if(IsDirectory(path))
for (int i = 0; i<patterns.size(); ++i) {
{ std::vector<string> entries;
string eachPattern = patterns[i]; GetFileNameList2(path, string(), entries, false, true);
std::vector<string> eachResult; for(size_t i = 0; i < entries.size(); i++)
glob_rec(directory, eachPattern, eachResult, recursive, true, directory); {
for (int j = 0; j<eachResult.size(); ++j) const string& e = entries[i];
{ RemoveAll(e);
string filePath = eachResult[j]; }
if (IsDirectory(filePath))
{
filePath = filePath + "/";
for (int k = 0; k < filePath.size(); ++k)
{
if (IsPathSeparator(filePath[k]))
{
filePath[k] = '/';
}
}
}
if (addPath)
{
result.push_back(filePath);
}
else
{
if (!IsDirectory(filePath))
{
result.push_back(GetFileName(filePath));
}
}
}
}
std::sort(result.begin(), result.end());
}
void RemoveAll(const string& path)
{
if (!Exists(path))
return;
if (IsDirectory(path))
{
std::vector<string> entries;
GetFileNameList2(path, string(), entries, false, true);
for (size_t i = 0; i < entries.size(); i++)
{
const string& e = entries[i];
RemoveAll(e);
}
#ifdef _MSC_VER #ifdef _MSC_VER
bool result = _rmdir(path.c_str()) == 0; bool result = _rmdir(path.c_str()) == 0;
#else #else
bool result = rmdir(path.c_str()) == 0; bool result = rmdir(path.c_str()) == 0;
#endif #endif
if (!result) if(!result)
{ {
printf("can't remove directory: %s\n", path.c_str()); printf("can't remove directory: %s\n", path.c_str());
} }
} }
else else
{ {
#ifdef _MSC_VER #ifdef _MSC_VER
bool result = _unlink(path.c_str()) == 0; bool result = _unlink(path.c_str()) == 0;
#else #else
bool result = unlink(path.c_str()) == 0; bool result = unlink(path.c_str()) == 0;
#endif #endif
if (!result) if(!result)
{
printf("can't remove file: %s\n", path.c_str());
}
}
}
void Remove(const string &directory, const string &extension)
{
DIR *dir;
static int numberOfFiles = 0;
if ((dir = opendir(directory.c_str())) != 0)
{
/* find all the files and directories within directory */
try
{
struct dirent *ent;
while ((ent = readdir(dir)) != 0)
{
const char* name = ent->d_name;
if ((name[0] == 0) || (name[0] == '.' && name[1] == 0) || (name[0] == '.' && name[1] == '.' && name[2] == 0))
continue;
string path = JoinPath(directory, name);
if (isDir(path, dir))
{
Remove(path, extension);
}
// �ж���չ��
if (extension.empty() || wildcmp(name, extension.c_str()))
{
RemoveAll(path);
++numberOfFiles;
printf("%s deleted! number of deleted files:%d\n", path.c_str(), numberOfFiles);
}
}
}
catch (...)
{
closedir(dir);
throw;
}
closedir(dir);
}
else
{
printf("could not open directory: %s", directory.c_str());
}
// ����RemoveAllɾ��Ŀ¼
RemoveAll(directory);
}
string GetFileName(const string &path)
{
string fileName;
int indexOfPathSeparator = -1;
for (int i = path.size() - 1; i >= 0; --i)
{ {
if (IsPathSeparator(path[i])) printf("can't remove file: %s\n", path.c_str());
{
fileName = path.substr(i + 1, path.size() - i - 1);
indexOfPathSeparator = i;
break;
}
}
if (indexOfPathSeparator == -1)
{
fileName = path;
} }
}
}
void Remove(const string& directory, const string& extension)
{
DIR* dir;
static int numberOfFiles = 0;
return fileName; if((dir = opendir(directory.c_str())) != 0)
}
string GetFileName_NoExtension(const string &path)
{ {
string fileName=GetFileName(path); /* find all the files and directories within directory */
string fileName_NoExtension; try
for(int i=fileName.size()-1;i>0;--i)
{ {
if(fileName[i]=='.') struct dirent* ent;
while((ent = readdir(dir)) != 0)
{ {
fileName_NoExtension=fileName.substr(0,i); const char* name = ent->d_name;
break; if((name[0] == 0) || (name[0] == '.' && name[1] == 0) ||
} (name[0] == '.' && name[1] == '.' && name[2] == 0))
} continue;
return fileName_NoExtension; string path = JoinPath(directory, name);
}
if(isDir(path, dir))
string GetExtension(const string &path) {
{ Remove(path, extension);
string fileName; }
for (int i = path.size() - 1; i >= 0; --i)
{ // �ж���չ��
if (path[i]=='.') if(extension.empty() || wildcmp(name, extension.c_str()))
{ {
fileName = path.substr(i, path.size() - i); RemoveAll(path);
break; ++numberOfFiles;
} printf("%s deleted! number of deleted files:%d\n", path.c_str(), numberOfFiles);
} }
return fileName;
}
string GetParentPath(const string &path)
{
string fileName;
for (int i = path.size() - 1; i >= 0; --i)
{
if (IsPathSeparator(path[i]))
{
fileName = path.substr(0, i+1);
break;
}
}
return fileName;
}
static bool CreateDirectory(const string &path)
{
#if defined WIN32 || defined _WIN32 || defined WINCE
#ifdef WINRT
wchar_t wpath[MAX_PATH];
size_t copied = mbstowcs(wpath, path.c_str(), MAX_PATH);
CV_Assert((copied != MAX_PATH) && (copied != (size_t)-1));
int result = CreateDirectoryA(wpath, NULL) ? 0 : -1;
#else
int result = _mkdir(path.c_str());
#endif
#elif defined __linux__ || defined __APPLE__
int result = mkdir(path.c_str(), 0777);
#else
int result = -1;
#endif
if (result == -1)
{
return IsDirectory(path);
} }
return true; }
} catch(...)
bool CreateDirectories(const string &directoryPath)
{
string path = directoryPath;
for (;;)
{
char last_char = path.empty() ? 0 : path[path.length() - 1];
if (IsPathSeparator(last_char))
{
path = path.substr(0, path.length() - 1);
continue;
}
break;
}
if (path.empty() || path == "./" || path == ".\\" || path == ".")
return true;
if (IsDirectory(path))
return true;
size_t pos = path.rfind('/');
if (pos == string::npos)
pos = path.rfind('\\');
if (pos != string::npos)
{
string parent_directory = path.substr(0, pos);
if (!parent_directory.empty())
{
if (!CreateDirectories(parent_directory))
return false;
}
}
return CreateDirectory(path);
}
bool CopyFile(const string srcPath, const string dstPath)
{
std::ifstream srcFile(srcPath,ios::binary);
std::ofstream dstFile(dstPath,ios::binary);
if(!srcFile.is_open())
{ {
printf("can not open %s\n",srcPath.c_str()); closedir(dir);
return false; throw;
} }
if(!dstFile.is_open()) closedir(dir);
}
else
{
printf("could not open directory: %s", directory.c_str());
}
// ����RemoveAllɾ��Ŀ¼
RemoveAll(directory);
}
string GetFileName(const string& path)
{
string fileName;
int indexOfPathSeparator = -1;
for(int i = path.size() - 1; i >= 0; --i)
{
if(IsPathSeparator(path[i]))
{ {
printf("can not open %s\n", dstPath.c_str()); fileName = path.substr(i + 1, path.size() - i - 1);
return false; indexOfPathSeparator = i;
break;
} }
if(srcPath==dstPath) }
if(indexOfPathSeparator == -1)
{
fileName = path;
}
return fileName;
}
string GetFileName_NoExtension(const string& path)
{
string fileName = GetFileName(path);
string fileName_NoExtension;
for(int i = fileName.size() - 1; i > 0; --i)
{
if(fileName[i] == '.')
{ {
printf("src can not be same with dst\n"); fileName_NoExtension = fileName.substr(0, i);
return false; break;
} }
char buffer[2048]; }
unsigned int numberOfBytes=0;
while(srcFile) return fileName_NoExtension;
}
string GetExtension(const string& path)
{
string fileName;
for(int i = path.size() - 1; i >= 0; --i)
{
if(path[i] == '.')
{ {
srcFile.read(buffer,2048); fileName = path.substr(i, path.size() - i);
dstFile.write(buffer,srcFile.gcount()); break;
numberOfBytes+=srcFile.gcount();
} }
srcFile.close();
dstFile.close();
return true;
} }
bool CopyDirectories(string srcPath, const string dstPath) return fileName;
}
string GetParentPath(const string& path)
{
string fileName;
for(int i = path.size() - 1; i >= 0; --i)
{ {
if(srcPath==dstPath) if(IsPathSeparator(path[i]))
{ {
printf("src can not be same with dst\n"); fileName = path.substr(0, i + 1);
return false; break;
} }
}
// ȥ������·���ָ��� return fileName;
srcPath = srcPath.substr(0, srcPath.size() - 1); }
vector<string> fileNameList; static bool CreateDirectory(const string& path)
GetFileNameList2(srcPath, "", fileNameList, true, true); {
#if defined WIN32 || defined _WIN32 || defined WINCE
#ifdef WINRT
wchar_t wpath[MAX_PATH];
size_t copied = mbstowcs(wpath, path.c_str(), MAX_PATH);
CV_Assert((copied != MAX_PATH) && (copied != (size_t)-1));
int result = CreateDirectoryA(wpath, NULL) ? 0 : -1;
#else
int result = _mkdir(path.c_str());
#endif
#elif defined __linux__ || defined __APPLE__
int result = mkdir(path.c_str(), 0777);
#else
int result = -1;
#endif
string parentPathOfSrc=GetParentPath(srcPath); if(result == -1)
int length=parentPathOfSrc.size(); {
return IsDirectory(path);
}
return true;
}
bool CreateDirectories(const string& directoryPath)
{
string path = directoryPath;
// create all directories for(;;)
for(int i=0;i<fileNameList.size();++i) {
char last_char = path.empty() ? 0 : path[path.length() - 1];
if(IsPathSeparator(last_char))
{ {
// create directory path = path.substr(0, path.length() - 1);
string srcFilePath=fileNameList[i]; continue;
string subStr=srcFilePath.substr(length,srcFilePath.size()-length);
string dstFilePath=dstPath+subStr;
string parentPathOfDst=GetParentPath(dstFilePath);
CreateDirectories(parentPathOfDst);
} }
break;
}
// copy file if(path.empty() || path == "./" || path == ".\\" || path == ".")
for(int i=0;i<fileNameList.size();++i) return true;
if(IsDirectory(path))
return true;
size_t pos = path.rfind('/');
if(pos == string::npos)
pos = path.rfind('\\');
if(pos != string::npos)
{
string parent_directory = path.substr(0, pos);
if(!parent_directory.empty())
{ {
string srcFilePath=fileNameList[i]; if(!CreateDirectories(parent_directory))
if (IsDirectory(srcFilePath)) return false;
{
continue;
}
string subStr=srcFilePath.substr(length,srcFilePath.size()-length);
string dstFilePath=dstPath+subStr;
// copy file
CopyFile(srcFilePath,dstFilePath);
// process
double process = (1.0*(i + 1) / fileNameList.size()) * 100;
printf("%s done! %f% \n", GetFileName(fileNameList[i]).c_str(), process);
} }
printf("all done!(the number of files:%d)\n", fileNameList.size()); }
return true; return CreateDirectory(path);
}
bool CopyFile(const string srcPath, const string dstPath)
{
std::ifstream srcFile(srcPath, ios::binary);
std::ofstream dstFile(dstPath, ios::binary);
if(!srcFile.is_open())
{
printf("can not open %s\n", srcPath.c_str());
return false;
} }
if(!dstFile.is_open())
{
printf("can not open %s\n", dstPath.c_str());
return false;
}
if(srcPath == dstPath)
{
printf("src can not be same with dst\n");
return false;
}
char buffer[2048];
unsigned int numberOfBytes = 0;
while(srcFile)
{
srcFile.read(buffer, 2048);
dstFile.write(buffer, srcFile.gcount());
numberOfBytes += srcFile.gcount();
}
srcFile.close();
dstFile.close();
return true;
} }
bool CopyDirectories(string srcPath, const string dstPath)
{
if(srcPath == dstPath)
{
printf("src can not be same with dst\n");
return false;
}
srcPath = srcPath.substr(0, srcPath.size() - 1);
vector<string> fileNameList;
GetFileNameList2(srcPath, "", fileNameList, true, true);
string parentPathOfSrc = GetParentPath(srcPath);
int length = parentPathOfSrc.size();
// create all directories
for(int i = 0; i < fileNameList.size(); ++i)
{
// create directory
string srcFilePath = fileNameList[i];
string subStr = srcFilePath.substr(length, srcFilePath.size() - length);
string dstFilePath = dstPath + subStr;
string parentPathOfDst = GetParentPath(dstFilePath);
CreateDirectories(parentPathOfDst);
}
// copy file
for(int i = 0; i < fileNameList.size(); ++i)
{
string srcFilePath = fileNameList[i];
if(IsDirectory(srcFilePath))
{
continue;
}
string subStr = srcFilePath.substr(length, srcFilePath.size() - length);
string dstFilePath = dstPath + subStr;
// copy file
CopyFile(srcFilePath, dstFilePath);
// process
double process = (1.0 * (i + 1) / fileNameList.size()) * 100;
printf("%s done! %f% \n", GetFileName(fileNameList[i]).c_str(), process);
}
printf("all done!(the number of files:%d)\n", fileNameList.size());
return true;
}
} // namespace migraphxSamples
...@@ -5,66 +5,74 @@ ...@@ -5,66 +5,74 @@
#include <string> #include <string>
#include <vector> #include <vector>
namespace migraphxSamples namespace migraphxSamples {
{
// 路径是否存在 // 路径是否存在
bool Exists(const std::string &path); bool Exists(const std::string& path);
// 路径是否为目录 // 路径是否为目录
bool IsDirectory(const std::string &path); bool IsDirectory(const std::string& path);
// 是否是路径分隔符(Linux:‘/’,Windows:’\\’) // 是否是路径分隔符(Linux:‘/’,Windows:’\\’)
bool IsPathSeparator(char c); bool IsPathSeparator(char c);
// 路径拼接 // 路径拼接
std::string JoinPath(const std::string &base, const std::string &path); std::string JoinPath(const std::string& base, const std::string& path);
// 创建多级目录,注意:创建多级目录的时候,目标目录是不能有文件存在的 // 创建多级目录,注意:创建多级目录的时候,目标目录是不能有文件存在的
bool CreateDirectories(const std::string &directoryPath); bool CreateDirectories(const std::string& directoryPath);
/** 生成符合指定模式的文件名列表(支持递归遍历) /** 生成符合指定模式的文件名列表(支持递归遍历)
* *
* pattern: 模式,比如"*.jpg","*.png","*.jpg,*.png" * pattern: 模式,比如"*.jpg","*.png","*.jpg,*.png"
* addPath:是否包含父路径 * addPath:是否包含父路径
* 注意: * 注意:
1. 多个模式使用","分割,比如"*.jpg,*.png" 1. 多个模式使用","分割,比如"*.jpg,*.png"
2. 支持通配符'*','?' ,比如第一个字符是7的所有文件名:"7*.*", 以512结尾的所有jpg文件名:"*512.jpg" 2. 支持通配符'*','?' ,比如第一个字符是7的所有文件名:"7*.*",
以512结尾的所有jpg文件名:"*512.jpg"
3. 使用"*.jpg",而不是".jpg" 3. 使用"*.jpg",而不是".jpg"
4. 空string表示返回所有结果 4. 空string表示返回所有结果
5. 不能返回子目录名 5. 不能返回子目录名
* *
*/ */
void GetFileNameList(const std::string &directory, const std::string &pattern, std::vector<std::string> &result, bool recursive, bool addPath); void GetFileNameList(const std::string& directory,
const std::string& pattern,
std::vector<std::string>& result,
bool recursive,
bool addPath);
// 与GetFileNameList的区别在于如果有子目录,在addPath为true的时候会返回子目录路径(目录名最后有"/") // 与GetFileNameList的区别在于如果有子目录,在addPath为true的时候会返回子目录路径(目录名最后有"/")
void GetFileNameList2(const std::string &directory, const std::string &pattern, std::vector<std::string> &result, bool recursive, bool addPath); void GetFileNameList2(const std::string& directory,
const std::string& pattern,
std::vector<std::string>& result,
bool recursive,
bool addPath);
// 删除文件或者目录,支持递归删除 // 删除文件或者目录,支持递归删除
void Remove(const std::string &directory, const std::string &extension=""); void Remove(const std::string& directory, const std::string& extension = "");
/** 获取路径的文件名和扩展名 /** 获取路径的文件名和扩展名
* *
* 示例:path为D:/1/1.txt,则GetFileName()为1.txt,GetFileName_NoExtension()为1,GetExtension()为.txt,GetParentPath()为D:/1/ * 示例:path为D:/1/1.txt,则GetFileName()为1.txt,GetFileName_NoExtension()为1,GetExtension()为.txt,GetParentPath()为D:/1/
*/ */
std::string GetFileName(const std::string &path); std::string GetFileName(const std::string& path);
std::string GetFileName_NoExtension(const std::string &path); std::string GetFileName_NoExtension(const std::string& path);
std::string GetExtension(const std::string &path); std::string GetExtension(const std::string& path);
std::string GetParentPath(const std::string &path); std::string GetParentPath(const std::string& path);
// 拷贝文件 // 拷贝文件
bool CopyFile(const std::string srcPath,const std::string dstPath); bool CopyFile(const std::string srcPath, const std::string dstPath);
/** 拷贝目录 /** 拷贝目录
* *
* 示例:CopyDirectories("D:/0/1/2/","E:/3/");实现把D:/0/1/2/目录拷贝到E:/3/目录中(即拷贝完成后的目录结构为E:/3/2/) * 示例:CopyDirectories("D:/0/1/2/","E:/3/");实现把D:/0/1/2/目录拷贝到E:/3/目录中(即拷贝完成后的目录结构为E:/3/2/)
* 注意: * 注意:
1.第一个参数的最后不能加”/” 1.第一个参数的最后不能加”/”
2.不能拷贝隐藏文件 2.不能拷贝隐藏文件
*/ */
bool CopyDirectories(std::string srcPath,const std::string dstPath); bool CopyDirectories(std::string srcPath, const std::string dstPath);
} } // namespace migraphxSamples
#endif #endif
...@@ -4,11 +4,13 @@ ...@@ -4,11 +4,13 @@
#define __SIMPLE_LOG_H__ #define __SIMPLE_LOG_H__
#include <time.h> #include <time.h>
#include <string>
#include <map> #include <map>
#include <thread>
#include <mutex> #include <mutex>
#if (defined WIN32 || defined _WIN32) #include <string>
#include <thread>
#if(defined WIN32 || defined _WIN32)
#include <Windows.h> #include <Windows.h>
#else #else
#include <sys/time.h> #include <sys/time.h>
...@@ -16,29 +18,31 @@ ...@@ -16,29 +18,31 @@
using namespace std; using namespace std;
/** 简易日志 /** 简易日志
* *
* 不依赖于其他第三方库,只需要包含一个头文件就可以使用。提供了4种日志级别,包括INFO,DEBUG,WARN和ERROR。 * 不依赖于其他第三方库,只需要包含一个头文件就可以使用。提供了4种日志级别,包括INFO,DEBUG,WARN和ERROR。
* *
* 示例1: * 示例1:
// 初始化日志,在./Log/目录下创建两个日志文件log1.log和log2.log(注意:目录./Log/需要存在,否则日志创建失败) //
初始化日志,在./Log/目录下创建两个日志文件log1.log和log2.log(注意:目录./Log/需要存在,否则日志创建失败)
LogManager::GetInstance()->Initialize("./Log/","log1"); LogManager::GetInstance()->Initialize("./Log/","log1");
LogManager::GetInstance()->Initialize("./Log/","log2"); LogManager::GetInstance()->Initialize("./Log/","log2");
// 写日志 // 写日志
string log = "Hello World"; string log = "Hello World";
LOG_INFO(LogManager::GetInstance()->GetLogFile("log1"), "%s\n", log.c_str()); // 写入log1.log LOG_INFO(LogManager::GetInstance()->GetLogFile("log1"), "%s\n",
LOG_INFO(LogManager::GetInstance()->GetLogFile("log2"), "%s\n", log.c_str()); // 写入log2.log log.c_str()); // 写入log1.log
LOG_INFO(LogManager::GetInstance()->GetLogFile("log2"), "%s\n",
log.c_str()); // 写入log2.log
// 关闭日志 // 关闭日志
LogManager::GetInstance()->Close("log1"); LogManager::GetInstance()->Close("log1");
LogManager::GetInstance()->Close("log2"); LogManager::GetInstance()->Close("log2");
* 示例2: * 示例2:
// 将日志输出到控制台 // 将日志输出到控制台
string log = "Hello World"; string log = "Hello World";
LOG_INFO(stdout, "%s\n", log.c_str()); LOG_INFO(stdout, "%s\n", log.c_str());
* 注意: * 注意:
1. 需要C++11 1. 需要C++11
...@@ -50,44 +54,43 @@ using namespace std; ...@@ -50,44 +54,43 @@ using namespace std;
class LogManager class LogManager
{ {
private: private:
LogManager(){} LogManager() {}
public: public:
~LogManager(){} ~LogManager() {}
inline void Initialize(const string &parentPath,const string &logName) inline void Initialize(const string& parentPath, const string& logName)
{ {
// 日志名为空表示输出到控制台 // 日志名为空表示输出到控制台
if(logName.size()==0) if(logName.size() == 0)
return; return;
// 查找该日志文件,如果没有则创建 // 查找该日志文件,如果没有则创建
std::map<string, FILE*>::const_iterator iter = logMap.find(logName); std::map<string, FILE*>::const_iterator iter = logMap.find(logName);
if (iter == logMap.end()) if(iter == logMap.end())
{ {
string pathOfLog = parentPath+ logName + ".log"; string pathOfLog = parentPath + logName + ".log";
FILE *logFile = fopen(pathOfLog.c_str(), "a"); // w:覆盖原有文件,a:追加 FILE* logFile = fopen(pathOfLog.c_str(), "a"); // w:覆盖原有文件,a:追加
if(logFile!=NULL) if(logFile != NULL)
{ {
logMap.insert(std::make_pair(logName, logFile)); logMap.insert(std::make_pair(logName, logFile));
} }
} }
} }
inline FILE* GetLogFile(const string &logName) inline FILE* GetLogFile(const string& logName)
{ {
std::map<string, FILE*>::const_iterator iter=logMap.find(logName); std::map<string, FILE*>::const_iterator iter = logMap.find(logName);
if(iter==logMap.end()) if(iter == logMap.end())
{ {
return NULL; return NULL;
} }
return (*iter).second; return (*iter).second;
} }
inline void Close(const string &logName) inline void Close(const string& logName)
{ {
std::map<string, FILE*>::const_iterator iter=logMap.find(logName); std::map<string, FILE*>::const_iterator iter = logMap.find(logName);
if(iter==logMap.end()) if(iter == logMap.end())
{ {
return; return;
} }
...@@ -95,10 +98,7 @@ public: ...@@ -95,10 +98,7 @@ public:
fclose((*iter).second); fclose((*iter).second);
logMap.erase(iter); logMap.erase(iter);
} }
inline std::mutex &GetLogMutex() inline std::mutex& GetLogMutex() { return logMutex; }
{
return logMutex;
}
// Singleton // Singleton
static LogManager* GetInstance() static LogManager* GetInstance()
...@@ -106,21 +106,22 @@ public: ...@@ -106,21 +106,22 @@ public:
static LogManager logManager; static LogManager logManager;
return &logManager; return &logManager;
} }
private:
private:
std::map<string, FILE*> logMap; std::map<string, FILE*> logMap;
std::mutex logMutex; std::mutex logMutex;
}; };
#ifdef LOG_MUTEX #ifdef LOG_MUTEX
#define LOCK LogManager::GetInstance()->GetLogMutex().lock() #define LOCK LogManager::GetInstance()->GetLogMutex().lock()
#define UNLOCK LogManager::GetInstance()->GetLogMutex().unlock() #define UNLOCK LogManager::GetInstance()->GetLogMutex().unlock()
#else #else
#define LOCK #define LOCK
#define UNLOCK #define UNLOCK
#endif #endif
// log time // log time
typedef struct _LogTime typedef struct _LogTime
{ {
string year; string year;
string month; string month;
...@@ -131,53 +132,53 @@ typedef struct _LogTime ...@@ -131,53 +132,53 @@ typedef struct _LogTime
string millisecond; // ms string millisecond; // ms
string microsecond; // us string microsecond; // us
string weekDay; string weekDay;
}LogTime; } LogTime;
inline LogTime GetTime() inline LogTime GetTime()
{ {
LogTime currentTime; LogTime currentTime;
#if (defined WIN32 || defined _WIN32) #if(defined WIN32 || defined _WIN32)
SYSTEMTIME systemTime; SYSTEMTIME systemTime;
GetLocalTime(&systemTime); GetLocalTime(&systemTime);
char temp[8] = { 0 }; char temp[8] = {0};
sprintf(temp, "%04d", systemTime.wYear); sprintf(temp, "%04d", systemTime.wYear);
currentTime.year=string(temp); currentTime.year = string(temp);
sprintf(temp, "%02d", systemTime.wMonth); sprintf(temp, "%02d", systemTime.wMonth);
currentTime.month=string(temp); currentTime.month = string(temp);
sprintf(temp, "%02d", systemTime.wDay); sprintf(temp, "%02d", systemTime.wDay);
currentTime.day=string(temp); currentTime.day = string(temp);
sprintf(temp, "%02d", systemTime.wHour); sprintf(temp, "%02d", systemTime.wHour);
currentTime.hour=string(temp); currentTime.hour = string(temp);
sprintf(temp, "%02d", systemTime.wMinute); sprintf(temp, "%02d", systemTime.wMinute);
currentTime.minute=string(temp); currentTime.minute = string(temp);
sprintf(temp, "%02d", systemTime.wSecond); sprintf(temp, "%02d", systemTime.wSecond);
currentTime.second=string(temp); currentTime.second = string(temp);
sprintf(temp, "%03d", systemTime.wMilliseconds); sprintf(temp, "%03d", systemTime.wMilliseconds);
currentTime.millisecond=string(temp); currentTime.millisecond = string(temp);
sprintf(temp, "%d", systemTime.wDayOfWeek); sprintf(temp, "%d", systemTime.wDayOfWeek);
currentTime.weekDay=string(temp); currentTime.weekDay = string(temp);
#else #else
struct timeval tv; struct timeval tv;
struct tm *p; struct tm* p;
gettimeofday(&tv, NULL); gettimeofday(&tv, NULL);
p = localtime(&tv.tv_sec); p = localtime(&tv.tv_sec);
char temp[8]={0}; char temp[8] = {0};
sprintf(temp,"%04d",1900+p->tm_year); sprintf(temp, "%04d", 1900 + p->tm_year);
currentTime.year=string(temp); currentTime.year = string(temp);
sprintf(temp,"%02d",1+p->tm_mon); sprintf(temp, "%02d", 1 + p->tm_mon);
currentTime.month=string(temp); currentTime.month = string(temp);
sprintf(temp,"%02d",p->tm_mday); sprintf(temp, "%02d", p->tm_mday);
currentTime.day=string(temp); currentTime.day = string(temp);
sprintf(temp,"%02d",p->tm_hour); sprintf(temp, "%02d", p->tm_hour);
currentTime.hour=string(temp); currentTime.hour = string(temp);
sprintf(temp,"%02d",p->tm_min); sprintf(temp, "%02d", p->tm_min);
currentTime.minute=string(temp); currentTime.minute = string(temp);
sprintf(temp,"%02d",p->tm_sec); sprintf(temp, "%02d", p->tm_sec);
currentTime.second=string(temp); currentTime.second = string(temp);
sprintf(temp,"%03d",(int)(tv.tv_usec/1000)); sprintf(temp, "%03d", (int)(tv.tv_usec / 1000));
currentTime.millisecond = string(temp); currentTime.millisecond = string(temp);
sprintf(temp, "%03d", (int)(tv.tv_usec % 1000)); sprintf(temp, "%03d", (int)(tv.tv_usec % 1000));
currentTime.microsecond = string(temp); currentTime.microsecond = string(temp);
...@@ -187,61 +188,83 @@ inline LogTime GetTime() ...@@ -187,61 +188,83 @@ inline LogTime GetTime()
return currentTime; return currentTime;
} }
#define LOG_TIME(logFile) \ #define LOG_TIME(logFile) \
do\ do \
{\ { \
LogTime currentTime=GetTime(); \ LogTime currentTime = GetTime(); \
fprintf(((logFile == NULL) ? stdout : logFile), "%s-%s-%s %s:%s:%s.%s\t",currentTime.year.c_str(),currentTime.month.c_str(),currentTime.day.c_str(),currentTime.hour.c_str(),currentTime.minute.c_str(),currentTime.second.c_str(),currentTime.millisecond.c_str()); \ fprintf(((logFile == NULL) ? stdout : logFile), \
}while (0) "%s-%s-%s %s:%s:%s.%s\t", \
currentTime.year.c_str(), \
currentTime.month.c_str(), \
#define LOG_INFO(logFile,logInfo, ...) \ currentTime.day.c_str(), \
do\ currentTime.hour.c_str(), \
{\ currentTime.minute.c_str(), \
LOCK; \ currentTime.second.c_str(), \
LOG_TIME(logFile); \ currentTime.millisecond.c_str()); \
fprintf(((logFile == NULL) ? stdout : logFile), "INFO\t"); \ } while(0)
fprintf(((logFile == NULL) ? stdout : logFile), "[%s:%d (%s) ]: ", __FILE__, __LINE__, __FUNCTION__); \
fprintf(((logFile == NULL) ? stdout : logFile), logInfo, ## __VA_ARGS__); \
fflush(logFile); \
UNLOCK; \
} while (0)
#define LOG_DEBUG(logFile,logInfo, ...) \
do\
{\
LOCK; \
LOG_TIME(logFile);\
fprintf(((logFile==NULL)?stdout:logFile), "DEBUG\t"); \
fprintf(((logFile==NULL)?stdout:logFile), "[%s:%d (%s) ]: ", __FILE__, __LINE__, __FUNCTION__); \
fprintf(((logFile==NULL)?stdout:logFile),logInfo, ## __VA_ARGS__); \
fflush(logFile); \
UNLOCK; \
} while (0)
#define LOG_ERROR(logFile,logInfo, ...) \
do\
{\
LOCK; \
LOG_TIME(logFile);\
fprintf(((logFile==NULL)?stdout:logFile), "ERROR\t"); \
fprintf(((logFile==NULL)?stdout:logFile), "[%s:%d (%s) ]: ", __FILE__, __LINE__, __FUNCTION__); \
fprintf(((logFile==NULL)?stdout:logFile),logInfo, ## __VA_ARGS__); \
fflush(logFile); \
UNLOCK; \
} while (0)
#define LOG_WARN(logFile,logInfo, ...) \
do\
{\
LOCK; \
LOG_TIME(logFile);\
fprintf(((logFile==NULL)?stdout:logFile), "WARN\t"); \
fprintf(((logFile==NULL)?stdout:logFile), "[%s:%d (%s) ]: ", __FILE__, __LINE__, __FUNCTION__); \
fprintf(((logFile==NULL)?stdout:logFile),logInfo, ## __VA_ARGS__); \
fflush(logFile); \
UNLOCK; \
} while (0)
#endif // __SIMPLE_LOG_H__ #define LOG_INFO(logFile, logInfo, ...) \
do \
{ \
LOCK; \
LOG_TIME(logFile); \
fprintf(((logFile == NULL) ? stdout : logFile), "INFO\t"); \
fprintf(((logFile == NULL) ? stdout : logFile), \
"[%s:%d (%s) ]: ", \
__FILE__, \
__LINE__, \
__FUNCTION__); \
fprintf(((logFile == NULL) ? stdout : logFile), logInfo, ##__VA_ARGS__); \
fflush(logFile); \
UNLOCK; \
} while(0)
#define LOG_DEBUG(logFile, logInfo, ...) \
do \
{ \
LOCK; \
LOG_TIME(logFile); \
fprintf(((logFile == NULL) ? stdout : logFile), "DEBUG\t"); \
fprintf(((logFile == NULL) ? stdout : logFile), \
"[%s:%d (%s) ]: ", \
__FILE__, \
__LINE__, \
__FUNCTION__); \
fprintf(((logFile == NULL) ? stdout : logFile), logInfo, ##__VA_ARGS__); \
fflush(logFile); \
UNLOCK; \
} while(0)
#define LOG_ERROR(logFile, logInfo, ...) \
do \
{ \
LOCK; \
LOG_TIME(logFile); \
fprintf(((logFile == NULL) ? stdout : logFile), "ERROR\t"); \
fprintf(((logFile == NULL) ? stdout : logFile), \
"[%s:%d (%s) ]: ", \
__FILE__, \
__LINE__, \
__FUNCTION__); \
fprintf(((logFile == NULL) ? stdout : logFile), logInfo, ##__VA_ARGS__); \
fflush(logFile); \
UNLOCK; \
} while(0)
#define LOG_WARN(logFile, logInfo, ...) \
do \
{ \
LOCK; \
LOG_TIME(logFile); \
fprintf(((logFile == NULL) ? stdout : logFile), "WARN\t"); \
fprintf(((logFile == NULL) ? stdout : logFile), \
"[%s:%d (%s) ]: ", \
__FILE__, \
__LINE__, \
__FUNCTION__); \
fprintf(((logFile == NULL) ? stdout : logFile), logInfo, ##__VA_ARGS__); \
fflush(logFile); \
UNLOCK; \
} while(0)
#endif // __SIMPLE_LOG_H__
#include <stdexcept>
#include <algorithm> #include <algorithm>
#include <cstring> #include <cstring>
#include <fstream> #include <fstream>
#include "utf8proc.h" #include <stdexcept>
#include "./tokenization.h" #include "./tokenization.h"
#include "utf8proc.h"
namespace cuBERT { namespace cuBERT {
void FullTokenizer::convert_tokens_to_ids(const std::vector<std::string> &tokens, uint64_t *ids) { void FullTokenizer::convert_tokens_to_ids(const std::vector<std::string>& tokens, uint64_t* ids)
for (int i = 0; i < tokens.size(); ++i) { {
ids[i] = convert_token_to_id(tokens[i]); for(int i = 0; i < tokens.size(); ++i)
} {
ids[i] = convert_token_to_id(tokens[i]);
} }
}
// trim from start (in place) // trim from start (in place)
static inline void ltrim(std::string &s) { static inline void ltrim(std::string& s)
s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int ch) { {
return !std::isspace(ch); s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int ch) { return !std::isspace(ch); }));
})); }
}
// trim from end (in place) // trim from end (in place)
static inline void rtrim(std::string &s) { static inline void rtrim(std::string& s)
s.erase(std::find_if(s.rbegin(), s.rend(), [](int ch) { {
return !std::isspace(ch); s.erase(std::find_if(s.rbegin(), s.rend(), [](int ch) { return !std::isspace(ch); }).base(),
}).base(), s.end()); s.end());
} }
// trim from both ends (in place) // trim from both ends (in place)
static inline void trim(std::string &s) { static inline void trim(std::string& s)
ltrim(s); {
rtrim(s); ltrim(s);
} rtrim(s);
}
void load_vocab(const char *vocab_file, std::unordered_map<std::string, uint64_t> *vocab) {
std::ifstream file(vocab_file);
if (!file) {
throw std::invalid_argument("Unable to open vocab file");
}
unsigned int index = 0;
std::string line;
while (std::getline(file, line)) {
trim(line);
(*vocab)[line] = index;
index++;
}
file.close(); void load_vocab(const char* vocab_file, std::unordered_map<std::string, uint64_t>* vocab)
{
std::ifstream file(vocab_file);
if(!file)
{
throw std::invalid_argument("Unable to open vocab file");
} }
inline bool _is_whitespace(int c, const char *cat) { unsigned int index = 0;
if (c == ' ' || c == '\t' || c == '\n' || c == '\r') { std::string line;
return true; while(std::getline(file, line))
} {
return cat[0] == 'Z' && cat[1] == 's'; trim(line);
(*vocab)[line] = index;
index++;
} }
inline bool _is_control(int c, const char *cat) { file.close();
// These are technically control characters but we count them as whitespace characters. }
if (c == '\t' || c == '\n' || c == '\r') {
return false;
}
return 'C' == *cat;
}
inline bool _is_punctuation(int cp, const char *cat) { inline bool _is_whitespace(int c, const char* cat)
// We treat all non-letter/number ASCII as punctuation. {
// Characters such as "^", "$", and "`" are not in the Unicode if(c == ' ' || c == '\t' || c == '\n' || c == '\r')
// Punctuation class but we treat them as punctuation anyways, for {
// consistency. return true;
if ((cp >= 33 && cp <= 47) || (cp >= 58 && cp <= 64) ||
(cp >= 91 && cp <= 96) || (cp >= 123 && cp <= 126)) {
return true;
}
return 'P' == *cat;
} }
return cat[0] == 'Z' && cat[1] == 's';
}
bool _is_whitespace(int c) { inline bool _is_control(int c, const char* cat)
return _is_whitespace(c, utf8proc_category_string(c)); {
// These are technically control characters but we count them as whitespace
// characters.
if(c == '\t' || c == '\n' || c == '\r')
{
return false;
} }
return 'C' == *cat;
}
bool _is_control(int c) { inline bool _is_punctuation(int cp, const char* cat)
return _is_control(c, utf8proc_category_string(c)); {
// We treat all non-letter/number ASCII as punctuation.
// Characters such as "^", "$", and "`" are not in the Unicode
// Punctuation class but we treat them as punctuation anyways, for
// consistency.
if((cp >= 33 && cp <= 47) || (cp >= 58 && cp <= 64) || (cp >= 91 && cp <= 96) ||
(cp >= 123 && cp <= 126))
{
return true;
} }
return 'P' == *cat;
}
bool _is_punctuation(int cp) { bool _is_whitespace(int c) { return _is_whitespace(c, utf8proc_category_string(c)); }
return _is_punctuation(cp, utf8proc_category_string(cp));
} bool _is_control(int c) { return _is_control(c, utf8proc_category_string(c)); }
bool _is_punctuation(int cp) { return _is_punctuation(cp, utf8proc_category_string(cp)); }
bool BasicTokenizer::_is_chinese_char(int cp)
{
// This defines a "chinese character" as anything in the CJK Unicode block:
// https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
//
// Note that the CJK Unicode block is NOT all Japanese and Korean characters,
// despite its name. The modern Korean Hangul alphabet is a different block,
// as is Japanese Hiragana and Katakana. Those alphabets are used to write
// space-separated words, so they are not treated specially and handled
// like the all of the other languages.
return (cp >= 0x4E00 && cp <= 0x9FFF) || (cp >= 0x3400 && cp <= 0x4DBF) ||
(cp >= 0x20000 && cp <= 0x2A6DF) || (cp >= 0x2A700 && cp <= 0x2B73F) ||
(cp >= 0x2B740 && cp <= 0x2B81F) || (cp >= 0x2B820 && cp <= 0x2CEAF) ||
(cp >= 0xF900 && cp <= 0xFAFF) || (cp >= 0x2F800 && cp <= 0x2FA1F);
}
bool BasicTokenizer::_is_chinese_char(int cp) { void BasicTokenizer::tokenize(const char* text,
// This defines a "chinese character" as anything in the CJK Unicode block: std::vector<std::string>* output_tokens,
// https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) size_t max_length)
// {
// Note that the CJK Unicode block is NOT all Japanese and Korean characters, // This was added on November 1st, 2018 for the multilingual and Chinese
// despite its name. The modern Korean Hangul alphabet is a different block, // models. This is also applied to the English models now, but it doesn't
// as is Japanese Hiragana and Katakana. Those alphabets are used to write // matter since the English models were not trained on any Chinese data
// space-separated words, so they are not treated specially and handled // and generally don't have any Chinese data in them (there are Chinese
// like the all of the other languages. // characters in the vocabulary because Wikipedia does have some Chinese
return (cp >= 0x4E00 && cp <= 0x9FFF) || // words in the English Wikipedia.).
(cp >= 0x3400 && cp <= 0x4DBF) || if(do_lower_case)
(cp >= 0x20000 && cp <= 0x2A6DF) || {
(cp >= 0x2A700 && cp <= 0x2B73F) || text = (const char*)utf8proc_NFD((const utf8proc_uint8_t*)text);
(cp >= 0x2B740 && cp <= 0x2B81F) ||
(cp >= 0x2B820 && cp <= 0x2CEAF) ||
(cp >= 0xF900 && cp <= 0xFAFF) ||
(cp >= 0x2F800 && cp <= 0x2FA1F);
} }
void BasicTokenizer::tokenize(const char *text, std::vector<std::string> *output_tokens, size_t max_length) { size_t word_bytes = std::strlen(text);
// This was added on November 1st, 2018 for the multilingual and Chinese bool new_token = true;
// models. This is also applied to the English models now, but it doesn't size_t subpos = 0;
// matter since the English models were not trained on any Chinese data int cp;
// and generally don't have any Chinese data in them (there are Chinese char dst[4];
// characters in the vocabulary because Wikipedia does have some Chinese
// words in the English Wikipedia.). while(word_bytes > 0)
if (do_lower_case) { {
text = (const char *) utf8proc_NFD((const utf8proc_uint8_t *) text); int len = utf8proc_iterate((const utf8proc_uint8_t*)text + subpos, word_bytes, &cp);
if(len < 0)
{
std::cerr << "UTF-8 decode error: " << text << std::endl;
break;
}
if(do_lower_case)
{
cp = utf8proc_tolower(cp);
} }
size_t word_bytes = std::strlen(text); const char* cat = utf8proc_category_string(cp);
bool new_token = true; if(cp == 0 || cp == 0xfffd || _is_control(cp, cat))
size_t subpos = 0; {
int cp; // pass
char dst[4]; }
else if(do_lower_case && cat[0] == 'M' && cat[1] == 'n')
while (word_bytes > 0) { {
int len = utf8proc_iterate((const utf8proc_uint8_t *) text + subpos, word_bytes, &cp); // pass
if (len < 0) { }
std::cerr << "UTF-8 decode error: " << text << std::endl; else if(_is_whitespace(cp, cat))
break; {
} new_token = true;
if (do_lower_case) { }
cp = utf8proc_tolower(cp); else
{
size_t dst_len = len;
const char* dst_ptr = text + subpos;
if(do_lower_case)
{
dst_len = utf8proc_encode_char(cp, (utf8proc_uint8_t*)dst);
dst_ptr = dst;
} }
const char *cat = utf8proc_category_string(cp); if(_is_punctuation(cp, cat) || _is_chinese_char(cp))
if (cp == 0 || cp == 0xfffd || _is_control(cp, cat)) { {
// pass output_tokens->emplace_back(dst_ptr, dst_len);
} else if (do_lower_case && cat[0] == 'M' && cat[1] == 'n') {
// pass
} else if (_is_whitespace(cp, cat)) {
new_token = true; new_token = true;
} else { }
size_t dst_len = len; else
const char *dst_ptr = text + subpos; {
if (do_lower_case) { if(new_token)
dst_len = utf8proc_encode_char(cp, (utf8proc_uint8_t *) dst); {
dst_ptr = dst;
}
if (_is_punctuation(cp, cat) || _is_chinese_char(cp)) {
output_tokens->emplace_back(dst_ptr, dst_len); output_tokens->emplace_back(dst_ptr, dst_len);
new_token = true; new_token = false;
} else { }
if (new_token) { else
output_tokens->emplace_back(dst_ptr, dst_len); {
new_token = false; output_tokens->at(output_tokens->size() - 1).append(dst_ptr, dst_len);
} else {
output_tokens->at(output_tokens->size() - 1).append(dst_ptr, dst_len);
}
} }
} }
}
word_bytes = word_bytes - len; word_bytes = word_bytes - len;
subpos = subpos + len; subpos = subpos + len;
// early terminate // early terminate
if (output_tokens->size() >= max_length) { if(output_tokens->size() >= max_length)
break; {
} break;
} }
}
if (do_lower_case) { if(do_lower_case)
free((void *) text); {
} free((void*)text);
} }
}
void WordpieceTokenizer::tokenize(const std::string& token, std::vector<std::string>* output_tokens)
{
if(token.size() > max_input_chars_per_word)
{ // FIXME: slightly different
output_tokens->push_back(unk_token);
return;
}
size_t output_tokens_len = output_tokens->size();
for(size_t start = 0; start < token.size();)
{
bool is_bad = true;
// TODO: can be optimized by prefix-tree
for(size_t end = token.size(); start < end; --end)
{ // FIXME: slightly different
std::string substr = start > 0 ? "##" + token.substr(start, end - start)
: token.substr(start, end - start);
if(vocab->count(substr))
{
is_bad = false;
output_tokens->push_back(substr);
start = end;
break;
}
}
void WordpieceTokenizer::tokenize(const std::string &token, std::vector<std::string> *output_tokens) { if(is_bad)
if (token.size() > max_input_chars_per_word) { // FIXME: slightly different {
output_tokens->resize(output_tokens_len);
output_tokens->push_back(unk_token); output_tokens->push_back(unk_token);
return; return;
} }
size_t output_tokens_len = output_tokens->size();
for (size_t start = 0; start < token.size();) {
bool is_bad = true;
// TODO: can be optimized by prefix-tree
for (size_t end = token.size(); start < end; --end) { // FIXME: slightly different
std::string substr = start > 0
? "##" + token.substr(start, end - start)
: token.substr(start, end - start);
if (vocab->count(substr)) {
is_bad = false;
output_tokens->push_back(substr);
start = end;
break;
}
}
if (is_bad) {
output_tokens->resize(output_tokens_len);
output_tokens->push_back(unk_token);
return;
}
}
} }
}
void FullTokenizer::tokenize(const char* text,
void FullTokenizer::tokenize(const char *text, std::vector<std::string> *output_tokens, size_t max_length) { std::vector<std::string>* output_tokens,
std::vector<std::string> tokens; size_t max_length)
tokens.reserve(max_length); {
basic_tokenizer->tokenize(text, &tokens, max_length); std::vector<std::string> tokens;
tokens.reserve(max_length);
for (const auto &token : tokens) { basic_tokenizer->tokenize(text, &tokens, max_length);
wordpiece_tokenizer->tokenize(token, output_tokens);
for(const auto& token : tokens)
// early terminate {
if (output_tokens->size() >= max_length) { wordpiece_tokenizer->tokenize(token, output_tokens);
break;
} // early terminate
if(output_tokens->size() >= max_length)
{
break;
} }
} }
} }
} // namespace cuBERT
#ifndef CUBERT_TOKENIZATION_H #ifndef CUBERT_TOKENIZATION_H
#define CUBERT_TOKENIZATION_H #define CUBERT_TOKENIZATION_H
#include <iostream>
#include <string> #include <string>
#include <vector>
#include <unordered_map> #include <unordered_map>
#include <iostream> #include <vector>
namespace cuBERT { namespace cuBERT {
void load_vocab(const char *vocab_file, std::unordered_map<std::string, uint64_t> *vocab); void load_vocab(const char* vocab_file, std::unordered_map<std::string, uint64_t>* vocab);
/** /**
* Checks whether `chars` is a whitespace character. * Checks whether `chars` is a whitespace character.
* @param c * @param c
* @return * @return
*/ */
bool _is_whitespace(int c); bool _is_whitespace(int c);
/** /**
* Checks whether `chars` is a control character. * Checks whether `chars` is a control character.
* @param c * @param c
* @return * @return
*/ */
bool _is_control(int c); bool _is_control(int c);
/** /**
* Checks whether `chars` is a punctuation character. * Checks whether `chars` is a punctuation character.
* @param cp * @param cp
* @return * @return
*/ */
bool _is_punctuation(int cp); bool _is_punctuation(int cp);
/** /**
* Runs basic tokenization (punctuation splitting, lower casing, etc.). * Runs basic tokenization (punctuation splitting, lower casing, etc.).
*/ */
class BasicTokenizer { class BasicTokenizer
{
public: public:
/** /**
* Constructs a BasicTokenizer. * Constructs a BasicTokenizer.
* @param do_lower_case Whether to lower case the input. * @param do_lower_case Whether to lower case the input.
*/ */
explicit BasicTokenizer(bool do_lower_case = true) : do_lower_case(do_lower_case) {} explicit BasicTokenizer(bool do_lower_case = true) : do_lower_case(do_lower_case) {}
BasicTokenizer(const BasicTokenizer &other) = delete; BasicTokenizer(const BasicTokenizer& other) = delete;
virtual ~BasicTokenizer() = default; virtual ~BasicTokenizer() = default;
/** /**
* Tokenizes a piece of text. * Tokenizes a piece of text.
* *
* to_lower * to_lower
* _run_strip_accents Strips accents from a piece of text. * _run_strip_accents Strips accents from a piece of text.
* _clean_text Performs invalid character removal and whitespace cleanup on text. * _clean_text Performs invalid character removal and whitespace cleanup on
* _tokenize_chinese_chars Adds whitespace around any CJK character. * text. _tokenize_chinese_chars Adds whitespace around any CJK character.
* _run_split_on_punc Splits punctuation on a piece of text. * _run_split_on_punc Splits punctuation on a piece of text.
* whitespace_tokenize Runs basic whitespace cleaning and splitting on a piece of text. * whitespace_tokenize Runs basic whitespace cleaning and splitting on a piece
* * of text.
* @param text *
* @param output_tokens * @param text
*/ * @param output_tokens
void tokenize(const char *text, std::vector<std::string> *output_tokens, size_t max_length); */
void tokenize(const char* text, std::vector<std::string>* output_tokens, size_t max_length);
private: private:
const bool do_lower_case; const bool do_lower_case;
/** /**
* Checks whether CP is the codepoint of a CJK character. * Checks whether CP is the codepoint of a CJK character.
* @param cp * @param cp
* @return * @return
*/ */
inline static bool _is_chinese_char(int cp); inline static bool _is_chinese_char(int cp);
}; };
/** /**
* Runs WordPiece tokenziation. * Runs WordPiece tokenziation.
*/ */
class WordpieceTokenizer { class WordpieceTokenizer
{
public: public:
explicit WordpieceTokenizer( explicit WordpieceTokenizer(std::unordered_map<std::string, uint64_t>* vocab,
std::unordered_map<std::string, uint64_t> *vocab, std::string unk_token = "[UNK]",
std::string unk_token = "[UNK]", int max_input_chars_per_word = 200)
int max_input_chars_per_word = 200 : vocab(vocab), unk_token(unk_token), max_input_chars_per_word(max_input_chars_per_word)
) : vocab(vocab), unk_token(unk_token), max_input_chars_per_word(max_input_chars_per_word) {} {
}
WordpieceTokenizer(const WordpieceTokenizer &other) = delete;
WordpieceTokenizer(const WordpieceTokenizer& other) = delete;
virtual ~WordpieceTokenizer() = default;
virtual ~WordpieceTokenizer() = default;
/**
* Tokenizes a piece of text into its word pieces. /**
* * Tokenizes a piece of text into its word pieces.
* This uses a greedy longest-match-first algorithm to perform tokenization *
* using the given vocabulary. * This uses a greedy longest-match-first algorithm to perform tokenization
* * using the given vocabulary.
* For example: *
* input = "unaffable" * For example:
* output = ["un", "##aff", "##able"] * input = "unaffable"
* * output = ["un", "##aff", "##able"]
* @param text A single token or whitespace separated tokens. This should have already been passed through `BasicTokenizer. *
* @param output_tokens A list of wordpiece tokens. * @param text A single token or whitespace separated tokens. This should have
*/ * already been passed through `BasicTokenizer.
void tokenize(const std::string &text, std::vector<std::string> *output_tokens); * @param output_tokens A list of wordpiece tokens.
*/
void tokenize(const std::string& text, std::vector<std::string>* output_tokens);
private: private:
const std::unordered_map<std::string, uint64_t> *vocab; const std::unordered_map<std::string, uint64_t>* vocab;
const std::string unk_token; const std::string unk_token;
const int max_input_chars_per_word; const int max_input_chars_per_word;
}; };
/** /**
* Runs end-to-end tokenziation. * Runs end-to-end tokenziation.
*/ */
class FullTokenizer { class FullTokenizer
{
public: public:
FullTokenizer(const char *vocab_file, bool do_lower_case = true) { FullTokenizer(const char* vocab_file, bool do_lower_case = true)
vocab = new std::unordered_map<std::string, uint64_t>(); {
load_vocab(vocab_file, vocab); vocab = new std::unordered_map<std::string, uint64_t>();
basic_tokenizer = new BasicTokenizer(do_lower_case); load_vocab(vocab_file, vocab);
wordpiece_tokenizer = new WordpieceTokenizer(vocab); basic_tokenizer = new BasicTokenizer(do_lower_case);
wordpiece_tokenizer = new WordpieceTokenizer(vocab);
}
~FullTokenizer()
{
if(wordpiece_tokenizer != NULL)
{
wordpiece_tokenizer = NULL;
} }
delete wordpiece_tokenizer;
~FullTokenizer() { if(basic_tokenizer != NULL)
if (wordpiece_tokenizer != NULL){ {
wordpiece_tokenizer = NULL; basic_tokenizer = NULL;
}
delete wordpiece_tokenizer;
if (basic_tokenizer != NULL){
basic_tokenizer = NULL;
}
delete basic_tokenizer;
if (vocab != NULL){
vocab = NULL;
}
delete vocab;
} }
delete basic_tokenizer;
void tokenize(const char *text, std::vector<std::string> *output_tokens, size_t max_length); if(vocab != NULL)
{
inline uint64_t convert_token_to_id(const std::string &token) { vocab = NULL;
auto item = vocab->find(token); }
if (item == vocab->end()) { delete vocab;
std::cerr << "vocab missing key: " << token << std::endl; }
return 0;
} else { void tokenize(const char* text, std::vector<std::string>* output_tokens, size_t max_length);
return item->second;
} inline uint64_t convert_token_to_id(const std::string& token)
{
auto item = vocab->find(token);
if(item == vocab->end())
{
std::cerr << "vocab missing key: " << token << std::endl;
return 0;
}
else
{
return item->second;
} }
}
void convert_tokens_to_ids(const std::vector<std::string> &tokens, uint64_t *ids); void convert_tokens_to_ids(const std::vector<std::string>& tokens, uint64_t* ids);
private: private:
std::unordered_map<std::string, uint64_t> *vocab; std::unordered_map<std::string, uint64_t>* vocab;
BasicTokenizer *basic_tokenizer; BasicTokenizer* basic_tokenizer;
WordpieceTokenizer *wordpiece_tokenizer; WordpieceTokenizer* wordpiece_tokenizer;
}; };
} } // namespace cuBERT
#endif //CUBERT_TOKENIZATION_H #endif // CUBERT_TOKENIZATION_H
/* -*- mode: c; c-basic-offset: 2; tab-width: 2; indent-tabs-mode: nil -*- */ /* -*- mode: c; c-basic-offset: 2; tab-width: 2; indent-tabs-mode: nil -*- */
/* /*
* Copyright (c) 2014-2021 Steven G. Johnson, Jiahao Chen, Peter Colberg, Tony Kelman, Scott P. Jones, and other contributors. * Copyright (c) 2014-2021 Steven G. Johnson, Jiahao Chen, Peter Colberg, Tony
* Copyright (c) 2009 Public Software Group e. V., Berlin, Germany * Kelman, Scott P. Jones, and other contributors. Copyright (c) 2009 Public
* Software Group e. V., Berlin, Germany
* *
* Permission is hereby granted, free of charge, to any person obtaining a * Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"), * copy of this software and associated documentation files (the "Software"),
...@@ -32,7 +33,6 @@ ...@@ -32,7 +33,6 @@
* Please notice the copyright statement in the file "utf8proc_data.c". * Please notice the copyright statement in the file "utf8proc_data.c".
*/ */
/* /*
* File name: utf8proc.c * File name: utf8proc.c
* *
...@@ -40,36 +40,26 @@ ...@@ -40,36 +40,26 @@
* Implementation of libutf8proc. * Implementation of libutf8proc.
*/ */
#include "utf8proc.h" #include "utf8proc.h"
#ifndef SSIZE_MAX #ifndef SSIZE_MAX
#define SSIZE_MAX ((size_t)SIZE_MAX/2) #define SSIZE_MAX ((size_t)SIZE_MAX / 2)
#endif #endif
#ifndef UINT16_MAX #ifndef UINT16_MAX
# define UINT16_MAX 65535U #define UINT16_MAX 65535U
#endif #endif
#include "utf8proc_data.c" #include "utf8proc_data.c"
UTF8PROC_DLLEXPORT const utf8proc_int8_t utf8proc_utf8class[256] = { UTF8PROC_DLLEXPORT const utf8proc_int8_t utf8proc_utf8class[256] = {
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0};
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0 };
#define UTF8PROC_HANGUL_SBASE 0xAC00 #define UTF8PROC_HANGUL_SBASE 0xAC00
#define UTF8PROC_HANGUL_LBASE 0x1100 #define UTF8PROC_HANGUL_LBASE 0x1100
...@@ -81,166 +71,198 @@ UTF8PROC_DLLEXPORT const utf8proc_int8_t utf8proc_utf8class[256] = { ...@@ -81,166 +71,198 @@ UTF8PROC_DLLEXPORT const utf8proc_int8_t utf8proc_utf8class[256] = {
#define UTF8PROC_HANGUL_NCOUNT 588 #define UTF8PROC_HANGUL_NCOUNT 588
#define UTF8PROC_HANGUL_SCOUNT 11172 #define UTF8PROC_HANGUL_SCOUNT 11172
/* END is exclusive */ /* END is exclusive */
#define UTF8PROC_HANGUL_L_START 0x1100 #define UTF8PROC_HANGUL_L_START 0x1100
#define UTF8PROC_HANGUL_L_END 0x115A #define UTF8PROC_HANGUL_L_END 0x115A
#define UTF8PROC_HANGUL_L_FILLER 0x115F #define UTF8PROC_HANGUL_L_FILLER 0x115F
#define UTF8PROC_HANGUL_V_START 0x1160 #define UTF8PROC_HANGUL_V_START 0x1160
#define UTF8PROC_HANGUL_V_END 0x11A3 #define UTF8PROC_HANGUL_V_END 0x11A3
#define UTF8PROC_HANGUL_T_START 0x11A8 #define UTF8PROC_HANGUL_T_START 0x11A8
#define UTF8PROC_HANGUL_T_END 0x11FA #define UTF8PROC_HANGUL_T_END 0x11FA
#define UTF8PROC_HANGUL_S_START 0xAC00 #define UTF8PROC_HANGUL_S_START 0xAC00
#define UTF8PROC_HANGUL_S_END 0xD7A4 #define UTF8PROC_HANGUL_S_END 0xD7A4
/* Should follow semantic-versioning rules (semver.org) based on API /* Should follow semantic-versioning rules (semver.org) based on API
compatibility. (Note that the shared-library version number will compatibility. (Note that the shared-library version number will
be different, being based on ABI compatibility.): */ be different, being based on ABI compatibility.): */
#define STRINGIZEx(x) #x #define STRINGIZEx(x) #x
#define STRINGIZE(x) STRINGIZEx(x) #define STRINGIZE(x) STRINGIZEx(x)
UTF8PROC_DLLEXPORT const char *utf8proc_version(void) { UTF8PROC_DLLEXPORT const char* utf8proc_version(void)
return STRINGIZE(UTF8PROC_VERSION_MAJOR) "." STRINGIZE(UTF8PROC_VERSION_MINOR) "." STRINGIZE(UTF8PROC_VERSION_PATCH) ""; {
} return STRINGIZE(UTF8PROC_VERSION_MAJOR) "." STRINGIZE(UTF8PROC_VERSION_MINOR) "." STRINGIZE(UTF8PROC_VERSION_PATCH) "";
}
UTF8PROC_DLLEXPORT const char *utf8proc_unicode_version(void) {
return "15.0.0"; UTF8PROC_DLLEXPORT const char* utf8proc_unicode_version(void) { return "15.0.0"; }
}
UTF8PROC_DLLEXPORT const char* utf8proc_errmsg(utf8proc_ssize_t errcode)
UTF8PROC_DLLEXPORT const char *utf8proc_errmsg(utf8proc_ssize_t errcode) { {
switch (errcode) { switch(errcode)
case UTF8PROC_ERROR_NOMEM: {
return "Memory for processing UTF-8 data could not be allocated."; case UTF8PROC_ERROR_NOMEM: return "Memory for processing UTF-8 data could not be allocated.";
case UTF8PROC_ERROR_OVERFLOW: case UTF8PROC_ERROR_OVERFLOW: return "UTF-8 string is too long to be processed.";
return "UTF-8 string is too long to be processed."; case UTF8PROC_ERROR_INVALIDUTF8: return "Invalid UTF-8 string";
case UTF8PROC_ERROR_INVALIDUTF8: case UTF8PROC_ERROR_NOTASSIGNED: return "Unassigned Unicode code point found in UTF-8 string.";
return "Invalid UTF-8 string"; case UTF8PROC_ERROR_INVALIDOPTS: return "Invalid options for UTF-8 processing chosen.";
case UTF8PROC_ERROR_NOTASSIGNED: default: return "An unknown error occurred while processing UTF-8 data.";
return "Unassigned Unicode code point found in UTF-8 string."; }
case UTF8PROC_ERROR_INVALIDOPTS: }
return "Invalid options for UTF-8 processing chosen.";
default: #define utf_cont(ch) (((ch) & 0xc0) == 0x80)
return "An unknown error occurred while processing UTF-8 data."; UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_iterate(const utf8proc_uint8_t* str,
} utf8proc_ssize_t strlen,
} utf8proc_int32_t* dst)
{
#define utf_cont(ch) (((ch) & 0xc0) == 0x80) utf8proc_int32_t uc;
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_iterate( const utf8proc_uint8_t* end;
const utf8proc_uint8_t *str, utf8proc_ssize_t strlen, utf8proc_int32_t *dst
) { *dst = -1;
utf8proc_int32_t uc; if(!strlen)
const utf8proc_uint8_t *end; return 0;
end = str + ((strlen < 0) ? 4 : strlen);
*dst = -1; uc = *str++;
if (!strlen) return 0; if(uc < 0x80)
end = str + ((strlen < 0) ? 4 : strlen); {
uc = *str++; *dst = uc;
if (uc < 0x80) { return 1;
*dst = uc; }
return 1; // Must be between 0xc2 and 0xf4 inclusive to be valid
} if((utf8proc_uint32_t)(uc - 0xc2) > (0xf4 - 0xc2))
// Must be between 0xc2 and 0xf4 inclusive to be valid
if ((utf8proc_uint32_t)(uc - 0xc2) > (0xf4-0xc2)) return UTF8PROC_ERROR_INVALIDUTF8;
if (uc < 0xe0) { // 2-byte sequence
// Must have valid continuation character
if (str >= end || !utf_cont(*str)) return UTF8PROC_ERROR_INVALIDUTF8;
*dst = ((uc & 0x1f)<<6) | (*str & 0x3f);
return 2;
}
if (uc < 0xf0) { // 3-byte sequence
if ((str + 1 >= end) || !utf_cont(*str) || !utf_cont(str[1]))
return UTF8PROC_ERROR_INVALIDUTF8; return UTF8PROC_ERROR_INVALIDUTF8;
// Check for surrogate chars if(uc < 0xe0)
if (uc == 0xed && *str > 0x9f) { // 2-byte sequence
return UTF8PROC_ERROR_INVALIDUTF8; // Must have valid continuation character
uc = ((uc & 0xf)<<12) | ((*str & 0x3f)<<6) | (str[1] & 0x3f); if(str >= end || !utf_cont(*str))
if (uc < 0x800) return UTF8PROC_ERROR_INVALIDUTF8;
return UTF8PROC_ERROR_INVALIDUTF8; *dst = ((uc & 0x1f) << 6) | (*str & 0x3f);
*dst = uc; return 2;
return 3; }
} if(uc < 0xf0)
// 4-byte sequence { // 3-byte sequence
// Must have 3 valid continuation characters if((str + 1 >= end) || !utf_cont(*str) || !utf_cont(str[1]))
if ((str + 2 >= end) || !utf_cont(*str) || !utf_cont(str[1]) || !utf_cont(str[2])) return UTF8PROC_ERROR_INVALIDUTF8;
return UTF8PROC_ERROR_INVALIDUTF8; // Check for surrogate chars
// Make sure in correct range (0x10000 - 0x10ffff) if(uc == 0xed && *str > 0x9f)
if (uc == 0xf0) { return UTF8PROC_ERROR_INVALIDUTF8;
if (*str < 0x90) return UTF8PROC_ERROR_INVALIDUTF8; uc = ((uc & 0xf) << 12) | ((*str & 0x3f) << 6) | (str[1] & 0x3f);
} else if (uc == 0xf4) { if(uc < 0x800)
if (*str > 0x8f) return UTF8PROC_ERROR_INVALIDUTF8; return UTF8PROC_ERROR_INVALIDUTF8;
} *dst = uc;
*dst = ((uc & 7)<<18) | ((*str & 0x3f)<<12) | ((str[1] & 0x3f)<<6) | (str[2] & 0x3f); return 3;
return 4; }
} // 4-byte sequence
// Must have 3 valid continuation characters
UTF8PROC_DLLEXPORT utf8proc_bool utf8proc_codepoint_valid(utf8proc_int32_t uc) { if((str + 2 >= end) || !utf_cont(*str) || !utf_cont(str[1]) || !utf_cont(str[2]))
return (((utf8proc_uint32_t)uc)-0xd800 > 0x07ff) && ((utf8proc_uint32_t)uc < 0x110000); return UTF8PROC_ERROR_INVALIDUTF8;
} // Make sure in correct range (0x10000 - 0x10ffff)
if(uc == 0xf0)
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_encode_char(utf8proc_int32_t uc, utf8proc_uint8_t *dst) { {
if (uc < 0x00) { if(*str < 0x90)
return 0; return UTF8PROC_ERROR_INVALIDUTF8;
} else if (uc < 0x80) { }
dst[0] = (utf8proc_uint8_t) uc; else if(uc == 0xf4)
return 1; {
} else if (uc < 0x800) { if(*str > 0x8f)
dst[0] = (utf8proc_uint8_t)(0xC0 + (uc >> 6)); return UTF8PROC_ERROR_INVALIDUTF8;
dst[1] = (utf8proc_uint8_t)(0x80 + (uc & 0x3F)); }
return 2; *dst = ((uc & 7) << 18) | ((*str & 0x3f) << 12) | ((str[1] & 0x3f) << 6) | (str[2] & 0x3f);
// Note: we allow encoding 0xd800-0xdfff here, so as not to change
// the API, however, these are actually invalid in UTF-8
} else if (uc < 0x10000) {
dst[0] = (utf8proc_uint8_t)(0xE0 + (uc >> 12));
dst[1] = (utf8proc_uint8_t)(0x80 + ((uc >> 6) & 0x3F));
dst[2] = (utf8proc_uint8_t)(0x80 + (uc & 0x3F));
return 3;
} else if (uc < 0x110000) {
dst[0] = (utf8proc_uint8_t)(0xF0 + (uc >> 18));
dst[1] = (utf8proc_uint8_t)(0x80 + ((uc >> 12) & 0x3F));
dst[2] = (utf8proc_uint8_t)(0x80 + ((uc >> 6) & 0x3F));
dst[3] = (utf8proc_uint8_t)(0x80 + (uc & 0x3F));
return 4; return 4;
} else return 0; }
UTF8PROC_DLLEXPORT utf8proc_bool utf8proc_codepoint_valid(utf8proc_int32_t uc)
{
return (((utf8proc_uint32_t)uc) - 0xd800 > 0x07ff) && ((utf8proc_uint32_t)uc < 0x110000);
}
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_encode_char(utf8proc_int32_t uc, utf8proc_uint8_t* dst)
{
if(uc < 0x00)
{
return 0;
}
else if(uc < 0x80)
{
dst[0] = (utf8proc_uint8_t)uc;
return 1;
}
else if(uc < 0x800)
{
dst[0] = (utf8proc_uint8_t)(0xC0 + (uc >> 6));
dst[1] = (utf8proc_uint8_t)(0x80 + (uc & 0x3F));
return 2;
// Note: we allow encoding 0xd800-0xdfff here, so as not to change
// the API, however, these are actually invalid in UTF-8
}
else if(uc < 0x10000)
{
dst[0] = (utf8proc_uint8_t)(0xE0 + (uc >> 12));
dst[1] = (utf8proc_uint8_t)(0x80 + ((uc >> 6) & 0x3F));
dst[2] = (utf8proc_uint8_t)(0x80 + (uc & 0x3F));
return 3;
}
else if(uc < 0x110000)
{
dst[0] = (utf8proc_uint8_t)(0xF0 + (uc >> 18));
dst[1] = (utf8proc_uint8_t)(0x80 + ((uc >> 12) & 0x3F));
dst[2] = (utf8proc_uint8_t)(0x80 + ((uc >> 6) & 0x3F));
dst[3] = (utf8proc_uint8_t)(0x80 + (uc & 0x3F));
return 4;
}
else
return 0;
} }
/* internal version used for inserting 0xff bytes between graphemes */ /* internal version used for inserting 0xff bytes between graphemes */
static utf8proc_ssize_t charbound_encode_char(utf8proc_int32_t uc, utf8proc_uint8_t *dst) { static utf8proc_ssize_t charbound_encode_char(utf8proc_int32_t uc, utf8proc_uint8_t* dst)
if (uc < 0x00) { {
if (uc == -1) { /* internal value used for grapheme breaks */ if(uc < 0x00)
dst[0] = (utf8proc_uint8_t)0xFF; {
if(uc == -1)
{ /* internal value used for grapheme breaks */
dst[0] = (utf8proc_uint8_t)0xFF;
return 1;
}
return 0;
}
else if(uc < 0x80)
{
dst[0] = (utf8proc_uint8_t)uc;
return 1; return 1;
} }
return 0; else if(uc < 0x800)
} else if (uc < 0x80) { {
dst[0] = (utf8proc_uint8_t)uc; dst[0] = (utf8proc_uint8_t)(0xC0 + (uc >> 6));
return 1; dst[1] = (utf8proc_uint8_t)(0x80 + (uc & 0x3F));
} else if (uc < 0x800) { return 2;
dst[0] = (utf8proc_uint8_t)(0xC0 + (uc >> 6)); }
dst[1] = (utf8proc_uint8_t)(0x80 + (uc & 0x3F)); else if(uc < 0x10000)
return 2; {
} else if (uc < 0x10000) { dst[0] = (utf8proc_uint8_t)(0xE0 + (uc >> 12));
dst[0] = (utf8proc_uint8_t)(0xE0 + (uc >> 12)); dst[1] = (utf8proc_uint8_t)(0x80 + ((uc >> 6) & 0x3F));
dst[1] = (utf8proc_uint8_t)(0x80 + ((uc >> 6) & 0x3F)); dst[2] = (utf8proc_uint8_t)(0x80 + (uc & 0x3F));
dst[2] = (utf8proc_uint8_t)(0x80 + (uc & 0x3F)); return 3;
return 3; }
} else if (uc < 0x110000) { else if(uc < 0x110000)
dst[0] = (utf8proc_uint8_t)(0xF0 + (uc >> 18)); {
dst[1] = (utf8proc_uint8_t)(0x80 + ((uc >> 12) & 0x3F)); dst[0] = (utf8proc_uint8_t)(0xF0 + (uc >> 18));
dst[2] = (utf8proc_uint8_t)(0x80 + ((uc >> 6) & 0x3F)); dst[1] = (utf8proc_uint8_t)(0x80 + ((uc >> 12) & 0x3F));
dst[3] = (utf8proc_uint8_t)(0x80 + (uc & 0x3F)); dst[2] = (utf8proc_uint8_t)(0x80 + ((uc >> 6) & 0x3F));
return 4; dst[3] = (utf8proc_uint8_t)(0x80 + (uc & 0x3F));
} else return 0; return 4;
}
else
return 0;
} }
/* internal "unsafe" version that does not check whether uc is in range */ /* internal "unsafe" version that does not check whether uc is in range */
static const utf8proc_property_t *unsafe_get_property(utf8proc_int32_t uc) { static const utf8proc_property_t* unsafe_get_property(utf8proc_int32_t uc)
/* ASSERT: uc >= 0 && uc < 0x110000 */ {
return utf8proc_properties + ( /* ASSERT: uc >= 0 && uc < 0x110000 */
utf8proc_stage2table[ return utf8proc_properties +
utf8proc_stage1table[uc >> 8] + (uc & 0xFF) (utf8proc_stage2table[utf8proc_stage1table[uc >> 8] + (uc & 0xFF)]);
]
);
} }
UTF8PROC_DLLEXPORT const utf8proc_property_t *utf8proc_get_property(utf8proc_int32_t uc) { UTF8PROC_DLLEXPORT const utf8proc_property_t* utf8proc_get_property(utf8proc_int32_t uc)
return uc < 0 || uc >= 0x110000 ? utf8proc_properties : unsafe_get_property(uc); {
return uc < 0 || uc >= 0x110000 ? utf8proc_properties : unsafe_get_property(uc);
} }
/* return whether there is a grapheme break between boundclasses lbc and tbc /* return whether there is a grapheme break between boundclasses lbc and tbc
...@@ -250,543 +272,707 @@ UTF8PROC_DLLEXPORT const utf8proc_property_t *utf8proc_get_property(utf8proc_int ...@@ -250,543 +272,707 @@ UTF8PROC_DLLEXPORT const utf8proc_property_t *utf8proc_get_property(utf8proc_int
http://www.unicode.org/reports/tr29/tr29-29.html http://www.unicode.org/reports/tr29/tr29-29.html
CAVEATS: CAVEATS:
Please note that evaluation of GB10 (grapheme breaks between emoji zwj sequences) Please note that evaluation of GB10 (grapheme breaks between emoji zwj
and GB 12/13 (regional indicator code points) require knowledge of previous characters sequences) and GB 12/13 (regional indicator code points) require knowledge of
and are thus not handled by this function. This may result in an incorrect break before previous characters and are thus not handled by this function. This may result
an E_Modifier class codepoint and an incorrectly missing break between two in an incorrect break before an E_Modifier class codepoint and an incorrectly
REGIONAL_INDICATOR class code points if such support does not exist in the caller. missing break between two REGIONAL_INDICATOR class code points if such support
does not exist in the caller.
See the special support in grapheme_break_extended, for required bookkeeping by the caller.
See the special support in grapheme_break_extended, for required bookkeeping
by the caller.
*/ */
static utf8proc_bool grapheme_break_simple(int lbc, int tbc) { static utf8proc_bool grapheme_break_simple(int lbc, int tbc)
return {
(lbc == UTF8PROC_BOUNDCLASS_START) ? true : // GB1 return (lbc == UTF8PROC_BOUNDCLASS_START) ? true : // GB1
(lbc == UTF8PROC_BOUNDCLASS_CR && // GB3 (lbc == UTF8PROC_BOUNDCLASS_CR && // GB3
tbc == UTF8PROC_BOUNDCLASS_LF) ? false : // --- tbc == UTF8PROC_BOUNDCLASS_LF)
(lbc >= UTF8PROC_BOUNDCLASS_CR && lbc <= UTF8PROC_BOUNDCLASS_CONTROL) ? true : // GB4 ? false
(tbc >= UTF8PROC_BOUNDCLASS_CR && tbc <= UTF8PROC_BOUNDCLASS_CONTROL) ? true : // GB5 : // ---
(lbc == UTF8PROC_BOUNDCLASS_L && // GB6 (lbc >= UTF8PROC_BOUNDCLASS_CR && lbc <= UTF8PROC_BOUNDCLASS_CONTROL) ? true
(tbc == UTF8PROC_BOUNDCLASS_L || // --- : // GB4
tbc == UTF8PROC_BOUNDCLASS_V || // --- (tbc >= UTF8PROC_BOUNDCLASS_CR && tbc <= UTF8PROC_BOUNDCLASS_CONTROL) ? true
tbc == UTF8PROC_BOUNDCLASS_LV || // --- : // GB5
tbc == UTF8PROC_BOUNDCLASS_LVT)) ? false : // --- (lbc == UTF8PROC_BOUNDCLASS_L && // GB6
((lbc == UTF8PROC_BOUNDCLASS_LV || // GB7 (tbc == UTF8PROC_BOUNDCLASS_L || // ---
lbc == UTF8PROC_BOUNDCLASS_V) && // --- tbc == UTF8PROC_BOUNDCLASS_V || // ---
(tbc == UTF8PROC_BOUNDCLASS_V || // --- tbc == UTF8PROC_BOUNDCLASS_LV || // ---
tbc == UTF8PROC_BOUNDCLASS_T)) ? false : // --- tbc == UTF8PROC_BOUNDCLASS_LVT))
((lbc == UTF8PROC_BOUNDCLASS_LVT || // GB8 ? false
lbc == UTF8PROC_BOUNDCLASS_T) && // --- : // ---
tbc == UTF8PROC_BOUNDCLASS_T) ? false : // --- ((lbc == UTF8PROC_BOUNDCLASS_LV || // GB7
(tbc == UTF8PROC_BOUNDCLASS_EXTEND || // GB9 lbc == UTF8PROC_BOUNDCLASS_V) && // ---
tbc == UTF8PROC_BOUNDCLASS_ZWJ || // --- (tbc == UTF8PROC_BOUNDCLASS_V || // ---
tbc == UTF8PROC_BOUNDCLASS_SPACINGMARK || // GB9a tbc == UTF8PROC_BOUNDCLASS_T))
lbc == UTF8PROC_BOUNDCLASS_PREPEND) ? false : // GB9b ? false
(lbc == UTF8PROC_BOUNDCLASS_E_ZWG && // GB11 (requires additional handling below) : // ---
tbc == UTF8PROC_BOUNDCLASS_EXTENDED_PICTOGRAPHIC) ? false : // ---- ((lbc == UTF8PROC_BOUNDCLASS_LVT || // GB8
(lbc == UTF8PROC_BOUNDCLASS_REGIONAL_INDICATOR && // GB12/13 (requires additional handling below) lbc == UTF8PROC_BOUNDCLASS_T) && // ---
tbc == UTF8PROC_BOUNDCLASS_REGIONAL_INDICATOR) ? false : // ---- tbc == UTF8PROC_BOUNDCLASS_T)
true; // GB999 ? false
} : // ---
(tbc == UTF8PROC_BOUNDCLASS_EXTEND || // GB9
static utf8proc_bool grapheme_break_extended(int lbc, int tbc, utf8proc_int32_t *state) tbc == UTF8PROC_BOUNDCLASS_ZWJ || // ---
{ tbc == UTF8PROC_BOUNDCLASS_SPACINGMARK || // GB9a
if (state) { lbc == UTF8PROC_BOUNDCLASS_PREPEND)
int lbc_override; ? false
if (*state == UTF8PROC_BOUNDCLASS_START) : // GB9b
*state = lbc_override = lbc; (lbc == UTF8PROC_BOUNDCLASS_E_ZWG && // GB11 (requires additional
else // handling below)
lbc_override = *state; tbc == UTF8PROC_BOUNDCLASS_EXTENDED_PICTOGRAPHIC)
utf8proc_bool break_permitted = grapheme_break_simple(lbc_override, tbc); ? false
: // ----
// Special support for GB 12/13 made possible by GB999. After two RI (lbc == UTF8PROC_BOUNDCLASS_REGIONAL_INDICATOR && // GB12/13
// class codepoints we want to force a break. Do this by resetting the // (requires
// second RI's bound class to UTF8PROC_BOUNDCLASS_OTHER, to force a break // additional
// after that character according to GB999 (unless of course such a break is // handling below)
// forbidden by a different rule such as GB9). tbc == UTF8PROC_BOUNDCLASS_REGIONAL_INDICATOR)
if (*state == tbc && tbc == UTF8PROC_BOUNDCLASS_REGIONAL_INDICATOR) ? false
*state = UTF8PROC_BOUNDCLASS_OTHER; : // ----
// Special support for GB11 (emoji extend* zwj / emoji) true; // GB999
else if (*state == UTF8PROC_BOUNDCLASS_EXTENDED_PICTOGRAPHIC) { }
if (tbc == UTF8PROC_BOUNDCLASS_EXTEND) // fold EXTEND codepoints into emoji
*state = UTF8PROC_BOUNDCLASS_EXTENDED_PICTOGRAPHIC; static utf8proc_bool grapheme_break_extended(int lbc, int tbc, utf8proc_int32_t* state)
else if (tbc == UTF8PROC_BOUNDCLASS_ZWJ) {
*state = UTF8PROC_BOUNDCLASS_E_ZWG; // state to record emoji+zwg combo if(state)
else {
*state = tbc; int lbc_override;
if(*state == UTF8PROC_BOUNDCLASS_START)
*state = lbc_override = lbc;
else
lbc_override = *state;
utf8proc_bool break_permitted = grapheme_break_simple(lbc_override, tbc);
// Special support for GB 12/13 made possible by GB999. After two RI
// class codepoints we want to force a break. Do this by resetting the
// second RI's bound class to UTF8PROC_BOUNDCLASS_OTHER, to force a break
// after that character according to GB999 (unless of course such a break is
// forbidden by a different rule such as GB9).
if(*state == tbc && tbc == UTF8PROC_BOUNDCLASS_REGIONAL_INDICATOR)
*state = UTF8PROC_BOUNDCLASS_OTHER;
// Special support for GB11 (emoji extend* zwj / emoji)
else if(*state == UTF8PROC_BOUNDCLASS_EXTENDED_PICTOGRAPHIC)
{
if(tbc == UTF8PROC_BOUNDCLASS_EXTEND) // fold EXTEND codepoints into emoji
*state = UTF8PROC_BOUNDCLASS_EXTENDED_PICTOGRAPHIC;
else if(tbc == UTF8PROC_BOUNDCLASS_ZWJ)
*state = UTF8PROC_BOUNDCLASS_E_ZWG; // state to record emoji+zwg combo
else
*state = tbc;
}
else
*state = tbc;
return break_permitted;
} }
else else
*state = tbc; return grapheme_break_simple(lbc, tbc);
return break_permitted;
}
else
return grapheme_break_simple(lbc, tbc);
} }
UTF8PROC_DLLEXPORT utf8proc_bool utf8proc_grapheme_break_stateful( UTF8PROC_DLLEXPORT utf8proc_bool utf8proc_grapheme_break_stateful(utf8proc_int32_t c1,
utf8proc_int32_t c1, utf8proc_int32_t c2, utf8proc_int32_t *state) { utf8proc_int32_t c2,
utf8proc_int32_t* state)
return grapheme_break_extended(utf8proc_get_property(c1)->boundclass, {
utf8proc_get_property(c2)->boundclass, return grapheme_break_extended(
state); utf8proc_get_property(c1)->boundclass, utf8proc_get_property(c2)->boundclass, state);
} }
UTF8PROC_DLLEXPORT utf8proc_bool utf8proc_grapheme_break(utf8proc_int32_t c1, utf8proc_int32_t c2)
UTF8PROC_DLLEXPORT utf8proc_bool utf8proc_grapheme_break( {
utf8proc_int32_t c1, utf8proc_int32_t c2) { return utf8proc_grapheme_break_stateful(c1, c2, NULL);
return utf8proc_grapheme_break_stateful(c1, c2, NULL);
} }
static utf8proc_int32_t seqindex_decode_entry(const utf8proc_uint16_t **entry) static utf8proc_int32_t seqindex_decode_entry(const utf8proc_uint16_t** entry)
{ {
utf8proc_int32_t entry_cp = **entry; utf8proc_int32_t entry_cp = **entry;
if ((entry_cp & 0xF800) == 0xD800) { if((entry_cp & 0xF800) == 0xD800)
*entry = *entry + 1; {
entry_cp = ((entry_cp & 0x03FF) << 10) | (**entry & 0x03FF); *entry = *entry + 1;
entry_cp += 0x10000; entry_cp = ((entry_cp & 0x03FF) << 10) | (**entry & 0x03FF);
} entry_cp += 0x10000;
return entry_cp; }
return entry_cp;
} }
static utf8proc_int32_t seqindex_decode_index(const utf8proc_uint32_t seqindex) static utf8proc_int32_t seqindex_decode_index(const utf8proc_uint32_t seqindex)
{ {
const utf8proc_uint16_t *entry = &utf8proc_sequences[seqindex]; const utf8proc_uint16_t* entry = &utf8proc_sequences[seqindex];
return seqindex_decode_entry(&entry); return seqindex_decode_entry(&entry);
} }
static utf8proc_ssize_t seqindex_write_char_decomposed(utf8proc_uint16_t seqindex, utf8proc_int32_t *dst, utf8proc_ssize_t bufsize, utf8proc_option_t options, int *last_boundclass) { static utf8proc_ssize_t seqindex_write_char_decomposed(utf8proc_uint16_t seqindex,
utf8proc_ssize_t written = 0; utf8proc_int32_t* dst,
const utf8proc_uint16_t *entry = &utf8proc_sequences[seqindex & 0x3FFF]; utf8proc_ssize_t bufsize,
int len = seqindex >> 14; utf8proc_option_t options,
if (len >= 3) { int* last_boundclass)
len = *entry; {
entry++; utf8proc_ssize_t written = 0;
} const utf8proc_uint16_t* entry = &utf8proc_sequences[seqindex & 0x3FFF];
for (; len >= 0; entry++, len--) { int len = seqindex >> 14;
utf8proc_int32_t entry_cp = seqindex_decode_entry(&entry); if(len >= 3)
{
written += utf8proc_decompose_char(entry_cp, dst+written, len = *entry;
(bufsize > written) ? (bufsize - written) : 0, options, entry++;
last_boundclass); }
if (written < 0) return UTF8PROC_ERROR_OVERFLOW; for(; len >= 0; entry++, len--)
} {
return written; utf8proc_int32_t entry_cp = seqindex_decode_entry(&entry);
written += utf8proc_decompose_char(entry_cp,
dst + written,
(bufsize > written) ? (bufsize - written) : 0,
options,
last_boundclass);
if(written < 0)
return UTF8PROC_ERROR_OVERFLOW;
}
return written;
} }
UTF8PROC_DLLEXPORT utf8proc_int32_t utf8proc_tolower(utf8proc_int32_t c) UTF8PROC_DLLEXPORT utf8proc_int32_t utf8proc_tolower(utf8proc_int32_t c)
{ {
utf8proc_int32_t cl = utf8proc_get_property(c)->lowercase_seqindex; utf8proc_int32_t cl = utf8proc_get_property(c)->lowercase_seqindex;
return cl != UINT16_MAX ? seqindex_decode_index((utf8proc_uint32_t)cl) : c; return cl != UINT16_MAX ? seqindex_decode_index((utf8proc_uint32_t)cl) : c;
} }
UTF8PROC_DLLEXPORT utf8proc_int32_t utf8proc_toupper(utf8proc_int32_t c) UTF8PROC_DLLEXPORT utf8proc_int32_t utf8proc_toupper(utf8proc_int32_t c)
{ {
utf8proc_int32_t cu = utf8proc_get_property(c)->uppercase_seqindex; utf8proc_int32_t cu = utf8proc_get_property(c)->uppercase_seqindex;
return cu != UINT16_MAX ? seqindex_decode_index((utf8proc_uint32_t)cu) : c; return cu != UINT16_MAX ? seqindex_decode_index((utf8proc_uint32_t)cu) : c;
} }
UTF8PROC_DLLEXPORT utf8proc_int32_t utf8proc_totitle(utf8proc_int32_t c) UTF8PROC_DLLEXPORT utf8proc_int32_t utf8proc_totitle(utf8proc_int32_t c)
{ {
utf8proc_int32_t cu = utf8proc_get_property(c)->titlecase_seqindex; utf8proc_int32_t cu = utf8proc_get_property(c)->titlecase_seqindex;
return cu != UINT16_MAX ? seqindex_decode_index((utf8proc_uint32_t)cu) : c; return cu != UINT16_MAX ? seqindex_decode_index((utf8proc_uint32_t)cu) : c;
} }
UTF8PROC_DLLEXPORT int utf8proc_islower(utf8proc_int32_t c) UTF8PROC_DLLEXPORT int utf8proc_islower(utf8proc_int32_t c)
{ {
const utf8proc_property_t *p = utf8proc_get_property(c); const utf8proc_property_t* p = utf8proc_get_property(c);
return p->lowercase_seqindex != p->uppercase_seqindex && p->lowercase_seqindex == UINT16_MAX; return p->lowercase_seqindex != p->uppercase_seqindex && p->lowercase_seqindex == UINT16_MAX;
} }
UTF8PROC_DLLEXPORT int utf8proc_isupper(utf8proc_int32_t c) UTF8PROC_DLLEXPORT int utf8proc_isupper(utf8proc_int32_t c)
{ {
const utf8proc_property_t *p = utf8proc_get_property(c); const utf8proc_property_t* p = utf8proc_get_property(c);
return p->lowercase_seqindex != p->uppercase_seqindex && p->uppercase_seqindex == UINT16_MAX && p->category != UTF8PROC_CATEGORY_LT; return p->lowercase_seqindex != p->uppercase_seqindex && p->uppercase_seqindex == UINT16_MAX &&
p->category != UTF8PROC_CATEGORY_LT;
} }
/* return a character width analogous to wcwidth (except portable and /* return a character width analogous to wcwidth (except portable and
hopefully less buggy than most system wcwidth functions). */ hopefully less buggy than most system wcwidth functions). */
UTF8PROC_DLLEXPORT int utf8proc_charwidth(utf8proc_int32_t c) { UTF8PROC_DLLEXPORT int utf8proc_charwidth(utf8proc_int32_t c)
return utf8proc_get_property(c)->charwidth; {
return utf8proc_get_property(c)->charwidth;
} }
UTF8PROC_DLLEXPORT utf8proc_category_t utf8proc_category(utf8proc_int32_t c) { UTF8PROC_DLLEXPORT utf8proc_category_t utf8proc_category(utf8proc_int32_t c)
return (utf8proc_category_t) utf8proc_get_property(c)->category; {
return (utf8proc_category_t)utf8proc_get_property(c)->category;
} }
UTF8PROC_DLLEXPORT const char *utf8proc_category_string(utf8proc_int32_t c) { UTF8PROC_DLLEXPORT const char* utf8proc_category_string(utf8proc_int32_t c)
static const char s[][3] = {"Cn","Lu","Ll","Lt","Lm","Lo","Mn","Mc","Me","Nd","Nl","No","Pc","Pd","Ps","Pe","Pi","Pf","Po","Sm","Sc","Sk","So","Zs","Zl","Zp","Cc","Cf","Cs","Co"}; {
return s[utf8proc_category(c)]; static const char s[][3] = {"Cn", "Lu", "Ll", "Lt", "Lm", "Lo", "Mn", "Mc", "Me", "Nd",
"Nl", "No", "Pc", "Pd", "Ps", "Pe", "Pi", "Pf", "Po", "Sm",
"Sc", "Sk", "So", "Zs", "Zl", "Zp", "Cc", "Cf", "Cs", "Co"};
return s[utf8proc_category(c)];
} }
#define utf8proc_decompose_lump(replacement_uc) \ #define utf8proc_decompose_lump(replacement_uc) \
return utf8proc_decompose_char((replacement_uc), dst, bufsize, \ return utf8proc_decompose_char( \
options & ~(unsigned int)UTF8PROC_LUMP, last_boundclass) (replacement_uc), dst, bufsize, options & ~(unsigned int)UTF8PROC_LUMP, last_boundclass)
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose_char(utf8proc_int32_t uc, utf8proc_int32_t *dst, utf8proc_ssize_t bufsize, utf8proc_option_t options, int *last_boundclass) { UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose_char(utf8proc_int32_t uc,
const utf8proc_property_t *property; utf8proc_int32_t* dst,
utf8proc_propval_t category; utf8proc_ssize_t bufsize,
utf8proc_int32_t hangul_sindex; utf8proc_option_t options,
if (uc < 0 || uc >= 0x110000) return UTF8PROC_ERROR_NOTASSIGNED; int* last_boundclass)
property = unsafe_get_property(uc); {
category = property->category; const utf8proc_property_t* property;
hangul_sindex = uc - UTF8PROC_HANGUL_SBASE; utf8proc_propval_t category;
if (options & (UTF8PROC_COMPOSE|UTF8PROC_DECOMPOSE)) { utf8proc_int32_t hangul_sindex;
if (hangul_sindex >= 0 && hangul_sindex < UTF8PROC_HANGUL_SCOUNT) { if(uc < 0 || uc >= 0x110000)
utf8proc_int32_t hangul_tindex; return UTF8PROC_ERROR_NOTASSIGNED;
if (bufsize >= 1) { property = unsafe_get_property(uc);
dst[0] = UTF8PROC_HANGUL_LBASE + category = property->category;
hangul_sindex / UTF8PROC_HANGUL_NCOUNT; hangul_sindex = uc - UTF8PROC_HANGUL_SBASE;
if (bufsize >= 2) dst[1] = UTF8PROC_HANGUL_VBASE + if(options & (UTF8PROC_COMPOSE | UTF8PROC_DECOMPOSE))
(hangul_sindex % UTF8PROC_HANGUL_NCOUNT) / UTF8PROC_HANGUL_TCOUNT; {
} if(hangul_sindex >= 0 && hangul_sindex < UTF8PROC_HANGUL_SCOUNT)
hangul_tindex = hangul_sindex % UTF8PROC_HANGUL_TCOUNT; {
if (!hangul_tindex) return 2; utf8proc_int32_t hangul_tindex;
if (bufsize >= 3) dst[2] = UTF8PROC_HANGUL_TBASE + hangul_tindex; if(bufsize >= 1)
return 3; {
} dst[0] = UTF8PROC_HANGUL_LBASE + hangul_sindex / UTF8PROC_HANGUL_NCOUNT;
} if(bufsize >= 2)
if (options & UTF8PROC_REJECTNA) { dst[1] = UTF8PROC_HANGUL_VBASE +
if (!category) return UTF8PROC_ERROR_NOTASSIGNED; (hangul_sindex % UTF8PROC_HANGUL_NCOUNT) / UTF8PROC_HANGUL_TCOUNT;
} }
if (options & UTF8PROC_IGNORE) { hangul_tindex = hangul_sindex % UTF8PROC_HANGUL_TCOUNT;
if (property->ignorable) return 0; if(!hangul_tindex)
} return 2;
if (options & UTF8PROC_STRIPNA) { if(bufsize >= 3)
if (!category) return 0; dst[2] = UTF8PROC_HANGUL_TBASE + hangul_tindex;
} return 3;
if (options & UTF8PROC_LUMP) { }
if (category == UTF8PROC_CATEGORY_ZS) utf8proc_decompose_lump(0x0020); }
if (uc == 0x2018 || uc == 0x2019 || uc == 0x02BC || uc == 0x02C8) if(options & UTF8PROC_REJECTNA)
utf8proc_decompose_lump(0x0027); {
if (category == UTF8PROC_CATEGORY_PD || uc == 0x2212) if(!category)
utf8proc_decompose_lump(0x002D); return UTF8PROC_ERROR_NOTASSIGNED;
if (uc == 0x2044 || uc == 0x2215) utf8proc_decompose_lump(0x002F); }
if (uc == 0x2236) utf8proc_decompose_lump(0x003A); if(options & UTF8PROC_IGNORE)
if (uc == 0x2039 || uc == 0x2329 || uc == 0x3008) {
utf8proc_decompose_lump(0x003C); if(property->ignorable)
if (uc == 0x203A || uc == 0x232A || uc == 0x3009) return 0;
utf8proc_decompose_lump(0x003E); }
if (uc == 0x2216) utf8proc_decompose_lump(0x005C); if(options & UTF8PROC_STRIPNA)
if (uc == 0x02C4 || uc == 0x02C6 || uc == 0x2038 || uc == 0x2303) {
utf8proc_decompose_lump(0x005E); if(!category)
if (category == UTF8PROC_CATEGORY_PC || uc == 0x02CD) return 0;
utf8proc_decompose_lump(0x005F); }
if (uc == 0x02CB) utf8proc_decompose_lump(0x0060); if(options & UTF8PROC_LUMP)
if (uc == 0x2223) utf8proc_decompose_lump(0x007C); {
if (uc == 0x223C) utf8proc_decompose_lump(0x007E); if(category == UTF8PROC_CATEGORY_ZS)
if ((options & UTF8PROC_NLF2LS) && (options & UTF8PROC_NLF2PS)) { utf8proc_decompose_lump(0x0020);
if (category == UTF8PROC_CATEGORY_ZL || if(uc == 0x2018 || uc == 0x2019 || uc == 0x02BC || uc == 0x02C8)
category == UTF8PROC_CATEGORY_ZP) utf8proc_decompose_lump(0x0027);
utf8proc_decompose_lump(0x000A); if(category == UTF8PROC_CATEGORY_PD || uc == 0x2212)
} utf8proc_decompose_lump(0x002D);
} if(uc == 0x2044 || uc == 0x2215)
if (options & UTF8PROC_STRIPMARK) { utf8proc_decompose_lump(0x002F);
if (category == UTF8PROC_CATEGORY_MN || if(uc == 0x2236)
category == UTF8PROC_CATEGORY_MC || utf8proc_decompose_lump(0x003A);
category == UTF8PROC_CATEGORY_ME) return 0; if(uc == 0x2039 || uc == 0x2329 || uc == 0x3008)
} utf8proc_decompose_lump(0x003C);
if (options & UTF8PROC_CASEFOLD) { if(uc == 0x203A || uc == 0x232A || uc == 0x3009)
if (property->casefold_seqindex != UINT16_MAX) { utf8proc_decompose_lump(0x003E);
return seqindex_write_char_decomposed(property->casefold_seqindex, dst, bufsize, options, last_boundclass); if(uc == 0x2216)
} utf8proc_decompose_lump(0x005C);
} if(uc == 0x02C4 || uc == 0x02C6 || uc == 0x2038 || uc == 0x2303)
if (options & (UTF8PROC_COMPOSE|UTF8PROC_DECOMPOSE)) { utf8proc_decompose_lump(0x005E);
if (property->decomp_seqindex != UINT16_MAX && if(category == UTF8PROC_CATEGORY_PC || uc == 0x02CD)
(!property->decomp_type || (options & UTF8PROC_COMPAT))) { utf8proc_decompose_lump(0x005F);
return seqindex_write_char_decomposed(property->decomp_seqindex, dst, bufsize, options, last_boundclass); if(uc == 0x02CB)
} utf8proc_decompose_lump(0x0060);
} if(uc == 0x2223)
if (options & UTF8PROC_CHARBOUND) { utf8proc_decompose_lump(0x007C);
utf8proc_bool boundary; if(uc == 0x223C)
int tbc = property->boundclass; utf8proc_decompose_lump(0x007E);
boundary = grapheme_break_extended(*last_boundclass, tbc, last_boundclass); if((options & UTF8PROC_NLF2LS) && (options & UTF8PROC_NLF2PS))
if (boundary) { {
if (bufsize >= 1) dst[0] = -1; /* sentinel value for grapheme break */ if(category == UTF8PROC_CATEGORY_ZL || category == UTF8PROC_CATEGORY_ZP)
if (bufsize >= 2) dst[1] = uc; utf8proc_decompose_lump(0x000A);
return 2; }
} }
} if(options & UTF8PROC_STRIPMARK)
if (bufsize >= 1) *dst = uc; {
return 1; if(category == UTF8PROC_CATEGORY_MN || category == UTF8PROC_CATEGORY_MC ||
} category == UTF8PROC_CATEGORY_ME)
return 0;
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose( }
const utf8proc_uint8_t *str, utf8proc_ssize_t strlen, if(options & UTF8PROC_CASEFOLD)
utf8proc_int32_t *buffer, utf8proc_ssize_t bufsize, utf8proc_option_t options {
) { if(property->casefold_seqindex != UINT16_MAX)
{
return seqindex_write_char_decomposed(
property->casefold_seqindex, dst, bufsize, options, last_boundclass);
}
}
if(options & (UTF8PROC_COMPOSE | UTF8PROC_DECOMPOSE))
{
if(property->decomp_seqindex != UINT16_MAX &&
(!property->decomp_type || (options & UTF8PROC_COMPAT)))
{
return seqindex_write_char_decomposed(
property->decomp_seqindex, dst, bufsize, options, last_boundclass);
}
}
if(options & UTF8PROC_CHARBOUND)
{
utf8proc_bool boundary;
int tbc = property->boundclass;
boundary = grapheme_break_extended(*last_boundclass, tbc, last_boundclass);
if(boundary)
{
if(bufsize >= 1)
dst[0] = -1; /* sentinel value for grapheme break */
if(bufsize >= 2)
dst[1] = uc;
return 2;
}
}
if(bufsize >= 1)
*dst = uc;
return 1;
}
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose(const utf8proc_uint8_t* str,
utf8proc_ssize_t strlen,
utf8proc_int32_t* buffer,
utf8proc_ssize_t bufsize,
utf8proc_option_t options)
{
return utf8proc_decompose_custom(str, strlen, buffer, bufsize, options, NULL, NULL); return utf8proc_decompose_custom(str, strlen, buffer, bufsize, options, NULL, NULL);
} }
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose_custom( UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose_custom(const utf8proc_uint8_t* str,
const utf8proc_uint8_t *str, utf8proc_ssize_t strlen, utf8proc_ssize_t strlen,
utf8proc_int32_t *buffer, utf8proc_ssize_t bufsize, utf8proc_option_t options, utf8proc_int32_t* buffer,
utf8proc_custom_func custom_func, void *custom_data utf8proc_ssize_t bufsize,
) { utf8proc_option_t options,
/* strlen will be ignored, if UTF8PROC_NULLTERM is set in options */ utf8proc_custom_func custom_func,
utf8proc_ssize_t wpos = 0; void* custom_data)
if ((options & UTF8PROC_COMPOSE) && (options & UTF8PROC_DECOMPOSE)) {
return UTF8PROC_ERROR_INVALIDOPTS; /* strlen will be ignored, if UTF8PROC_NULLTERM is set in options */
if ((options & UTF8PROC_STRIPMARK) &&
!(options & UTF8PROC_COMPOSE) && !(options & UTF8PROC_DECOMPOSE))
return UTF8PROC_ERROR_INVALIDOPTS;
{
utf8proc_int32_t uc;
utf8proc_ssize_t rpos = 0;
utf8proc_ssize_t decomp_result;
int boundclass = UTF8PROC_BOUNDCLASS_START;
while (1) {
if (options & UTF8PROC_NULLTERM) {
rpos += utf8proc_iterate(str + rpos, -1, &uc);
/* checking of return value is not necessary,
as 'uc' is < 0 in case of error */
if (uc < 0) return UTF8PROC_ERROR_INVALIDUTF8;
if (rpos < 0) return UTF8PROC_ERROR_OVERFLOW;
if (uc == 0) break;
} else {
if (rpos >= strlen) break;
rpos += utf8proc_iterate(str + rpos, strlen - rpos, &uc);
if (uc < 0) return UTF8PROC_ERROR_INVALIDUTF8;
}
if (custom_func != NULL) {
uc = custom_func(uc, custom_data); /* user-specified custom mapping */
}
decomp_result = utf8proc_decompose_char(
uc, buffer + wpos, (bufsize > wpos) ? (bufsize - wpos) : 0, options,
&boundclass
);
if (decomp_result < 0) return decomp_result;
wpos += decomp_result;
/* prohibiting integer overflows due to too long strings: */
if (wpos < 0 ||
wpos > (utf8proc_ssize_t)(SSIZE_MAX/sizeof(utf8proc_int32_t)/2))
return UTF8PROC_ERROR_OVERFLOW;
}
}
if ((options & (UTF8PROC_COMPOSE|UTF8PROC_DECOMPOSE)) && bufsize >= wpos) {
utf8proc_ssize_t pos = 0;
while (pos < wpos-1) {
utf8proc_int32_t uc1, uc2;
const utf8proc_property_t *property1, *property2;
uc1 = buffer[pos];
uc2 = buffer[pos+1];
property1 = unsafe_get_property(uc1);
property2 = unsafe_get_property(uc2);
if (property1->combining_class > property2->combining_class &&
property2->combining_class > 0) {
buffer[pos] = uc2;
buffer[pos+1] = uc1;
if (pos > 0) pos--; else pos++;
} else {
pos++;
}
}
}
return wpos;
}
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_normalize_utf32(utf8proc_int32_t *buffer, utf8proc_ssize_t length, utf8proc_option_t options) {
/* UTF8PROC_NULLTERM option will be ignored, 'length' is never ignored */
if (options & (UTF8PROC_NLF2LS | UTF8PROC_NLF2PS | UTF8PROC_STRIPCC)) {
utf8proc_ssize_t rpos;
utf8proc_ssize_t wpos = 0;
utf8proc_int32_t uc;
for (rpos = 0; rpos < length; rpos++) {
uc = buffer[rpos];
if (uc == 0x000D && rpos < length-1 && buffer[rpos+1] == 0x000A) rpos++;
if (uc == 0x000A || uc == 0x000D || uc == 0x0085 ||
((options & UTF8PROC_STRIPCC) && (uc == 0x000B || uc == 0x000C))) {
if (options & UTF8PROC_NLF2LS) {
if (options & UTF8PROC_NLF2PS) {
buffer[wpos++] = 0x000A;
} else {
buffer[wpos++] = 0x2028;
}
} else {
if (options & UTF8PROC_NLF2PS) {
buffer[wpos++] = 0x2029;
} else {
buffer[wpos++] = 0x0020;
}
}
} else if ((options & UTF8PROC_STRIPCC) &&
(uc < 0x0020 || (uc >= 0x007F && uc < 0x00A0))) {
if (uc == 0x0009) buffer[wpos++] = 0x0020;
} else {
buffer[wpos++] = uc;
}
}
length = wpos;
}
if (options & UTF8PROC_COMPOSE) {
utf8proc_int32_t *starter = NULL;
utf8proc_int32_t current_char;
const utf8proc_property_t *starter_property = NULL, *current_property;
utf8proc_propval_t max_combining_class = -1;
utf8proc_ssize_t rpos;
utf8proc_ssize_t wpos = 0; utf8proc_ssize_t wpos = 0;
utf8proc_int32_t composition; if((options & UTF8PROC_COMPOSE) && (options & UTF8PROC_DECOMPOSE))
for (rpos = 0; rpos < length; rpos++) { return UTF8PROC_ERROR_INVALIDOPTS;
current_char = buffer[rpos]; if((options & UTF8PROC_STRIPMARK) && !(options & UTF8PROC_COMPOSE) &&
current_property = unsafe_get_property(current_char); !(options & UTF8PROC_DECOMPOSE))
if (starter && current_property->combining_class > max_combining_class) { return UTF8PROC_ERROR_INVALIDOPTS;
/* combination perhaps possible */ {
utf8proc_int32_t hangul_lindex; utf8proc_int32_t uc;
utf8proc_int32_t hangul_sindex; utf8proc_ssize_t rpos = 0;
hangul_lindex = *starter - UTF8PROC_HANGUL_LBASE; utf8proc_ssize_t decomp_result;
if (hangul_lindex >= 0 && hangul_lindex < UTF8PROC_HANGUL_LCOUNT) { int boundclass = UTF8PROC_BOUNDCLASS_START;
utf8proc_int32_t hangul_vindex; while(1)
hangul_vindex = current_char - UTF8PROC_HANGUL_VBASE; {
if (hangul_vindex >= 0 && hangul_vindex < UTF8PROC_HANGUL_VCOUNT) { if(options & UTF8PROC_NULLTERM)
*starter = UTF8PROC_HANGUL_SBASE + {
(hangul_lindex * UTF8PROC_HANGUL_VCOUNT + hangul_vindex) * rpos += utf8proc_iterate(str + rpos, -1, &uc);
UTF8PROC_HANGUL_TCOUNT; /* checking of return value is not necessary,
starter_property = NULL; as 'uc' is < 0 in case of error */
continue; if(uc < 0)
} return UTF8PROC_ERROR_INVALIDUTF8;
} if(rpos < 0)
hangul_sindex = *starter - UTF8PROC_HANGUL_SBASE; return UTF8PROC_ERROR_OVERFLOW;
if (hangul_sindex >= 0 && hangul_sindex < UTF8PROC_HANGUL_SCOUNT && if(uc == 0)
(hangul_sindex % UTF8PROC_HANGUL_TCOUNT) == 0) { break;
utf8proc_int32_t hangul_tindex; }
hangul_tindex = current_char - UTF8PROC_HANGUL_TBASE; else
if (hangul_tindex >= 0 && hangul_tindex < UTF8PROC_HANGUL_TCOUNT) { {
*starter += hangul_tindex; if(rpos >= strlen)
starter_property = NULL; break;
continue; rpos += utf8proc_iterate(str + rpos, strlen - rpos, &uc);
} if(uc < 0)
return UTF8PROC_ERROR_INVALIDUTF8;
}
if(custom_func != NULL)
{
uc = custom_func(uc, custom_data); /* user-specified custom mapping */
}
decomp_result = utf8proc_decompose_char(
uc, buffer + wpos, (bufsize > wpos) ? (bufsize - wpos) : 0, options, &boundclass);
if(decomp_result < 0)
return decomp_result;
wpos += decomp_result;
/* prohibiting integer overflows due to too long strings: */
if(wpos < 0 || wpos > (utf8proc_ssize_t)(SSIZE_MAX / sizeof(utf8proc_int32_t) / 2))
return UTF8PROC_ERROR_OVERFLOW;
} }
if (!starter_property) { }
starter_property = unsafe_get_property(*starter); if((options & (UTF8PROC_COMPOSE | UTF8PROC_DECOMPOSE)) && bufsize >= wpos)
{
utf8proc_ssize_t pos = 0;
while(pos < wpos - 1)
{
utf8proc_int32_t uc1, uc2;
const utf8proc_property_t *property1, *property2;
uc1 = buffer[pos];
uc2 = buffer[pos + 1];
property1 = unsafe_get_property(uc1);
property2 = unsafe_get_property(uc2);
if(property1->combining_class > property2->combining_class &&
property2->combining_class > 0)
{
buffer[pos] = uc2;
buffer[pos + 1] = uc1;
if(pos > 0)
pos--;
else
pos++;
}
else
{
pos++;
}
} }
if (starter_property->comb_index < 0x8000 && }
current_property->comb_index != UINT16_MAX && return wpos;
current_property->comb_index >= 0x8000) { }
int sidx = starter_property->comb_index;
int idx = current_property->comb_index & 0x3FFF; UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_normalize_utf32(utf8proc_int32_t* buffer,
if (idx >= utf8proc_combinations[sidx] && idx <= utf8proc_combinations[sidx + 1] ) { utf8proc_ssize_t length,
idx += sidx + 2 - utf8proc_combinations[sidx]; utf8proc_option_t options)
if (current_property->comb_index & 0x4000) { {
composition = (utf8proc_combinations[idx] << 16) | utf8proc_combinations[idx+1]; /* UTF8PROC_NULLTERM option will be ignored, 'length' is never ignored */
} else if(options & (UTF8PROC_NLF2LS | UTF8PROC_NLF2PS | UTF8PROC_STRIPCC))
composition = utf8proc_combinations[idx]; {
utf8proc_ssize_t rpos;
if (composition > 0 && (!(options & UTF8PROC_STABLE) || utf8proc_ssize_t wpos = 0;
!(unsafe_get_property(composition)->comp_exclusion))) { utf8proc_int32_t uc;
*starter = composition; for(rpos = 0; rpos < length; rpos++)
starter_property = NULL; {
continue; uc = buffer[rpos];
if(uc == 0x000D && rpos < length - 1 && buffer[rpos + 1] == 0x000A)
rpos++;
if(uc == 0x000A || uc == 0x000D || uc == 0x0085 ||
((options & UTF8PROC_STRIPCC) && (uc == 0x000B || uc == 0x000C)))
{
if(options & UTF8PROC_NLF2LS)
{
if(options & UTF8PROC_NLF2PS)
{
buffer[wpos++] = 0x000A;
}
else
{
buffer[wpos++] = 0x2028;
}
}
else
{
if(options & UTF8PROC_NLF2PS)
{
buffer[wpos++] = 0x2029;
}
else
{
buffer[wpos++] = 0x0020;
}
}
}
else if((options & UTF8PROC_STRIPCC) && (uc < 0x0020 || (uc >= 0x007F && uc < 0x00A0)))
{
if(uc == 0x0009)
buffer[wpos++] = 0x0020;
}
else
{
buffer[wpos++] = uc;
} }
}
} }
} length = wpos;
buffer[wpos] = current_char; }
if (current_property->combining_class) { if(options & UTF8PROC_COMPOSE)
if (current_property->combining_class > max_combining_class) { {
max_combining_class = current_property->combining_class; utf8proc_int32_t* starter = NULL;
utf8proc_int32_t current_char;
const utf8proc_property_t *starter_property = NULL, *current_property;
utf8proc_propval_t max_combining_class = -1;
utf8proc_ssize_t rpos;
utf8proc_ssize_t wpos = 0;
utf8proc_int32_t composition;
for(rpos = 0; rpos < length; rpos++)
{
current_char = buffer[rpos];
current_property = unsafe_get_property(current_char);
if(starter && current_property->combining_class > max_combining_class)
{
/* combination perhaps possible */
utf8proc_int32_t hangul_lindex;
utf8proc_int32_t hangul_sindex;
hangul_lindex = *starter - UTF8PROC_HANGUL_LBASE;
if(hangul_lindex >= 0 && hangul_lindex < UTF8PROC_HANGUL_LCOUNT)
{
utf8proc_int32_t hangul_vindex;
hangul_vindex = current_char - UTF8PROC_HANGUL_VBASE;
if(hangul_vindex >= 0 && hangul_vindex < UTF8PROC_HANGUL_VCOUNT)
{
*starter = UTF8PROC_HANGUL_SBASE +
(hangul_lindex * UTF8PROC_HANGUL_VCOUNT + hangul_vindex) *
UTF8PROC_HANGUL_TCOUNT;
starter_property = NULL;
continue;
}
}
hangul_sindex = *starter - UTF8PROC_HANGUL_SBASE;
if(hangul_sindex >= 0 && hangul_sindex < UTF8PROC_HANGUL_SCOUNT &&
(hangul_sindex % UTF8PROC_HANGUL_TCOUNT) == 0)
{
utf8proc_int32_t hangul_tindex;
hangul_tindex = current_char - UTF8PROC_HANGUL_TBASE;
if(hangul_tindex >= 0 && hangul_tindex < UTF8PROC_HANGUL_TCOUNT)
{
*starter += hangul_tindex;
starter_property = NULL;
continue;
}
}
if(!starter_property)
{
starter_property = unsafe_get_property(*starter);
}
if(starter_property->comb_index < 0x8000 &&
current_property->comb_index != UINT16_MAX &&
current_property->comb_index >= 0x8000)
{
int sidx = starter_property->comb_index;
int idx = current_property->comb_index & 0x3FFF;
if(idx >= utf8proc_combinations[sidx] && idx <= utf8proc_combinations[sidx + 1])
{
idx += sidx + 2 - utf8proc_combinations[sidx];
if(current_property->comb_index & 0x4000)
{
composition =
(utf8proc_combinations[idx] << 16) | utf8proc_combinations[idx + 1];
}
else
composition = utf8proc_combinations[idx];
if(composition > 0 && (!(options & UTF8PROC_STABLE) ||
!(unsafe_get_property(composition)->comp_exclusion)))
{
*starter = composition;
starter_property = NULL;
continue;
}
}
}
}
buffer[wpos] = current_char;
if(current_property->combining_class)
{
if(current_property->combining_class > max_combining_class)
{
max_combining_class = current_property->combining_class;
}
}
else
{
starter = buffer + wpos;
starter_property = NULL;
max_combining_class = -1;
}
wpos++;
} }
} else { length = wpos;
starter = buffer + wpos; }
starter_property = NULL; return length;
max_combining_class = -1; }
}
wpos++; UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_reencode(utf8proc_int32_t* buffer,
} utf8proc_ssize_t length,
length = wpos; utf8proc_option_t options)
} {
return length; /* UTF8PROC_NULLTERM option will be ignored, 'length' is never ignored
} ASSERT: 'buffer' has one spare byte of free space at the end! */
length = utf8proc_normalize_utf32(buffer, length, options);
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_reencode(utf8proc_int32_t *buffer, utf8proc_ssize_t length, utf8proc_option_t options) { if(length < 0)
/* UTF8PROC_NULLTERM option will be ignored, 'length' is never ignored return length;
ASSERT: 'buffer' has one spare byte of free space at the end! */ {
length = utf8proc_normalize_utf32(buffer, length, options); utf8proc_ssize_t rpos, wpos = 0;
if (length < 0) return length; utf8proc_int32_t uc;
{ if(options & UTF8PROC_CHARBOUND)
utf8proc_ssize_t rpos, wpos = 0; {
utf8proc_int32_t uc; for(rpos = 0; rpos < length; rpos++)
if (options & UTF8PROC_CHARBOUND) { {
for (rpos = 0; rpos < length; rpos++) { uc = buffer[rpos];
uc = buffer[rpos]; wpos += charbound_encode_char(uc, ((utf8proc_uint8_t*)buffer) + wpos);
wpos += charbound_encode_char(uc, ((utf8proc_uint8_t *)buffer) + wpos); }
} }
} else { else
for (rpos = 0; rpos < length; rpos++) { {
uc = buffer[rpos]; for(rpos = 0; rpos < length; rpos++)
wpos += utf8proc_encode_char(uc, ((utf8proc_uint8_t *)buffer) + wpos); {
uc = buffer[rpos];
wpos += utf8proc_encode_char(uc, ((utf8proc_uint8_t*)buffer) + wpos);
}
} }
((utf8proc_uint8_t*)buffer)[wpos] = 0;
return wpos;
} }
((utf8proc_uint8_t *)buffer)[wpos] = 0;
return wpos;
}
} }
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_map( UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_map(const utf8proc_uint8_t* str,
const utf8proc_uint8_t *str, utf8proc_ssize_t strlen, utf8proc_uint8_t **dstptr, utf8proc_option_t options utf8proc_ssize_t strlen,
) { utf8proc_uint8_t** dstptr,
utf8proc_option_t options)
{
return utf8proc_map_custom(str, strlen, dstptr, options, NULL, NULL); return utf8proc_map_custom(str, strlen, dstptr, options, NULL, NULL);
} }
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_map_custom( UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_map_custom(const utf8proc_uint8_t* str,
const utf8proc_uint8_t *str, utf8proc_ssize_t strlen, utf8proc_uint8_t **dstptr, utf8proc_option_t options, utf8proc_ssize_t strlen,
utf8proc_custom_func custom_func, void *custom_data utf8proc_uint8_t** dstptr,
) { utf8proc_option_t options,
utf8proc_int32_t *buffer; utf8proc_custom_func custom_func,
utf8proc_ssize_t result; void* custom_data)
*dstptr = NULL; {
result = utf8proc_decompose_custom(str, strlen, NULL, 0, options, custom_func, custom_data); utf8proc_int32_t* buffer;
if (result < 0) return result; utf8proc_ssize_t result;
buffer = (utf8proc_int32_t *) malloc(((utf8proc_size_t)result) * sizeof(utf8proc_int32_t) + 1); *dstptr = NULL;
if (!buffer) return UTF8PROC_ERROR_NOMEM; result = utf8proc_decompose_custom(str, strlen, NULL, 0, options, custom_func, custom_data);
result = utf8proc_decompose_custom(str, strlen, buffer, result, options, custom_func, custom_data); if(result < 0)
if (result < 0) { return result;
free(buffer); buffer = (utf8proc_int32_t*)malloc(((utf8proc_size_t)result) * sizeof(utf8proc_int32_t) + 1);
return result; if(!buffer)
} return UTF8PROC_ERROR_NOMEM;
result = utf8proc_reencode(buffer, result, options); result =
if (result < 0) { utf8proc_decompose_custom(str, strlen, buffer, result, options, custom_func, custom_data);
free(buffer); if(result < 0)
{
free(buffer);
return result;
}
result = utf8proc_reencode(buffer, result, options);
if(result < 0)
{
free(buffer);
return result;
}
{
utf8proc_int32_t* newptr;
newptr = (utf8proc_int32_t*)realloc(buffer, (size_t)result + 1);
if(newptr)
buffer = newptr;
}
*dstptr = (utf8proc_uint8_t*)buffer;
return result; return result;
}
{
utf8proc_int32_t *newptr;
newptr = (utf8proc_int32_t *) realloc(buffer, (size_t)result+1);
if (newptr) buffer = newptr;
}
*dstptr = (utf8proc_uint8_t *)buffer;
return result;
} }
UTF8PROC_DLLEXPORT utf8proc_uint8_t *utf8proc_NFD(const utf8proc_uint8_t *str) { UTF8PROC_DLLEXPORT utf8proc_uint8_t* utf8proc_NFD(const utf8proc_uint8_t* str)
utf8proc_uint8_t *retval; {
utf8proc_map(str, 0, &retval, UTF8PROC_NULLTERM | UTF8PROC_STABLE | utf8proc_uint8_t* retval;
UTF8PROC_DECOMPOSE); utf8proc_map(str, 0, &retval, UTF8PROC_NULLTERM | UTF8PROC_STABLE | UTF8PROC_DECOMPOSE);
return retval; return retval;
} }
UTF8PROC_DLLEXPORT utf8proc_uint8_t *utf8proc_NFC(const utf8proc_uint8_t *str) { UTF8PROC_DLLEXPORT utf8proc_uint8_t* utf8proc_NFC(const utf8proc_uint8_t* str)
utf8proc_uint8_t *retval; {
utf8proc_map(str, 0, &retval, UTF8PROC_NULLTERM | UTF8PROC_STABLE | utf8proc_uint8_t* retval;
UTF8PROC_COMPOSE); utf8proc_map(str, 0, &retval, UTF8PROC_NULLTERM | UTF8PROC_STABLE | UTF8PROC_COMPOSE);
return retval; return retval;
} }
UTF8PROC_DLLEXPORT utf8proc_uint8_t *utf8proc_NFKD(const utf8proc_uint8_t *str) { UTF8PROC_DLLEXPORT utf8proc_uint8_t* utf8proc_NFKD(const utf8proc_uint8_t* str)
utf8proc_uint8_t *retval; {
utf8proc_map(str, 0, &retval, UTF8PROC_NULLTERM | UTF8PROC_STABLE | utf8proc_uint8_t* retval;
UTF8PROC_DECOMPOSE | UTF8PROC_COMPAT); utf8proc_map(str,
return retval; 0,
&retval,
UTF8PROC_NULLTERM | UTF8PROC_STABLE | UTF8PROC_DECOMPOSE | UTF8PROC_COMPAT);
return retval;
} }
UTF8PROC_DLLEXPORT utf8proc_uint8_t *utf8proc_NFKC(const utf8proc_uint8_t *str) { UTF8PROC_DLLEXPORT utf8proc_uint8_t* utf8proc_NFKC(const utf8proc_uint8_t* str)
utf8proc_uint8_t *retval; {
utf8proc_map(str, 0, &retval, UTF8PROC_NULLTERM | UTF8PROC_STABLE | utf8proc_uint8_t* retval;
UTF8PROC_COMPOSE | UTF8PROC_COMPAT); utf8proc_map(
return retval; str, 0, &retval, UTF8PROC_NULLTERM | UTF8PROC_STABLE | UTF8PROC_COMPOSE | UTF8PROC_COMPAT);
return retval;
} }
UTF8PROC_DLLEXPORT utf8proc_uint8_t *utf8proc_NFKC_Casefold(const utf8proc_uint8_t *str) { UTF8PROC_DLLEXPORT utf8proc_uint8_t* utf8proc_NFKC_Casefold(const utf8proc_uint8_t* str)
utf8proc_uint8_t *retval; {
utf8proc_map(str, 0, &retval, UTF8PROC_NULLTERM | UTF8PROC_STABLE | utf8proc_uint8_t* retval;
UTF8PROC_COMPOSE | UTF8PROC_COMPAT | UTF8PROC_CASEFOLD | UTF8PROC_IGNORE); utf8proc_map(str,
return retval; 0,
&retval,
UTF8PROC_NULLTERM | UTF8PROC_STABLE | UTF8PROC_COMPOSE | UTF8PROC_COMPAT |
UTF8PROC_CASEFOLD | UTF8PROC_IGNORE);
return retval;
} }
/* /*
* Copyright (c) 2014-2021 Steven G. Johnson, Jiahao Chen, Peter Colberg, Tony Kelman, Scott P. Jones, and other contributors. * Copyright (c) 2014-2021 Steven G. Johnson, Jiahao Chen, Peter Colberg, Tony
* Copyright (c) 2009 Public Software Group e. V., Berlin, Germany * Kelman, Scott P. Jones, and other contributors. Copyright (c) 2009 Public
* Software Group e. V., Berlin, Germany
* *
* Permission is hereby granted, free of charge, to any person obtaining a * Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"), * copy of this software and associated documentation files (the "Software"),
...@@ -21,7 +22,6 @@ ...@@ -21,7 +22,6 @@
* DEALINGS IN THE SOFTWARE. * DEALINGS IN THE SOFTWARE.
*/ */
/** /**
* @mainpage * @mainpage
* *
...@@ -37,15 +37,23 @@ ...@@ -37,15 +37,23 @@
* The features of utf8proc include: * The features of utf8proc include:
* *
* - Transformation of strings (@ref utf8proc_map) to: * - Transformation of strings (@ref utf8proc_map) to:
* - decompose (@ref UTF8PROC_DECOMPOSE) or compose (@ref UTF8PROC_COMPOSE) Unicode combining characters (http://en.wikipedia.org/wiki/Combining_character) * - decompose (@ref UTF8PROC_DECOMPOSE) or compose (@ref UTF8PROC_COMPOSE)
* Unicode combining characters
* (http://en.wikipedia.org/wiki/Combining_character)
* - canonicalize Unicode compatibility characters (@ref UTF8PROC_COMPAT) * - canonicalize Unicode compatibility characters (@ref UTF8PROC_COMPAT)
* - strip "ignorable" (@ref UTF8PROC_IGNORE) characters, control characters (@ref UTF8PROC_STRIPCC), or combining characters such as accents (@ref UTF8PROC_STRIPMARK) * - strip "ignorable" (@ref UTF8PROC_IGNORE) characters, control characters
* (@ref UTF8PROC_STRIPCC), or combining characters such as accents (@ref
* UTF8PROC_STRIPMARK)
* - case-folding (@ref UTF8PROC_CASEFOLD) * - case-folding (@ref UTF8PROC_CASEFOLD)
* - Unicode normalization: @ref utf8proc_NFD, @ref utf8proc_NFC, @ref utf8proc_NFKD, @ref utf8proc_NFKC * - Unicode normalization: @ref utf8proc_NFD, @ref utf8proc_NFC, @ref
* - Detecting grapheme boundaries (@ref utf8proc_grapheme_break and @ref UTF8PROC_CHARBOUND) * utf8proc_NFKD, @ref utf8proc_NFKC
* - Detecting grapheme boundaries (@ref utf8proc_grapheme_break and @ref
* UTF8PROC_CHARBOUND)
* - Character-width computation: @ref utf8proc_charwidth * - Character-width computation: @ref utf8proc_charwidth
* - Classification of characters by Unicode category: @ref utf8proc_category and @ref utf8proc_category_string * - Classification of characters by Unicode category: @ref utf8proc_category
* - Encode (@ref utf8proc_encode_char) and decode (@ref utf8proc_iterate) Unicode codepoints to/from UTF-8. * and @ref utf8proc_category_string
* - Encode (@ref utf8proc_encode_char) and decode (@ref utf8proc_iterate)
* Unicode codepoints to/from UTF-8.
*/ */
/** @file */ /** @file */
...@@ -68,9 +76,11 @@ ...@@ -68,9 +76,11 @@
* being based on ABI compatibility rather than API compatibility. * being based on ABI compatibility rather than API compatibility.
*/ */
/** @{ */ /** @{ */
/** The MAJOR version number (increased when backwards API compatibility is broken). */ /** The MAJOR version number (increased when backwards API compatibility is
* broken). */
#define UTF8PROC_VERSION_MAJOR 2 #define UTF8PROC_VERSION_MAJOR 2
/** The MINOR version number (increased when new functionality is added in a backwards-compatible manner). */ /** The MINOR version number (increased when new functionality is added in a
* backwards-compatible manner). */
#define UTF8PROC_VERSION_MINOR 8 #define UTF8PROC_VERSION_MINOR 8
/** The PATCH version (increased for fixes that do not change the API). */ /** The PATCH version (increased for fixes that do not change the API). */
#define UTF8PROC_VERSION_PATCH 0 #define UTF8PROC_VERSION_PATCH 0
...@@ -86,28 +96,29 @@ typedef short utf8proc_int16_t; ...@@ -86,28 +96,29 @@ typedef short utf8proc_int16_t;
typedef unsigned short utf8proc_uint16_t; typedef unsigned short utf8proc_uint16_t;
typedef int utf8proc_int32_t; typedef int utf8proc_int32_t;
typedef unsigned int utf8proc_uint32_t; typedef unsigned int utf8proc_uint32_t;
# ifdef _WIN64 #ifdef _WIN64
typedef __int64 utf8proc_ssize_t; typedef __int64 utf8proc_ssize_t;
typedef unsigned __int64 utf8proc_size_t; typedef unsigned __int64 utf8proc_size_t;
# else #else
typedef int utf8proc_ssize_t; typedef int utf8proc_ssize_t;
typedef unsigned int utf8proc_size_t; typedef unsigned int utf8proc_size_t;
# endif #endif
# ifndef __cplusplus #ifndef __cplusplus
// emulate C99 bool // emulate C99 bool
typedef unsigned char utf8proc_bool; typedef unsigned char utf8proc_bool;
# ifndef __bool_true_false_are_defined #ifndef __bool_true_false_are_defined
# define false 0 #define false 0
# define true 1 #define true 1
# define __bool_true_false_are_defined 1 #define __bool_true_false_are_defined 1
# endif #endif
# else #else
typedef bool utf8proc_bool; typedef bool utf8proc_bool;
# endif #endif
#else #else
# include <stddef.h> #include <inttypes.h>
# include <stdbool.h> #include <stdbool.h>
# include <inttypes.h> #include <stddef.h>
typedef int8_t utf8proc_int8_t; typedef int8_t utf8proc_int8_t;
typedef uint8_t utf8proc_uint8_t; typedef uint8_t utf8proc_uint8_t;
typedef int16_t utf8proc_int16_t; typedef int16_t utf8proc_int16_t;
...@@ -121,19 +132,19 @@ typedef bool utf8proc_bool; ...@@ -121,19 +132,19 @@ typedef bool utf8proc_bool;
#include <limits.h> #include <limits.h>
#ifdef UTF8PROC_STATIC #ifdef UTF8PROC_STATIC
# define UTF8PROC_DLLEXPORT #define UTF8PROC_DLLEXPORT
#else #else
# ifdef _WIN32 #ifdef _WIN32
# ifdef UTF8PROC_EXPORTS #ifdef UTF8PROC_EXPORTS
# define UTF8PROC_DLLEXPORT __declspec(dllexport) #define UTF8PROC_DLLEXPORT __declspec(dllexport)
# else #else
# define UTF8PROC_DLLEXPORT __declspec(dllimport) #define UTF8PROC_DLLEXPORT __declspec(dllimport)
# endif #endif
# elif __GNUC__ >= 4 #elif __GNUC__ >= 4
# define UTF8PROC_DLLEXPORT __attribute__ ((visibility("default"))) #define UTF8PROC_DLLEXPORT __attribute__((visibility("default")))
# else #else
# define UTF8PROC_DLLEXPORT #define UTF8PROC_DLLEXPORT
# endif #endif
#endif #endif
#ifdef __cplusplus #ifdef __cplusplus
...@@ -143,72 +154,74 @@ extern "C" { ...@@ -143,72 +154,74 @@ extern "C" {
/** /**
* Option flags used by several functions in the library. * Option flags used by several functions in the library.
*/ */
typedef enum { typedef enum
/** The given UTF-8 input is NULL terminated. */ {
UTF8PROC_NULLTERM = (1<<0), /** The given UTF-8 input is NULL terminated. */
/** Unicode Versioning Stability has to be respected. */ UTF8PROC_NULLTERM = (1 << 0),
UTF8PROC_STABLE = (1<<1), /** Unicode Versioning Stability has to be respected. */
/** Compatibility decomposition (i.e. formatting information is lost). */ UTF8PROC_STABLE = (1 << 1),
UTF8PROC_COMPAT = (1<<2), /** Compatibility decomposition (i.e. formatting information is lost). */
/** Return a result with decomposed characters. */ UTF8PROC_COMPAT = (1 << 2),
UTF8PROC_COMPOSE = (1<<3), /** Return a result with decomposed characters. */
/** Return a result with decomposed characters. */ UTF8PROC_COMPOSE = (1 << 3),
UTF8PROC_DECOMPOSE = (1<<4), /** Return a result with decomposed characters. */
/** Strip "default ignorable characters" such as SOFT-HYPHEN or ZERO-WIDTH-SPACE. */ UTF8PROC_DECOMPOSE = (1 << 4),
UTF8PROC_IGNORE = (1<<5), /** Strip "default ignorable characters" such as SOFT-HYPHEN or
/** Return an error, if the input contains unassigned codepoints. */ ZERO-WIDTH-SPACE. */
UTF8PROC_REJECTNA = (1<<6), UTF8PROC_IGNORE = (1 << 5),
/** /** Return an error, if the input contains unassigned codepoints. */
* Indicating that NLF-sequences (LF, CRLF, CR, NEL) are representing a UTF8PROC_REJECTNA = (1 << 6),
* line break, and should be converted to the codepoint for line /**
* separation (LS). * Indicating that NLF-sequences (LF, CRLF, CR, NEL) are representing a
*/ * line break, and should be converted to the codepoint for line
UTF8PROC_NLF2LS = (1<<7), * separation (LS).
/** */
* Indicating that NLF-sequences are representing a paragraph break, and UTF8PROC_NLF2LS = (1 << 7),
* should be converted to the codepoint for paragraph separation /**
* (PS). * Indicating that NLF-sequences are representing a paragraph break, and
*/ * should be converted to the codepoint for paragraph separation
UTF8PROC_NLF2PS = (1<<8), * (PS).
/** Indicating that the meaning of NLF-sequences is unknown. */ */
UTF8PROC_NLF2LF = (UTF8PROC_NLF2LS | UTF8PROC_NLF2PS), UTF8PROC_NLF2PS = (1 << 8),
/** Strips and/or convers control characters. /** Indicating that the meaning of NLF-sequences is unknown. */
* UTF8PROC_NLF2LF = (UTF8PROC_NLF2LS | UTF8PROC_NLF2PS),
* NLF-sequences are transformed into space, except if one of the /** Strips and/or convers control characters.
* NLF2LS/PS/LF options is given. HorizontalTab (HT) and FormFeed (FF) *
* are treated as a NLF-sequence in this case. All other control * NLF-sequences are transformed into space, except if one of the
* characters are simply removed. * NLF2LS/PS/LF options is given. HorizontalTab (HT) and FormFeed (FF)
*/ * are treated as a NLF-sequence in this case. All other control
UTF8PROC_STRIPCC = (1<<9), * characters are simply removed.
/** */
* Performs unicode case folding, to be able to do a case-insensitive UTF8PROC_STRIPCC = (1 << 9),
* string comparison. /**
*/ * Performs unicode case folding, to be able to do a case-insensitive
UTF8PROC_CASEFOLD = (1<<10), * string comparison.
/** */
* Inserts 0xFF bytes at the beginning of each sequence which is UTF8PROC_CASEFOLD = (1 << 10),
* representing a single grapheme cluster (see UAX#29). /**
*/ * Inserts 0xFF bytes at the beginning of each sequence which is
UTF8PROC_CHARBOUND = (1<<11), * representing a single grapheme cluster (see UAX#29).
/** Lumps certain characters together. */
* UTF8PROC_CHARBOUND = (1 << 11),
* E.g. HYPHEN U+2010 and MINUS U+2212 to ASCII "-". See lump.md for details. /** Lumps certain characters together.
* *
* If NLF2LF is set, this includes a transformation of paragraph and * E.g. HYPHEN U+2010 and MINUS U+2212 to ASCII "-". See lump.md for details.
* line separators to ASCII line-feed (LF). *
*/ * If NLF2LF is set, this includes a transformation of paragraph and
UTF8PROC_LUMP = (1<<12), * line separators to ASCII line-feed (LF).
/** Strips all character markings. */
* UTF8PROC_LUMP = (1 << 12),
* This includes non-spacing, spacing and enclosing (i.e. accents). /** Strips all character markings.
* @note This option works only with @ref UTF8PROC_COMPOSE or *
* @ref UTF8PROC_DECOMPOSE * This includes non-spacing, spacing and enclosing (i.e. accents).
*/ * @note This option works only with @ref UTF8PROC_COMPOSE or
UTF8PROC_STRIPMARK = (1<<13), * @ref UTF8PROC_DECOMPOSE
/** */
* Strip unassigned codepoints. UTF8PROC_STRIPMARK = (1 << 13),
*/ /**
UTF8PROC_STRIPNA = (1<<14), * Strip unassigned codepoints.
*/
UTF8PROC_STRIPNA = (1 << 14),
} utf8proc_option_t; } utf8proc_option_t;
/** @name Error codes /** @name Error codes
...@@ -221,7 +234,8 @@ typedef enum { ...@@ -221,7 +234,8 @@ typedef enum {
#define UTF8PROC_ERROR_OVERFLOW -2 #define UTF8PROC_ERROR_OVERFLOW -2
/** The given string is not a legal UTF-8 string. */ /** The given string is not a legal UTF-8 string. */
#define UTF8PROC_ERROR_INVALIDUTF8 -3 #define UTF8PROC_ERROR_INVALIDUTF8 -3
/** The @ref UTF8PROC_REJECTNA flag was set and an unassigned codepoint was found. */ /** The @ref UTF8PROC_REJECTNA flag was set and an unassigned codepoint was
* found. */
#define UTF8PROC_ERROR_NOTASSIGNED -4 #define UTF8PROC_ERROR_NOTASSIGNED -4
/** Invalid options have been used. */ /** Invalid options have been used. */
#define UTF8PROC_ERROR_INVALIDOPTS -5 #define UTF8PROC_ERROR_INVALIDOPTS -5
...@@ -233,159 +247,164 @@ typedef enum { ...@@ -233,159 +247,164 @@ typedef enum {
typedef utf8proc_int16_t utf8proc_propval_t; typedef utf8proc_int16_t utf8proc_propval_t;
/** Struct containing information about a codepoint. */ /** Struct containing information about a codepoint. */
typedef struct utf8proc_property_struct { typedef struct utf8proc_property_struct
/** {
* Unicode category. /**
* @see utf8proc_category_t. * Unicode category.
*/ * @see utf8proc_category_t.
utf8proc_propval_t category; */
utf8proc_propval_t combining_class; utf8proc_propval_t category;
/** utf8proc_propval_t combining_class;
* Bidirectional class. /**
* @see utf8proc_bidi_class_t. * Bidirectional class.
*/ * @see utf8proc_bidi_class_t.
utf8proc_propval_t bidi_class; */
/** utf8proc_propval_t bidi_class;
* @anchor Decomposition type. /**
* @see utf8proc_decomp_type_t. * @anchor Decomposition type.
*/ * @see utf8proc_decomp_type_t.
utf8proc_propval_t decomp_type; */
utf8proc_uint16_t decomp_seqindex; utf8proc_propval_t decomp_type;
utf8proc_uint16_t casefold_seqindex; utf8proc_uint16_t decomp_seqindex;
utf8proc_uint16_t uppercase_seqindex; utf8proc_uint16_t casefold_seqindex;
utf8proc_uint16_t lowercase_seqindex; utf8proc_uint16_t uppercase_seqindex;
utf8proc_uint16_t titlecase_seqindex; utf8proc_uint16_t lowercase_seqindex;
utf8proc_uint16_t comb_index; utf8proc_uint16_t titlecase_seqindex;
unsigned bidi_mirrored:1; utf8proc_uint16_t comb_index;
unsigned comp_exclusion:1; unsigned bidi_mirrored : 1;
/** unsigned comp_exclusion : 1;
* Can this codepoint be ignored? /**
* * Can this codepoint be ignored?
* Used by @ref utf8proc_decompose_char when @ref UTF8PROC_IGNORE is *
* passed as an option. * Used by @ref utf8proc_decompose_char when @ref UTF8PROC_IGNORE is
*/ * passed as an option.
unsigned ignorable:1; */
unsigned control_boundary:1; unsigned ignorable : 1;
/** The width of the codepoint. */ unsigned control_boundary : 1;
unsigned charwidth:2; /** The width of the codepoint. */
unsigned pad:2; unsigned charwidth : 2;
/** unsigned pad : 2;
* Boundclass. /**
* @see utf8proc_boundclass_t. * Boundclass.
*/ * @see utf8proc_boundclass_t.
unsigned boundclass:8; */
unsigned boundclass : 8;
} utf8proc_property_t; } utf8proc_property_t;
/** Unicode categories. */ /** Unicode categories. */
typedef enum { typedef enum
UTF8PROC_CATEGORY_CN = 0, /**< Other, not assigned */ {
UTF8PROC_CATEGORY_LU = 1, /**< Letter, uppercase */ UTF8PROC_CATEGORY_CN = 0, /**< Other, not assigned */
UTF8PROC_CATEGORY_LL = 2, /**< Letter, lowercase */ UTF8PROC_CATEGORY_LU = 1, /**< Letter, uppercase */
UTF8PROC_CATEGORY_LT = 3, /**< Letter, titlecase */ UTF8PROC_CATEGORY_LL = 2, /**< Letter, lowercase */
UTF8PROC_CATEGORY_LM = 4, /**< Letter, modifier */ UTF8PROC_CATEGORY_LT = 3, /**< Letter, titlecase */
UTF8PROC_CATEGORY_LO = 5, /**< Letter, other */ UTF8PROC_CATEGORY_LM = 4, /**< Letter, modifier */
UTF8PROC_CATEGORY_MN = 6, /**< Mark, nonspacing */ UTF8PROC_CATEGORY_LO = 5, /**< Letter, other */
UTF8PROC_CATEGORY_MC = 7, /**< Mark, spacing combining */ UTF8PROC_CATEGORY_MN = 6, /**< Mark, nonspacing */
UTF8PROC_CATEGORY_ME = 8, /**< Mark, enclosing */ UTF8PROC_CATEGORY_MC = 7, /**< Mark, spacing combining */
UTF8PROC_CATEGORY_ND = 9, /**< Number, decimal digit */ UTF8PROC_CATEGORY_ME = 8, /**< Mark, enclosing */
UTF8PROC_CATEGORY_NL = 10, /**< Number, letter */ UTF8PROC_CATEGORY_ND = 9, /**< Number, decimal digit */
UTF8PROC_CATEGORY_NO = 11, /**< Number, other */ UTF8PROC_CATEGORY_NL = 10, /**< Number, letter */
UTF8PROC_CATEGORY_PC = 12, /**< Punctuation, connector */ UTF8PROC_CATEGORY_NO = 11, /**< Number, other */
UTF8PROC_CATEGORY_PD = 13, /**< Punctuation, dash */ UTF8PROC_CATEGORY_PC = 12, /**< Punctuation, connector */
UTF8PROC_CATEGORY_PS = 14, /**< Punctuation, open */ UTF8PROC_CATEGORY_PD = 13, /**< Punctuation, dash */
UTF8PROC_CATEGORY_PE = 15, /**< Punctuation, close */ UTF8PROC_CATEGORY_PS = 14, /**< Punctuation, open */
UTF8PROC_CATEGORY_PI = 16, /**< Punctuation, initial quote */ UTF8PROC_CATEGORY_PE = 15, /**< Punctuation, close */
UTF8PROC_CATEGORY_PF = 17, /**< Punctuation, final quote */ UTF8PROC_CATEGORY_PI = 16, /**< Punctuation, initial quote */
UTF8PROC_CATEGORY_PO = 18, /**< Punctuation, other */ UTF8PROC_CATEGORY_PF = 17, /**< Punctuation, final quote */
UTF8PROC_CATEGORY_SM = 19, /**< Symbol, math */ UTF8PROC_CATEGORY_PO = 18, /**< Punctuation, other */
UTF8PROC_CATEGORY_SC = 20, /**< Symbol, currency */ UTF8PROC_CATEGORY_SM = 19, /**< Symbol, math */
UTF8PROC_CATEGORY_SK = 21, /**< Symbol, modifier */ UTF8PROC_CATEGORY_SC = 20, /**< Symbol, currency */
UTF8PROC_CATEGORY_SO = 22, /**< Symbol, other */ UTF8PROC_CATEGORY_SK = 21, /**< Symbol, modifier */
UTF8PROC_CATEGORY_ZS = 23, /**< Separator, space */ UTF8PROC_CATEGORY_SO = 22, /**< Symbol, other */
UTF8PROC_CATEGORY_ZL = 24, /**< Separator, line */ UTF8PROC_CATEGORY_ZS = 23, /**< Separator, space */
UTF8PROC_CATEGORY_ZP = 25, /**< Separator, paragraph */ UTF8PROC_CATEGORY_ZL = 24, /**< Separator, line */
UTF8PROC_CATEGORY_CC = 26, /**< Other, control */ UTF8PROC_CATEGORY_ZP = 25, /**< Separator, paragraph */
UTF8PROC_CATEGORY_CF = 27, /**< Other, format */ UTF8PROC_CATEGORY_CC = 26, /**< Other, control */
UTF8PROC_CATEGORY_CS = 28, /**< Other, surrogate */ UTF8PROC_CATEGORY_CF = 27, /**< Other, format */
UTF8PROC_CATEGORY_CO = 29, /**< Other, private use */ UTF8PROC_CATEGORY_CS = 28, /**< Other, surrogate */
UTF8PROC_CATEGORY_CO = 29, /**< Other, private use */
} utf8proc_category_t; } utf8proc_category_t;
/** Bidirectional character classes. */ /** Bidirectional character classes. */
typedef enum { typedef enum
UTF8PROC_BIDI_CLASS_L = 1, /**< Left-to-Right */ {
UTF8PROC_BIDI_CLASS_LRE = 2, /**< Left-to-Right Embedding */ UTF8PROC_BIDI_CLASS_L = 1, /**< Left-to-Right */
UTF8PROC_BIDI_CLASS_LRO = 3, /**< Left-to-Right Override */ UTF8PROC_BIDI_CLASS_LRE = 2, /**< Left-to-Right Embedding */
UTF8PROC_BIDI_CLASS_R = 4, /**< Right-to-Left */ UTF8PROC_BIDI_CLASS_LRO = 3, /**< Left-to-Right Override */
UTF8PROC_BIDI_CLASS_AL = 5, /**< Right-to-Left Arabic */ UTF8PROC_BIDI_CLASS_R = 4, /**< Right-to-Left */
UTF8PROC_BIDI_CLASS_RLE = 6, /**< Right-to-Left Embedding */ UTF8PROC_BIDI_CLASS_AL = 5, /**< Right-to-Left Arabic */
UTF8PROC_BIDI_CLASS_RLO = 7, /**< Right-to-Left Override */ UTF8PROC_BIDI_CLASS_RLE = 6, /**< Right-to-Left Embedding */
UTF8PROC_BIDI_CLASS_PDF = 8, /**< Pop Directional Format */ UTF8PROC_BIDI_CLASS_RLO = 7, /**< Right-to-Left Override */
UTF8PROC_BIDI_CLASS_EN = 9, /**< European Number */ UTF8PROC_BIDI_CLASS_PDF = 8, /**< Pop Directional Format */
UTF8PROC_BIDI_CLASS_ES = 10, /**< European Separator */ UTF8PROC_BIDI_CLASS_EN = 9, /**< European Number */
UTF8PROC_BIDI_CLASS_ET = 11, /**< European Number Terminator */ UTF8PROC_BIDI_CLASS_ES = 10, /**< European Separator */
UTF8PROC_BIDI_CLASS_AN = 12, /**< Arabic Number */ UTF8PROC_BIDI_CLASS_ET = 11, /**< European Number Terminator */
UTF8PROC_BIDI_CLASS_CS = 13, /**< Common Number Separator */ UTF8PROC_BIDI_CLASS_AN = 12, /**< Arabic Number */
UTF8PROC_BIDI_CLASS_NSM = 14, /**< Nonspacing Mark */ UTF8PROC_BIDI_CLASS_CS = 13, /**< Common Number Separator */
UTF8PROC_BIDI_CLASS_BN = 15, /**< Boundary Neutral */ UTF8PROC_BIDI_CLASS_NSM = 14, /**< Nonspacing Mark */
UTF8PROC_BIDI_CLASS_B = 16, /**< Paragraph Separator */ UTF8PROC_BIDI_CLASS_BN = 15, /**< Boundary Neutral */
UTF8PROC_BIDI_CLASS_S = 17, /**< Segment Separator */ UTF8PROC_BIDI_CLASS_B = 16, /**< Paragraph Separator */
UTF8PROC_BIDI_CLASS_WS = 18, /**< Whitespace */ UTF8PROC_BIDI_CLASS_S = 17, /**< Segment Separator */
UTF8PROC_BIDI_CLASS_ON = 19, /**< Other Neutrals */ UTF8PROC_BIDI_CLASS_WS = 18, /**< Whitespace */
UTF8PROC_BIDI_CLASS_LRI = 20, /**< Left-to-Right Isolate */ UTF8PROC_BIDI_CLASS_ON = 19, /**< Other Neutrals */
UTF8PROC_BIDI_CLASS_RLI = 21, /**< Right-to-Left Isolate */ UTF8PROC_BIDI_CLASS_LRI = 20, /**< Left-to-Right Isolate */
UTF8PROC_BIDI_CLASS_FSI = 22, /**< First Strong Isolate */ UTF8PROC_BIDI_CLASS_RLI = 21, /**< Right-to-Left Isolate */
UTF8PROC_BIDI_CLASS_PDI = 23, /**< Pop Directional Isolate */ UTF8PROC_BIDI_CLASS_FSI = 22, /**< First Strong Isolate */
UTF8PROC_BIDI_CLASS_PDI = 23, /**< Pop Directional Isolate */
} utf8proc_bidi_class_t; } utf8proc_bidi_class_t;
/** Decomposition type. */ /** Decomposition type. */
typedef enum { typedef enum
UTF8PROC_DECOMP_TYPE_FONT = 1, /**< Font */ {
UTF8PROC_DECOMP_TYPE_NOBREAK = 2, /**< Nobreak */ UTF8PROC_DECOMP_TYPE_FONT = 1, /**< Font */
UTF8PROC_DECOMP_TYPE_INITIAL = 3, /**< Initial */ UTF8PROC_DECOMP_TYPE_NOBREAK = 2, /**< Nobreak */
UTF8PROC_DECOMP_TYPE_MEDIAL = 4, /**< Medial */ UTF8PROC_DECOMP_TYPE_INITIAL = 3, /**< Initial */
UTF8PROC_DECOMP_TYPE_FINAL = 5, /**< Final */ UTF8PROC_DECOMP_TYPE_MEDIAL = 4, /**< Medial */
UTF8PROC_DECOMP_TYPE_ISOLATED = 6, /**< Isolated */ UTF8PROC_DECOMP_TYPE_FINAL = 5, /**< Final */
UTF8PROC_DECOMP_TYPE_CIRCLE = 7, /**< Circle */ UTF8PROC_DECOMP_TYPE_ISOLATED = 6, /**< Isolated */
UTF8PROC_DECOMP_TYPE_SUPER = 8, /**< Super */ UTF8PROC_DECOMP_TYPE_CIRCLE = 7, /**< Circle */
UTF8PROC_DECOMP_TYPE_SUB = 9, /**< Sub */ UTF8PROC_DECOMP_TYPE_SUPER = 8, /**< Super */
UTF8PROC_DECOMP_TYPE_VERTICAL = 10, /**< Vertical */ UTF8PROC_DECOMP_TYPE_SUB = 9, /**< Sub */
UTF8PROC_DECOMP_TYPE_WIDE = 11, /**< Wide */ UTF8PROC_DECOMP_TYPE_VERTICAL = 10, /**< Vertical */
UTF8PROC_DECOMP_TYPE_NARROW = 12, /**< Narrow */ UTF8PROC_DECOMP_TYPE_WIDE = 11, /**< Wide */
UTF8PROC_DECOMP_TYPE_SMALL = 13, /**< Small */ UTF8PROC_DECOMP_TYPE_NARROW = 12, /**< Narrow */
UTF8PROC_DECOMP_TYPE_SQUARE = 14, /**< Square */ UTF8PROC_DECOMP_TYPE_SMALL = 13, /**< Small */
UTF8PROC_DECOMP_TYPE_FRACTION = 15, /**< Fraction */ UTF8PROC_DECOMP_TYPE_SQUARE = 14, /**< Square */
UTF8PROC_DECOMP_TYPE_COMPAT = 16, /**< Compat */ UTF8PROC_DECOMP_TYPE_FRACTION = 15, /**< Fraction */
UTF8PROC_DECOMP_TYPE_COMPAT = 16, /**< Compat */
} utf8proc_decomp_type_t; } utf8proc_decomp_type_t;
/** Boundclass property. (TR29) */ /** Boundclass property. (TR29) */
typedef enum { typedef enum
UTF8PROC_BOUNDCLASS_START = 0, /**< Start */ {
UTF8PROC_BOUNDCLASS_OTHER = 1, /**< Other */ UTF8PROC_BOUNDCLASS_START = 0, /**< Start */
UTF8PROC_BOUNDCLASS_CR = 2, /**< Cr */ UTF8PROC_BOUNDCLASS_OTHER = 1, /**< Other */
UTF8PROC_BOUNDCLASS_LF = 3, /**< Lf */ UTF8PROC_BOUNDCLASS_CR = 2, /**< Cr */
UTF8PROC_BOUNDCLASS_CONTROL = 4, /**< Control */ UTF8PROC_BOUNDCLASS_LF = 3, /**< Lf */
UTF8PROC_BOUNDCLASS_EXTEND = 5, /**< Extend */ UTF8PROC_BOUNDCLASS_CONTROL = 4, /**< Control */
UTF8PROC_BOUNDCLASS_L = 6, /**< L */ UTF8PROC_BOUNDCLASS_EXTEND = 5, /**< Extend */
UTF8PROC_BOUNDCLASS_V = 7, /**< V */ UTF8PROC_BOUNDCLASS_L = 6, /**< L */
UTF8PROC_BOUNDCLASS_T = 8, /**< T */ UTF8PROC_BOUNDCLASS_V = 7, /**< V */
UTF8PROC_BOUNDCLASS_LV = 9, /**< Lv */ UTF8PROC_BOUNDCLASS_T = 8, /**< T */
UTF8PROC_BOUNDCLASS_LVT = 10, /**< Lvt */ UTF8PROC_BOUNDCLASS_LV = 9, /**< Lv */
UTF8PROC_BOUNDCLASS_REGIONAL_INDICATOR = 11, /**< Regional indicator */ UTF8PROC_BOUNDCLASS_LVT = 10, /**< Lvt */
UTF8PROC_BOUNDCLASS_SPACINGMARK = 12, /**< Spacingmark */ UTF8PROC_BOUNDCLASS_REGIONAL_INDICATOR = 11, /**< Regional indicator */
UTF8PROC_BOUNDCLASS_PREPEND = 13, /**< Prepend */ UTF8PROC_BOUNDCLASS_SPACINGMARK = 12, /**< Spacingmark */
UTF8PROC_BOUNDCLASS_ZWJ = 14, /**< Zero Width Joiner */ UTF8PROC_BOUNDCLASS_PREPEND = 13, /**< Prepend */
UTF8PROC_BOUNDCLASS_ZWJ = 14, /**< Zero Width Joiner */
/* the following are no longer used in Unicode 11, but we keep
the constants here for backward compatibility */ /* the following are no longer used in Unicode 11, but we keep
UTF8PROC_BOUNDCLASS_E_BASE = 15, /**< Emoji Base */ the constants here for backward compatibility */
UTF8PROC_BOUNDCLASS_E_MODIFIER = 16, /**< Emoji Modifier */ UTF8PROC_BOUNDCLASS_E_BASE = 15, /**< Emoji Base */
UTF8PROC_BOUNDCLASS_GLUE_AFTER_ZWJ = 17, /**< Glue_After_ZWJ */ UTF8PROC_BOUNDCLASS_E_MODIFIER = 16, /**< Emoji Modifier */
UTF8PROC_BOUNDCLASS_E_BASE_GAZ = 18, /**< E_BASE + GLUE_AFTER_ZJW */ UTF8PROC_BOUNDCLASS_GLUE_AFTER_ZWJ = 17, /**< Glue_After_ZWJ */
UTF8PROC_BOUNDCLASS_E_BASE_GAZ = 18, /**< E_BASE + GLUE_AFTER_ZJW */
/* the Extended_Pictographic property is used in the Unicode 11
grapheme-boundary rules, so we store it in the boundclass field */ /* the Extended_Pictographic property is used in the Unicode 11
UTF8PROC_BOUNDCLASS_EXTENDED_PICTOGRAPHIC = 19, grapheme-boundary rules, so we store it in the boundclass field */
UTF8PROC_BOUNDCLASS_E_ZWG = 20, /* UTF8PROC_BOUNDCLASS_EXTENDED_PICTOGRAPHIC + ZWJ */ UTF8PROC_BOUNDCLASS_EXTENDED_PICTOGRAPHIC = 19,
UTF8PROC_BOUNDCLASS_E_ZWG = 20, /* UTF8PROC_BOUNDCLASS_EXTENDED_PICTOGRAPHIC + ZWJ */
} utf8proc_boundclass_t; } utf8proc_boundclass_t;
/** /**
...@@ -393,7 +412,7 @@ typedef enum { ...@@ -393,7 +412,7 @@ typedef enum {
* @ref utf8proc_decompose_custom, which is used to specify a user-defined * @ref utf8proc_decompose_custom, which is used to specify a user-defined
* mapping of codepoints to be applied in conjunction with other mappings. * mapping of codepoints to be applied in conjunction with other mappings.
*/ */
typedef utf8proc_int32_t (*utf8proc_custom_func)(utf8proc_int32_t codepoint, void *data); typedef utf8proc_int32_t (*utf8proc_custom_func)(utf8proc_int32_t codepoint, void* data);
/** /**
* Array containing the byte lengths of a UTF-8 encoded codepoint based * Array containing the byte lengths of a UTF-8 encoded codepoint based
...@@ -406,18 +425,18 @@ UTF8PROC_DLLEXPORT extern const utf8proc_int8_t utf8proc_utf8class[256]; ...@@ -406,18 +425,18 @@ UTF8PROC_DLLEXPORT extern const utf8proc_int8_t utf8proc_utf8class[256];
* (http://semver.org format), possibly with a "-dev" suffix for * (http://semver.org format), possibly with a "-dev" suffix for
* development versions. * development versions.
*/ */
UTF8PROC_DLLEXPORT const char *utf8proc_version(void); UTF8PROC_DLLEXPORT const char* utf8proc_version(void);
/** /**
* Returns the utf8proc supported Unicode version as a string MAJOR.MINOR.PATCH. * Returns the utf8proc supported Unicode version as a string MAJOR.MINOR.PATCH.
*/ */
UTF8PROC_DLLEXPORT const char *utf8proc_unicode_version(void); UTF8PROC_DLLEXPORT const char* utf8proc_unicode_version(void);
/** /**
* Returns an informative error string for the given utf8proc error code * Returns an informative error string for the given utf8proc error code
* (e.g. the error codes returned by @ref utf8proc_map). * (e.g. the error codes returned by @ref utf8proc_map).
*/ */
UTF8PROC_DLLEXPORT const char *utf8proc_errmsg(utf8proc_ssize_t errcode); UTF8PROC_DLLEXPORT const char* utf8proc_errmsg(utf8proc_ssize_t errcode);
/** /**
* Reads a single codepoint from the UTF-8 sequence being pointed to by `str`. * Reads a single codepoint from the UTF-8 sequence being pointed to by `str`.
...@@ -429,7 +448,9 @@ UTF8PROC_DLLEXPORT const char *utf8proc_errmsg(utf8proc_ssize_t errcode); ...@@ -429,7 +448,9 @@ UTF8PROC_DLLEXPORT const char *utf8proc_errmsg(utf8proc_ssize_t errcode);
* In case of success, the number of bytes read is returned; otherwise, a * In case of success, the number of bytes read is returned; otherwise, a
* negative error code is returned. * negative error code is returned.
*/ */
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_iterate(const utf8proc_uint8_t *str, utf8proc_ssize_t strlen, utf8proc_int32_t *codepoint_ref); UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_iterate(const utf8proc_uint8_t* str,
utf8proc_ssize_t strlen,
utf8proc_int32_t* codepoint_ref);
/** /**
* Check if a codepoint is valid (regardless of whether it has been * Check if a codepoint is valid (regardless of whether it has been
...@@ -448,7 +469,8 @@ UTF8PROC_DLLEXPORT utf8proc_bool utf8proc_codepoint_valid(utf8proc_int32_t codep ...@@ -448,7 +469,8 @@ UTF8PROC_DLLEXPORT utf8proc_bool utf8proc_codepoint_valid(utf8proc_int32_t codep
* *
* This function does not check whether `codepoint` is valid Unicode. * This function does not check whether `codepoint` is valid Unicode.
*/ */
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_encode_char(utf8proc_int32_t codepoint, utf8proc_uint8_t *dst); UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_encode_char(utf8proc_int32_t codepoint,
utf8proc_uint8_t* dst);
/** /**
* Look up the properties for a given codepoint. * Look up the properties for a given codepoint.
...@@ -462,7 +484,7 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_encode_char(utf8proc_int32_t codepo ...@@ -462,7 +484,7 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_encode_char(utf8proc_int32_t codepo
* If the codepoint is unassigned or invalid, a pointer to a special struct is * If the codepoint is unassigned or invalid, a pointer to a special struct is
* returned in which `category` is 0 (@ref UTF8PROC_CATEGORY_CN). * returned in which `category` is 0 (@ref UTF8PROC_CATEGORY_CN).
*/ */
UTF8PROC_DLLEXPORT const utf8proc_property_t *utf8proc_get_property(utf8proc_int32_t codepoint); UTF8PROC_DLLEXPORT const utf8proc_property_t* utf8proc_get_property(utf8proc_int32_t codepoint);
/** Decompose a codepoint into an array of codepoints. /** Decompose a codepoint into an array of codepoints.
* *
...@@ -492,10 +514,11 @@ UTF8PROC_DLLEXPORT const utf8proc_property_t *utf8proc_get_property(utf8proc_int ...@@ -492,10 +514,11 @@ UTF8PROC_DLLEXPORT const utf8proc_property_t *utf8proc_get_property(utf8proc_int
* required buffer size is returned, while the buffer will be overwritten with * required buffer size is returned, while the buffer will be overwritten with
* undefined data. * undefined data.
*/ */
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose_char( UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose_char(utf8proc_int32_t codepoint,
utf8proc_int32_t codepoint, utf8proc_int32_t *dst, utf8proc_ssize_t bufsize, utf8proc_int32_t* dst,
utf8proc_option_t options, int *last_boundclass utf8proc_ssize_t bufsize,
); utf8proc_option_t options,
int* last_boundclass);
/** /**
* The same as @ref utf8proc_decompose_char, but acts on a whole UTF-8 * The same as @ref utf8proc_decompose_char, but acts on a whole UTF-8
...@@ -514,22 +537,26 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose_char( ...@@ -514,22 +537,26 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose_char(
* required buffer size is returned, while the buffer will be overwritten with * required buffer size is returned, while the buffer will be overwritten with
* undefined data. * undefined data.
*/ */
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose( UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose(const utf8proc_uint8_t* str,
const utf8proc_uint8_t *str, utf8proc_ssize_t strlen, utf8proc_ssize_t strlen,
utf8proc_int32_t *buffer, utf8proc_ssize_t bufsize, utf8proc_option_t options utf8proc_int32_t* buffer,
); utf8proc_ssize_t bufsize,
utf8proc_option_t options);
/** /**
* The same as @ref utf8proc_decompose, but also takes a `custom_func` mapping function * The same as @ref utf8proc_decompose, but also takes a `custom_func` mapping
* that is called on each codepoint in `str` before any other transformations * function that is called on each codepoint in `str` before any other
* (along with a `custom_data` pointer that is passed through to `custom_func`). * transformations (along with a `custom_data` pointer that is passed through to
* The `custom_func` argument is ignored if it is `NULL`. See also @ref utf8proc_map_custom. * `custom_func`). The `custom_func` argument is ignored if it is `NULL`. See
* also @ref utf8proc_map_custom.
*/ */
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose_custom( UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose_custom(const utf8proc_uint8_t* str,
const utf8proc_uint8_t *str, utf8proc_ssize_t strlen, utf8proc_ssize_t strlen,
utf8proc_int32_t *buffer, utf8proc_ssize_t bufsize, utf8proc_option_t options, utf8proc_int32_t* buffer,
utf8proc_custom_func custom_func, void *custom_data utf8proc_ssize_t bufsize,
); utf8proc_option_t options,
utf8proc_custom_func custom_func,
void* custom_data);
/** /**
* Normalizes the sequence of `length` codepoints pointed to by `buffer` * Normalizes the sequence of `length` codepoints pointed to by `buffer`
...@@ -541,20 +568,24 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose_custom( ...@@ -541,20 +568,24 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose_custom(
* - @ref UTF8PROC_NLF2LS - convert LF, CRLF, CR and NEL into LS * - @ref UTF8PROC_NLF2LS - convert LF, CRLF, CR and NEL into LS
* - @ref UTF8PROC_NLF2PS - convert LF, CRLF, CR and NEL into PS * - @ref UTF8PROC_NLF2PS - convert LF, CRLF, CR and NEL into PS
* - @ref UTF8PROC_NLF2LF - convert LF, CRLF, CR and NEL into LF * - @ref UTF8PROC_NLF2LF - convert LF, CRLF, CR and NEL into LF
* - @ref UTF8PROC_STRIPCC - strip or convert all non-affected control characters * - @ref UTF8PROC_STRIPCC - strip or convert all non-affected control
* characters
* - @ref UTF8PROC_COMPOSE - try to combine decomposed codepoints into composite * - @ref UTF8PROC_COMPOSE - try to combine decomposed codepoints into composite
* codepoints * codepoints
* - @ref UTF8PROC_STABLE - prohibit combining characters that would violate * - @ref UTF8PROC_STABLE - prohibit combining characters that would violate
* the unicode versioning stability * the unicode versioning stability
* *
* @return * @return
* In case of success, the length (in codepoints) of the normalized UTF-32 string is * In case of success, the length (in codepoints) of the normalized UTF-32
* returned; otherwise, a negative error code is returned (@ref utf8proc_errmsg). * string is returned; otherwise, a negative error code is returned (@ref
* utf8proc_errmsg).
* *
* @warning The entries of the array pointed to by `str` have to be in the * @warning The entries of the array pointed to by `str` have to be in the
* range `0x0000` to `0x10FFFF`. Otherwise, the program might crash! * range `0x0000` to `0x10FFFF`. Otherwise, the program might crash!
*/ */
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_normalize_utf32(utf8proc_int32_t *buffer, utf8proc_ssize_t length, utf8proc_option_t options); UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_normalize_utf32(utf8proc_int32_t* buffer,
utf8proc_ssize_t length,
utf8proc_option_t options);
/** /**
* Reencodes the sequence of `length` codepoints pointed to by `buffer` * Reencodes the sequence of `length` codepoints pointed to by `buffer`
...@@ -567,7 +598,8 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_normalize_utf32(utf8proc_int32_t *b ...@@ -567,7 +598,8 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_normalize_utf32(utf8proc_int32_t *b
* - @ref UTF8PROC_NLF2LS - convert LF, CRLF, CR and NEL into LS * - @ref UTF8PROC_NLF2LS - convert LF, CRLF, CR and NEL into LS
* - @ref UTF8PROC_NLF2PS - convert LF, CRLF, CR and NEL into PS * - @ref UTF8PROC_NLF2PS - convert LF, CRLF, CR and NEL into PS
* - @ref UTF8PROC_NLF2LF - convert LF, CRLF, CR and NEL into LF * - @ref UTF8PROC_NLF2LF - convert LF, CRLF, CR and NEL into LF
* - @ref UTF8PROC_STRIPCC - strip or convert all non-affected control characters * - @ref UTF8PROC_STRIPCC - strip or convert all non-affected control
* characters
* - @ref UTF8PROC_COMPOSE - try to combine decomposed codepoints into composite * - @ref UTF8PROC_COMPOSE - try to combine decomposed codepoints into composite
* codepoints * codepoints
* - @ref UTF8PROC_STABLE - prohibit combining characters that would violate * - @ref UTF8PROC_STABLE - prohibit combining characters that would violate
...@@ -584,35 +616,39 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_normalize_utf32(utf8proc_int32_t *b ...@@ -584,35 +616,39 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_normalize_utf32(utf8proc_int32_t *b
* entries of the array pointed to by `str` have to be in the * entries of the array pointed to by `str` have to be in the
* range `0x0000` to `0x10FFFF`. Otherwise, the program might crash! * range `0x0000` to `0x10FFFF`. Otherwise, the program might crash!
*/ */
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_reencode(utf8proc_int32_t *buffer, utf8proc_ssize_t length, utf8proc_option_t options); UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_reencode(utf8proc_int32_t* buffer,
utf8proc_ssize_t length,
utf8proc_option_t options);
/** /**
* Given a pair of consecutive codepoints, return whether a grapheme break is * Given a pair of consecutive codepoints, return whether a grapheme break is
* permitted between them (as defined by the extended grapheme clusters in UAX#29). * permitted between them (as defined by the extended grapheme clusters in
* UAX#29).
* *
* @param codepoint1 The first codepoint. * @param codepoint1 The first codepoint.
* @param codepoint2 The second codepoint, occurring consecutively after `codepoint1`. * @param codepoint2 The second codepoint, occurring consecutively after
* @param state Beginning with Version 29 (Unicode 9.0.0), this algorithm requires * `codepoint1`.
* state to break graphemes. This state can be passed in as a pointer * @param state Beginning with Version 29 (Unicode 9.0.0), this algorithm
* requires state to break graphemes. This state can be passed in as a pointer
* in the `state` argument and should initially be set to 0. If the * in the `state` argument and should initially be set to 0. If the
* state is not passed in (i.e. a null pointer is passed), UAX#29 rules * state is not passed in (i.e. a null pointer is passed), UAX#29
* GB10/12/13 which require this state will not be applied, essentially * rules GB10/12/13 which require this state will not be applied, essentially
* matching the rules in Unicode 8.0.0. * matching the rules in Unicode 8.0.0.
* *
* @warning If the state parameter is used, `utf8proc_grapheme_break_stateful` must * @warning If the state parameter is used, `utf8proc_grapheme_break_stateful`
* be called IN ORDER on ALL potential breaks in a string. However, it * must be called IN ORDER on ALL potential breaks in a string. However, it is
* is safe to reset the state to zero after a grapheme break. * safe to reset the state to zero after a grapheme break.
*/ */
UTF8PROC_DLLEXPORT utf8proc_bool utf8proc_grapheme_break_stateful( UTF8PROC_DLLEXPORT utf8proc_bool utf8proc_grapheme_break_stateful(utf8proc_int32_t codepoint1,
utf8proc_int32_t codepoint1, utf8proc_int32_t codepoint2, utf8proc_int32_t *state); utf8proc_int32_t codepoint2,
utf8proc_int32_t* state);
/** /**
* Same as @ref utf8proc_grapheme_break_stateful, except without support for the * Same as @ref utf8proc_grapheme_break_stateful, except without support for the
* Unicode 9 additions to the algorithm. Supported for legacy reasons. * Unicode 9 additions to the algorithm. Supported for legacy reasons.
*/ */
UTF8PROC_DLLEXPORT utf8proc_bool utf8proc_grapheme_break( UTF8PROC_DLLEXPORT utf8proc_bool utf8proc_grapheme_break(utf8proc_int32_t codepoint1,
utf8proc_int32_t codepoint1, utf8proc_int32_t codepoint2); utf8proc_int32_t codepoint2);
/** /**
* Given a codepoint `c`, return the codepoint of the corresponding * Given a codepoint `c`, return the codepoint of the corresponding
...@@ -636,21 +672,21 @@ UTF8PROC_DLLEXPORT utf8proc_int32_t utf8proc_toupper(utf8proc_int32_t c); ...@@ -636,21 +672,21 @@ UTF8PROC_DLLEXPORT utf8proc_int32_t utf8proc_toupper(utf8proc_int32_t c);
UTF8PROC_DLLEXPORT utf8proc_int32_t utf8proc_totitle(utf8proc_int32_t c); UTF8PROC_DLLEXPORT utf8proc_int32_t utf8proc_totitle(utf8proc_int32_t c);
/** /**
* Given a codepoint `c`, return `1` if the codepoint corresponds to a lower-case character * Given a codepoint `c`, return `1` if the codepoint corresponds to a
* and `0` otherwise. * lower-case character and `0` otherwise.
*/ */
UTF8PROC_DLLEXPORT int utf8proc_islower(utf8proc_int32_t c); UTF8PROC_DLLEXPORT int utf8proc_islower(utf8proc_int32_t c);
/** /**
* Given a codepoint `c`, return `1` if the codepoint corresponds to an upper-case character * Given a codepoint `c`, return `1` if the codepoint corresponds to an
* and `0` otherwise. * upper-case character and `0` otherwise.
*/ */
UTF8PROC_DLLEXPORT int utf8proc_isupper(utf8proc_int32_t c); UTF8PROC_DLLEXPORT int utf8proc_isupper(utf8proc_int32_t c);
/** /**
* Given a codepoint, return a character width analogous to `wcwidth(codepoint)`, * Given a codepoint, return a character width analogous to
* except that a width of 0 is returned for non-printable codepoints * `wcwidth(codepoint)`, except that a width of 0 is returned for non-printable
* instead of -1 as in `wcwidth`. * codepoints instead of -1 as in `wcwidth`.
* *
* @note * @note
* If you want to check for particular types of non-printable characters, * If you want to check for particular types of non-printable characters,
...@@ -667,7 +703,7 @@ UTF8PROC_DLLEXPORT utf8proc_category_t utf8proc_category(utf8proc_int32_t codepo ...@@ -667,7 +703,7 @@ UTF8PROC_DLLEXPORT utf8proc_category_t utf8proc_category(utf8proc_int32_t codepo
* Return the two-letter (nul-terminated) Unicode category string for * Return the two-letter (nul-terminated) Unicode category string for
* the codepoint (e.g. `"Lu"` or `"Co"`). * the codepoint (e.g. `"Lu"` or `"Co"`).
*/ */
UTF8PROC_DLLEXPORT const char *utf8proc_category_string(utf8proc_int32_t codepoint); UTF8PROC_DLLEXPORT const char* utf8proc_category_string(utf8proc_int32_t codepoint);
/** /**
* Maps the given UTF-8 string pointed to by `str` to a new UTF-8 * Maps the given UTF-8 string pointed to by `str` to a new UTF-8
...@@ -688,9 +724,10 @@ UTF8PROC_DLLEXPORT const char *utf8proc_category_string(utf8proc_int32_t codepoi ...@@ -688,9 +724,10 @@ UTF8PROC_DLLEXPORT const char *utf8proc_category_string(utf8proc_int32_t codepoi
* @note The memory of the new UTF-8 string will have been allocated * @note The memory of the new UTF-8 string will have been allocated
* with `malloc`, and should therefore be deallocated with `free`. * with `malloc`, and should therefore be deallocated with `free`.
*/ */
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_map( UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_map(const utf8proc_uint8_t* str,
const utf8proc_uint8_t *str, utf8proc_ssize_t strlen, utf8proc_uint8_t **dstptr, utf8proc_option_t options utf8proc_ssize_t strlen,
); utf8proc_uint8_t** dstptr,
utf8proc_option_t options);
/** /**
* Like @ref utf8proc_map, but also takes a `custom_func` mapping function * Like @ref utf8proc_map, but also takes a `custom_func` mapping function
...@@ -698,10 +735,12 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_map( ...@@ -698,10 +735,12 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_map(
* (along with a `custom_data` pointer that is passed through to `custom_func`). * (along with a `custom_data` pointer that is passed through to `custom_func`).
* The `custom_func` argument is ignored if it is `NULL`. * The `custom_func` argument is ignored if it is `NULL`.
*/ */
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_map_custom( UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_map_custom(const utf8proc_uint8_t* str,
const utf8proc_uint8_t *str, utf8proc_ssize_t strlen, utf8proc_uint8_t **dstptr, utf8proc_option_t options, utf8proc_ssize_t strlen,
utf8proc_custom_func custom_func, void *custom_data utf8proc_uint8_t** dstptr,
); utf8proc_option_t options,
utf8proc_custom_func custom_func,
void* custom_data);
/** @name Unicode normalization /** @name Unicode normalization
* *
...@@ -712,18 +751,18 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_map_custom( ...@@ -712,18 +751,18 @@ UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_map_custom(
*/ */
/** @{ */ /** @{ */
/** NFD normalization (@ref UTF8PROC_DECOMPOSE). */ /** NFD normalization (@ref UTF8PROC_DECOMPOSE). */
UTF8PROC_DLLEXPORT utf8proc_uint8_t *utf8proc_NFD(const utf8proc_uint8_t *str); UTF8PROC_DLLEXPORT utf8proc_uint8_t* utf8proc_NFD(const utf8proc_uint8_t* str);
/** NFC normalization (@ref UTF8PROC_COMPOSE). */ /** NFC normalization (@ref UTF8PROC_COMPOSE). */
UTF8PROC_DLLEXPORT utf8proc_uint8_t *utf8proc_NFC(const utf8proc_uint8_t *str); UTF8PROC_DLLEXPORT utf8proc_uint8_t* utf8proc_NFC(const utf8proc_uint8_t* str);
/** NFKD normalization (@ref UTF8PROC_DECOMPOSE and @ref UTF8PROC_COMPAT). */ /** NFKD normalization (@ref UTF8PROC_DECOMPOSE and @ref UTF8PROC_COMPAT). */
UTF8PROC_DLLEXPORT utf8proc_uint8_t *utf8proc_NFKD(const utf8proc_uint8_t *str); UTF8PROC_DLLEXPORT utf8proc_uint8_t* utf8proc_NFKD(const utf8proc_uint8_t* str);
/** NFKC normalization (@ref UTF8PROC_COMPOSE and @ref UTF8PROC_COMPAT). */ /** NFKC normalization (@ref UTF8PROC_COMPOSE and @ref UTF8PROC_COMPAT). */
UTF8PROC_DLLEXPORT utf8proc_uint8_t *utf8proc_NFKC(const utf8proc_uint8_t *str); UTF8PROC_DLLEXPORT utf8proc_uint8_t* utf8proc_NFKC(const utf8proc_uint8_t* str);
/** /**
* NFKC_Casefold normalization (@ref UTF8PROC_COMPOSE and @ref UTF8PROC_COMPAT * NFKC_Casefold normalization (@ref UTF8PROC_COMPOSE and @ref UTF8PROC_COMPAT
* and @ref UTF8PROC_CASEFOLD and @ref UTF8PROC_IGNORE). * and @ref UTF8PROC_CASEFOLD and @ref UTF8PROC_IGNORE).
**/ **/
UTF8PROC_DLLEXPORT utf8proc_uint8_t *utf8proc_NFKC_Casefold(const utf8proc_uint8_t *str); UTF8PROC_DLLEXPORT utf8proc_uint8_t* utf8proc_NFKC_Casefold(const utf8proc_uint8_t* str);
/** @} */ /** @} */
#ifdef __cplusplus #ifdef __cplusplus
......
#include <Bert.h>
#include <Filesystem.h>
#include <SimpleLog.h>
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
#include <Bert.h>
#include <SimpleLog.h>
#include <Filesystem.h>
#include <tokenization.h> #include <tokenization.h>
int main(int argc, char *argv[]) int main(int argc, char* argv[])
{ {
// 加载Bert模型 // 加载Bert模型
migraphxSamples::Bert bert; migraphxSamples::Bert bert;
migraphxSamples::ErrorCode errorCode = bert.Initialize(); migraphxSamples::ErrorCode errorCode = bert.Initialize();
if (errorCode != migraphxSamples::SUCCESS) if(errorCode != migraphxSamples::SUCCESS)
{ {
LOG_ERROR(stdout, "fail to initialize Bert!\n"); LOG_ERROR(stdout, "fail to initialize Bert!\n");
exit(-1); exit(-1);
} }
LOG_INFO(stdout, "succeed to initialize Bert\n"); LOG_INFO(stdout, "succeed to initialize Bert\n");
int max_seq_length = 256; // 滑动窗口的长度 int max_seq_length = 256; // 滑动窗口的长度
int max_query_length = 64; // 问题的最大长度 int max_query_length = 64; // 问题的最大长度
int batch_size = 1; // batch_size值 int batch_size = 1; // batch_size值
int n_best_size = 20; // 索引数量 int n_best_size = 20; // 索引数量
int max_answer_length = 30; // 答案的最大长度 int max_answer_length = 30; // 答案的最大长度
// 上下文文本数据 // 上下文文本数据
const char text[] = { u8"ROCm is the first open-source exascale-class platform for accelerated computing that’s also programming-language independent. It brings a philosophy of choice, minimalism and modular software development to GPU computing. You are free to choose or even develop tools and a language run time for your application. ROCm is built for scale, it supports multi-GPU computing and has a rich system run time with the critical features that large-scale application, compiler and language-run-time development requires. Since the ROCm ecosystem is comprised of open technologies: frameworks (Tensorflow / PyTorch), libraries (MIOpen / Blas / RCCL), programming model (HIP), inter-connect (OCD) and up streamed Linux® Kernel support – the platform is continually optimized for performance and extensibility." }; const char text[] = {u8"ROCm is the first open-source exascale-class platform for accelerated "
u8"computing that’s also programming-language independent. It brings a "
u8"philosophy of choice, minimalism and modular software development to "
u8"GPU computing. You are free to choose or even develop tools and a "
u8"language run time for your application. ROCm is built for scale, it "
u8"supports multi-GPU computing and has a rich system run time with the "
u8"critical features that large-scale application, compiler and "
u8"language-run-time development requires. Since the ROCm ecosystem is "
u8"comprised of open technologies: frameworks (Tensorflow / PyTorch), "
u8"libraries (MIOpen / Blas / RCCL), programming model (HIP), "
u8"inter-connect (OCD) and up streamed Linux® Kernel support – the "
u8"platform is continually optimized for performance and extensibility."};
char question[100]; char question[100];
std::vector<std::vector<long unsigned int>> input_ids; std::vector<std::vector<long unsigned int>> input_ids;
...@@ -35,14 +46,22 @@ int main(int argc, char *argv[]) ...@@ -35,14 +46,22 @@ int main(int argc, char *argv[])
std::vector<float> end_position; std::vector<float> end_position;
std::string answer = {}; std::string answer = {};
cuBERT::FullTokenizer tokenizer = cuBERT::FullTokenizer("../Resource/uncased_L-12_H-768_A-12/vocab.txt"); // 分词工具 cuBERT::FullTokenizer tokenizer =
cuBERT::FullTokenizer("../Resource/uncased_L-12_H-768_A-12/vocab.txt"); // 分词工具
while (true) while(true)
{ {
// 数据前处理 // 数据前处理
std::cout << "question: "; std::cout << "question: ";
cin.getline(question, 100); cin.getline(question, 100);
bert.Preprocessing(tokenizer, batch_size, max_seq_length, text, question, input_ids, input_masks, segment_ids); bert.Preprocessing(tokenizer,
batch_size,
max_seq_length,
text,
question,
input_ids,
input_masks,
segment_ids);
// 推理 // 推理
bert.Inference(input_ids, input_masks, segment_ids, start_position, end_position); bert.Inference(input_ids, input_masks, segment_ids, start_position, end_position);
...@@ -61,6 +80,6 @@ int main(int argc, char *argv[]) ...@@ -61,6 +80,6 @@ int main(int argc, char *argv[])
end_position.clear(); end_position.clear();
answer = {}; answer = {};
} }
return 0; return 0;
} }
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment