#ifndef __BERT_H__ #define __BERT_H__ #include #include #include #include namespace migraphxSamples { typedef enum _ErrorCode { SUCCESS = 0, MODEL_NOT_EXIST, CONFIG_FILE_NOT_EXIST, FAIL_TO_LOAD_MODEL, FAIL_TO_OPEN_CONFIG_FILE, } ErrorCode; typedef struct _Sort_st { int index; float value; } Sort_st; typedef struct _ResultOfPredictions { int start_index; int end_index; float start_predictionvalue; float end_predictionvalue; } ResultOfPredictions; class Bert { public: Bert(); ~Bert(); ErrorCode Initialize(); ErrorCode Inference(const std::vector>& input_ids, const std::vector>& input_masks, const std::vector>& segment_ids, std::vector& start_position, std::vector& end_position); ErrorCode Preprocessing(cuBERT::FullTokenizer tokenizer, int batch_size, int max_seq_length, const char* text, char* question, std::vector>& input_ids, std::vector>& input_masks, std::vector>& segment_ids); ErrorCode Postprocessing(int n_best_size, int max_answer_length, const std::vector& start_position, const std::vector& end_position, std::string& answer); private: std::vector tokens_text; std::vector tokens_question; migraphx::program net; std::string inputName1; std::string inputName2; std::string inputName3; std::string inputName4; migraphx::shape inputShape1; migraphx::shape inputShape2; migraphx::shape inputShape3; migraphx::shape inputShape4; }; } // namespace migraphxSamples #endif