#ifndef __GPT2_H__ #define __GPT2_H__ #include #include #include #include namespace ortSamples { typedef enum _ErrorCode { SUCCESS=0, MODEL_NOT_EXIST, CONFIG_FILE_NOT_EXIST, FAIL_TO_LOAD_MODEL, FAIL_TO_OPEN_CONFIG_FILE, }ErrorCode; typedef struct _Predictions { long unsigned int index; float predictionvalue; }Predictions; class GPT2 { public: GPT2(); ~GPT2(); ErrorCode Initialize(); ErrorCode Preprocessing(cuBERT::FullTokenizer tokenizer, char *question, std::vector &input_id); long unsigned int Inference(const std::vector &input_id); private: std::vector input_node_names; std::vector output_node_names; Ort::Session *session; Ort::Env env = Ort::Env(ORT_LOGGING_LEVEL_ERROR, "ONNXRuntime"); Ort::SessionOptions sessionOptions = Ort::SessionOptions(); }; } #endif