profile_serving.py 6.61 KB
Newer Older
lvhan028's avatar
lvhan028 committed
1
import json
2
import logging
lvhan028's avatar
lvhan028 committed
3
4
5
6
7
8
9
10
11
12
import multiprocessing as mp
import os
import random
import time
from typing import List

import fire
import numpy as np
from sentencepiece import SentencePieceProcessor

13
from lmdeploy.serve.turbomind.chatbot import Chatbot
lvhan028's avatar
lvhan028 committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


class Tokenizer:

    def __init__(self, model_path: str):
        # reload tokenizer
        assert os.path.isfile(model_path), model_path
        self.sp_model = SentencePieceProcessor(model_file=model_path)

    def encode(self, prompts: List):
        prompts_token_ids = self.sp_model.Encode(prompts,
                                                 add_bos=False,
                                                 add_eos=False)
        return [len(token_ids) for token_ids in prompts_token_ids]


def infer(chatbot, session_id: int, req_que: mp.Queue, res_que: mp.Queue):
    stats = []
32
33
    for prompt, input_seqlen, output_seqlen in iter(req_que.get,
                                                    [None, None, None]):
lvhan028's avatar
lvhan028 committed
34
35
36
37
38
39
40
41
42
43
44
45
        timestamps = []
        tokens = []
        start = time.perf_counter()
        for status, res, token in chatbot.stream_infer(
                session_id,
                prompt,
                request_output_len=output_seqlen,
                sequence_start=True,
                sequence_end=True):
            timestamps.append(time.perf_counter())
            tokens.append(token)

46
47
        first_token_latency = np.round(timestamps[1] - start, 3)
        token_latency = np.round(timestamps[-1] - timestamps[0], 3)
lvhan028's avatar
lvhan028 committed
48
49
        token = tokens[-1] - tokens[0]
        stats.append([first_token_latency, token, token_latency])
50
51
        print(f'session {session_id}: '
              f'input_seqlen {input_seqlen}, output_seqlen {output_seqlen}')
lvhan028's avatar
lvhan028 committed
52
53
54
55
56
57
    res_que.put((session_id, stats))


def warmup(tritonserver_addr: str,
           concurrency: int,
           output_seqlen: int,
58
           warmup_round: int = 1):
lvhan028's avatar
lvhan028 committed
59
60
61
62
    print('start to warmup ...')

    def _infer(_chatbot, session_id):
        for _ in range(warmup_round):
63
            for _, _, _ in _chatbot.stream_infer(
lvhan028's avatar
lvhan028 committed
64
65
66
67
68
69
                    session_id,
                    prompt='',
                    request_output_len=output_seqlen,
                    sequence_start=True,
                    sequence_end=True):
                continue
70
            _chatbot.reset_session()
lvhan028's avatar
lvhan028 committed
71
72
73
74
75

    _start = time.perf_counter()
    chatbots = [
        Chatbot(tritonserver_addr=tritonserver_addr,
                ignore_eos=True,
76
                log_level=logging.ERROR,
lvhan028's avatar
lvhan028 committed
77
78
79
80
81
82
83
84
85
86
87
88
89
                profile_generation=True) for _ in range(concurrency)
    ]
    procs = []
    for i, chatbot in enumerate(chatbots):
        proc = mp.Process(target=_infer, args=(chatbot, i + 1))
        procs.append(proc)
        proc.start()
    for proc in procs:
        proc.join()
    _end = time.perf_counter()
    print(f'end warmup, elapsed time: {round(_end - _start, 2)} s')


90
def read_dataset(tokenizer_path: str, dataset_path: str, samples: int,
91
                 session_len: int, que: mp.Queue):
lvhan028's avatar
lvhan028 committed
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
    start = time.perf_counter()
    with open(dataset_path) as f:
        dataset = json.load(f)
        dataset = [data for data in dataset if len(data['conversations']) >= 2]
        # Only keep the first two turns of each conversation.
        dataset = [(data['conversations'][0]['value'],
                    data['conversations'][1]['value']) for data in dataset]
        prompts = [prompt for prompt, _ in dataset]
        completions = [completion for _, completion in dataset]
        print(f'elapsed time for read data: '
              f'{round(time.perf_counter() - start, 2)} s')

    start = time.perf_counter()
    tokenizer = Tokenizer(tokenizer_path)
    prompts_token_lens = tokenizer.encode(prompts)
    completions_token_lens = tokenizer.encode(completions)
    print(f'elapsed time for tokenization: '
          f'{round(time.perf_counter() - start, 2)} s')

    start = time.perf_counter()
    filtered_dataset = []
    for (prompt, _), input_len, output_len in zip(dataset, prompts_token_lens,
                                                  completions_token_lens):
        if input_len + output_len > session_len:
            # ignore too long conversation
            continue
        filtered_dataset.append([prompt, input_len, output_len])

    if samples > 0:
        filtered_dataset = random.sample(filtered_dataset, samples)

    for data in filtered_dataset:
        que.put(data)
    print(f'elapsed time for filtering: '
          f'{round(time.perf_counter() - start, 2)} s')
127
    return len(filtered_dataset)
lvhan028's avatar
lvhan028 committed
128
129
130
131
132
133
134


def main(tritonserver_addr: str,
         tokenizer_path: str,
         dataset_path: str,
         concurrency: int = 1,
         session_len: int = 2048,
135
         samples: int = 1000):
136
    warmup(tritonserver_addr, concurrency, session_len - 1)
137
    req_que = mp.Queue()
lvhan028's avatar
lvhan028 committed
138
    res_que = mp.Queue()
139

lvhan028's avatar
lvhan028 committed
140
141
142
143
144
145
    procs = []
    _start = time.perf_counter()
    for i in range(concurrency):
        chatbot = Chatbot(tritonserver_addr=tritonserver_addr,
                          display=False,
                          profile_serving=True,
146
147
                          ignore_eos=True,
                          log_level=logging.ERROR)
lvhan028's avatar
lvhan028 committed
148
149
150
151
        proc = mp.Process(target=infer,
                          args=(chatbot, i + 1, req_que, res_que))
        procs.append(proc)
        proc.start()
152
153
154
155
156
157

    # read data and put it to queue
    n_req = read_dataset(tokenizer_path, dataset_path, samples, session_len,
                         req_que)
    for i in range(concurrency):
        req_que.put([None, None, None])
lvhan028's avatar
lvhan028 committed
158
159

    stats = []
160
    for i in range(concurrency):
lvhan028's avatar
lvhan028 committed
161
162
        session_id, _stats = res_que.get()
        print(f'\n{"-" * 50}\n'
163
164
              f'session {session_id}: processed reqs {len(_stats)}, '
              f'stats: \n{_stats}\n{"-" * 50}\n')
165
        stats.append(np.array(_stats))
lvhan028's avatar
lvhan028 committed
166

167
168
169
    _end = time.perf_counter()
    elapsed_time = _end - _start

170
    stats = np.concatenate(stats).reshape(-1, 3)
lvhan028's avatar
lvhan028 committed
171
172
173
174

    first_token_latency_min = np.min(stats[:, 0], axis=0)
    first_token_latency_max = np.max(stats[:, 0], axis=0)
    first_token_latency_ave = np.mean(stats[:, 0], axis=0)
175
176
177
    token_throughput = np.sum(stats[:, 1], axis=0) / elapsed_time
    req_throughput = n_req / elapsed_time

178
    print(f'\n{"-" * 50}\nconcurrency: {concurrency}\n'
179
          f'elapsed_time: {elapsed_time:.3f}s\n'
lvhan028's avatar
lvhan028 committed
180
          f'first_token latency(min, max, ave): '
181
182
183
184
          f'{first_token_latency_min:.3f}s, {first_token_latency_max:.3f}s, '
          f'{first_token_latency_ave:.3f}s\n'
          f'token throughput: {token_throughput:.3f} token/s\n'
          f'req throughput: {req_throughput:.3f} req/s\n'
185
          f'{"-" * 50}\n')
lvhan028's avatar
lvhan028 committed
186

187
188
189
    for proc in procs:
        proc.join()

lvhan028's avatar
lvhan028 committed
190
191
192

if __name__ == '__main__':
    fire.Fire(main)