main.cpp 1.97 KB
Newer Older
1
2
3
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
liucong's avatar
liucong committed
4
5
6
7
#include <gpt2.h>
#include <fstream>
#include <SimpleLog.h>
#include <tokenization.h>
8

liucong's avatar
liucong committed
9
int main()
10
{
liucong's avatar
liucong committed
11
12
13
14
    // 加载GPT2模型
    migraphxSamples::GPT2 gpt2;
    migraphxSamples::ErrorCode errorCode = gpt2.Initialize();
    if (errorCode != SUCCESS)
15
    {
liucong's avatar
liucong committed
16
17
        LOG_ERROR(stdout, "fail to initialize GPT2!\n");
        exit(-1);
18
    }
liucong's avatar
liucong committed
19
20
21
22
23
24
25
26
27
    LOG_INFO(stdout, "succeed to initialize GPT2\n");

    // 加载词汇表,用于编码和解码
    cuBERT::FullTokenizer tokenizer = cuBERT::FullTokenizer("../Resource/Models/vocab_shici.txt");
    std::ifstream infile;
    std::string buf;
    std::vector<std::string> output;
    infile.open("../Resource/Models/vocab_shici.txt");
    while (std::getline(infile,buf))
28
    {
liucong's avatar
liucong committed
29
        output.push_back(buf);
30
    }
liucong's avatar
liucong committed
31
32
33
34
35
36
37
38
39

    std::vector<long unsigned int> input_id;
    char question[100];

    std::vector<long unsigned int> score;
    std::vector<std::string> result;

    std::cout << "开始和GPT2对诗,输入CTRL + Z以退出" << std::endl;
    while (true) 
40
    {
liucong's avatar
liucong committed
41
42
43
44
45
46
47
48
49
50
        // 数据预处理
        std::cout << "question: ";
        cin.getline(question, 100);
        gpt2.Preprocessing(tokenizer, question, input_id);

        // 推理
        for(int i=0;i<50;++i)
        {
            long unsigned int outputs = gpt2.Inference(input_id);
            if(outputs == 102)
51
52
53
            {
                break;
            }
liucong's avatar
liucong committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
            input_id.push_back(outputs);
            score.push_back(outputs);
        }

        // 将数值映射为字符
        for(int i=0;i<score.size();++i)
        {
            result.push_back(output[score[i]]);
        }

        // 打印结果
        std::cout << "chatbot: ";
        std::cout << question;
        for(int j=0; j<result.size();++j)
        {
            std::cout << result[j];
        }
        std::cout << std::endl;
        
        // 清除数据
        input_id.clear();
        result.clear();
        score.clear();
77
    }
liucong's avatar
liucong committed
78

79
80
    return 0;
}