#ifndef GPT2_H #define GPT2_H #include #include #include #include #include namespace migraphxSamples { typedef struct _Predictions { long unsigned int index; float predictionvalue; }Predictions; class GPT2 { public: GPT2(); ~GPT2(); ErrorCode Initialize(); ErrorCode Preprocessing(cuBERT::FullTokenizer tokenizer, char *question, std::vector &input_id); long unsigned int Inference(const std::vector &input_id); private: migraphx::program net; std::string inputName; migraphx::shape inputShape; }; } #endif