profile_throughput.py 10.9 KB
Newer Older
1
2
# Copyright (c) OpenMMLab. All rights reserved.
import csv
3
import json
4
import os
5
6
7
8
9
10
11
import random
import time
from queue import Queue
from threading import Thread
from typing import List, Tuple

import fire
12
import numpy as np
13
from tqdm import tqdm
14

15
16
from lmdeploy.tokenizer import Tokenizer
from lmdeploy.turbomind import TurboMind
17
18
19
20
21
22
23
24
25
26
27
28
29
30


def sample_requests(
    dataset_path: str,
    num_requests: int,
    tokenizer: Tokenizer,
) -> List[Tuple[str, int, int]]:
    # Load the dataset.
    with open(dataset_path) as f:
        dataset = json.load(f)
    # Filter out the conversations with less than 2 turns.
    dataset = [data for data in dataset if len(data['conversations']) >= 2]
    # Only keep the first two turns of each conversation.
    dataset = [(data['conversations'][0]['value'],
31
                data['conversations'][1]['value']) for data in dataset]
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61

    # Tokenize the prompts and completions.
    prompts = [prompt for prompt, _ in dataset]
    prompt_token_ids = tokenizer(prompts).input_ids
    completions = [completion for _, completion in dataset]
    completion_token_ids = tokenizer(completions).input_ids
    tokenized_dataset = []
    for i in range(len(dataset)):
        output_len = len(completion_token_ids[i])
        tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))

    # Filter out too long sequences.
    filtered_dataset: List[Tuple[str, int, int]] = []
    for prompt, prompt_token_ids, output_len in tokenized_dataset:
        prompt_len = len(prompt_token_ids)
        if prompt_len < 4 or output_len < 4:
            # Prune too short sequences.
            continue
        if prompt_len > 1024 or prompt_len + output_len > 2048:
            # Prune too long sequences.
            continue
        filtered_dataset.append((prompt, prompt_len, output_len))

    # Sample the requests.
    sampled_requests = random.sample(filtered_dataset, num_requests)
    return sampled_requests


class Engine:

62
63
64
65
66
67
68
    def __init__(self, model_path: str, tp: int, csv: str, **kwargs):
        # avoid turbomind checking chat template name by setting
        # `model_name='llama'`
        tm_model = TurboMind(model_path=model_path,
                             model_name='llama',
                             tp=tp,
                             **kwargs)
69
        self.tm_model = tm_model
70
71
72
        self.tokenizer = tm_model.tokenizer
        self.csv = csv
        self.pbar = None
73

74
75
    def _inference(self, req_queue: Queue, res_queue: Queue, session_id: int,
                   stream_output: bool):
76
        model_inst = self.tm_model.create_instance()
77
        stats = []
78
79
        # get each generated token's latency
        per_token_latency_stats = []
80
81
        for prompt, input_seqlen, output_seqlen in iter(
                req_queue.get, [None, None, None]):
82
            _per_token_latency_stats = [0] * (output_seqlen + 1)
83
            offset = 0
84
85
            prev = time.perf_counter()
            n_prev_token = 0
86
87

            input_ids = self.tokenizer(prompt).input_ids
88
89
90
91
92
93
94
95
96
97
            for outputs in model_inst.stream_infer(
                    session_id,
                    input_ids=input_ids,
                    request_output_len=output_seqlen,
                    temperature=1.0,
                    top_p=1.0,
                    sequence_start=True,
                    sequence_end=True,
                    ignore_eos=True,
                    stream_output=stream_output):
98
                res, n_token = outputs[0]
99
                self.tokenizer.decode(res, offset)
100
101
102
103
104
105
106
107
108
                offset = n_token
                now = time.perf_counter()
                if n_prev_token != n_token:
                    _per_token_latency_stats[n_prev_token] = np.round(
                        now - prev, 3)
                    n_prev_token = n_token
                prev = now

            assert output_seqlen <= n_token <= output_seqlen + 1, \
109
                f'Error. session_id({session_id}) request {output_seqlen} ' \
110
                f'tokens, but generate {n_token} tokens.\n' \
111
                f'prompt: {prompt}'
112
113
114
115

            first_token_latency = _per_token_latency_stats[0]
            completion_tokens = n_token
            total_tokens = n_token + input_seqlen
116
117
            stats.append([
                first_token_latency, completion_tokens, output_seqlen,
118
                total_tokens
119
            ])
120
121
            # skip the first token latency
            per_token_latency_stats.append(_per_token_latency_stats[1:])
122
            self.pbar.update(1)
123
        res_queue.put((session_id, stats, per_token_latency_stats))
124
125
126
127
128
129
130

    def process_request(self,
                        requests,
                        concurrency: int = 1,
                        stream_output: bool = True):
        res_queue = Queue()
        req_queue = Queue()
131
132
        threads = []

133
134
        self.pbar = tqdm(total=len(requests))

135
136
137
138
139
140
        # feed request to q
        for req in requests:
            req_queue.put(req)
        for i in range(concurrency):
            req_queue.put([None, None, None])

141
142
143
144
        start = time.time()

        # start threads
        for i in range(concurrency):
145
146
            t = Thread(target=self._inference,
                       args=(req_queue, res_queue, i, stream_output))
147
148
149
150
151
152
153
            t.start()
            threads.append(t)

        # wait for finish
        for t in threads:
            t.join()

154
155
156
        elapsed_time = time.time() - start

        stats = []
157
        per_token_latency_stats = []
158
        while not res_queue.empty():
159
            session_id, _stats, _per_token_latency_stats = res_queue.get()
160
            stats.append(np.array(_stats))
161
162
163
164
165
            per_token_latency_stats += [
                item for sublist in _per_token_latency_stats
                for item in sublist
            ]
        stats = np.concatenate(stats).reshape(-1, 4)
166
167
168
169
170
171
172
173
174

        first_token_latency_min = np.min(stats[:, 0], axis=0)
        first_token_latency_max = np.max(stats[:, 0], axis=0)
        first_token_latency_ave = np.mean(stats[:, 0], axis=0)
        completion_tokens = np.sum(stats[:, 1], axis=0)
        total_tokens = np.sum(stats[:, 3], axis=0)
        prompt_tokens = total_tokens - completion_tokens
        completion_token_throughput = completion_tokens / elapsed_time
        total_token_throughput = total_tokens / elapsed_time
175
176
177
178
179
180
181
182
183
184
        rps = len(requests) / elapsed_time
        rpm = rps * 60

        per_token_latency_stats.sort()
        percentiles = [
            np.round(
                per_token_latency_stats[int(percent *
                                            len(per_token_latency_stats))], 3)
            for percent in [0.5, 0.75, 0.95, 0.99]
        ]
185
186
187
188

        print(f'\n{"-" * 50}\nconcurrency: {concurrency}\n'
              f'elapsed_time: {elapsed_time:.3f}s\n')
        if stream_output:
189
190
191
192
193
194
            print(f'first token latency(s)(min, max, ave): '
                  f'{first_token_latency_min:.3f}, '
                  f'{first_token_latency_max:.3f}, '
                  f'{first_token_latency_ave:.3f}')
            print(f'per-token latency(s) percentile(50, 75, 95, 99): '
                  f'{percentiles}\n')
195
196
197
198
199
        print(
            f'number of prompt tokens: {prompt_tokens:.0f}\n'
            f'number of completion tokens: {completion_tokens:.0f}\n'
            f'token throughput (completion token): {completion_token_throughput:.3f} token/s\n'  # noqa
            f'token throughput (prompt + completion token): {total_token_throughput:.3f} token/s\n'  # noqa
200
201
            f'RPS (request per second): {rps:.3f} req/s\n'
            f'RPM (request per minute): {rpm:.3f} req/min\n'
202
            f'{"-" * 50}\n')
203

204
205
206
207
208
209
210
        if self.csv:
            with open(self.csv, 'w') as csvfile:
                writer = csv.writer(csvfile)
                writer.writerow([
                    'batch', 'num_promts', 'prompt_tokens',
                    'completion_tokens', '1st_token_latency(min)(s)',
                    '1st_token_latency(max)(s)', '1st_token_latency(ave)(s)',
211
212
213
                    'percentile50(s)', 'percentile75(s)', 'percentile95(s)',
                    'percentile99(s)', 'output token thr(tokens/s)',
                    'total token thr(token/s)', 'RPS', 'RPM'
214
215
216
217
218
219
220
                ])
                writer.writerow([
                    concurrency,
                    len(requests), prompt_tokens, completion_tokens,
                    f'{first_token_latency_min:.3f}' if stream_output else '-',
                    f'{first_token_latency_max:.3f}' if stream_output else '-',
                    f'{first_token_latency_ave:.3f}' if stream_output else '-',
221
222
223
224
                    f'{percentiles[0]:.3f}' if stream_output else '-',
                    f'{percentiles[1]:.3f}' if stream_output else '-',
                    f'{percentiles[2]:.3f}' if stream_output else '-',
                    f'{percentiles[3]:.3f}' if stream_output else '-',
225
                    f'{completion_token_throughput:.3f}',
226
                    f'{total_token_throughput:.3f}', f'{rps:.3f}', f'{rpm:.3f}'
227
228
                ])

229
230
231

def main(dataset: str,
         model_path: str,
232
233
         concurrency: int = 64,
         num_prompts: int = 2000,
234
         tp: int = 1,
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
         top_k: int = 1,
         top_p: float = 1.0,
         temperature: float = 1.0,
         stream_output: bool = True,
         csv: str = './profile_throughput.csv',
         log_level: str = 'ERROR',
         seed: int = 0):
    """Benchmark the request throughput of lmdeploy in localhost.

    Args:
        dataset (str): Path to the dataset
        model_path (str): Path to a model in localhost or a model_repo_id in huggingface.co
        concurrency (int, optional): Number of working threads to process the sampled prompts.
            Defaults to 64.
        num_prompts (int, optional): Number of prompts to process. Defaults to 2000.
        tp (int, optional): Number of GPUs for tensor parallel. Defaults to 1.
        top_k (int, optional): The number of highest probability vocabulary tokens
            to keep for top-k-filtering. Defaults to 1.
        top_p (float, optional): the set of most probable tokens with
            probabilities that add up to top_p or higher
            are kept for generation. Defaults to 1.0.
        temperature (float, optional): The value used to modulate the next token probabilities.
            Defaults to 1.0.
        stream_output (bool, optional): Indicator for streaming output. Defaults to True.
        csv (str, optional): The path to save the result.
        log_level(str, optional): The log level. Defaults to INFO
        seed (int, optional): Seed used in sampling prompts from dataset. Defaults to 0.
    """    # noqa
    random.seed(seed)
    os.environ['TM_LOG_LEVEL'] = log_level

    engine = Engine(model_path,
                    tp=tp,
                    top_k=top_k,
                    top_p=top_p,
                    temperature=temperature,
                    csv=csv)

    requests = sample_requests(dataset, num_prompts, engine.tokenizer)
274

275
    engine.process_request(requests, concurrency, stream_output)
276
277
278
279


if __name__ == '__main__':
    fire.Fire(main)