#include #include #include #include #include #include using namespace std; using namespace migraphx; using namespace migraphxSamples; void Sample_Bert() { // 加载Bert模型 Bert bert; InitializationParameterOfNLP initParamOfNLPBert; initParamOfNLPBert.parentPath = ""; initParamOfNLPBert.configFilePath = CONFIG_FILE; initParamOfNLPBert.logName = ""; ErrorCode errorCode = bert.Initialize(initParamOfNLPBert); if (errorCode != SUCCESS) { LOG_ERROR(stdout, "fail to initialize Bert!\n"); exit(-1); } LOG_INFO(stdout, "succeed to initialize Bert\n"); int max_seq_length = 256; // 滑动窗口的长度 int max_query_length = 64; // 问题的最大长度 int batch_size = 1; // batch_size值 int n_best_size = 20; // 索引数量 int max_answer_length = 30; // 答案的最大长度 // 上下文文本数据 const char text[] = { u8"ROCm is the first open-source exascale-class platform for accelerated computing that’s also programming-language independent. It brings a philosophy of choice, minimalism and modular software development to GPU computing. You are free to choose or even develop tools and a language run time for your application. ROCm is built for scale, it supports multi-GPU computing and has a rich system run time with the critical features that large-scale application, compiler and language-run-time development requires. Since the ROCm ecosystem is comprised of open technologies: frameworks (Tensorflow / PyTorch), libraries (MIOpen / Blas / RCCL), programming model (HIP), inter-connect (OCD) and up streamed Linux® Kernel support – the platform is continually optimized for performance and extensibility." }; char question[100]; std::vector> input_ids; std::vector> input_masks; std::vector> segment_ids; std::vector start_position; std::vector end_position; std::string answer = {}; cuBERT::FullTokenizer tokenizer = cuBERT::FullTokenizer("../Resource/Models/NLP/Bert/uncased_L-12_H-768_A-12/vocab.txt"); // 分词工具 while (true) { // 数据前处理 std::cout << "question: "; cin.getline(question, 100); bert.Preprocessing(tokenizer, batch_size, max_seq_length, text, question, input_ids, input_masks, segment_ids); // 推理 double time1 = getTickCount(); bert.Inference(input_ids, input_masks, segment_ids, start_position, end_position); double time2 = getTickCount(); double elapsedTime = (time2 - time1) * 1000 / getTickFrequency(); // 数据后处理 bert.Postprocessing(n_best_size, max_answer_length, start_position, end_position, answer); // 打印输出预测结果 std::cout << "answer: " << answer << std::endl; LOG_INFO(stdout, "inference time:%f ms\n", elapsedTime); // 清除数据 input_ids.clear(); input_masks.clear(); segment_ids.clear(); start_position.clear(); end_position.clear(); answer = {}; } }