Commit aec6280a authored by liucong's avatar liucong
Browse files

对C++代码通过格式化

parent ab78f8ec
#include <Bert.h> #include <Bert.h>
#include <migraphx/onnx.hpp>
#include <migraphx/gpu/target.hpp>
#include <Filesystem.h> #include <Filesystem.h>
#include <SimpleLog.h> #include <SimpleLog.h>
#include <algorithm>
#include <stdexcept>
#include <tokenization.h> #include <tokenization.h>
namespace migraphxSamples #include <algorithm>
{ #include <migraphx/gpu/target.hpp>
#include <migraphx/onnx.hpp>
Bert::Bert() #include <stdexcept>
{
} namespace migraphxSamples {
Bert::~Bert() Bert::Bert() {}
{
} Bert::~Bert() {}
ErrorCode Bert::Initialize() ErrorCode Bert::Initialize()
{ {
// 获取模型文件 // 获取模型文件
std::string modelPath="../Resource/bertsquad-10.onnx"; std::string modelPath = "../Resource/bertsquad-10.onnx";
// 设置最大输入shape // 设置最大输入shape
migraphx::onnx_options onnx_options; migraphx::onnx_options onnx_options;
onnx_options.map_input_dims["unique_ids_raw_output___9:0"]={1}; onnx_options.map_input_dims["unique_ids_raw_output___9:0"] = {1};
onnx_options.map_input_dims["input_ids:0"]={1,256}; onnx_options.map_input_dims["input_ids:0"] = {1, 256};
onnx_options.map_input_dims["input_mask:0"]={1,256}; onnx_options.map_input_dims["input_mask:0"] = {1, 256};
onnx_options.map_input_dims["segment_ids:0"]={1,256}; onnx_options.map_input_dims["segment_ids:0"] = {1, 256};
// 加载模型 // 加载模型
if(!Exists(modelPath)) if(!Exists(modelPath))
{ {
LOG_ERROR(stdout,"%s not exist!\n",modelPath.c_str()); LOG_ERROR(stdout, "%s not exist!\n", modelPath.c_str());
return MODEL_NOT_EXIST; return MODEL_NOT_EXIST;
} }
net = migraphx::parse_onnx(modelPath, onnx_options); net = migraphx::parse_onnx(modelPath, onnx_options);
LOG_INFO(stdout,"succeed to load model: %s\n",GetFileName(modelPath).c_str()); LOG_INFO(stdout, "succeed to load model: %s\n", GetFileName(modelPath).c_str());
// 获取模型输入/输出节点信息 // 获取模型输入/输出节点信息
std::unordered_map<std::string, migraphx::shape> inputs=net.get_inputs(); std::unordered_map<std::string, migraphx::shape> inputs = net.get_inputs();
std::unordered_map<std::string, migraphx::shape> outputs=net.get_outputs(); std::unordered_map<std::string, migraphx::shape> outputs = net.get_outputs();
inputName1 = "unique_ids_raw_output___9:0"; inputName1 = "unique_ids_raw_output___9:0";
inputShape1 = inputs.at(inputName1); inputShape1 = inputs.at(inputName1);
...@@ -65,27 +56,27 @@ ErrorCode Bert::Initialize() ...@@ -65,27 +56,27 @@ ErrorCode Bert::Initialize()
// 编译模型 // 编译模型
migraphx::compile_options options; migraphx::compile_options options;
options.device_id=0; // 设置GPU设备,默认为0号设备 options.device_id = 0; // 设置GPU设备,默认为0号设备
options.offload_copy=true; options.offload_copy = true;
net.compile(gpuTarget,options); net.compile(gpuTarget, options);
LOG_INFO(stdout,"succeed to compile model: %s\n",GetFileName(modelPath).c_str()); LOG_INFO(stdout, "succeed to compile model: %s\n", GetFileName(modelPath).c_str());
// warm up // warm up
std::unordered_map<std::string, migraphx::argument> inputData; std::unordered_map<std::string, migraphx::argument> inputData;
inputData[inputName1]=migraphx::argument(inputShape1); inputData[inputName1] = migraphx::argument(inputShape1);
inputData[inputName2]=migraphx::argument(inputShape2); inputData[inputName2] = migraphx::argument(inputShape2);
inputData[inputName3]=migraphx::argument(inputShape3); inputData[inputName3] = migraphx::argument(inputShape3);
inputData[inputName4]=migraphx::argument(inputShape4); inputData[inputName4] = migraphx::argument(inputShape4);
net.eval(inputData); net.eval(inputData);
return SUCCESS; return SUCCESS;
} }
ErrorCode Bert::Inference(const std::vector<std::vector<long unsigned int>> &input_ids, ErrorCode Bert::Inference(const std::vector<std::vector<long unsigned int>>& input_ids,
const std::vector<std::vector<long unsigned int>> &input_masks, const std::vector<std::vector<long unsigned int>>& input_masks,
const std::vector<std::vector<long unsigned int>> &segment_ids, const std::vector<std::vector<long unsigned int>>& segment_ids,
std::vector<float> &start_position, std::vector<float>& start_position,
std::vector<float> &end_position) std::vector<float>& end_position)
{ {
// 保存预处理后的数据 // 保存预处理后的数据
int num = input_ids.size(); int num = input_ids.size();
...@@ -93,9 +84,9 @@ ErrorCode Bert::Inference(const std::vector<std::vector<long unsigned int>> &inp ...@@ -93,9 +84,9 @@ ErrorCode Bert::Inference(const std::vector<std::vector<long unsigned int>> &inp
long unsigned int input_mask[num][256]; long unsigned int input_mask[num][256];
long unsigned int segment_id[num][256]; long unsigned int segment_id[num][256];
long unsigned int position_id[num][1]; long unsigned int position_id[num][1];
for(int i=0;i<input_ids.size();++i) for(int i = 0; i < input_ids.size(); ++i)
{ {
for(int j=0;j<input_ids[0].size();++j) for(int j = 0; j < input_ids[0].size(); ++j)
{ {
input_id[i][j] = input_ids[i][j]; input_id[i][j] = input_ids[i][j];
segment_id[i][j] = segment_ids[i][j]; segment_id[i][j] = segment_ids[i][j];
...@@ -111,25 +102,25 @@ ErrorCode Bert::Inference(const std::vector<std::vector<long unsigned int>> &inp ...@@ -111,25 +102,25 @@ ErrorCode Bert::Inference(const std::vector<std::vector<long unsigned int>> &inp
float* start_data; float* start_data;
float* end_data; float* end_data;
for(int i=0;i<input_ids.size();++i) for(int i = 0; i < input_ids.size(); ++i)
{ {
// 创建输入数据 // 创建输入数据
inputData[inputName1]=migraphx::argument{inputShape1, (long unsigned int*)position_id[i]}; inputData[inputName1] = migraphx::argument{inputShape1, (long unsigned int*)position_id[i]};
inputData[inputName2]=migraphx::argument{inputShape2, (long unsigned int*)segment_id[i]}; inputData[inputName2] = migraphx::argument{inputShape2, (long unsigned int*)segment_id[i]};
inputData[inputName3]=migraphx::argument{inputShape3, (long unsigned int*)input_mask[i]}; inputData[inputName3] = migraphx::argument{inputShape3, (long unsigned int*)input_mask[i]};
inputData[inputName4]=migraphx::argument{inputShape4, (long unsigned int*)input_id[i]}; inputData[inputName4] = migraphx::argument{inputShape4, (long unsigned int*)input_id[i]};
// 推理 // 推理
results = net.eval(inputData); results = net.eval(inputData);
// 获取输出节点的属性 // 获取输出节点的属性
start_prediction = results[1]; // 答案的开始位置 start_prediction = results[1]; // 答案的开始位置
start_data = (float *)start_prediction.data(); // 开始位置的数据指针 start_data = (float*)start_prediction.data(); // 开始位置的数据指针
end_prediction = results[0]; // 答案的结束位置 end_prediction = results[0]; // 答案的结束位置
end_data = (float *)end_prediction.data(); // 结束位置的数据指针 end_data = (float*)end_prediction.data(); // 结束位置的数据指针
// 保存推理结果 // 保存推理结果
for(int i=0;i<256;++i) for(int i = 0; i < 256; ++i)
{ {
start_position.push_back(start_data[i]); start_position.push_back(start_data[i]);
end_position.push_back(end_data[i]); end_position.push_back(end_data[i]);
...@@ -142,11 +133,11 @@ ErrorCode Bert::Inference(const std::vector<std::vector<long unsigned int>> &inp ...@@ -142,11 +133,11 @@ ErrorCode Bert::Inference(const std::vector<std::vector<long unsigned int>> &inp
ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer, ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer,
int batch_size, int batch_size,
int max_seq_length, int max_seq_length,
const char *text, const char* text,
char *question, char* question,
std::vector<std::vector<long unsigned int>> &input_ids, std::vector<std::vector<long unsigned int>>& input_ids,
std::vector<std::vector<long unsigned int>> &input_masks, std::vector<std::vector<long unsigned int>>& input_masks,
std::vector<std::vector<long unsigned int>> &segment_ids) std::vector<std::vector<long unsigned int>>& segment_ids)
{ {
std::vector<long unsigned int> input_id(max_seq_length); std::vector<long unsigned int> input_id(max_seq_length);
std::vector<long unsigned int> input_mask(max_seq_length); std::vector<long unsigned int> input_mask(max_seq_length);
...@@ -161,18 +152,18 @@ ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer, ...@@ -161,18 +152,18 @@ ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer,
// 当上下文文本加问题文本的长度大于规定的最大长度,采用滑动窗口操作 // 当上下文文本加问题文本的长度大于规定的最大长度,采用滑动窗口操作
if(tokens_text.size() + tokens_question.size() > max_seq_length - 5) if(tokens_text.size() + tokens_question.size() > max_seq_length - 5)
{ {
int windows_len = max_seq_length - 5 -tokens_question.size(); int windows_len = max_seq_length - 5 - tokens_question.size();
std::vector<std::string> tokens_text_window(windows_len); std::vector<std::string> tokens_text_window(windows_len);
std::vector<std::vector<std::string>> tokens_text_windows; std::vector<std::vector<std::string>> tokens_text_windows;
int start_offset = 0; int start_offset = 0;
int position = 0; int position = 0;
int n; int n;
while (start_offset < tokens_text.size()) while(start_offset < tokens_text.size())
{ {
n = 0; n = 0;
if(start_offset+windows_len>tokens_text.size()) if(start_offset + windows_len > tokens_text.size())
{ {
for(int i=start_offset;i<tokens_text.size();++i) for(int i = start_offset; i < tokens_text.size(); ++i)
{ {
tokens_text_window[n] = tokens_text[i]; tokens_text_window[n] = tokens_text[i];
++n; ++n;
...@@ -180,7 +171,7 @@ ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer, ...@@ -180,7 +171,7 @@ ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer,
} }
else else
{ {
for(int i=start_offset;i<start_offset+windows_len;++i) for(int i = start_offset; i < start_offset + windows_len; ++i)
{ {
tokens_text_window[n] = tokens_text[i]; tokens_text_window[n] = tokens_text[i];
++n; ++n;
...@@ -191,7 +182,7 @@ ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer, ...@@ -191,7 +182,7 @@ ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer,
++position; ++position;
} }
for(int i=0;i<position;++i) for(int i = 0; i < position; ++i)
{ {
input_id[0] = tokenizer.convert_token_to_id("[CLS]"); input_id[0] = tokenizer.convert_token_to_id("[CLS]");
segment_id[0] = 0; segment_id[0] = 0;
...@@ -199,7 +190,7 @@ ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer, ...@@ -199,7 +190,7 @@ ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer,
input_id[1] = tokenizer.convert_token_to_id("[CLS]"); input_id[1] = tokenizer.convert_token_to_id("[CLS]");
segment_id[1] = 0; segment_id[1] = 0;
for (int j=0;j<tokens_question.size();++j) for(int j = 0; j < tokens_question.size(); ++j)
{ {
input_id[j + 2] = tokenizer.convert_token_to_id(tokens_question[j]); input_id[j + 2] = tokenizer.convert_token_to_id(tokens_question[j]);
segment_id[j + 2] = 0; segment_id[j + 2] = 0;
...@@ -211,13 +202,15 @@ ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer, ...@@ -211,13 +202,15 @@ ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer,
input_id[tokens_question.size() + 3] = tokenizer.convert_token_to_id("[SEP]"); input_id[tokens_question.size() + 3] = tokenizer.convert_token_to_id("[SEP]");
segment_id[tokens_question.size() + 3] = 0; segment_id[tokens_question.size() + 3] = 0;
for (int j=0;j<tokens_question.size();++j) for(int j = 0; j < tokens_question.size(); ++j)
{ {
input_id[j + tokens_text_windows[i].size() + 4] = tokenizer.convert_token_to_id(tokens_text_windows[i][j]); input_id[j + tokens_text_windows[i].size() + 4] =
tokenizer.convert_token_to_id(tokens_text_windows[i][j]);
segment_id[j + tokens_text_windows[i].size() + 4] = 1; segment_id[j + tokens_text_windows[i].size() + 4] = 1;
} }
input_id[tokens_question.size() + tokens_text_windows[i].size() + 4] = tokenizer.convert_token_to_id("[SEP]"); input_id[tokens_question.size() + tokens_text_windows[i].size() + 4] =
tokenizer.convert_token_to_id("[SEP]");
segment_id[tokens_question.size() + tokens_text_windows[i].size() + 4] = 1; segment_id[tokens_question.size() + tokens_text_windows[i].size() + 4] = 1;
// 掩码为1的表示为真实标记,0表示为填充标记。 // 掩码为1的表示为真实标记,0表示为填充标记。
...@@ -240,7 +233,7 @@ ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer, ...@@ -240,7 +233,7 @@ ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer,
input_id[1] = tokenizer.convert_token_to_id("[CLS]"); input_id[1] = tokenizer.convert_token_to_id("[CLS]");
segment_id[1] = 0; segment_id[1] = 0;
for (int i=0;i<tokens_question.size();++i) for(int i = 0; i < tokens_question.size(); ++i)
{ {
input_id[i + 2] = tokenizer.convert_token_to_id(tokens_question[i]); input_id[i + 2] = tokenizer.convert_token_to_id(tokens_question[i]);
segment_id[i + 2] = 0; segment_id[i + 2] = 0;
...@@ -252,13 +245,15 @@ ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer, ...@@ -252,13 +245,15 @@ ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer,
input_id[tokens_question.size() + 3] = tokenizer.convert_token_to_id("[SEP]"); input_id[tokens_question.size() + 3] = tokenizer.convert_token_to_id("[SEP]");
segment_id[tokens_question.size() + 3] = 0; segment_id[tokens_question.size() + 3] = 0;
for (int i=0;i<tokens_text.size();++i) for(int i = 0; i < tokens_text.size(); ++i)
{ {
input_id[i + tokens_question.size() + 4] = tokenizer.convert_token_to_id(tokens_text[i]); input_id[i + tokens_question.size() + 4] =
tokenizer.convert_token_to_id(tokens_text[i]);
segment_id[i + tokens_question.size() + 4] = 1; segment_id[i + tokens_question.size() + 4] = 1;
} }
input_id[tokens_question.size() + tokens_text.size() + 4] = tokenizer.convert_token_to_id("[SEP]"); input_id[tokens_question.size() + tokens_text.size() + 4] =
tokenizer.convert_token_to_id("[SEP]");
segment_id[tokens_question.size() + tokens_text.size() + 4] = 1; segment_id[tokens_question.size() + tokens_text.size() + 4] = 1;
// 掩码为1的表示为真实标记,0表示为填充标记。 // 掩码为1的表示为真实标记,0表示为填充标记。
...@@ -275,27 +270,25 @@ ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer, ...@@ -275,27 +270,25 @@ ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer,
return SUCCESS; return SUCCESS;
} }
static bool Compare(Sort_st a, Sort_st b) static bool Compare(Sort_st a, Sort_st b) { return a.value > b.value; }
{
return a.value > b.value;
}
static bool CompareM(ResultOfPredictions a, ResultOfPredictions b) static bool CompareM(ResultOfPredictions a, ResultOfPredictions b)
{ {
return a.start_predictionvalue + a.end_predictionvalue > b.start_predictionvalue + b.end_predictionvalue; return a.start_predictionvalue + a.end_predictionvalue >
b.start_predictionvalue + b.end_predictionvalue;
} }
ErrorCode Bert::Postprocessing(int n_best_size, ErrorCode Bert::Postprocessing(int n_best_size,
int max_answer_length, int max_answer_length,
const std::vector<float> &start_position, const std::vector<float>& start_position,
const std::vector<float> &end_position, const std::vector<float>& end_position,
std::string &answer) std::string& answer)
{ {
// 取前n_best_size个最大概率值的索引 // 取前n_best_size个最大概率值的索引
std::vector<Sort_st> start_array(start_position.size()); std::vector<Sort_st> start_array(start_position.size());
std::vector<Sort_st> end_array(end_position.size()); std::vector<Sort_st> end_array(end_position.size());
for (int i=0;i<start_position.size();++i) for(int i = 0; i < start_position.size(); ++i)
{ {
start_array[i].index = i; start_array[i].index = i;
start_array[i].value = start_position.at(i); start_array[i].value = start_position.at(i);
...@@ -309,10 +302,10 @@ ErrorCode Bert::Postprocessing(int n_best_size, ...@@ -309,10 +302,10 @@ ErrorCode Bert::Postprocessing(int n_best_size,
std::vector<ResultOfPredictions> resultsOfPredictions(400); std::vector<ResultOfPredictions> resultsOfPredictions(400);
int num = start_position.size() / 256; int num = start_position.size() / 256;
bool flag; bool flag;
int n=0; int n = 0;
for(int i=0;i<n_best_size;++i) for(int i = 0; i < n_best_size; ++i)
{ {
for(int j=0;j<n_best_size;++j) for(int j = 0; j < n_best_size; ++j)
{ {
flag = false; flag = false;
if(start_array[i].index > start_position.size()) if(start_array[i].index > start_position.size())
...@@ -325,15 +318,17 @@ ErrorCode Bert::Postprocessing(int n_best_size, ...@@ -325,15 +318,17 @@ ErrorCode Bert::Postprocessing(int n_best_size,
continue; continue;
} }
for(int t=0;t<num;++t) for(int t = 0; t < num; ++t)
{ {
if(start_array[i].index > t*256 && start_array[i].index < tokens_question.size()+4+t*256) if(start_array[i].index > t * 256 &&
start_array[i].index < tokens_question.size() + 4 + t * 256)
{ {
flag = true; flag = true;
break; break;
} }
if(end_array[j].index > t*256 && end_array[j].index < tokens_question.size()+4+t*256) if(end_array[j].index > t * 256 &&
end_array[j].index < tokens_question.size() + 4 + t * 256)
{ {
flag = true; flag = true;
break; break;
...@@ -368,9 +363,10 @@ ErrorCode Bert::Postprocessing(int n_best_size, ...@@ -368,9 +363,10 @@ ErrorCode Bert::Postprocessing(int n_best_size,
int start_index = 0; int start_index = 0;
int end_index = 0; int end_index = 0;
for(int i=0;i<400;++i) for(int i = 0; i < 400; ++i)
{ {
if(resultsOfPredictions[i].start_predictionvalue==0 && resultsOfPredictions[i].end_predictionvalue==0) if(resultsOfPredictions[i].start_predictionvalue == 0 &&
resultsOfPredictions[i].end_predictionvalue == 0)
{ {
continue; continue;
} }
...@@ -380,24 +376,24 @@ ErrorCode Bert::Postprocessing(int n_best_size, ...@@ -380,24 +376,24 @@ ErrorCode Bert::Postprocessing(int n_best_size,
} }
// 映射回上下文文本的索引,(当前的索引值-问题的长度-4) // 映射回上下文文本的索引,(当前的索引值-问题的长度-4)
int answer_start_index = start_index - tokens_question.size()- 4; int answer_start_index = start_index - tokens_question.size() - 4;
int answer_end_index = end_index - tokens_question.size() - 4 + 1; int answer_end_index = end_index - tokens_question.size() - 4 + 1;
// 根据开始索引和结束索引,获取区间内的数据 // 根据开始索引和结束索引,获取区间内的数据
int j=0; int j = 0;
for(int i=answer_start_index;i<answer_end_index;++i) for(int i = answer_start_index; i < answer_end_index; ++i)
{ {
if(tokens_text[i].find('#') != -1) if(tokens_text[i].find('#') != -1)
{ {
j=i-1; j = i - 1;
break; break;
} }
} }
for(int i=answer_start_index;i<answer_end_index;++i) for(int i = answer_start_index; i < answer_end_index; ++i)
{ {
answer += tokens_text[i]; answer += tokens_text[i];
if(tokens_text[i].find('#') != -1 || i==j) if(tokens_text[i].find('#') != -1 || i == j)
{ {
continue; continue;
} }
...@@ -405,9 +401,9 @@ ErrorCode Bert::Postprocessing(int n_best_size, ...@@ -405,9 +401,9 @@ ErrorCode Bert::Postprocessing(int n_best_size,
} }
int index = 0; int index = 0;
while( (index = answer.find('#',index)) != string::npos) while((index = answer.find('#', index)) != string::npos)
{ {
answer.erase(index,1); answer.erase(index, 1);
} }
tokens_text.clear(); tokens_text.clear();
tokens_question.clear(); tokens_question.clear();
...@@ -415,5 +411,4 @@ ErrorCode Bert::Postprocessing(int n_best_size, ...@@ -415,5 +411,4 @@ ErrorCode Bert::Postprocessing(int n_best_size,
return SUCCESS; return SUCCESS;
} }
} } // namespace migraphxSamples
#ifndef __BERT_H__ #ifndef __BERT_H__
#define __BERT_H__ #define __BERT_H__
#include <tokenization.h>
#include <cstdint> #include <cstdint>
#include <string>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <tokenization.h> #include <string>
namespace migraphxSamples namespace migraphxSamples {
typedef enum _ErrorCode
{ {
typedef enum _ErrorCode SUCCESS = 0,
{
SUCCESS=0,
MODEL_NOT_EXIST, MODEL_NOT_EXIST,
CONFIG_FILE_NOT_EXIST, CONFIG_FILE_NOT_EXIST,
FAIL_TO_LOAD_MODEL, FAIL_TO_LOAD_MODEL,
FAIL_TO_OPEN_CONFIG_FILE, FAIL_TO_OPEN_CONFIG_FILE,
}ErrorCode; } ErrorCode;
typedef struct _Sort_st typedef struct _Sort_st
{ {
int index; int index;
float value; float value;
}Sort_st; } Sort_st;
typedef struct _ResultOfPredictions typedef struct _ResultOfPredictions
{ {
int start_index; int start_index;
int end_index; int end_index;
float start_predictionvalue; float start_predictionvalue;
float end_predictionvalue; float end_predictionvalue;
}ResultOfPredictions; } ResultOfPredictions;
class Bert class Bert
{ {
public: public:
Bert(); Bert();
~Bert(); ~Bert();
ErrorCode Initialize(); ErrorCode Initialize();
ErrorCode Inference(const std::vector<std::vector<long unsigned int>> &input_ids, ErrorCode Inference(const std::vector<std::vector<long unsigned int>>& input_ids,
const std::vector<std::vector<long unsigned int>> &input_masks, const std::vector<std::vector<long unsigned int>>& input_masks,
const std::vector<std::vector<long unsigned int>> &segment_ids, const std::vector<std::vector<long unsigned int>>& segment_ids,
std::vector<float> &start_position, std::vector<float>& start_position,
std::vector<float> &end_position); std::vector<float>& end_position);
ErrorCode Preprocessing(cuBERT::FullTokenizer tokenizer, ErrorCode Preprocessing(cuBERT::FullTokenizer tokenizer,
int batch_size, int batch_size,
int max_seq_length, int max_seq_length,
const char *text, const char* text,
char *question, char* question,
std::vector<std::vector<long unsigned int>> &input_ids, std::vector<std::vector<long unsigned int>>& input_ids,
std::vector<std::vector<long unsigned int>> &input_masks, std::vector<std::vector<long unsigned int>>& input_masks,
std::vector<std::vector<long unsigned int>> &segment_ids); std::vector<std::vector<long unsigned int>>& segment_ids);
ErrorCode Postprocessing(int n_best_size, ErrorCode Postprocessing(int n_best_size,
int max_answer_length, int max_answer_length,
const std::vector<float> &start_position, const std::vector<float>& start_position,
const std::vector<float> &end_position, const std::vector<float>& end_position,
std::string &answer); std::string& answer);
private: private:
std::vector<std::string> tokens_text; std::vector<std::string> tokens_text;
std::vector<std::string> tokens_question; std::vector<std::string> tokens_question;
...@@ -74,9 +74,8 @@ private: ...@@ -74,9 +74,8 @@ private:
migraphx::shape inputShape2; migraphx::shape inputShape2;
migraphx::shape inputShape3; migraphx::shape inputShape3;
migraphx::shape inputShape4; migraphx::shape inputShape4;
}; };
} } // namespace migraphxSamples
#endif #endif
\ No newline at end of file
This diff is collapsed.
...@@ -6,23 +6,22 @@ ...@@ -6,23 +6,22 @@
#include <string> #include <string>
#include <vector> #include <vector>
namespace migraphxSamples namespace migraphxSamples {
{
// 路径是否存在 // 路径是否存在
bool Exists(const std::string &path); bool Exists(const std::string& path);
// 路径是否为目录 // 路径是否为目录
bool IsDirectory(const std::string &path); bool IsDirectory(const std::string& path);
// 是否是路径分隔符(Linux:‘/’,Windows:’\\’) // 是否是路径分隔符(Linux:‘/’,Windows:’\\’)
bool IsPathSeparator(char c); bool IsPathSeparator(char c);
// 路径拼接 // 路径拼接
std::string JoinPath(const std::string &base, const std::string &path); std::string JoinPath(const std::string& base, const std::string& path);
// 创建多级目录,注意:创建多级目录的时候,目标目录是不能有文件存在的 // 创建多级目录,注意:创建多级目录的时候,目标目录是不能有文件存在的
bool CreateDirectories(const std::string &directoryPath); bool CreateDirectories(const std::string& directoryPath);
/** 生成符合指定模式的文件名列表(支持递归遍历) /** 生成符合指定模式的文件名列表(支持递归遍历)
* *
...@@ -30,31 +29,40 @@ bool CreateDirectories(const std::string &directoryPath); ...@@ -30,31 +29,40 @@ bool CreateDirectories(const std::string &directoryPath);
* addPath:是否包含父路径 * addPath:是否包含父路径
* 注意: * 注意:
1. 多个模式使用","分割,比如"*.jpg,*.png" 1. 多个模式使用","分割,比如"*.jpg,*.png"
2. 支持通配符'*','?' ,比如第一个字符是7的所有文件名:"7*.*", 以512结尾的所有jpg文件名:"*512.jpg" 2. 支持通配符'*','?' ,比如第一个字符是7的所有文件名:"7*.*",
以512结尾的所有jpg文件名:"*512.jpg"
3. 使用"*.jpg",而不是".jpg" 3. 使用"*.jpg",而不是".jpg"
4. 空string表示返回所有结果 4. 空string表示返回所有结果
5. 不能返回子目录名 5. 不能返回子目录名
* *
*/ */
void GetFileNameList(const std::string &directory, const std::string &pattern, std::vector<std::string> &result, bool recursive, bool addPath); void GetFileNameList(const std::string& directory,
const std::string& pattern,
std::vector<std::string>& result,
bool recursive,
bool addPath);
// 与GetFileNameList的区别在于如果有子目录,在addPath为true的时候会返回子目录路径(目录名最后有"/") // 与GetFileNameList的区别在于如果有子目录,在addPath为true的时候会返回子目录路径(目录名最后有"/")
void GetFileNameList2(const std::string &directory, const std::string &pattern, std::vector<std::string> &result, bool recursive, bool addPath); void GetFileNameList2(const std::string& directory,
const std::string& pattern,
std::vector<std::string>& result,
bool recursive,
bool addPath);
// 删除文件或者目录,支持递归删除 // 删除文件或者目录,支持递归删除
void Remove(const std::string &directory, const std::string &extension=""); void Remove(const std::string& directory, const std::string& extension = "");
/** 获取路径的文件名和扩展名 /** 获取路径的文件名和扩展名
* *
* 示例:path为D:/1/1.txt,则GetFileName()为1.txt,GetFileName_NoExtension()为1,GetExtension()为.txt,GetParentPath()为D:/1/ * 示例:path为D:/1/1.txt,则GetFileName()为1.txt,GetFileName_NoExtension()为1,GetExtension()为.txt,GetParentPath()为D:/1/
*/ */
std::string GetFileName(const std::string &path); std::string GetFileName(const std::string& path);
std::string GetFileName_NoExtension(const std::string &path); std::string GetFileName_NoExtension(const std::string& path);
std::string GetExtension(const std::string &path); std::string GetExtension(const std::string& path);
std::string GetParentPath(const std::string &path); std::string GetParentPath(const std::string& path);
// 拷贝文件 // 拷贝文件
bool CopyFile(const std::string srcPath,const std::string dstPath); bool CopyFile(const std::string srcPath, const std::string dstPath);
/** 拷贝目录 /** 拷贝目录
* *
...@@ -63,8 +71,8 @@ bool CopyFile(const std::string srcPath,const std::string dstPath); ...@@ -63,8 +71,8 @@ bool CopyFile(const std::string srcPath,const std::string dstPath);
1.第一个参数的最后不能加”/” 1.第一个参数的最后不能加”/”
2.不能拷贝隐藏文件 2.不能拷贝隐藏文件
*/ */
bool CopyDirectories(std::string srcPath,const std::string dstPath); bool CopyDirectories(std::string srcPath, const std::string dstPath);
} } // namespace migraphxSamples
#endif #endif
...@@ -4,11 +4,13 @@ ...@@ -4,11 +4,13 @@
#define __SIMPLE_LOG_H__ #define __SIMPLE_LOG_H__
#include <time.h> #include <time.h>
#include <string>
#include <map> #include <map>
#include <thread>
#include <mutex> #include <mutex>
#if (defined WIN32 || defined _WIN32) #include <string>
#include <thread>
#if(defined WIN32 || defined _WIN32)
#include <Windows.h> #include <Windows.h>
#else #else
#include <sys/time.h> #include <sys/time.h>
...@@ -16,20 +18,22 @@ ...@@ -16,20 +18,22 @@
using namespace std; using namespace std;
/** 简易日志 /** 简易日志
* *
* 不依赖于其他第三方库,只需要包含一个头文件就可以使用。提供了4种日志级别,包括INFO,DEBUG,WARN和ERROR。 * 不依赖于其他第三方库,只需要包含一个头文件就可以使用。提供了4种日志级别,包括INFO,DEBUG,WARN和ERROR。
* *
* 示例1: * 示例1:
// 初始化日志,在./Log/目录下创建两个日志文件log1.log和log2.log(注意:目录./Log/需要存在,否则日志创建失败) //
初始化日志,在./Log/目录下创建两个日志文件log1.log和log2.log(注意:目录./Log/需要存在,否则日志创建失败)
LogManager::GetInstance()->Initialize("./Log/","log1"); LogManager::GetInstance()->Initialize("./Log/","log1");
LogManager::GetInstance()->Initialize("./Log/","log2"); LogManager::GetInstance()->Initialize("./Log/","log2");
// 写日志 // 写日志
string log = "Hello World"; string log = "Hello World";
LOG_INFO(LogManager::GetInstance()->GetLogFile("log1"), "%s\n", log.c_str()); // 写入log1.log LOG_INFO(LogManager::GetInstance()->GetLogFile("log1"), "%s\n",
LOG_INFO(LogManager::GetInstance()->GetLogFile("log2"), "%s\n", log.c_str()); // 写入log2.log log.c_str()); // 写入log1.log
LOG_INFO(LogManager::GetInstance()->GetLogFile("log2"), "%s\n",
log.c_str()); // 写入log2.log
// 关闭日志 // 关闭日志
LogManager::GetInstance()->Close("log1"); LogManager::GetInstance()->Close("log1");
...@@ -50,44 +54,43 @@ using namespace std; ...@@ -50,44 +54,43 @@ using namespace std;
class LogManager class LogManager
{ {
private: private:
LogManager(){} LogManager() {}
public: public:
~LogManager(){} ~LogManager() {}
inline void Initialize(const string &parentPath,const string &logName) inline void Initialize(const string& parentPath, const string& logName)
{ {
// 日志名为空表示输出到控制台 // 日志名为空表示输出到控制台
if(logName.size()==0) if(logName.size() == 0)
return; return;
// 查找该日志文件,如果没有则创建 // 查找该日志文件,如果没有则创建
std::map<string, FILE*>::const_iterator iter = logMap.find(logName); std::map<string, FILE*>::const_iterator iter = logMap.find(logName);
if (iter == logMap.end()) if(iter == logMap.end())
{ {
string pathOfLog = parentPath+ logName + ".log"; string pathOfLog = parentPath + logName + ".log";
FILE *logFile = fopen(pathOfLog.c_str(), "a"); // w:覆盖原有文件,a:追加 FILE* logFile = fopen(pathOfLog.c_str(), "a"); // w:覆盖原有文件,a:追加
if(logFile!=NULL) if(logFile != NULL)
{ {
logMap.insert(std::make_pair(logName, logFile)); logMap.insert(std::make_pair(logName, logFile));
} }
} }
} }
inline FILE* GetLogFile(const string &logName) inline FILE* GetLogFile(const string& logName)
{ {
std::map<string, FILE*>::const_iterator iter=logMap.find(logName); std::map<string, FILE*>::const_iterator iter = logMap.find(logName);
if(iter==logMap.end()) if(iter == logMap.end())
{ {
return NULL; return NULL;
} }
return (*iter).second; return (*iter).second;
} }
inline void Close(const string &logName) inline void Close(const string& logName)
{ {
std::map<string, FILE*>::const_iterator iter=logMap.find(logName); std::map<string, FILE*>::const_iterator iter = logMap.find(logName);
if(iter==logMap.end()) if(iter == logMap.end())
{ {
return; return;
} }
...@@ -95,10 +98,7 @@ public: ...@@ -95,10 +98,7 @@ public:
fclose((*iter).second); fclose((*iter).second);
logMap.erase(iter); logMap.erase(iter);
} }
inline std::mutex &GetLogMutex() inline std::mutex& GetLogMutex() { return logMutex; }
{
return logMutex;
}
// Singleton // Singleton
static LogManager* GetInstance() static LogManager* GetInstance()
...@@ -106,17 +106,18 @@ public: ...@@ -106,17 +106,18 @@ public:
static LogManager logManager; static LogManager logManager;
return &logManager; return &logManager;
} }
private:
private:
std::map<string, FILE*> logMap; std::map<string, FILE*> logMap;
std::mutex logMutex; std::mutex logMutex;
}; };
#ifdef LOG_MUTEX #ifdef LOG_MUTEX
#define LOCK LogManager::GetInstance()->GetLogMutex().lock() #define LOCK LogManager::GetInstance()->GetLogMutex().lock()
#define UNLOCK LogManager::GetInstance()->GetLogMutex().unlock() #define UNLOCK LogManager::GetInstance()->GetLogMutex().unlock()
#else #else
#define LOCK #define LOCK
#define UNLOCK #define UNLOCK
#endif #endif
// log time // log time
...@@ -131,53 +132,53 @@ typedef struct _LogTime ...@@ -131,53 +132,53 @@ typedef struct _LogTime
string millisecond; // ms string millisecond; // ms
string microsecond; // us string microsecond; // us
string weekDay; string weekDay;
}LogTime; } LogTime;
inline LogTime GetTime() inline LogTime GetTime()
{ {
LogTime currentTime; LogTime currentTime;
#if (defined WIN32 || defined _WIN32) #if(defined WIN32 || defined _WIN32)
SYSTEMTIME systemTime; SYSTEMTIME systemTime;
GetLocalTime(&systemTime); GetLocalTime(&systemTime);
char temp[8] = { 0 }; char temp[8] = {0};
sprintf(temp, "%04d", systemTime.wYear); sprintf(temp, "%04d", systemTime.wYear);
currentTime.year=string(temp); currentTime.year = string(temp);
sprintf(temp, "%02d", systemTime.wMonth); sprintf(temp, "%02d", systemTime.wMonth);
currentTime.month=string(temp); currentTime.month = string(temp);
sprintf(temp, "%02d", systemTime.wDay); sprintf(temp, "%02d", systemTime.wDay);
currentTime.day=string(temp); currentTime.day = string(temp);
sprintf(temp, "%02d", systemTime.wHour); sprintf(temp, "%02d", systemTime.wHour);
currentTime.hour=string(temp); currentTime.hour = string(temp);
sprintf(temp, "%02d", systemTime.wMinute); sprintf(temp, "%02d", systemTime.wMinute);
currentTime.minute=string(temp); currentTime.minute = string(temp);
sprintf(temp, "%02d", systemTime.wSecond); sprintf(temp, "%02d", systemTime.wSecond);
currentTime.second=string(temp); currentTime.second = string(temp);
sprintf(temp, "%03d", systemTime.wMilliseconds); sprintf(temp, "%03d", systemTime.wMilliseconds);
currentTime.millisecond=string(temp); currentTime.millisecond = string(temp);
sprintf(temp, "%d", systemTime.wDayOfWeek); sprintf(temp, "%d", systemTime.wDayOfWeek);
currentTime.weekDay=string(temp); currentTime.weekDay = string(temp);
#else #else
struct timeval tv; struct timeval tv;
struct tm *p; struct tm* p;
gettimeofday(&tv, NULL); gettimeofday(&tv, NULL);
p = localtime(&tv.tv_sec); p = localtime(&tv.tv_sec);
char temp[8]={0}; char temp[8] = {0};
sprintf(temp,"%04d",1900+p->tm_year); sprintf(temp, "%04d", 1900 + p->tm_year);
currentTime.year=string(temp); currentTime.year = string(temp);
sprintf(temp,"%02d",1+p->tm_mon); sprintf(temp, "%02d", 1 + p->tm_mon);
currentTime.month=string(temp); currentTime.month = string(temp);
sprintf(temp,"%02d",p->tm_mday); sprintf(temp, "%02d", p->tm_mday);
currentTime.day=string(temp); currentTime.day = string(temp);
sprintf(temp,"%02d",p->tm_hour); sprintf(temp, "%02d", p->tm_hour);
currentTime.hour=string(temp); currentTime.hour = string(temp);
sprintf(temp,"%02d",p->tm_min); sprintf(temp, "%02d", p->tm_min);
currentTime.minute=string(temp); currentTime.minute = string(temp);
sprintf(temp,"%02d",p->tm_sec); sprintf(temp, "%02d", p->tm_sec);
currentTime.second=string(temp); currentTime.second = string(temp);
sprintf(temp,"%03d",(int)(tv.tv_usec/1000)); sprintf(temp, "%03d", (int)(tv.tv_usec / 1000));
currentTime.millisecond = string(temp); currentTime.millisecond = string(temp);
sprintf(temp, "%03d", (int)(tv.tv_usec % 1000)); sprintf(temp, "%03d", (int)(tv.tv_usec % 1000));
currentTime.microsecond = string(temp); currentTime.microsecond = string(temp);
...@@ -188,60 +189,82 @@ inline LogTime GetTime() ...@@ -188,60 +189,82 @@ inline LogTime GetTime()
} }
#define LOG_TIME(logFile) \ #define LOG_TIME(logFile) \
do\ do \
{\ { \
LogTime currentTime=GetTime(); \ LogTime currentTime = GetTime(); \
fprintf(((logFile == NULL) ? stdout : logFile), "%s-%s-%s %s:%s:%s.%s\t",currentTime.year.c_str(),currentTime.month.c_str(),currentTime.day.c_str(),currentTime.hour.c_str(),currentTime.minute.c_str(),currentTime.second.c_str(),currentTime.millisecond.c_str()); \ fprintf(((logFile == NULL) ? stdout : logFile), \
}while (0) "%s-%s-%s %s:%s:%s.%s\t", \
currentTime.year.c_str(), \
currentTime.month.c_str(), \
#define LOG_INFO(logFile,logInfo, ...) \ currentTime.day.c_str(), \
do\ currentTime.hour.c_str(), \
{\ currentTime.minute.c_str(), \
currentTime.second.c_str(), \
currentTime.millisecond.c_str()); \
} while(0)
#define LOG_INFO(logFile, logInfo, ...) \
do \
{ \
LOCK; \ LOCK; \
LOG_TIME(logFile); \ LOG_TIME(logFile); \
fprintf(((logFile == NULL) ? stdout : logFile), "INFO\t"); \ fprintf(((logFile == NULL) ? stdout : logFile), "INFO\t"); \
fprintf(((logFile == NULL) ? stdout : logFile), "[%s:%d (%s) ]: ", __FILE__, __LINE__, __FUNCTION__); \ fprintf(((logFile == NULL) ? stdout : logFile), \
fprintf(((logFile == NULL) ? stdout : logFile), logInfo, ## __VA_ARGS__); \ "[%s:%d (%s) ]: ", \
__FILE__, \
__LINE__, \
__FUNCTION__); \
fprintf(((logFile == NULL) ? stdout : logFile), logInfo, ##__VA_ARGS__); \
fflush(logFile); \ fflush(logFile); \
UNLOCK; \ UNLOCK; \
} while (0) } while(0)
#define LOG_DEBUG(logFile,logInfo, ...) \ #define LOG_DEBUG(logFile, logInfo, ...) \
do\ do \
{\ { \
LOCK; \ LOCK; \
LOG_TIME(logFile);\ LOG_TIME(logFile); \
fprintf(((logFile==NULL)?stdout:logFile), "DEBUG\t"); \ fprintf(((logFile == NULL) ? stdout : logFile), "DEBUG\t"); \
fprintf(((logFile==NULL)?stdout:logFile), "[%s:%d (%s) ]: ", __FILE__, __LINE__, __FUNCTION__); \ fprintf(((logFile == NULL) ? stdout : logFile), \
fprintf(((logFile==NULL)?stdout:logFile),logInfo, ## __VA_ARGS__); \ "[%s:%d (%s) ]: ", \
__FILE__, \
__LINE__, \
__FUNCTION__); \
fprintf(((logFile == NULL) ? stdout : logFile), logInfo, ##__VA_ARGS__); \
fflush(logFile); \ fflush(logFile); \
UNLOCK; \ UNLOCK; \
} while (0) } while(0)
#define LOG_ERROR(logFile,logInfo, ...) \ #define LOG_ERROR(logFile, logInfo, ...) \
do\ do \
{\ { \
LOCK; \ LOCK; \
LOG_TIME(logFile);\ LOG_TIME(logFile); \
fprintf(((logFile==NULL)?stdout:logFile), "ERROR\t"); \ fprintf(((logFile == NULL) ? stdout : logFile), "ERROR\t"); \
fprintf(((logFile==NULL)?stdout:logFile), "[%s:%d (%s) ]: ", __FILE__, __LINE__, __FUNCTION__); \ fprintf(((logFile == NULL) ? stdout : logFile), \
fprintf(((logFile==NULL)?stdout:logFile),logInfo, ## __VA_ARGS__); \ "[%s:%d (%s) ]: ", \
__FILE__, \
__LINE__, \
__FUNCTION__); \
fprintf(((logFile == NULL) ? stdout : logFile), logInfo, ##__VA_ARGS__); \
fflush(logFile); \ fflush(logFile); \
UNLOCK; \ UNLOCK; \
} while (0) } while(0)
#define LOG_WARN(logFile,logInfo, ...) \ #define LOG_WARN(logFile, logInfo, ...) \
do\ do \
{\ { \
LOCK; \ LOCK; \
LOG_TIME(logFile);\ LOG_TIME(logFile); \
fprintf(((logFile==NULL)?stdout:logFile), "WARN\t"); \ fprintf(((logFile == NULL) ? stdout : logFile), "WARN\t"); \
fprintf(((logFile==NULL)?stdout:logFile), "[%s:%d (%s) ]: ", __FILE__, __LINE__, __FUNCTION__); \ fprintf(((logFile == NULL) ? stdout : logFile), \
fprintf(((logFile==NULL)?stdout:logFile),logInfo, ## __VA_ARGS__); \ "[%s:%d (%s) ]: ", \
__FILE__, \
__LINE__, \
__FUNCTION__); \
fprintf(((logFile == NULL) ? stdout : logFile), logInfo, ##__VA_ARGS__); \
fflush(logFile); \ fflush(logFile); \
UNLOCK; \ UNLOCK; \
} while (0) } while(0)
#endif // __SIMPLE_LOG_H__ #endif // __SIMPLE_LOG_H__
#include <stdexcept>
#include <algorithm> #include <algorithm>
#include <cstring> #include <cstring>
#include <fstream> #include <fstream>
#include "utf8proc.h" #include <stdexcept>
#include "./tokenization.h" #include "./tokenization.h"
#include "utf8proc.h"
namespace cuBERT { namespace cuBERT {
void FullTokenizer::convert_tokens_to_ids(const std::vector<std::string> &tokens, uint64_t *ids) { void FullTokenizer::convert_tokens_to_ids(const std::vector<std::string>& tokens, uint64_t* ids)
for (int i = 0; i < tokens.size(); ++i) { {
for(int i = 0; i < tokens.size(); ++i)
{
ids[i] = convert_token_to_id(tokens[i]); ids[i] = convert_token_to_id(tokens[i]);
} }
} }
// trim from start (in place) // trim from start (in place)
static inline void ltrim(std::string &s) { static inline void ltrim(std::string& s)
s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int ch) { {
return !std::isspace(ch); s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int ch) { return !std::isspace(ch); }));
})); }
}
// trim from end (in place) // trim from end (in place)
static inline void rtrim(std::string &s) { static inline void rtrim(std::string& s)
s.erase(std::find_if(s.rbegin(), s.rend(), [](int ch) { {
return !std::isspace(ch); s.erase(std::find_if(s.rbegin(), s.rend(), [](int ch) { return !std::isspace(ch); }).base(),
}).base(), s.end()); s.end());
} }
// trim from both ends (in place) // trim from both ends (in place)
static inline void trim(std::string &s) { static inline void trim(std::string& s)
{
ltrim(s); ltrim(s);
rtrim(s); rtrim(s);
} }
void load_vocab(const char *vocab_file, std::unordered_map<std::string, uint64_t> *vocab) { void load_vocab(const char* vocab_file, std::unordered_map<std::string, uint64_t>* vocab)
{
std::ifstream file(vocab_file); std::ifstream file(vocab_file);
if (!file) { if(!file)
{
throw std::invalid_argument("Unable to open vocab file"); throw std::invalid_argument("Unable to open vocab file");
} }
unsigned int index = 0; unsigned int index = 0;
std::string line; std::string line;
while (std::getline(file, line)) { while(std::getline(file, line))
{
trim(line); trim(line);
(*vocab)[line] = index; (*vocab)[line] = index;
index++; index++;
} }
file.close(); file.close();
} }
inline bool _is_whitespace(int c, const char *cat) { inline bool _is_whitespace(int c, const char* cat)
if (c == ' ' || c == '\t' || c == '\n' || c == '\r') { {
if(c == ' ' || c == '\t' || c == '\n' || c == '\r')
{
return true; return true;
} }
return cat[0] == 'Z' && cat[1] == 's'; return cat[0] == 'Z' && cat[1] == 's';
} }
inline bool _is_control(int c, const char *cat) { inline bool _is_control(int c, const char* cat)
// These are technically control characters but we count them as whitespace characters. {
if (c == '\t' || c == '\n' || c == '\r') { // These are technically control characters but we count them as whitespace
// characters.
if(c == '\t' || c == '\n' || c == '\r')
{
return false; return false;
} }
return 'C' == *cat; return 'C' == *cat;
} }
inline bool _is_punctuation(int cp, const char *cat) { inline bool _is_punctuation(int cp, const char* cat)
// We treat all non-letter/number ASCII as punctuation. {
// Characters such as "^", "$", and "`" are not in the Unicode // We treat all non-letter/number ASCII as punctuation.
// Punctuation class but we treat them as punctuation anyways, for // Characters such as "^", "$", and "`" are not in the Unicode
// consistency. // Punctuation class but we treat them as punctuation anyways, for
if ((cp >= 33 && cp <= 47) || (cp >= 58 && cp <= 64) || // consistency.
(cp >= 91 && cp <= 96) || (cp >= 123 && cp <= 126)) { if((cp >= 33 && cp <= 47) || (cp >= 58 && cp <= 64) || (cp >= 91 && cp <= 96) ||
(cp >= 123 && cp <= 126))
{
return true; return true;
} }
return 'P' == *cat; return 'P' == *cat;
} }
bool _is_whitespace(int c) {
return _is_whitespace(c, utf8proc_category_string(c));
}
bool _is_control(int c) {
return _is_control(c, utf8proc_category_string(c));
}
bool _is_punctuation(int cp) {
return _is_punctuation(cp, utf8proc_category_string(cp));
}
bool BasicTokenizer::_is_chinese_char(int cp) { bool _is_whitespace(int c) { return _is_whitespace(c, utf8proc_category_string(c)); }
// This defines a "chinese character" as anything in the CJK Unicode block:
// https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) bool _is_control(int c) { return _is_control(c, utf8proc_category_string(c)); }
//
// Note that the CJK Unicode block is NOT all Japanese and Korean characters, bool _is_punctuation(int cp) { return _is_punctuation(cp, utf8proc_category_string(cp)); }
// despite its name. The modern Korean Hangul alphabet is a different block,
// as is Japanese Hiragana and Katakana. Those alphabets are used to write bool BasicTokenizer::_is_chinese_char(int cp)
// space-separated words, so they are not treated specially and handled {
// like the all of the other languages. // This defines a "chinese character" as anything in the CJK Unicode block:
return (cp >= 0x4E00 && cp <= 0x9FFF) || // https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
(cp >= 0x3400 && cp <= 0x4DBF) || //
(cp >= 0x20000 && cp <= 0x2A6DF) || // Note that the CJK Unicode block is NOT all Japanese and Korean characters,
(cp >= 0x2A700 && cp <= 0x2B73F) || // despite its name. The modern Korean Hangul alphabet is a different block,
(cp >= 0x2B740 && cp <= 0x2B81F) || // as is Japanese Hiragana and Katakana. Those alphabets are used to write
(cp >= 0x2B820 && cp <= 0x2CEAF) || // space-separated words, so they are not treated specially and handled
(cp >= 0xF900 && cp <= 0xFAFF) || // like the all of the other languages.
(cp >= 0x2F800 && cp <= 0x2FA1F); return (cp >= 0x4E00 && cp <= 0x9FFF) || (cp >= 0x3400 && cp <= 0x4DBF) ||
} (cp >= 0x20000 && cp <= 0x2A6DF) || (cp >= 0x2A700 && cp <= 0x2B73F) ||
(cp >= 0x2B740 && cp <= 0x2B81F) || (cp >= 0x2B820 && cp <= 0x2CEAF) ||
(cp >= 0xF900 && cp <= 0xFAFF) || (cp >= 0x2F800 && cp <= 0x2FA1F);
}
void BasicTokenizer::tokenize(const char *text, std::vector<std::string> *output_tokens, size_t max_length) { void BasicTokenizer::tokenize(const char* text,
// This was added on November 1st, 2018 for the multilingual and Chinese std::vector<std::string>* output_tokens,
// models. This is also applied to the English models now, but it doesn't size_t max_length)
// matter since the English models were not trained on any Chinese data {
// and generally don't have any Chinese data in them (there are Chinese // This was added on November 1st, 2018 for the multilingual and Chinese
// characters in the vocabulary because Wikipedia does have some Chinese // models. This is also applied to the English models now, but it doesn't
// words in the English Wikipedia.). // matter since the English models were not trained on any Chinese data
if (do_lower_case) { // and generally don't have any Chinese data in them (there are Chinese
text = (const char *) utf8proc_NFD((const utf8proc_uint8_t *) text); // characters in the vocabulary because Wikipedia does have some Chinese
// words in the English Wikipedia.).
if(do_lower_case)
{
text = (const char*)utf8proc_NFD((const utf8proc_uint8_t*)text);
} }
size_t word_bytes = std::strlen(text); size_t word_bytes = std::strlen(text);
...@@ -127,39 +133,56 @@ namespace cuBERT { ...@@ -127,39 +133,56 @@ namespace cuBERT {
int cp; int cp;
char dst[4]; char dst[4];
while (word_bytes > 0) { while(word_bytes > 0)
int len = utf8proc_iterate((const utf8proc_uint8_t *) text + subpos, word_bytes, &cp); {
if (len < 0) { int len = utf8proc_iterate((const utf8proc_uint8_t*)text + subpos, word_bytes, &cp);
if(len < 0)
{
std::cerr << "UTF-8 decode error: " << text << std::endl; std::cerr << "UTF-8 decode error: " << text << std::endl;
break; break;
} }
if (do_lower_case) { if(do_lower_case)
{
cp = utf8proc_tolower(cp); cp = utf8proc_tolower(cp);
} }
const char *cat = utf8proc_category_string(cp); const char* cat = utf8proc_category_string(cp);
if (cp == 0 || cp == 0xfffd || _is_control(cp, cat)) { if(cp == 0 || cp == 0xfffd || _is_control(cp, cat))
{
// pass // pass
} else if (do_lower_case && cat[0] == 'M' && cat[1] == 'n') { }
else if(do_lower_case && cat[0] == 'M' && cat[1] == 'n')
{
// pass // pass
} else if (_is_whitespace(cp, cat)) { }
else if(_is_whitespace(cp, cat))
{
new_token = true; new_token = true;
} else { }
else
{
size_t dst_len = len; size_t dst_len = len;
const char *dst_ptr = text + subpos; const char* dst_ptr = text + subpos;
if (do_lower_case) { if(do_lower_case)
dst_len = utf8proc_encode_char(cp, (utf8proc_uint8_t *) dst); {
dst_len = utf8proc_encode_char(cp, (utf8proc_uint8_t*)dst);
dst_ptr = dst; dst_ptr = dst;
} }
if (_is_punctuation(cp, cat) || _is_chinese_char(cp)) { if(_is_punctuation(cp, cat) || _is_chinese_char(cp))
{
output_tokens->emplace_back(dst_ptr, dst_len); output_tokens->emplace_back(dst_ptr, dst_len);
new_token = true; new_token = true;
} else { }
if (new_token) { else
{
if(new_token)
{
output_tokens->emplace_back(dst_ptr, dst_len); output_tokens->emplace_back(dst_ptr, dst_len);
new_token = false; new_token = false;
} else { }
else
{
output_tokens->at(output_tokens->size() - 1).append(dst_ptr, dst_len); output_tokens->at(output_tokens->size() - 1).append(dst_ptr, dst_len);
} }
} }
...@@ -169,33 +192,38 @@ namespace cuBERT { ...@@ -169,33 +192,38 @@ namespace cuBERT {
subpos = subpos + len; subpos = subpos + len;
// early terminate // early terminate
if (output_tokens->size() >= max_length) { if(output_tokens->size() >= max_length)
{
break; break;
} }
} }
if (do_lower_case) { if(do_lower_case)
free((void *) text); {
free((void*)text);
} }
} }
void WordpieceTokenizer::tokenize(const std::string &token, std::vector<std::string> *output_tokens) { void WordpieceTokenizer::tokenize(const std::string& token, std::vector<std::string>* output_tokens)
if (token.size() > max_input_chars_per_word) { // FIXME: slightly different {
if(token.size() > max_input_chars_per_word)
{ // FIXME: slightly different
output_tokens->push_back(unk_token); output_tokens->push_back(unk_token);
return; return;
} }
size_t output_tokens_len = output_tokens->size(); size_t output_tokens_len = output_tokens->size();
for (size_t start = 0; start < token.size();) { for(size_t start = 0; start < token.size();)
{
bool is_bad = true; bool is_bad = true;
// TODO: can be optimized by prefix-tree // TODO: can be optimized by prefix-tree
for (size_t end = token.size(); start < end; --end) { // FIXME: slightly different for(size_t end = token.size(); start < end; --end)
std::string substr = start > 0 { // FIXME: slightly different
? "##" + token.substr(start, end - start) std::string substr = start > 0 ? "##" + token.substr(start, end - start)
: token.substr(start, end - start); : token.substr(start, end - start);
if (vocab->count(substr)) { if(vocab->count(substr))
{
is_bad = false; is_bad = false;
output_tokens->push_back(substr); output_tokens->push_back(substr);
start = end; start = end;
...@@ -203,27 +231,32 @@ namespace cuBERT { ...@@ -203,27 +231,32 @@ namespace cuBERT {
} }
} }
if (is_bad) { if(is_bad)
{
output_tokens->resize(output_tokens_len); output_tokens->resize(output_tokens_len);
output_tokens->push_back(unk_token); output_tokens->push_back(unk_token);
return; return;
} }
} }
} }
void FullTokenizer::tokenize(const char *text, std::vector<std::string> *output_tokens, size_t max_length) { void FullTokenizer::tokenize(const char* text,
std::vector<std::string>* output_tokens,
size_t max_length)
{
std::vector<std::string> tokens; std::vector<std::string> tokens;
tokens.reserve(max_length); tokens.reserve(max_length);
basic_tokenizer->tokenize(text, &tokens, max_length); basic_tokenizer->tokenize(text, &tokens, max_length);
for (const auto &token : tokens) { for(const auto& token : tokens)
{
wordpiece_tokenizer->tokenize(token, output_tokens); wordpiece_tokenizer->tokenize(token, output_tokens);
// early terminate // early terminate
if (output_tokens->size() >= max_length) { if(output_tokens->size() >= max_length)
{
break; break;
} }
} }
}
} }
} // namespace cuBERT
#ifndef CUBERT_TOKENIZATION_H #ifndef CUBERT_TOKENIZATION_H
#define CUBERT_TOKENIZATION_H #define CUBERT_TOKENIZATION_H
#include <iostream>
#include <string> #include <string>
#include <vector>
#include <unordered_map> #include <unordered_map>
#include <iostream> #include <vector>
namespace cuBERT { namespace cuBERT {
void load_vocab(const char *vocab_file, std::unordered_map<std::string, uint64_t> *vocab); void load_vocab(const char* vocab_file, std::unordered_map<std::string, uint64_t>* vocab);
/** /**
* Checks whether `chars` is a whitespace character. * Checks whether `chars` is a whitespace character.
* @param c * @param c
* @return * @return
*/ */
bool _is_whitespace(int c); bool _is_whitespace(int c);
/** /**
* Checks whether `chars` is a control character. * Checks whether `chars` is a control character.
* @param c * @param c
* @return * @return
*/ */
bool _is_control(int c); bool _is_control(int c);
/** /**
* Checks whether `chars` is a punctuation character. * Checks whether `chars` is a punctuation character.
* @param cp * @param cp
* @return * @return
*/ */
bool _is_punctuation(int cp); bool _is_punctuation(int cp);
/** /**
* Runs basic tokenization (punctuation splitting, lower casing, etc.). * Runs basic tokenization (punctuation splitting, lower casing, etc.).
*/ */
class BasicTokenizer { class BasicTokenizer
{
public: public:
/** /**
* Constructs a BasicTokenizer. * Constructs a BasicTokenizer.
...@@ -42,7 +43,7 @@ namespace cuBERT { ...@@ -42,7 +43,7 @@ namespace cuBERT {
*/ */
explicit BasicTokenizer(bool do_lower_case = true) : do_lower_case(do_lower_case) {} explicit BasicTokenizer(bool do_lower_case = true) : do_lower_case(do_lower_case) {}
BasicTokenizer(const BasicTokenizer &other) = delete; BasicTokenizer(const BasicTokenizer& other) = delete;
virtual ~BasicTokenizer() = default; virtual ~BasicTokenizer() = default;
...@@ -51,15 +52,16 @@ namespace cuBERT { ...@@ -51,15 +52,16 @@ namespace cuBERT {
* *
* to_lower * to_lower
* _run_strip_accents Strips accents from a piece of text. * _run_strip_accents Strips accents from a piece of text.
* _clean_text Performs invalid character removal and whitespace cleanup on text. * _clean_text Performs invalid character removal and whitespace cleanup on
* _tokenize_chinese_chars Adds whitespace around any CJK character. * text. _tokenize_chinese_chars Adds whitespace around any CJK character.
* _run_split_on_punc Splits punctuation on a piece of text. * _run_split_on_punc Splits punctuation on a piece of text.
* whitespace_tokenize Runs basic whitespace cleaning and splitting on a piece of text. * whitespace_tokenize Runs basic whitespace cleaning and splitting on a piece
* of text.
* *
* @param text * @param text
* @param output_tokens * @param output_tokens
*/ */
void tokenize(const char *text, std::vector<std::string> *output_tokens, size_t max_length); void tokenize(const char* text, std::vector<std::string>* output_tokens, size_t max_length);
private: private:
const bool do_lower_case; const bool do_lower_case;
...@@ -70,20 +72,22 @@ namespace cuBERT { ...@@ -70,20 +72,22 @@ namespace cuBERT {
* @return * @return
*/ */
inline static bool _is_chinese_char(int cp); inline static bool _is_chinese_char(int cp);
}; };
/** /**
* Runs WordPiece tokenziation. * Runs WordPiece tokenziation.
*/ */
class WordpieceTokenizer { class WordpieceTokenizer
{
public: public:
explicit WordpieceTokenizer( explicit WordpieceTokenizer(std::unordered_map<std::string, uint64_t>* vocab,
std::unordered_map<std::string, uint64_t> *vocab,
std::string unk_token = "[UNK]", std::string unk_token = "[UNK]",
int max_input_chars_per_word = 200 int max_input_chars_per_word = 200)
) : vocab(vocab), unk_token(unk_token), max_input_chars_per_word(max_input_chars_per_word) {} : vocab(vocab), unk_token(unk_token), max_input_chars_per_word(max_input_chars_per_word)
{
}
WordpieceTokenizer(const WordpieceTokenizer &other) = delete; WordpieceTokenizer(const WordpieceTokenizer& other) = delete;
virtual ~WordpieceTokenizer() = default; virtual ~WordpieceTokenizer() = default;
...@@ -97,67 +101,77 @@ namespace cuBERT { ...@@ -97,67 +101,77 @@ namespace cuBERT {
* input = "unaffable" * input = "unaffable"
* output = ["un", "##aff", "##able"] * output = ["un", "##aff", "##able"]
* *
* @param text A single token or whitespace separated tokens. This should have already been passed through `BasicTokenizer. * @param text A single token or whitespace separated tokens. This should have
* already been passed through `BasicTokenizer.
* @param output_tokens A list of wordpiece tokens. * @param output_tokens A list of wordpiece tokens.
*/ */
void tokenize(const std::string &text, std::vector<std::string> *output_tokens); void tokenize(const std::string& text, std::vector<std::string>* output_tokens);
private: private:
const std::unordered_map<std::string, uint64_t> *vocab; const std::unordered_map<std::string, uint64_t>* vocab;
const std::string unk_token; const std::string unk_token;
const int max_input_chars_per_word; const int max_input_chars_per_word;
}; };
/** /**
* Runs end-to-end tokenziation. * Runs end-to-end tokenziation.
*/ */
class FullTokenizer { class FullTokenizer
{
public: public:
FullTokenizer(const char *vocab_file, bool do_lower_case = true) { FullTokenizer(const char* vocab_file, bool do_lower_case = true)
{
vocab = new std::unordered_map<std::string, uint64_t>(); vocab = new std::unordered_map<std::string, uint64_t>();
load_vocab(vocab_file, vocab); load_vocab(vocab_file, vocab);
basic_tokenizer = new BasicTokenizer(do_lower_case); basic_tokenizer = new BasicTokenizer(do_lower_case);
wordpiece_tokenizer = new WordpieceTokenizer(vocab); wordpiece_tokenizer = new WordpieceTokenizer(vocab);
} }
~FullTokenizer() { ~FullTokenizer()
if (wordpiece_tokenizer != NULL){ {
if(wordpiece_tokenizer != NULL)
{
wordpiece_tokenizer = NULL; wordpiece_tokenizer = NULL;
} }
delete wordpiece_tokenizer; delete wordpiece_tokenizer;
if (basic_tokenizer != NULL){ if(basic_tokenizer != NULL)
{
basic_tokenizer = NULL; basic_tokenizer = NULL;
} }
delete basic_tokenizer; delete basic_tokenizer;
if (vocab != NULL){ if(vocab != NULL)
{
vocab = NULL; vocab = NULL;
} }
delete vocab; delete vocab;
} }
void tokenize(const char *text, std::vector<std::string> *output_tokens, size_t max_length); void tokenize(const char* text, std::vector<std::string>* output_tokens, size_t max_length);
inline uint64_t convert_token_to_id(const std::string &token) { inline uint64_t convert_token_to_id(const std::string& token)
{
auto item = vocab->find(token); auto item = vocab->find(token);
if (item == vocab->end()) { if(item == vocab->end())
{
std::cerr << "vocab missing key: " << token << std::endl; std::cerr << "vocab missing key: " << token << std::endl;
return 0; return 0;
} else { }
else
{
return item->second; return item->second;
} }
} }
void convert_tokens_to_ids(const std::vector<std::string> &tokens, uint64_t *ids); void convert_tokens_to_ids(const std::vector<std::string>& tokens, uint64_t* ids);
private: private:
std::unordered_map<std::string, uint64_t> *vocab; std::unordered_map<std::string, uint64_t>* vocab;
BasicTokenizer *basic_tokenizer; BasicTokenizer* basic_tokenizer;
WordpieceTokenizer *wordpiece_tokenizer; WordpieceTokenizer* wordpiece_tokenizer;
}; };
} } // namespace cuBERT
#endif //CUBERT_TOKENIZATION_H #endif // CUBERT_TOKENIZATION_H
This diff is collapsed.
This diff is collapsed.
#include <Bert.h>
#include <Filesystem.h>
#include <SimpleLog.h>
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
#include <Bert.h>
#include <SimpleLog.h>
#include <Filesystem.h>
#include <tokenization.h> #include <tokenization.h>
int main(int argc, char *argv[]) int main(int argc, char* argv[])
{ {
// 加载Bert模型 // 加载Bert模型
migraphxSamples::Bert bert; migraphxSamples::Bert bert;
migraphxSamples::ErrorCode errorCode = bert.Initialize(); migraphxSamples::ErrorCode errorCode = bert.Initialize();
if (errorCode != migraphxSamples::SUCCESS) if(errorCode != migraphxSamples::SUCCESS)
{ {
LOG_ERROR(stdout, "fail to initialize Bert!\n"); LOG_ERROR(stdout, "fail to initialize Bert!\n");
exit(-1); exit(-1);
...@@ -25,7 +25,18 @@ int main(int argc, char *argv[]) ...@@ -25,7 +25,18 @@ int main(int argc, char *argv[])
int max_answer_length = 30; // 答案的最大长度 int max_answer_length = 30; // 答案的最大长度
// 上下文文本数据 // 上下文文本数据
const char text[] = { u8"ROCm is the first open-source exascale-class platform for accelerated computing that’s also programming-language independent. It brings a philosophy of choice, minimalism and modular software development to GPU computing. You are free to choose or even develop tools and a language run time for your application. ROCm is built for scale, it supports multi-GPU computing and has a rich system run time with the critical features that large-scale application, compiler and language-run-time development requires. Since the ROCm ecosystem is comprised of open technologies: frameworks (Tensorflow / PyTorch), libraries (MIOpen / Blas / RCCL), programming model (HIP), inter-connect (OCD) and up streamed Linux® Kernel support – the platform is continually optimized for performance and extensibility." }; const char text[] = {u8"ROCm is the first open-source exascale-class platform for accelerated "
u8"computing that’s also programming-language independent. It brings a "
u8"philosophy of choice, minimalism and modular software development to "
u8"GPU computing. You are free to choose or even develop tools and a "
u8"language run time for your application. ROCm is built for scale, it "
u8"supports multi-GPU computing and has a rich system run time with the "
u8"critical features that large-scale application, compiler and "
u8"language-run-time development requires. Since the ROCm ecosystem is "
u8"comprised of open technologies: frameworks (Tensorflow / PyTorch), "
u8"libraries (MIOpen / Blas / RCCL), programming model (HIP), "
u8"inter-connect (OCD) and up streamed Linux® Kernel support – the "
u8"platform is continually optimized for performance and extensibility."};
char question[100]; char question[100];
std::vector<std::vector<long unsigned int>> input_ids; std::vector<std::vector<long unsigned int>> input_ids;
...@@ -35,14 +46,22 @@ int main(int argc, char *argv[]) ...@@ -35,14 +46,22 @@ int main(int argc, char *argv[])
std::vector<float> end_position; std::vector<float> end_position;
std::string answer = {}; std::string answer = {};
cuBERT::FullTokenizer tokenizer = cuBERT::FullTokenizer("../Resource/uncased_L-12_H-768_A-12/vocab.txt"); // 分词工具 cuBERT::FullTokenizer tokenizer =
cuBERT::FullTokenizer("../Resource/uncased_L-12_H-768_A-12/vocab.txt"); // 分词工具
while (true) while(true)
{ {
// 数据前处理 // 数据前处理
std::cout << "question: "; std::cout << "question: ";
cin.getline(question, 100); cin.getline(question, 100);
bert.Preprocessing(tokenizer, batch_size, max_seq_length, text, question, input_ids, input_masks, segment_ids); bert.Preprocessing(tokenizer,
batch_size,
max_seq_length,
text,
question,
input_ids,
input_masks,
segment_ids);
// 推理 // 推理
bert.Inference(input_ids, input_masks, segment_ids, start_position, end_position); bert.Inference(input_ids, input_masks, segment_ids, start_position, end_position);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment