"vllm/entrypoints/openai/engine/serving.py" did not exist on "136d750f5f421ca5be2e24b0a913e813d99bb831"
benchmark_rope.py 4.52 KB
Newer Older
1
from itertools import accumulate
2
from typing import List, Optional
Terry's avatar
Terry committed
3
4

import nvtx
5
6
import torch

7
8
from vllm.model_executor.layers.rotary_embedding import (RotaryEmbedding,
                                                         get_rope)
9
10
from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser
Terry's avatar
Terry committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25


def benchmark_rope_kernels_multi_lora(
    is_neox_style: bool,
    batch_size: int,
    seq_len: int,
    num_heads: int,
    head_size: int,
    rotary_dim: Optional[int],
    dtype: torch.dtype,
    seed: int,
    device: str,
    max_position: int = 8192,
    base: int = 10000,
) -> None:
26
    current_platform.seed_everything(seed)
Terry's avatar
Terry committed
27
28
29
30
31
32
33
34
    torch.set_default_device(device)
    if rotary_dim is None:
        rotary_dim = head_size
    # silulating serving 4 LoRAs
    scaling_factors = [1, 2, 4, 8]
    # batched RoPE can take multiple scaling factors
    batched_rope = get_rope(head_size, rotary_dim, max_position, base,
                            is_neox_style, {
35
                                "rope_type": "linear",
Terry's avatar
Terry committed
36
37
38
39
                                "factor": tuple(scaling_factors)
                            })
    # non-batched RoPE takes only one scaling factor, we create multiple
    # instances to simulate the same behavior
40
    non_batched_ropes: List[RotaryEmbedding] = []
Terry's avatar
Terry committed
41
42
43
44
    for scaling_factor in scaling_factors:
        non_batched_ropes.append(
            get_rope(head_size, rotary_dim, max_position, base, is_neox_style,
                     {
45
                         "rope_type": "linear",
Terry's avatar
Terry committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
                         "factor": (scaling_factor, )
                     }))

    positions = torch.randint(0, max_position, (batch_size, seq_len))
    query = torch.randn(batch_size,
                        seq_len,
                        num_heads * head_size,
                        dtype=dtype)
    key = torch.randn_like(query)

    # create query offsets for batched RoPE, we concat multiple kv cache
    # together and each query needs to find the right kv cache of its type
    offset_map = torch.tensor(
        list(
            accumulate([0] + [
                max_position * scaling_factor * 2
                for scaling_factor in scaling_factors[:-1]
            ])))
    query_types = torch.randint(0,
                                len(scaling_factors), (batch_size, seq_len),
                                device=device)
    # map query types to offsets
    query_offsets = offset_map[query_types]
    # the kernel takes flattened offsets
    flatten_offsets = query_offsets.flatten()

    # batched queries of the same type together for non-batched RoPE
    queries = [query[query_types == i] for i in range(len(scaling_factors))]
    keys = [key[query_types == i] for i in range(len(scaling_factors))]
    packed_qkr = zip(queries, keys, non_batched_ropes)
    # synchronize before start timing
    torch.cuda.synchronize()
    with nvtx.annotate("non-batched", color="yellow"):
        for q, k, r in packed_qkr:
            r.forward(positions, q, k)
    torch.cuda.synchronize()
    with nvtx.annotate("batched", color="green"):
        batched_rope.forward(positions, query, key, flatten_offsets)
    torch.cuda.synchronize()


if __name__ == '__main__':
88
    parser = FlexibleArgumentParser(
Terry's avatar
Terry committed
89
90
91
92
93
94
95
        description="Benchmark the rotary embedding kernels.")
    parser.add_argument("--is-neox-style", type=bool, default=True)
    parser.add_argument("--batch-size", type=int, default=16)
    parser.add_argument("--seq-len", type=int, default=512)
    parser.add_argument("--num-heads", type=int, default=8)
    parser.add_argument("--head-size",
                        type=int,
Joe's avatar
Joe committed
96
                        choices=[64, 80, 96, 112, 120, 128, 192, 256],
Terry's avatar
Terry committed
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
                        default=128)
    parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32)
    parser.add_argument("--dtype",
                        type=str,
                        choices=["bfloat16", "float"],
                        default="float")
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--device",
                        type=str,
                        choices=["cuda:0", "cuda:1"],
                        default="cuda:0")
    args = parser.parse_args()
    print(args)

    benchmark_rope_kernels_multi_lora(
        is_neox_style=args.is_neox_style,
        batch_size=args.batch_size,
        seq_len=args.seq_len,
        num_heads=args.num_heads,
        head_size=args.head_size,
        rotary_dim=args.rotary_dim,
        dtype=getattr(torch, args.dtype),
        seed=args.seed,
        device=args.device,
    )