main.py 3.21 KB
Newer Older
sugon_cxj's avatar
sugon_cxj committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import numpy as np
import os.path
import time

from transformers import BertTokenizer, BertForMaskedLM, AutoTokenizer
from onnxruntime import InferenceSession, SessionOptions, get_all_providers

def main():
    tokenizer = AutoTokenizer.from_pretrained('./')
    # context = 'ONNX is an open format to represent models. The benefits of using ONNX include interoperability of frameworks and hardware optimization.'
    # question = 'What are advantages of ONNX?'

    # context = '今天天气晴'
    # question = '今天天气怎么样?'

    # context = '中国历史有5000年'
    # question = '中国历史有多少年?'

    context = 'ROCM是AMD的一个软件平台,用来加速GPU计算'
    question = 'ROCM用来干什么?'

    session = InferenceSession("./model.onnx", providers=[('ROCMExecutionProvider', {'device_id': '4'}),'CPUExecutionProvider'])
    session_fp16 = InferenceSession("./model_fp16.onnx", providers=[('ROCMExecutionProvider', {'device_id': '4'}),'CPUExecutionProvider'])

    #获取模型原始输入的字段名称
    input_names=[]
    input_shapes=[]
    for i in range(len(session.get_inputs())):
        input_names.append(session.get_inputs()[i].name)
        input_shapes.append(session.get_inputs()[i].shape)
    print("input_names:",input_names)
    print("input_shapes:",input_shapes)

    #统计模型输出个数
    output_names=[]
    for i in range(len(session.get_outputs())):
        output_names.append(session.get_outputs()[i].name)
    print("output_names:",output_names)

    inputs = tokenizer(question, context, padding=True, truncation=False, return_tensors='np')
    inputs_int64 = {key: np.array(inputs[key], dtype=np.int64) for key in inputs}
    print("inputs:",tokenizer.decode(inputs.input_ids[0]))

    input_ids_zeros = np.zeros((1,384),np.int64)
    input_mask_zeros = np.zeros((1,384),np.int64)
    segment_ids_zeros = np.zeros((1,384),np.int64)

    for i in range(len(inputs.input_ids[0])):
        input_ids_zeros[0][i] = inputs.input_ids[0][i]
        input_mask_zeros[0][i] = inputs.attention_mask[0][i]
        segment_ids_zeros[0][i] = inputs.token_type_ids[0][i]

    onnx_input = {input_names[0]:input_ids_zeros, 
                    input_names[1]:input_mask_zeros, 
                    input_names[2]:segment_ids_zeros }

    for i in range(10):
        t1 = time.perf_counter()
        outputs = session.run(input_feed=dict(onnx_input), output_names=None)
        t2 = time.perf_counter()
        print("fp32:",i,t2 - t1)
    answer_start_index = outputs[0].argmax()
    answer_end_index = outputs[1].argmax()
    predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
    print("results fp32:",tokenizer.decode(predict_answer_tokens))

    for i in range(10):
        t1 = time.perf_counter()
        outputs_fp16 = session_fp16.run(input_feed=dict(onnx_input), output_names=None)
        t2 = time.perf_counter()
        print("fp16:",i,t2 - t1)
    answer_start_index_fp16 = outputs_fp16[0].argmax()
    answer_end_index_fp16 = outputs_fp16[1].argmax()
    predict_answer_tokens_fp16 = inputs.input_ids[0, answer_start_index_fp16 : answer_end_index_fp16 + 1]
    print("results fp16:",tokenizer.decode(predict_answer_tokens_fp16))

if __name__ == "__main__":
    main()