profile_throughput.py 12.9 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
zhouxiang's avatar
zhouxiang committed
2
import argparse
3
import csv
4
import json
5
import os
6
7
8
9
import random
import time
from queue import Queue
from threading import Thread
zhouxiang's avatar
zhouxiang committed
10
from typing import List, Tuple, Union
11

12
import numpy as np
13
from tqdm import tqdm
14

zhouxiang's avatar
zhouxiang committed
15
16
17
18
19
from lmdeploy.cli.utils import ArgumentHelper, DefaultsAndTypesHelpFormatter
from lmdeploy.messages import (EngineGenerationConfig, PytorchEngineConfig,
                               TurbomindEngineConfig)
from lmdeploy.pytorch.engine.engine import EngineInstance
from lmdeploy.tokenizer import DetokenizeState, Tokenizer
20
21
22
23
24
25
26
27
28
29
30
31
32
33


def sample_requests(
    dataset_path: str,
    num_requests: int,
    tokenizer: Tokenizer,
) -> List[Tuple[str, int, int]]:
    # Load the dataset.
    with open(dataset_path) as f:
        dataset = json.load(f)
    # Filter out the conversations with less than 2 turns.
    dataset = [data for data in dataset if len(data['conversations']) >= 2]
    # Only keep the first two turns of each conversation.
    dataset = [(data['conversations'][0]['value'],
34
                data['conversations'][1]['value']) for data in dataset]
35

zhouxiang's avatar
zhouxiang committed
36
37
38
    # pre-sample to avoid go through all the dataset
    dataset = random.sample(dataset, max(int(num_requests * 1.2), 1000))

39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
    # Tokenize the prompts and completions.
    prompts = [prompt for prompt, _ in dataset]
    prompt_token_ids = tokenizer(prompts).input_ids
    completions = [completion for _, completion in dataset]
    completion_token_ids = tokenizer(completions).input_ids
    tokenized_dataset = []
    for i in range(len(dataset)):
        output_len = len(completion_token_ids[i])
        tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))

    # Filter out too long sequences.
    filtered_dataset: List[Tuple[str, int, int]] = []
    for prompt, prompt_token_ids, output_len in tokenized_dataset:
        prompt_len = len(prompt_token_ids)
        if prompt_len < 4 or output_len < 4:
            # Prune too short sequences.
            continue
        if prompt_len > 1024 or prompt_len + output_len > 2048:
            # Prune too long sequences.
            continue
        filtered_dataset.append((prompt, prompt_len, output_len))

    # Sample the requests.
    sampled_requests = random.sample(filtered_dataset, num_requests)
    return sampled_requests


class Engine:

zhouxiang's avatar
zhouxiang committed
68
69
70
71
72
73
74
75
76
77
78
    def __init__(self, model_path: str,
                 engine_config: Union[PytorchEngineConfig,
                                      TurbomindEngineConfig], csv: str):
        if isinstance(engine_config, TurbomindEngineConfig):
            from lmdeploy.turbomind import TurboMind
            tm_model = TurboMind.from_pretrained(model_path,
                                                 engine_config=engine_config)
        elif isinstance(engine_config, PytorchEngineConfig):
            from lmdeploy.pytorch.engine import Engine as PytorchEngine
            tm_model = PytorchEngine(model_path, engine_config=engine_config)

79
        self.tm_model = tm_model
80
        self.tokenizer = tm_model.tokenizer
zhouxiang's avatar
zhouxiang committed
81

82
83
        self.csv = csv
        self.pbar = None
84

85
    def _inference(self, req_queue: Queue, res_queue: Queue, session_id: int,
zhouxiang's avatar
zhouxiang committed
86
                   temperature: float, top_p: float, top_k: int,
87
                   stream_output: bool):
88
        model_inst = self.tm_model.create_instance()
89
        stats = []
90
91
        # get each generated token's latency
        per_token_latency_stats = []
92
93
        for prompt, input_seqlen, output_seqlen in iter(
                req_queue.get, [None, None, None]):
94
            _per_token_latency_stats = [0] * (output_seqlen + 1)
zhouxiang's avatar
zhouxiang committed
95
            state = DetokenizeState()
96
97
            prev = time.perf_counter()
            n_prev_token = 0
98
99

            input_ids = self.tokenizer(prompt).input_ids
zhouxiang's avatar
zhouxiang committed
100

101
102
103
            for outputs in model_inst.stream_infer(
                    session_id,
                    input_ids=input_ids,
zhouxiang's avatar
zhouxiang committed
104
105
106
107
108
109
                    gen_config=EngineGenerationConfig(
                        max_new_tokens=output_seqlen,
                        temperature=temperature,
                        top_p=top_p,
                        top_k=top_k,
                        ignore_eos=True),
110
111
112
                    sequence_start=True,
                    sequence_end=True,
                    stream_output=stream_output):
zhouxiang's avatar
zhouxiang committed
113
114
                _, res, n_token = outputs
                _, state = self.tokenizer.detokenize_incrementally(res, state)
115
116
117
118
119
120
                now = time.perf_counter()
                if n_prev_token != n_token:
                    _per_token_latency_stats[n_prev_token] = np.round(
                        now - prev, 3)
                    n_prev_token = n_token
                prev = now
zhouxiang's avatar
zhouxiang committed
121
122
123
            # for pytorch engine to restart a session
            if isinstance(model_inst, EngineInstance):
                model_inst.end(session_id)
124
            assert output_seqlen <= n_token <= output_seqlen + 1, \
125
                f'Error. session_id({session_id}) request {output_seqlen} ' \
126
                f'tokens, but generate {n_token} tokens.\n' \
127
                f'prompt: {prompt}'
128
129
130
131

            first_token_latency = _per_token_latency_stats[0]
            completion_tokens = n_token
            total_tokens = n_token + input_seqlen
132
133
            stats.append([
                first_token_latency, completion_tokens, output_seqlen,
134
                total_tokens
135
            ])
136
137
            # skip the first token latency
            per_token_latency_stats.append(_per_token_latency_stats[1:])
138
            self.pbar.update(1)
139
        res_queue.put((session_id, stats, per_token_latency_stats))
140

zhouxiang's avatar
zhouxiang committed
141
142
    def process_request(self, requests, concurrency, temperature, top_p, top_k,
                        stream_output):
143
144
        res_queue = Queue()
        req_queue = Queue()
145
146
        threads = []

147
148
        self.pbar = tqdm(total=len(requests))

149
150
151
152
153
154
        # feed request to q
        for req in requests:
            req_queue.put(req)
        for i in range(concurrency):
            req_queue.put([None, None, None])

155
156
157
158
        start = time.time()

        # start threads
        for i in range(concurrency):
159
            t = Thread(target=self._inference,
zhouxiang's avatar
zhouxiang committed
160
161
162
                       args=(req_queue, res_queue, i, temperature, top_p,
                             top_k, stream_output),
                       daemon=True)
163
164
165
166
167
168
169
            t.start()
            threads.append(t)

        # wait for finish
        for t in threads:
            t.join()

170
171
172
        elapsed_time = time.time() - start

        stats = []
173
        per_token_latency_stats = []
174
        while not res_queue.empty():
175
            session_id, _stats, _per_token_latency_stats = res_queue.get()
176
            stats.append(np.array(_stats))
177
178
179
180
181
            per_token_latency_stats += [
                item for sublist in _per_token_latency_stats
                for item in sublist
            ]
        stats = np.concatenate(stats).reshape(-1, 4)
182
183
184
185
186
187
188
189
190

        first_token_latency_min = np.min(stats[:, 0], axis=0)
        first_token_latency_max = np.max(stats[:, 0], axis=0)
        first_token_latency_ave = np.mean(stats[:, 0], axis=0)
        completion_tokens = np.sum(stats[:, 1], axis=0)
        total_tokens = np.sum(stats[:, 3], axis=0)
        prompt_tokens = total_tokens - completion_tokens
        completion_token_throughput = completion_tokens / elapsed_time
        total_token_throughput = total_tokens / elapsed_time
191
192
193
194
195
196
197
198
199
200
        rps = len(requests) / elapsed_time
        rpm = rps * 60

        per_token_latency_stats.sort()
        percentiles = [
            np.round(
                per_token_latency_stats[int(percent *
                                            len(per_token_latency_stats))], 3)
            for percent in [0.5, 0.75, 0.95, 0.99]
        ]
201
202
203
204

        print(f'\n{"-" * 50}\nconcurrency: {concurrency}\n'
              f'elapsed_time: {elapsed_time:.3f}s\n')
        if stream_output:
205
206
207
208
209
210
            print(f'first token latency(s)(min, max, ave): '
                  f'{first_token_latency_min:.3f}, '
                  f'{first_token_latency_max:.3f}, '
                  f'{first_token_latency_ave:.3f}')
            print(f'per-token latency(s) percentile(50, 75, 95, 99): '
                  f'{percentiles}\n')
211
212
213
214
215
        print(
            f'number of prompt tokens: {prompt_tokens:.0f}\n'
            f'number of completion tokens: {completion_tokens:.0f}\n'
            f'token throughput (completion token): {completion_token_throughput:.3f} token/s\n'  # noqa
            f'token throughput (prompt + completion token): {total_token_throughput:.3f} token/s\n'  # noqa
216
217
            f'RPS (request per second): {rps:.3f} req/s\n'
            f'RPM (request per minute): {rpm:.3f} req/min\n'
218
            f'{"-" * 50}\n')
219

220
221
222
223
        if self.csv:
            with open(self.csv, 'w') as csvfile:
                writer = csv.writer(csvfile)
                writer.writerow([
224
225
226
227
                    'batch', 'num_promts', 'RPS', 'RPM', 'FTL(ave)(s)',
                    'FTL(min)(s)', 'FTL(max)(s)', '50%(s)', '75%(s)', '95%(s)',
                    '99%(s)', 'throughput(out tok/s)',
                    'throughput(total tok/s)'
228
229
230
                ])
                writer.writerow([
                    concurrency,
231
232
                    len(requests), f'{rps:.3f}', f'{rpm:.3f}',
                    f'{first_token_latency_ave:.3f}' if stream_output else '-',
233
234
                    f'{first_token_latency_min:.3f}' if stream_output else '-',
                    f'{first_token_latency_max:.3f}' if stream_output else '-',
235
236
237
238
                    f'{percentiles[0]:.3f}' if stream_output else '-',
                    f'{percentiles[1]:.3f}' if stream_output else '-',
                    f'{percentiles[2]:.3f}' if stream_output else '-',
                    f'{percentiles[3]:.3f}' if stream_output else '-',
239
                    f'{completion_token_throughput:.3f}',
240
                    f'{total_token_throughput:.3f}'
241
242
                ])

243

zhouxiang's avatar
zhouxiang committed
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
def parse_args():
    parser = argparse.ArgumentParser(
        description='Benchmark the request throughput of lmdeploy '
        'in localhost',
        formatter_class=DefaultsAndTypesHelpFormatter)
    parser.add_argument('dataset', type=str, help='the path dataset')
    parser.add_argument('model_path',
                        type=str,
                        help='the path of the model in localhost or '
                        'the repo_id of the model in huggingface.co')
    parser.add_argument(
        '-c',
        '--concurrency',
        type=int,
        help='Number of working threads to process the sampled prompts',
        default=256)
    parser.add_argument('-n',
                        '--num-prompts',
                        type=int,
                        help='Number of prompts to process',
                        default=5000)
    parser.add_argument('--csv',
                        type=str,
                        help='Where to save the result.',
                        default='./profile_throughput.csv')
    parser.add_argument('--seed',
                        type=int,
                        default=0,
                        help='Seed used in sampling prompts from dataset')
    # other args
    ArgumentHelper.top_p(parser)
    ArgumentHelper.temperature(parser)
    ArgumentHelper.top_k(parser)
    ArgumentHelper.log_level(parser)
    ArgumentHelper.backend(parser)

    # pytorch engine args
    pt_group = parser.add_argument_group('PyTorch engine arguments')
    tp_act = ArgumentHelper.tp(pt_group)
    session_len_act = ArgumentHelper.session_len(pt_group, default=4096)
    cache_count_act = ArgumentHelper.cache_max_entry_count(pt_group)

    # turbomind engine args
    tb_group = parser.add_argument_group('TurboMind engine argument')
    tb_group._group_actions.append(tp_act)
    tb_group._group_actions.append(session_len_act)
    tb_group._group_actions.append(cache_count_act)
    ArgumentHelper.model_format(tb_group, default='hf')

    args = parser.parse_args()
    return args


def main():
    args = parse_args()
    random.seed(args.seed)
    os.environ['TM_LOG_LEVEL'] = args.log_level
    if args.backend == 'turbomind':
        engine_config = TurbomindEngineConfig(
            session_len=args.session_len,
            max_batch_size=args.concurrency,
            tp=args.tp,
            cache_max_entry_count=args.cache_max_entry_count,
            model_format=args.model_format)
    elif args.backend == 'pytorch':
        engine_config = PytorchEngineConfig(
            session_len=args.session_len,
            cache_max_entry_count=args.cache_max_entry_count,
            max_batch_size=args.concurrency,
            tp=args.tp,
            thread_safe=True)

    engine = Engine(args.model_path, engine_config, csv=args.csv)

    requests = sample_requests(args.dataset, args.num_prompts,
                               engine.tokenizer)

    engine.process_request(requests,
                           temperature=args.temperature,
                           top_p=args.top_p,
                           top_k=args.top_k,
                           concurrency=args.concurrency,
                           stream_output=True)
327
328
329


if __name__ == '__main__':
zhouxiang's avatar
zhouxiang committed
330
    main()