Bert.h 2.35 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#ifndef BERT_H
#define BERT_H

#include <cstdint>
#include <string>
#include <migraphx/program.hpp>
#include <CommonDefinition.h>
#include <tokenization.h>
using namespace cuBERT;

namespace migraphxSamples
{
    typedef struct _Sort_st
    {
        int index;
        float value;
    }Sort_st;

    typedef struct _ResultOfPredictions
    {
        int start_index;
        int end_index;
        float start_predictionvalue;
        float end_predictionvalue;
    }ResultOfPredictions;

class Bert
{
public:
    Bert();
    
    ~Bert();

    ErrorCode Initialize(InitializationParameterOfNLP initParamOfNLPBert);

    ErrorCode Inference(const std::vector<std::vector<long unsigned int>> &input_ids, 
                        const std::vector<std::vector<long unsigned int>> &input_masks, 
                        const std::vector<std::vector<long unsigned int>> &segment_ids,
                        std::vector<float> &start_position,
                        std::vector<float> &end_position);

    ErrorCode Preprocessing(cuBERT::FullTokenizer tokenizer,
                             int batch_size,
                             int max_seq_length,
                             const char *text,
                             char *question,
                             std::vector<std::vector<long unsigned int>> &input_ids, 
                             std::vector<std::vector<long unsigned int>> &input_masks, 
                             std::vector<std::vector<long unsigned int>> &segment_ids);

    ErrorCode Postprocessing(int n_best_size, 
                             int max_answer_length, 
                             const std::vector<float> &start_position,
                             const std::vector<float> &end_position, 
                             std::string &answer);

private:
    ErrorCode DoCommonInitialization(InitializationParameterOfNLP initParamOfNLPBert);

private:
    FILE *logFile;
    cv::FileStorage configurationFile;
    InitializationParameterOfNLP initializationParameter;

    std::vector<std::string> tokens_text;
    std::vector<std::string> tokens_question;

    migraphx::program net;
    std::string inputName1;
    std::string inputName2;
    std::string inputName3;
    std::string inputName4;
    migraphx::shape inputShape1;
    migraphx::shape inputShape2;
    migraphx::shape inputShape3;
    migraphx::shape inputShape4;

};

}

#endif