profile_generation.py 16.2 KB
Newer Older
1
2
3
4
5
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import csv
import logging
import os
lvhan028's avatar
lvhan028 committed
6
import time
7
from dataclasses import dataclass
8
9
from queue import Queue
from threading import Thread
10
from typing import List
lvhan028's avatar
lvhan028 committed
11
12

import numpy as np
13
14
15
16
17
from pynvml import (NVMLError, nvmlDeviceGetCount, nvmlDeviceGetHandleByIndex,
                    nvmlDeviceGetMemoryInfo, nvmlDeviceGetName,
                    nvmlDeviceGetPowerState, nvmlDeviceGetTemperature,
                    nvmlInit, nvmlShutdown, nvmlSystemGetDriverVersion)
from tqdm import tqdm
lvhan028's avatar
lvhan028 committed
18

19
from lmdeploy.turbomind import TurboMind
lvhan028's avatar
lvhan028 committed
20
21


22
def infer(model, session_id: int, input_ids: List, output_seqlen: int,
23
24
25
26
          top_k: int, top_p: float, temperature: float, test_round: int,
          que: Queue):
    if session_id == 1:
        pbar = tqdm(total=test_round)
q.yao's avatar
q.yao committed
27
    chatbot = model.create_instance()
lvhan028's avatar
lvhan028 committed
28
    stats = []
29
30
31
    for _ in range(test_round):
        token_latency_stats = [0] * (output_seqlen + 1)
        prev = time.perf_counter()
32
        n_prev_token = 0
33
34
35
36
37
38
        """
        The iterator provided by `stream_infer` denotes the number of generated tokens so far,
        which is represented by the variable `n_token`.
        Please note that `n_token` is not a continuous value. In other words, during the iteration,
        its value might be 5, 7, 8, 16, and so on, rather than 1, 2, 3, 4, etc.
        So, it is quite difficult to get the latency of each generated token.
39
        As a work-around, we set the latency `now-prev` of each iteration to the first token of
40
41
42
43
44
        the new generated tokens, and leave the latency of the rest tokens being 0.
        For example, in the first iteration, 5 tokens are generated.
        The time elapsing in this iteration `now-prev` is set to the latency of first token of
        the 5 tokens, i.e. `token_latency_stats[0]`, and `token_latency_stats[1:4]` is set 0`
        """   # noqa: E501
45
46
47
48
49
        for outputs in chatbot.stream_infer(session_id,
                                            input_ids,
                                            request_output_len=output_seqlen,
                                            sequence_start=True,
                                            sequence_end=True,
50
                                            ignore_eos=True,
51
52
53
54
                                            stream_output=True,
                                            top_k=top_k,
                                            top_p=top_p,
                                            temperature=temperature):
55
56
            _, n_token = outputs[0]
            now = time.perf_counter()
57
58
59
            if n_prev_token != n_token:
                token_latency_stats[n_prev_token] = np.round(now - prev, 3)
                n_prev_token = n_token
60
            prev = now
61
62
        if session_id == 1:
            pbar.update(1)
63
64
65
66
67

        assert output_seqlen <= n_token <= output_seqlen + 1, \
            f'Error. session_id({session_id}) request {output_seqlen} ' \
            f'tokens, but generate {n_token} tokens'
        stats.append(token_latency_stats[:output_seqlen])
lvhan028's avatar
lvhan028 committed
68
69
70
    que.put((session_id, stats))


71
72
73
74
75
def warmup(model, concurrency: int, input_ids: List[int], output_seqlen: int,
           warmup_round: int):
    if not warmup_round:
        return

lvhan028's avatar
lvhan028 committed
76
77
    print('start to warmup ...')

q.yao's avatar
q.yao committed
78
79
    def _infer(model, session_id):
        chatbot = model.create_instance()
lvhan028's avatar
lvhan028 committed
80
        for _ in range(warmup_round):
81
            for _ in chatbot.stream_infer(session_id,
82
                                          input_ids=input_ids,
83
84
85
                                          request_output_len=output_seqlen,
                                          sequence_start=True,
                                          sequence_end=True,
86
87
88
89
                                          ignore_eos=True,
                                          top_k=1,
                                          top_p=1.0,
                                          temperature=1.0):
lvhan028's avatar
lvhan028 committed
90
91
92
93
                continue

    _start = time.perf_counter()
    procs = []
q.yao's avatar
q.yao committed
94
95
    for i in range(concurrency):
        proc = Thread(target=_infer, args=(model, i + 1))
lvhan028's avatar
lvhan028 committed
96
97
        procs.append(proc)
        proc.start()
q.yao's avatar
q.yao committed
98

99
100
101
    for proc in procs:
        proc.join()

lvhan028's avatar
lvhan028 committed
102
103
104
105
    _end = time.perf_counter()
    print(f'end warmup, elapsed time: {round(_end - _start, 2)}s')


106
107
108
def profile_throughput(model_path: str, concurrency: int, input_seqlen: int,
                       output_seqlen: int, tp: int, top_k: int, top_p: float,
                       temperature: float, test_round: int, warmup_round: int,
109
                       **kwargs):
110
111
112
113
114
115
116

    print(f'profiling ... concurrency: {concurrency}, '
          f'n_prompt_token: {input_seqlen}, '
          f'n_completion_token: {output_seqlen}, '
          f'test_round: {test_round}, warmup_round: {warmup_round}')

    # avoid turbomind checking chat template name by setting `model_name='llama'` # noqa
117
118
119
120
    tm_model = TurboMind(model_path=model_path,
                         tp=tp,
                         model_name='llama',
                         **kwargs)
q.yao's avatar
q.yao committed
121

122
    # make up a dummy `input_ids` with the length of `input_seqlen` exactly
123
    assert input_seqlen > 0, 'input_seqlen should > 0'
124
125
    input_ids = np.random.randint(low=0, high=101, size=input_seqlen).tolist()
    warmup(tm_model, concurrency, input_ids, output_seqlen, warmup_round)
126

q.yao's avatar
q.yao committed
127
    que = Queue()
lvhan028's avatar
lvhan028 committed
128
129
    procs = []
    _start = time.perf_counter()
q.yao's avatar
q.yao committed
130

lvhan028's avatar
lvhan028 committed
131
    for i in range(concurrency):
q.yao's avatar
q.yao committed
132
        proc = Thread(target=infer,
133
134
                      args=(tm_model, i + 1, input_ids, output_seqlen, top_k,
                            top_p, temperature, test_round, que))
lvhan028's avatar
lvhan028 committed
135
136
        procs.append(proc)
        proc.start()
q.yao's avatar
q.yao committed
137

138
139
140
    for proc in procs:
        proc.join()

lvhan028's avatar
lvhan028 committed
141
142
143
    _end = time.perf_counter()
    elapsed_time = _end - _start

144
    token_latency_stats = []
lvhan028's avatar
lvhan028 committed
145
    while not que.empty():
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
        _, _stats = que.get()
        token_latency_stats += _stats

    # The shape is [concurrency*test_round, output_seqlen]
    token_latency_stats = np.stack(token_latency_stats, axis=0)

    first_token_latency_min = np.round(
        np.min(token_latency_stats[:, 0], axis=0), 3)
    first_token_latency_max = np.round(
        np.max(token_latency_stats[:, 0], axis=0), 3)
    first_token_latency_ave = np.round(
        np.mean(token_latency_stats[:, 0], axis=0), 3)
    token_latency_max = np.round(np.max(np.sum(token_latency_stats, axis=1)),
                                 3)
    token_latency_min = np.round(np.min(np.sum(token_latency_stats, axis=1)),
                                 3)
    token_latency_ave = np.round(np.mean(np.sum(token_latency_stats, axis=1)),
                                 3)
    # sort token_latency without the first token's latency
    sorted_token_latency = np.sort(token_latency_stats[:, 1:].flatten())
    percentiles = [
        np.round(
            sorted_token_latency[int(percent * len(sorted_token_latency))], 3)
        for percent in [0.5, 0.75, 0.95, 0.99]
    ]

    throughput = np.round(token_latency_stats.size / elapsed_time, 2)
    print(f'\n{"-" * 50}\ntotal time: {elapsed_time:.2f}s\n'
          f'concurrency: {concurrency}, test_round: {test_round}\n'
          f'input_tokens: {input_seqlen}, output_tokens: {output_seqlen}\n'
lvhan028's avatar
lvhan028 committed
176
          f'first_token latency(min, max, ave): '
177
178
179
180
181
182
183
184
185
186
          f'{first_token_latency_min}s, {first_token_latency_max}s, '
          f'{first_token_latency_ave}s\ntotal_token latency(min, max, ave): '
          f'{token_latency_min}s, {token_latency_max}s, '
          f'{token_latency_ave}s\n'
          f'token_latency percentiles(50%,75%,95%,99%)(s): {percentiles}\n'
          f'throughput: {throughput} token/s\n{"-" * 50}')
    return tm_model.model_name, \
        [first_token_latency_min, first_token_latency_max,
         first_token_latency_ave], \
        percentiles, throughput, tm_model.gpu_count
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266


class MemoryMonitor:
    from multiprocessing import Manager
    max_mem = Manager().Value('f', 0)  # GB
    device_count = Manager().Value('f', 0)

    @staticmethod
    def nvidia_info():
        # pip install nvidia-ml-py
        nvidia_dict = {
            'state': True,
            'nvidia_version': '',
            'nvidia_count': 0,
            'gpus': []
        }
        try:
            nvmlInit()
            nvidia_dict['nvidia_version'] = nvmlSystemGetDriverVersion()
            nvidia_dict['nvidia_count'] = nvmlDeviceGetCount()
            for i in range(nvidia_dict['nvidia_count']):
                handle = nvmlDeviceGetHandleByIndex(i)
                memory_info = nvmlDeviceGetMemoryInfo(handle)
                gpu = {
                    'gpu_name': nvmlDeviceGetName(handle),
                    'total': memory_info.total,
                    'free': memory_info.free,
                    'used': memory_info.used,
                    'temperature': f'{nvmlDeviceGetTemperature(handle, 0)}℃',
                    'powerStatus': nvmlDeviceGetPowerState(handle)
                }
                nvidia_dict['gpus'].append(gpu)
        except NVMLError as _:  # noqa
            nvidia_dict['state'] = False
        except Exception as _:  # noqa
            nvidia_dict['state'] = False
        finally:
            try:
                nvmlShutdown()
            except:  # noqa
                pass
        return nvidia_dict

    @classmethod
    def mem_monitor(cls):
        info = cls.nvidia_info()
        max_mem = 0
        mem_start = 0
        cls.device_count.value = len(info['gpus'])
        for used_total in info['gpus']:
            mem_start += used_total['used']
        while True:
            info = cls.nvidia_info()
            used = 0
            for used_total in info['gpus']:
                used += used_total['used']
            if used > max_mem:
                max_mem = used
                cls.max_mem.value = (max_mem - mem_start) / (1 << 30)

    @classmethod
    def start(cls):
        cls._running = True
        from multiprocessing import Process
        cls.proc = Process(target=cls.mem_monitor)
        cls.proc.start()

    @classmethod
    def terminate(cls) -> float:
        """Terminate the subprocess and return maximum memory."""
        cls.proc.kill()
        return cls.max_mem.value


@dataclass
class ProfileResult:
    model_name: str
    batch: int
    prompt_tokens: int
    completion_tokens: int
267
268
    first_token_latency: List
    percentiles: List
269
270
271
272
273
274
275
276
277
    throughput_per_proc: float
    throughput_per_node: float
    mem_per_proc: float
    mem_per_gpu: float
    mem_per_node: float


def parse_args():
    parser = argparse.ArgumentParser(description='Regression Test')
278
    parser.add_argument('model_path',
279
                        type=str,
280
281
                        help='the path of the model in localhost or '
                        'the repo_id of the model in huggingface.co')
282
283
284
285
    parser.add_argument('--concurrency',
                        nargs='+',
                        type=int,
                        help='how many requests launched concurrently',
286
                        default=[1, 16, 32, 64])
287
288
289
290
291
292
    parser.add_argument(
        '--prompt-tokens',
        nargs='+',
        type=int,
        help='how many requests launched concurrently. One-to-one'
        'correspondence with completion-tokens',
293
        default=[1, 128, 128, 2048, 2048])
294
295
296
297
298
    parser.add_argument('--completion-tokens',
                        nargs='+',
                        type=int,
                        help='how many tokens to be generated. One-to-one'
                        'correspondence with prompt-tokens',
299
                        default=[128, 128, 2048, 128, 2048])
300
    parser.add_argument('--tp', type=int, help='Tensor parallel', default=1)
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
    parser.add_argument('--top_k',
                        type=int,
                        help='The number of highest probability vocabulary '
                        'tokens to keep for top-k-filtering',
                        default=1)
    parser.add_argument('--top_p',
                        type=float,
                        help='the set of most probable tokens with '
                        'probabilities that add up to top_p or higher '
                        'are kept for generation',
                        default=1.0)
    parser.add_argument('--temperature',
                        type=float,
                        help='The value used to modulate the next token '
                        'probabilities',
                        default=1.0)
    parser.add_argument('--csv',
318
319
320
321
322
                        type=str,
                        help='Where to save the result.',
                        default='profile_generation.csv')
    parser.add_argument('--log-level',
                        help='set log level',
323
                        default='ERROR',
324
                        choices=list(logging._nameToLevel.keys()))
325
326
327
    parser.add_argument('--test-round',
                        type=int,
                        help='number of test rounds',
328
329
330
331
332
                        default=6)
    parser.add_argument('--warmup-round',
                        type=int,
                        help='number of warmuop rounds',
                        default=1)
333
334
335
336
337
338
    args = parser.parse_args()
    return args


def main():
    args = parse_args()
339
340
341
342
    assert len(args.prompt_tokens) == len(args.completion_tokens), \
        f'mismatched size between `prompt-tokens` and `completion-tokenes`' \
        f', {len(args.prompt_tokens)} vs {len(args.completion_tokens)}'

343
344
    os.environ['TM_LOG_LEVEL'] = args.log_level
    results: List[ProfileResult] = []
345
346
347
    for batch in args.concurrency:
        for prompt_tokens, completion_tokens in zip(args.prompt_tokens,
                                                    args.completion_tokens):
348
349
350
351
352
353
354
            MemoryMonitor.start()
            from functools import partial
            from multiprocessing import Pool
            profile_target = partial(profile_throughput,
                                     concurrency=batch,
                                     input_seqlen=prompt_tokens,
                                     output_seqlen=completion_tokens,
355
356
357
358
                                     tp=args.tp,
                                     top_k=args.top_k,
                                     top_p=args.top_p,
                                     temperature=args.temperature,
359
360
                                     test_round=args.test_round,
                                     warmup_round=args.warmup_round)
361
            output = Pool(1).map(profile_target, (args.model_path, ))
362
363
            model_name, first_token_latency, percentiles, \
                throughput_per_proc, tp = output[0]
364
365
366
367
368
369
370
371
            time.sleep(5)  # wait a while for releasing GPU mem
            memory = MemoryMonitor.terminate()
            device_count = MemoryMonitor.device_count.value
            results.append(
                ProfileResult(model_name=model_name,
                              batch=batch,
                              prompt_tokens=prompt_tokens,
                              completion_tokens=completion_tokens,
372
373
                              first_token_latency=first_token_latency,
                              percentiles=percentiles,
374
375
376
377
378
379
                              throughput_per_proc=throughput_per_proc,
                              throughput_per_node=throughput_per_proc / tp *
                              device_count,
                              mem_per_proc=memory,
                              mem_per_gpu=memory / tp,
                              mem_per_node=memory / tp * device_count))
380
381
382
    if args.csv:
        with open(args.csv, 'w') as csvfile:
            writer = csv.writer(csvfile)
383
            writer.writerow([
384
385
386
387
388
                'batch', 'prompt_tokens', 'completion_tokens',
                '1st_token_latency(min)(s)', '1st_token_latency(max)(s)',
                '1st_token_latency(ave)(s)', 'percentile50(s)',
                'percentile75(s)', 'percentile95(s)', 'percentile99(s)',
                'throughput(token/s)', 'mem_per_proc(GB)', 'mem_per_gpu(GB)'
389
            ])
390
391
392
393
394
395
396
397
398
            for re in results:
                writer.writerow([
                    re.batch, re.prompt_tokens, re.completion_tokens,
                    re.first_token_latency[0], re.first_token_latency[1],
                    re.first_token_latency[2], re.percentiles[0],
                    re.percentiles[1], re.percentiles[2], re.percentiles[3],
                    f'{re.throughput_per_proc:.2f}', f'{re.mem_per_proc:.2f}',
                    f'{re.mem_per_gpu:.2f}'
                ])
lvhan028's avatar
lvhan028 committed
399
400
401


if __name__ == '__main__':
402
    main()