benchmark_latency.py 2.29 KB
Newer Older
1
2
3
4
5
6
7
8
import argparse
import time
from typing import List

from tqdm import tqdm
import numpy as np
import torch

9
10
11
from cacheflow.master.server import (
    add_server_arguments, process_server_arguments,
    init_local_server_and_frontend_with_arguments)
12
13
14
15
from cacheflow.sampling_params import SamplingParams


def main(args: argparse.Namespace):
16
    server, frontend = init_local_server_and_frontend_with_arguments(args)
17
18

    sampling_params_dict = {
19
20
        'n': args.n,
        'temperature': 0.0 if args.use_beam_search else 1.0,
21
        'top_p': 1.0,
22
        'use_beam_search': args.use_beam_search,
23
24
25
26
        'stop_token_ids': set(),
        'max_num_steps': args.output_len,
    }
    sampling_params = SamplingParams.from_dict(sampling_params_dict)
27
    print(sampling_params)
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
    input_token_ids = [0] * args.input_len

    def profile_step(profile=False):
        if profile:
            torch.cuda.cudart().cudaProfilerStart()
        for _ in range(args.batch_size):
            frontend._add_query(input_token_ids, sampling_params)
        server.add_sequence_groups(frontend.get_inputs())
        start_time = time.time()
        while True:
            server.step()
            if not server.has_unfinished_requests():
                break
        end_time = time.time()
        latency = end_time - start_time
        if profile:
            torch.cuda.cudart().cudaProfilerStop()
        return latency

    print("Warm up step")
    profile_step()

    # Benchmark.
    latencies = []
    for _ in tqdm(range(3), desc="Profile step"):
        latencies.append(profile_step())
    print(f'Avg latency: {np.mean(latencies)} seconds')


if __name__ == '__main__':
58
59
    parser = argparse.ArgumentParser(
        description='Benchmark the latency of decoding a single sentence.')
60
61
62
63
    parser = add_server_arguments(parser)
    parser.add_argument('--input-len', type=int, default=32)
    parser.add_argument('--output-len', type=int, default=128)
    parser.add_argument('--batch-size', type=int, default=8)
64
65
    parser.add_argument('--n', type=int, default=1)
    parser.add_argument('--use-beam-search', action='store_true')
66
    args = parser.parse_args()
67
    args = process_server_arguments(args)
68
69
    args.max_num_batched_tokens = max(
        args.max_num_batched_tokens, args.batch_size * args.input_len)
70
71
    print(args)
    main(args)