Commit 1d125612 authored by liucong8560's avatar liucong8560
Browse files

Merge branch 'develop' into 'master'

Develop

See merge request modelzoo/bert_migraphx!1
parents b8fc3c49 d5a26d95
...@@ -6,40 +6,108 @@ BERT的全称为Bidirectional Encoder Representation from Transformers,是一 ...@@ -6,40 +6,108 @@ BERT的全称为Bidirectional Encoder Representation from Transformers,是一
## 模型结构 ## 模型结构
以往的预训练模型的结构会受到单向语言模型(从左到右或者从右到左)的限制,因而也限制了模型的表征能力,使其只能获取单方向的上下文信息。而BERT利用MLM进行预训练并且采用深层的双向Transformer组件(单向的Transformer一般被称为Transformer decoder,其每一个token(符号)只会attend到目前往左的token。而双向的Transformer则被称为Transformer encoder,其每一个token会attend到所有的token)来构建整个模型,因此最终生成能融合左右上下文信息的深层双向语言表征。 以往的预训练模型的结构会受到单向语言模型(从左到右或者从右到左)的限制,因而也限制了模型的表征能力,使其只能获取单方向的上下文信息。而BERT利用MLM进行预训练并且采用深层的双向Transformer组件(单向的Transformer一般被称为Transformer decoder,其每一个token(符号)只会attend到目前往左的token。而双向的Transformer则被称为Transformer encoder,其每一个token会attend到所有的token)来构建整个模型,因此最终生成能融合左右上下文信息的深层双向语言表征。
## 推理 ## 构建安装
### 环境配置
在光源可拉取推理的docker镜像,BERT模型推理的镜像如下: 在光源可拉取推理的docker镜像,BERT模型推理的镜像如下:
```python ```python
docker pull image.sourcefind.cn:5000/dcu/admin/base/custom:ort_dcu_1.14.0_migraphx2.5.2_dtk22.10.1 docker pull image.sourcefind.cn:5000/dcu/admin/base/custom:ort1.14.0_migraphx3.0.0-dtk22.10.1
``` ```
在光合开发者社区可下载MIGraphX安装包,python依赖安装: ### 安装Opencv依赖
```python ```python
pip install -r requirement.txt cd <path_to_migraphx_samples>
sh ./3rdParty/InstallOpenCVDependences.sh
```
### 修改CMakeLists.txt
- 如果使用ubuntu系统,需要修改CMakeLists.txt中依赖库路径:
将"${CMAKE_CURRENT_SOURCE_DIR}/depend/lib64/"修改为"${CMAKE_CURRENT_SOURCE_DIR}/depend/lib/"
- **MIGraphX2.3.0及以上版本需要c++17**
### 安装OpenCV并构建工程
```
rbuild build -d depend
```
### 设置环境变量
将依赖库依赖加入环境变量LD_LIBRARY_PATH,在~/.bashrc中添加如下语句:
**Centos**:
```
export LD_LIBRARY_PATH=<path_to_migraphx_samples>/depend/lib64/:$LD_LIBRARY_PATH
``` ```
本次采用经典的Bert模型完成问题回答任务,模型和分词文件下载链接:https://pan.baidu.com/s/1yc30IzM4ocOpTpfFuUMR0w, 提取码:8f1a, 将bertsquad-10.onnx文件和uncased_L-12_H-768_A-12分词文件保存在model文件夹下。 **Ubuntu**:
### 运行示例 ```
export LD_LIBRARY_PATH=<path_to_migraphx_samples>/depend/lib/:$LD_LIBRARY_PATH
```
我们提供了基于MIGraphX的推理脚本,版本依赖: 然后执行:
- Migraphx(DCU版本) >= 2.5.2 ```
source ~/.bashrc
```
## 推理
bert.py是基于Migraphx的推理脚本,使用需安装好MIGraphX。使用方法: 本次采用经典的Bert模型完成问题回答任务,模型和分词文件下载链接:https://pan.baidu.com/s/1yc30IzM4ocOpTpfFuUMR0w, 提取码:8f1a, 将bertsquad-10.onnx文件和uncased_L-12_H-768_A-12分词文件保存在Resource\Models\NLP\Bert文件夹下。下面介绍如何运行python代码和C++代码示例,具体推理代码解析,在Doc目录中有详细说明。
### python版本推理
1.参考《MIGraphX教程》中的安装方法安装MIGraphX并设置好PYTHONPATH
2.安装依赖:
```python
# 进入migraphx samples工程根目录
cd <path_to_migraphx_samples>
# 进入示例程序目录
cd Python/NLP/Bert
# 安装依赖
pip install -r requirements.txt
```
3.在Python/NLP/Bert目录下执行如下命令运行该示例程序:
```python ```python
# 执行推理
python bert.py python bert.py
``` ```
推理结果为: 输出结果为:
<img src="./Doc/Images/Bert_05.png" style="zoom:90%;" align=middle>
输出结果中,问题id对应预测概率值最大的答案。
### C++版本推理
切换到build目录中,执行如下命令:
```python
cd ./build/
./MIGraphX_Samples
```
根据提示选择运行BERT模型的示例程序
```python
./MIGraphX_Samples 0
```
如下所示,会在当前界面中提示输入问题,根据问题得到预测答案
<img src="./Sample_picture.png" style="zoom:100%;" align=middle> <img src="./Doc/Images/Bert_06.png" style="zoom:100%;" align=middle>
## 历史版本 ## 历史版本
......
<?xml version="1.0" encoding="GB2312"?>
<opencv_storage>
<!--Bert-->
<Bert>
<ModelPath>"../Resource/Models/NLP/Bert/bertsquad-10.onnx"</ModelPath>
</Bert>
</opencv_storage>
#include <fstream>
#include <sstream>
#include <migraphx/onnx.hpp>
#include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/quantization.hpp>
#include <CommonUtility.h>
#include <Filesystem.h>
#include <SimpleLog.h>
#include <algorithm>
#include <string>
#include <vector>
#include <stdexcept>
#include <Bert.h>
#include <tokenization.h>
namespace migraphxSamples
{
Bert::Bert():logFile(NULL)
{
}
Bert::~Bert()
{
configurationFile.release();
}
ErrorCode Bert::Initialize(InitializationParameterOfNLP initParamOfNLPBert)
{
// 初始化(获取日志文件,加载配置文件等)
ErrorCode errorCode=DoCommonInitialization(initParamOfNLPBert);
if(errorCode!=SUCCESS)
{
LOG_ERROR(logFile,"fail to DoCommonInitialization\n");
return errorCode;
}
LOG_INFO(logFile,"succeed to DoCommonInitialization\n");
// 获取配置文件参数
FileNode netNode = configurationFile["Bert"];
std::string modelPath=initializationParameter.parentPath+(std::string)netNode["ModelPath"];
// 加载模型
if(Exists(modelPath)==false)
{
LOG_ERROR(logFile,"%s not exist!\n",modelPath.c_str());
return MODEL_NOT_EXIST;
}
net = migraphx::parse_onnx(modelPath);
LOG_INFO(logFile,"succeed to load model: %s\n",GetFileName(modelPath).c_str());
// 获取模型输入属性
std::unordered_map<std::string, migraphx::shape> input = net.get_parameter_shapes();
inputName1 = "unique_ids_raw_output___9:0";
inputShape1 = input.at(inputName1);
inputName2 = "segment_ids:0";
inputShape2 = input.at(inputName2);
inputName3 = "input_mask:0";
inputShape3 = input.at(inputName3);
inputName4 = "input_ids:0";
inputShape4 = input.at(inputName4);
// 设置模型为GPU模式
migraphx::target gpuTarget = migraphx::gpu::target{};
// 编译模型
migraphx::compile_options options;
options.device_id=0; // 设置GPU设备,默认为0号设备
options.offload_copy=true; // 设置offload_copy
net.compile(gpuTarget,options);
LOG_INFO(logFile,"succeed to compile model: %s\n",GetFileName(modelPath).c_str());
// Run once by itself
migraphx::parameter_map inputData;
inputData[inputName1]=migraphx::generate_argument(inputShape1);
inputData[inputName2]=migraphx::generate_argument(inputShape2);
inputData[inputName3]=migraphx::generate_argument(inputShape3);
inputData[inputName4]=migraphx::generate_argument(inputShape4);
net.eval(inputData);
return SUCCESS;
}
ErrorCode Bert::DoCommonInitialization(InitializationParameterOfNLP initParamOfNLPBert)
{
initializationParameter = initParamOfNLPBert;
// 获取日志文件
logFile=LogManager::GetInstance()->GetLogFile(initializationParameter.logName);
// 加载配置文件
std::string configFilePath=initializationParameter.configFilePath;
if(!Exists(configFilePath))
{
LOG_ERROR(logFile, "no configuration file!\n");
return CONFIG_FILE_NOT_EXIST;
}
if(!configurationFile.open(configFilePath, FileStorage::READ))
{
LOG_ERROR(logFile, "fail to open configuration file\n");
return FAIL_TO_OPEN_CONFIG_FILE;
}
LOG_INFO(logFile, "succeed to open configuration file\n");
// 修改父路径
std::string &parentPath = initializationParameter.parentPath;
if (!parentPath.empty())
{
if(!IsPathSeparator(parentPath[parentPath.size() - 1]))
{
parentPath+=PATH_SEPARATOR;
}
}
return SUCCESS;
}
ErrorCode Bert::Inference(const std::vector<std::vector<long unsigned int>> &input_ids,
const std::vector<std::vector<long unsigned int>> &input_masks,
const std::vector<std::vector<long unsigned int>> &segment_ids,
std::vector<float> &start_position,
std::vector<float> &end_position)
{
// 保存预处理后的数据
int num = input_ids.size();
long unsigned int input_id[num][256];
long unsigned int input_mask[num][256];
long unsigned int segment_id[num][256];
long unsigned int position_id[num][1];
for(int i=0;i<input_ids.size();++i)
{
for(int j=0;j<input_ids[0].size();++j)
{
input_id[i][j] = input_ids[i][j];
segment_id[i][j] = segment_ids[i][j];
input_mask[i][j] = input_masks[i][j];
position_id[i][0] = 1;
}
}
migraphx::parameter_map inputData;
std::vector<migraphx::argument> results;
migraphx::argument start_prediction;
migraphx::argument end_prediction;
float* start_data;
float* end_data;
for(int i=0;i<input_ids.size();++i)
{
// 输入数据
inputData[inputName1]=migraphx::argument{inputShape1, (long unsigned int*)position_id[i]};
inputData[inputName2]=migraphx::argument{inputShape2, (long unsigned int*)segment_id[i]};
inputData[inputName3]=migraphx::argument{inputShape3, (long unsigned int*)input_mask[i]};
inputData[inputName4]=migraphx::argument{inputShape4, (long unsigned int*)input_id[i]};
// 推理
results = net.eval(inputData);
// 获取输出节点的属性
start_prediction = results[1]; // 答案的开始位置
start_data = (float *)start_prediction.data(); // 开始位置的数据指针
end_prediction = results[0]; // 答案的结束位置
end_data = (float *)end_prediction.data(); // 结束位置的数据指针
// 保存推理结果
for(int i=0;i<256;++i)
{
start_position.push_back(start_data[i]);
end_position.push_back(end_data[i]);
}
}
return SUCCESS;
}
ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer,
int batch_size,
int max_seq_length,
const char *text,
char *question,
std::vector<std::vector<long unsigned int>> &input_ids,
std::vector<std::vector<long unsigned int>> &input_masks,
std::vector<std::vector<long unsigned int>> &segment_ids)
{
std::vector<long unsigned int> input_id(max_seq_length);
std::vector<long unsigned int> input_mask(max_seq_length);
std::vector<long unsigned int> segment_id(max_seq_length);
// 对上下文文本和问题进行分词操作
tokens_text.reserve(max_seq_length);
tokens_question.reserve(max_seq_length);
tokenizer.tokenize(text, &tokens_text, max_seq_length);
tokenizer.tokenize(question, &tokens_question, max_seq_length);
// 当上下文文本加问题文本的长度大于规定的最大长度,采用滑动窗口操作
if(tokens_text.size() + tokens_question.size() > max_seq_length - 5)
{
int windows_len = max_seq_length - 5 -tokens_question.size();
std::vector<std::string> tokens_text_window(windows_len);
std::vector<std::vector<std::string>> tokens_text_windows;
int start_offset = 0;
int position = 0;
int n;
while (start_offset < tokens_text.size())
{
n = 0;
if(start_offset+windows_len>tokens_text.size())
{
for(int i=start_offset;i<tokens_text.size();++i)
{
tokens_text_window[n] = tokens_text[i];
++n;
}
}
else
{
for(int i=start_offset;i<start_offset+windows_len;++i)
{
tokens_text_window[n] = tokens_text[i];
++n;
}
}
tokens_text_windows.push_back(tokens_text_window);
start_offset += 256;
++position;
}
for(int i=0;i<position;++i)
{
input_id[0] = tokenizer.convert_token_to_id("[CLS]");
segment_id[0] = 0;
input_id[1] = tokenizer.convert_token_to_id("[CLS]");
segment_id[1] = 0;
for (int j=0;j<tokens_question.size();++j)
{
input_id[j + 2] = tokenizer.convert_token_to_id(tokens_question[j]);
segment_id[j + 2] = 0;
}
input_id[tokens_question.size() + 2] = tokenizer.convert_token_to_id("[SEP]");
segment_id[tokens_question.size() + 2] = 0;
input_id[tokens_question.size() + 3] = tokenizer.convert_token_to_id("[SEP]");
segment_id[tokens_question.size() + 3] = 0;
for (int j=0;j<tokens_question.size();++j)
{
input_id[j + tokens_text_windows[i].size() + 4] = tokenizer.convert_token_to_id(tokens_text_windows[i][j]);
segment_id[j + tokens_text_windows[i].size() + 4] = 1;
}
input_id[tokens_question.size() + tokens_text_windows[i].size() + 4] = tokenizer.convert_token_to_id("[SEP]");
segment_id[tokens_question.size() + tokens_text_windows[i].size() + 4] = 1;
// 掩码为1的表示为真实标记,0表示为填充标记。
int len = tokens_text_windows[i].size() + tokens_question.size() + 5;
std::fill(input_mask.begin(), input_mask.begin() + len, 1);
std::fill(input_mask.begin() + len, input_mask.begin() + max_seq_length, 0);
std::fill(input_id.begin() + len, input_id.begin() + max_seq_length, 0);
std::fill(segment_id.begin() + len, segment_id.begin() + max_seq_length, 0);
input_ids.push_back(input_id);
input_masks.push_back(input_mask);
segment_ids.push_back(segment_id);
}
}
else
{
// 当上下文文本加问题文本的长度小于等于规定的最大长度,直接拼接处理
input_id[0] = tokenizer.convert_token_to_id("[CLS]");
segment_id[0] = 0;
input_id[1] = tokenizer.convert_token_to_id("[CLS]");
segment_id[1] = 0;
for (int i=0;i<tokens_question.size();++i)
{
input_id[i + 2] = tokenizer.convert_token_to_id(tokens_question[i]);
segment_id[i + 2] = 0;
}
input_id[tokens_question.size() + 2] = tokenizer.convert_token_to_id("[SEP]");
segment_id[tokens_question.size() + 2] = 0;
input_id[tokens_question.size() + 3] = tokenizer.convert_token_to_id("[SEP]");
segment_id[tokens_question.size() + 3] = 0;
for (int i=0;i<tokens_text.size();++i)
{
input_id[i + tokens_question.size() + 4] = tokenizer.convert_token_to_id(tokens_text[i]);
segment_id[i + tokens_question.size() + 4] = 1;
}
input_id[tokens_question.size() + tokens_text.size() + 4] = tokenizer.convert_token_to_id("[SEP]");
segment_id[tokens_question.size() + tokens_text.size() + 4] = 1;
// 掩码为1的表示为真实标记,0表示为填充标记。
int len = tokens_text.size() + tokens_question.size() + 5;
std::fill(input_mask.begin(), input_mask.begin() + len, 1);
std::fill(input_mask.begin() + len, input_mask.begin() + max_seq_length, 0);
std::fill(input_id.begin() + len, input_id.begin() + max_seq_length, 0);
std::fill(segment_id.begin() + len, segment_id.begin() + max_seq_length, 0);
input_ids.push_back(input_id);
input_masks.push_back(input_mask);
segment_ids.push_back(segment_id);
}
return SUCCESS;
}
static bool Compare(Sort_st a, Sort_st b)
{
return a.value > b.value;
}
static bool CompareM(ResultOfPredictions a, ResultOfPredictions b)
{
return a.start_predictionvalue + a.end_predictionvalue > b.start_predictionvalue + b.end_predictionvalue;
}
ErrorCode Bert::Postprocessing(int n_best_size,
int max_answer_length,
const std::vector<float> &start_position,
const std::vector<float> &end_position,
std::string &answer)
{
// 取前n_best_size个最大概率值的索引
std::vector<Sort_st> start_array(start_position.size());
std::vector<Sort_st> end_array(end_position.size());
for (int i=0;i<start_position.size();++i)
{
start_array[i].index = i;
start_array[i].value = start_position.at(i);
end_array[i].index = i;
end_array[i].value = end_position.at(i);
}
std::sort(start_array.begin(), start_array.end(), Compare);
std::sort(end_array.begin(), end_array.end(), Compare);
// 过滤和筛选,筛选掉不符合的索引
std::vector<ResultOfPredictions> resultsOfPredictions(400);
int num = start_position.size() / 256;
bool flag;
int n=0;
for(int i=0;i<n_best_size;++i)
{
for(int j=0;j<n_best_size;++j)
{
flag = false;
if(start_array[i].index > start_position.size())
{
continue;
}
if(end_array[j].index > end_position.size())
{
continue;
}
for(int t=0;t<num;++t)
{
if(start_array[i].index > t*256 && start_array[i].index < tokens_question.size()+4+t*256)
{
flag = true;
break;
}
if(end_array[j].index > t*256 && end_array[j].index < tokens_question.size()+4+t*256)
{
flag = true;
break;
}
}
if(flag)
{
continue;
}
if(start_array[i].index > end_array[j].index)
{
continue;
}
int length = end_array[j].index - start_array[i].index + 1;
if(length > max_answer_length)
{
continue;
}
resultsOfPredictions[n].start_index = start_array[i].index;
resultsOfPredictions[n].end_index = end_array[j].index;
resultsOfPredictions[n].start_predictionvalue = start_array[i].value;
resultsOfPredictions[n].end_predictionvalue = end_array[j].value;
++n;
}
}
// 排序,将开始索引加结束索引的概率值和最大的排在前面
std::sort(resultsOfPredictions.begin(), resultsOfPredictions.end(), CompareM);
int start_index = 0;
int end_index = 0;
for(int i=0;i<400;++i)
{
if(resultsOfPredictions[i].start_predictionvalue==0 && resultsOfPredictions[i].end_predictionvalue==0)
{
continue;
}
start_index = resultsOfPredictions[i].start_index;
end_index = resultsOfPredictions[i].end_index;
break;
}
// 映射回上下文文本的索引,(当前的索引值-问题的长度-4)
int answer_start_index = start_index - tokens_question.size()- 4;
int answer_end_index = end_index - tokens_question.size() - 4 + 1;
// 根据开始索引和结束索引,获取区间内的数据
int j=0;
for(int i=answer_start_index;i<answer_end_index;++i)
{
if(tokens_text[i].find('#') != -1)
{
j=i-1;
break;
}
}
for(int i=answer_start_index;i<answer_end_index;++i)
{
answer += tokens_text[i];
if(tokens_text[i].find('#') != -1 || i==j)
{
continue;
}
answer += " ";
}
int index = 0;
while( (index = answer.find('#',index)) != string::npos)
{
answer.erase(index,1);
}
tokens_text.clear();
tokens_question.clear();
return SUCCESS;
}
}
#ifndef BERT_H
#define BERT_H
#include <cstdint>
#include <string>
#include <migraphx/program.hpp>
#include <CommonDefinition.h>
#include <tokenization.h>
using namespace cuBERT;
namespace migraphxSamples
{
typedef struct _Sort_st
{
int index;
float value;
}Sort_st;
typedef struct _ResultOfPredictions
{
int start_index;
int end_index;
float start_predictionvalue;
float end_predictionvalue;
}ResultOfPredictions;
class Bert
{
public:
Bert();
~Bert();
ErrorCode Initialize(InitializationParameterOfNLP initParamOfNLPBert);
ErrorCode Inference(const std::vector<std::vector<long unsigned int>> &input_ids,
const std::vector<std::vector<long unsigned int>> &input_masks,
const std::vector<std::vector<long unsigned int>> &segment_ids,
std::vector<float> &start_position,
std::vector<float> &end_position);
ErrorCode Preprocessing(cuBERT::FullTokenizer tokenizer,
int batch_size,
int max_seq_length,
const char *text,
char *question,
std::vector<std::vector<long unsigned int>> &input_ids,
std::vector<std::vector<long unsigned int>> &input_masks,
std::vector<std::vector<long unsigned int>> &segment_ids);
ErrorCode Postprocessing(int n_best_size,
int max_answer_length,
const std::vector<float> &start_position,
const std::vector<float> &end_position,
std::string &answer);
private:
ErrorCode DoCommonInitialization(InitializationParameterOfNLP initParamOfNLPBert);
private:
FILE *logFile;
cv::FileStorage configurationFile;
InitializationParameterOfNLP initializationParameter;
std::vector<std::string> tokens_text;
std::vector<std::string> tokens_question;
migraphx::program net;
std::string inputName1;
std::string inputName2;
std::string inputName3;
std::string inputName4;
migraphx::shape inputShape1;
migraphx::shape inputShape2;
migraphx::shape inputShape3;
migraphx::shape inputShape4;
};
}
#endif
\ No newline at end of file
#include <stdexcept>
#include <algorithm>
#include <cstring>
#include <fstream>
#include "utf8proc.h"
#include "./tokenization.h"
namespace cuBERT {
void FullTokenizer::convert_tokens_to_ids(const std::vector<std::string> &tokens, uint64_t *ids) {
for (int i = 0; i < tokens.size(); ++i) {
ids[i] = convert_token_to_id(tokens[i]);
}
}
// trim from start (in place)
static inline void ltrim(std::string &s) {
s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int ch) {
return !std::isspace(ch);
}));
}
// trim from end (in place)
static inline void rtrim(std::string &s) {
s.erase(std::find_if(s.rbegin(), s.rend(), [](int ch) {
return !std::isspace(ch);
}).base(), s.end());
}
// trim from both ends (in place)
static inline void trim(std::string &s) {
ltrim(s);
rtrim(s);
}
void load_vocab(const char *vocab_file, std::unordered_map<std::string, uint64_t> *vocab) {
std::ifstream file(vocab_file);
if (!file) {
throw std::invalid_argument("Unable to open vocab file");
}
unsigned int index = 0;
std::string line;
while (std::getline(file, line)) {
trim(line);
(*vocab)[line] = index;
index++;
}
file.close();
}
inline bool _is_whitespace(int c, const char *cat) {
if (c == ' ' || c == '\t' || c == '\n' || c == '\r') {
return true;
}
return cat[0] == 'Z' && cat[1] == 's';
}
inline bool _is_control(int c, const char *cat) {
// These are technically control characters but we count them as whitespace characters.
if (c == '\t' || c == '\n' || c == '\r') {
return false;
}
return 'C' == *cat;
}
inline bool _is_punctuation(int cp, const char *cat) {
// We treat all non-letter/number ASCII as punctuation.
// Characters such as "^", "$", and "`" are not in the Unicode
// Punctuation class but we treat them as punctuation anyways, for
// consistency.
if ((cp >= 33 && cp <= 47) || (cp >= 58 && cp <= 64) ||
(cp >= 91 && cp <= 96) || (cp >= 123 && cp <= 126)) {
return true;
}
return 'P' == *cat;
}
bool _is_whitespace(int c) {
return _is_whitespace(c, utf8proc_category_string(c));
}
bool _is_control(int c) {
return _is_control(c, utf8proc_category_string(c));
}
bool _is_punctuation(int cp) {
return _is_punctuation(cp, utf8proc_category_string(cp));
}
bool BasicTokenizer::_is_chinese_char(int cp) {
// This defines a "chinese character" as anything in the CJK Unicode block:
// https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
//
// Note that the CJK Unicode block is NOT all Japanese and Korean characters,
// despite its name. The modern Korean Hangul alphabet is a different block,
// as is Japanese Hiragana and Katakana. Those alphabets are used to write
// space-separated words, so they are not treated specially and handled
// like the all of the other languages.
return (cp >= 0x4E00 && cp <= 0x9FFF) ||
(cp >= 0x3400 && cp <= 0x4DBF) ||
(cp >= 0x20000 && cp <= 0x2A6DF) ||
(cp >= 0x2A700 && cp <= 0x2B73F) ||
(cp >= 0x2B740 && cp <= 0x2B81F) ||
(cp >= 0x2B820 && cp <= 0x2CEAF) ||
(cp >= 0xF900 && cp <= 0xFAFF) ||
(cp >= 0x2F800 && cp <= 0x2FA1F);
}
void BasicTokenizer::tokenize(const char *text, std::vector<std::string> *output_tokens, size_t max_length) {
// This was added on November 1st, 2018 for the multilingual and Chinese
// models. This is also applied to the English models now, but it doesn't
// matter since the English models were not trained on any Chinese data
// and generally don't have any Chinese data in them (there are Chinese
// characters in the vocabulary because Wikipedia does have some Chinese
// words in the English Wikipedia.).
if (do_lower_case) {
text = (const char *) utf8proc_NFD((const utf8proc_uint8_t *) text);
}
size_t word_bytes = std::strlen(text);
bool new_token = true;
size_t subpos = 0;
int cp;
char dst[4];
while (word_bytes > 0) {
int len = utf8proc_iterate((const utf8proc_uint8_t *) text + subpos, word_bytes, &cp);
if (len < 0) {
std::cerr << "UTF-8 decode error: " << text << std::endl;
break;
}
if (do_lower_case) {
cp = utf8proc_tolower(cp);
}
const char *cat = utf8proc_category_string(cp);
if (cp == 0 || cp == 0xfffd || _is_control(cp, cat)) {
// pass
} else if (do_lower_case && cat[0] == 'M' && cat[1] == 'n') {
// pass
} else if (_is_whitespace(cp, cat)) {
new_token = true;
} else {
size_t dst_len = len;
const char *dst_ptr = text + subpos;
if (do_lower_case) {
dst_len = utf8proc_encode_char(cp, (utf8proc_uint8_t *) dst);
dst_ptr = dst;
}
if (_is_punctuation(cp, cat) || _is_chinese_char(cp)) {
output_tokens->emplace_back(dst_ptr, dst_len);
new_token = true;
} else {
if (new_token) {
output_tokens->emplace_back(dst_ptr, dst_len);
new_token = false;
} else {
output_tokens->at(output_tokens->size() - 1).append(dst_ptr, dst_len);
}
}
}
word_bytes = word_bytes - len;
subpos = subpos + len;
// early terminate
if (output_tokens->size() >= max_length) {
break;
}
}
if (do_lower_case) {
free((void *) text);
}
}
void WordpieceTokenizer::tokenize(const std::string &token, std::vector<std::string> *output_tokens) {
if (token.size() > max_input_chars_per_word) { // FIXME: slightly different
output_tokens->push_back(unk_token);
return;
}
size_t output_tokens_len = output_tokens->size();
for (size_t start = 0; start < token.size();) {
bool is_bad = true;
// TODO: can be optimized by prefix-tree
for (size_t end = token.size(); start < end; --end) { // FIXME: slightly different
std::string substr = start > 0
? "##" + token.substr(start, end - start)
: token.substr(start, end - start);
if (vocab->count(substr)) {
is_bad = false;
output_tokens->push_back(substr);
start = end;
break;
}
}
if (is_bad) {
output_tokens->resize(output_tokens_len);
output_tokens->push_back(unk_token);
return;
}
}
}
void FullTokenizer::tokenize(const char *text, std::vector<std::string> *output_tokens, size_t max_length) {
std::vector<std::string> tokens;
tokens.reserve(max_length);
basic_tokenizer->tokenize(text, &tokens, max_length);
for (const auto &token : tokens) {
wordpiece_tokenizer->tokenize(token, output_tokens);
// early terminate
if (output_tokens->size() >= max_length) {
break;
}
}
}
}
#ifndef CUBERT_TOKENIZATION_H
#define CUBERT_TOKENIZATION_H
#include <string>
#include <vector>
#include <unordered_map>
#include <iostream>
namespace cuBERT {
void load_vocab(const char *vocab_file, std::unordered_map<std::string, uint64_t> *vocab);
/**
* Checks whether `chars` is a whitespace character.
* @param c
* @return
*/
bool _is_whitespace(int c);
/**
* Checks whether `chars` is a control character.
* @param c
* @return
*/
bool _is_control(int c);
/**
* Checks whether `chars` is a punctuation character.
* @param cp
* @return
*/
bool _is_punctuation(int cp);
/**
* Runs basic tokenization (punctuation splitting, lower casing, etc.).
*/
class BasicTokenizer {
public:
/**
* Constructs a BasicTokenizer.
* @param do_lower_case Whether to lower case the input.
*/
explicit BasicTokenizer(bool do_lower_case = true) : do_lower_case(do_lower_case) {}
BasicTokenizer(const BasicTokenizer &other) = delete;
virtual ~BasicTokenizer() = default;
/**
* Tokenizes a piece of text.
*
* to_lower
* _run_strip_accents Strips accents from a piece of text.
* _clean_text Performs invalid character removal and whitespace cleanup on text.
* _tokenize_chinese_chars Adds whitespace around any CJK character.
* _run_split_on_punc Splits punctuation on a piece of text.
* whitespace_tokenize Runs basic whitespace cleaning and splitting on a piece of text.
*
* @param text
* @param output_tokens
*/
void tokenize(const char *text, std::vector<std::string> *output_tokens, size_t max_length);
private:
const bool do_lower_case;
/**
* Checks whether CP is the codepoint of a CJK character.
* @param cp
* @return
*/
inline static bool _is_chinese_char(int cp);
};
/**
* Runs WordPiece tokenziation.
*/
class WordpieceTokenizer {
public:
explicit WordpieceTokenizer(
std::unordered_map<std::string, uint64_t> *vocab,
std::string unk_token = "[UNK]",
int max_input_chars_per_word = 200
) : vocab(vocab), unk_token(unk_token), max_input_chars_per_word(max_input_chars_per_word) {}
WordpieceTokenizer(const WordpieceTokenizer &other) = delete;
virtual ~WordpieceTokenizer() = default;
/**
* Tokenizes a piece of text into its word pieces.
*
* This uses a greedy longest-match-first algorithm to perform tokenization
* using the given vocabulary.
*
* For example:
* input = "unaffable"
* output = ["un", "##aff", "##able"]
*
* @param text A single token or whitespace separated tokens. This should have already been passed through `BasicTokenizer.
* @param output_tokens A list of wordpiece tokens.
*/
void tokenize(const std::string &text, std::vector<std::string> *output_tokens);
private:
const std::unordered_map<std::string, uint64_t> *vocab;
const std::string unk_token;
const int max_input_chars_per_word;
};
/**
* Runs end-to-end tokenziation.
*/
class FullTokenizer {
public:
FullTokenizer(const char *vocab_file, bool do_lower_case = true) {
vocab = new std::unordered_map<std::string, uint64_t>();
load_vocab(vocab_file, vocab);
basic_tokenizer = new BasicTokenizer(do_lower_case);
wordpiece_tokenizer = new WordpieceTokenizer(vocab);
}
~FullTokenizer() {
if (wordpiece_tokenizer != NULL){
wordpiece_tokenizer = NULL;
}
delete wordpiece_tokenizer;
if (basic_tokenizer != NULL){
basic_tokenizer = NULL;
}
delete basic_tokenizer;
if (vocab != NULL){
vocab = NULL;
}
delete vocab;
}
void tokenize(const char *text, std::vector<std::string> *output_tokens, size_t max_length);
inline uint64_t convert_token_to_id(const std::string &token) {
auto item = vocab->find(token);
if (item == vocab->end()) {
std::cerr << "vocab missing key: " << token << std::endl;
return 0;
} else {
return item->second;
}
}
void convert_tokens_to_ids(const std::vector<std::string> &tokens, uint64_t *ids);
private:
std::unordered_map<std::string, uint64_t> *vocab;
BasicTokenizer *basic_tokenizer;
WordpieceTokenizer *wordpiece_tokenizer;
};
}
#endif //CUBERT_TOKENIZATION_H
/* -*- mode: c; c-basic-offset: 2; tab-width: 2; indent-tabs-mode: nil -*- */
/*
* Copyright (c) 2014-2021 Steven G. Johnson, Jiahao Chen, Peter Colberg, Tony Kelman, Scott P. Jones, and other contributors.
* Copyright (c) 2009 Public Software Group e. V., Berlin, Germany
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*/
/*
* This library contains derived data from a modified version of the
* Unicode data files.
*
* The original data files are available at
* https://www.unicode.org/Public/UNIDATA/
*
* Please notice the copyright statement in the file "utf8proc_data.c".
*/
/*
* File name: utf8proc.c
*
* Description:
* Implementation of libutf8proc.
*/
#include "utf8proc.h"
#ifndef SSIZE_MAX
#define SSIZE_MAX ((size_t)SIZE_MAX/2)
#endif
#ifndef UINT16_MAX
# define UINT16_MAX 65535U
#endif
#include "utf8proc_data.c"
UTF8PROC_DLLEXPORT const utf8proc_int8_t utf8proc_utf8class[256] = {
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0 };
#define UTF8PROC_HANGUL_SBASE 0xAC00
#define UTF8PROC_HANGUL_LBASE 0x1100
#define UTF8PROC_HANGUL_VBASE 0x1161
#define UTF8PROC_HANGUL_TBASE 0x11A7
#define UTF8PROC_HANGUL_LCOUNT 19
#define UTF8PROC_HANGUL_VCOUNT 21
#define UTF8PROC_HANGUL_TCOUNT 28
#define UTF8PROC_HANGUL_NCOUNT 588
#define UTF8PROC_HANGUL_SCOUNT 11172
/* END is exclusive */
#define UTF8PROC_HANGUL_L_START 0x1100
#define UTF8PROC_HANGUL_L_END 0x115A
#define UTF8PROC_HANGUL_L_FILLER 0x115F
#define UTF8PROC_HANGUL_V_START 0x1160
#define UTF8PROC_HANGUL_V_END 0x11A3
#define UTF8PROC_HANGUL_T_START 0x11A8
#define UTF8PROC_HANGUL_T_END 0x11FA
#define UTF8PROC_HANGUL_S_START 0xAC00
#define UTF8PROC_HANGUL_S_END 0xD7A4
/* Should follow semantic-versioning rules (semver.org) based on API
compatibility. (Note that the shared-library version number will
be different, being based on ABI compatibility.): */
#define STRINGIZEx(x) #x
#define STRINGIZE(x) STRINGIZEx(x)
UTF8PROC_DLLEXPORT const char *utf8proc_version(void) {
return STRINGIZE(UTF8PROC_VERSION_MAJOR) "." STRINGIZE(UTF8PROC_VERSION_MINOR) "." STRINGIZE(UTF8PROC_VERSION_PATCH) "";
}
UTF8PROC_DLLEXPORT const char *utf8proc_unicode_version(void) {
return "15.0.0";
}
UTF8PROC_DLLEXPORT const char *utf8proc_errmsg(utf8proc_ssize_t errcode) {
switch (errcode) {
case UTF8PROC_ERROR_NOMEM:
return "Memory for processing UTF-8 data could not be allocated.";
case UTF8PROC_ERROR_OVERFLOW:
return "UTF-8 string is too long to be processed.";
case UTF8PROC_ERROR_INVALIDUTF8:
return "Invalid UTF-8 string";
case UTF8PROC_ERROR_NOTASSIGNED:
return "Unassigned Unicode code point found in UTF-8 string.";
case UTF8PROC_ERROR_INVALIDOPTS:
return "Invalid options for UTF-8 processing chosen.";
default:
return "An unknown error occurred while processing UTF-8 data.";
}
}
#define utf_cont(ch) (((ch) & 0xc0) == 0x80)
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_iterate(
const utf8proc_uint8_t *str, utf8proc_ssize_t strlen, utf8proc_int32_t *dst
) {
utf8proc_int32_t uc;
const utf8proc_uint8_t *end;
*dst = -1;
if (!strlen) return 0;
end = str + ((strlen < 0) ? 4 : strlen);
uc = *str++;
if (uc < 0x80) {
*dst = uc;
return 1;
}
// Must be between 0xc2 and 0xf4 inclusive to be valid
if ((utf8proc_uint32_t)(uc - 0xc2) > (0xf4-0xc2)) return UTF8PROC_ERROR_INVALIDUTF8;
if (uc < 0xe0) { // 2-byte sequence
// Must have valid continuation character
if (str >= end || !utf_cont(*str)) return UTF8PROC_ERROR_INVALIDUTF8;
*dst = ((uc & 0x1f)<<6) | (*str & 0x3f);
return 2;
}
if (uc < 0xf0) { // 3-byte sequence
if ((str + 1 >= end) || !utf_cont(*str) || !utf_cont(str[1]))
return UTF8PROC_ERROR_INVALIDUTF8;
// Check for surrogate chars
if (uc == 0xed && *str > 0x9f)
return UTF8PROC_ERROR_INVALIDUTF8;
uc = ((uc & 0xf)<<12) | ((*str & 0x3f)<<6) | (str[1] & 0x3f);
if (uc < 0x800)
return UTF8PROC_ERROR_INVALIDUTF8;
*dst = uc;
return 3;
}
// 4-byte sequence
// Must have 3 valid continuation characters
if ((str + 2 >= end) || !utf_cont(*str) || !utf_cont(str[1]) || !utf_cont(str[2]))
return UTF8PROC_ERROR_INVALIDUTF8;
// Make sure in correct range (0x10000 - 0x10ffff)
if (uc == 0xf0) {
if (*str < 0x90) return UTF8PROC_ERROR_INVALIDUTF8;
} else if (uc == 0xf4) {
if (*str > 0x8f) return UTF8PROC_ERROR_INVALIDUTF8;
}
*dst = ((uc & 7)<<18) | ((*str & 0x3f)<<12) | ((str[1] & 0x3f)<<6) | (str[2] & 0x3f);
return 4;
}
UTF8PROC_DLLEXPORT utf8proc_bool utf8proc_codepoint_valid(utf8proc_int32_t uc) {
return (((utf8proc_uint32_t)uc)-0xd800 > 0x07ff) && ((utf8proc_uint32_t)uc < 0x110000);
}
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_encode_char(utf8proc_int32_t uc, utf8proc_uint8_t *dst) {
if (uc < 0x00) {
return 0;
} else if (uc < 0x80) {
dst[0] = (utf8proc_uint8_t) uc;
return 1;
} else if (uc < 0x800) {
dst[0] = (utf8proc_uint8_t)(0xC0 + (uc >> 6));
dst[1] = (utf8proc_uint8_t)(0x80 + (uc & 0x3F));
return 2;
// Note: we allow encoding 0xd800-0xdfff here, so as not to change
// the API, however, these are actually invalid in UTF-8
} else if (uc < 0x10000) {
dst[0] = (utf8proc_uint8_t)(0xE0 + (uc >> 12));
dst[1] = (utf8proc_uint8_t)(0x80 + ((uc >> 6) & 0x3F));
dst[2] = (utf8proc_uint8_t)(0x80 + (uc & 0x3F));
return 3;
} else if (uc < 0x110000) {
dst[0] = (utf8proc_uint8_t)(0xF0 + (uc >> 18));
dst[1] = (utf8proc_uint8_t)(0x80 + ((uc >> 12) & 0x3F));
dst[2] = (utf8proc_uint8_t)(0x80 + ((uc >> 6) & 0x3F));
dst[3] = (utf8proc_uint8_t)(0x80 + (uc & 0x3F));
return 4;
} else return 0;
}
/* internal version used for inserting 0xff bytes between graphemes */
static utf8proc_ssize_t charbound_encode_char(utf8proc_int32_t uc, utf8proc_uint8_t *dst) {
if (uc < 0x00) {
if (uc == -1) { /* internal value used for grapheme breaks */
dst[0] = (utf8proc_uint8_t)0xFF;
return 1;
}
return 0;
} else if (uc < 0x80) {
dst[0] = (utf8proc_uint8_t)uc;
return 1;
} else if (uc < 0x800) {
dst[0] = (utf8proc_uint8_t)(0xC0 + (uc >> 6));
dst[1] = (utf8proc_uint8_t)(0x80 + (uc & 0x3F));
return 2;
} else if (uc < 0x10000) {
dst[0] = (utf8proc_uint8_t)(0xE0 + (uc >> 12));
dst[1] = (utf8proc_uint8_t)(0x80 + ((uc >> 6) & 0x3F));
dst[2] = (utf8proc_uint8_t)(0x80 + (uc & 0x3F));
return 3;
} else if (uc < 0x110000) {
dst[0] = (utf8proc_uint8_t)(0xF0 + (uc >> 18));
dst[1] = (utf8proc_uint8_t)(0x80 + ((uc >> 12) & 0x3F));
dst[2] = (utf8proc_uint8_t)(0x80 + ((uc >> 6) & 0x3F));
dst[3] = (utf8proc_uint8_t)(0x80 + (uc & 0x3F));
return 4;
} else return 0;
}
/* internal "unsafe" version that does not check whether uc is in range */
static const utf8proc_property_t *unsafe_get_property(utf8proc_int32_t uc) {
/* ASSERT: uc >= 0 && uc < 0x110000 */
return utf8proc_properties + (
utf8proc_stage2table[
utf8proc_stage1table[uc >> 8] + (uc & 0xFF)
]
);
}
UTF8PROC_DLLEXPORT const utf8proc_property_t *utf8proc_get_property(utf8proc_int32_t uc) {
return uc < 0 || uc >= 0x110000 ? utf8proc_properties : unsafe_get_property(uc);
}
/* return whether there is a grapheme break between boundclasses lbc and tbc
(according to the definition of extended grapheme clusters)
Rule numbering refers to TR29 Version 29 (Unicode 9.0.0):
http://www.unicode.org/reports/tr29/tr29-29.html
CAVEATS:
Please note that evaluation of GB10 (grapheme breaks between emoji zwj sequences)
and GB 12/13 (regional indicator code points) require knowledge of previous characters
and are thus not handled by this function. This may result in an incorrect break before
an E_Modifier class codepoint and an incorrectly missing break between two
REGIONAL_INDICATOR class code points if such support does not exist in the caller.
See the special support in grapheme_break_extended, for required bookkeeping by the caller.
*/
static utf8proc_bool grapheme_break_simple(int lbc, int tbc) {
return
(lbc == UTF8PROC_BOUNDCLASS_START) ? true : // GB1
(lbc == UTF8PROC_BOUNDCLASS_CR && // GB3
tbc == UTF8PROC_BOUNDCLASS_LF) ? false : // ---
(lbc >= UTF8PROC_BOUNDCLASS_CR && lbc <= UTF8PROC_BOUNDCLASS_CONTROL) ? true : // GB4
(tbc >= UTF8PROC_BOUNDCLASS_CR && tbc <= UTF8PROC_BOUNDCLASS_CONTROL) ? true : // GB5
(lbc == UTF8PROC_BOUNDCLASS_L && // GB6
(tbc == UTF8PROC_BOUNDCLASS_L || // ---
tbc == UTF8PROC_BOUNDCLASS_V || // ---
tbc == UTF8PROC_BOUNDCLASS_LV || // ---
tbc == UTF8PROC_BOUNDCLASS_LVT)) ? false : // ---
((lbc == UTF8PROC_BOUNDCLASS_LV || // GB7
lbc == UTF8PROC_BOUNDCLASS_V) && // ---
(tbc == UTF8PROC_BOUNDCLASS_V || // ---
tbc == UTF8PROC_BOUNDCLASS_T)) ? false : // ---
((lbc == UTF8PROC_BOUNDCLASS_LVT || // GB8
lbc == UTF8PROC_BOUNDCLASS_T) && // ---
tbc == UTF8PROC_BOUNDCLASS_T) ? false : // ---
(tbc == UTF8PROC_BOUNDCLASS_EXTEND || // GB9
tbc == UTF8PROC_BOUNDCLASS_ZWJ || // ---
tbc == UTF8PROC_BOUNDCLASS_SPACINGMARK || // GB9a
lbc == UTF8PROC_BOUNDCLASS_PREPEND) ? false : // GB9b
(lbc == UTF8PROC_BOUNDCLASS_E_ZWG && // GB11 (requires additional handling below)
tbc == UTF8PROC_BOUNDCLASS_EXTENDED_PICTOGRAPHIC) ? false : // ----
(lbc == UTF8PROC_BOUNDCLASS_REGIONAL_INDICATOR && // GB12/13 (requires additional handling below)
tbc == UTF8PROC_BOUNDCLASS_REGIONAL_INDICATOR) ? false : // ----
true; // GB999
}
static utf8proc_bool grapheme_break_extended(int lbc, int tbc, utf8proc_int32_t *state)
{
if (state) {
int lbc_override;
if (*state == UTF8PROC_BOUNDCLASS_START)
*state = lbc_override = lbc;
else
lbc_override = *state;
utf8proc_bool break_permitted = grapheme_break_simple(lbc_override, tbc);
// Special support for GB 12/13 made possible by GB999. After two RI
// class codepoints we want to force a break. Do this by resetting the
// second RI's bound class to UTF8PROC_BOUNDCLASS_OTHER, to force a break
// after that character according to GB999 (unless of course such a break is
// forbidden by a different rule such as GB9).
if (*state == tbc && tbc == UTF8PROC_BOUNDCLASS_REGIONAL_INDICATOR)
*state = UTF8PROC_BOUNDCLASS_OTHER;
// Special support for GB11 (emoji extend* zwj / emoji)
else if (*state == UTF8PROC_BOUNDCLASS_EXTENDED_PICTOGRAPHIC) {
if (tbc == UTF8PROC_BOUNDCLASS_EXTEND) // fold EXTEND codepoints into emoji
*state = UTF8PROC_BOUNDCLASS_EXTENDED_PICTOGRAPHIC;
else if (tbc == UTF8PROC_BOUNDCLASS_ZWJ)
*state = UTF8PROC_BOUNDCLASS_E_ZWG; // state to record emoji+zwg combo
else
*state = tbc;
}
else
*state = tbc;
return break_permitted;
}
else
return grapheme_break_simple(lbc, tbc);
}
UTF8PROC_DLLEXPORT utf8proc_bool utf8proc_grapheme_break_stateful(
utf8proc_int32_t c1, utf8proc_int32_t c2, utf8proc_int32_t *state) {
return grapheme_break_extended(utf8proc_get_property(c1)->boundclass,
utf8proc_get_property(c2)->boundclass,
state);
}
UTF8PROC_DLLEXPORT utf8proc_bool utf8proc_grapheme_break(
utf8proc_int32_t c1, utf8proc_int32_t c2) {
return utf8proc_grapheme_break_stateful(c1, c2, NULL);
}
static utf8proc_int32_t seqindex_decode_entry(const utf8proc_uint16_t **entry)
{
utf8proc_int32_t entry_cp = **entry;
if ((entry_cp & 0xF800) == 0xD800) {
*entry = *entry + 1;
entry_cp = ((entry_cp & 0x03FF) << 10) | (**entry & 0x03FF);
entry_cp += 0x10000;
}
return entry_cp;
}
static utf8proc_int32_t seqindex_decode_index(const utf8proc_uint32_t seqindex)
{
const utf8proc_uint16_t *entry = &utf8proc_sequences[seqindex];
return seqindex_decode_entry(&entry);
}
static utf8proc_ssize_t seqindex_write_char_decomposed(utf8proc_uint16_t seqindex, utf8proc_int32_t *dst, utf8proc_ssize_t bufsize, utf8proc_option_t options, int *last_boundclass) {
utf8proc_ssize_t written = 0;
const utf8proc_uint16_t *entry = &utf8proc_sequences[seqindex & 0x3FFF];
int len = seqindex >> 14;
if (len >= 3) {
len = *entry;
entry++;
}
for (; len >= 0; entry++, len--) {
utf8proc_int32_t entry_cp = seqindex_decode_entry(&entry);
written += utf8proc_decompose_char(entry_cp, dst+written,
(bufsize > written) ? (bufsize - written) : 0, options,
last_boundclass);
if (written < 0) return UTF8PROC_ERROR_OVERFLOW;
}
return written;
}
UTF8PROC_DLLEXPORT utf8proc_int32_t utf8proc_tolower(utf8proc_int32_t c)
{
utf8proc_int32_t cl = utf8proc_get_property(c)->lowercase_seqindex;
return cl != UINT16_MAX ? seqindex_decode_index((utf8proc_uint32_t)cl) : c;
}
UTF8PROC_DLLEXPORT utf8proc_int32_t utf8proc_toupper(utf8proc_int32_t c)
{
utf8proc_int32_t cu = utf8proc_get_property(c)->uppercase_seqindex;
return cu != UINT16_MAX ? seqindex_decode_index((utf8proc_uint32_t)cu) : c;
}
UTF8PROC_DLLEXPORT utf8proc_int32_t utf8proc_totitle(utf8proc_int32_t c)
{
utf8proc_int32_t cu = utf8proc_get_property(c)->titlecase_seqindex;
return cu != UINT16_MAX ? seqindex_decode_index((utf8proc_uint32_t)cu) : c;
}
UTF8PROC_DLLEXPORT int utf8proc_islower(utf8proc_int32_t c)
{
const utf8proc_property_t *p = utf8proc_get_property(c);
return p->lowercase_seqindex != p->uppercase_seqindex && p->lowercase_seqindex == UINT16_MAX;
}
UTF8PROC_DLLEXPORT int utf8proc_isupper(utf8proc_int32_t c)
{
const utf8proc_property_t *p = utf8proc_get_property(c);
return p->lowercase_seqindex != p->uppercase_seqindex && p->uppercase_seqindex == UINT16_MAX && p->category != UTF8PROC_CATEGORY_LT;
}
/* return a character width analogous to wcwidth (except portable and
hopefully less buggy than most system wcwidth functions). */
UTF8PROC_DLLEXPORT int utf8proc_charwidth(utf8proc_int32_t c) {
return utf8proc_get_property(c)->charwidth;
}
UTF8PROC_DLLEXPORT utf8proc_category_t utf8proc_category(utf8proc_int32_t c) {
return (utf8proc_category_t) utf8proc_get_property(c)->category;
}
UTF8PROC_DLLEXPORT const char *utf8proc_category_string(utf8proc_int32_t c) {
static const char s[][3] = {"Cn","Lu","Ll","Lt","Lm","Lo","Mn","Mc","Me","Nd","Nl","No","Pc","Pd","Ps","Pe","Pi","Pf","Po","Sm","Sc","Sk","So","Zs","Zl","Zp","Cc","Cf","Cs","Co"};
return s[utf8proc_category(c)];
}
#define utf8proc_decompose_lump(replacement_uc) \
return utf8proc_decompose_char((replacement_uc), dst, bufsize, \
options & ~(unsigned int)UTF8PROC_LUMP, last_boundclass)
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose_char(utf8proc_int32_t uc, utf8proc_int32_t *dst, utf8proc_ssize_t bufsize, utf8proc_option_t options, int *last_boundclass) {
const utf8proc_property_t *property;
utf8proc_propval_t category;
utf8proc_int32_t hangul_sindex;
if (uc < 0 || uc >= 0x110000) return UTF8PROC_ERROR_NOTASSIGNED;
property = unsafe_get_property(uc);
category = property->category;
hangul_sindex = uc - UTF8PROC_HANGUL_SBASE;
if (options & (UTF8PROC_COMPOSE|UTF8PROC_DECOMPOSE)) {
if (hangul_sindex >= 0 && hangul_sindex < UTF8PROC_HANGUL_SCOUNT) {
utf8proc_int32_t hangul_tindex;
if (bufsize >= 1) {
dst[0] = UTF8PROC_HANGUL_LBASE +
hangul_sindex / UTF8PROC_HANGUL_NCOUNT;
if (bufsize >= 2) dst[1] = UTF8PROC_HANGUL_VBASE +
(hangul_sindex % UTF8PROC_HANGUL_NCOUNT) / UTF8PROC_HANGUL_TCOUNT;
}
hangul_tindex = hangul_sindex % UTF8PROC_HANGUL_TCOUNT;
if (!hangul_tindex) return 2;
if (bufsize >= 3) dst[2] = UTF8PROC_HANGUL_TBASE + hangul_tindex;
return 3;
}
}
if (options & UTF8PROC_REJECTNA) {
if (!category) return UTF8PROC_ERROR_NOTASSIGNED;
}
if (options & UTF8PROC_IGNORE) {
if (property->ignorable) return 0;
}
if (options & UTF8PROC_STRIPNA) {
if (!category) return 0;
}
if (options & UTF8PROC_LUMP) {
if (category == UTF8PROC_CATEGORY_ZS) utf8proc_decompose_lump(0x0020);
if (uc == 0x2018 || uc == 0x2019 || uc == 0x02BC || uc == 0x02C8)
utf8proc_decompose_lump(0x0027);
if (category == UTF8PROC_CATEGORY_PD || uc == 0x2212)
utf8proc_decompose_lump(0x002D);
if (uc == 0x2044 || uc == 0x2215) utf8proc_decompose_lump(0x002F);
if (uc == 0x2236) utf8proc_decompose_lump(0x003A);
if (uc == 0x2039 || uc == 0x2329 || uc == 0x3008)
utf8proc_decompose_lump(0x003C);
if (uc == 0x203A || uc == 0x232A || uc == 0x3009)
utf8proc_decompose_lump(0x003E);
if (uc == 0x2216) utf8proc_decompose_lump(0x005C);
if (uc == 0x02C4 || uc == 0x02C6 || uc == 0x2038 || uc == 0x2303)
utf8proc_decompose_lump(0x005E);
if (category == UTF8PROC_CATEGORY_PC || uc == 0x02CD)
utf8proc_decompose_lump(0x005F);
if (uc == 0x02CB) utf8proc_decompose_lump(0x0060);
if (uc == 0x2223) utf8proc_decompose_lump(0x007C);
if (uc == 0x223C) utf8proc_decompose_lump(0x007E);
if ((options & UTF8PROC_NLF2LS) && (options & UTF8PROC_NLF2PS)) {
if (category == UTF8PROC_CATEGORY_ZL ||
category == UTF8PROC_CATEGORY_ZP)
utf8proc_decompose_lump(0x000A);
}
}
if (options & UTF8PROC_STRIPMARK) {
if (category == UTF8PROC_CATEGORY_MN ||
category == UTF8PROC_CATEGORY_MC ||
category == UTF8PROC_CATEGORY_ME) return 0;
}
if (options & UTF8PROC_CASEFOLD) {
if (property->casefold_seqindex != UINT16_MAX) {
return seqindex_write_char_decomposed(property->casefold_seqindex, dst, bufsize, options, last_boundclass);
}
}
if (options & (UTF8PROC_COMPOSE|UTF8PROC_DECOMPOSE)) {
if (property->decomp_seqindex != UINT16_MAX &&
(!property->decomp_type || (options & UTF8PROC_COMPAT))) {
return seqindex_write_char_decomposed(property->decomp_seqindex, dst, bufsize, options, last_boundclass);
}
}
if (options & UTF8PROC_CHARBOUND) {
utf8proc_bool boundary;
int tbc = property->boundclass;
boundary = grapheme_break_extended(*last_boundclass, tbc, last_boundclass);
if (boundary) {
if (bufsize >= 1) dst[0] = -1; /* sentinel value for grapheme break */
if (bufsize >= 2) dst[1] = uc;
return 2;
}
}
if (bufsize >= 1) *dst = uc;
return 1;
}
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose(
const utf8proc_uint8_t *str, utf8proc_ssize_t strlen,
utf8proc_int32_t *buffer, utf8proc_ssize_t bufsize, utf8proc_option_t options
) {
return utf8proc_decompose_custom(str, strlen, buffer, bufsize, options, NULL, NULL);
}
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose_custom(
const utf8proc_uint8_t *str, utf8proc_ssize_t strlen,
utf8proc_int32_t *buffer, utf8proc_ssize_t bufsize, utf8proc_option_t options,
utf8proc_custom_func custom_func, void *custom_data
) {
/* strlen will be ignored, if UTF8PROC_NULLTERM is set in options */
utf8proc_ssize_t wpos = 0;
if ((options & UTF8PROC_COMPOSE) && (options & UTF8PROC_DECOMPOSE))
return UTF8PROC_ERROR_INVALIDOPTS;
if ((options & UTF8PROC_STRIPMARK) &&
!(options & UTF8PROC_COMPOSE) && !(options & UTF8PROC_DECOMPOSE))
return UTF8PROC_ERROR_INVALIDOPTS;
{
utf8proc_int32_t uc;
utf8proc_ssize_t rpos = 0;
utf8proc_ssize_t decomp_result;
int boundclass = UTF8PROC_BOUNDCLASS_START;
while (1) {
if (options & UTF8PROC_NULLTERM) {
rpos += utf8proc_iterate(str + rpos, -1, &uc);
/* checking of return value is not necessary,
as 'uc' is < 0 in case of error */
if (uc < 0) return UTF8PROC_ERROR_INVALIDUTF8;
if (rpos < 0) return UTF8PROC_ERROR_OVERFLOW;
if (uc == 0) break;
} else {
if (rpos >= strlen) break;
rpos += utf8proc_iterate(str + rpos, strlen - rpos, &uc);
if (uc < 0) return UTF8PROC_ERROR_INVALIDUTF8;
}
if (custom_func != NULL) {
uc = custom_func(uc, custom_data); /* user-specified custom mapping */
}
decomp_result = utf8proc_decompose_char(
uc, buffer + wpos, (bufsize > wpos) ? (bufsize - wpos) : 0, options,
&boundclass
);
if (decomp_result < 0) return decomp_result;
wpos += decomp_result;
/* prohibiting integer overflows due to too long strings: */
if (wpos < 0 ||
wpos > (utf8proc_ssize_t)(SSIZE_MAX/sizeof(utf8proc_int32_t)/2))
return UTF8PROC_ERROR_OVERFLOW;
}
}
if ((options & (UTF8PROC_COMPOSE|UTF8PROC_DECOMPOSE)) && bufsize >= wpos) {
utf8proc_ssize_t pos = 0;
while (pos < wpos-1) {
utf8proc_int32_t uc1, uc2;
const utf8proc_property_t *property1, *property2;
uc1 = buffer[pos];
uc2 = buffer[pos+1];
property1 = unsafe_get_property(uc1);
property2 = unsafe_get_property(uc2);
if (property1->combining_class > property2->combining_class &&
property2->combining_class > 0) {
buffer[pos] = uc2;
buffer[pos+1] = uc1;
if (pos > 0) pos--; else pos++;
} else {
pos++;
}
}
}
return wpos;
}
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_normalize_utf32(utf8proc_int32_t *buffer, utf8proc_ssize_t length, utf8proc_option_t options) {
/* UTF8PROC_NULLTERM option will be ignored, 'length' is never ignored */
if (options & (UTF8PROC_NLF2LS | UTF8PROC_NLF2PS | UTF8PROC_STRIPCC)) {
utf8proc_ssize_t rpos;
utf8proc_ssize_t wpos = 0;
utf8proc_int32_t uc;
for (rpos = 0; rpos < length; rpos++) {
uc = buffer[rpos];
if (uc == 0x000D && rpos < length-1 && buffer[rpos+1] == 0x000A) rpos++;
if (uc == 0x000A || uc == 0x000D || uc == 0x0085 ||
((options & UTF8PROC_STRIPCC) && (uc == 0x000B || uc == 0x000C))) {
if (options & UTF8PROC_NLF2LS) {
if (options & UTF8PROC_NLF2PS) {
buffer[wpos++] = 0x000A;
} else {
buffer[wpos++] = 0x2028;
}
} else {
if (options & UTF8PROC_NLF2PS) {
buffer[wpos++] = 0x2029;
} else {
buffer[wpos++] = 0x0020;
}
}
} else if ((options & UTF8PROC_STRIPCC) &&
(uc < 0x0020 || (uc >= 0x007F && uc < 0x00A0))) {
if (uc == 0x0009) buffer[wpos++] = 0x0020;
} else {
buffer[wpos++] = uc;
}
}
length = wpos;
}
if (options & UTF8PROC_COMPOSE) {
utf8proc_int32_t *starter = NULL;
utf8proc_int32_t current_char;
const utf8proc_property_t *starter_property = NULL, *current_property;
utf8proc_propval_t max_combining_class = -1;
utf8proc_ssize_t rpos;
utf8proc_ssize_t wpos = 0;
utf8proc_int32_t composition;
for (rpos = 0; rpos < length; rpos++) {
current_char = buffer[rpos];
current_property = unsafe_get_property(current_char);
if (starter && current_property->combining_class > max_combining_class) {
/* combination perhaps possible */
utf8proc_int32_t hangul_lindex;
utf8proc_int32_t hangul_sindex;
hangul_lindex = *starter - UTF8PROC_HANGUL_LBASE;
if (hangul_lindex >= 0 && hangul_lindex < UTF8PROC_HANGUL_LCOUNT) {
utf8proc_int32_t hangul_vindex;
hangul_vindex = current_char - UTF8PROC_HANGUL_VBASE;
if (hangul_vindex >= 0 && hangul_vindex < UTF8PROC_HANGUL_VCOUNT) {
*starter = UTF8PROC_HANGUL_SBASE +
(hangul_lindex * UTF8PROC_HANGUL_VCOUNT + hangul_vindex) *
UTF8PROC_HANGUL_TCOUNT;
starter_property = NULL;
continue;
}
}
hangul_sindex = *starter - UTF8PROC_HANGUL_SBASE;
if (hangul_sindex >= 0 && hangul_sindex < UTF8PROC_HANGUL_SCOUNT &&
(hangul_sindex % UTF8PROC_HANGUL_TCOUNT) == 0) {
utf8proc_int32_t hangul_tindex;
hangul_tindex = current_char - UTF8PROC_HANGUL_TBASE;
if (hangul_tindex >= 0 && hangul_tindex < UTF8PROC_HANGUL_TCOUNT) {
*starter += hangul_tindex;
starter_property = NULL;
continue;
}
}
if (!starter_property) {
starter_property = unsafe_get_property(*starter);
}
if (starter_property->comb_index < 0x8000 &&
current_property->comb_index != UINT16_MAX &&
current_property->comb_index >= 0x8000) {
int sidx = starter_property->comb_index;
int idx = current_property->comb_index & 0x3FFF;
if (idx >= utf8proc_combinations[sidx] && idx <= utf8proc_combinations[sidx + 1] ) {
idx += sidx + 2 - utf8proc_combinations[sidx];
if (current_property->comb_index & 0x4000) {
composition = (utf8proc_combinations[idx] << 16) | utf8proc_combinations[idx+1];
} else
composition = utf8proc_combinations[idx];
if (composition > 0 && (!(options & UTF8PROC_STABLE) ||
!(unsafe_get_property(composition)->comp_exclusion))) {
*starter = composition;
starter_property = NULL;
continue;
}
}
}
}
buffer[wpos] = current_char;
if (current_property->combining_class) {
if (current_property->combining_class > max_combining_class) {
max_combining_class = current_property->combining_class;
}
} else {
starter = buffer + wpos;
starter_property = NULL;
max_combining_class = -1;
}
wpos++;
}
length = wpos;
}
return length;
}
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_reencode(utf8proc_int32_t *buffer, utf8proc_ssize_t length, utf8proc_option_t options) {
/* UTF8PROC_NULLTERM option will be ignored, 'length' is never ignored
ASSERT: 'buffer' has one spare byte of free space at the end! */
length = utf8proc_normalize_utf32(buffer, length, options);
if (length < 0) return length;
{
utf8proc_ssize_t rpos, wpos = 0;
utf8proc_int32_t uc;
if (options & UTF8PROC_CHARBOUND) {
for (rpos = 0; rpos < length; rpos++) {
uc = buffer[rpos];
wpos += charbound_encode_char(uc, ((utf8proc_uint8_t *)buffer) + wpos);
}
} else {
for (rpos = 0; rpos < length; rpos++) {
uc = buffer[rpos];
wpos += utf8proc_encode_char(uc, ((utf8proc_uint8_t *)buffer) + wpos);
}
}
((utf8proc_uint8_t *)buffer)[wpos] = 0;
return wpos;
}
}
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_map(
const utf8proc_uint8_t *str, utf8proc_ssize_t strlen, utf8proc_uint8_t **dstptr, utf8proc_option_t options
) {
return utf8proc_map_custom(str, strlen, dstptr, options, NULL, NULL);
}
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_map_custom(
const utf8proc_uint8_t *str, utf8proc_ssize_t strlen, utf8proc_uint8_t **dstptr, utf8proc_option_t options,
utf8proc_custom_func custom_func, void *custom_data
) {
utf8proc_int32_t *buffer;
utf8proc_ssize_t result;
*dstptr = NULL;
result = utf8proc_decompose_custom(str, strlen, NULL, 0, options, custom_func, custom_data);
if (result < 0) return result;
buffer = (utf8proc_int32_t *) malloc(((utf8proc_size_t)result) * sizeof(utf8proc_int32_t) + 1);
if (!buffer) return UTF8PROC_ERROR_NOMEM;
result = utf8proc_decompose_custom(str, strlen, buffer, result, options, custom_func, custom_data);
if (result < 0) {
free(buffer);
return result;
}
result = utf8proc_reencode(buffer, result, options);
if (result < 0) {
free(buffer);
return result;
}
{
utf8proc_int32_t *newptr;
newptr = (utf8proc_int32_t *) realloc(buffer, (size_t)result+1);
if (newptr) buffer = newptr;
}
*dstptr = (utf8proc_uint8_t *)buffer;
return result;
}
UTF8PROC_DLLEXPORT utf8proc_uint8_t *utf8proc_NFD(const utf8proc_uint8_t *str) {
utf8proc_uint8_t *retval;
utf8proc_map(str, 0, &retval, UTF8PROC_NULLTERM | UTF8PROC_STABLE |
UTF8PROC_DECOMPOSE);
return retval;
}
UTF8PROC_DLLEXPORT utf8proc_uint8_t *utf8proc_NFC(const utf8proc_uint8_t *str) {
utf8proc_uint8_t *retval;
utf8proc_map(str, 0, &retval, UTF8PROC_NULLTERM | UTF8PROC_STABLE |
UTF8PROC_COMPOSE);
return retval;
}
UTF8PROC_DLLEXPORT utf8proc_uint8_t *utf8proc_NFKD(const utf8proc_uint8_t *str) {
utf8proc_uint8_t *retval;
utf8proc_map(str, 0, &retval, UTF8PROC_NULLTERM | UTF8PROC_STABLE |
UTF8PROC_DECOMPOSE | UTF8PROC_COMPAT);
return retval;
}
UTF8PROC_DLLEXPORT utf8proc_uint8_t *utf8proc_NFKC(const utf8proc_uint8_t *str) {
utf8proc_uint8_t *retval;
utf8proc_map(str, 0, &retval, UTF8PROC_NULLTERM | UTF8PROC_STABLE |
UTF8PROC_COMPOSE | UTF8PROC_COMPAT);
return retval;
}
UTF8PROC_DLLEXPORT utf8proc_uint8_t *utf8proc_NFKC_Casefold(const utf8proc_uint8_t *str) {
utf8proc_uint8_t *retval;
utf8proc_map(str, 0, &retval, UTF8PROC_NULLTERM | UTF8PROC_STABLE |
UTF8PROC_COMPOSE | UTF8PROC_COMPAT | UTF8PROC_CASEFOLD | UTF8PROC_IGNORE);
return retval;
}
/*
* Copyright (c) 2014-2021 Steven G. Johnson, Jiahao Chen, Peter Colberg, Tony Kelman, Scott P. Jones, and other contributors.
* Copyright (c) 2009 Public Software Group e. V., Berlin, Germany
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*/
/**
* @mainpage
*
* utf8proc is a free/open-source (MIT/expat licensed) C library
* providing Unicode normalization, case-folding, and other operations
* for strings in the UTF-8 encoding, supporting up-to-date Unicode versions.
* See the utf8proc home page (http://julialang.org/utf8proc/)
* for downloads and other information, or the source code on github
* (https://github.com/JuliaLang/utf8proc).
*
* For the utf8proc API documentation, see: @ref utf8proc.h
*
* The features of utf8proc include:
*
* - Transformation of strings (@ref utf8proc_map) to:
* - decompose (@ref UTF8PROC_DECOMPOSE) or compose (@ref UTF8PROC_COMPOSE) Unicode combining characters (http://en.wikipedia.org/wiki/Combining_character)
* - canonicalize Unicode compatibility characters (@ref UTF8PROC_COMPAT)
* - strip "ignorable" (@ref UTF8PROC_IGNORE) characters, control characters (@ref UTF8PROC_STRIPCC), or combining characters such as accents (@ref UTF8PROC_STRIPMARK)
* - case-folding (@ref UTF8PROC_CASEFOLD)
* - Unicode normalization: @ref utf8proc_NFD, @ref utf8proc_NFC, @ref utf8proc_NFKD, @ref utf8proc_NFKC
* - Detecting grapheme boundaries (@ref utf8proc_grapheme_break and @ref UTF8PROC_CHARBOUND)
* - Character-width computation: @ref utf8proc_charwidth
* - Classification of characters by Unicode category: @ref utf8proc_category and @ref utf8proc_category_string
* - Encode (@ref utf8proc_encode_char) and decode (@ref utf8proc_iterate) Unicode codepoints to/from UTF-8.
*/
/** @file */
#ifndef UTF8PROC_H
#define UTF8PROC_H
/** @name API version
*
* The utf8proc API version MAJOR.MINOR.PATCH, following
* semantic-versioning rules (http://semver.org) based on API
* compatibility.
*
* This is also returned at runtime by @ref utf8proc_version; however, the
* runtime version may append a string like "-dev" to the version number
* for prerelease versions.
*
* @note The shared-library version number in the Makefile
* (and CMakeLists.txt, and MANIFEST) may be different,
* being based on ABI compatibility rather than API compatibility.
*/
/** @{ */
/** The MAJOR version number (increased when backwards API compatibility is broken). */
#define UTF8PROC_VERSION_MAJOR 2
/** The MINOR version number (increased when new functionality is added in a backwards-compatible manner). */
#define UTF8PROC_VERSION_MINOR 8
/** The PATCH version (increased for fixes that do not change the API). */
#define UTF8PROC_VERSION_PATCH 0
/** @} */
#include <stdlib.h>
#if defined(_MSC_VER) && _MSC_VER < 1800
// MSVC prior to 2013 lacked stdbool.h and inttypes.h
typedef signed char utf8proc_int8_t;
typedef unsigned char utf8proc_uint8_t;
typedef short utf8proc_int16_t;
typedef unsigned short utf8proc_uint16_t;
typedef int utf8proc_int32_t;
typedef unsigned int utf8proc_uint32_t;
# ifdef _WIN64
typedef __int64 utf8proc_ssize_t;
typedef unsigned __int64 utf8proc_size_t;
# else
typedef int utf8proc_ssize_t;
typedef unsigned int utf8proc_size_t;
# endif
# ifndef __cplusplus
// emulate C99 bool
typedef unsigned char utf8proc_bool;
# ifndef __bool_true_false_are_defined
# define false 0
# define true 1
# define __bool_true_false_are_defined 1
# endif
# else
typedef bool utf8proc_bool;
# endif
#else
# include <stddef.h>
# include <stdbool.h>
# include <inttypes.h>
typedef int8_t utf8proc_int8_t;
typedef uint8_t utf8proc_uint8_t;
typedef int16_t utf8proc_int16_t;
typedef uint16_t utf8proc_uint16_t;
typedef int32_t utf8proc_int32_t;
typedef uint32_t utf8proc_uint32_t;
typedef size_t utf8proc_size_t;
typedef ptrdiff_t utf8proc_ssize_t;
typedef bool utf8proc_bool;
#endif
#include <limits.h>
#ifdef UTF8PROC_STATIC
# define UTF8PROC_DLLEXPORT
#else
# ifdef _WIN32
# ifdef UTF8PROC_EXPORTS
# define UTF8PROC_DLLEXPORT __declspec(dllexport)
# else
# define UTF8PROC_DLLEXPORT __declspec(dllimport)
# endif
# elif __GNUC__ >= 4
# define UTF8PROC_DLLEXPORT __attribute__ ((visibility("default")))
# else
# define UTF8PROC_DLLEXPORT
# endif
#endif
#ifdef __cplusplus
extern "C" {
#endif
/**
* Option flags used by several functions in the library.
*/
typedef enum {
/** The given UTF-8 input is NULL terminated. */
UTF8PROC_NULLTERM = (1<<0),
/** Unicode Versioning Stability has to be respected. */
UTF8PROC_STABLE = (1<<1),
/** Compatibility decomposition (i.e. formatting information is lost). */
UTF8PROC_COMPAT = (1<<2),
/** Return a result with decomposed characters. */
UTF8PROC_COMPOSE = (1<<3),
/** Return a result with decomposed characters. */
UTF8PROC_DECOMPOSE = (1<<4),
/** Strip "default ignorable characters" such as SOFT-HYPHEN or ZERO-WIDTH-SPACE. */
UTF8PROC_IGNORE = (1<<5),
/** Return an error, if the input contains unassigned codepoints. */
UTF8PROC_REJECTNA = (1<<6),
/**
* Indicating that NLF-sequences (LF, CRLF, CR, NEL) are representing a
* line break, and should be converted to the codepoint for line
* separation (LS).
*/
UTF8PROC_NLF2LS = (1<<7),
/**
* Indicating that NLF-sequences are representing a paragraph break, and
* should be converted to the codepoint for paragraph separation
* (PS).
*/
UTF8PROC_NLF2PS = (1<<8),
/** Indicating that the meaning of NLF-sequences is unknown. */
UTF8PROC_NLF2LF = (UTF8PROC_NLF2LS | UTF8PROC_NLF2PS),
/** Strips and/or convers control characters.
*
* NLF-sequences are transformed into space, except if one of the
* NLF2LS/PS/LF options is given. HorizontalTab (HT) and FormFeed (FF)
* are treated as a NLF-sequence in this case. All other control
* characters are simply removed.
*/
UTF8PROC_STRIPCC = (1<<9),
/**
* Performs unicode case folding, to be able to do a case-insensitive
* string comparison.
*/
UTF8PROC_CASEFOLD = (1<<10),
/**
* Inserts 0xFF bytes at the beginning of each sequence which is
* representing a single grapheme cluster (see UAX#29).
*/
UTF8PROC_CHARBOUND = (1<<11),
/** Lumps certain characters together.
*
* E.g. HYPHEN U+2010 and MINUS U+2212 to ASCII "-". See lump.md for details.
*
* If NLF2LF is set, this includes a transformation of paragraph and
* line separators to ASCII line-feed (LF).
*/
UTF8PROC_LUMP = (1<<12),
/** Strips all character markings.
*
* This includes non-spacing, spacing and enclosing (i.e. accents).
* @note This option works only with @ref UTF8PROC_COMPOSE or
* @ref UTF8PROC_DECOMPOSE
*/
UTF8PROC_STRIPMARK = (1<<13),
/**
* Strip unassigned codepoints.
*/
UTF8PROC_STRIPNA = (1<<14),
} utf8proc_option_t;
/** @name Error codes
* Error codes being returned by almost all functions.
*/
/** @{ */
/** Memory could not be allocated. */
#define UTF8PROC_ERROR_NOMEM -1
/** The given string is too long to be processed. */
#define UTF8PROC_ERROR_OVERFLOW -2
/** The given string is not a legal UTF-8 string. */
#define UTF8PROC_ERROR_INVALIDUTF8 -3
/** The @ref UTF8PROC_REJECTNA flag was set and an unassigned codepoint was found. */
#define UTF8PROC_ERROR_NOTASSIGNED -4
/** Invalid options have been used. */
#define UTF8PROC_ERROR_INVALIDOPTS -5
/** @} */
/* @name Types */
/** Holds the value of a property. */
typedef utf8proc_int16_t utf8proc_propval_t;
/** Struct containing information about a codepoint. */
typedef struct utf8proc_property_struct {
/**
* Unicode category.
* @see utf8proc_category_t.
*/
utf8proc_propval_t category;
utf8proc_propval_t combining_class;
/**
* Bidirectional class.
* @see utf8proc_bidi_class_t.
*/
utf8proc_propval_t bidi_class;
/**
* @anchor Decomposition type.
* @see utf8proc_decomp_type_t.
*/
utf8proc_propval_t decomp_type;
utf8proc_uint16_t decomp_seqindex;
utf8proc_uint16_t casefold_seqindex;
utf8proc_uint16_t uppercase_seqindex;
utf8proc_uint16_t lowercase_seqindex;
utf8proc_uint16_t titlecase_seqindex;
utf8proc_uint16_t comb_index;
unsigned bidi_mirrored:1;
unsigned comp_exclusion:1;
/**
* Can this codepoint be ignored?
*
* Used by @ref utf8proc_decompose_char when @ref UTF8PROC_IGNORE is
* passed as an option.
*/
unsigned ignorable:1;
unsigned control_boundary:1;
/** The width of the codepoint. */
unsigned charwidth:2;
unsigned pad:2;
/**
* Boundclass.
* @see utf8proc_boundclass_t.
*/
unsigned boundclass:8;
} utf8proc_property_t;
/** Unicode categories. */
typedef enum {
UTF8PROC_CATEGORY_CN = 0, /**< Other, not assigned */
UTF8PROC_CATEGORY_LU = 1, /**< Letter, uppercase */
UTF8PROC_CATEGORY_LL = 2, /**< Letter, lowercase */
UTF8PROC_CATEGORY_LT = 3, /**< Letter, titlecase */
UTF8PROC_CATEGORY_LM = 4, /**< Letter, modifier */
UTF8PROC_CATEGORY_LO = 5, /**< Letter, other */
UTF8PROC_CATEGORY_MN = 6, /**< Mark, nonspacing */
UTF8PROC_CATEGORY_MC = 7, /**< Mark, spacing combining */
UTF8PROC_CATEGORY_ME = 8, /**< Mark, enclosing */
UTF8PROC_CATEGORY_ND = 9, /**< Number, decimal digit */
UTF8PROC_CATEGORY_NL = 10, /**< Number, letter */
UTF8PROC_CATEGORY_NO = 11, /**< Number, other */
UTF8PROC_CATEGORY_PC = 12, /**< Punctuation, connector */
UTF8PROC_CATEGORY_PD = 13, /**< Punctuation, dash */
UTF8PROC_CATEGORY_PS = 14, /**< Punctuation, open */
UTF8PROC_CATEGORY_PE = 15, /**< Punctuation, close */
UTF8PROC_CATEGORY_PI = 16, /**< Punctuation, initial quote */
UTF8PROC_CATEGORY_PF = 17, /**< Punctuation, final quote */
UTF8PROC_CATEGORY_PO = 18, /**< Punctuation, other */
UTF8PROC_CATEGORY_SM = 19, /**< Symbol, math */
UTF8PROC_CATEGORY_SC = 20, /**< Symbol, currency */
UTF8PROC_CATEGORY_SK = 21, /**< Symbol, modifier */
UTF8PROC_CATEGORY_SO = 22, /**< Symbol, other */
UTF8PROC_CATEGORY_ZS = 23, /**< Separator, space */
UTF8PROC_CATEGORY_ZL = 24, /**< Separator, line */
UTF8PROC_CATEGORY_ZP = 25, /**< Separator, paragraph */
UTF8PROC_CATEGORY_CC = 26, /**< Other, control */
UTF8PROC_CATEGORY_CF = 27, /**< Other, format */
UTF8PROC_CATEGORY_CS = 28, /**< Other, surrogate */
UTF8PROC_CATEGORY_CO = 29, /**< Other, private use */
} utf8proc_category_t;
/** Bidirectional character classes. */
typedef enum {
UTF8PROC_BIDI_CLASS_L = 1, /**< Left-to-Right */
UTF8PROC_BIDI_CLASS_LRE = 2, /**< Left-to-Right Embedding */
UTF8PROC_BIDI_CLASS_LRO = 3, /**< Left-to-Right Override */
UTF8PROC_BIDI_CLASS_R = 4, /**< Right-to-Left */
UTF8PROC_BIDI_CLASS_AL = 5, /**< Right-to-Left Arabic */
UTF8PROC_BIDI_CLASS_RLE = 6, /**< Right-to-Left Embedding */
UTF8PROC_BIDI_CLASS_RLO = 7, /**< Right-to-Left Override */
UTF8PROC_BIDI_CLASS_PDF = 8, /**< Pop Directional Format */
UTF8PROC_BIDI_CLASS_EN = 9, /**< European Number */
UTF8PROC_BIDI_CLASS_ES = 10, /**< European Separator */
UTF8PROC_BIDI_CLASS_ET = 11, /**< European Number Terminator */
UTF8PROC_BIDI_CLASS_AN = 12, /**< Arabic Number */
UTF8PROC_BIDI_CLASS_CS = 13, /**< Common Number Separator */
UTF8PROC_BIDI_CLASS_NSM = 14, /**< Nonspacing Mark */
UTF8PROC_BIDI_CLASS_BN = 15, /**< Boundary Neutral */
UTF8PROC_BIDI_CLASS_B = 16, /**< Paragraph Separator */
UTF8PROC_BIDI_CLASS_S = 17, /**< Segment Separator */
UTF8PROC_BIDI_CLASS_WS = 18, /**< Whitespace */
UTF8PROC_BIDI_CLASS_ON = 19, /**< Other Neutrals */
UTF8PROC_BIDI_CLASS_LRI = 20, /**< Left-to-Right Isolate */
UTF8PROC_BIDI_CLASS_RLI = 21, /**< Right-to-Left Isolate */
UTF8PROC_BIDI_CLASS_FSI = 22, /**< First Strong Isolate */
UTF8PROC_BIDI_CLASS_PDI = 23, /**< Pop Directional Isolate */
} utf8proc_bidi_class_t;
/** Decomposition type. */
typedef enum {
UTF8PROC_DECOMP_TYPE_FONT = 1, /**< Font */
UTF8PROC_DECOMP_TYPE_NOBREAK = 2, /**< Nobreak */
UTF8PROC_DECOMP_TYPE_INITIAL = 3, /**< Initial */
UTF8PROC_DECOMP_TYPE_MEDIAL = 4, /**< Medial */
UTF8PROC_DECOMP_TYPE_FINAL = 5, /**< Final */
UTF8PROC_DECOMP_TYPE_ISOLATED = 6, /**< Isolated */
UTF8PROC_DECOMP_TYPE_CIRCLE = 7, /**< Circle */
UTF8PROC_DECOMP_TYPE_SUPER = 8, /**< Super */
UTF8PROC_DECOMP_TYPE_SUB = 9, /**< Sub */
UTF8PROC_DECOMP_TYPE_VERTICAL = 10, /**< Vertical */
UTF8PROC_DECOMP_TYPE_WIDE = 11, /**< Wide */
UTF8PROC_DECOMP_TYPE_NARROW = 12, /**< Narrow */
UTF8PROC_DECOMP_TYPE_SMALL = 13, /**< Small */
UTF8PROC_DECOMP_TYPE_SQUARE = 14, /**< Square */
UTF8PROC_DECOMP_TYPE_FRACTION = 15, /**< Fraction */
UTF8PROC_DECOMP_TYPE_COMPAT = 16, /**< Compat */
} utf8proc_decomp_type_t;
/** Boundclass property. (TR29) */
typedef enum {
UTF8PROC_BOUNDCLASS_START = 0, /**< Start */
UTF8PROC_BOUNDCLASS_OTHER = 1, /**< Other */
UTF8PROC_BOUNDCLASS_CR = 2, /**< Cr */
UTF8PROC_BOUNDCLASS_LF = 3, /**< Lf */
UTF8PROC_BOUNDCLASS_CONTROL = 4, /**< Control */
UTF8PROC_BOUNDCLASS_EXTEND = 5, /**< Extend */
UTF8PROC_BOUNDCLASS_L = 6, /**< L */
UTF8PROC_BOUNDCLASS_V = 7, /**< V */
UTF8PROC_BOUNDCLASS_T = 8, /**< T */
UTF8PROC_BOUNDCLASS_LV = 9, /**< Lv */
UTF8PROC_BOUNDCLASS_LVT = 10, /**< Lvt */
UTF8PROC_BOUNDCLASS_REGIONAL_INDICATOR = 11, /**< Regional indicator */
UTF8PROC_BOUNDCLASS_SPACINGMARK = 12, /**< Spacingmark */
UTF8PROC_BOUNDCLASS_PREPEND = 13, /**< Prepend */
UTF8PROC_BOUNDCLASS_ZWJ = 14, /**< Zero Width Joiner */
/* the following are no longer used in Unicode 11, but we keep
the constants here for backward compatibility */
UTF8PROC_BOUNDCLASS_E_BASE = 15, /**< Emoji Base */
UTF8PROC_BOUNDCLASS_E_MODIFIER = 16, /**< Emoji Modifier */
UTF8PROC_BOUNDCLASS_GLUE_AFTER_ZWJ = 17, /**< Glue_After_ZWJ */
UTF8PROC_BOUNDCLASS_E_BASE_GAZ = 18, /**< E_BASE + GLUE_AFTER_ZJW */
/* the Extended_Pictographic property is used in the Unicode 11
grapheme-boundary rules, so we store it in the boundclass field */
UTF8PROC_BOUNDCLASS_EXTENDED_PICTOGRAPHIC = 19,
UTF8PROC_BOUNDCLASS_E_ZWG = 20, /* UTF8PROC_BOUNDCLASS_EXTENDED_PICTOGRAPHIC + ZWJ */
} utf8proc_boundclass_t;
/**
* Function pointer type passed to @ref utf8proc_map_custom and
* @ref utf8proc_decompose_custom, which is used to specify a user-defined
* mapping of codepoints to be applied in conjunction with other mappings.
*/
typedef utf8proc_int32_t (*utf8proc_custom_func)(utf8proc_int32_t codepoint, void *data);
/**
* Array containing the byte lengths of a UTF-8 encoded codepoint based
* on the first byte.
*/
UTF8PROC_DLLEXPORT extern const utf8proc_int8_t utf8proc_utf8class[256];
/**
* Returns the utf8proc API version as a string MAJOR.MINOR.PATCH
* (http://semver.org format), possibly with a "-dev" suffix for
* development versions.
*/
UTF8PROC_DLLEXPORT const char *utf8proc_version(void);
/**
* Returns the utf8proc supported Unicode version as a string MAJOR.MINOR.PATCH.
*/
UTF8PROC_DLLEXPORT const char *utf8proc_unicode_version(void);
/**
* Returns an informative error string for the given utf8proc error code
* (e.g. the error codes returned by @ref utf8proc_map).
*/
UTF8PROC_DLLEXPORT const char *utf8proc_errmsg(utf8proc_ssize_t errcode);
/**
* Reads a single codepoint from the UTF-8 sequence being pointed to by `str`.
* The maximum number of bytes read is `strlen`, unless `strlen` is
* negative (in which case up to 4 bytes are read).
*
* If a valid codepoint could be read, it is stored in the variable
* pointed to by `codepoint_ref`, otherwise that variable will be set to -1.
* In case of success, the number of bytes read is returned; otherwise, a
* negative error code is returned.
*/
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_iterate(const utf8proc_uint8_t *str, utf8proc_ssize_t strlen, utf8proc_int32_t *codepoint_ref);
/**
* Check if a codepoint is valid (regardless of whether it has been
* assigned a value by the current Unicode standard).
*
* @return 1 if the given `codepoint` is valid and otherwise return 0.
*/
UTF8PROC_DLLEXPORT utf8proc_bool utf8proc_codepoint_valid(utf8proc_int32_t codepoint);
/**
* Encodes the codepoint as an UTF-8 string in the byte array pointed
* to by `dst`. This array must be at least 4 bytes long.
*
* In case of success the number of bytes written is returned, and
* otherwise 0 is returned.
*
* This function does not check whether `codepoint` is valid Unicode.
*/
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_encode_char(utf8proc_int32_t codepoint, utf8proc_uint8_t *dst);
/**
* Look up the properties for a given codepoint.
*
* @param codepoint The Unicode codepoint.
*
* @returns
* A pointer to a (constant) struct containing information about
* the codepoint.
* @par
* If the codepoint is unassigned or invalid, a pointer to a special struct is
* returned in which `category` is 0 (@ref UTF8PROC_CATEGORY_CN).
*/
UTF8PROC_DLLEXPORT const utf8proc_property_t *utf8proc_get_property(utf8proc_int32_t codepoint);
/** Decompose a codepoint into an array of codepoints.
*
* @param codepoint the codepoint.
* @param dst the destination buffer.
* @param bufsize the size of the destination buffer.
* @param options one or more of the following flags:
* - @ref UTF8PROC_REJECTNA - return an error `codepoint` is unassigned
* - @ref UTF8PROC_IGNORE - strip "default ignorable" codepoints
* - @ref UTF8PROC_CASEFOLD - apply Unicode casefolding
* - @ref UTF8PROC_COMPAT - replace certain codepoints with their
* compatibility decomposition
* - @ref UTF8PROC_CHARBOUND - insert 0xFF bytes before each grapheme cluster
* - @ref UTF8PROC_LUMP - lump certain different codepoints together
* - @ref UTF8PROC_STRIPMARK - remove all character marks
* - @ref UTF8PROC_STRIPNA - remove unassigned codepoints
* @param last_boundclass
* Pointer to an integer variable containing
* the previous codepoint's boundary class if the @ref UTF8PROC_CHARBOUND
* option is used. Otherwise, this parameter is ignored.
*
* @return
* In case of success, the number of codepoints written is returned; in case
* of an error, a negative error code is returned (@ref utf8proc_errmsg).
* @par
* If the number of written codepoints would be bigger than `bufsize`, the
* required buffer size is returned, while the buffer will be overwritten with
* undefined data.
*/
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose_char(
utf8proc_int32_t codepoint, utf8proc_int32_t *dst, utf8proc_ssize_t bufsize,
utf8proc_option_t options, int *last_boundclass
);
/**
* The same as @ref utf8proc_decompose_char, but acts on a whole UTF-8
* string and orders the decomposed sequences correctly.
*
* If the @ref UTF8PROC_NULLTERM flag in `options` is set, processing
* will be stopped, when a NULL byte is encountered, otherwise `strlen`
* bytes are processed. The result (in the form of 32-bit unicode
* codepoints) is written into the buffer being pointed to by
* `buffer` (which must contain at least `bufsize` entries). In case of
* success, the number of codepoints written is returned; in case of an
* error, a negative error code is returned (@ref utf8proc_errmsg).
* See @ref utf8proc_decompose_custom to supply additional transformations.
*
* If the number of written codepoints would be bigger than `bufsize`, the
* required buffer size is returned, while the buffer will be overwritten with
* undefined data.
*/
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose(
const utf8proc_uint8_t *str, utf8proc_ssize_t strlen,
utf8proc_int32_t *buffer, utf8proc_ssize_t bufsize, utf8proc_option_t options
);
/**
* The same as @ref utf8proc_decompose, but also takes a `custom_func` mapping function
* that is called on each codepoint in `str` before any other transformations
* (along with a `custom_data` pointer that is passed through to `custom_func`).
* The `custom_func` argument is ignored if it is `NULL`. See also @ref utf8proc_map_custom.
*/
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_decompose_custom(
const utf8proc_uint8_t *str, utf8proc_ssize_t strlen,
utf8proc_int32_t *buffer, utf8proc_ssize_t bufsize, utf8proc_option_t options,
utf8proc_custom_func custom_func, void *custom_data
);
/**
* Normalizes the sequence of `length` codepoints pointed to by `buffer`
* in-place (i.e., the result is also stored in `buffer`).
*
* @param buffer the (native-endian UTF-32) unicode codepoints to re-encode.
* @param length the length (in codepoints) of the buffer.
* @param options a bitwise or (`|`) of one or more of the following flags:
* - @ref UTF8PROC_NLF2LS - convert LF, CRLF, CR and NEL into LS
* - @ref UTF8PROC_NLF2PS - convert LF, CRLF, CR and NEL into PS
* - @ref UTF8PROC_NLF2LF - convert LF, CRLF, CR and NEL into LF
* - @ref UTF8PROC_STRIPCC - strip or convert all non-affected control characters
* - @ref UTF8PROC_COMPOSE - try to combine decomposed codepoints into composite
* codepoints
* - @ref UTF8PROC_STABLE - prohibit combining characters that would violate
* the unicode versioning stability
*
* @return
* In case of success, the length (in codepoints) of the normalized UTF-32 string is
* returned; otherwise, a negative error code is returned (@ref utf8proc_errmsg).
*
* @warning The entries of the array pointed to by `str` have to be in the
* range `0x0000` to `0x10FFFF`. Otherwise, the program might crash!
*/
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_normalize_utf32(utf8proc_int32_t *buffer, utf8proc_ssize_t length, utf8proc_option_t options);
/**
* Reencodes the sequence of `length` codepoints pointed to by `buffer`
* UTF-8 data in-place (i.e., the result is also stored in `buffer`).
* Can optionally normalize the UTF-32 sequence prior to UTF-8 conversion.
*
* @param buffer the (native-endian UTF-32) unicode codepoints to re-encode.
* @param length the length (in codepoints) of the buffer.
* @param options a bitwise or (`|`) of one or more of the following flags:
* - @ref UTF8PROC_NLF2LS - convert LF, CRLF, CR and NEL into LS
* - @ref UTF8PROC_NLF2PS - convert LF, CRLF, CR and NEL into PS
* - @ref UTF8PROC_NLF2LF - convert LF, CRLF, CR and NEL into LF
* - @ref UTF8PROC_STRIPCC - strip or convert all non-affected control characters
* - @ref UTF8PROC_COMPOSE - try to combine decomposed codepoints into composite
* codepoints
* - @ref UTF8PROC_STABLE - prohibit combining characters that would violate
* the unicode versioning stability
* - @ref UTF8PROC_CHARBOUND - insert 0xFF bytes before each grapheme cluster
*
* @return
* In case of success, the length (in bytes) of the resulting nul-terminated
* UTF-8 string is returned; otherwise, a negative error code is returned
* (@ref utf8proc_errmsg).
*
* @warning The amount of free space pointed to by `buffer` must
* exceed the amount of the input data by one byte, and the
* entries of the array pointed to by `str` have to be in the
* range `0x0000` to `0x10FFFF`. Otherwise, the program might crash!
*/
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_reencode(utf8proc_int32_t *buffer, utf8proc_ssize_t length, utf8proc_option_t options);
/**
* Given a pair of consecutive codepoints, return whether a grapheme break is
* permitted between them (as defined by the extended grapheme clusters in UAX#29).
*
* @param codepoint1 The first codepoint.
* @param codepoint2 The second codepoint, occurring consecutively after `codepoint1`.
* @param state Beginning with Version 29 (Unicode 9.0.0), this algorithm requires
* state to break graphemes. This state can be passed in as a pointer
* in the `state` argument and should initially be set to 0. If the
* state is not passed in (i.e. a null pointer is passed), UAX#29 rules
* GB10/12/13 which require this state will not be applied, essentially
* matching the rules in Unicode 8.0.0.
*
* @warning If the state parameter is used, `utf8proc_grapheme_break_stateful` must
* be called IN ORDER on ALL potential breaks in a string. However, it
* is safe to reset the state to zero after a grapheme break.
*/
UTF8PROC_DLLEXPORT utf8proc_bool utf8proc_grapheme_break_stateful(
utf8proc_int32_t codepoint1, utf8proc_int32_t codepoint2, utf8proc_int32_t *state);
/**
* Same as @ref utf8proc_grapheme_break_stateful, except without support for the
* Unicode 9 additions to the algorithm. Supported for legacy reasons.
*/
UTF8PROC_DLLEXPORT utf8proc_bool utf8proc_grapheme_break(
utf8proc_int32_t codepoint1, utf8proc_int32_t codepoint2);
/**
* Given a codepoint `c`, return the codepoint of the corresponding
* lower-case character, if any; otherwise (if there is no lower-case
* variant, or if `c` is not a valid codepoint) return `c`.
*/
UTF8PROC_DLLEXPORT utf8proc_int32_t utf8proc_tolower(utf8proc_int32_t c);
/**
* Given a codepoint `c`, return the codepoint of the corresponding
* upper-case character, if any; otherwise (if there is no upper-case
* variant, or if `c` is not a valid codepoint) return `c`.
*/
UTF8PROC_DLLEXPORT utf8proc_int32_t utf8proc_toupper(utf8proc_int32_t c);
/**
* Given a codepoint `c`, return the codepoint of the corresponding
* title-case character, if any; otherwise (if there is no title-case
* variant, or if `c` is not a valid codepoint) return `c`.
*/
UTF8PROC_DLLEXPORT utf8proc_int32_t utf8proc_totitle(utf8proc_int32_t c);
/**
* Given a codepoint `c`, return `1` if the codepoint corresponds to a lower-case character
* and `0` otherwise.
*/
UTF8PROC_DLLEXPORT int utf8proc_islower(utf8proc_int32_t c);
/**
* Given a codepoint `c`, return `1` if the codepoint corresponds to an upper-case character
* and `0` otherwise.
*/
UTF8PROC_DLLEXPORT int utf8proc_isupper(utf8proc_int32_t c);
/**
* Given a codepoint, return a character width analogous to `wcwidth(codepoint)`,
* except that a width of 0 is returned for non-printable codepoints
* instead of -1 as in `wcwidth`.
*
* @note
* If you want to check for particular types of non-printable characters,
* (analogous to `isprint` or `iscntrl`), use @ref utf8proc_category. */
UTF8PROC_DLLEXPORT int utf8proc_charwidth(utf8proc_int32_t codepoint);
/**
* Return the Unicode category for the codepoint (one of the
* @ref utf8proc_category_t constants.)
*/
UTF8PROC_DLLEXPORT utf8proc_category_t utf8proc_category(utf8proc_int32_t codepoint);
/**
* Return the two-letter (nul-terminated) Unicode category string for
* the codepoint (e.g. `"Lu"` or `"Co"`).
*/
UTF8PROC_DLLEXPORT const char *utf8proc_category_string(utf8proc_int32_t codepoint);
/**
* Maps the given UTF-8 string pointed to by `str` to a new UTF-8
* string, allocated dynamically by `malloc` and returned via `dstptr`.
*
* If the @ref UTF8PROC_NULLTERM flag in the `options` field is set,
* the length is determined by a NULL terminator, otherwise the
* parameter `strlen` is evaluated to determine the string length, but
* in any case the result will be NULL terminated (though it might
* contain NULL characters with the string if `str` contained NULL
* characters). Other flags in the `options` field are passed to the
* functions defined above, and regarded as described. See also
* @ref utf8proc_map_custom to supply a custom codepoint transformation.
*
* In case of success the length of the new string is returned,
* otherwise a negative error code is returned.
*
* @note The memory of the new UTF-8 string will have been allocated
* with `malloc`, and should therefore be deallocated with `free`.
*/
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_map(
const utf8proc_uint8_t *str, utf8proc_ssize_t strlen, utf8proc_uint8_t **dstptr, utf8proc_option_t options
);
/**
* Like @ref utf8proc_map, but also takes a `custom_func` mapping function
* that is called on each codepoint in `str` before any other transformations
* (along with a `custom_data` pointer that is passed through to `custom_func`).
* The `custom_func` argument is ignored if it is `NULL`.
*/
UTF8PROC_DLLEXPORT utf8proc_ssize_t utf8proc_map_custom(
const utf8proc_uint8_t *str, utf8proc_ssize_t strlen, utf8proc_uint8_t **dstptr, utf8proc_option_t options,
utf8proc_custom_func custom_func, void *custom_data
);
/** @name Unicode normalization
*
* Returns a pointer to newly allocated memory of a NFD, NFC, NFKD, NFKC or
* NFKC_Casefold normalized version of the null-terminated string `str`. These
* are shortcuts to calling @ref utf8proc_map with @ref UTF8PROC_NULLTERM
* combined with @ref UTF8PROC_STABLE and flags indicating the normalization.
*/
/** @{ */
/** NFD normalization (@ref UTF8PROC_DECOMPOSE). */
UTF8PROC_DLLEXPORT utf8proc_uint8_t *utf8proc_NFD(const utf8proc_uint8_t *str);
/** NFC normalization (@ref UTF8PROC_COMPOSE). */
UTF8PROC_DLLEXPORT utf8proc_uint8_t *utf8proc_NFC(const utf8proc_uint8_t *str);
/** NFKD normalization (@ref UTF8PROC_DECOMPOSE and @ref UTF8PROC_COMPAT). */
UTF8PROC_DLLEXPORT utf8proc_uint8_t *utf8proc_NFKD(const utf8proc_uint8_t *str);
/** NFKC normalization (@ref UTF8PROC_COMPOSE and @ref UTF8PROC_COMPAT). */
UTF8PROC_DLLEXPORT utf8proc_uint8_t *utf8proc_NFKC(const utf8proc_uint8_t *str);
/**
* NFKC_Casefold normalization (@ref UTF8PROC_COMPOSE and @ref UTF8PROC_COMPAT
* and @ref UTF8PROC_CASEFOLD and @ref UTF8PROC_IGNORE).
**/
UTF8PROC_DLLEXPORT utf8proc_uint8_t *utf8proc_NFKC_Casefold(const utf8proc_uint8_t *str);
/** @} */
#ifdef __cplusplus
}
#endif
#endif
This source diff could not be displayed because it is too large. You can view the blob instead.
#include <Sample.h>
#include <SimpleLog.h>
#include <Filesystem.h>
#include <Bert.h>
#include <tokenization.h>
#include <fstream>
using namespace std;
using namespace migraphx;
using namespace migraphxSamples;
void Sample_Bert()
{
// 加载Bert模型
Bert bert;
InitializationParameterOfNLP initParamOfNLPBert;
initParamOfNLPBert.parentPath = "";
initParamOfNLPBert.configFilePath = CONFIG_FILE;
initParamOfNLPBert.logName = "";
ErrorCode errorCode = bert.Initialize(initParamOfNLPBert);
if (errorCode != SUCCESS)
{
LOG_ERROR(stdout, "fail to initialize Bert!\n");
exit(-1);
}
LOG_INFO(stdout, "succeed to initialize Bert\n");
int max_seq_length = 256; // 滑动窗口的长度
int max_query_length = 64; // 问题的最大长度
int batch_size = 1; // batch_size值
int n_best_size = 20; // 索引数量
int max_answer_length = 30; // 答案的最大长度
// 上下文文本数据
const char text[] = { u8"ROCm is the first open-source exascale-class platform for accelerated computing that’s also programming-language independent. It brings a philosophy of choice, minimalism and modular software development to GPU computing. You are free to choose or even develop tools and a language run time for your application. ROCm is built for scale, it supports multi-GPU computing and has a rich system run time with the critical features that large-scale application, compiler and language-run-time development requires. Since the ROCm ecosystem is comprised of open technologies: frameworks (Tensorflow / PyTorch), libraries (MIOpen / Blas / RCCL), programming model (HIP), inter-connect (OCD) and up streamed Linux® Kernel support – the platform is continually optimized for performance and extensibility." };
char question[100];
std::vector<std::vector<long unsigned int>> input_ids;
std::vector<std::vector<long unsigned int>> input_masks;
std::vector<std::vector<long unsigned int>> segment_ids;
std::vector<float> start_position;
std::vector<float> end_position;
std::string answer = {};
cuBERT::FullTokenizer tokenizer = cuBERT::FullTokenizer("../Resource/Models/NLP/Bert/uncased_L-12_H-768_A-12/vocab.txt"); // 分词工具
while (true)
{
// 数据前处理
std::cout << "question: ";
cin.getline(question, 100);
bert.Preprocessing(tokenizer, batch_size, max_seq_length, text, question, input_ids, input_masks, segment_ids);
// 推理
double time1 = getTickCount();
bert.Inference(input_ids, input_masks, segment_ids, start_position, end_position);
double time2 = getTickCount();
double elapsedTime = (time2 - time1) * 1000 / getTickFrequency();
// 数据后处理
bert.Postprocessing(n_best_size, max_answer_length, start_position, end_position, answer);
// 打印输出预测结果
std::cout << "answer: " << answer << std::endl;
LOG_INFO(stdout, "inference time:%f ms\n", elapsedTime);
// 清除数据
input_ids.clear();
input_masks.clear();
segment_ids.clear();
start_position.clear();
end_position.clear();
answer = {};
}
}
\ No newline at end of file
// 示例程序
#ifndef __SAMPLE_H__
#define __SAMPLE_H__
// Bert sample
void Sample_Bert();
#endif
\ No newline at end of file
// 常用数据类型和宏定义
#ifndef __COMMON_DEFINITION_H__
#define __COMMON_DEFINITION_H__
#include <string>
#include <opencv2/opencv.hpp>
using namespace std;
using namespace cv;
namespace migraphxSamples
{
// 路径分隔符(Linux:‘/’,Windows:’\\’)
#ifdef _WIN32
#define PATH_SEPARATOR '\\'
#else
#define PATH_SEPARATOR '/'
#endif
#define CONFIG_FILE "../Resource/Configuration.xml"
typedef struct __Time
{
string year;
string month;
string day;
string hour;
string minute;
string second;
string millisecond; // ms
string microsecond; // us
string weekDay;
}_Time;
typedef enum _ErrorCode
{
SUCCESS=0, // 0
MODEL_NOT_EXIST, // 模型不存在
CONFIG_FILE_NOT_EXIST, // 配置文件不存在
FAIL_TO_LOAD_MODEL, // 加载模型失败
FAIL_TO_OPEN_CONFIG_FILE, // 加载配置文件失败
IMAGE_ERROR, // 图像错误
}ErrorCode;
typedef struct _ResultOfPrediction
{
float confidence;
int label;
_ResultOfPrediction():confidence(0.0f),label(0){}
}ResultOfPrediction;
typedef struct _ResultOfDetection
{
Rect boundingBox;
float confidence;
int classID;
string className;
bool exist;
_ResultOfDetection():confidence(0.0f),classID(0),exist(true){}
}ResultOfDetection;
typedef struct _InitializationParameterOfNLP
{
std::string parentPath;
std::string configFilePath;
cv::Size inputSize;
std::string logName;
}InitializationParameterOfNLP;
}
#endif
#include <CommonUtility.h>
#include <assert.h>
#include <ctype.h>
#include <time.h>
#include <stdlib.h>
#include <algorithm>
#include <sstream>
#include <vector>
#ifdef _WIN32
#include <io.h>
#include <direct.h>
#include <Windows.h>
#else
#include <unistd.h>
#include <dirent.h>
#include <sys/stat.h>
#include <sys/time.h>
#endif
#include <SimpleLog.h>
namespace migraphxSamples
{
_Time GetCurrentTime3()
{
_Time currentTime;
#if (defined WIN32 || defined _WIN32)
SYSTEMTIME systemTime;
GetLocalTime(&systemTime);
char temp[8] = { 0 };
sprintf(temp, "%04d", systemTime.wYear);
currentTime.year=string(temp);
sprintf(temp, "%02d", systemTime.wMonth);
currentTime.month=string(temp);
sprintf(temp, "%02d", systemTime.wDay);
currentTime.day=string(temp);
sprintf(temp, "%02d", systemTime.wHour);
currentTime.hour=string(temp);
sprintf(temp, "%02d", systemTime.wMinute);
currentTime.minute=string(temp);
sprintf(temp, "%02d", systemTime.wSecond);
currentTime.second=string(temp);
sprintf(temp, "%03d", systemTime.wMilliseconds);
currentTime.millisecond=string(temp);
sprintf(temp, "%d", systemTime.wDayOfWeek);
currentTime.weekDay=string(temp);
#else
struct timeval tv;
struct tm *p;
gettimeofday(&tv, NULL);
p = localtime(&tv.tv_sec);
char temp[8]={0};
sprintf(temp,"%04d",1900+p->tm_year);
currentTime.year=string(temp);
sprintf(temp,"%02d",1+p->tm_mon);
currentTime.month=string(temp);
sprintf(temp,"%02d",p->tm_mday);
currentTime.day=string(temp);
sprintf(temp,"%02d",p->tm_hour);
currentTime.hour=string(temp);
sprintf(temp,"%02d",p->tm_min);
currentTime.minute=string(temp);
sprintf(temp,"%02d",p->tm_sec);
currentTime.second=string(temp);
sprintf(temp,"%03d",tv.tv_usec/1000);
currentTime.millisecond = string(temp);
sprintf(temp, "%03d", tv.tv_usec % 1000);
currentTime.microsecond = string(temp);
sprintf(temp, "%d", p->tm_wday);
currentTime.weekDay = string(temp);
#endif
return currentTime;
}
std::vector<std::string> SplitString(std::string str, std::string separator)
{
std::string::size_type pos;
std::vector<std::string> result;
str+=separator;//扩展字符串以方便操作
int size=str.size();
for(int i=0; i<size; i++)
{
pos=str.find(separator,i);
if(pos<size)
{
std::string s=str.substr(i,pos-i);
result.push_back(s);
i=pos+separator.size()-1;
}
}
return result;
}
bool CompareConfidence(const ResultOfDetection &L,const ResultOfDetection &R)
{
return L.confidence > R.confidence;
}
bool CompareArea(const ResultOfDetection &L,const ResultOfDetection &R)
{
return L.boundingBox.area() > R.boundingBox.area();
}
void NMS(vector<ResultOfDetection> &detections, float IOUThreshold)
{
// sort
std::sort(detections.begin(), detections.end(), CompareConfidence);
for (int i = 0; i<detections.size(); ++i)
{
if (detections[i].exist)
{
for (int j = i + 1; j<detections.size(); ++j)
{
if (detections[j].exist)
{
// compute IOU
float intersectionArea = (detections[i].boundingBox & detections[j].boundingBox).area();
float intersectionRate = intersectionArea / (detections[i].boundingBox.area() + detections[j].boundingBox.area() - intersectionArea);
if (intersectionRate>IOUThreshold)
{
detections[j].exist = false;
}
}
}
}
}
}
}
// 常用工具
#ifndef __COMMON_UTILITY_H__
#define __COMMON_UTILITY_H__
#include <mutex>
#include <string>
#include <vector>
#include <CommonDefinition.h>
using namespace std;
namespace migraphxSamples
{
// 分割字符串
std::vector<std::string> SplitString(std::string str,std::string separator);
// 排序规则: 按照置信度或者按照面积排序
bool CompareConfidence(const ResultOfDetection &L,const ResultOfDetection &R);
bool CompareArea(const ResultOfDetection &L,const ResultOfDetection &R);
void NMS(std::vector<ResultOfDetection> &detections, float IOUThreshold);
}
#endif
#include <Filesystem.h>
#include <algorithm>
#include <sys/stat.h>
#include <sys/types.h>
#include <fstream>
#ifdef _WIN32
#include <io.h>
#include <direct.h>
#include <Windows.h>
#else
#include <unistd.h>
#include <dirent.h>
#endif
#include <CommonUtility.h>
#include <opencv2/opencv.hpp>
#include <SimpleLog.h>
using namespace cv;
// 路径分隔符(Linux:‘/’,Windows:’\\’)
#ifdef _WIN32
#define PATH_SEPARATOR '\\'
#else
#define PATH_SEPARATOR '/'
#endif
namespace migraphxSamples
{
#if defined _WIN32 || defined WINCE
const char dir_separators[] = "/\\";
struct dirent
{
const char* d_name;
};
struct DIR
{
#ifdef WINRT
WIN32_FIND_DATAW data;
#else
WIN32_FIND_DATAA data;
#endif
HANDLE handle;
dirent ent;
#ifdef WINRT
DIR() { }
~DIR()
{
if (ent.d_name)
delete[] ent.d_name;
}
#endif
};
DIR* opendir(const char* path)
{
DIR* dir = new DIR;
dir->ent.d_name = 0;
#ifdef WINRT
string full_path = string(path) + "\\*";
wchar_t wfull_path[MAX_PATH];
size_t copied = mbstowcs(wfull_path, full_path.c_str(), MAX_PATH);
CV_Assert((copied != MAX_PATH) && (copied != (size_t)-1));
dir->handle = ::FindFirstFileExW(wfull_path, FindExInfoStandard,
&dir->data, FindExSearchNameMatch, NULL, 0);
#else
dir->handle = ::FindFirstFileExA((string(path) + "\\*").c_str(),
FindExInfoStandard, &dir->data, FindExSearchNameMatch, NULL, 0);
#endif
if (dir->handle == INVALID_HANDLE_VALUE)
{
/*closedir will do all cleanup*/
delete dir;
return 0;
}
return dir;
}
dirent* readdir(DIR* dir)
{
#ifdef WINRT
if (dir->ent.d_name != 0)
{
if (::FindNextFileW(dir->handle, &dir->data) != TRUE)
return 0;
}
size_t asize = wcstombs(NULL, dir->data.cFileName, 0);
CV_Assert((asize != 0) && (asize != (size_t)-1));
char* aname = new char[asize + 1];
aname[asize] = 0;
wcstombs(aname, dir->data.cFileName, asize);
dir->ent.d_name = aname;
#else
if (dir->ent.d_name != 0)
{
if (::FindNextFileA(dir->handle, &dir->data) != TRUE)
return 0;
}
dir->ent.d_name = dir->data.cFileName;
#endif
return &dir->ent;
}
void closedir(DIR* dir)
{
::FindClose(dir->handle);
delete dir;
}
#else
# include <dirent.h>
# include <sys/stat.h>
const char dir_separators[] = "/";
#endif
static bool isDir(const string &path, DIR* dir)
{
#if defined _WIN32 || defined WINCE
DWORD attributes;
BOOL status = TRUE;
if (dir)
attributes = dir->data.dwFileAttributes;
else
{
WIN32_FILE_ATTRIBUTE_DATA all_attrs;
#ifdef WINRT
wchar_t wpath[MAX_PATH];
size_t copied = mbstowcs(wpath, path.c_str(), MAX_PATH);
CV_Assert((copied != MAX_PATH) && (copied != (size_t)-1));
status = ::GetFileAttributesExW(wpath, GetFileExInfoStandard, &all_attrs);
#else
status = ::GetFileAttributesExA(path.c_str(), GetFileExInfoStandard, &all_attrs);
#endif
attributes = all_attrs.dwFileAttributes;
}
return status && ((attributes & FILE_ATTRIBUTE_DIRECTORY) != 0);
#else
(void)dir;
struct stat stat_buf;
if (0 != stat(path.c_str(), &stat_buf))
return false;
int is_dir = S_ISDIR(stat_buf.st_mode);
return is_dir != 0;
#endif
}
bool IsDirectory(const string &path)
{
return isDir(path, NULL);
}
bool Exists(const string& path)
{
#if defined _WIN32 || defined WINCE
BOOL status = TRUE;
{
WIN32_FILE_ATTRIBUTE_DATA all_attrs;
#ifdef WINRT
wchar_t wpath[MAX_PATH];
size_t copied = mbstowcs(wpath, path.c_str(), MAX_PATH);
CV_Assert((copied != MAX_PATH) && (copied != (size_t)-1));
status = ::GetFileAttributesExW(wpath, GetFileExInfoStandard, &all_attrs);
#else
status = ::GetFileAttributesExA(path.c_str(), GetFileExInfoStandard, &all_attrs);
#endif
}
return !!status;
#else
struct stat stat_buf;
return (0 == stat(path.c_str(), &stat_buf));
#endif
}
bool IsPathSeparator(char c)
{
return c == '/' || c == '\\';
}
string JoinPath(const string& base, const string& path)
{
if (base.empty())
return path;
if (path.empty())
return base;
bool baseSep = IsPathSeparator(base[base.size() - 1]);
bool pathSep = IsPathSeparator(path[0]);
string result;
if (baseSep && pathSep)
{
result = base + path.substr(1);
}
else if (!baseSep && !pathSep)
{
result = base + PATH_SEPARATOR + path;
}
else
{
result = base + path;
}
return result;
}
static bool wildcmp(const char *string, const char *wild)
{
const char *cp = 0, *mp = 0;
while ((*string) && (*wild != '*'))
{
if ((*wild != *string) && (*wild != '?'))
{
return false;
}
wild++;
string++;
}
while (*string)
{
if (*wild == '*')
{
if (!*++wild)
{
return true;
}
mp = wild;
cp = string + 1;
}
else if ((*wild == *string) || (*wild == '?'))
{
wild++;
string++;
}
else
{
wild = mp;
string = cp++;
}
}
while (*wild == '*')
{
wild++;
}
return *wild == 0;
}
static void glob_rec(const string &directory, const string& wildchart, std::vector<string>& result,
bool recursive, bool includeDirectories, const string& pathPrefix)
{
DIR *dir;
if ((dir = opendir(directory.c_str())) != 0)
{
/* find all the files and directories within directory */
try
{
struct dirent *ent;
while ((ent = readdir(dir)) != 0)
{
const char* name = ent->d_name;
if ((name[0] == 0) || (name[0] == '.' && name[1] == 0) || (name[0] == '.' && name[1] == '.' && name[2] == 0))
continue;
string path = JoinPath(directory, name);
string entry = JoinPath(pathPrefix, name);
if (isDir(path, dir))
{
if (recursive)
glob_rec(path, wildchart, result, recursive, includeDirectories, entry);
if (!includeDirectories)
continue;
}
if (wildchart.empty() || wildcmp(name, wildchart.c_str()))
result.push_back(entry);
}
}
catch (...)
{
closedir(dir);
throw;
}
closedir(dir);
}
else
{
LOG_INFO(stdout, "could not open directory: %s", directory.c_str());
}
}
void GetFileNameList(const string &directory, const string &pattern, std::vector<string>& result, bool recursive, bool addPath)
{
// split pattern
vector<string> patterns=SplitString(pattern,",");
result.clear();
for(int i=0;i<patterns.size();++i)
{
string eachPattern=patterns[i];
std::vector<string> eachResult;
glob_rec(directory, eachPattern, eachResult, recursive, true, directory);
for(int j=0;j<eachResult.size();++j)
{
if (IsDirectory(eachResult[j]))
continue;
if(addPath)
{
result.push_back(eachResult[j]);
}
else
{
result.push_back(GetFileName(eachResult[j]));
}
}
}
std::sort(result.begin(), result.end());
}
void GetFileNameList2(const string &directory, const string &pattern, std::vector<string>& result, bool recursive, bool addPath)
{
// split pattern
vector<string> patterns = SplitString(pattern, ",");
result.clear();
for (int i = 0; i<patterns.size(); ++i)
{
string eachPattern = patterns[i];
std::vector<string> eachResult;
glob_rec(directory, eachPattern, eachResult, recursive, true, directory);
for (int j = 0; j<eachResult.size(); ++j)
{
string filePath = eachResult[j];
if (IsDirectory(filePath))
{
filePath = filePath + "/";
for (int k = 0; k < filePath.size(); ++k)
{
if (IsPathSeparator(filePath[k]))
{
filePath[k] = '/';
}
}
}
if (addPath)
{
result.push_back(filePath);
}
else
{
if (!IsDirectory(filePath))
{
result.push_back(GetFileName(filePath));
}
}
}
}
std::sort(result.begin(), result.end());
}
void RemoveAll(const string& path)
{
if (!Exists(path))
return;
if (IsDirectory(path))
{
std::vector<string> entries;
GetFileNameList2(path, string(), entries, false, true);
for (size_t i = 0; i < entries.size(); i++)
{
const string& e = entries[i];
RemoveAll(e);
}
#ifdef _MSC_VER
bool result = _rmdir(path.c_str()) == 0;
#else
bool result = rmdir(path.c_str()) == 0;
#endif
if (!result)
{
LOG_INFO(stdout, "can't remove directory: %s\n", path.c_str());
}
}
else
{
#ifdef _MSC_VER
bool result = _unlink(path.c_str()) == 0;
#else
bool result = unlink(path.c_str()) == 0;
#endif
if (!result)
{
LOG_INFO(stdout, "can't remove file: %s\n", path.c_str());
}
}
}
void Remove(const string &directory, const string &extension)
{
DIR *dir;
static int numberOfFiles = 0;
if ((dir = opendir(directory.c_str())) != 0)
{
/* find all the files and directories within directory */
try
{
struct dirent *ent;
while ((ent = readdir(dir)) != 0)
{
const char* name = ent->d_name;
if ((name[0] == 0) || (name[0] == '.' && name[1] == 0) || (name[0] == '.' && name[1] == '.' && name[2] == 0))
continue;
string path = JoinPath(directory, name);
if (isDir(path, dir))
{
Remove(path, extension);
}
// �ж���չ��
if (extension.empty() || wildcmp(name, extension.c_str()))
{
RemoveAll(path);
++numberOfFiles;
LOG_INFO(stdout, "%s deleted! number of deleted files:%d\n", path.c_str(), numberOfFiles);
}
}
}
catch (...)
{
closedir(dir);
throw;
}
closedir(dir);
}
else
{
LOG_INFO(stdout, "could not open directory: %s", directory.c_str());
}
// ����RemoveAllɾ��Ŀ¼
RemoveAll(directory);
}
string GetFileName(const string &path)
{
string fileName;
int indexOfPathSeparator = -1;
for (int i = path.size() - 1; i >= 0; --i)
{
if (IsPathSeparator(path[i]))
{
fileName = path.substr(i + 1, path.size() - i - 1);
indexOfPathSeparator = i;
break;
}
}
if (indexOfPathSeparator == -1)
{
fileName = path;
}
return fileName;
}
string GetFileName_NoExtension(const string &path)
{
string fileName=GetFileName(path);
string fileName_NoExtension;
for(int i=fileName.size()-1;i>0;--i)
{
if(fileName[i]=='.')
{
fileName_NoExtension=fileName.substr(0,i);
break;
}
}
return fileName_NoExtension;
}
string GetExtension(const string &path)
{
string fileName;
for (int i = path.size() - 1; i >= 0; --i)
{
if (path[i]=='.')
{
fileName = path.substr(i, path.size() - i);
break;
}
}
return fileName;
}
string GetParentPath(const string &path)
{
string fileName;
for (int i = path.size() - 1; i >= 0; --i)
{
if (IsPathSeparator(path[i]))
{
fileName = path.substr(0, i+1);
break;
}
}
return fileName;
}
static bool CreateDirectory(const string &path)
{
#if defined WIN32 || defined _WIN32 || defined WINCE
#ifdef WINRT
wchar_t wpath[MAX_PATH];
size_t copied = mbstowcs(wpath, path.c_str(), MAX_PATH);
CV_Assert((copied != MAX_PATH) && (copied != (size_t)-1));
int result = CreateDirectoryA(wpath, NULL) ? 0 : -1;
#else
int result = _mkdir(path.c_str());
#endif
#elif defined __linux__ || defined __APPLE__
int result = mkdir(path.c_str(), 0777);
#else
int result = -1;
#endif
if (result == -1)
{
return IsDirectory(path);
}
return true;
}
bool CreateDirectories(const string &directoryPath)
{
string path = directoryPath;
for (;;)
{
char last_char = path.empty() ? 0 : path[path.length() - 1];
if (IsPathSeparator(last_char))
{
path = path.substr(0, path.length() - 1);
continue;
}
break;
}
if (path.empty() || path == "./" || path == ".\\" || path == ".")
return true;
if (IsDirectory(path))
return true;
size_t pos = path.rfind('/');
if (pos == string::npos)
pos = path.rfind('\\');
if (pos != string::npos)
{
string parent_directory = path.substr(0, pos);
if (!parent_directory.empty())
{
if (!CreateDirectories(parent_directory))
return false;
}
}
return CreateDirectory(path);
}
bool CopyFile(const string srcPath, const string dstPath)
{
std::ifstream srcFile(srcPath,ios::binary);
std::ofstream dstFile(dstPath,ios::binary);
if(!srcFile.is_open())
{
LOG_ERROR(stdout,"can not open %s\n",srcPath.c_str());
return false;
}
if(!dstFile.is_open())
{
LOG_ERROR(stdout, "can not open %s\n", dstPath.c_str());
return false;
}
if(srcPath==dstPath)
{
LOG_ERROR(stdout, "src can not be same with dst\n");
return false;
}
char buffer[2048];
unsigned int numberOfBytes=0;
while(srcFile)
{
srcFile.read(buffer,2048);
dstFile.write(buffer,srcFile.gcount());
numberOfBytes+=srcFile.gcount();
}
srcFile.close();
dstFile.close();
return true;
}
bool CopyDirectories(string srcPath, const string dstPath)
{
if(srcPath==dstPath)
{
LOG_ERROR(stdout, "src can not be same with dst\n");
return false;
}
// ȥ������·���ָ���
srcPath = srcPath.substr(0, srcPath.size() - 1);
vector<string> fileNameList;
GetFileNameList2(srcPath, "", fileNameList, true, true);
string parentPathOfSrc=GetParentPath(srcPath);
int length=parentPathOfSrc.size();
// create all directories
for(int i=0;i<fileNameList.size();++i)
{
// create directory
string srcFilePath=fileNameList[i];
string subStr=srcFilePath.substr(length,srcFilePath.size()-length);
string dstFilePath=dstPath+subStr;
string parentPathOfDst=GetParentPath(dstFilePath);
CreateDirectories(parentPathOfDst);
}
// copy file
for(int i=0;i<fileNameList.size();++i)
{
string srcFilePath=fileNameList[i];
if (IsDirectory(srcFilePath))
{
continue;
}
string subStr=srcFilePath.substr(length,srcFilePath.size()-length);
string dstFilePath=dstPath+subStr;
// copy file
CopyFile(srcFilePath,dstFilePath);
// process
double process = (1.0*(i + 1) / fileNameList.size()) * 100;
LOG_INFO(stdout, "%s done! %f% \n", GetFileName(fileNameList[i]).c_str(), process);
}
LOG_INFO(stdout, "all done!(the number of files:%d)\n", fileNameList.size());
return true;
}
}
// 文件以及目录处理
#ifndef __FILE_SYSTEM_H__
#define __FILE_SYSTEM_H__
#include <vector>
#include <string>
using namespace std;
namespace migraphxSamples
{
// 路径是否存在
bool Exists(const std::string &path);
// 路径是否为目录
bool IsDirectory(const std::string &path);
// 是否是路径分隔符(Linux:‘/’,Windows:’\\’)
bool IsPathSeparator(char c);
// 路径拼接
string JoinPath(const std::string &base, const std::string &path);
// 创建多级目录,注意:创建多级目录的时候,目标目录是不能有文件存在的
bool CreateDirectories(const std::string &directoryPath);
/** 生成符合指定模式的文件名列表(支持递归遍历)
*
* pattern: 模式,比如"*.jpg","*.png","*.jpg,*.png"
* addPath:是否包含父路径
* 注意:
1. 多个模式使用","分割,比如"*.jpg,*.png"
2. 支持通配符'*','?' ,比如第一个字符是7的所有文件名:"7*.*", 以512结尾的所有jpg文件名:"*512.jpg"
3. 使用"*.jpg",而不是".jpg"
4. 空string表示返回所有结果
5. 不能返回子目录名
*
*/
void GetFileNameList(const std::string &directory, const std::string &pattern, std::vector<std::string> &result, bool recursive, bool addPath);
// 与GetFileNameList的区别在于如果有子目录,在addPath为true的时候会返回子目录路径(目录名最后有"/")
void GetFileNameList2(const std::string &directory, const std::string &pattern, std::vector<std::string> &result, bool recursive, bool addPath);
// 删除文件或者目录,支持递归删除
void Remove(const std::string &directory, const std::string &extension="");
/** 获取路径的文件名和扩展名
*
* 示例:path为D:/1/1.txt,则GetFileName()为1.txt,GetFileName_NoExtension()为1,GetExtension()为.txt,GetParentPath()为D:/1/
*/
string GetFileName(const std::string &path); // 1.txt
string GetFileName_NoExtension(const std::string &path); // 1
string GetExtension(const std::string &path);// .txt
string GetParentPath(const std::string &path);// D:/1/
// 拷贝文件:CopyFile("D:/1.txt","D:/2.txt");将1.txt拷贝为2.txt
bool CopyFile(const std::string srcPath,const std::string dstPath);
/** 拷贝目录
*
* 示例:CopyDirectories("D:/0/1/2/","E:/3/");实现把D:/0/1/2/目录拷贝到E:/3/目录中(即拷贝完成后的目录结构为E:/3/2/)
* 注意:
1.第一个参数的最后不能加”/”
2.不能拷贝隐藏文件
*/
bool CopyDirectories(std::string srcPath,const std::string dstPath);
}
#endif
// 简易日志
#ifndef __SIMPLE_LOG_H__
#define __SIMPLE_LOG_H__
#include <time.h>
#include <string>
#include <map>
#include <thread>
#include <mutex>
#if (defined WIN32 || defined _WIN32)
#include <Windows.h>
#else
#include <sys/time.h>
#endif
using namespace std;
/** 简易日志
*
* 轻量级日志系统,不依赖于其他第三方库,只需要包含一个头文件就可以使用。提供了4种日志级别,包括INFO,DEBUG,WARN和ERROR。
*
* 示例1:
// 初始化日志,在./Log/目录下创建两个日志文件log1.log和log2.log(注意:目录./Log/需要存在,否则日志创建失败)
LogManager::GetInstance()->Initialize("./Log/","log1");
LogManager::GetInstance()->Initialize("./Log/","log2");
// 写日志
string log = "Hello World";
LOG_INFO(LogManager::GetInstance()->GetLogFile("log1"), "%s\n", log.c_str()); // 写入log1.log
LOG_INFO(LogManager::GetInstance()->GetLogFile("log2"), "%s\n", log.c_str()); // 写入log2.log
// 关闭日志
LogManager::GetInstance()->Close("log1");
LogManager::GetInstance()->Close("log2");
* 示例2:
// 将日志输出到控制台
string log = "Hello World";
LOG_INFO(stdout, "%s\n", log.c_str());
* 注意:
1. 需要C++11
2. 多线程的时候需要加锁(打开#define LOG_MUTEX),否则会导致日志显示混乱
*/
// #define LOG_MUTEX // 加锁
class LogManager
{
private:
LogManager(){}
public:
~LogManager(){}
inline void Initialize(const string &parentPath,const string &logName)
{
// 日志名为空表示输出到控制台
if(logName.size()==0)
return;
// 查找该日志文件,如果没有则创建
std::map<string, FILE*>::const_iterator iter = logMap.find(logName);
if (iter == logMap.end())
{
string pathOfLog = parentPath+ logName + ".log";
FILE *logFile = fopen(pathOfLog.c_str(), "a"); // w:覆盖原有文件,a:追加
if(logFile!=NULL)
{
logMap.insert(std::make_pair(logName, logFile));
}
}
}
inline FILE* GetLogFile(const string &logName)
{
std::map<string, FILE*>::const_iterator iter=logMap.find(logName);
if(iter==logMap.end())
{
return NULL;
}
return (*iter).second;
}
inline void Close(const string &logName)
{
std::map<string, FILE*>::const_iterator iter=logMap.find(logName);
if(iter==logMap.end())
{
return;
}
fclose((*iter).second);
logMap.erase(iter);
}
inline std::mutex &GetLogMutex()
{
return logMutex;
}
// Singleton
static LogManager* GetInstance()
{
static LogManager logManager;
return &logManager;
}
private:
std::map<string, FILE*> logMap;
std::mutex logMutex;
};
#ifdef LOG_MUTEX
#define LOCK LogManager::GetInstance()->GetLogMutex().lock()
#define UNLOCK LogManager::GetInstance()->GetLogMutex().unlock()
#else
#define LOCK
#define UNLOCK
#endif
// log time
typedef struct _LogTime
{
string year;
string month;
string day;
string hour;
string minute;
string second;
string millisecond; // ms
string microsecond; // us
string weekDay;
}LogTime;
inline LogTime GetTime()
{
LogTime currentTime;
#if (defined WIN32 || defined _WIN32)
SYSTEMTIME systemTime;
GetLocalTime(&systemTime);
char temp[8] = { 0 };
sprintf(temp, "%04d", systemTime.wYear);
currentTime.year=string(temp);
sprintf(temp, "%02d", systemTime.wMonth);
currentTime.month=string(temp);
sprintf(temp, "%02d", systemTime.wDay);
currentTime.day=string(temp);
sprintf(temp, "%02d", systemTime.wHour);
currentTime.hour=string(temp);
sprintf(temp, "%02d", systemTime.wMinute);
currentTime.minute=string(temp);
sprintf(temp, "%02d", systemTime.wSecond);
currentTime.second=string(temp);
sprintf(temp, "%03d", systemTime.wMilliseconds);
currentTime.millisecond=string(temp);
sprintf(temp, "%d", systemTime.wDayOfWeek);
currentTime.weekDay=string(temp);
#else
struct timeval tv;
struct tm *p;
gettimeofday(&tv, NULL);
p = localtime(&tv.tv_sec);
char temp[8]={0};
sprintf(temp,"%04d",1900+p->tm_year);
currentTime.year=string(temp);
sprintf(temp,"%02d",1+p->tm_mon);
currentTime.month=string(temp);
sprintf(temp,"%02d",p->tm_mday);
currentTime.day=string(temp);
sprintf(temp,"%02d",p->tm_hour);
currentTime.hour=string(temp);
sprintf(temp,"%02d",p->tm_min);
currentTime.minute=string(temp);
sprintf(temp,"%02d",p->tm_sec);
currentTime.second=string(temp);
sprintf(temp,"%03d",(int)(tv.tv_usec/1000));
currentTime.millisecond = string(temp);
sprintf(temp, "%03d", (int)(tv.tv_usec % 1000));
currentTime.microsecond = string(temp);
sprintf(temp, "%d", p->tm_wday);
currentTime.weekDay = string(temp);
#endif
return currentTime;
}
#define LOG_TIME(logFile) \
do\
{\
LogTime currentTime=GetTime(); \
fprintf(((logFile == NULL) ? stdout : logFile), "%s-%s-%s %s:%s:%s.%s\t",currentTime.year.c_str(),currentTime.month.c_str(),currentTime.day.c_str(),currentTime.hour.c_str(),currentTime.minute.c_str(),currentTime.second.c_str(),currentTime.millisecond.c_str()); \
}while (0)
#define LOG_INFO(logFile,logInfo, ...) \
do\
{\
LOCK; \
LOG_TIME(logFile); \
fprintf(((logFile == NULL) ? stdout : logFile), "INFO\t"); \
fprintf(((logFile == NULL) ? stdout : logFile), "[%s:%d (%s) ]: ", __FILE__, __LINE__, __FUNCTION__); \
fprintf(((logFile == NULL) ? stdout : logFile), logInfo, ## __VA_ARGS__); \
fflush(logFile); \
UNLOCK; \
} while (0)
#define LOG_DEBUG(logFile,logInfo, ...) \
do\
{\
LOCK; \
LOG_TIME(logFile);\
fprintf(((logFile==NULL)?stdout:logFile), "DEBUG\t"); \
fprintf(((logFile==NULL)?stdout:logFile), "[%s:%d (%s) ]: ", __FILE__, __LINE__, __FUNCTION__); \
fprintf(((logFile==NULL)?stdout:logFile),logInfo, ## __VA_ARGS__); \
fflush(logFile); \
UNLOCK; \
} while (0)
#define LOG_ERROR(logFile,logInfo, ...) \
do\
{\
LOCK; \
LOG_TIME(logFile);\
fprintf(((logFile==NULL)?stdout:logFile), "ERROR\t"); \
fprintf(((logFile==NULL)?stdout:logFile), "[%s:%d (%s) ]: ", __FILE__, __LINE__, __FUNCTION__); \
fprintf(((logFile==NULL)?stdout:logFile),logInfo, ## __VA_ARGS__); \
fflush(logFile); \
UNLOCK; \
} while (0)
#define LOG_WARN(logFile,logInfo, ...) \
do\
{\
LOCK; \
LOG_TIME(logFile);\
fprintf(((logFile==NULL)?stdout:logFile), "WARN\t"); \
fprintf(((logFile==NULL)?stdout:logFile), "[%s:%d (%s) ]: ", __FILE__, __LINE__, __FUNCTION__); \
fprintf(((logFile==NULL)?stdout:logFile),logInfo, ## __VA_ARGS__); \
fflush(logFile); \
UNLOCK; \
} while (0)
#endif // __SIMPLE_LOG_H__
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <Sample.h>
void MIGraphXSamplesUsage(char* programName)
{
printf("Usage : %s <index> \n", programName);
printf("index:\n");
printf("\t 0) Bert sample.\n");
}
int main(int argc, char *argv[])
{
if (argc < 2 || argc > 2)
{
MIGraphXSamplesUsage(argv[0]);
return -1;
}
if (!strncmp(argv[1], "-h", 2))
{
MIGraphXSamplesUsage(argv[0]);
return 0;
}
switch (*argv[1])
{
case '0':
{
Sample_Bert();
break;
}
default :
{
MIGraphXSamplesUsage(argv[0]);
break;
}
}
return 0;
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment