Commit 13de681d authored by liucong's avatar liucong
Browse files

精简代码

parent d5a26d95
#! /bin/sh
############### Ubuntu ###############
# 参考:https://docs.opencv.org/3.4.11/d7/d9f/tutorial_linux_install.html
# apt-get install build-essential -y
# apt-get install cmake git libgtk2.0-dev pkg-config libavcodec-dev libavformat-dev libswscale-dev -y
# apt-get install python-dev python-numpy libtbb2 libtbb-dev libjpeg-dev libpng-dev libtiff-dev libjasper-dev libdc1394-22-dev -y # 处理图像所需的包,可选
############### CentOS ###############
yum install gcc gcc-c++ gtk2-devel gimp-devel gimp-devel-tools gimp-help-browser zlib-devel libtiff-devel libjpeg-devel libpng-devel gstreamer-devel libavc1394-devel libraw1394-devel libdc1394-devel jasper-devel jasper-utils swig python libtool nasm -y
\ No newline at end of file
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
cmake_minimum_required(VERSION 3.5) cmake_minimum_required(VERSION 3.5)
# 设置项目名 # 设置项目名
project(MIGraphX_Samples) project(Bert)
# 设置编译器 # 设置编译器
set(CMAKE_CXX_COMPILER g++) set(CMAKE_CXX_COMPILER g++)
...@@ -12,7 +12,6 @@ set(CMAKE_BUILD_TYPE release) ...@@ -12,7 +12,6 @@ set(CMAKE_BUILD_TYPE release)
# 添加头文件路径 # 添加头文件路径
set(INCLUDE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/Src/ set(INCLUDE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/Src/
${CMAKE_CURRENT_SOURCE_DIR}/Src/Utility/ ${CMAKE_CURRENT_SOURCE_DIR}/Src/Utility/
${CMAKE_CURRENT_SOURCE_DIR}/Src/NLP/Bert/
$ENV{DTKROOT}/include/ $ENV{DTKROOT}/include/
${CMAKE_CURRENT_SOURCE_DIR}/depend/include/) ${CMAKE_CURRENT_SOURCE_DIR}/depend/include/)
include_directories(${INCLUDE_PATH}) include_directories(${INCLUDE_PATH})
...@@ -23,11 +22,7 @@ set(LIBRARY_PATH ${CMAKE_CURRENT_SOURCE_DIR}/depend/lib64/ ...@@ -23,11 +22,7 @@ set(LIBRARY_PATH ${CMAKE_CURRENT_SOURCE_DIR}/depend/lib64/
link_directories(${LIBRARY_PATH}) link_directories(${LIBRARY_PATH})
# 添加依赖库 # 添加依赖库
set(LIBRARY opencv_core set(LIBRARY migraphx_ref
opencv_imgproc
opencv_imgcodecs
opencv_dnn
migraphx_ref
migraphx migraphx
migraphx_c migraphx_c
migraphx_device migraphx_device
...@@ -37,12 +32,11 @@ link_libraries(${LIBRARY}) ...@@ -37,12 +32,11 @@ link_libraries(${LIBRARY})
# 添加源文件 # 添加源文件
set(SOURCE_FILES ${CMAKE_CURRENT_SOURCE_DIR}/Src/main.cpp set(SOURCE_FILES ${CMAKE_CURRENT_SOURCE_DIR}/Src/main.cpp
${CMAKE_CURRENT_SOURCE_DIR}/Src/Sample.cpp ${CMAKE_CURRENT_SOURCE_DIR}/Src/Bert.cpp
${CMAKE_CURRENT_SOURCE_DIR}/Src/NLP/Bert/Bert.cpp ${CMAKE_CURRENT_SOURCE_DIR}/Src/Utility/tokenization.cpp
${CMAKE_CURRENT_SOURCE_DIR}/Src/NLP/Bert/tokenization.cpp ${CMAKE_CURRENT_SOURCE_DIR}/Src/Utility/utf8proc.c
${CMAKE_CURRENT_SOURCE_DIR}/Src/NLP/Bert/utf8proc.c
${CMAKE_CURRENT_SOURCE_DIR}/Src/Utility/CommonUtility.cpp ${CMAKE_CURRENT_SOURCE_DIR}/Src/Utility/CommonUtility.cpp
${CMAKE_CURRENT_SOURCE_DIR}/Src/Utility/Filesystem.cpp) ${CMAKE_CURRENT_SOURCE_DIR}/Src/Utility/Filesystem.cpp)
# 添加可执行目标 # 添加可执行目标
add_executable(MIGraphX_Samples ${SOURCE_FILES}) add_executable(Bert ${SOURCE_FILES})
Doc/Images/Bert_06.png

26.7 KB | W: | H:

Doc/Images/Bert_06.png

17.9 KB | W: | H:

Doc/Images/Bert_06.png
Doc/Images/Bert_06.png
Doc/Images/Bert_06.png
Doc/Images/Bert_06.png
  • 2-up
  • Swipe
  • Onion skin
# Bert # Bert
本示例主要通过Bert模型说明如何使用MIGraphX C++ API进行自然语言处理模型的推理,包括参数设置、数据准备、预处理、模型推理以及数据后处理。 本示例主要通过Bert模型说明如何使用MIGraphX C++ API进行自然语言处理模型的推理,包括数据准备、预处理、模型推理以及数据后处理。
## 模型简介 ## 模型简介
自然语言处理(Natural Language Processing,NLP )是能够实现人与计算机之间用自然语言进行有效沟通的理论和方法,是计算机科学领域与人工智能领域中的一个重要方向。本次采用经典的Bert模型完成问题回答任务,模型和分词文件下载链接:https://pan.baidu.com/s/1yc30IzM4ocOpTpfFuUMR0w, 提取码:8f1a, 下载bertsquad-10.onnx文件和uncased_L-12_H-768_A-12分词文件保存在Resource/Models/NLP/Bert文件夹下。整体模型结构如下图所示,也可以通过netron工具:https://netron.app/ 查看Bert模型结构。 自然语言处理(Natural Language Processing,NLP )是能够实现人与计算机之间用自然语言进行有效沟通的理论和方法,是计算机科学领域与人工智能领域中的一个重要方向。本次采用经典的Bert模型完成问题回答任务,模型和分词文件下载链接:https://pan.baidu.com/s/1yc30IzM4ocOpTpfFuUMR0w, 提取码:8f1a, 下载bertsquad-10.onnx文件和uncased_L-12_H-768_A-12分词文件保存在Resource/文件夹下。整体模型结构如下图所示,也可以通过netron工具:https://netron.app/ 查看Bert模型结构。
<img src="../Images/Bert_01.png" style="zoom:100%;" align=middle> <img src="./Images/Bert_01.png" style="zoom:100%;" align=middle>
问题回答任务是指输入一段上下文文本的描述和一个问题,模型从给定的文本中预测出答案。例如: 问题回答任务是指输入一段上下文文本的描述和一个问题,模型从给定的文本中预测出答案。例如:
...@@ -16,17 +16,6 @@ ...@@ -16,17 +16,6 @@
3.答案:Li Ming 3.答案:Li Ming
``` ```
## 参数设置
在samples工程中的Resource/Configuration.xml文件的Bert节点表示Bert模型的参数,主要设置模型的读取路径。
```xml
<!--Bert-->
<Bert>
<ModelPath>"../Resource/Models/NLP/bertsquad-10.onnx"</ModelPath>
</Bert>
```
## 数据准备 ## 数据准备
在自然语言处理领域中,首先需要准备文本数据,如下所示,通常需要提供问题(question)和上下文文本(context),自己可以根据需求准备相应的问题和上下文文本作为输入数据,进行模型推理。 在自然语言处理领域中,首先需要准备文本数据,如下所示,通常需要提供问题(question)和上下文文本(context),自己可以根据需求准备相应的问题和上下文文本作为输入数据,进行模型推理。
...@@ -52,7 +41,7 @@ ...@@ -52,7 +41,7 @@
如下图所示,为滑动窗口的具体操作: 如下图所示,为滑动窗口的具体操作:
<img src="../Images/Bert_03.png" style="zoom:80%;" align=middle> <img src="./Images/Bert_03.png" style="zoom:80%;" align=middle>
从图中可以看出,通过指定窗口大小为256,进行滑动窗口处理可以将上下文文本分成多个子文本,用于后续的数据拼接。 从图中可以看出,通过指定窗口大小为256,进行滑动窗口处理可以将上下文文本分成多个子文本,用于后续的数据拼接。
...@@ -106,7 +95,7 @@ ErrorCode Bert::Preprocessing(...) ...@@ -106,7 +95,7 @@ ErrorCode Bert::Preprocessing(...)
### 数据拼接 ### 数据拼接
当获得指定的问题和上下文文本时,对问题和上下文文本进行拼接操作,具体过程如下图所示: 当获得指定的问题和上下文文本时,对问题和上下文文本进行拼接操作,具体过程如下图所示:
<img src="../Images/Bert_02.png" style="zoom:80%;" align=middle> <img src="./Images/Bert_02.png" style="zoom:80%;" align=middle>
从图中可以看出,是将问题和上下文文本拼接成一个序列,问题和上下文文本用[SEP]符号隔开,完成数据拼接后再输入到模型中进行特征提取。其中,“[CLS]”是一个分类标志,表示后面的内容属于问题文本,“[SEP]”字符是一个分割标志,用来将问题和上下文文本分开。 从图中可以看出,是将问题和上下文文本拼接成一个序列,问题和上下文文本用[SEP]符号隔开,完成数据拼接后再输入到模型中进行特征提取。其中,“[CLS]”是一个分类标志,表示后面的内容属于问题文本,“[SEP]”字符是一个分割标志,用来将问题和上下文文本分开。
...@@ -207,7 +196,7 @@ ErrorCode Bert::Inference(...) ...@@ -207,7 +196,7 @@ ErrorCode Bert::Inference(...)
获得模型的推理结果后,并不能直接作为问题回答任务的结果显示,如下图所示,还需要进一步数据处理,得到最终的预测结果。 获得模型的推理结果后,并不能直接作为问题回答任务的结果显示,如下图所示,还需要进一步数据处理,得到最终的预测结果。
<img src="../Images/Bert_04.png" style="zoom:80%;" align=middle> <img src="./Images/Bert_04.png" style="zoom:80%;" align=middle>
从图中可以看出,数据后处理主要包含如下操作: 从图中可以看出,数据后处理主要包含如下操作:
......
...@@ -4,9 +4,9 @@ ...@@ -4,9 +4,9 @@
## 模型简介 ## 模型简介
自然语言处理(Natural Language Processing,NLP )是能够实现人与计算机之间用自然语言进行有效沟通的理论和方法,是计算机科学领域与人工智能领域中的一个重要方向。本次采用经典的Bert模型完成问题回答任务,模型和分词文件下载链接:https://pan.baidu.com/s/1yc30IzM4ocOpTpfFuUMR0w, 提取码:8f1a, 将bertsquad-10.onnx文件和uncased_L-12_H-768_A-12分词文件保存在Resource/Models/NLP/Bert文件夹下。整体模型结构如下图所示,也可以通过netron工具:https://netron.app/ 查看Bert模型结构。 自然语言处理(Natural Language Processing,NLP )是能够实现人与计算机之间用自然语言进行有效沟通的理论和方法,是计算机科学领域与人工智能领域中的一个重要方向。本次采用经典的Bert模型完成问题回答任务,模型和分词文件下载链接:https://pan.baidu.com/s/1yc30IzM4ocOpTpfFuUMR0w, 提取码:8f1a, 将bertsquad-10.onnx文件和uncased_L-12_H-768_A-12分词文件保存在Resource/文件夹下。整体模型结构如下图所示,也可以通过netron工具:https://netron.app/ 查看Bert模型结构。
<img src="../Images/Bert_01.png" style="zoom:100%;" align=middle> <img src="./Images/Bert_01.png" style="zoom:100%;" align=middle>
问题回答任务是指输入一段上下文文本的描述和一个问题,模型从上下文文本中预测出答案。例如: 问题回答任务是指输入一段上下文文本的描述和一个问题,模型从上下文文本中预测出答案。例如:
...@@ -107,7 +107,7 @@ def read_squad_examples(input_file): ...@@ -107,7 +107,7 @@ def read_squad_examples(input_file):
1.滑动窗口操作,如下图所示,当问题加上下文文本超过256个字符时,采取滑动窗口的方法构建输入序列。 1.滑动窗口操作,如下图所示,当问题加上下文文本超过256个字符时,采取滑动窗口的方法构建输入序列。
<img src="../Images/Bert_03.png" style="zoom:80%;" align=middle> <img src="./Images/Bert_03.png" style="zoom:80%;" align=middle>
从图中可以看出,问题部分不参与滑动处理,只将上下文文本进行滑动窗口操作,裁切得到多个子文本,用于后续的数据拼接。 从图中可以看出,问题部分不参与滑动处理,只将上下文文本进行滑动窗口操作,裁切得到多个子文本,用于后续的数据拼接。
...@@ -198,7 +198,7 @@ for idx in range(0, n): ...@@ -198,7 +198,7 @@ for idx in range(0, n):
获得推理结果后,并不能直接作为问题回答任务的结果显示,如下图所示,还需要进一步数据处理,得到最终的预测结果。 获得推理结果后,并不能直接作为问题回答任务的结果显示,如下图所示,还需要进一步数据处理,得到最终的预测结果。
<img src="../Images/Bert_04.png" style="zoom:80%;" align=middle> <img src="./Images/Bert_04.png" style="zoom:80%;" align=middle>
从图中可以看出,数据后处理主要包含如下操作: 从图中可以看出,数据后处理主要包含如下操作:
......
...@@ -10,7 +10,7 @@ RawResult = collections.namedtuple("RawResult", ...@@ -10,7 +10,7 @@ RawResult = collections.namedtuple("RawResult",
["unique_id", "start_logits", "end_logits"]) ["unique_id", "start_logits", "end_logits"])
# 数据前处理 # 数据前处理
input_file = '../../../Resource/Models/NLP/Bert/inputs_data.json' input_file = '../Resource/inputs_data.json'
# 使用run_onnx_squad中的read_squad_examples方法读取输入文件,进行数据处理,将文本拆分成一个个单词 # 使用run_onnx_squad中的read_squad_examples方法读取输入文件,进行数据处理,将文本拆分成一个个单词
eval_examples = read_squad_examples(input_file) eval_examples = read_squad_examples(input_file)
...@@ -23,7 +23,7 @@ n_best_size = 20 # 预选数量 ...@@ -23,7 +23,7 @@ n_best_size = 20 # 预选数量
max_answer_length = 30 # 问题的最大长度 max_answer_length = 30 # 问题的最大长度
# 分词工具 # 分词工具
vocab_file = os.path.join('../../../Resource/Models/NLP/Bert/uncased_L-12_H-768_A-12', 'vocab.txt') vocab_file = os.path.join('../Resource/uncased_L-12_H-768_A-12', 'vocab.txt')
tokenizer = tokenizers.BertWordPieceTokenizer(vocab_file) tokenizer = tokenizers.BertWordPieceTokenizer(vocab_file)
# 使用run_onnx_squad中的convert_examples_to_features方法从输入中获取参数 # 使用run_onnx_squad中的convert_examples_to_features方法从输入中获取参数
...@@ -31,7 +31,7 @@ input_ids, input_mask, segment_ids, extra_data = convert_examples_to_features(ev ...@@ -31,7 +31,7 @@ input_ids, input_mask, segment_ids, extra_data = convert_examples_to_features(ev
# 编译 # 编译
print("INFO: Parsing and compiling the model...") print("INFO: Parsing and compiling the model...")
model = migraphx.parse_onnx("../../../Resource/Models/NLP/Bert/bertsquad-10.onnx") model = migraphx.parse_onnx("../Resource/bertsquad-10.onnx")
model.compile(migraphx.get_target("gpu"),device_id=0) model.compile(migraphx.get_target("gpu"),device_id=0)
n = len(input_ids) n = len(input_ids)
......
...@@ -6,7 +6,13 @@ BERT的全称为Bidirectional Encoder Representation from Transformers,是一 ...@@ -6,7 +6,13 @@ BERT的全称为Bidirectional Encoder Representation from Transformers,是一
## 模型结构 ## 模型结构
以往的预训练模型的结构会受到单向语言模型(从左到右或者从右到左)的限制,因而也限制了模型的表征能力,使其只能获取单方向的上下文信息。而BERT利用MLM进行预训练并且采用深层的双向Transformer组件(单向的Transformer一般被称为Transformer decoder,其每一个token(符号)只会attend到目前往左的token。而双向的Transformer则被称为Transformer encoder,其每一个token会attend到所有的token)来构建整个模型,因此最终生成能融合左右上下文信息的深层双向语言表征。 以往的预训练模型的结构会受到单向语言模型(从左到右或者从右到左)的限制,因而也限制了模型的表征能力,使其只能获取单方向的上下文信息。而BERT利用MLM进行预训练并且采用深层的双向Transformer组件(单向的Transformer一般被称为Transformer decoder,其每一个token(符号)只会attend到目前往左的token。而双向的Transformer则被称为Transformer encoder,其每一个token会attend到所有的token)来构建整个模型,因此最终生成能融合左右上下文信息的深层双向语言表征。
## 构建安装 ## python版本推理
下面介绍如何运行python代码示例,具体推理代码解析,在Doc/Tutorial_Python.md中有详细说明。
本次采用经典的Bert模型完成问题回答任务,模型和分词文件下载链接:https://pan.baidu.com/s/1yc30IzM4ocOpTpfFuUMR0w, 提取码:8f1a, 将bertsquad-10.onnx文件和uncased_L-12_H-768_A-12分词文件保存在Resource/文件夹下。
### 构建安装
在光源可拉取推理的docker镜像,BERT模型推理的镜像如下: 在光源可拉取推理的docker镜像,BERT模型推理的镜像如下:
...@@ -14,13 +20,41 @@ BERT的全称为Bidirectional Encoder Representation from Transformers,是一 ...@@ -14,13 +20,41 @@ BERT的全称为Bidirectional Encoder Representation from Transformers,是一
docker pull image.sourcefind.cn:5000/dcu/admin/base/custom:ort1.14.0_migraphx3.0.0-dtk22.10.1 docker pull image.sourcefind.cn:5000/dcu/admin/base/custom:ort1.14.0_migraphx3.0.0-dtk22.10.1
``` ```
### 安装Opencv依赖 ### 推理示例
1.参考《MIGraphX教程》设置好PYTHONPATH
2.安装依赖:
```python
# 进入migraphx samples工程根目录
cd <path_to_migraphx_samples>
# 进入示例程序目录
cd ./Python/
# 安装依赖
pip install -r requirements.txt
```
3.在Python/NLP/Bert目录下执行如下命令运行该示例程序:
```python ```python
cd <path_to_migraphx_samples> python bert.py
sh ./3rdParty/InstallOpenCVDependences.sh
``` ```
输出结果为:
<img src="./Doc/Images/Bert_05.png" style="zoom:90%;" align=middle>
输出结果中,问题id对应预测概率值最大的答案。
## C++版本推理
下面介绍如何运行C++代码示例,具体推理代码解析,在Doc/Tutorial_Cpp.md目录中有详细说明。
参考Python版本推理中的构建安装,在光源中拉取推理的docker镜像。
### 修改CMakeLists.txt ### 修改CMakeLists.txt
- 如果使用ubuntu系统,需要修改CMakeLists.txt中依赖库路径: - 如果使用ubuntu系统,需要修改CMakeLists.txt中依赖库路径:
...@@ -29,7 +63,7 @@ sh ./3rdParty/InstallOpenCVDependences.sh ...@@ -29,7 +63,7 @@ sh ./3rdParty/InstallOpenCVDependences.sh
- **MIGraphX2.3.0及以上版本需要c++17** - **MIGraphX2.3.0及以上版本需要c++17**
### 安装OpenCV并构建工程 ### 构建工程
``` ```
rbuild build -d depend rbuild build -d depend
...@@ -57,55 +91,22 @@ export LD_LIBRARY_PATH=<path_to_migraphx_samples>/depend/lib/:$LD_LIBRARY_PATH ...@@ -57,55 +91,22 @@ export LD_LIBRARY_PATH=<path_to_migraphx_samples>/depend/lib/:$LD_LIBRARY_PATH
source ~/.bashrc source ~/.bashrc
``` ```
## 推理 ### 推理示例
本次采用经典的Bert模型完成问题回答任务,模型和分词文件下载链接:https://pan.baidu.com/s/1yc30IzM4ocOpTpfFuUMR0w, 提取码:8f1a, 将bertsquad-10.onnx文件和uncased_L-12_H-768_A-12分词文件保存在Resource\Models\NLP\Bert文件夹下。下面介绍如何运行python代码和C++代码示例,具体推理代码解析,在Doc目录中有详细说明。 运行Bert示例程序,具体执行如下命令:
### python版本推理
1.参考《MIGraphX教程》中的安装方法安装MIGraphX并设置好PYTHONPATH
2.安装依赖:
```python ```python
# 进入migraphx samples工程根目录 # 进入migraphx samples工程根目录
cd <path_to_migraphx_samples> cd <path_to_migraphx_samples>
# 进入示例程序目录 # 进入build目录
cd Python/NLP/Bert
# 安装依赖
pip install -r requirements.txt
```
3.在Python/NLP/Bert目录下执行如下命令运行该示例程序:
```python
python bert.py
```
输出结果为:
<img src="./Doc/Images/Bert_05.png" style="zoom:90%;" align=middle>
输出结果中,问题id对应预测概率值最大的答案。
### C++版本推理
切换到build目录中,执行如下命令:
```python
cd ./build/ cd ./build/
./MIGraphX_Samples
```
根据提示选择运行BERT模型的示例程序
```python # 执行示例程序
./MIGraphX_Samples 0 ./Bert
``` ```
如下所示,会在当前界面中提示输入问题,根据问题得到预测答案 如下所示,会在当前界面中提示输入问题,根据问题得到预测答案
<img src="./Doc/Images/Bert_06.png" style="zoom:100%;" align=middle> <img src="./Doc/Images/Bert_06.png" style="zoom:100%;" align=middle>
......
<?xml version="1.0" encoding="GB2312"?>
<opencv_storage>
<!--Bert-->
<Bert>
<ModelPath>"../Resource/Models/NLP/Bert/bertsquad-10.onnx"</ModelPath>
</Bert>
</opencv_storage>
#include <fstream> #include <Bert.h>
#include <sstream>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include <migraphx/gpu/target.hpp> #include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/quantization.hpp>
#include <CommonUtility.h>
#include <Filesystem.h> #include <Filesystem.h>
#include <SimpleLog.h> #include <SimpleLog.h>
#include <algorithm> #include <algorithm>
#include <string>
#include <vector>
#include <stdexcept> #include <stdexcept>
#include <Bert.h>
#include <tokenization.h> #include <tokenization.h>
namespace migraphxSamples namespace migraphxSamples
{ {
Bert::Bert():logFile(NULL) Bert::Bert()
{ {
} }
Bert::~Bert() Bert::~Bert()
{ {
configurationFile.release();
} }
ErrorCode Bert::Initialize(InitializationParameterOfNLP initParamOfNLPBert) ErrorCode Bert::Initialize()
{ {
// 初始化(获取日志文件,加载配置文件等)
ErrorCode errorCode=DoCommonInitialization(initParamOfNLPBert);
if(errorCode!=SUCCESS)
{
LOG_ERROR(logFile,"fail to DoCommonInitialization\n");
return errorCode;
}
LOG_INFO(logFile,"succeed to DoCommonInitialization\n");
// 获取配置文件参数 // 获取模型文件
FileNode netNode = configurationFile["Bert"]; std::string modelPath="../Resource/bertsquad-10.onnx";
std::string modelPath=initializationParameter.parentPath+(std::string)netNode["ModelPath"];
// 加载模型 // 加载模型
if(Exists(modelPath)==false) if(Exists(modelPath)==false)
{ {
LOG_ERROR(logFile,"%s not exist!\n",modelPath.c_str()); LOG_ERROR(stdout,"%s not exist!\n",modelPath.c_str());
return MODEL_NOT_EXIST; return MODEL_NOT_EXIST;
} }
net = migraphx::parse_onnx(modelPath); net = migraphx::parse_onnx(modelPath);
LOG_INFO(logFile,"succeed to load model: %s\n",GetFileName(modelPath).c_str()); LOG_INFO(stdout,"succeed to load model: %s\n",GetFileName(modelPath).c_str());
// 获取模型输入属性 // 获取模型输入属性
std::unordered_map<std::string, migraphx::shape> input = net.get_parameter_shapes(); std::unordered_map<std::string, migraphx::shape> input = net.get_parameter_shapes();
...@@ -72,55 +57,21 @@ ErrorCode Bert::Initialize(InitializationParameterOfNLP initParamOfNLPBert) ...@@ -72,55 +57,21 @@ ErrorCode Bert::Initialize(InitializationParameterOfNLP initParamOfNLPBert)
// 编译模型 // 编译模型
migraphx::compile_options options; migraphx::compile_options options;
options.device_id=0; // 设置GPU设备,默认为0号设备 options.device_id=0; // 设置GPU设备,默认为0号设备
options.offload_copy=true; // 设置offload_copy options.offload_copy=true;
net.compile(gpuTarget,options); net.compile(gpuTarget,options);
LOG_INFO(logFile,"succeed to compile model: %s\n",GetFileName(modelPath).c_str()); LOG_INFO(stdout,"succeed to compile model: %s\n",GetFileName(modelPath).c_str());
// Run once by itself // warm up
migraphx::parameter_map inputData; std::unordered_map<std::string, migraphx::argument> inputData;
inputData[inputName1]=migraphx::generate_argument(inputShape1); inputData[inputName1]=migraphx::argument(inputShape1);
inputData[inputName2]=migraphx::generate_argument(inputShape2); inputData[inputName2]=migraphx::argument(inputShape2);
inputData[inputName3]=migraphx::generate_argument(inputShape3); inputData[inputName3]=migraphx::argument(inputShape3);
inputData[inputName4]=migraphx::generate_argument(inputShape4); inputData[inputName4]=migraphx::argument(inputShape4);
net.eval(inputData); net.eval(inputData);
return SUCCESS; return SUCCESS;
} }
ErrorCode Bert::DoCommonInitialization(InitializationParameterOfNLP initParamOfNLPBert)
{
initializationParameter = initParamOfNLPBert;
// 获取日志文件
logFile=LogManager::GetInstance()->GetLogFile(initializationParameter.logName);
// 加载配置文件
std::string configFilePath=initializationParameter.configFilePath;
if(!Exists(configFilePath))
{
LOG_ERROR(logFile, "no configuration file!\n");
return CONFIG_FILE_NOT_EXIST;
}
if(!configurationFile.open(configFilePath, FileStorage::READ))
{
LOG_ERROR(logFile, "fail to open configuration file\n");
return FAIL_TO_OPEN_CONFIG_FILE;
}
LOG_INFO(logFile, "succeed to open configuration file\n");
// 修改父路径
std::string &parentPath = initializationParameter.parentPath;
if (!parentPath.empty())
{
if(!IsPathSeparator(parentPath[parentPath.size() - 1]))
{
parentPath+=PATH_SEPARATOR;
}
}
return SUCCESS;
}
ErrorCode Bert::Inference(const std::vector<std::vector<long unsigned int>> &input_ids, ErrorCode Bert::Inference(const std::vector<std::vector<long unsigned int>> &input_ids,
const std::vector<std::vector<long unsigned int>> &input_masks, const std::vector<std::vector<long unsigned int>> &input_masks,
const std::vector<std::vector<long unsigned int>> &segment_ids, const std::vector<std::vector<long unsigned int>> &segment_ids,
...@@ -144,7 +95,7 @@ ErrorCode Bert::Inference(const std::vector<std::vector<long unsigned int>> &inp ...@@ -144,7 +95,7 @@ ErrorCode Bert::Inference(const std::vector<std::vector<long unsigned int>> &inp
} }
} }
migraphx::parameter_map inputData; std::unordered_map<std::string, migraphx::argument> inputData;
std::vector<migraphx::argument> results; std::vector<migraphx::argument> results;
migraphx::argument start_prediction; migraphx::argument start_prediction;
migraphx::argument end_prediction; migraphx::argument end_prediction;
...@@ -153,7 +104,7 @@ ErrorCode Bert::Inference(const std::vector<std::vector<long unsigned int>> &inp ...@@ -153,7 +104,7 @@ ErrorCode Bert::Inference(const std::vector<std::vector<long unsigned int>> &inp
for(int i=0;i<input_ids.size();++i) for(int i=0;i<input_ids.size();++i)
{ {
// 输入数据 // 创建输入数据
inputData[inputName1]=migraphx::argument{inputShape1, (long unsigned int*)position_id[i]}; inputData[inputName1]=migraphx::argument{inputShape1, (long unsigned int*)position_id[i]};
inputData[inputName2]=migraphx::argument{inputShape2, (long unsigned int*)segment_id[i]}; inputData[inputName2]=migraphx::argument{inputShape2, (long unsigned int*)segment_id[i]};
inputData[inputName3]=migraphx::argument{inputShape3, (long unsigned int*)input_mask[i]}; inputData[inputName3]=migraphx::argument{inputShape3, (long unsigned int*)input_mask[i]};
......
...@@ -6,7 +6,6 @@ ...@@ -6,7 +6,6 @@
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <CommonDefinition.h> #include <CommonDefinition.h>
#include <tokenization.h> #include <tokenization.h>
using namespace cuBERT;
namespace migraphxSamples namespace migraphxSamples
{ {
...@@ -31,7 +30,7 @@ public: ...@@ -31,7 +30,7 @@ public:
~Bert(); ~Bert();
ErrorCode Initialize(InitializationParameterOfNLP initParamOfNLPBert); ErrorCode Initialize();
ErrorCode Inference(const std::vector<std::vector<long unsigned int>> &input_ids, ErrorCode Inference(const std::vector<std::vector<long unsigned int>> &input_ids,
const std::vector<std::vector<long unsigned int>> &input_masks, const std::vector<std::vector<long unsigned int>> &input_masks,
...@@ -55,13 +54,6 @@ public: ...@@ -55,13 +54,6 @@ public:
std::string &answer); std::string &answer);
private: private:
ErrorCode DoCommonInitialization(InitializationParameterOfNLP initParamOfNLPBert);
private:
FILE *logFile;
cv::FileStorage configurationFile;
InitializationParameterOfNLP initializationParameter;
std::vector<std::string> tokens_text; std::vector<std::string> tokens_text;
std::vector<std::string> tokens_question; std::vector<std::string> tokens_question;
......
#include <Sample.h>
#include <SimpleLog.h>
#include <Filesystem.h>
#include <Bert.h>
#include <tokenization.h>
#include <fstream>
using namespace std;
using namespace migraphx;
using namespace migraphxSamples;
void Sample_Bert()
{
// 加载Bert模型
Bert bert;
InitializationParameterOfNLP initParamOfNLPBert;
initParamOfNLPBert.parentPath = "";
initParamOfNLPBert.configFilePath = CONFIG_FILE;
initParamOfNLPBert.logName = "";
ErrorCode errorCode = bert.Initialize(initParamOfNLPBert);
if (errorCode != SUCCESS)
{
LOG_ERROR(stdout, "fail to initialize Bert!\n");
exit(-1);
}
LOG_INFO(stdout, "succeed to initialize Bert\n");
int max_seq_length = 256; // 滑动窗口的长度
int max_query_length = 64; // 问题的最大长度
int batch_size = 1; // batch_size值
int n_best_size = 20; // 索引数量
int max_answer_length = 30; // 答案的最大长度
// 上下文文本数据
const char text[] = { u8"ROCm is the first open-source exascale-class platform for accelerated computing that’s also programming-language independent. It brings a philosophy of choice, minimalism and modular software development to GPU computing. You are free to choose or even develop tools and a language run time for your application. ROCm is built for scale, it supports multi-GPU computing and has a rich system run time with the critical features that large-scale application, compiler and language-run-time development requires. Since the ROCm ecosystem is comprised of open technologies: frameworks (Tensorflow / PyTorch), libraries (MIOpen / Blas / RCCL), programming model (HIP), inter-connect (OCD) and up streamed Linux® Kernel support – the platform is continually optimized for performance and extensibility." };
char question[100];
std::vector<std::vector<long unsigned int>> input_ids;
std::vector<std::vector<long unsigned int>> input_masks;
std::vector<std::vector<long unsigned int>> segment_ids;
std::vector<float> start_position;
std::vector<float> end_position;
std::string answer = {};
cuBERT::FullTokenizer tokenizer = cuBERT::FullTokenizer("../Resource/Models/NLP/Bert/uncased_L-12_H-768_A-12/vocab.txt"); // 分词工具
while (true)
{
// 数据前处理
std::cout << "question: ";
cin.getline(question, 100);
bert.Preprocessing(tokenizer, batch_size, max_seq_length, text, question, input_ids, input_masks, segment_ids);
// 推理
double time1 = getTickCount();
bert.Inference(input_ids, input_masks, segment_ids, start_position, end_position);
double time2 = getTickCount();
double elapsedTime = (time2 - time1) * 1000 / getTickFrequency();
// 数据后处理
bert.Postprocessing(n_best_size, max_answer_length, start_position, end_position, answer);
// 打印输出预测结果
std::cout << "answer: " << answer << std::endl;
LOG_INFO(stdout, "inference time:%f ms\n", elapsedTime);
// 清除数据
input_ids.clear();
input_masks.clear();
segment_ids.clear();
start_position.clear();
end_position.clear();
answer = {};
}
}
\ No newline at end of file
// 示例程序
#ifndef __SAMPLE_H__
#define __SAMPLE_H__
// Bert sample
void Sample_Bert();
#endif
\ No newline at end of file
// 常用数据类型和宏定义 // 常用定义
#ifndef __COMMON_DEFINITION_H__ #ifndef __COMMON_DEFINITION_H__
#define __COMMON_DEFINITION_H__ #define __COMMON_DEFINITION_H__
#include <string>
#include <opencv2/opencv.hpp> #include <opencv2/opencv.hpp>
using namespace std;
using namespace cv;
namespace migraphxSamples namespace migraphxSamples
{ {
...@@ -21,20 +17,7 @@ namespace migraphxSamples ...@@ -21,20 +17,7 @@ namespace migraphxSamples
#define CONFIG_FILE "../Resource/Configuration.xml" #define CONFIG_FILE "../Resource/Configuration.xml"
typedef struct __Time typedef enum _ErrorCode
{
string year;
string month;
string day;
string hour;
string minute;
string second;
string millisecond; // ms
string microsecond; // us
string weekDay;
}_Time;
typedef enum _ErrorCode
{ {
SUCCESS=0, // 0 SUCCESS=0, // 0
MODEL_NOT_EXIST, // 模型不存在 MODEL_NOT_EXIST, // 模型不存在
...@@ -44,7 +27,7 @@ typedef enum _ErrorCode ...@@ -44,7 +27,7 @@ typedef enum _ErrorCode
IMAGE_ERROR, // 图像错误 IMAGE_ERROR, // 图像错误
}ErrorCode; }ErrorCode;
typedef struct _ResultOfPrediction typedef struct _ResultOfPrediction
{ {
float confidence; float confidence;
int label; int label;
...@@ -52,24 +35,22 @@ typedef struct _ResultOfPrediction ...@@ -52,24 +35,22 @@ typedef struct _ResultOfPrediction
}ResultOfPrediction; }ResultOfPrediction;
typedef struct _ResultOfDetection typedef struct _ResultOfDetection
{ {
Rect boundingBox; cv::Rect boundingBox;
float confidence; float confidence;
int classID; int classID;
string className; std::string className;
bool exist; bool exist;
_ResultOfDetection():confidence(0.0f),classID(0),exist(true){} _ResultOfDetection():confidence(0.0f),classID(0),exist(true){}
}ResultOfDetection; }ResultOfDetection;
typedef struct _InitializationParameterOfNLP typedef struct _InitializationParameterOfNLP
{ {
std::string parentPath; std::string parentPath;
std::string configFilePath; std::string configFilePath;
cv::Size inputSize;
std::string logName;
}InitializationParameterOfNLP; }InitializationParameterOfNLP;
} }
......
#include <CommonUtility.h> #include <CommonUtility.h>
#include <assert.h>
#include <ctype.h>
#include <time.h>
#include <stdlib.h>
#include <algorithm>
#include <sstream>
#include <vector>
#ifdef _WIN32
#include <io.h>
#include <direct.h>
#include <Windows.h>
#else
#include <unistd.h>
#include <dirent.h>
#include <sys/stat.h>
#include <sys/time.h>
#endif
#include <SimpleLog.h>
namespace migraphxSamples namespace migraphxSamples
{ {
_Time GetCurrentTime3()
{
_Time currentTime;
#if (defined WIN32 || defined _WIN32)
SYSTEMTIME systemTime;
GetLocalTime(&systemTime);
char temp[8] = { 0 };
sprintf(temp, "%04d", systemTime.wYear);
currentTime.year=string(temp);
sprintf(temp, "%02d", systemTime.wMonth);
currentTime.month=string(temp);
sprintf(temp, "%02d", systemTime.wDay);
currentTime.day=string(temp);
sprintf(temp, "%02d", systemTime.wHour);
currentTime.hour=string(temp);
sprintf(temp, "%02d", systemTime.wMinute);
currentTime.minute=string(temp);
sprintf(temp, "%02d", systemTime.wSecond);
currentTime.second=string(temp);
sprintf(temp, "%03d", systemTime.wMilliseconds);
currentTime.millisecond=string(temp);
sprintf(temp, "%d", systemTime.wDayOfWeek);
currentTime.weekDay=string(temp);
#else
struct timeval tv;
struct tm *p;
gettimeofday(&tv, NULL);
p = localtime(&tv.tv_sec);
char temp[8]={0};
sprintf(temp,"%04d",1900+p->tm_year);
currentTime.year=string(temp);
sprintf(temp,"%02d",1+p->tm_mon);
currentTime.month=string(temp);
sprintf(temp,"%02d",p->tm_mday);
currentTime.day=string(temp);
sprintf(temp,"%02d",p->tm_hour);
currentTime.hour=string(temp);
sprintf(temp,"%02d",p->tm_min);
currentTime.minute=string(temp);
sprintf(temp,"%02d",p->tm_sec);
currentTime.second=string(temp);
sprintf(temp,"%03d",tv.tv_usec/1000);
currentTime.millisecond = string(temp);
sprintf(temp, "%03d", tv.tv_usec % 1000);
currentTime.microsecond = string(temp);
sprintf(temp, "%d", p->tm_wday);
currentTime.weekDay = string(temp);
#endif
return currentTime;
}
std::vector<std::string> SplitString(std::string str, std::string separator)
{
std::string::size_type pos;
std::vector<std::string> result;
str+=separator;//扩展字符串以方便操作
int size=str.size();
for(int i=0; i<size; i++)
{
pos=str.find(separator,i);
if(pos<size)
{
std::string s=str.substr(i,pos-i);
result.push_back(s);
i=pos+separator.size()-1;
}
}
return result;
}
bool CompareConfidence(const ResultOfDetection &L,const ResultOfDetection &R) bool CompareConfidence(const ResultOfDetection &L,const ResultOfDetection &R)
{ {
return L.confidence > R.confidence; return L.confidence > R.confidence;
...@@ -109,7 +13,7 @@ bool CompareArea(const ResultOfDetection &L,const ResultOfDetection &R) ...@@ -109,7 +13,7 @@ bool CompareArea(const ResultOfDetection &L,const ResultOfDetection &R)
return L.boundingBox.area() > R.boundingBox.area(); return L.boundingBox.area() > R.boundingBox.area();
} }
void NMS(vector<ResultOfDetection> &detections, float IOUThreshold) void NMS(std::vector<ResultOfDetection> &detections, float IOUThreshold)
{ {
// sort // sort
std::sort(detections.begin(), detections.end(), CompareConfidence); std::sort(detections.begin(), detections.end(), CompareConfidence);
......
...@@ -3,23 +3,16 @@ ...@@ -3,23 +3,16 @@
#ifndef __COMMON_UTILITY_H__ #ifndef __COMMON_UTILITY_H__
#define __COMMON_UTILITY_H__ #define __COMMON_UTILITY_H__
#include <mutex>
#include <string>
#include <vector>
#include <CommonDefinition.h> #include <CommonDefinition.h>
using namespace std;
namespace migraphxSamples namespace migraphxSamples
{ {
// 分割字符串
std::vector<std::string> SplitString(std::string str,std::string separator);
// 排序规则: 按照置信度或者按照面积排序 // 排序规则: 按照置信度或者按照面积排序
bool CompareConfidence(const ResultOfDetection &L,const ResultOfDetection &R); bool CompareConfidence(const ResultOfDetection &L,const ResultOfDetection &R);
bool CompareArea(const ResultOfDetection &L,const ResultOfDetection &R); bool CompareArea(const ResultOfDetection &L,const ResultOfDetection &R);
// 非极大抑制
void NMS(std::vector<ResultOfDetection> &detections, float IOUThreshold); void NMS(std::vector<ResultOfDetection> &detections, float IOUThreshold);
} }
......
...@@ -11,12 +11,7 @@ ...@@ -11,12 +11,7 @@
#include <unistd.h> #include <unistd.h>
#include <dirent.h> #include <dirent.h>
#endif #endif
#include <CommonUtility.h>
#include <opencv2/opencv.hpp>
#include <SimpleLog.h>
using namespace cv;
// 路径分隔符(Linux:‘/’,Windows:’\\’) // 路径分隔符(Linux:‘/’,Windows:’\\’)
#ifdef _WIN32 #ifdef _WIN32
#define PATH_SEPARATOR '\\' #define PATH_SEPARATOR '\\'
...@@ -24,9 +19,31 @@ using namespace cv; ...@@ -24,9 +19,31 @@ using namespace cv;
#define PATH_SEPARATOR '/' #define PATH_SEPARATOR '/'
#endif #endif
using namespace std;
namespace migraphxSamples namespace migraphxSamples
{ {
static std::vector<std::string> SplitString(std::string str, std::string separator)
{
std::string::size_type pos;
std::vector<std::string> result;
str+=separator;//扩展字符串以方便操作
int size=str.size();
for(int i=0; i<size; i++)
{
pos=str.find(separator,i);
if(pos<size)
{
std::string s=str.substr(i,pos-i);
result.push_back(s);
i=pos+separator.size()-1;
}
}
return result;
}
#if defined _WIN32 || defined WINCE #if defined _WIN32 || defined WINCE
const char dir_separators[] = "/\\"; const char dir_separators[] = "/\\";
...@@ -293,7 +310,7 @@ namespace migraphxSamples ...@@ -293,7 +310,7 @@ namespace migraphxSamples
} }
else else
{ {
LOG_INFO(stdout, "could not open directory: %s", directory.c_str()); printf("could not open directory: %s", directory.c_str());
} }
} }
...@@ -390,7 +407,7 @@ namespace migraphxSamples ...@@ -390,7 +407,7 @@ namespace migraphxSamples
#endif #endif
if (!result) if (!result)
{ {
LOG_INFO(stdout, "can't remove directory: %s\n", path.c_str()); printf("can't remove directory: %s\n", path.c_str());
} }
} }
else else
...@@ -402,7 +419,7 @@ namespace migraphxSamples ...@@ -402,7 +419,7 @@ namespace migraphxSamples
#endif #endif
if (!result) if (!result)
{ {
LOG_INFO(stdout, "can't remove file: %s\n", path.c_str()); printf("can't remove file: %s\n", path.c_str());
} }
} }
} }
...@@ -438,7 +455,7 @@ namespace migraphxSamples ...@@ -438,7 +455,7 @@ namespace migraphxSamples
{ {
RemoveAll(path); RemoveAll(path);
++numberOfFiles; ++numberOfFiles;
LOG_INFO(stdout, "%s deleted! number of deleted files:%d\n", path.c_str(), numberOfFiles); printf("%s deleted! number of deleted files:%d\n", path.c_str(), numberOfFiles);
} }
} }
...@@ -452,7 +469,7 @@ namespace migraphxSamples ...@@ -452,7 +469,7 @@ namespace migraphxSamples
} }
else else
{ {
LOG_INFO(stdout, "could not open directory: %s", directory.c_str()); printf("could not open directory: %s", directory.c_str());
} }
// ����RemoveAllɾ��Ŀ¼ // ����RemoveAllɾ��Ŀ¼
...@@ -592,17 +609,17 @@ namespace migraphxSamples ...@@ -592,17 +609,17 @@ namespace migraphxSamples
if(!srcFile.is_open()) if(!srcFile.is_open())
{ {
LOG_ERROR(stdout,"can not open %s\n",srcPath.c_str()); printf("can not open %s\n",srcPath.c_str());
return false; return false;
} }
if(!dstFile.is_open()) if(!dstFile.is_open())
{ {
LOG_ERROR(stdout, "can not open %s\n", dstPath.c_str()); printf("can not open %s\n", dstPath.c_str());
return false; return false;
} }
if(srcPath==dstPath) if(srcPath==dstPath)
{ {
LOG_ERROR(stdout, "src can not be same with dst\n"); printf("src can not be same with dst\n");
return false; return false;
} }
char buffer[2048]; char buffer[2048];
...@@ -622,7 +639,7 @@ namespace migraphxSamples ...@@ -622,7 +639,7 @@ namespace migraphxSamples
{ {
if(srcPath==dstPath) if(srcPath==dstPath)
{ {
LOG_ERROR(stdout, "src can not be same with dst\n"); printf("src can not be same with dst\n");
return false; return false;
} }
...@@ -662,9 +679,9 @@ namespace migraphxSamples ...@@ -662,9 +679,9 @@ namespace migraphxSamples
// process // process
double process = (1.0*(i + 1) / fileNameList.size()) * 100; double process = (1.0*(i + 1) / fileNameList.size()) * 100;
LOG_INFO(stdout, "%s done! %f% \n", GetFileName(fileNameList[i]).c_str(), process); printf("%s done! %f% \n", GetFileName(fileNameList[i]).c_str(), process);
} }
LOG_INFO(stdout, "all done!(the number of files:%d)\n", fileNameList.size()); printf("all done!(the number of files:%d)\n", fileNameList.size());
return true; return true;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment