Commit d8f53cee authored by yangql's avatar yangql
Browse files

Initial commit

parents
#include <stdexcept>
#include <algorithm>
#include <cstring>
#include <fstream>
#include "utf8proc.h"
#include "./tokenization.h"
namespace cuBERT {
void FullTokenizer::convert_tokens_to_ids(const std::vector<std::string> &tokens, uint64_t *ids) {
for (int i = 0; i < tokens.size(); ++i) {
ids[i] = convert_token_to_id(tokens[i]);
}
}
// trim from start (in place)
static inline void ltrim(std::string &s) {
s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int ch) {
return !std::isspace(ch);
}));
}
// trim from end (in place)
static inline void rtrim(std::string &s) {
s.erase(std::find_if(s.rbegin(), s.rend(), [](int ch) {
return !std::isspace(ch);
}).base(), s.end());
}
// trim from both ends (in place)
static inline void trim(std::string &s) {
ltrim(s);
rtrim(s);
}
void load_vocab(const char *vocab_file, std::unordered_map<std::string, uint64_t> *vocab) {
std::ifstream file(vocab_file);
if (!file) {
throw std::invalid_argument("Unable to open vocab file");
}
unsigned int index = 0;
std::string line;
while (std::getline(file, line)) {
trim(line);
(*vocab)[line] = index;
index++;
}
file.close();
}
inline bool _is_whitespace(int c, const char *cat) {
if (c == ' ' || c == '\t' || c == '\n' || c == '\r') {
return true;
}
return cat[0] == 'Z' && cat[1] == 's';
}
inline bool _is_control(int c, const char *cat) {
// These are technically control characters but we count them as whitespace characters.
if (c == '\t' || c == '\n' || c == '\r') {
return false;
}
return 'C' == *cat;
}
inline bool _is_punctuation(int cp, const char *cat) {
// We treat all non-letter/number ASCII as punctuation.
// Characters such as "^", "$", and "`" are not in the Unicode
// Punctuation class but we treat them as punctuation anyways, for
// consistency.
if ((cp >= 33 && cp <= 47) || (cp >= 58 && cp <= 64) ||
(cp >= 91 && cp <= 96) || (cp >= 123 && cp <= 126)) {
return true;
}
return 'P' == *cat;
}
bool _is_whitespace(int c) {
return _is_whitespace(c, utf8proc_category_string(c));
}
bool _is_control(int c) {
return _is_control(c, utf8proc_category_string(c));
}
bool _is_punctuation(int cp) {
return _is_punctuation(cp, utf8proc_category_string(cp));
}
bool BasicTokenizer::_is_chinese_char(int cp) {
// This defines a "chinese character" as anything in the CJK Unicode block:
// https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
//
// Note that the CJK Unicode block is NOT all Japanese and Korean characters,
// despite its name. The modern Korean Hangul alphabet is a different block,
// as is Japanese Hiragana and Katakana. Those alphabets are used to write
// space-separated words, so they are not treated specially and handled
// like the all of the other languages.
return (cp >= 0x4E00 && cp <= 0x9FFF) ||
(cp >= 0x3400 && cp <= 0x4DBF) ||
(cp >= 0x20000 && cp <= 0x2A6DF) ||
(cp >= 0x2A700 && cp <= 0x2B73F) ||
(cp >= 0x2B740 && cp <= 0x2B81F) ||
(cp >= 0x2B820 && cp <= 0x2CEAF) ||
(cp >= 0xF900 && cp <= 0xFAFF) ||
(cp >= 0x2F800 && cp <= 0x2FA1F);
}
void BasicTokenizer::tokenize(const char *text, std::vector<std::string> *output_tokens, size_t max_length) {
// This was added on November 1st, 2018 for the multilingual and Chinese
// models. This is also applied to the English models now, but it doesn't
// matter since the English models were not trained on any Chinese data
// and generally don't have any Chinese data in them (there are Chinese
// characters in the vocabulary because Wikipedia does have some Chinese
// words in the English Wikipedia.).
if (do_lower_case) {
text = (const char *) utf8proc_NFD((const utf8proc_uint8_t *) text);
}
size_t word_bytes = std::strlen(text);
bool new_token = true;
size_t subpos = 0;
int cp;
char dst[4];
while (word_bytes > 0) {
int len = utf8proc_iterate((const utf8proc_uint8_t *) text + subpos, word_bytes, &cp);
if (len < 0) {
std::cerr << "UTF-8 decode error: " << text << std::endl;
break;
}
if (do_lower_case) {
cp = utf8proc_tolower(cp);
}
const char *cat = utf8proc_category_string(cp);
if (cp == 0 || cp == 0xfffd || _is_control(cp, cat)) {
// pass
} else if (do_lower_case && cat[0] == 'M' && cat[1] == 'n') {
// pass
} else if (_is_whitespace(cp, cat)) {
new_token = true;
} else {
size_t dst_len = len;
const char *dst_ptr = text + subpos;
if (do_lower_case) {
dst_len = utf8proc_encode_char(cp, (utf8proc_uint8_t *) dst);
dst_ptr = dst;
}
if (_is_punctuation(cp, cat) || _is_chinese_char(cp)) {
output_tokens->emplace_back(dst_ptr, dst_len);
new_token = true;
} else {
if (new_token) {
output_tokens->emplace_back(dst_ptr, dst_len);
new_token = false;
} else {
output_tokens->at(output_tokens->size() - 1).append(dst_ptr, dst_len);
}
}
}
word_bytes = word_bytes - len;
subpos = subpos + len;
// early terminate
if (output_tokens->size() >= max_length) {
break;
}
}
if (do_lower_case) {
free((void *) text);
}
}
void WordpieceTokenizer::tokenize(const std::string &token, std::vector<std::string> *output_tokens) {
if (token.size() > max_input_chars_per_word) { // FIXME: slightly different
output_tokens->push_back(unk_token);
return;
}
size_t output_tokens_len = output_tokens->size();
for (size_t start = 0; start < token.size();) {
bool is_bad = true;
// TODO: can be optimized by prefix-tree
for (size_t end = token.size(); start < end; --end) { // FIXME: slightly different
std::string substr = start > 0
? "##" + token.substr(start, end - start)
: token.substr(start, end - start);
if (vocab->count(substr)) {
is_bad = false;
output_tokens->push_back(substr);
start = end;
break;
}
}
if (is_bad) {
output_tokens->resize(output_tokens_len);
output_tokens->push_back(unk_token);
return;
}
}
}
void FullTokenizer::tokenize(const char *text, std::vector<std::string> *output_tokens, size_t max_length) {
std::vector<std::string> tokens;
tokens.reserve(max_length);
basic_tokenizer->tokenize(text, &tokens, max_length);
for (const auto &token : tokens) {
wordpiece_tokenizer->tokenize(token, output_tokens);
// early terminate
if (output_tokens->size() >= max_length) {
break;
}
}
}
}
#ifndef CUBERT_TOKENIZATION_H
#define CUBERT_TOKENIZATION_H
#include <string>
#include <vector>
#include <unordered_map>
#include <iostream>
namespace cuBERT {
void load_vocab(const char *vocab_file, std::unordered_map<std::string, uint64_t> *vocab);
/**
* Checks whether `chars` is a whitespace character.
* @param c
* @return
*/
bool _is_whitespace(int c);
/**
* Checks whether `chars` is a control character.
* @param c
* @return
*/
bool _is_control(int c);
/**
* Checks whether `chars` is a punctuation character.
* @param cp
* @return
*/
bool _is_punctuation(int cp);
/**
* Runs basic tokenization (punctuation splitting, lower casing, etc.).
*/
class BasicTokenizer {
public:
/**
* Constructs a BasicTokenizer.
* @param do_lower_case Whether to lower case the input.
*/
explicit BasicTokenizer(bool do_lower_case = true) : do_lower_case(do_lower_case) {}
BasicTokenizer(const BasicTokenizer &other) = delete;
virtual ~BasicTokenizer() = default;
/**
* Tokenizes a piece of text.
*
* to_lower
* _run_strip_accents Strips accents from a piece of text.
* _clean_text Performs invalid character removal and whitespace cleanup on text.
* _tokenize_chinese_chars Adds whitespace around any CJK character.
* _run_split_on_punc Splits punctuation on a piece of text.
* whitespace_tokenize Runs basic whitespace cleaning and splitting on a piece of text.
*
* @param text
* @param output_tokens
*/
void tokenize(const char *text, std::vector<std::string> *output_tokens, size_t max_length);
private:
const bool do_lower_case;
/**
* Checks whether CP is the codepoint of a CJK character.
* @param cp
* @return
*/
inline static bool _is_chinese_char(int cp);
};
/**
* Runs WordPiece tokenziation.
*/
class WordpieceTokenizer {
public:
explicit WordpieceTokenizer(
std::unordered_map<std::string, uint64_t> *vocab,
std::string unk_token = "[UNK]",
int max_input_chars_per_word = 200
) : vocab(vocab), unk_token(unk_token), max_input_chars_per_word(max_input_chars_per_word) {}
WordpieceTokenizer(const WordpieceTokenizer &other) = delete;
virtual ~WordpieceTokenizer() = default;
/**
* Tokenizes a piece of text into its word pieces.
*
* This uses a greedy longest-match-first algorithm to perform tokenization
* using the given vocabulary.
*
* For example:
* input = "unaffable"
* output = ["un", "##aff", "##able"]
*
* @param text A single token or whitespace separated tokens. This should have already been passed through `BasicTokenizer.
* @param output_tokens A list of wordpiece tokens.
*/
void tokenize(const std::string &text, std::vector<std::string> *output_tokens);
private:
const std::unordered_map<std::string, uint64_t> *vocab;
const std::string unk_token;
const int max_input_chars_per_word;
};
/**
* Runs end-to-end tokenziation.
*/
class FullTokenizer {
public:
FullTokenizer(const char *vocab_file, bool do_lower_case = true) {
vocab = new std::unordered_map<std::string, uint64_t>();
load_vocab(vocab_file, vocab);
basic_tokenizer = new BasicTokenizer(do_lower_case);
wordpiece_tokenizer = new WordpieceTokenizer(vocab);
}
~FullTokenizer() {
if (wordpiece_tokenizer != NULL){
wordpiece_tokenizer = NULL;
}
delete wordpiece_tokenizer;
if (basic_tokenizer != NULL){
basic_tokenizer = NULL;
}
delete basic_tokenizer;
if (vocab != NULL){
vocab = NULL;
}
delete vocab;
}
void tokenize(const char *text, std::vector<std::string> *output_tokens, size_t max_length);
inline uint64_t convert_token_to_id(const std::string &token) {
auto item = vocab->find(token);
if (item == vocab->end()) {
std::cerr << "vocab missing key: " << token << std::endl;
return 0;
} else {
return item->second;
}
}
void convert_tokens_to_ids(const std::vector<std::string> &tokens, uint64_t *ids);
private:
std::unordered_map<std::string, uint64_t> *vocab;
BasicTokenizer *basic_tokenizer;
WordpieceTokenizer *wordpiece_tokenizer;
};
}
#endif //CUBERT_TOKENIZATION_H
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <GPT2.h>
#include <fstream>
#include <SimpleLog.h>
#include <Filesystem.h>
#include <tokenization.h>
int main()
{
// 加载GPT2模型
ortSamples::GPT2 gpt2;
ortSamples::ErrorCode errorCode = gpt2.Initialize();
if (errorCode != ortSamples::SUCCESS)
{
LOG_ERROR(stdout, "fail to initialize GPT2!\n");
exit(-1);
}
LOG_INFO(stdout, "succeed to initialize GPT2\n");
// 加载词汇表,用于编码和解码
cuBERT::FullTokenizer tokenizer = cuBERT::FullTokenizer("../Resource/vocab_shici.txt");
std::ifstream infile;
std::string buf;
std::vector<std::string> output;
infile.open("../Resource/vocab_shici.txt");
while (std::getline(infile,buf))
{
output.push_back(buf);
}
std::vector<long unsigned int> input_id;
char question[100];
std::vector<long unsigned int> score;
std::vector<std::string> result;
std::cout << "开始和GPT2对诗,输入CTRL + Z以退出" << std::endl;
while (true)
{
// 数据预处理
std::cout << "question: ";
cin.getline(question, 100);
gpt2.Preprocessing(tokenizer, question, input_id);
// 推理
for(int i=0;i<50;++i)
{
long unsigned int outputs = gpt2.Inference(input_id);
if(outputs == 102)
{
break;
}
input_id.push_back(outputs);
score.push_back(outputs);
}
// 将数值映射为字符
for(int i=0;i<score.size();++i)
{
result.push_back(output[score[i]]);
}
// 打印结果
std::cout << "chatbot: ";
std::cout << question;
for(int j=0; j<result.size();++j)
{
std::cout<<result[j]<<std::endl;
}
// 清除数据
input_id.clear();
result.clear();
score.clear();
}
return 0;
}
\ No newline at end of file
# 模型名称
modelName=GPT2_Ort
# 模型描述
modelDescription=GPT2是第二代生成式预训练语言模型。
# 应用场景
appScenario=推理,NLP,ONNXRuntime
# 框架类型
frameType=ONNXRuntime
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment