#ifndef __GPT2_H__ #define __GPT2_H__ #include #include #include #include namespace migraphxSamples { typedef enum _ErrorCode { SUCCESS=0, MODEL_NOT_EXIST, CONFIG_FILE_NOT_EXIST, FAIL_TO_LOAD_MODEL, FAIL_TO_OPEN_CONFIG_FILE, }ErrorCode; typedef struct _Predictions { long unsigned int index; float predictionvalue; }Predictions; class GPT2 { public: GPT2(); ~GPT2(); ErrorCode Initialize(); ErrorCode Preprocessing(cuBERT::FullTokenizer tokenizer, char *question, std::vector &input_id); long unsigned int Inference(const std::vector &input_id); private: migraphx::program net; std::string inputName; migraphx::shape inputShape; }; } #endif