main.cpp 1.99 KB
Newer Older
1
2
3
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
liucong's avatar
liucong committed
4
#include <GPT2.h>
liucong's avatar
liucong committed
5
6
#include <fstream>
#include <SimpleLog.h>
liucong's avatar
liucong committed
7
#include <Filesystem.h>
liucong's avatar
liucong committed
8
#include <tokenization.h>
9

liucong's avatar
liucong committed
10
int main()
11
{
liucong's avatar
liucong committed
12
13
14
    // 加载GPT2模型
    migraphxSamples::GPT2 gpt2;
    migraphxSamples::ErrorCode errorCode = gpt2.Initialize();
liucong's avatar
liucong committed
15
    if (errorCode != migraphxSamples::SUCCESS)
16
    {
liucong's avatar
liucong committed
17
18
        LOG_ERROR(stdout, "fail to initialize GPT2!\n");
        exit(-1);
19
    }
liucong's avatar
liucong committed
20
21
22
    LOG_INFO(stdout, "succeed to initialize GPT2\n");

    // 加载词汇表,用于编码和解码
liucong's avatar
liucong committed
23
    cuBERT::FullTokenizer tokenizer = cuBERT::FullTokenizer("../Resource/vocab_shici.txt");
liucong's avatar
liucong committed
24
25
26
    std::ifstream infile;
    std::string buf;
    std::vector<std::string> output;
liucong's avatar
liucong committed
27
    infile.open("../Resource/vocab_shici.txt");
liucong's avatar
liucong committed
28
    while (std::getline(infile,buf))
29
    {
liucong's avatar
liucong committed
30
        output.push_back(buf);
31
    }
liucong's avatar
liucong committed
32
33
34
35
36
37
38
39
40

    std::vector<long unsigned int> input_id;
    char question[100];

    std::vector<long unsigned int> score;
    std::vector<std::string> result;

    std::cout << "开始和GPT2对诗,输入CTRL + Z以退出" << std::endl;
    while (true) 
41
    {
liucong's avatar
liucong committed
42
43
44
45
46
47
48
49
50
51
        // 数据预处理
        std::cout << "question: ";
        cin.getline(question, 100);
        gpt2.Preprocessing(tokenizer, question, input_id);

        // 推理
        for(int i=0;i<50;++i)
        {
            long unsigned int outputs = gpt2.Inference(input_id);
            if(outputs == 102)
52
53
54
            {
                break;
            }
liucong's avatar
liucong committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
            input_id.push_back(outputs);
            score.push_back(outputs);
        }

        // 将数值映射为字符
        for(int i=0;i<score.size();++i)
        {
            result.push_back(output[score[i]]);
        }

        // 打印结果
        std::cout << "chatbot: ";
        std::cout << question;
        for(int j=0; j<result.size();++j)
        {
            std::cout << result[j];
        }
        std::cout << std::endl;
        
        // 清除数据
        input_id.clear();
        result.clear();
        score.clear();
78
    }
liucong's avatar
liucong committed
79

80
81
    return 0;
}