#include #include #include #include #include #include #include #include int main() { // 加载GPT2模型 migraphxSamples::GPT2 gpt2; migraphxSamples::ErrorCode errorCode = gpt2.Initialize(); if (errorCode != migraphxSamples::SUCCESS) { LOG_ERROR(stdout, "fail to initialize GPT2!\n"); exit(-1); } LOG_INFO(stdout, "succeed to initialize GPT2\n"); // 加载词汇表,用于编码和解码 cuBERT::FullTokenizer tokenizer = cuBERT::FullTokenizer("../Resource/vocab_shici.txt"); std::ifstream infile; std::string buf; std::vector output; infile.open("../Resource/vocab_shici.txt"); while (std::getline(infile,buf)) { output.push_back(buf); } std::vector input_id; char question[100]; std::vector score; std::vector result; std::cout << "开始和GPT2对诗,输入CTRL + Z以退出" << std::endl; while (true) { // 数据预处理 std::cout << "question: "; cin.getline(question, 100); gpt2.Preprocessing(tokenizer, question, input_id); // 推理 for(int i=0;i<50;++i) { long unsigned int outputs = gpt2.Inference(input_id); if(outputs == 102) { break; } input_id.push_back(outputs); score.push_back(outputs); } // 将数值映射为字符 for(int i=0;i