Bert.h 2.15 KB
Newer Older
yangql's avatar
yangql committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#ifndef __BERT_H__
#define __BERT_H__

#include <cstdint>
#include <string>
#include <onnxruntime/core/session/onnxruntime_cxx_api.h>
#include <tokenization.h>

namespace ortSamples
{
    typedef enum _ErrorCode
    {
        SUCCESS=0, 
        MODEL_NOT_EXIST, 
        CONFIG_FILE_NOT_EXIST, 
        FAIL_TO_LOAD_MODEL, 
        FAIL_TO_OPEN_CONFIG_FILE, 
    }ErrorCode;

    typedef struct _Sort_st
    {
        int index;
        float value;
    }Sort_st;

    typedef struct _ResultOfPredictions
    {
        int start_index;
        int end_index;
        float start_predictionvalue;
        float end_predictionvalue;
    }ResultOfPredictions;

class Bert
{
public:
    Bert();
    
    ~Bert();

    ErrorCode Initialize();

    ErrorCode Inference(const std::vector<std::vector<long unsigned int>> &input_ids, 
                        const std::vector<std::vector<long unsigned int>> &input_masks, 
                        const std::vector<std::vector<long unsigned int>> &segment_ids,
                        std::vector<float> &start_position,
                        std::vector<float> &end_position);

    ErrorCode Preprocessing(cuBERT::FullTokenizer tokenizer,
                             int batch_size,
                             int max_seq_length,
                             const char *text,
                             char *question,
                             std::vector<std::vector<long unsigned int>> &input_ids, 
                             std::vector<std::vector<long unsigned int>> &input_masks, 
                             std::vector<std::vector<long unsigned int>> &segment_ids);

    ErrorCode Postprocessing(int n_best_size, 
                             int max_answer_length, 
                             const std::vector<float> &start_position,
                             const std::vector<float> &end_position, 
                             std::string &answer);

private:
    std::vector<std::string> tokens_text;
    std::vector<std::string> tokens_question;
    Ort::Session *session;
    Ort::Env env = Ort::Env(ORT_LOGGING_LEVEL_ERROR, "ONNXRuntime");
    Ort::SessionOptions sessionOptions = Ort::SessionOptions();
};

}

#endif