profile_hf_generation.py 9.05 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""Benchmark huggingface models and maybe speedup by deepspeed.

Theoretically, this tool is compatible with all huggingface models.

Example 1: Test huggingface llama2 with simulated 128 input token and 128 generated token

```shell
python profile_hf_generation.py \
    --model_path $PATH_TO_HF_LLAMA2 \
    --batch_size 16 \
    --input_seqlen 128 \
    --gen_seqlen 128 \
    --out_file profile_hf.csv
```

Example 2: Same test above but accelerated with DeepSpeed inference and more round to test

```shell
python profile_hf_generation.py \
    --model_path $PATH_TO_HF_LLAMA2 \
    --batch_size 16 \
    --input_seqlen 128 \
    --gen_seqlen 128 \
    --test_round 2 \
    --accel deepspeed \
    --out_file profile_hf.csv
```

Example 3: Same test above but do not use streamer to measure time of every token
        Only overall time is measured, a little bit faster than streaming mode

```shell
python profile_hf_generation.py \
    --model_path $PATH_TO_HF_LLAMA2 \
    --batch_size 16 \
    --input_seqlen 128 \
    --gen_seqlen 128 \
    --test_round 2 \
    --accel deepspeed \
    --no-streamer \
    --out_file profile_hf.csv
```

Result will be saved in `profile_hf.csv`, which is a comma-separated file
with the following fields.

1. model: name of model, specified with --model_log_name or otherwise f"{model.__class__.__name__}-{accel}"
2. batch_size: as name
3. input_seqlen: as name
4. gen_len: length of sequence length to generate
5. total_len: gen_len + input_len

In total, the model will take a random input ids of shape (batch_size, input_seqlen) and
run forward `gen_len` times to generate output ids of shape (batch_size, gen_len)

6. first_time: latency to forward the first (batch_size, input_seqlen) ids
    to get `input_seqlen+1`-th batch of output of shape (batch_size, 1)
7. next_time: average latency of the next samples (averaged of 5 sample),
    this measure latency when context length is short
8. last_time: average latency of the last samples (averaged of 5 sample),
    this measure latency when context length is long
9. total time: total time to generate all tokens
10. throughput(total): bs * total_len / total_time (same measure as vllm)
11. throughput(gen): bs * gen_len / total_time
"""   # noqa: E501

import csv
import logging
import os
import time
from typing import Optional

import fire
import torch
from transformers import AutoModelForCausalLM, GenerationConfig

from lmdeploy.pytorch.accel import LoadNoInit
zhouxiang's avatar
zhouxiang committed
78
from lmdeploy.utils import get_logger
79

zhouxiang's avatar
zhouxiang committed
80
logger = get_logger(__file__)
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
logger.setLevel(logging.DEBUG)
info = logger.info
warning = logger.warning
debug = logger.debug
cinfo = lambda x: info('\033[1;32m' + x + '\033[0m')  # noqa E731
cprint = lambda x: print('\033[1;32m' + x + '\033[0m')  # noqa E731
avg = lambda x: sum(x) / len(x)  # noqa E731


class TimingStreamer:
    """Timing helper for HuggingFace models."""

    def __init__(self) -> None:
        # self.token_cache = []
        # self.tokens = None

        torch.cuda.synchronize()

        self.evts = []
        # self.value = 0

    def put(self, value):
        """
        Notes:
            When `put` is called for the first time, no prompt is feed to the model yet.
            When `put` is called later, event is recorded for the previous generation,
                which means the second event records the time for the first prompt.

            GenerationMixin will call `.cpu()` on output token which implies a sync.
        """  # noqa: E501
        # self.value += 1
        # self.token_cache.append(value)
        evt = torch.cuda.Event(enable_timing=True)
        evt.record()
        self.evts.append(evt)

    def end(self):
        torch.cuda.synchronize()
        # self.tokens = torch.hstack([_atleast_2d(v) for v in self.token_cache])    # noqa: E501

    def get_times(self):
        """Maybe deprecated.

        Returns:
          a list of times in ms for (first prompt, avg next token, total time)
        """
        first = self.evts[0].elapsed_time(self.evts[1])
        rest = [
            self.evts[i].elapsed_time(self.evts[i + 1])
            for i in range(1,
                           len(self.evts) - 1)
        ]
        avg = sum(rest) / len(rest)
        return first + sum(rest), first, avg

    def raw_times(self):
        """
        Returns:
          a list of times in ms.
        """
        evts = self.evts
        r = [evts[i].elapsed_time(evts[i + 1]) for i in range(len(evts) - 1)]
        return r


class CSVWritter:

    def __init__(
        self,
        file='unnamed.csv',
        header=[
            'model',
            'batch_size',
            'input_seqlen',
            'gen_len',
            'total_len',
            'first_time',
            'next_time',
            'last_time',
            'total time',
            'throughput(total)',
            'throughput(gen)',
        ],
    ):
        if self.on_master:
            self.file = file
            csv.writer(open(file, 'a')).writerow(header)

    def write(self, line):
        if self.on_master:
            csv.writer(open(self.file, 'a')).writerow(line)

    @property
    def on_master(self):
        # return not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0  # noqa: E501
        rank = int(os.environ.get('RANK', 0))
        return rank == 0


def init_hf_model(model_path: str):
    start = time.monotonic()
    with LoadNoInit():
        model = AutoModelForCausalLM.from_pretrained(model_path,
                                                     torch_dtype=torch.float16,
                                                     trust_remote_code=True)
    print(f'load model in {time.monotonic() -start} s')
    return model


def accel_deepspeed(model, max_out_tokens, tp_size=1):
    import deepspeed
    ds_model = deepspeed.init_inference(
        model=model,  # Transformers models
        tensor_parallel={'tp_size': tp_size},
        dtype=torch.float16,  # dtype of the weights (fp16)
        replace_with_kernel_inject=True,
        max_out_tokens=max_out_tokens,
    )

    return ds_model


def main(model_path: str,
         batch_size: int,
         input_seqlen: int,
         gen_seqlen: int,
         test_round: int = 1,
         accel: Optional[str] = None,
         out_file: Optional[str] = 'profile_hf.csv',
         model_log_name: Optional[str] = None,
         no_streamer: bool = False):

    total_seqlen = input_seqlen + gen_seqlen

    model = init_hf_model(model_path)

    vocab_size = model.config.vocab_size
    if model_log_name is None:
        model_log_name = model.__class__.__name__
        if accel is not None:
            model_log_name += f'-{accel}'

    if accel is None:
        model = model.cuda()
    elif accel == 'deepspeed':
        model = accel_deepspeed(model, total_seqlen + 6)
        # longer total seqlen for fault tolerance
    else:
        raise NotImplementedError(f'accel {accel} not supported.')

    # log to file
    csvwritter = CSVWritter(out_file)

    cprint('Benchmarking {} '
           f'with batch_size={batch_size}, input_seqlen={input_seqlen}, '
           f'gen_seqlen={gen_seqlen}, accel={accel}')
    for r in range(test_round):
        # TODO: now write every round to csv
        # Use external tool for analysis

        cprint(f'Test round {r}')
        # input_id = 0 sometimes cause some cuda error
        fake_inputs = torch.randint(10, vocab_size, (batch_size, input_seqlen))
        fake_inputs = fake_inputs.cuda()

        ts = TimingStreamer() if not no_streamer else None

        torch.cuda.synchronize()
        start = time.monotonic()
        fake_outputs = model.generate(
            fake_inputs,
            GenerationConfig(max_new_tokens=gen_seqlen,
                             do_sample=False,
                             eos_token_id=[-1]),
            streamer=ts,
        )
        torch.cuda.synchronize()
        end = time.monotonic()
        assert fake_outputs.size() == (batch_size,
                                       total_seqlen), fake_outputs.size()

        # total_time, first_time, _ = ts.get_times()
        if no_streamer:
            total_time = (end - start) * 1000
            first_time = next_time = last_time = 0
        else:
            raw_times = ts.raw_times()  # You may further analyze this
            total_time = sum(raw_times)
            first_time = raw_times[0]
            next_time = avg(raw_times[1:6])
            last_time = avg(raw_times[-5:])
        tt = batch_size * total_seqlen * 1000 / total_time
        tg = batch_size * gen_seqlen * 1000 / total_time
        cprint(f'First token/ms: {first_time:.1f}, '
               f'Next tokens/ms: {next_time:.1f}, '
               f'Last tokens/ms: {last_time:.1f}, '
               f'Total Time/ms: {total_time:5.3f}, '
               f'Throughput Total(tok/s): {tt:5.3f}, '
               f'Throughput Gen(tok/s): {tg:5.3f}')

        if test_round > 1 and r > 0:
            # First round is warm up
            csvwritter.write([
                model_log_name, batch_size, input_seqlen, gen_seqlen,
                total_seqlen, first_time, next_time, last_time, total_time, tt,
                tg
            ])


if __name__ == '__main__':
    fire.Fire(main)