//
// Created by huangyuyang on 6/9/23.
//

#include "fstream"
#include <chrono>
#include "chatglm.h"
#include <unistd.h>
//static factoryllm fllm;
//static int modeltype = 0;
//static char* modelpath = NULL;
//static fastllm::basellm* chatGlm = fllm.createllm(LLM_TYPE_CHATGLM);
//static fastllm::basellm* moss = fllm.createllm(LLM_TYPE_MOSS);
//static fastllm::basellm* vicuna = fllm.createllm(LLM_TYPE_VICUNA);

struct BenchmarkConfig {
    std::string path = "chatglm-6b-fp16.bin"; // 模型文件路径
    int limit = -1; // 输出token数限制，如果 < 0 则代表无限制
    int batch = -1; // batch数, -1时使用文件中的行数作为batch
    std::string file; // 输入文件
    std::string output; // 输出文件，如果不设定则输出到屏幕
    int runloop = 0;
};

void Usage() {
    std::cout << "Usage:" << std::endl;
    std::cout << "[-h|--help]:                  显示帮助" << std::endl;
    std::cout << "<-p|--path> <args>:           模型文件的路径" << std::endl;
    std::cout << "<-l|--limit> <args>:          输出token数限制，不设定则表示无限制" << std::endl;
    std::cout << "<-b|--batch> <args>:          batch数，不设定时使用文件中的行数作为batch"      << std::endl;
    std::cout << "<-f|--file> <args>:           输入文件，文件中每行一个prompt，如果行数不足batch则用之前的prompt补充"      << std::endl;
    std::cout << "<-o|--output> <args>:         输出结果写文件，如果不设定则输出到屏幕"      << std::endl;
}

void ParseArgs(int argc, char **argv, BenchmarkConfig &config) {
    std::vector <std::string> sargv;
    for (int i = 0; i < argc; i++) {
        sargv.push_back(std::string(argv[i]));
    }
    for (int i = 1; i < argc; i++) {
        if (sargv[i] == "-h" || sargv[i] == "--help") {
            Usage();
            exit(0);
        }
        else if (sargv[i] == "-p" || sargv[i] == "--path") {
            config.path = sargv[++i];
        } else if (sargv[i] == "-l" || sargv[i] == "--limit") {
            config.limit = atoi(sargv[++i].c_str());
        } else if (sargv[i] == "-b" || sargv[i] == "--batch") {
            config.batch = atoi(sargv[++i].c_str());
        } else if (sargv[i] == "-f" || sargv[i] == "--file") {
            config.file = sargv[++i];
        } else if (sargv[i] == "-o" || sargv[i] == "--output") {
            config.output = sargv[++i];
        } else if (sargv[i] == "--loop") {
            config.runloop = 1;
        } else {
            Usage();
            exit(-1);
        }
    }
}

static double GetSpan(std::chrono::system_clock::time_point time1, std::chrono::system_clock::time_point time2) {
    auto duration = std::chrono::duration_cast<std::chrono::microseconds> (time2 - time1);
    return double(duration.count()) * std::chrono::microseconds::period::num / std::chrono::microseconds::period::den;
};

int main(int argc, char **argv) {

    BenchmarkConfig config;
    ParseArgs(argc, argv, config);
    fastllm::ChatGLMModel chatGlm;
    chatGlm.LoadFromFile(config.path);
    chatGlm.WarmUp();
    chatGlm.output_token_limit = config.limit;

    std::vector <std::string> inputs;
    if (config.file != "") {
        std::ifstream finputs(config.file, std::ios::in);
        while (true) {
            std::string input = "";
            std::getline(finputs, input);
            if (input == "") {
                break;
            } else {
                inputs.push_back(input);
            }
        }
    } else {
        inputs.push_back("Hello！");
    }
    if (config.batch < 0) {
        config.batch = inputs.size();
    }
    while (inputs.size() < config.batch) {
        inputs.push_back(inputs[rand() % inputs.size()]);
    }
    if (inputs.size() > config.batch && config.batch != -1) {
        inputs.resize(config.batch);
    }

    if(config.runloop == 1){
        while(true){
            std::vector <std::string> outputs;
            static int tokens = 0;
            auto st = std::chrono::system_clock::now();
            chatGlm.ResponseBatch(inputs, outputs, [](int index, std::vector <std::string> &contents) {
                if (index != -1) {
                    for (int i = 0; i < contents.size(); i++) {
                        tokens += (contents[i].size() > 0);
                    }
                }
            });
            float spend = GetSpan(st, std::chrono::system_clock::now());

            if (config.output != "") {
                FILE *fo = fopen(config.output.c_str(), "w");
                for (int i = 0; i < outputs.size(); i++) {
                    fprintf(fo, "[ user: \"%s\", model: \"%s\"]\n", inputs[i].c_str(), outputs[i].c_str());
                }
                fclose(fo);
            }
            pid_t pid = getpid();
//            printf("batch: %d\n", (int)inputs.size());
            printf("pid %d : output %d tokens\nuse %f s\nspeed = %f tokens / s\n", pid, tokens, spend, tokens / spend);
        }
    }
    else{
        std::vector <std::string> outputs;
        static int tokens = 0;
        auto st = std::chrono::system_clock::now();
        chatGlm.ResponseBatch(inputs, outputs, [](int index, std::vector <std::string> &contents) {
            if (index != -1) {
                for (int i = 0; i < contents.size(); i++) {
                    tokens += (contents[i].size() > 0);
                }
            }
        });
        float spend = GetSpan(st, std::chrono::system_clock::now());

        if (config.output != "") {
            FILE *fo = fopen(config.output.c_str(), "w");
            for (int i = 0; i < outputs.size(); i++) {
                fprintf(fo, "[ user: \"%s\", model: \"%s\"]\n", inputs[i].c_str(), outputs[i].c_str());
            }
            fclose(fo);
        } else {
            for (int i = 0; i < outputs.size(); i++) {
                printf("[ user: \"%s\", model: \"%s\"]\n", inputs[i].c_str(), outputs[i].c_str());
            }
        }

        printf("batch: %d\n", (int)inputs.size());
        printf("output %d tokens\nuse %f s\nspeed = %f tokens / s\n", tokens, spend, tokens / spend);
    }
    return 0;
}