"...research_projects/controlnetxs/pipeline_controlnet_xs.py" did not exist on "14e3a28c120eea88093442eb0a2a3df35d21a22d"
benchmark_rope.py 4.33 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

zhuwenwen's avatar
zhuwenwen committed
4
from itertools import accumulate
5
from typing import Optional
zhuwenwen's avatar
zhuwenwen committed
6
7
8
9

import nvtx
import torch

10
11
12
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding, get_rope
from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser
zhuwenwen's avatar
zhuwenwen committed
13
14
15
16
17
18
19
20
21
22
23
24
25


def benchmark_rope_kernels_multi_lora(
    is_neox_style: bool,
    batch_size: int,
    seq_len: int,
    num_heads: int,
    head_size: int,
    rotary_dim: Optional[int],
    dtype: torch.dtype,
    seed: int,
    device: str,
    max_position: int = 8192,
26
    base: float = 10000,
zhuwenwen's avatar
zhuwenwen committed
27
) -> None:
28
    current_platform.seed_everything(seed)
zhuwenwen's avatar
zhuwenwen committed
29
30
31
32
33
34
    torch.set_default_device(device)
    if rotary_dim is None:
        rotary_dim = head_size
    # silulating serving 4 LoRAs
    scaling_factors = [1, 2, 4, 8]
    # batched RoPE can take multiple scaling factors
35
36
37
38
39
40
41
42
    batched_rope = get_rope(
        head_size,
        rotary_dim,
        max_position,
        base,
        is_neox_style,
        {"rope_type": "linear", "factor": tuple(scaling_factors)},
    )
zhuwenwen's avatar
zhuwenwen committed
43
44
    # non-batched RoPE takes only one scaling factor, we create multiple
    # instances to simulate the same behavior
45
    non_batched_ropes: list[RotaryEmbedding] = []
zhuwenwen's avatar
zhuwenwen committed
46
47
    for scaling_factor in scaling_factors:
        non_batched_ropes.append(
48
49
50
51
52
53
54
55
56
            get_rope(
                head_size,
                rotary_dim,
                max_position,
                base,
                is_neox_style,
                {"rope_type": "linear", "factor": (scaling_factor,)},
            )
        )
zhuwenwen's avatar
zhuwenwen committed
57
58

    positions = torch.randint(0, max_position, (batch_size, seq_len))
59
    query = torch.randn(batch_size, seq_len, num_heads * head_size, dtype=dtype)
zhuwenwen's avatar
zhuwenwen committed
60
61
62
63
64
65
    key = torch.randn_like(query)

    # create query offsets for batched RoPE, we concat multiple kv cache
    # together and each query needs to find the right kv cache of its type
    offset_map = torch.tensor(
        list(
66
67
68
69
70
71
72
73
74
75
76
77
            accumulate(
                [0]
                + [
                    max_position * scaling_factor * 2
                    for scaling_factor in scaling_factors[:-1]
                ]
            )
        )
    )
    query_types = torch.randint(
        0, len(scaling_factors), (batch_size, seq_len), device=device
    )
zhuwenwen's avatar
zhuwenwen committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    # map query types to offsets
    query_offsets = offset_map[query_types]
    # the kernel takes flattened offsets
    flatten_offsets = query_offsets.flatten()

    # batched queries of the same type together for non-batched RoPE
    queries = [query[query_types == i] for i in range(len(scaling_factors))]
    keys = [key[query_types == i] for i in range(len(scaling_factors))]
    packed_qkr = zip(queries, keys, non_batched_ropes)
    # synchronize before start timing
    torch.cuda.synchronize()
    with nvtx.annotate("non-batched", color="yellow"):
        for q, k, r in packed_qkr:
            r.forward(positions, q, k)
    torch.cuda.synchronize()
    with nvtx.annotate("batched", color="green"):
        batched_rope.forward(positions, query, key, flatten_offsets)
    torch.cuda.synchronize()


98
if __name__ == "__main__":
laibao's avatar
laibao committed
99
    parser = FlexibleArgumentParser(
100
101
        description="Benchmark the rotary embedding kernels."
    )
zhuwenwen's avatar
zhuwenwen committed
102
103
104
105
    parser.add_argument("--is-neox-style", type=bool, default=True)
    parser.add_argument("--batch-size", type=int, default=16)
    parser.add_argument("--seq-len", type=int, default=512)
    parser.add_argument("--num-heads", type=int, default=8)
106
107
108
109
110
111
    parser.add_argument(
        "--head-size",
        type=int,
        choices=[64, 80, 96, 112, 120, 128, 192, 256],
        default=128,
    )
zhuwenwen's avatar
zhuwenwen committed
112
    parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32)
113
114
115
    parser.add_argument(
        "--dtype", type=str, choices=["bfloat16", "float"], default="float"
    )
zhuwenwen's avatar
zhuwenwen committed
116
    parser.add_argument("--seed", type=int, default=0)
117
118
119
    parser.add_argument(
        "--device", type=str, choices=["cuda:0", "cuda:1"], default="cuda:0"
    )
zhuwenwen's avatar
zhuwenwen committed
120
121
122
123
124
125
126
127
128
129
130
131
132
133
    args = parser.parse_args()
    print(args)

    benchmark_rope_kernels_multi_lora(
        is_neox_style=args.is_neox_style,
        batch_size=args.batch_size,
        seq_len=args.seq_len,
        num_heads=args.num_heads,
        head_size=args.head_size,
        rotary_dim=args.rotary_dim,
        dtype=getattr(torch, args.dtype),
        seed=args.seed,
        device=args.device,
    )