benchmark.cpp 4.67 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
//
// Created by huangyuyang on 6/9/23.
//

#include "fstream"
#include <chrono>
#include "chatglm.h"

//static factoryllm fllm;
//static int modeltype = 0;
//static char* modelpath = NULL;
//static fastllm::basellm* chatGlm = fllm.createllm(LLM_TYPE_CHATGLM);
//static fastllm::basellm* moss = fllm.createllm(LLM_TYPE_MOSS);
//static fastllm::basellm* vicuna = fllm.createllm(LLM_TYPE_VICUNA);

struct BenchmarkConfig {
    std::string path = "chatglm-6b-fp16.bin"; // 模型文件路径
    int limit = -1; // 输出token数限制,如果 < 0 则代表无限制
    int batch = -1; // batch数, -1时使用文件中的行数作为batch
    std::string file; // 输入文件
    std::string output; // 输出文件,如果不设定则输出到屏幕
};

void Usage() {
    std::cout << "Usage:" << std::endl;
    std::cout << "[-h|--help]:                  显示帮助" << std::endl;
    std::cout << "<-p|--path> <args>:           模型文件的路径" << std::endl;
    std::cout << "<-l|--limit> <args>:          输出token数限制,不设定则表示无限制" << std::endl;
    std::cout << "<-b|--batch> <args>:          batch数,不设定时使用文件中的行数作为batch"      << std::endl;
    std::cout << "<-f|--file> <args>:           输入文件,文件中每行一个prompt,如果行数不足batch则用之前的prompt补充"      << std::endl;
    std::cout << "<-o|--output> <args>:         输出结果写文件,如果不设定则输出到屏幕"      << std::endl;
}

void ParseArgs(int argc, char **argv, BenchmarkConfig &config) {
    std::vector <std::string> sargv;
    for (int i = 0; i < argc; i++) {
        sargv.push_back(std::string(argv[i]));
    }
    for (int i = 1; i < argc; i++) {
        if (sargv[i] == "-h" || sargv[i] == "--help") {
            Usage();
            exit(0);
        }
        else if (sargv[i] == "-p" || sargv[i] == "--path") {
            config.path = sargv[++i];
        } else if (sargv[i] == "-l" || sargv[i] == "--limit") {
            config.limit = atoi(sargv[++i].c_str());
        } else if (sargv[i] == "-b" || sargv[i] == "--batch") {
            config.batch = atoi(sargv[++i].c_str());
        } else if (sargv[i] == "-f" || sargv[i] == "--file") {
            config.file = sargv[++i];
        } else if (sargv[i] == "-o" || sargv[i] == "--output") {
            config.output = sargv[++i];
        } else {
            Usage();
            exit(-1);
        }
    }
}

static double GetSpan(std::chrono::system_clock::time_point time1, std::chrono::system_clock::time_point time2) {
    auto duration = std::chrono::duration_cast<std::chrono::microseconds> (time2 - time1);
    return double(duration.count()) * std::chrono::microseconds::period::num / std::chrono::microseconds::period::den;
};

int main(int argc, char **argv) {

    BenchmarkConfig config;
    ParseArgs(argc, argv, config);
    fastllm::ChatGLMModel chatGlm;
    chatGlm.LoadFromFile(config.path);
    chatGlm.WarmUp();
    chatGlm.output_token_limit = config.limit;

    std::vector <std::string> inputs;
    if (config.file != "") {
        std::ifstream finputs(config.file, std::ios::in);
        while (true) {
            std::string input = "";
            std::getline(finputs, input);
            if (input == "") {
                break;
            } else {
                inputs.push_back(input);
            }
        }
    } else {
        inputs.push_back("Hello!");
    }
    if (config.batch < 0) {
        config.batch = inputs.size();
    }
    while (inputs.size() < config.batch) {
        inputs.push_back(inputs[rand() % inputs.size()]);
    }
    if (inputs.size() > config.batch && config.batch != -1) {
        inputs.resize(config.batch);
    }

    std::vector <std::string> outputs;
    static int tokens = 0;
    auto st = std::chrono::system_clock::now();
    chatGlm.ResponseBatch(inputs, outputs, [](int index, std::vector <std::string> &contents) {
        if (index != -1) {
            for (int i = 0; i < contents.size(); i++) {
                tokens += (contents[i].size() > 0);
            }
        }
    });
    float spend = GetSpan(st, std::chrono::system_clock::now());

    if (config.output != "") {
        FILE *fo = fopen(config.output.c_str(), "w");
        for (int i = 0; i < outputs.size(); i++) {
            fprintf(fo, "[ user: \"%s\", model: \"%s\"]\n", inputs[i].c_str(), outputs[i].c_str());
        }
        fclose(fo);
    } else {
        for (int i = 0; i < outputs.size(); i++) {
            printf("[ user: \"%s\", model: \"%s\"]\n", inputs[i].c_str(), outputs[i].c_str());
        }
    }

    printf("batch: %d\n", (int)inputs.size());
    printf("output %d tokens\nuse %f s\nspeed = %f tokens / s\n", tokens, spend, tokens / spend);
    return 0;
}