"examples/community/stable_diffusion_controlnet_img2img.py" did not exist on "0c39f53cbb2724b9706a5a8397e8c5ace414aa96"
main.cpp 1.95 KB
Newer Older
yangql's avatar
yangql committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <GPT2.h>
#include <fstream>
#include <SimpleLog.h>
#include <Filesystem.h>
#include <tokenization.h>

int main()
{
    // 加载GPT2模型
    ortSamples::GPT2 gpt2;
    ortSamples::ErrorCode errorCode = gpt2.Initialize();
    if (errorCode != ortSamples::SUCCESS)
    {
        LOG_ERROR(stdout, "fail to initialize GPT2!\n");
        exit(-1);
    }
    LOG_INFO(stdout, "succeed to initialize GPT2\n");

    // 加载词汇表,用于编码和解码
    cuBERT::FullTokenizer tokenizer = cuBERT::FullTokenizer("../Resource/vocab_shici.txt");
    std::ifstream infile;
    std::string buf;
    std::vector<std::string> output;
    infile.open("../Resource/vocab_shici.txt");
    while (std::getline(infile,buf))
    {
        output.push_back(buf);
    }

    std::vector<long unsigned int> input_id;
    char question[100];
    std::vector<long unsigned int> score;
    std::vector<std::string> result;

    std::cout << "开始和GPT2对诗,输入CTRL + Z以退出" << std::endl;
    while (true) 
    {
        // 数据预处理
        std::cout << "question: ";
        cin.getline(question, 100);
        gpt2.Preprocessing(tokenizer, question, input_id);

        // 推理
        for(int i=0;i<50;++i)
        {
            long unsigned int outputs = gpt2.Inference(input_id);
            if(outputs == 102)
            {
                break;
            }
            input_id.push_back(outputs);
            score.push_back(outputs);
        }

        // 将数值映射为字符
        for(int i=0;i<score.size();++i)
        {
            result.push_back(output[score[i]]);
        }

        // 打印结果
        std::cout << "chatbot: ";
        std::cout << question;
        for(int j=0; j<result.size();++j)
        {
            std::cout<<result[j]<<std::endl;
        }

        // 清除数据
        input_id.clear();
        result.clear();
        score.clear();
    }
    return 0;
}