benchmark_latency.py 2.25 KB
Newer Older
1
2
3
4
5
6
7
8
import argparse
import time
from typing import List

from tqdm import tqdm
import numpy as np
import torch

Woosuk Kwon's avatar
Woosuk Kwon committed
9
from cacheflow.core.server import (
10
11
    add_server_arguments, process_server_arguments,
    init_local_server_and_frontend_with_arguments)
12
13
14
15
from cacheflow.sampling_params import SamplingParams


def main(args: argparse.Namespace):
16
    server, frontend = init_local_server_and_frontend_with_arguments(args)
17

Woosuk Kwon's avatar
Woosuk Kwon committed
18
19
20
21
22
23
24
25
    sampling_params = SamplingParams(
        n=args.n,
        temperature=0.0 if args.use_beam_search else 1.0,
        top_p=1.0,
        use_beam_search=args.use_beam_search,
        stop_token_ids=set(),
        max_tokens=args.output_len,
    )
26
    print(sampling_params)
27
28
29
30
31
32
    input_token_ids = [0] * args.input_len

    def profile_step(profile=False):
        if profile:
            torch.cuda.cudart().cudaProfilerStart()
        for _ in range(args.batch_size):
Woosuk Kwon's avatar
Woosuk Kwon committed
33
34
            dummy_prompt = ""
            frontend._add_query(dummy_prompt, input_token_ids, sampling_params)
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
        server.add_sequence_groups(frontend.get_inputs())
        start_time = time.time()
        while True:
            server.step()
            if not server.has_unfinished_requests():
                break
        end_time = time.time()
        latency = end_time - start_time
        if profile:
            torch.cuda.cudart().cudaProfilerStop()
        return latency

    print("Warm up step")
    profile_step()

    # Benchmark.
    latencies = []
    for _ in tqdm(range(3), desc="Profile step"):
        latencies.append(profile_step())
    print(f'Avg latency: {np.mean(latencies)} seconds')


if __name__ == '__main__':
58
59
    parser = argparse.ArgumentParser(
        description='Benchmark the latency of decoding a single sentence.')
60
61
62
63
    parser = add_server_arguments(parser)
    parser.add_argument('--input-len', type=int, default=32)
    parser.add_argument('--output-len', type=int, default=128)
    parser.add_argument('--batch-size', type=int, default=8)
64
65
    parser.add_argument('--n', type=int, default=1)
    parser.add_argument('--use-beam-search', action='store_true')
66
    args = parser.parse_args()
67
    args = process_server_arguments(args)
68
69
    args.max_num_batched_tokens = max(
        args.max_num_batched_tokens, args.batch_size * args.input_len)
70
71
    print(args)
    main(args)