#include #include #include #include #include #include #include #include namespace migraphxSamples { GPT2::GPT2() { } GPT2::~GPT2() { } ErrorCode GPT2::Initialize() { // 获取模型文件 std::string modelPath="../Resource/GPT2_shici.onnx"; // 设置最大输入shape migraphx::onnx_options onnx_options; onnx_options.map_input_dims["input"] = {1,1000}; // 加载模型 if(!Exists(modelPath)) { LOG_ERROR(stdout,"%s not exist!\n",modelPath.c_str()); return MODEL_NOT_EXIST; } net = migraphx::parse_onnx(modelPath, onnx_options); LOG_INFO(stdout,"succeed to load model: %s\n",GetFileName(modelPath).c_str()); // 获取模型输入/输出节点信息 std::unordered_map inputs = net.get_inputs(); std::unordered_map outputs = net.get_outputs(); inputName=inputs.begin()->first; inputShape=inputs.begin()->second; // 设置模型为GPU模式 migraphx::target gpuTarget = migraphx::gpu::target{}; // 编译模型 migraphx::compile_options options; options.device_id = 0; // 设置GPU设备,默认为0号设备 options.offload_copy = true; // 设置offload_copy net.compile(gpuTarget,options); LOG_INFO(stdout,"succeed to compile model: %s\n",GetFileName(modelPath).c_str()); return SUCCESS; } static bool CompareM(Predictions a, Predictions b) { return a.predictionvalue > b.predictionvalue; } long unsigned int GPT2::Inference(const std::vector &input_id) { long unsigned int input[1][input_id.size()]; for (int j=0;j> inputShapes; inputShapes.push_back({1,input_id.size()}); // 创建输入数据 std::unordered_map inputData; inputData[inputName]=migraphx::argument{migraphx::shape(inputShape.type(),inputShapes[0]),(long unsigned int*)input}; // 推理 std::vector results = net.eval(inputData); // 获取输出节点的属性 migraphx::argument result = results[0]; migraphx::shape outputShape = result.get_shape(); // 输出节点的shape int numberOfOutput=outputShape.elements(); // 输出节点元素的个数 float *data = (float *)result.data(); // 输出节点数据指针 // 保存推理结果 long unsigned int n = 0; std::vector resultsOfPredictions(22557); for(int i=(input_id.size()-1)*22557; i &input_id) { // 分词操作 int max_seq_length = 1000; std::vector tokens_question; tokens_question.reserve(max_seq_length); tokenizer.tokenize(question, &tokens_question, max_seq_length); // 保存编码信息 input_id.push_back(tokenizer.convert_token_to_id("[CLS]")); for (int i=0;i