#ifndef __BERT_H__ #define __BERT_H__ #include #include #include #include namespace migraphxSamples { typedef enum _ErrorCode { SUCCESS=0, MODEL_NOT_EXIST, CONFIG_FILE_NOT_EXIST, FAIL_TO_LOAD_MODEL, FAIL_TO_OPEN_CONFIG_FILE, }ErrorCode; typedef struct _Sort_st { int index; float value; }Sort_st; typedef struct _ResultOfPredictions { int start_index; int end_index; float start_predictionvalue; float end_predictionvalue; }ResultOfPredictions; class Bert { public: Bert(); ~Bert(); ErrorCode Initialize(); ErrorCode Inference(const std::vector> &input_ids, const std::vector> &input_masks, const std::vector> &segment_ids, std::vector &start_position, std::vector &end_position); ErrorCode Preprocessing(cuBERT::FullTokenizer tokenizer, int batch_size, int max_seq_length, const char *text, char *question, std::vector> &input_ids, std::vector> &input_masks, std::vector> &segment_ids); ErrorCode Postprocessing(int n_best_size, int max_answer_length, const std::vector &start_position, const std::vector &end_position, std::string &answer); private: std::vector tokens_text; std::vector tokens_question; migraphx::program net; std::string inputName1; std::string inputName2; std::string inputName3; std::string inputName4; migraphx::shape inputShape1; migraphx::shape inputShape2; migraphx::shape inputShape3; migraphx::shape inputShape4; }; } #endif