#include <Bert.h>
#include <Filesystem.h>
#include <SimpleLog.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <tokenization.h>

int main(int argc, char* argv[])
{
    // 加载Bert模型
    migraphxSamples::Bert bert;
    migraphxSamples::ErrorCode errorCode = bert.Initialize();
    if(errorCode != migraphxSamples::SUCCESS)
    {
        LOG_ERROR(stdout, "fail to initialize Bert!\n");
        exit(-1);
    }
    LOG_INFO(stdout, "succeed to initialize Bert\n");

    int max_seq_length    = 256; // 滑动窗口的长度
    int max_query_length  = 64;  // 问题的最大长度
    int batch_size        = 1;   // batch_size值
    int n_best_size       = 20;  // 索引数量
    int max_answer_length = 30;  // 答案的最大长度

    // 上下文文本数据
    const char text[] = {u8"ROCm is the first open-source exascale-class platform for accelerated "
                         u8"computing that’s also programming-language independent. It brings a "
                         u8"philosophy of choice, minimalism and modular software development to "
                         u8"GPU computing. You are free to choose or even develop tools and a "
                         u8"language run time for your application. ROCm is built for scale, it "
                         u8"supports multi-GPU computing and has a rich system run time with the "
                         u8"critical features that large-scale application, compiler and "
                         u8"language-run-time development requires. Since the ROCm ecosystem is "
                         u8"comprised of open technologies: frameworks (Tensorflow / PyTorch), "
                         u8"libraries (MIOpen / Blas / RCCL), programming model (HIP), "
                         u8"inter-connect (OCD) and up streamed Linux® Kernel support – the "
                         u8"platform is continually optimized for performance and extensibility."};
    char question[100];

    std::vector<std::vector<long unsigned int>> input_ids;
    std::vector<std::vector<long unsigned int>> input_masks;
    std::vector<std::vector<long unsigned int>> segment_ids;
    std::vector<float> start_position;
    std::vector<float> end_position;
    std::string answer = {};

    cuBERT::FullTokenizer tokenizer =
        cuBERT::FullTokenizer("../Resource/uncased_L-12_H-768_A-12/vocab.txt"); // 分词工具

    while(true)
    {
        // 数据前处理
        std::cout << "question: ";
        cin.getline(question, 100);
        bert.Preprocessing(tokenizer,
                           batch_size,
                           max_seq_length,
                           text,
                           question,
                           input_ids,
                           input_masks,
                           segment_ids);

        // 推理
        bert.Inference(input_ids, input_masks, segment_ids, start_position, end_position);

        // 数据后处理
        bert.Postprocessing(n_best_size, max_answer_length, start_position, end_position, answer);

        // 打印输出预测结果
        std::cout << "answer: " << answer << std::endl;

        // 清除数据
        input_ids.clear();
        input_masks.clear();
        segment_ids.clear();
        start_position.clear();
        end_position.clear();
        answer = {};
    }

    return 0;
}