Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
GPT2_migraphx
Commits
0f9dc829
Commit
0f9dc829
authored
Nov 15, 2023
by
liucong
Browse files
重新格式化Cppd代码格式
parent
824cfb81
Changes
10
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
2235 additions
and
1957 deletions
+2235
-1957
Src/GPT2.cpp
Src/GPT2.cpp
+30
-38
Src/GPT2.h
Src/GPT2.h
+21
-21
Src/Utility/Filesystem.cpp
Src/Utility/Filesystem.cpp
+576
-570
Src/Utility/Filesystem.h
Src/Utility/Filesystem.h
+27
-19
Src/Utility/SimpleLog.h
Src/Utility/SimpleLog.h
+136
-117
Src/Utility/tokenization.cpp
Src/Utility/tokenization.cpp
+209
-176
Src/Utility/tokenization.h
Src/Utility/tokenization.h
+120
-106
Src/Utility/utf8proc.c
Src/Utility/utf8proc.c
+802
-621
Src/Utility/utf8proc.h
Src/Utility/utf8proc.h
+307
-282
Src/main.cpp
Src/main.cpp
+7
-7
No files found.
Src/GPT2.cpp
View file @
0f9dc829
...
...
@@ -12,90 +12,82 @@
namespace
migraphxSamples
{
GPT2
::
GPT2
()
{
}
GPT2
::~
GPT2
()
{
GPT2
::
GPT2
()
{}
}
GPT2
::~
GPT2
()
{
}
ErrorCode
GPT2
::
Initialize
()
{
// 获取模型文件
std
::
string
modelPath
=
"../Resource/GPT2_shici.onnx"
;
std
::
string
modelPath
=
"../Resource/GPT2_shici.onnx"
;
// 设置最大输入shape
migraphx
::
onnx_options
onnx_options
;
onnx_options
.
map_input_dims
[
"input"
]
=
{
1
,
1000
};
onnx_options
.
map_input_dims
[
"input"
]
=
{
1
,
1000
};
// 加载模型
if
(
!
Exists
(
modelPath
))
{
LOG_ERROR
(
stdout
,
"%s not exist!
\n
"
,
modelPath
.
c_str
());
LOG_ERROR
(
stdout
,
"%s not exist!
\n
"
,
modelPath
.
c_str
());
return
MODEL_NOT_EXIST
;
}
net
=
migraphx
::
parse_onnx
(
modelPath
,
onnx_options
);
LOG_INFO
(
stdout
,
"succeed to load model: %s
\n
"
,
GetFileName
(
modelPath
).
c_str
());
net
=
migraphx
::
parse_onnx
(
modelPath
,
onnx_options
);
LOG_INFO
(
stdout
,
"succeed to load model: %s
\n
"
,
GetFileName
(
modelPath
).
c_str
());
// 获取模型输入/输出节点信息
std
::
unordered_map
<
std
::
string
,
migraphx
::
shape
>
inputs
=
net
.
get_inputs
();
std
::
unordered_map
<
std
::
string
,
migraphx
::
shape
>
inputs
=
net
.
get_inputs
();
std
::
unordered_map
<
std
::
string
,
migraphx
::
shape
>
outputs
=
net
.
get_outputs
();
inputName
=
inputs
.
begin
()
->
first
;
inputShape
=
inputs
.
begin
()
->
second
;
inputName
=
inputs
.
begin
()
->
first
;
inputShape
=
inputs
.
begin
()
->
second
;
// 设置模型为GPU模式
migraphx
::
target
gpuTarget
=
migraphx
::
gpu
::
target
{};
// 编译模型
migraphx
::
compile_options
options
;
options
.
device_id
=
0
;
// 设置GPU设备,默认为0号设备
options
.
offload_copy
=
true
;
// 设置offload_copy
net
.
compile
(
gpuTarget
,
options
);
LOG_INFO
(
stdout
,
"succeed to compile model: %s
\n
"
,
GetFileName
(
modelPath
).
c_str
());
options
.
device_id
=
0
;
// 设置GPU设备,默认为0号设备
options
.
offload_copy
=
true
;
// 设置offload_copy
net
.
compile
(
gpuTarget
,
options
);
LOG_INFO
(
stdout
,
"succeed to compile model: %s
\n
"
,
GetFileName
(
modelPath
).
c_str
());
return
SUCCESS
;
}
static
bool
CompareM
(
Predictions
a
,
Predictions
b
)
{
return
a
.
predictionvalue
>
b
.
predictionvalue
;
}
static
bool
CompareM
(
Predictions
a
,
Predictions
b
)
{
return
a
.
predictionvalue
>
b
.
predictionvalue
;
}
long
unsigned
int
GPT2
::
Inference
(
const
std
::
vector
<
long
unsigned
int
>
&
input_id
)
long
unsigned
int
GPT2
::
Inference
(
const
std
::
vector
<
long
unsigned
int
>&
input_id
)
{
long
unsigned
int
input
[
1
][
input_id
.
size
()];
for
(
int
j
=
0
;
j
<
input_id
.
size
();
++
j
)
for
(
int
j
=
0
;
j
<
input_id
.
size
();
++
j
)
{
input
[
0
][
j
]
=
input_id
[
j
];
}
// 设置输入shape
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
inputShapes
;
inputShapes
.
push_back
({
1
,
input_id
.
size
()});
inputShapes
.
push_back
({
1
,
input_id
.
size
()});
// 创建输入数据
std
::
unordered_map
<
std
::
string
,
migraphx
::
argument
>
inputData
;
inputData
[
inputName
]
=
migraphx
::
argument
{
migraphx
::
shape
(
inputShape
.
type
(),
inputShapes
[
0
]),(
long
unsigned
int
*
)
input
};
inputData
[
inputName
]
=
migraphx
::
argument
{
migraphx
::
shape
(
inputShape
.
type
(),
inputShapes
[
0
]),
(
long
unsigned
int
*
)
input
};
// 推理
std
::
vector
<
migraphx
::
argument
>
results
=
net
.
eval
(
inputData
);
// 获取输出节点的属性
migraphx
::
argument
result
=
results
[
0
];
migraphx
::
shape
outputShape
=
result
.
get_shape
();
// 输出节点的shape
int
numberOfOutput
=
outputShape
.
elements
();
// 输出节点元素的个数
float
*
data
=
(
float
*
)
result
.
data
();
// 输出节点数据指针
migraphx
::
argument
result
=
results
[
0
];
migraphx
::
shape
outputShape
=
result
.
get_shape
();
// 输出节点的shape
int
numberOfOutput
=
outputShape
.
elements
();
// 输出节点元素的个数
float
*
data
=
(
float
*
)
result
.
data
();
// 输出节点数据指针
// 保存推理结果
long
unsigned
int
n
=
0
;
std
::
vector
<
Predictions
>
resultsOfPredictions
(
22557
);
for
(
int
i
=
(
input_id
.
size
()
-
1
)
*
22557
;
i
<
input_id
.
size
()
*
22557
;
++
i
)
for
(
int
i
=
(
input_id
.
size
()
-
1
)
*
22557
;
i
<
input_id
.
size
()
*
22557
;
++
i
)
{
resultsOfPredictions
[
n
].
index
=
n
;
resultsOfPredictions
[
n
].
index
=
n
;
resultsOfPredictions
[
n
].
predictionvalue
=
data
[
i
];
++
n
;
}
...
...
@@ -110,8 +102,8 @@ long unsigned int GPT2::Inference(const std::vector<long unsigned int> &input_id
}
ErrorCode
GPT2
::
Preprocessing
(
cuBERT
::
FullTokenizer
tokenizer
,
char
*
question
,
std
::
vector
<
long
unsigned
int
>
&
input_id
)
char
*
question
,
std
::
vector
<
long
unsigned
int
>&
input_id
)
{
// 分词操作
int
max_seq_length
=
1000
;
...
...
@@ -121,7 +113,7 @@ ErrorCode GPT2::Preprocessing(cuBERT::FullTokenizer tokenizer,
// 保存编码信息
input_id
.
push_back
(
tokenizer
.
convert_token_to_id
(
"[CLS]"
));
for
(
int
i
=
0
;
i
<
tokens_question
.
size
();
++
i
)
for
(
int
i
=
0
;
i
<
tokens_question
.
size
();
++
i
)
{
input_id
.
push_back
(
tokenizer
.
convert_token_to_id
(
tokens_question
[
i
]));
}
...
...
@@ -129,4 +121,4 @@ ErrorCode GPT2::Preprocessing(cuBERT::FullTokenizer tokenizer,
return
SUCCESS
;
}
}
\ No newline at end of file
}
// namespace migraphxSamples
\ No newline at end of file
Src/GPT2.h
View file @
0f9dc829
...
...
@@ -8,42 +8,42 @@
namespace
migraphxSamples
{
typedef
enum
_ErrorCode
{
SUCCESS
=
0
,
MODEL_NOT_EXIST
,
CONFIG_FILE_NOT_EXIST
,
FAIL_TO_LOAD_MODEL
,
FAIL_TO_OPEN_CONFIG_FILE
,
}
ErrorCode
;
typedef
struct
_Predictions
{
long
unsigned
int
index
;
float
predictionvalue
;
}
Predictions
;
typedef
enum
_ErrorCode
{
SUCCESS
=
0
,
MODEL_NOT_EXIST
,
CONFIG_FILE_NOT_EXIST
,
FAIL_TO_LOAD_MODEL
,
FAIL_TO_OPEN_CONFIG_FILE
,
}
ErrorCode
;
typedef
struct
_Predictions
{
long
unsigned
int
index
;
float
predictionvalue
;
}
Predictions
;
class
GPT2
{
public:
public:
GPT2
();
~
GPT2
();
ErrorCode
Initialize
();
ErrorCode
Preprocessing
(
cuBERT
::
FullTokenizer
tokenizer
,
char
*
question
,
std
::
vector
<
long
unsigned
int
>
&
input_id
);
char
*
question
,
std
::
vector
<
long
unsigned
int
>&
input_id
);
long
unsigned
int
Inference
(
const
std
::
vector
<
long
unsigned
int
>
&
input_id
);
long
unsigned
int
Inference
(
const
std
::
vector
<
long
unsigned
int
>&
input_id
);
private:
private:
migraphx
::
program
net
;
std
::
string
inputName
;
migraphx
::
shape
inputShape
;
};
}
}
// namespace migraphxSamples
#endif
\ No newline at end of file
Src/Utility/Filesystem.cpp
View file @
0f9dc829
This diff is collapsed.
Click to expand it.
Src/Utility/Filesystem.h
View file @
0f9dc829
...
...
@@ -5,27 +5,27 @@
#include <string>
#include <vector>
namespace
migraphxSamples
{
// 路径是否存在
bool
Exists
(
const
std
::
string
&
path
);
bool
Exists
(
const
std
::
string
&
path
);
// 路径是否为目录
bool
IsDirectory
(
const
std
::
string
&
path
);
bool
IsDirectory
(
const
std
::
string
&
path
);
// 是否是路径分隔符(Linux:‘/’,Windows:’\\’)
bool
IsPathSeparator
(
char
c
);
// 路径拼接
std
::
string
JoinPath
(
const
std
::
string
&
base
,
const
std
::
string
&
path
);
std
::
string
JoinPath
(
const
std
::
string
&
base
,
const
std
::
string
&
path
);
// 创建多级目录,注意:创建多级目录的时候,目标目录是不能有文件存在的
bool
CreateDirectories
(
const
std
::
string
&
directoryPath
);
bool
CreateDirectories
(
const
std
::
string
&
directoryPath
);
/** 生成符合指定模式的文件名列表(支持递归遍历)
*
*
* pattern: 模式,比如"*.jpg","*.png","*.jpg,*.png"
* addPath:是否包含父路径
* 注意:
...
...
@@ -36,35 +36,43 @@ bool CreateDirectories(const std::string &directoryPath);
5. 不能返回子目录名
*
*/
void
GetFileNameList
(
const
std
::
string
&
directory
,
const
std
::
string
&
pattern
,
std
::
vector
<
std
::
string
>
&
result
,
bool
recursive
,
bool
addPath
);
void
GetFileNameList
(
const
std
::
string
&
directory
,
const
std
::
string
&
pattern
,
std
::
vector
<
std
::
string
>&
result
,
bool
recursive
,
bool
addPath
);
// 与GetFileNameList的区别在于如果有子目录,在addPath为true的时候会返回子目录路径(目录名最后有"/")
void
GetFileNameList2
(
const
std
::
string
&
directory
,
const
std
::
string
&
pattern
,
std
::
vector
<
std
::
string
>
&
result
,
bool
recursive
,
bool
addPath
);
void
GetFileNameList2
(
const
std
::
string
&
directory
,
const
std
::
string
&
pattern
,
std
::
vector
<
std
::
string
>&
result
,
bool
recursive
,
bool
addPath
);
// 删除文件或者目录,支持递归删除
void
Remove
(
const
std
::
string
&
directory
,
const
std
::
string
&
extension
=
""
);
void
Remove
(
const
std
::
string
&
directory
,
const
std
::
string
&
extension
=
""
);
/** 获取路径的文件名和扩展名
*
*
* 示例:path为D:/1/1.txt,则GetFileName()为1.txt,GetFileName_NoExtension()为1,GetExtension()为.txt,GetParentPath()为D:/1/
*/
std
::
string
GetFileName
(
const
std
::
string
&
path
);
std
::
string
GetFileName_NoExtension
(
const
std
::
string
&
path
);
std
::
string
GetExtension
(
const
std
::
string
&
path
);
std
::
string
GetParentPath
(
const
std
::
string
&
path
);
*/
std
::
string
GetFileName
(
const
std
::
string
&
path
);
std
::
string
GetFileName_NoExtension
(
const
std
::
string
&
path
);
std
::
string
GetExtension
(
const
std
::
string
&
path
);
std
::
string
GetParentPath
(
const
std
::
string
&
path
);
// 拷贝文件
bool
CopyFile
(
const
std
::
string
srcPath
,
const
std
::
string
dstPath
);
bool
CopyFile
(
const
std
::
string
srcPath
,
const
std
::
string
dstPath
);
/** 拷贝目录
*
*
* 示例:CopyDirectories("D:/0/1/2/","E:/3/");实现把D:/0/1/2/目录拷贝到E:/3/目录中(即拷贝完成后的目录结构为E:/3/2/)
* 注意:
1.第一个参数的最后不能加”/”
2.不能拷贝隐藏文件
*/
bool
CopyDirectories
(
std
::
string
srcPath
,
const
std
::
string
dstPath
);
bool
CopyDirectories
(
std
::
string
srcPath
,
const
std
::
string
dstPath
);
}
}
// namespace migraphxSamples
#endif
Src/Utility/SimpleLog.h
View file @
0f9dc829
...
...
@@ -8,7 +8,7 @@
#include <map>
#include <thread>
#include <mutex>
#if
(defined WIN32 || defined _WIN32)
#if(defined WIN32 || defined _WIN32)
#include <Windows.h>
#else
#include <sys/time.h>
...
...
@@ -16,13 +16,13 @@
using
namespace
std
;
/** 简易日志
*
*
* 不依赖于其他第三方库,只需要包含一个头文件就可以使用。提供了4种日志级别,包括INFO,DEBUG,WARN和ERROR。
*
*
* 示例1:
// 初始化日志,在./Log/目录下创建两个日志文件log1.log和log2.log(注意:目录./Log/需要存在,否则日志创建失败)
//
初始化日志,在./Log/目录下创建两个日志文件log1.log和log2.log(注意:目录./Log/需要存在,否则日志创建失败)
LogManager::GetInstance()->Initialize("./Log/","log1");
LogManager::GetInstance()->Initialize("./Log/","log2");
...
...
@@ -34,11 +34,11 @@ using namespace std;
// 关闭日志
LogManager::GetInstance()->Close("log1");
LogManager::GetInstance()->Close("log2");
* 示例2:
// 将日志输出到控制台
string log = "Hello World";
LOG_INFO(stdout, "%s\n", log.c_str());
LOG_INFO(stdout, "%s\n", log.c_str());
* 注意:
1. 需要C++11
...
...
@@ -50,44 +50,43 @@ using namespace std;
class
LogManager
{
private:
LogManager
(){}
private:
LogManager
()
{}
public:
~
LogManager
(){}
inline
void
Initialize
(
const
string
&
parentPath
,
const
string
&
logName
)
public:
~
LogManager
()
{}
inline
void
Initialize
(
const
string
&
parentPath
,
const
string
&
logName
)
{
// 日志名为空表示输出到控制台
if
(
logName
.
size
()
==
0
)
if
(
logName
.
size
()
==
0
)
return
;
// 查找该日志文件,如果没有则创建
std
::
map
<
string
,
FILE
*>::
const_iterator
iter
=
logMap
.
find
(
logName
);
if
(
iter
==
logMap
.
end
())
if
(
iter
==
logMap
.
end
())
{
string
pathOfLog
=
parentPath
+
logName
+
".log"
;
FILE
*
logFile
=
fopen
(
pathOfLog
.
c_str
(),
"a"
);
// w:覆盖原有文件,a:追加
if
(
logFile
!=
NULL
)
string
pathOfLog
=
parentPath
+
logName
+
".log"
;
FILE
*
logFile
=
fopen
(
pathOfLog
.
c_str
(),
"a"
);
// w:覆盖原有文件,a:追加
if
(
logFile
!=
NULL
)
{
logMap
.
insert
(
std
::
make_pair
(
logName
,
logFile
));
}
}
}
inline
FILE
*
GetLogFile
(
const
string
&
logName
)
inline
FILE
*
GetLogFile
(
const
string
&
logName
)
{
std
::
map
<
string
,
FILE
*>::
const_iterator
iter
=
logMap
.
find
(
logName
);
if
(
iter
==
logMap
.
end
())
std
::
map
<
string
,
FILE
*>::
const_iterator
iter
=
logMap
.
find
(
logName
);
if
(
iter
==
logMap
.
end
())
{
return
NULL
;
}
return
(
*
iter
).
second
;
}
inline
void
Close
(
const
string
&
logName
)
inline
void
Close
(
const
string
&
logName
)
{
std
::
map
<
string
,
FILE
*>::
const_iterator
iter
=
logMap
.
find
(
logName
);
if
(
iter
==
logMap
.
end
())
std
::
map
<
string
,
FILE
*>::
const_iterator
iter
=
logMap
.
find
(
logName
);
if
(
iter
==
logMap
.
end
())
{
return
;
}
...
...
@@ -95,10 +94,7 @@ public:
fclose
((
*
iter
).
second
);
logMap
.
erase
(
iter
);
}
inline
std
::
mutex
&
GetLogMutex
()
{
return
logMutex
;
}
inline
std
::
mutex
&
GetLogMutex
()
{
return
logMutex
;
}
// Singleton
static
LogManager
*
GetInstance
()
...
...
@@ -106,21 +102,22 @@ public:
static
LogManager
logManager
;
return
&
logManager
;
}
private:
private:
std
::
map
<
string
,
FILE
*>
logMap
;
std
::
mutex
logMutex
;
};
#ifdef LOG_MUTEX
#define LOCK
LogManager::GetInstance()->GetLogMutex().lock()
#define UNLOCK
LogManager::GetInstance()->GetLogMutex().unlock()
#define LOCK
LogManager::GetInstance()->GetLogMutex().lock()
#define UNLOCK
LogManager::GetInstance()->GetLogMutex().unlock()
#else
#define LOCK
#define UNLOCK
#define LOCK
#define UNLOCK
#endif
// log time
typedef
struct
_LogTime
typedef
struct
_LogTime
{
string
year
;
string
month
;
...
...
@@ -131,53 +128,53 @@ typedef struct _LogTime
string
millisecond
;
// ms
string
microsecond
;
// us
string
weekDay
;
}
LogTime
;
}
LogTime
;
inline
LogTime
GetTime
()
{
LogTime
currentTime
;
#if
(defined WIN32 || defined _WIN32)
#if(defined WIN32 || defined _WIN32)
SYSTEMTIME
systemTime
;
GetLocalTime
(
&
systemTime
);
char
temp
[
8
]
=
{
0
};
char
temp
[
8
]
=
{
0
};
sprintf
(
temp
,
"%04d"
,
systemTime
.
wYear
);
currentTime
.
year
=
string
(
temp
);
currentTime
.
year
=
string
(
temp
);
sprintf
(
temp
,
"%02d"
,
systemTime
.
wMonth
);
currentTime
.
month
=
string
(
temp
);
currentTime
.
month
=
string
(
temp
);
sprintf
(
temp
,
"%02d"
,
systemTime
.
wDay
);
currentTime
.
day
=
string
(
temp
);
currentTime
.
day
=
string
(
temp
);
sprintf
(
temp
,
"%02d"
,
systemTime
.
wHour
);
currentTime
.
hour
=
string
(
temp
);
currentTime
.
hour
=
string
(
temp
);
sprintf
(
temp
,
"%02d"
,
systemTime
.
wMinute
);
currentTime
.
minute
=
string
(
temp
);
currentTime
.
minute
=
string
(
temp
);
sprintf
(
temp
,
"%02d"
,
systemTime
.
wSecond
);
currentTime
.
second
=
string
(
temp
);
currentTime
.
second
=
string
(
temp
);
sprintf
(
temp
,
"%03d"
,
systemTime
.
wMilliseconds
);
currentTime
.
millisecond
=
string
(
temp
);
currentTime
.
millisecond
=
string
(
temp
);
sprintf
(
temp
,
"%d"
,
systemTime
.
wDayOfWeek
);
currentTime
.
weekDay
=
string
(
temp
);
currentTime
.
weekDay
=
string
(
temp
);
#else
struct
timeval
tv
;
struct
tm
*
p
;
struct
timeval
tv
;
struct
tm
*
p
;
gettimeofday
(
&
tv
,
NULL
);
p
=
localtime
(
&
tv
.
tv_sec
);
char
temp
[
8
]
=
{
0
};
sprintf
(
temp
,
"%04d"
,
1900
+
p
->
tm_year
);
currentTime
.
year
=
string
(
temp
);
sprintf
(
temp
,
"%02d"
,
1
+
p
->
tm_mon
);
currentTime
.
month
=
string
(
temp
);
sprintf
(
temp
,
"%02d"
,
p
->
tm_mday
);
currentTime
.
day
=
string
(
temp
);
sprintf
(
temp
,
"%02d"
,
p
->
tm_hour
);
currentTime
.
hour
=
string
(
temp
);
sprintf
(
temp
,
"%02d"
,
p
->
tm_min
);
currentTime
.
minute
=
string
(
temp
);
sprintf
(
temp
,
"%02d"
,
p
->
tm_sec
);
currentTime
.
second
=
string
(
temp
);
sprintf
(
temp
,
"%03d"
,(
int
)(
tv
.
tv_usec
/
1000
));
char
temp
[
8
]
=
{
0
};
sprintf
(
temp
,
"%04d"
,
1900
+
p
->
tm_year
);
currentTime
.
year
=
string
(
temp
);
sprintf
(
temp
,
"%02d"
,
1
+
p
->
tm_mon
);
currentTime
.
month
=
string
(
temp
);
sprintf
(
temp
,
"%02d"
,
p
->
tm_mday
);
currentTime
.
day
=
string
(
temp
);
sprintf
(
temp
,
"%02d"
,
p
->
tm_hour
);
currentTime
.
hour
=
string
(
temp
);
sprintf
(
temp
,
"%02d"
,
p
->
tm_min
);
currentTime
.
minute
=
string
(
temp
);
sprintf
(
temp
,
"%02d"
,
p
->
tm_sec
);
currentTime
.
second
=
string
(
temp
);
sprintf
(
temp
,
"%03d"
,
(
int
)(
tv
.
tv_usec
/
1000
));
currentTime
.
millisecond
=
string
(
temp
);
sprintf
(
temp
,
"%03d"
,
(
int
)(
tv
.
tv_usec
%
1000
));
currentTime
.
microsecond
=
string
(
temp
);
...
...
@@ -187,61 +184,83 @@ inline LogTime GetTime()
return
currentTime
;
}
#define LOG_TIME(logFile) \
do\
{\
LogTime currentTime=GetTime(); \
fprintf(((logFile == NULL) ? stdout : logFile), "%s-%s-%s %s:%s:%s.%s\t",currentTime.year.c_str(),currentTime.month.c_str(),currentTime.day.c_str(),currentTime.hour.c_str(),currentTime.minute.c_str(),currentTime.second.c_str(),currentTime.millisecond.c_str()); \
}while (0)
#define LOG_INFO(logFile,logInfo, ...) \
do\
{\
LOCK; \
LOG_TIME(logFile); \
fprintf(((logFile == NULL) ? stdout : logFile), "INFO\t"); \
fprintf(((logFile == NULL) ? stdout : logFile), "[%s:%d (%s) ]: ", __FILE__, __LINE__, __FUNCTION__); \
fprintf(((logFile == NULL) ? stdout : logFile), logInfo, ## __VA_ARGS__); \
fflush(logFile); \
UNLOCK; \
} while (0)
#define LOG_DEBUG(logFile,logInfo, ...) \
do\
{\
LOCK; \
LOG_TIME(logFile);\
fprintf(((logFile==NULL)?stdout:logFile), "DEBUG\t"); \
fprintf(((logFile==NULL)?stdout:logFile), "[%s:%d (%s) ]: ", __FILE__, __LINE__, __FUNCTION__); \
fprintf(((logFile==NULL)?stdout:logFile),logInfo, ## __VA_ARGS__); \
fflush(logFile); \
UNLOCK; \
} while (0)
#define LOG_ERROR(logFile,logInfo, ...) \
do\
{\
LOCK; \
LOG_TIME(logFile);\
fprintf(((logFile==NULL)?stdout:logFile), "ERROR\t"); \
fprintf(((logFile==NULL)?stdout:logFile), "[%s:%d (%s) ]: ", __FILE__, __LINE__, __FUNCTION__); \
fprintf(((logFile==NULL)?stdout:logFile),logInfo, ## __VA_ARGS__); \
fflush(logFile); \
UNLOCK; \
} while (0)
#define LOG_WARN(logFile,logInfo, ...) \
do\
{\
LOCK; \
LOG_TIME(logFile);\
fprintf(((logFile==NULL)?stdout:logFile), "WARN\t"); \
fprintf(((logFile==NULL)?stdout:logFile), "[%s:%d (%s) ]: ", __FILE__, __LINE__, __FUNCTION__); \
fprintf(((logFile==NULL)?stdout:logFile),logInfo, ## __VA_ARGS__); \
fflush(logFile); \
UNLOCK; \
} while (0)
#define LOG_TIME(logFile) \
do \
{ \
LogTime currentTime = GetTime(); \
fprintf(((logFile == NULL) ? stdout : logFile), \
"%s-%s-%s %s:%s:%s.%s\t", \
currentTime.year.c_str(), \
currentTime.month.c_str(), \
currentTime.day.c_str(), \
currentTime.hour.c_str(), \
currentTime.minute.c_str(), \
currentTime.second.c_str(), \
currentTime.millisecond.c_str()); \
} while(0)
#endif // __SIMPLE_LOG_H__
#define LOG_INFO(logFile, logInfo, ...) \
do \
{ \
LOCK; \
LOG_TIME(logFile); \
fprintf(((logFile == NULL) ? stdout : logFile), "INFO\t"); \
fprintf(((logFile == NULL) ? stdout : logFile), \
"[%s:%d (%s) ]: ", \
__FILE__, \
__LINE__, \
__FUNCTION__); \
fprintf(((logFile == NULL) ? stdout : logFile), logInfo, ##__VA_ARGS__); \
fflush(logFile); \
UNLOCK; \
} while(0)
#define LOG_DEBUG(logFile, logInfo, ...) \
do \
{ \
LOCK; \
LOG_TIME(logFile); \
fprintf(((logFile == NULL) ? stdout : logFile), "DEBUG\t"); \
fprintf(((logFile == NULL) ? stdout : logFile), \
"[%s:%d (%s) ]: ", \
__FILE__, \
__LINE__, \
__FUNCTION__); \
fprintf(((logFile == NULL) ? stdout : logFile), logInfo, ##__VA_ARGS__); \
fflush(logFile); \
UNLOCK; \
} while(0)
#define LOG_ERROR(logFile, logInfo, ...) \
do \
{ \
LOCK; \
LOG_TIME(logFile); \
fprintf(((logFile == NULL) ? stdout : logFile), "ERROR\t"); \
fprintf(((logFile == NULL) ? stdout : logFile), \
"[%s:%d (%s) ]: ", \
__FILE__, \
__LINE__, \
__FUNCTION__); \
fprintf(((logFile == NULL) ? stdout : logFile), logInfo, ##__VA_ARGS__); \
fflush(logFile); \
UNLOCK; \
} while(0)
#define LOG_WARN(logFile, logInfo, ...) \
do \
{ \
LOCK; \
LOG_TIME(logFile); \
fprintf(((logFile == NULL) ? stdout : logFile), "WARN\t"); \
fprintf(((logFile == NULL) ? stdout : logFile), \
"[%s:%d (%s) ]: ", \
__FILE__, \
__LINE__, \
__FUNCTION__); \
fprintf(((logFile == NULL) ? stdout : logFile), logInfo, ##__VA_ARGS__); \
fflush(logFile); \
UNLOCK; \
} while(0)
#endif // __SIMPLE_LOG_H__
Src/Utility/tokenization.cpp
View file @
0f9dc829
...
...
@@ -6,224 +6,257 @@
#include "./tokenization.h"
namespace
cuBERT
{
void
FullTokenizer
::
convert_tokens_to_ids
(
const
std
::
vector
<
std
::
string
>
&
tokens
,
uint64_t
*
ids
)
{
for
(
int
i
=
0
;
i
<
tokens
.
size
();
++
i
)
{
ids
[
i
]
=
convert_token_to_id
(
tokens
[
i
]);
}
namespace
cuBERT
{
void
FullTokenizer
::
convert_tokens_to_ids
(
const
std
::
vector
<
std
::
string
>&
tokens
,
uint64_t
*
ids
)
{
for
(
int
i
=
0
;
i
<
tokens
.
size
();
++
i
)
{
ids
[
i
]
=
convert_token_to_id
(
tokens
[
i
]);
}
}
// trim from start (in place)
static
inline
void
ltrim
(
std
::
string
&
s
)
{
s
.
erase
(
s
.
begin
(),
std
::
find_if
(
s
.
begin
(),
s
.
end
(),
[](
int
ch
)
{
return
!
std
::
isspace
(
ch
);
}));
}
static
inline
void
ltrim
(
std
::
string
&
s
)
{
s
.
erase
(
s
.
begin
(),
std
::
find_if
(
s
.
begin
(),
s
.
end
(),
[](
int
ch
)
{
return
!
std
::
isspace
(
ch
);
}));
}
// trim from end (in place)
static
inline
void
rtrim
(
std
::
string
&
s
)
{
s
.
erase
(
std
::
find_if
(
s
.
rbegin
(),
s
.
rend
(),
[](
int
ch
)
{
return
!
std
::
isspace
(
ch
);
}).
base
(),
s
.
end
());
}
static
inline
void
rtrim
(
std
::
string
&
s
)
{
s
.
erase
(
std
::
find_if
(
s
.
rbegin
(),
s
.
rend
(),
[](
int
ch
)
{
return
!
std
::
isspace
(
ch
);
}).
base
(),
s
.
end
());
}
// trim from both ends (in place)
static
inline
void
trim
(
std
::
string
&
s
)
{
ltrim
(
s
);
rtrim
(
s
);
}
void
load_vocab
(
const
char
*
vocab_file
,
std
::
unordered_map
<
std
::
string
,
uint64_t
>
*
vocab
)
{
std
::
ifstream
file
(
vocab_file
);
if
(
!
file
)
{
throw
std
::
invalid_argument
(
"Unable to open vocab file"
);
}
unsigned
int
index
=
0
;
std
::
string
line
;
while
(
std
::
getline
(
file
,
line
))
{
trim
(
line
);
(
*
vocab
)[
line
]
=
index
;
index
++
;
}
static
inline
void
trim
(
std
::
string
&
s
)
{
ltrim
(
s
);
rtrim
(
s
);
}
file
.
close
();
void
load_vocab
(
const
char
*
vocab_file
,
std
::
unordered_map
<
std
::
string
,
uint64_t
>*
vocab
)
{
std
::
ifstream
file
(
vocab_file
);
if
(
!
file
)
{
throw
std
::
invalid_argument
(
"Unable to open vocab file"
);
}
inline
bool
_is_whitespace
(
int
c
,
const
char
*
cat
)
{
if
(
c
==
' '
||
c
==
'\t'
||
c
==
'\n'
||
c
==
'\r'
)
{
return
true
;
}
return
cat
[
0
]
==
'Z'
&&
cat
[
1
]
==
's'
;
unsigned
int
index
=
0
;
std
::
string
line
;
while
(
std
::
getline
(
file
,
line
))
{
trim
(
line
);
(
*
vocab
)[
line
]
=
index
;
index
++
;
}
inline
bool
_is_control
(
int
c
,
const
char
*
cat
)
{
// These are technically control characters but we count them as whitespace characters.
if
(
c
==
'\t'
||
c
==
'\n'
||
c
==
'\r'
)
{
return
false
;
}
return
'C'
==
*
cat
;
}
file
.
close
();
}
inline
bool
_is_punctuation
(
int
cp
,
const
char
*
cat
)
{
// We treat all non-letter/number ASCII as punctuation.
// Characters such as "^", "$", and "`" are not in the Unicode
// Punctuation class but we treat them as punctuation anyways, for
// consistency.
if
((
cp
>=
33
&&
cp
<=
47
)
||
(
cp
>=
58
&&
cp
<=
64
)
||
(
cp
>=
91
&&
cp
<=
96
)
||
(
cp
>=
123
&&
cp
<=
126
))
{
return
true
;
}
return
'P'
==
*
cat
;
inline
bool
_is_whitespace
(
int
c
,
const
char
*
cat
)
{
if
(
c
==
' '
||
c
==
'\t'
||
c
==
'\n'
||
c
==
'\r'
)
{
return
true
;
}
return
cat
[
0
]
==
'Z'
&&
cat
[
1
]
==
's'
;
}
bool
_is_whitespace
(
int
c
)
{
return
_is_whitespace
(
c
,
utf8proc_category_string
(
c
));
inline
bool
_is_control
(
int
c
,
const
char
*
cat
)
{
// These are technically control characters but we count them as whitespace characters.
if
(
c
==
'\t'
||
c
==
'\n'
||
c
==
'\r'
)
{
return
false
;
}
return
'C'
==
*
cat
;
}
bool
_is_control
(
int
c
)
{
return
_is_control
(
c
,
utf8proc_category_string
(
c
));
inline
bool
_is_punctuation
(
int
cp
,
const
char
*
cat
)
{
// We treat all non-letter/number ASCII as punctuation.
// Characters such as "^", "$", and "`" are not in the Unicode
// Punctuation class but we treat them as punctuation anyways, for
// consistency.
if
((
cp
>=
33
&&
cp
<=
47
)
||
(
cp
>=
58
&&
cp
<=
64
)
||
(
cp
>=
91
&&
cp
<=
96
)
||
(
cp
>=
123
&&
cp
<=
126
))
{
return
true
;
}
return
'P'
==
*
cat
;
}
bool
_is_punctuation
(
int
cp
)
{
return
_is_punctuation
(
cp
,
utf8proc_category_string
(
cp
));
}
bool
_is_whitespace
(
int
c
)
{
return
_is_whitespace
(
c
,
utf8proc_category_string
(
c
));
}
bool
_is_control
(
int
c
)
{
return
_is_control
(
c
,
utf8proc_category_string
(
c
));
}
bool
_is_punctuation
(
int
cp
)
{
return
_is_punctuation
(
cp
,
utf8proc_category_string
(
cp
));
}
bool
BasicTokenizer
::
_is_chinese_char
(
int
cp
)
{
// This defines a "chinese character" as anything in the CJK Unicode block:
// https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
//
// Note that the CJK Unicode block is NOT all Japanese and Korean characters,
// despite its name. The modern Korean Hangul alphabet is a different block,
// as is Japanese Hiragana and Katakana. Those alphabets are used to write
// space-separated words, so they are not treated specially and handled
// like the all of the other languages.
return
(
cp
>=
0x4E00
&&
cp
<=
0x9FFF
)
||
(
cp
>=
0x3400
&&
cp
<=
0x4DBF
)
||
(
cp
>=
0x20000
&&
cp
<=
0x2A6DF
)
||
(
cp
>=
0x2A700
&&
cp
<=
0x2B73F
)
||
(
cp
>=
0x2B740
&&
cp
<=
0x2B81F
)
||
(
cp
>=
0x2B820
&&
cp
<=
0x2CEAF
)
||
(
cp
>=
0xF900
&&
cp
<=
0xFAFF
)
||
(
cp
>=
0x2F800
&&
cp
<=
0x2FA1F
);
}
bool
BasicTokenizer
::
_is_chinese_char
(
int
cp
)
{
// This defines a "chinese character" as anything in the CJK Unicode block:
// https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
//
// Note that the CJK Unicode block is NOT all Japanese and Korean characters,
// despite its name. The modern Korean Hangul alphabet is a different block,
// as is Japanese Hiragana and Katakana. Those alphabets are used to write
// space-separated words, so they are not treated specially and handled
// like the all of the other languages.
return
(
cp
>=
0x4E00
&&
cp
<=
0x9FFF
)
||
(
cp
>=
0x3400
&&
cp
<=
0x4DBF
)
||
(
cp
>=
0x20000
&&
cp
<=
0x2A6DF
)
||
(
cp
>=
0x2A700
&&
cp
<=
0x2B73F
)
||
(
cp
>=
0x2B740
&&
cp
<=
0x2B81F
)
||
(
cp
>=
0x2B820
&&
cp
<=
0x2CEAF
)
||
(
cp
>=
0xF900
&&
cp
<=
0xFAFF
)
||
(
cp
>=
0x2F800
&&
cp
<=
0x2FA1F
);
void
BasicTokenizer
::
tokenize
(
const
char
*
text
,
std
::
vector
<
std
::
string
>*
output_tokens
,
size_t
max_length
)
{
// This was added on November 1st, 2018 for the multilingual and Chinese
// models. This is also applied to the English models now, but it doesn't
// matter since the English models were not trained on any Chinese data
// and generally don't have any Chinese data in them (there are Chinese
// characters in the vocabulary because Wikipedia does have some Chinese
// words in the English Wikipedia.).
if
(
do_lower_case
)
{
text
=
(
const
char
*
)
utf8proc_NFD
((
const
utf8proc_uint8_t
*
)
text
);
}
void
BasicTokenizer
::
tokenize
(
const
char
*
text
,
std
::
vector
<
std
::
string
>
*
output_tokens
,
size_t
max_length
)
{
// This was added on November 1st, 2018 for the multilingual and Chinese
// models. This is also applied to the English models now, but it doesn't
// matter since the English models were not trained on any Chinese data
// and generally don't have any Chinese data in them (there are Chinese
// characters in the vocabulary because Wikipedia does have some Chinese
// words in the English Wikipedia.).
if
(
do_lower_case
)
{
text
=
(
const
char
*
)
utf8proc_NFD
((
const
utf8proc_uint8_t
*
)
text
);
size_t
word_bytes
=
std
::
strlen
(
text
);
bool
new_token
=
true
;
size_t
subpos
=
0
;
int
cp
;
char
dst
[
4
];
while
(
word_bytes
>
0
)
{
int
len
=
utf8proc_iterate
((
const
utf8proc_uint8_t
*
)
text
+
subpos
,
word_bytes
,
&
cp
);
if
(
len
<
0
)
{
std
::
cerr
<<
"UTF-8 decode error: "
<<
text
<<
std
::
endl
;
break
;
}
if
(
do_lower_case
)
{
cp
=
utf8proc_tolower
(
cp
);
}
size_t
word_bytes
=
std
::
strlen
(
text
);
bool
new_token
=
true
;
size_t
subpos
=
0
;
int
cp
;
char
dst
[
4
];
while
(
word_bytes
>
0
)
{
int
len
=
utf8proc_iterate
((
const
utf8proc_uint8_t
*
)
text
+
subpos
,
word_bytes
,
&
cp
);
if
(
len
<
0
)
{
std
::
cerr
<<
"UTF-8 decode error: "
<<
text
<<
std
::
endl
;
break
;
}
if
(
do_lower_case
)
{
cp
=
utf8proc_tolower
(
cp
);
const
char
*
cat
=
utf8proc_category_string
(
cp
);
if
(
cp
==
0
||
cp
==
0xfffd
||
_is_control
(
cp
,
cat
))
{
// pass
}
else
if
(
do_lower_case
&&
cat
[
0
]
==
'M'
&&
cat
[
1
]
==
'n'
)
{
// pass
}
else
if
(
_is_whitespace
(
cp
,
cat
))
{
new_token
=
true
;
}
else
{
size_t
dst_len
=
len
;
const
char
*
dst_ptr
=
text
+
subpos
;
if
(
do_lower_case
)
{
dst_len
=
utf8proc_encode_char
(
cp
,
(
utf8proc_uint8_t
*
)
dst
);
dst_ptr
=
dst
;
}
const
char
*
cat
=
utf8proc_category_string
(
cp
);
if
(
cp
==
0
||
cp
==
0xfffd
||
_is_control
(
cp
,
cat
))
{
// pass
}
else
if
(
do_lower_case
&&
cat
[
0
]
==
'M'
&&
cat
[
1
]
==
'n'
)
{
// pass
}
else
if
(
_is_whitespace
(
cp
,
cat
))
{
if
(
_is_punctuation
(
cp
,
cat
)
||
_is_chinese_char
(
cp
))
{
output_tokens
->
emplace_back
(
dst_ptr
,
dst_len
);
new_token
=
true
;
}
else
{
size_t
dst_len
=
len
;
const
char
*
dst_ptr
=
text
+
subpos
;
if
(
do_lower_case
)
{
dst_len
=
utf8proc_encode_char
(
cp
,
(
utf8proc_uint8_t
*
)
dst
);
dst_ptr
=
dst
;
}
if
(
_is_punctuation
(
cp
,
cat
)
||
_is_chinese_char
(
cp
))
{
}
else
{
if
(
new_token
)
{
output_tokens
->
emplace_back
(
dst_ptr
,
dst_len
);
new_token
=
true
;
}
else
{
if
(
new_token
)
{
output_tokens
->
emplace_back
(
dst_ptr
,
dst_len
);
new_token
=
false
;
}
else
{
output_tokens
->
at
(
output_tokens
->
size
()
-
1
).
append
(
dst_ptr
,
dst_len
);
}
new_token
=
false
;
}
else
{
output_tokens
->
at
(
output_tokens
->
size
()
-
1
).
append
(
dst_ptr
,
dst_len
);
}
}
}
word_bytes
=
word_bytes
-
len
;
subpos
=
subpos
+
len
;
word_bytes
=
word_bytes
-
len
;
subpos
=
subpos
+
len
;
// early terminate
if
(
output_tokens
->
size
()
>=
max_length
)
{
break
;
}
// early terminate
if
(
output_tokens
->
size
()
>=
max_length
)
{
break
;
}
}
if
(
do_lower_case
)
{
free
((
void
*
)
text
);
}
if
(
do_lower_case
)
{
free
((
void
*
)
text
);
}
}
void
WordpieceTokenizer
::
tokenize
(
const
std
::
string
&
token
,
std
::
vector
<
std
::
string
>*
output_tokens
)
{
if
(
token
.
size
()
>
max_input_chars_per_word
)
{
// FIXME: slightly different
output_tokens
->
push_back
(
unk_token
);
return
;
}
size_t
output_tokens_len
=
output_tokens
->
size
();
for
(
size_t
start
=
0
;
start
<
token
.
size
();)
{
bool
is_bad
=
true
;
// TODO: can be optimized by prefix-tree
for
(
size_t
end
=
token
.
size
();
start
<
end
;
--
end
)
{
// FIXME: slightly different
std
::
string
substr
=
start
>
0
?
"##"
+
token
.
substr
(
start
,
end
-
start
)
:
token
.
substr
(
start
,
end
-
start
);
if
(
vocab
->
count
(
substr
))
{
is_bad
=
false
;
output_tokens
->
push_back
(
substr
);
start
=
end
;
break
;
}
}
void
WordpieceTokenizer
::
tokenize
(
const
std
::
string
&
token
,
std
::
vector
<
std
::
string
>
*
output_tokens
)
{
if
(
token
.
size
()
>
max_input_chars_per_word
)
{
// FIXME: slightly different
if
(
is_bad
)
{
output_tokens
->
resize
(
output_tokens_len
);
output_tokens
->
push_back
(
unk_token
);
return
;
}
size_t
output_tokens_len
=
output_tokens
->
size
();
for
(
size_t
start
=
0
;
start
<
token
.
size
();)
{
bool
is_bad
=
true
;
// TODO: can be optimized by prefix-tree
for
(
size_t
end
=
token
.
size
();
start
<
end
;
--
end
)
{
// FIXME: slightly different
std
::
string
substr
=
start
>
0
?
"##"
+
token
.
substr
(
start
,
end
-
start
)
:
token
.
substr
(
start
,
end
-
start
);
if
(
vocab
->
count
(
substr
))
{
is_bad
=
false
;
output_tokens
->
push_back
(
substr
);
start
=
end
;
break
;
}
}
if
(
is_bad
)
{
output_tokens
->
resize
(
output_tokens_len
);
output_tokens
->
push_back
(
unk_token
);
return
;
}
}
}
}
void
FullTokenizer
::
tokenize
(
const
char
*
text
,
std
::
vector
<
std
::
string
>
*
output_tokens
,
size_t
max_length
)
{
std
::
vector
<
std
::
string
>
tokens
;
tokens
.
reserve
(
max_length
);
basic_tokenizer
->
tokenize
(
text
,
&
tokens
,
max_length
);
for
(
const
auto
&
token
:
tokens
)
{
wordpiece_tokenizer
->
tokenize
(
token
,
output_tokens
);
// early terminate
if
(
output_tokens
->
size
()
>=
max_length
)
{
break
;
}
void
FullTokenizer
::
tokenize
(
const
char
*
text
,
std
::
vector
<
std
::
string
>*
output_tokens
,
size_t
max_length
)
{
std
::
vector
<
std
::
string
>
tokens
;
tokens
.
reserve
(
max_length
);
basic_tokenizer
->
tokenize
(
text
,
&
tokens
,
max_length
);
for
(
const
auto
&
token
:
tokens
)
{
wordpiece_tokenizer
->
tokenize
(
token
,
output_tokens
);
// early terminate
if
(
output_tokens
->
size
()
>=
max_length
)
{
break
;
}
}
}
}
// namespace cuBERT
Src/Utility/tokenization.h
View file @
0f9dc829
...
...
@@ -6,158 +6,172 @@
#include <unordered_map>
#include <iostream>
namespace
cuBERT
{
namespace
cuBERT
{
void
load_vocab
(
const
char
*
vocab_file
,
std
::
unordered_map
<
std
::
string
,
uint64_t
>
*
vocab
);
void
load_vocab
(
const
char
*
vocab_file
,
std
::
unordered_map
<
std
::
string
,
uint64_t
>*
vocab
);
/**
* Checks whether `chars` is a whitespace character.
* @param c
* @return
*/
bool
_is_whitespace
(
int
c
);
bool
_is_whitespace
(
int
c
);
/**
* Checks whether `chars` is a control character.
* @param c
* @return
*/
bool
_is_control
(
int
c
);
bool
_is_control
(
int
c
);
/**
* Checks whether `chars` is a punctuation character.
* @param cp
* @return
*/
bool
_is_punctuation
(
int
cp
);
bool
_is_punctuation
(
int
cp
);
/**
* Runs basic tokenization (punctuation splitting, lower casing, etc.).
*/
class
BasicTokenizer
{
class
BasicTokenizer
{
public:
/**
* Constructs a BasicTokenizer.
* @param do_lower_case Whether to lower case the input.
*/
explicit
BasicTokenizer
(
bool
do_lower_case
=
true
)
:
do_lower_case
(
do_lower_case
)
{}
BasicTokenizer
(
const
BasicTokenizer
&
other
)
=
delete
;
virtual
~
BasicTokenizer
()
=
default
;
/**
* Tokenizes a piece of text.
*
* to_lower
* _run_strip_accents Strips accents from a piece of text.
* _clean_text Performs invalid character removal and whitespace cleanup on text.
* _tokenize_chinese_chars Adds whitespace around any CJK character.
* _run_split_on_punc Splits punctuation on a piece of text.
* whitespace_tokenize Runs basic whitespace cleaning and splitting on a piece of text.
*
* @param text
* @param output_tokens
*/
void
tokenize
(
const
char
*
text
,
std
::
vector
<
std
::
string
>
*
output_tokens
,
size_t
max_length
);
/**
* Constructs a BasicTokenizer.
* @param do_lower_case Whether to lower case the input.
*/
explicit
BasicTokenizer
(
bool
do_lower_case
=
true
)
:
do_lower_case
(
do_lower_case
)
{}
BasicTokenizer
(
const
BasicTokenizer
&
other
)
=
delete
;
virtual
~
BasicTokenizer
()
=
default
;
/**
* Tokenizes a piece of text.
*
* to_lower
* _run_strip_accents Strips accents from a piece of text.
* _clean_text Performs invalid character removal and whitespace cleanup on text.
* _tokenize_chinese_chars Adds whitespace around any CJK character.
* _run_split_on_punc Splits punctuation on a piece of text.
* whitespace_tokenize Runs basic whitespace cleaning and splitting on a piece of text.
*
* @param text
* @param output_tokens
*/
void
tokenize
(
const
char
*
text
,
std
::
vector
<
std
::
string
>*
output_tokens
,
size_t
max_length
);
private:
const
bool
do_lower_case
;
const
bool
do_lower_case
;
/**
* Checks whether CP is the codepoint of a CJK character.
* @param cp
* @return
*/
inline
static
bool
_is_chinese_char
(
int
cp
);
};
/**
* Checks whether CP is the codepoint of a CJK character.
* @param cp
* @return
*/
inline
static
bool
_is_chinese_char
(
int
cp
);
};
/**
* Runs WordPiece tokenziation.
*/
class
WordpieceTokenizer
{
class
WordpieceTokenizer
{
public:
explicit
WordpieceTokenizer
(
std
::
unordered_map
<
std
::
string
,
uint64_t
>
*
vocab
,
std
::
string
unk_token
=
"[UNK]"
,
int
max_input_chars_per_word
=
200
)
:
vocab
(
vocab
),
unk_token
(
unk_token
),
max_input_chars_per_word
(
max_input_chars_per_word
)
{}
WordpieceTokenizer
(
const
WordpieceTokenizer
&
other
)
=
delete
;
virtual
~
WordpieceTokenizer
()
=
default
;
/**
* Tokenizes a piece of text into its word pieces.
*
* This uses a greedy longest-match-first algorithm to perform tokenization
* using the given vocabulary.
*
* For example:
* input = "unaffable"
* output = ["un", "##aff", "##able"]
*
* @param text A single token or whitespace separated tokens. This should have already been passed through `BasicTokenizer.
* @param output_tokens A list of wordpiece tokens.
*/
void
tokenize
(
const
std
::
string
&
text
,
std
::
vector
<
std
::
string
>
*
output_tokens
);
explicit
WordpieceTokenizer
(
std
::
unordered_map
<
std
::
string
,
uint64_t
>*
vocab
,
std
::
string
unk_token
=
"[UNK]"
,
int
max_input_chars_per_word
=
200
)
:
vocab
(
vocab
),
unk_token
(
unk_token
),
max_input_chars_per_word
(
max_input_chars_per_word
)
{
}
WordpieceTokenizer
(
const
WordpieceTokenizer
&
other
)
=
delete
;
virtual
~
WordpieceTokenizer
()
=
default
;
/**
* Tokenizes a piece of text into its word pieces.
*
* This uses a greedy longest-match-first algorithm to perform tokenization
* using the given vocabulary.
*
* For example:
* input = "unaffable"
* output = ["un", "##aff", "##able"]
*
* @param text A single token or whitespace separated tokens. This should have already been
* passed through `BasicTokenizer.
* @param output_tokens A list of wordpiece tokens.
*/
void
tokenize
(
const
std
::
string
&
text
,
std
::
vector
<
std
::
string
>*
output_tokens
);
private:
const
std
::
unordered_map
<
std
::
string
,
uint64_t
>
*
vocab
;
const
std
::
string
unk_token
;
const
int
max_input_chars_per_word
;
};
const
std
::
unordered_map
<
std
::
string
,
uint64_t
>*
vocab
;
const
std
::
string
unk_token
;
const
int
max_input_chars_per_word
;
};
/**
* Runs end-to-end tokenziation.
*/
class
FullTokenizer
{
class
FullTokenizer
{
public:
FullTokenizer
(
const
char
*
vocab_file
,
bool
do_lower_case
=
true
)
{
vocab
=
new
std
::
unordered_map
<
std
::
string
,
uint64_t
>
();
load_vocab
(
vocab_file
,
vocab
);
basic_tokenizer
=
new
BasicTokenizer
(
do_lower_case
);
wordpiece_tokenizer
=
new
WordpieceTokenizer
(
vocab
);
FullTokenizer
(
const
char
*
vocab_file
,
bool
do_lower_case
=
true
)
{
vocab
=
new
std
::
unordered_map
<
std
::
string
,
uint64_t
>
();
load_vocab
(
vocab_file
,
vocab
);
basic_tokenizer
=
new
BasicTokenizer
(
do_lower_case
);
wordpiece_tokenizer
=
new
WordpieceTokenizer
(
vocab
);
}
~
FullTokenizer
()
{
if
(
wordpiece_tokenizer
!=
NULL
)
{
wordpiece_tokenizer
=
NULL
;
}
delete
wordpiece_tokenizer
;
~
FullTokenizer
()
{
if
(
wordpiece_tokenizer
!=
NULL
){
wordpiece_tokenizer
=
NULL
;
}
delete
wordpiece_tokenizer
;
if
(
basic_tokenizer
!=
NULL
){
basic_tokenizer
=
NULL
;
}
delete
basic_tokenizer
;
if
(
vocab
!=
NULL
){
vocab
=
NULL
;
}
delete
vocab
;
if
(
basic_tokenizer
!=
NULL
)
{
basic_tokenizer
=
NULL
;
}
delete
basic_tokenizer
;
void
tokenize
(
const
char
*
text
,
std
::
vector
<
std
::
string
>
*
output_tokens
,
size_t
max_length
);
inline
uint64_t
convert_token_to_id
(
const
std
::
string
&
token
)
{
auto
item
=
vocab
->
find
(
token
);
if
(
item
==
vocab
->
end
())
{
std
::
cerr
<<
"vocab missing key: "
<<
token
<<
std
::
endl
;
return
0
;
}
else
{
return
item
->
second
;
}
if
(
vocab
!=
NULL
)
{
vocab
=
NULL
;
}
delete
vocab
;
}
void
tokenize
(
const
char
*
text
,
std
::
vector
<
std
::
string
>*
output_tokens
,
size_t
max_length
);
inline
uint64_t
convert_token_to_id
(
const
std
::
string
&
token
)
{
auto
item
=
vocab
->
find
(
token
);
if
(
item
==
vocab
->
end
())
{
std
::
cerr
<<
"vocab missing key: "
<<
token
<<
std
::
endl
;
return
0
;
}
else
{
return
item
->
second
;
}
}
void
convert_tokens_to_ids
(
const
std
::
vector
<
std
::
string
>
&
tokens
,
uint64_t
*
ids
);
void
convert_tokens_to_ids
(
const
std
::
vector
<
std
::
string
>&
tokens
,
uint64_t
*
ids
);
private:
std
::
unordered_map
<
std
::
string
,
uint64_t
>
*
vocab
;
BasicTokenizer
*
basic_tokenizer
;
WordpieceTokenizer
*
wordpiece_tokenizer
;
};
std
::
unordered_map
<
std
::
string
,
uint64_t
>*
vocab
;
BasicTokenizer
*
basic_tokenizer
;
WordpieceTokenizer
*
wordpiece_tokenizer
;
};
}
}
// namespace cuBERT
#endif //CUBERT_TOKENIZATION_H
#endif //
CUBERT_TOKENIZATION_H
Src/Utility/utf8proc.c
View file @
0f9dc829
This diff is collapsed.
Click to expand it.
Src/Utility/utf8proc.h
View file @
0f9dc829
This diff is collapsed.
Click to expand it.
Src/main.cpp
View file @
0f9dc829
...
...
@@ -12,7 +12,7 @@ int main()
// 加载GPT2模型
migraphxSamples
::
GPT2
gpt2
;
migraphxSamples
::
ErrorCode
errorCode
=
gpt2
.
Initialize
();
if
(
errorCode
!=
migraphxSamples
::
SUCCESS
)
if
(
errorCode
!=
migraphxSamples
::
SUCCESS
)
{
LOG_ERROR
(
stdout
,
"fail to initialize GPT2!
\n
"
);
exit
(
-
1
);
...
...
@@ -25,7 +25,7 @@ int main()
std
::
string
buf
;
std
::
vector
<
std
::
string
>
output
;
infile
.
open
(
"../Resource/vocab_shici.txt"
);
while
(
std
::
getline
(
infile
,
buf
))
while
(
std
::
getline
(
infile
,
buf
))
{
output
.
push_back
(
buf
);
}
...
...
@@ -37,7 +37,7 @@ int main()
std
::
vector
<
std
::
string
>
result
;
std
::
cout
<<
"开始和GPT2对诗,输入CTRL + Z以退出"
<<
std
::
endl
;
while
(
true
)
while
(
true
)
{
// 数据预处理
std
::
cout
<<
"question: "
;
...
...
@@ -45,7 +45,7 @@ int main()
gpt2
.
Preprocessing
(
tokenizer
,
question
,
input_id
);
// 推理
for
(
int
i
=
0
;
i
<
50
;
++
i
)
for
(
int
i
=
0
;
i
<
50
;
++
i
)
{
long
unsigned
int
outputs
=
gpt2
.
Inference
(
input_id
);
if
(
outputs
==
102
)
...
...
@@ -57,7 +57,7 @@ int main()
}
// 将数值映射为字符
for
(
int
i
=
0
;
i
<
score
.
size
();
++
i
)
for
(
int
i
=
0
;
i
<
score
.
size
();
++
i
)
{
result
.
push_back
(
output
[
score
[
i
]]);
}
...
...
@@ -65,12 +65,12 @@ int main()
// 打印结果
std
::
cout
<<
"chatbot: "
;
std
::
cout
<<
question
;
for
(
int
j
=
0
;
j
<
result
.
size
();
++
j
)
for
(
int
j
=
0
;
j
<
result
.
size
();
++
j
)
{
std
::
cout
<<
result
[
j
];
}
std
::
cout
<<
std
::
endl
;
// 清除数据
input_id
.
clear
();
result
.
clear
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment