Bert.h 2.24 KB
Newer Older
liucong's avatar
liucong committed
1
2
#ifndef __BERT_H__
#define __BERT_H__
3
4
5
6
7
8
9
10

#include <cstdint>
#include <string>
#include <migraphx/program.hpp>
#include <tokenization.h>

namespace migraphxSamples
{
liucong's avatar
liucong committed
11
12
13
14
15
16
17
18
19
    typedef enum _ErrorCode
    {
        SUCCESS=0, 
        MODEL_NOT_EXIST, 
        CONFIG_FILE_NOT_EXIST, 
        FAIL_TO_LOAD_MODEL, 
        FAIL_TO_OPEN_CONFIG_FILE, 
    }ErrorCode;

20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
    typedef struct _Sort_st
    {
        int index;
        float value;
    }Sort_st;

    typedef struct _ResultOfPredictions
    {
        int start_index;
        int end_index;
        float start_predictionvalue;
        float end_predictionvalue;
    }ResultOfPredictions;

class Bert
{
public:
    Bert();
    
    ~Bert();

liucong's avatar
liucong committed
41
    ErrorCode Initialize();
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82

    ErrorCode Inference(const std::vector<std::vector<long unsigned int>> &input_ids, 
                        const std::vector<std::vector<long unsigned int>> &input_masks, 
                        const std::vector<std::vector<long unsigned int>> &segment_ids,
                        std::vector<float> &start_position,
                        std::vector<float> &end_position);

    ErrorCode Preprocessing(cuBERT::FullTokenizer tokenizer,
                             int batch_size,
                             int max_seq_length,
                             const char *text,
                             char *question,
                             std::vector<std::vector<long unsigned int>> &input_ids, 
                             std::vector<std::vector<long unsigned int>> &input_masks, 
                             std::vector<std::vector<long unsigned int>> &segment_ids);

    ErrorCode Postprocessing(int n_best_size, 
                             int max_answer_length, 
                             const std::vector<float> &start_position,
                             const std::vector<float> &end_position, 
                             std::string &answer);

private:
    std::vector<std::string> tokens_text;
    std::vector<std::string> tokens_question;

    migraphx::program net;
    std::string inputName1;
    std::string inputName2;
    std::string inputName3;
    std::string inputName4;
    migraphx::shape inputShape1;
    migraphx::shape inputShape2;
    migraphx::shape inputShape3;
    migraphx::shape inputShape4;

};

}

#endif