benchmark_rope.py 7.27 KB
Newer Older
cmx's avatar
cmx committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
import torch
import triton

from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks

from liger_kernel.transformers.rope import liger_rotary_pos_emb
from liger_kernel.utils import infer_device
from liger_kernel.utils import transformers_version_dispatch

device = infer_device()


def bench_speed_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
    provider = input.kernel_provider
    mode = input.kernel_operation_mode

    extra_benchmark_config = input.extra_benchmark_config
    num_q_heads = extra_benchmark_config["num_q_heads"]
    num_kv_heads = extra_benchmark_config["num_kv_heads"]
    dtype = extra_benchmark_config["dtype"]

    # x can be either hidden_size or seq_len
    hidden_size = extra_benchmark_config["hidden_size"] if "hidden_size" in extra_benchmark_config else input.x
    seq_len = extra_benchmark_config["seq_len"] if "seq_len" in extra_benchmark_config else input.x

    head_dim = hidden_size // num_q_heads
    rotary_emb = transformers_version_dispatch(
        "4.48.0",
        LlamaRotaryEmbedding,
        LlamaRotaryEmbedding,
        before_kwargs={"dim": head_dim, "device": device},
        after_kwargs={"config": LlamaConfig(num_kv_heads=num_kv_heads, head_dim=head_dim), "device": device},
    )
    q = torch.randn(
        (1, seq_len, num_q_heads, head_dim),
        device=device,
        requires_grad=True,
        dtype=dtype,
    ).transpose(1, 2)
    k = torch.randn(
        (1, seq_len, num_kv_heads, head_dim),
        device=device,
        requires_grad=True,
        dtype=dtype,
    ).transpose(1, 2)
    dq, dk = (
        torch.randn_like(q, device=device, dtype=dtype),
        torch.randn_like(k, device=device),
    )
    pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0)
    cos, sin = rotary_emb(k, pos_ids)

    def fwd():
        if provider == "liger":
            return liger_rotary_pos_emb(q, k, cos, sin, pos_ids)
        elif provider == "huggingface":
            return apply_rotary_pos_emb(q, k, cos, sin, pos_ids)
        else:
            raise ValueError(f"Invalid provider: {provider} for RoPE embedding")

    if mode == "forward":
        ms_50, ms_20, ms_80 = triton.testing.do_bench(
            fwd,
            grad_to_none=[q, k],
            rep=400,
            quantiles=QUANTILES,
        )
    elif mode == "backward":
        q_out, k_out = fwd()
        ms_50, ms_20, ms_80 = triton.testing.do_bench(
            lambda: torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True),
            grad_to_none=[q, k],
            rep=400,
            quantiles=QUANTILES,
        )
    elif mode == "full":

        def full():
            q_out, k_out = fwd()
            torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True)

        ms_50, ms_20, ms_80 = triton.testing.do_bench(
            full,
            grad_to_none=[q, k],
            rep=400,
            quantiles=QUANTILES,
        )
    return SingleBenchmarkRunOutput(
        y_20=ms_20,
        y_50=ms_50,
        y_80=ms_80,
    )


def bench_memory_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
    provider = input.kernel_provider

    extra_benchmark_config = input.extra_benchmark_config
    num_q_heads = extra_benchmark_config["num_q_heads"]
    num_kv_heads = extra_benchmark_config["num_kv_heads"]
    dtype = extra_benchmark_config["dtype"]

    # x can be either hidden_size or seq_len
    hidden_size = extra_benchmark_config["hidden_size"] if "hidden_size" in extra_benchmark_config else input.x
    seq_len = extra_benchmark_config["seq_len"] if "seq_len" in extra_benchmark_config else input.x

    head_dim = hidden_size // num_q_heads
    rotary_emb = transformers_version_dispatch(
        "4.48.0",
        LlamaRotaryEmbedding,
        LlamaRotaryEmbedding,
        before_kwargs={"dim": head_dim, "device": device},
        after_kwargs={"config": LlamaConfig(num_kv_heads=num_kv_heads, head_dim=head_dim), "device": device},
    )
    q = torch.randn(
        (1, seq_len, num_q_heads, head_dim),
        device=device,
        requires_grad=True,
        dtype=dtype,
    ).transpose(1, 2)
    k = torch.randn(
        (1, seq_len, num_kv_heads, head_dim),
        device=device,
        requires_grad=True,
        dtype=dtype,
    ).transpose(1, 2)
    dq, dk = (
        torch.randn_like(q, device=device, dtype=dtype),
        torch.randn_like(k, device=device),
    )
    pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0)
    cos, sin = rotary_emb(k, pos_ids)

    def full():
        if provider == "liger":
            q_out, k_out = liger_rotary_pos_emb(q, k, cos, sin, pos_ids)
        else:
            q_out, k_out = apply_rotary_pos_emb(q, k, cos, sin, pos_ids)
        torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True)

    mem_50, mem_20, mem_80 = _test_memory(
        full,
        quantiles=QUANTILES,
    )
    return SingleBenchmarkRunOutput(
        y_20=mem_20,
        y_50=mem_50,
        y_80=mem_80,
    )


if __name__ == "__main__":
    args = parse_benchmark_script_args()

    common_configs_varying_hidden_size = {
        "kernel_name": "rope",
        "x_name": "H",
        "x_label": "hidden size",
        "x_values": [32 * (2**i) for i in range(4, 10, 2)],
        "kernel_providers": ["liger", "huggingface"],
        "extra_benchmark_configs": [
            {
                "dtype": torch.bfloat16,
                "seq_len": 2048,
                "num_q_heads": 32,
                "num_kv_heads": 8,
            }
        ],
        "overwrite": args.overwrite,
    }
    run_benchmarks(
        bench_test_fn=bench_speed_rope,
        kernel_operation_modes=["forward", "backward", "full"],
        metric_name="speed",
        metric_unit="ms",
        **common_configs_varying_hidden_size,
    )
    run_benchmarks(
        bench_test_fn=bench_memory_rope,
        kernel_operation_modes=["full"],
        metric_name="memory",
        metric_unit="MB",
        **common_configs_varying_hidden_size,
    )

    common_configs_varying_seq_len = {
        "kernel_name": "rope",
        "x_name": "T",
        "x_label": "sequence length",
        "x_values": [2**i for i in range(10, 15)],
        "kernel_providers": ["liger", "huggingface"],
        "extra_benchmark_configs": [
            {
                "dtype": torch.bfloat16,
                "hidden_size": 8192,
                "num_q_heads": 32,
                "num_kv_heads": 8,
            }
        ],
        "overwrite": args.overwrite,
    }
    run_benchmarks(
        bench_test_fn=bench_speed_rope,
        kernel_operation_modes=["forward", "backward", "full"],
        metric_name="speed",
        metric_unit="ms",
        **common_configs_varying_seq_len,
    )
    run_benchmarks(
        bench_test_fn=bench_memory_rope,
        kernel_operation_modes=["full"],
        metric_name="memory",
        metric_unit="MB",
        **common_configs_varying_seq_len,
    )