benchmark.cpp 6.02 KB
Newer Older
1
2
3
4
5
6
7
//
// Created by huangyuyang on 6/9/23.
//

#include "fstream"
#include <chrono>
#include "chatglm.h"
zhouxiang's avatar
zhouxiang committed
8
#include <unistd.h>
9
10
11
12
13
14
15
16
17
18
19
20
21
//static factoryllm fllm;
//static int modeltype = 0;
//static char* modelpath = NULL;
//static fastllm::basellm* chatGlm = fllm.createllm(LLM_TYPE_CHATGLM);
//static fastllm::basellm* moss = fllm.createllm(LLM_TYPE_MOSS);
//static fastllm::basellm* vicuna = fllm.createllm(LLM_TYPE_VICUNA);

struct BenchmarkConfig {
    std::string path = "chatglm-6b-fp16.bin"; // 模型文件路径
    int limit = -1; // 输出token数限制,如果 < 0 则代表无限制
    int batch = -1; // batch数, -1时使用文件中的行数作为batch
    std::string file; // 输入文件
    std::string output; // 输出文件,如果不设定则输出到屏幕
zhouxiang's avatar
zhouxiang committed
22
    int runloop = 0;
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
};

void Usage() {
    std::cout << "Usage:" << std::endl;
    std::cout << "[-h|--help]:                  显示帮助" << std::endl;
    std::cout << "<-p|--path> <args>:           模型文件的路径" << std::endl;
    std::cout << "<-l|--limit> <args>:          输出token数限制,不设定则表示无限制" << std::endl;
    std::cout << "<-b|--batch> <args>:          batch数,不设定时使用文件中的行数作为batch"      << std::endl;
    std::cout << "<-f|--file> <args>:           输入文件,文件中每行一个prompt,如果行数不足batch则用之前的prompt补充"      << std::endl;
    std::cout << "<-o|--output> <args>:         输出结果写文件,如果不设定则输出到屏幕"      << std::endl;
}

void ParseArgs(int argc, char **argv, BenchmarkConfig &config) {
    std::vector <std::string> sargv;
    for (int i = 0; i < argc; i++) {
        sargv.push_back(std::string(argv[i]));
    }
    for (int i = 1; i < argc; i++) {
        if (sargv[i] == "-h" || sargv[i] == "--help") {
            Usage();
            exit(0);
        }
        else if (sargv[i] == "-p" || sargv[i] == "--path") {
            config.path = sargv[++i];
        } else if (sargv[i] == "-l" || sargv[i] == "--limit") {
            config.limit = atoi(sargv[++i].c_str());
        } else if (sargv[i] == "-b" || sargv[i] == "--batch") {
            config.batch = atoi(sargv[++i].c_str());
        } else if (sargv[i] == "-f" || sargv[i] == "--file") {
            config.file = sargv[++i];
        } else if (sargv[i] == "-o" || sargv[i] == "--output") {
            config.output = sargv[++i];
zhouxiang's avatar
zhouxiang committed
55
56
        } else if (sargv[i] == "--loop") {
            config.runloop = 1;
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
        } else {
            Usage();
            exit(-1);
        }
    }
}

static double GetSpan(std::chrono::system_clock::time_point time1, std::chrono::system_clock::time_point time2) {
    auto duration = std::chrono::duration_cast<std::chrono::microseconds> (time2 - time1);
    return double(duration.count()) * std::chrono::microseconds::period::num / std::chrono::microseconds::period::den;
};

int main(int argc, char **argv) {

    BenchmarkConfig config;
    ParseArgs(argc, argv, config);
    fastllm::ChatGLMModel chatGlm;
    chatGlm.LoadFromFile(config.path);
    chatGlm.WarmUp();
    chatGlm.output_token_limit = config.limit;

    std::vector <std::string> inputs;
    if (config.file != "") {
        std::ifstream finputs(config.file, std::ios::in);
        while (true) {
            std::string input = "";
            std::getline(finputs, input);
            if (input == "") {
                break;
            } else {
                inputs.push_back(input);
            }
        }
    } else {
        inputs.push_back("Hello!");
    }
    if (config.batch < 0) {
        config.batch = inputs.size();
    }
    while (inputs.size() < config.batch) {
        inputs.push_back(inputs[rand() % inputs.size()]);
    }
    if (inputs.size() > config.batch && config.batch != -1) {
        inputs.resize(config.batch);
    }

zhouxiang's avatar
zhouxiang committed
103
    if(config.runloop == 1){
104
        static int tokens = 0;
zhouxiang's avatar
zhouxiang committed
105
        while(true){
106
            tokens = 0;
zhouxiang's avatar
zhouxiang committed
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
            std::vector <std::string> outputs;
            auto st = std::chrono::system_clock::now();
            chatGlm.ResponseBatch(inputs, outputs, [](int index, std::vector <std::string> &contents) {
                if (index != -1) {
                    for (int i = 0; i < contents.size(); i++) {
                        tokens += (contents[i].size() > 0);
                    }
                }
            });
            float spend = GetSpan(st, std::chrono::system_clock::now());

            if (config.output != "") {
                FILE *fo = fopen(config.output.c_str(), "w");
                for (int i = 0; i < outputs.size(); i++) {
                    fprintf(fo, "[ user: \"%s\", model: \"%s\"]\n", inputs[i].c_str(), outputs[i].c_str());
                }
                fclose(fo);
124
            }
zhouxiang's avatar
zhouxiang committed
125
126
            pid_t pid = getpid();
//            printf("batch: %d\n", (int)inputs.size());
127
            printf("pid %d : output %d tokens, use %f s, speed = %f tokens/s\n", pid, tokens, spend, tokens / spend);
128
        }
zhouxiang's avatar
zhouxiang committed
129
130
131
132
133
134
135
136
137
138
139
140
141
    }
    else{
        std::vector <std::string> outputs;
        static int tokens = 0;
        auto st = std::chrono::system_clock::now();
        chatGlm.ResponseBatch(inputs, outputs, [](int index, std::vector <std::string> &contents) {
            if (index != -1) {
                for (int i = 0; i < contents.size(); i++) {
                    tokens += (contents[i].size() > 0);
                }
            }
        });
        float spend = GetSpan(st, std::chrono::system_clock::now());
142

zhouxiang's avatar
zhouxiang committed
143
144
145
146
147
148
149
150
151
152
        if (config.output != "") {
            FILE *fo = fopen(config.output.c_str(), "w");
            for (int i = 0; i < outputs.size(); i++) {
                fprintf(fo, "[ user: \"%s\", model: \"%s\"]\n", inputs[i].c_str(), outputs[i].c_str());
            }
            fclose(fo);
        } else {
            for (int i = 0; i < outputs.size(); i++) {
                printf("[ user: \"%s\", model: \"%s\"]\n", inputs[i].c_str(), outputs[i].c_str());
            }
153
154
        }

zhouxiang's avatar
zhouxiang committed
155
156
157
        printf("batch: %d\n", (int)inputs.size());
        printf("output %d tokens\nuse %f s\nspeed = %f tokens / s\n", tokens, spend, tokens / spend);
    }
158
159
    return 0;
}