#include #include #include #include #include #include #include #include namespace migraphxSamples { GPT2::GPT2() {} GPT2::~GPT2() {} ErrorCode GPT2::Initialize() { // 获取模型文件 std::string modelPath = "../Resource/GPT2_shici.onnx"; // 设置最大输入shape migraphx::onnx_options onnx_options; onnx_options.map_input_dims["input"] = {1, 1000}; // 加载模型 if(!Exists(modelPath)) { LOG_ERROR(stdout, "%s not exist!\n", modelPath.c_str()); return MODEL_NOT_EXIST; } net = migraphx::parse_onnx(modelPath, onnx_options); LOG_INFO(stdout, "succeed to load model: %s\n", GetFileName(modelPath).c_str()); // 获取模型输入/输出节点信息 std::unordered_map inputs = net.get_inputs(); std::unordered_map outputs = net.get_outputs(); inputName = inputs.begin()->first; inputShape = inputs.begin()->second; // 设置模型为GPU模式 migraphx::target gpuTarget = migraphx::gpu::target{}; // 编译模型 migraphx::compile_options options; options.device_id = 0; // 设置GPU设备,默认为0号设备 options.offload_copy = true; // 设置offload_copy net.compile(gpuTarget, options); LOG_INFO(stdout, "succeed to compile model: %s\n", GetFileName(modelPath).c_str()); return SUCCESS; } static bool CompareM(Predictions a, Predictions b) { return a.predictionvalue > b.predictionvalue; } long unsigned int GPT2::Inference(const std::vector& input_id) { long unsigned int input[1][input_id.size()]; for(int j = 0; j < input_id.size(); ++j) { input[0][j] = input_id[j]; } // 设置输入shape std::vector> inputShapes; inputShapes.push_back({1, input_id.size()}); // 创建输入数据 std::unordered_map inputData; inputData[inputName] = migraphx::argument{migraphx::shape(inputShape.type(), inputShapes[0]), (long unsigned int*)input}; // 推理 std::vector results = net.eval(inputData); // 获取输出节点的属性 migraphx::argument result = results[0]; migraphx::shape outputShape = result.get_shape(); // 输出节点的shape int numberOfOutput = outputShape.elements(); // 输出节点元素的个数 float* data = (float*)result.data(); // 输出节点数据指针 // 保存推理结果 long unsigned int n = 0; std::vector resultsOfPredictions(22557); for(int i = (input_id.size() - 1) * 22557; i < input_id.size() * 22557; ++i) { resultsOfPredictions[n].index = n; resultsOfPredictions[n].predictionvalue = data[i]; ++n; } // 对于[UNK]的概率设为无穷小,模型的预测结果不可能是[UNK] resultsOfPredictions[100].predictionvalue = -10000; // 排序 std::sort(resultsOfPredictions.begin(), resultsOfPredictions.end(), CompareM); return resultsOfPredictions[0].index; } ErrorCode GPT2::Preprocessing(cuBERT::FullTokenizer tokenizer, char* question, std::vector& input_id) { // 分词操作 int max_seq_length = 1000; std::vector tokens_question; tokens_question.reserve(max_seq_length); tokenizer.tokenize(question, &tokens_question, max_seq_length); // 保存编码信息 input_id.push_back(tokenizer.convert_token_to_id("[CLS]")); for(int i = 0; i < tokens_question.size(); ++i) { input_id.push_back(tokenizer.convert_token_to_id(tokens_question[i])); } return SUCCESS; } } // namespace migraphxSamples