Commit c99ee00b authored by liucong's avatar liucong
Browse files

修改部分代码和文档

parent 9cffb74e
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <Sample.h>
#include <gpt2.h>
#include <fstream>
#include <SimpleLog.h>
#include <tokenization.h>
void MIGraphXSamplesUsage(char* programName)
int main()
{
printf("Usage : %s <index> \n", programName);
printf("index:\n");
printf("\t 0) GPT2 sample.\n");
}
int main(int argc, char *argv[])
{
if (argc < 2 || argc > 2)
// 加载GPT2模型
migraphxSamples::GPT2 gpt2;
migraphxSamples::ErrorCode errorCode = gpt2.Initialize();
if (errorCode != SUCCESS)
{
MIGraphXSamplesUsage(argv[0]);
return -1;
LOG_ERROR(stdout, "fail to initialize GPT2!\n");
exit(-1);
}
if (!strncmp(argv[1], "-h", 2))
LOG_INFO(stdout, "succeed to initialize GPT2\n");
// 加载词汇表,用于编码和解码
cuBERT::FullTokenizer tokenizer = cuBERT::FullTokenizer("../Resource/Models/vocab_shici.txt");
std::ifstream infile;
std::string buf;
std::vector<std::string> output;
infile.open("../Resource/Models/vocab_shici.txt");
while (std::getline(infile,buf))
{
MIGraphXSamplesUsage(argv[0]);
return 0;
output.push_back(buf);
}
switch (*argv[1])
std::vector<long unsigned int> input_id;
char question[100];
std::vector<long unsigned int> score;
std::vector<std::string> result;
std::cout << "开始和GPT2对诗,输入CTRL + Z以退出" << std::endl;
while (true)
{
case '0':
{
Sample_GPT2();
break;
}
default :
// 数据预处理
std::cout << "question: ";
cin.getline(question, 100);
gpt2.Preprocessing(tokenizer, question, input_id);
// 推理
for(int i=0;i<50;++i)
{
long unsigned int outputs = gpt2.Inference(input_id);
if(outputs == 102)
{
MIGraphXSamplesUsage(argv[0]);
break;
}
input_id.push_back(outputs);
score.push_back(outputs);
}
// 将数值映射为字符
for(int i=0;i<score.size();++i)
{
result.push_back(output[score[i]]);
}
// 打印结果
std::cout << "chatbot: ";
std::cout << question;
for(int j=0; j<result.size();++j)
{
std::cout << result[j];
}
std::cout << std::endl;
// 清除数据
input_id.clear();
result.clear();
score.clear();
}
return 0;
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment