bert.py 3.33 KB
Newer Older
yangql's avatar
yangql committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import numpy as np
import json
import os.path
import tokenizers
import collections
from run_onnx_squad import read_squad_examples, write_predictions, convert_examples_to_features
import onnxruntime as ort

RawResult = collections.namedtuple("RawResult",
                                   ["unique_id", "start_logits", "end_logits"])

# 数据前处理
input_file = '../Resource/inputs_data.json'

# 使用run_onnx_squad中的read_squad_examples方法读取输入文件,进行数据处理,将文本拆分成一个个单词
eval_examples = read_squad_examples(input_file)  

max_seq_length = 256    # 规定输入文本的最大长度
doc_stride = 256        # 滑动的窗口大小
max_query_length = 64   # 问题的最大长度
batch_size = 1          # batch_size值
n_best_size = 20        # 预选数量
max_answer_length = 30  # 问题的最大长度

# 分词工具
vocab_file = os.path.join('../Resource/uncased_L-12_H-768_A-12', 'vocab.txt')
tokenizer = tokenizers.BertWordPieceTokenizer(vocab_file)

# 使用run_onnx_squad中的convert_examples_to_features方法从输入中获取参数
input_ids, input_mask, segment_ids, extra_data = convert_examples_to_features(eval_examples, tokenizer, max_seq_length, doc_stride, max_query_length)

# 加载模型
print("INFO: Parsing and compiling the model")
sess_options = ort.SessionOptions()

#设置图优化
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC

#是否开启profiling
sess_options.enable_profiling = False

#加载模型
dcu_session = ort.InferenceSession("../Resource/bertsquad-10.onnx",sess_options,providers=['ROCMExecutionProvider'],)
input_name=dcu_session.get_inputs()[0].name                     

n = len(input_ids)
bs = batch_size
all_results = []

for idx in range(0, n):
    item = eval_examples[idx]

    # 推理
    result = dcu_session.run(None, {
        "unique_ids_raw_output___9:0":
        np.array([item.qas_id], dtype=np.int64),   # position id
        "input_ids:0":
        input_ids[idx:idx + bs],                   # Token id,对应的文本数据转换为数值型数据
        "input_mask:0":
        input_mask[idx:idx + bs],                  # 掩码
        "segment_ids:0":
        segment_ids[idx:idx + bs]                  # segment id,对上下文文本和问题赋予不同的位置向量
    })

    in_batch = result[1].shape[0]
    npresule1 =np.array(result[0])
    npresule2 =np.array(result[1])
    start_logits = [float(x) for x in npresule1.flatten()]  # 答案起始位置的概率值
    end_logits = [float(x) for x in npresule2.flatten()]    # 答案结束位置的概率值

    for i in range(0, in_batch):
        unique_id = len(all_results)
        all_results.append(
            RawResult(unique_id=unique_id, start_logits=start_logits, end_logits=end_logits)) 

# 数据后处理,获取预测结果
output_dir = 'predictions'               
os.makedirs(output_dir, exist_ok=True)    
output_prediction_file = os.path.join(output_dir, "predictions.json")  
output_nbest_file = os.path.join(output_dir, "nbest_predictions.json")
write_predictions(eval_examples, extra_data, all_results, n_best_size,
                  max_answer_length, True, output_prediction_file,
                  output_nbest_file)     

with open(output_prediction_file) as json_file:
    test_data = json.load(json_file)
    print(json.dumps(test_data, indent=2))