main.cpp 3.3 KB
Newer Older
liucong's avatar
liucong committed
1
2
3
#include <Bert.h>
#include <Filesystem.h>
#include <SimpleLog.h>
4
5
6
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
liucong's avatar
liucong committed
7
#include <tokenization.h>
8

liucong's avatar
liucong committed
9
int main(int argc, char* argv[])
10
{
liucong's avatar
liucong committed
11
12
13
    // 加载Bert模型
    migraphxSamples::Bert bert;
    migraphxSamples::ErrorCode errorCode = bert.Initialize();
liucong's avatar
liucong committed
14
    if(errorCode != migraphxSamples::SUCCESS)
15
    {
liucong's avatar
liucong committed
16
17
        LOG_ERROR(stdout, "fail to initialize Bert!\n");
        exit(-1);
18
    }
liucong's avatar
liucong committed
19
20
    LOG_INFO(stdout, "succeed to initialize Bert\n");

liucong's avatar
liucong committed
21
22
23
24
25
    int max_seq_length    = 256; // 滑动窗口的长度
    int max_query_length  = 64;  // 问题的最大长度
    int batch_size        = 1;   // batch_size值
    int n_best_size       = 20;  // 索引数量
    int max_answer_length = 30;  // 答案的最大长度
liucong's avatar
liucong committed
26
27

    // 上下文文本数据
liucong's avatar
liucong committed
28
29
30
31
32
33
34
35
36
37
38
39
    const char text[] = {u8"ROCm is the first open-source exascale-class platform for accelerated "
                         u8"computing that’s also programming-language independent. It brings a "
                         u8"philosophy of choice, minimalism and modular software development to "
                         u8"GPU computing. You are free to choose or even develop tools and a "
                         u8"language run time for your application. ROCm is built for scale, it "
                         u8"supports multi-GPU computing and has a rich system run time with the "
                         u8"critical features that large-scale application, compiler and "
                         u8"language-run-time development requires. Since the ROCm ecosystem is "
                         u8"comprised of open technologies: frameworks (Tensorflow / PyTorch), "
                         u8"libraries (MIOpen / Blas / RCCL), programming model (HIP), "
                         u8"inter-connect (OCD) and up streamed Linux® Kernel support – the "
                         u8"platform is continually optimized for performance and extensibility."};
liucong's avatar
liucong committed
40
41
42
43
44
45
46
47
48
    char question[100];

    std::vector<std::vector<long unsigned int>> input_ids;
    std::vector<std::vector<long unsigned int>> input_masks;
    std::vector<std::vector<long unsigned int>> segment_ids;
    std::vector<float> start_position;
    std::vector<float> end_position;
    std::string answer = {};

liucong's avatar
liucong committed
49
50
    cuBERT::FullTokenizer tokenizer =
        cuBERT::FullTokenizer("../Resource/uncased_L-12_H-768_A-12/vocab.txt"); // 分词工具
liucong's avatar
liucong committed
51

liucong's avatar
liucong committed
52
    while(true)
53
    {
liucong's avatar
liucong committed
54
55
56
        // 数据前处理
        std::cout << "question: ";
        cin.getline(question, 100);
liucong's avatar
liucong committed
57
58
59
60
61
62
63
64
        bert.Preprocessing(tokenizer,
                           batch_size,
                           max_seq_length,
                           text,
                           question,
                           input_ids,
                           input_masks,
                           segment_ids);
liucong's avatar
liucong committed
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81

        // 推理
        bert.Inference(input_ids, input_masks, segment_ids, start_position, end_position);

        // 数据后处理
        bert.Postprocessing(n_best_size, max_answer_length, start_position, end_position, answer);

        // 打印输出预测结果
        std::cout << "answer: " << answer << std::endl;

        // 清除数据
        input_ids.clear();
        input_masks.clear();
        segment_ids.clear();
        start_position.clear();
        end_position.clear();
        answer = {};
82
    }
liucong's avatar
liucong committed
83

84
85
    return 0;
}