#include <Sample.h>
#include <SimpleLog.h>
#include <GPT2.h>
#include <tokenization.h>
#include <fstream>

using namespace std;
using namespace migraphx;
using namespace migraphxSamples;

void Sample_GPT2()
{
    // 加载GPT2模型
    GPT2 gpt2;
    InitializationParameterOfNLP initParamOfNLPGPT2;
    initParamOfNLPGPT2.parentPath = "";
    initParamOfNLPGPT2.configFilePath = CONFIG_FILE;
    initParamOfNLPGPT2.logName = "";
    ErrorCode errorCode = gpt2.Initialize(initParamOfNLPGPT2);
    if (errorCode != SUCCESS)
    {
        LOG_ERROR(stdout, "fail to initialize GPT2!\n");
        exit(-1);
    }
    LOG_INFO(stdout, "succeed to initialize GPT2\n");

    // 加载词汇表，用于编码和解码
    cuBERT::FullTokenizer tokenizer = cuBERT::FullTokenizer("../Resource/Models/NLP/GPT2/vocab_shici.txt");
    std::ifstream infile;
    std::string buf;
    std::vector<std::string> output;
    infile.open("../Resource/Models/NLP/GPT2/vocab_shici.txt");
    while (std::getline(infile,buf))
    {
        output.push_back(buf);
    }

    std::vector<long unsigned int> input_id;
    char question[100];

    std::vector<long unsigned int> score;
    std::vector<std::string> result;

    std::cout << "开始和GPT2对诗，输入CTRL + Z以退出" << std::endl;
    while (true) 
    {
        // 数据预处理
        std::cout << "question: ";
        cin.getline(question, 100);
        gpt2.Preprocessing(tokenizer, question, input_id);

        // 推理
        for(int i=0;i<50;++i)
        {
            long unsigned int outputs = gpt2.Inference(input_id);
            if(outputs == 102)
            {
                break;
            }
            input_id.push_back(outputs);
            score.push_back(outputs);
        }

        // 将数值映射为字符
        for(int i=0;i<score.size();++i)
        {
            result.push_back(output[score[i]]);
        }

        // 打印结果
        std::cout << "chatbot: ";
        std::cout << question;
        for(int j=0; j<result.size();++j)
        {
            std::cout << result[j];
        }
        std::cout << std::endl;
        
        // 清除数据
        input_id.clear();
        result.clear();
        score.clear();
    }
}
