benchmark_lora.py 35.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
import argparse
import copy
import json
import pickle
import time
9
from collections.abc import Callable
10
11
12
13
from dataclasses import dataclass
from enum import Enum, auto
from itertools import product
from pathlib import Path
14
from typing import Any
15
16
17
18
19
20
21

import torch
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
from utils import ArgPool, Bench, CudaGraphBenchParams
from weight_shapes import WEIGHT_SHAPES

22
23
24
from vllm.triton_utils import HAS_TRITON

if HAS_TRITON:
25
26
    from vllm.lora.ops.triton_ops import LoRAKernelMeta, lora_expand, lora_shrink
    from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
27

28
from vllm.utils.argparse_utils import FlexibleArgumentParser
29
30
31
32

DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
DEFAULT_TP_SIZES = [1]
DEFAULT_BATCH_SIZES = [
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
    1,
    16,
    32,
    64,
    128,
    192,
    256,
    320,
    384,
    448,
    512,
    640,
    768,
    896,
    1024,
    2048,
    3072,
    4096,
    5120,
    6144,
    7168,
    8192,
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
]
DEFAULT_HIDDEN_SIZES = [1024, 2048, 4096, 8192, 16384]
DEFAULT_LORA_RANKS = [16]
DEFAULT_NUM_LORAS = [1, 2, 3, 4]
DEFAULT_SORT_BY_LORA_IDS = [False, True]
DEFAULT_SEQ_LENGTHS = [1]
DEFAULT_EXPAND_FN_ADD_INPUTS = [True, False]


# Utilities
def dtype_to_str(dtype: torch.dtype):
    if dtype == torch.float16:
        return "f16"
    if dtype == torch.bfloat16:
        return "bf16"
    if dtype == torch.float32:
        return "f32"
    raise ValueError(f"Unsupported dtype {dtype}")


75
76
77
def make_rand_lora_weight_tensor(
    k: int, n: int, num_loras: int, dtype: torch.dtype, device: str = "cuda"
) -> torch.Tensor:
78
79
80
81
82
    # LoRA weights column major
    return torch.rand((num_loras, n, k), dtype=dtype).to(device)


def make_rand_tensors(
83
84
85
    a_shape: tuple[int, ...],
    b_shape: tuple[int, ...],
    c_shape: tuple[int, ...],
86
87
88
89
90
    a_dtype: torch.dtype,
    b_dtype: torch.dtype,
    c_dtype: torch.dtype,
    num_slices: int,
    device: str = "cuda",
91
) -> tuple[torch.Tensor, list[torch.Tensor], torch.Tensor]:
92
93
94
95
96
97
    """
    Make LoRA input/output matrices.
    """
    A = torch.rand(a_shape, dtype=a_dtype).to(device)

    # LoRA weights column major
98
    Bs = [torch.rand(b_shape, dtype=b_dtype).to(device) for _ in range(num_slices)]
99
100
101
102
103

    C = torch.zeros(c_shape, dtype=c_dtype).to(device)
    return A, Bs, C


104
105
106
def make_prompt_lora_mapping(
    num_prompts: int, num_active_loras: int, sort_by_lora_id: bool, device: str
) -> torch.Tensor:
107
    """
108
    All prompts are mapped to a LoRA ID in range [0, num_active_loras).
109
110
111
112
113
    where 0 refers to first lora, 1 refers to second lora and so on.
    """
    assert num_active_loras > 0

    if not sort_by_lora_id:
114
        return torch.randint(0, num_active_loras, (num_prompts,), dtype=torch.long)
115
116
117
118
119
120
121
122
123
124

    # Divide LoRAs equally and in order.
    part_size = num_prompts // num_active_loras
    part_size = max(part_size, 1)

    lora_id = 0
    prompt_lora_mapping = []
    while len(prompt_lora_mapping) < num_prompts:
        prompt_lora_mapping.extend([lora_id] * part_size)
        lora_id = lora_id + 1 if lora_id + 1 < num_active_loras else lora_id
125
126
127
128
129
130
131
132
133
134
135
136
    return torch.tensor(
        prompt_lora_mapping[:num_prompts], dtype=torch.long, device=device
    )


def make_token_lora_mapping(
    num_tokens: int,
    num_prompts: int,
    prompt_lora_mapping: torch.Tensor,
    seq_len_tensor: torch.Tensor,
    device: str,
):
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
    """
    Make token_lora_mapping from prompt_lora_mapping and seq_lens_tensor
    """
    assert prompt_lora_mapping.shape[0] == num_prompts

    # token to lora index mapping
    token_lora_mapping = [0] * num_tokens
    current_offset = 0
    for b_id in range(num_prompts):
        lora_index = prompt_lora_mapping[b_id].item()
        s = current_offset
        e = s + seq_len_tensor[b_id].item()
        token_lora_mapping[s:e] = [lora_index] * (e - s)
        current_offset += seq_len_tensor[b_id].item()

    return torch.tensor(token_lora_mapping, dtype=torch.long, device=device)


155
156
157
158
159
160
161
def ref_group_gemm(
    ref_out: torch.Tensor,
    input: torch.Tensor,
    lora_weights: list[torch.Tensor],
    seq_lens_cpu: torch.Tensor,
    prompt_lora_mapping_cpu: torch.Tensor,
    scaling: float,
162
    add_inputs: bool | None,
163
):
164
165
166
167
168
169
170
171
    """
    Torch group gemm reference implementation to test correctness of
    benchmarking operations.
    """
    batches = seq_lens_cpu.size(0)
    out_list = []
    current_offset = 0
    for lora_index, b_length in zip(range(batches), seq_lens_cpu):
172
        x = input[current_offset : b_length + current_offset, :]
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
        current_offset += b_length
        w = lora_weights[prompt_lora_mapping_cpu[lora_index]]
        result = torch.nn.functional.linear(x, w)
        result *= scaling
        out_list.append(result)

    cat_result = torch.cat(out_list, dim=0)

    if add_inputs:
        ref_out += cat_result
    else:
        ref_out.copy_(cat_result)


class OpType(Enum):
    """
    LoRA Ops to benchmark and its properties.
    """
191

192
193
    LORA_SHRINK = auto()
    LORA_EXPAND = auto()
194
195
196

    @staticmethod
    def from_str(s: str) -> "OpType":
197
198
199
200
        if s.lower() == "lora_shrink":
            return OpType.LORA_SHRINK
        if s.lower() == "lora_expand":
            return OpType.LORA_EXPAND
201
202
203
        raise ValueError(f"Unrecognized str {s} to convert to OpType")

    def is_shrink_fn(self) -> bool:
204
        return self in [OpType.LORA_SHRINK]
205
206

    def is_expand_fn(self) -> bool:
207
        return self in [OpType.LORA_EXPAND]
208

209
    def num_slices(self) -> list[int]:
210
        return [1, 2, 3]
211

212
213
214
    def mkn(
        self, batch_size: int, seq_length: int, hidden_size: int, lora_rank: int
    ) -> tuple[int, int, int]:
215
216
217
218
219
220
        num_tokens = batch_size * seq_length
        if self.is_shrink_fn():
            m = num_tokens
            k = hidden_size
            n = lora_rank
        else:
221
            assert self.is_expand_fn()
222
223
224
225
226
227
            m = num_tokens
            k = lora_rank
            n = hidden_size
        return m, k, n

    def matmul_dtypes(
228
        self, op_dtype: torch.dtype
229
    ) -> tuple[torch.dtype, torch.dtype, torch.dtype]:
230
231
232
233
234
235
        """
        return a type, b type and c type for A x B = C
        """
        if self.is_shrink_fn():
            return op_dtype, op_dtype, torch.float32
        else:
236
            assert self.is_expand_fn()
237
238
239
            return torch.float32, op_dtype, op_dtype

    def matmul_shapes(
240
241
242
243
244
245
246
        self,
        batch_size: int,
        seq_length: int,
        hidden_size: int,
        lora_rank: int,
        num_loras: int,
        num_slices: int,
247
    ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
248
249
250
251
252
253
254
        """
        Given num_slices, return the shapes of the A, B, and C matrices
        in A x B = C, for the op_type
        """
        m, k, n = self.mkn(batch_size, seq_length, hidden_size, lora_rank)

        b_shape = (num_loras, n, k)  # col-major
255
256
        if self in [OpType.LORA_SHRINK]:
            # LoRA shrink kernels support num_slices inherently in the kernel.
257
            return ((m, k), b_shape, (num_slices, m, n))
258
259
        if self in [OpType.LORA_EXPAND]:
            # LoRA expand kernels support num_slices inherently in the kernel
260
261
262
263
            return ((num_slices, m, k), b_shape, (m, n * num_slices))
        raise ValueError(f"Unrecognized op_type {self}")

    def bench_fn(self) -> Callable:
264
265
266
267
        if self == OpType.LORA_SHRINK:
            return lora_shrink
        if self == OpType.LORA_EXPAND:
            return lora_expand
268

269
270
        raise ValueError(f"Unrecognized optype {self}")

271
272
273
274
275
276
277
    def run_ref_group_gemm(
        self,
        output: torch.Tensor,
        input: torch.Tensor,
        lora_weights: list[torch.Tensor],
        **kwargs,
    ) -> Callable:
278
        """Each benchmark operation expects the input, lora_weights and outputs
279
280
281
        in a slightly different format. Refer to self.matmul_shapes().
        run_ref_group_gemm accounts for those differences in executing a
        reference group gemm for correctness testing.
282
283
284
        """
        w_dtype = lora_weights[0].dtype
        num_slices = len(lora_weights)
285
        if self in [OpType.LORA_SHRINK]:
286
            for slice_idx in range(num_slices):
287
288
289
290
291
292
                ref_group_gemm(
                    ref_out=output[slice_idx, :],
                    input=input,
                    lora_weights=lora_weights[slice_idx],
                    **kwargs,
                )
293
        elif self in [OpType.LORA_EXPAND]:
294
295
296
297
            hidden_size = lora_weights[0].shape[1]
            for slice_idx in range(num_slices):
                slice_offset = slice_idx * hidden_size
                ref_group_gemm(
298
                    ref_out=output[:, slice_offset : slice_offset + hidden_size],
299
300
                    input=input[slice_idx].clone().to(dtype=w_dtype),
                    lora_weights=lora_weights[slice_idx],
301
302
                    **kwargs,
                )
303
304
        else:
            raise ValueError(f"Unrecognized optype {self}")
305
306
307
308
309
310
311


@dataclass
class BenchmarkContext:
    """
    LoRA benchmark context
    """
312

313
314
315
316
317
318
319
    batch_size: int
    hidden_size: int
    num_loras: int
    num_active_loras: int
    lora_rank: int
    sort_by_lora_id: bool
    dtype: torch.dtype
320
321
    seq_length: int | None = None
    num_slices: int | None = None  # num_slices for slice based ops
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336

    def with_seq_length(self, seq_length: int) -> "BenchmarkContext":
        ctx = copy.copy(self)
        ctx.seq_length = seq_length
        return ctx

    def with_num_slices(self, num_slices: int) -> "BenchmarkContext":
        ctx = copy.copy(self)
        ctx.num_slices = num_slices
        return ctx

    def bench_label(self) -> str:
        return f"lora-{self.dtype}"

    def bench_sublabel(self, op_type: OpType) -> str:
337
338
339
        m, k, n = op_type.mkn(
            self.batch_size, self.seq_length, self.hidden_size, self.lora_rank
        )
340
        desc = {
341
342
343
344
345
346
347
348
            "bs": self.batch_size,
            "sl": self.seq_length,
            "m": m,
            "k": k,
            "n": n,
            "num_loras": self.num_loras,
            "sort_by_lora": self.sort_by_lora_id,
            "num_slices": self.num_slices,
349
350
351
352
353
354
355
356
357
        }
        return json.dumps(desc)


@dataclass
class BenchmarkTensors:
    """
    Input/Output tensors used for benchmarks
    """
358

359
360
    # matmul tensors
    input: torch.Tensor
361
    lora_weights_lst: list[torch.Tensor]
362
    output: torch.Tensor
363
364
365
    # LoRA kernel metadata
    lora_kernel_meta: LoRAKernelMeta
    # Metadata tensors used in testing correctness
366
367
368
369
    seq_lens: torch.Tensor
    prompt_lora_mapping: torch.Tensor

    def io_types(self) -> str:
370
371
372
373
374
        return (
            f"{dtype_to_str(self.input.dtype)}x"
            f"{dtype_to_str(self.lora_weights_lst[0].dtype)}=>"
            f"{dtype_to_str(self.output.dtype)}"
        )
375
376

    @staticmethod
377
378
379
    def make(
        ctx: BenchmarkContext, op_type: OpType, device: str = "cuda"
    ) -> "BenchmarkTensors":
380
381
        # Make input / output matmul tensors.
        a_shape, b_shape, c_shape = op_type.matmul_shapes(
382
383
384
385
386
387
388
            ctx.batch_size,
            ctx.seq_length,
            ctx.hidden_size,
            ctx.lora_rank,
            ctx.num_loras,
            ctx.num_slices,
        )
389
        a_type, b_type, c_type = op_type.matmul_dtypes(ctx.dtype)
390
391
392
        input_tensor, lora_weights, output_tensor = make_rand_tensors(
            a_shape, b_shape, c_shape, a_type, b_type, c_type, num_slices=ctx.num_slices
        )
393
394
395
396
397
398
399

        # Make metadata tensors.
        # Keep the metadata tensors in the CPU for further processing if needed.
        # The tensors get moved to the GPU before benchmarking.
        assert ctx.num_active_loras <= ctx.num_loras
        total_tokens = ctx.batch_size * ctx.seq_length

400
        # Make metadata tensors involved in correctness testing.
401
        # Prepare seq lens tensor
402
403
404
        seq_len_tensor = torch.randint(
            ctx.seq_length, ctx.seq_length + 1, (ctx.batch_size,)
        )
405
406
407
        assert total_tokens == seq_len_tensor.sum()
        # Prepare prompt lora indices tensor
        prompt_lora_indices_tensor = make_prompt_lora_mapping(
408
409
            ctx.batch_size, ctx.num_active_loras, ctx.sort_by_lora_id, "cpu"
        )
410
411

        # Make LoRAKernelMeta
412
        token_lora_indices_tensor = make_token_lora_mapping(
413
414
415
416
417
418
            total_tokens,
            ctx.batch_size,
            prompt_lora_indices_tensor,
            seq_len_tensor,
            "cpu",
        )
419
420
421
        lora_kernel_meta = LoRAKernelMeta.make(
            max_loras=ctx.num_loras,
            max_num_tokens=token_lora_indices_tensor.size(0),
422
423
424
425
426
427
428
429
430
431
432
433
            device="cpu",
        )
        lora_kernel_meta.prepare_tensors(token_lora_mapping=token_lora_indices_tensor)

        return BenchmarkTensors(
            input_tensor,
            lora_weights,
            output_tensor,
            lora_kernel_meta,
            seq_len_tensor,
            prompt_lora_indices_tensor,
        )
434
435
436
437
438
439
440
441
442

    def sanity_check(self) -> None:
        """
        Fails asserts when non-conformality is detected.
        """
        num_tokens = self.input.shape[-2]
        # check metadata tensors
        assert torch.sum(self.seq_lens) == num_tokens
        num_seqs = self.seq_lens.shape[0]
443
        # assert self.seq_start_loc.shape[0] == num_seqs
444
        assert self.prompt_lora_mapping.shape[0] == num_seqs
445
        assert self.lora_kernel_meta.token_lora_mapping.shape[0] == num_tokens
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463

    def to_device(self, device: str):
        """
        Transfer tensors to device if the tensors aren't already on the device
        """

        def to_device(tensor: torch.Tensor):
            if tensor.device != device:
                tensor = tensor.to(device=device)
            return tensor

        self.input = to_device(self.input)
        self.output = to_device(self.output)
        self.seq_lens = to_device(self.seq_lens)
        self.prompt_lora_mapping = to_device(self.prompt_lora_mapping)
        for i in range(len(self.lora_weights_lst)):
            self.lora_weights_lst[i] = to_device(self.lora_weights_lst[i])

464
465
466
467
        # LoRA meta
        for field_name in LoRAKernelMeta.__dataclass_fields__:
            field = getattr(self.lora_kernel_meta, field_name)
            assert isinstance(field, torch.Tensor)
468
469
470
471
472
            setattr(
                self.lora_kernel_meta,
                field_name,
                to_device(field) if field_name != "no_lora_flag_cpu" else field,
            )
473

474
    def metadata(self) -> tuple[int, int, int]:
475
476
477
478
        """
        Return num_seqs, num_tokens and max_seq_len
        """
        num_seqs = self.seq_lens.shape[0]
479
        num_tokens = self.lora_kernel_meta.token_lora_mapping.shape[0]
480
481
482
483
        max_seq_len = torch.max(self.seq_lens).item()
        num_slices = len(self.lora_weights_lst)
        return num_seqs, num_tokens, max_seq_len, num_slices

484
    def as_lora_shrink_kwargs(self) -> dict[str, Any]:
485
486
487
488
489
490
        self.sanity_check()
        self.to_device(self.input.device)

        _, num_tokens, _, num_slices = self.metadata()

        # Sanity check matrix shapes.
491
492
493
494
495
        i_shape, lw_shape, o_shape = (
            self.input.shape,
            self.lora_weights_lst[0].shape,
            self.output.shape,
        )
496
497
498
499
500
501
502
503
504
505
506
507
508
        # Expected input shape [num_tokens, hidden_size]
        assert len(i_shape) == 2
        assert i_shape[0] == num_tokens
        hidden_size = i_shape[1]
        # Expected lora weight shape [num_loras, lora_rank, hidden_size]
        assert len(lw_shape) == 3
        assert lw_shape[2] == hidden_size
        lora_rank = lw_shape[1]
        # Expected output shape [num_slices, num_tokens, lora_rank]
        assert len(o_shape) == 3
        assert o_shape == (num_slices, num_tokens, lora_rank)

        return {
509
510
511
512
513
514
515
516
517
518
519
            "inputs": self.input,
            "lora_a_weights": self.lora_weights_lst,
            "output_tensor": self.output,
            "token_lora_mapping": self.lora_kernel_meta.token_lora_mapping,
            "token_indices_sorted_by_lora_ids": (
                self.lora_kernel_meta.token_indices_sorted_by_lora_ids
            ),
            "num_tokens_per_lora": self.lora_kernel_meta.num_tokens_per_lora,
            "lora_token_start_loc": self.lora_kernel_meta.lora_token_start_loc,
            "lora_ids": self.lora_kernel_meta.active_lora_ids,
            "scaling": 1.0,
520
            "no_lora_flag_cpu": self.lora_kernel_meta.no_lora_flag_cpu,
521
522
        }

523
    def as_lora_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]:
524
525
526
527
528
529
        self.sanity_check()
        self.to_device(self.input.device)

        _, num_tokens, _, num_slices = self.metadata()

        # Sanity check matrix shapes.
530
531
532
533
534
        i_shape, lw_shape, o_shape = (
            self.input.shape,
            self.lora_weights_lst[0].shape,
            self.output.shape,
        )
535
536
537
538
539
540
541
542
543
544
545
546
547
548
        # Expected input shape : [num_slices, num_tokens, lora_rank]
        assert len(i_shape) == 3
        assert i_shape[0] == num_slices
        assert i_shape[1] == num_tokens
        lora_rank = i_shape[2]
        # Expected lora weight shape : [num_lora, hidden_size, lora_rank]
        assert len(lw_shape) == 3
        assert lw_shape[2] == lora_rank
        hidden_size = lw_shape[1]
        # Expected output shape : [num_tokens, hidden_size * num_slices]
        assert len(o_shape) == 2
        assert o_shape == (num_tokens, hidden_size * num_slices)

        return {
549
550
551
552
553
554
555
556
557
558
559
560
            "inputs": self.input,
            "lora_b_weights": self.lora_weights_lst,
            "output_tensor": self.output,
            "token_lora_mapping": self.lora_kernel_meta.token_lora_mapping,
            "token_indices_sorted_by_lora_ids": (
                self.lora_kernel_meta.token_indices_sorted_by_lora_ids
            ),
            "num_tokens_per_lora": self.lora_kernel_meta.num_tokens_per_lora,
            "lora_token_start_loc": self.lora_kernel_meta.lora_token_start_loc,
            "lora_ids": self.lora_kernel_meta.active_lora_ids,
            "offset_start": 0,
            "add_inputs": add_inputs,
561
            "no_lora_flag_cpu": self.lora_kernel_meta.no_lora_flag_cpu,
562
563
        }

564
    def bench_fn_kwargs(
565
        self, op_type: OpType, add_inputs: bool | None = None
566
    ) -> dict[str, Any]:
567
568
569
570
571
        if op_type.is_shrink_fn():
            assert add_inputs is None
        else:
            assert add_inputs is not None

572
573
574
575
        if op_type == OpType.LORA_SHRINK:
            return self.as_lora_shrink_kwargs()
        if op_type == OpType.LORA_EXPAND:
            return self.as_lora_expand_kwargs(add_inputs)
576
577
        raise ValueError(f"Unrecognized optype {self}")

578
    def test_correctness(
579
        self, op_type: OpType, expand_fn_add_inputs: bool | None
580
    ) -> bool:
581
582
583
584
585
586
587
588
589
        """
        Test correctness of op_type implementation against a grouped gemm
        reference implementation.
        """
        seq_lens_cpu = self.seq_lens.to(device="cpu")
        prompt_lora_mapping_cpu = self.prompt_lora_mapping.to(device="cpu")
        ref_output = self.output.clone()

        self.output.zero_()
590
        op_type.bench_fn()(**self.bench_fn_kwargs(op_type, expand_fn_add_inputs))
591
592
593
594
595
596
597
598

        op_type.run_ref_group_gemm(
            ref_output,
            self.input,
            self.lora_weights_lst,
            seq_lens_cpu=seq_lens_cpu,
            prompt_lora_mapping_cpu=prompt_lora_mapping_cpu,
            scaling=1.0,
599
600
            add_inputs=expand_fn_add_inputs,
        )
601
602
603
604
605
606
607
608
609
610

        rtol, atol = {
            torch.float16: (6e-2, 6e-2),
            torch.bfloat16: (6e-2, 6e-2),
            torch.float32: (1e-2, 1e-2),
        }[self.output.dtype]

        return torch.allclose(ref_output, self.output, rtol=rtol, atol=atol)


611
612
613
614
def bench_optype(
    ctx: BenchmarkContext,
    arg_pool_size: int,
    op_type: OpType,
615
616
    cuda_graph_nops: int | None = None,
    expand_fn_add_inputs: bool | None = None,
617
618
    test_correctness: bool = False,
) -> TMeasurement:
619
620
621
622
623
624
625
    assert arg_pool_size >= 1
    if op_type.is_shrink_fn():
        assert expand_fn_add_inputs is None
    else:
        assert expand_fn_add_inputs is not None

    # BenchmarkContext -> BenchmarkTensors
626
627
628
    bench_tensors: list[BenchmarkTensors] = [
        BenchmarkTensors.make(ctx, op_type) for _ in range(arg_pool_size)
    ]
629
630
631
632
633
    for bt in bench_tensors:
        bt.sanity_check()

    # Test correctness of our implementation.
    if test_correctness:
634
635
636
        assert all(
            [bt.test_correctness(op_type, expand_fn_add_inputs) for bt in bench_tensors]
        )
637

638
    # BenchmarkTensors -> dict (kwargs)
639
640
641
642
643
644
645
646
    kwargs_list = [
        bt.bench_fn_kwargs(op_type, add_inputs=expand_fn_add_inputs)
        for bt in bench_tensors
    ]

    # Clear LoRA optimization hash-maps.
    _LORA_A_PTR_DICT.clear()
    _LORA_B_PTR_DICT.clear()
647
    # Run bench function so that _LORA_A_PTR_DICT and _LORA_B_PTR_DICT are set up
648
649
650
651
652
653
654
655
656
657
    for kwargs in kwargs_list:
        op_type.bench_fn()(**kwargs)
    torch.cuda.synchronize()

    # Merge into a single kwargs and qualify arguments as ArgPool
    kwargs = {k: ArgPool([]) for k in kwargs_list[0]}
    for _kwargs in kwargs_list:
        for k, v in _kwargs.items():
            kwargs[k].values.append(v)

658
659
660
661
    describe_args = (
        f"add_inputs={expand_fn_add_inputs}" if expand_fn_add_inputs is not None else ""
    )
    description = f"{op_type.name}({describe_args}) ({bench_tensors[0].io_types()})"
662
663
664
665
666

    cuda_graph_params = None
    if cuda_graph_nops:
        cuda_graph_params = CudaGraphBenchParams(cuda_graph_nops)
    timer = None
667
668
669
670
671
672
673
674
    with Bench(
        cuda_graph_params,
        ctx.bench_label(),
        ctx.bench_sublabel(op_type),
        description,
        op_type.bench_fn(),
        **kwargs,
    ) as bench:
675
676
677
678
        timer = bench.run()
    return timer


679
680
681
682
def bench_torch_mm(
    ctx: BenchmarkContext,
    arg_pool_size: int,
    op_type: OpType,
683
    cuda_graph_nops: int | None = None,
684
) -> TMeasurement:
685
686
687
688
    """
    Benchmark basic torch.mm as a roofline.

    When all the input tokens have the same LoRA ID, the LoRA kernels are just
689
    a matmul. This torch.mm benchmark serves as a roofline for that case.
690
691
692
693

    input op_type is used in determining the m, k, n dimensions for the matmul.
    """

694
695
696
697
698
699
700
    batch_size, hidden_size, lora_rank, seq_length, dtype = (
        ctx.batch_size,
        ctx.hidden_size,
        ctx.lora_rank,
        ctx.seq_length,
        ctx.dtype,
    )
701
702
703
704
705
706
707
708
709
710
711
712
713

    m, k, n = op_type.mkn(batch_size, seq_length, hidden_size, lora_rank)
    # For a fairer comparison.
    n = n * ctx.num_slices

    # Get matmul input and output tensors for A x B = C
    As, Bs, Cs = [], [], []
    for _ in range(arg_pool_size):
        As.append(torch.rand((m, k), dtype=dtype).to("cuda"))
        Bs.append(torch.rand((n, k), dtype=dtype).to("cuda").t())
        Cs.append(torch.rand((m, n), dtype=dtype).to("cuda"))

    # Make torch.mm kwargs
714
    mm_kwargs = {"input": ArgPool(As), "mat2": ArgPool(Bs), "out": ArgPool(Cs)}
715
716
717
718

    description = (
        f"single-lora roofline using torch.mm ({dtype_to_str(dtype)}"
        f"x{dtype_to_str(dtype)}"
719
720
        f"=>{dtype_to_str(dtype)})"
    )
721
722
723
    cuda_graph_params = None
    if cuda_graph_nops:
        cuda_graph_params = CudaGraphBenchParams(cuda_graph_nops)
724
725
726
727
728
729
730
731
    with Bench(
        cuda_graph_params,
        ctx.bench_label(),
        ctx.bench_sublabel(op_type),
        description,
        torch.mm,
        **mm_kwargs,
    ) as bench:
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
        return bench.run()


# runner
def use_cuda_graph_recommendation() -> str:
    return """
            Triton kernels have a significant launch overhead with
            launched directly via python. This overhead is more noticeable
            for small the problem sizes. For these cases, it is recommended
            to use the script with `--cuda-graph-nops N` to benchmark N
            consecutive invocations of the benchmarking operations from 
            inside a CUDA Graph. Note that the returned measurement is for N 
            invocations of the operation.
            """


748
def print_timers(timers: list[TMeasurement], args: argparse.Namespace | None = None):
749
750
751
752
753
754
755
756
    compare = TBenchmark.Compare(timers)
    compare.print()

    if args and args.cuda_graph_nops:
        print(
            f"Note : The timings reported above is for {args.cuda_graph_nops} "
            "consecutive invocations of the benchmarking functions. "
            f"Please divide by {args.cuda_graph_nops} for single invocation "
757
758
            "timings."
        )
759

760
761
762
763
764
765
766
767
    print(
        "Note on Comparison with torch.mm : The torch.mm numbers are "
        "benchmark numbers of a simple matmul emulating the single lora "
        "case. It is provided as a roofline for comparing our LoRA Kernel "
        "implementations. It is expected that the LoRA kernels will be "
        "slower than torch.mm in cases where num_loras is big. But for "
        "small num_loras the goal should be to match the torch.mm numbers."
    )
768
769


770
def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]):
771
772
    if args.cuda_graph_nops is not None:
        assert args.cuda_graph_nops > 0
773
        print(f"Benchmarking {args.cuda_graph_nops} invocations inside a CUDA Graph")
774
775
776
777
778
779
    else:
        print(f"CUDA Graphs not enabled.\n{use_cuda_graph_recommendation()}")

    timers = []
    for bench_ctx in bench_ctxs:
        for seq_len in args.seq_lengths:
780
            bench_ops: list[OpType] = args.op_types
781
782
783
784
            seq_len_timers = []
            for bench_op in bench_ops:
                for num_slices in bench_op.num_slices():
                    _ctx = bench_ctx.with_seq_length(seq_len).with_num_slices(
785
786
                        num_slices
                    )
787
788
                    # Benchmark torch.mm as a roofline
                    seq_len_timers.append(
789
790
791
792
                        bench_torch_mm(
                            _ctx, args.arg_pool_size, bench_op, args.cuda_graph_nops
                        )
                    )
793
794

                    # Benchmark bench_op
795
796
797
                    expand_fn_add_inputs = (
                        [None] if bench_op.is_shrink_fn() else args.expand_fn_add_inputs
                    )
798
799
                    for add_input_arg in expand_fn_add_inputs:
                        seq_len_timers.append(
800
801
802
803
804
805
806
807
808
                            bench_optype(
                                _ctx,
                                args.arg_pool_size,
                                bench_op,
                                args.cuda_graph_nops,
                                add_input_arg,
                                args.test_correctness,
                            )
                        )
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829

            print_timers(seq_len_timers)
            timers.extend(seq_len_timers)

    # Result stdout dump
    print("== All Results ====")
    print_timers(timers, args)

    if args.output_directory:
        # Result file dump
        od = Path(args.output_directory)
        if not od.exists():
            od.mkdir()

        timestamp = int(time.time())
        pkl_file = od / f"lora_bench-{timestamp}.pkl"
        print(f"Writing benchmarks to {pkl_file}")
        with open(pkl_file, "wb") as f:
            pickle.dump(timers, f)


830
831
832
def as_benchmark_contexts(
    hidden_sizes: list[int], lora_ranks: list[int], args: argparse.Namespace
) -> list[BenchmarkContext]:
833
    ctxs: list[BenchmarkContext] = []
834
    for batch_size, hidden_size, lora_rank, num_loras, sort_by_lora_id in product(  # noqa
835
836
837
838
839
840
        args.batch_sizes,
        list(hidden_sizes),
        lora_ranks,
        args.num_loras,
        args.sort_by_lora_id,
    ):
841
842
843
844
845
846
847
        ctxs.append(
            BenchmarkContext(
                batch_size=batch_size,
                hidden_size=hidden_size,
                lora_rank=lora_rank,
                num_loras=num_loras,
                num_active_loras=args.num_active_loras
848
849
                if args.num_active_loras
                else num_loras,
850
851
852
853
854
                # To be filled based on the OpType to benchmark
                seq_length=None,
                sort_by_lora_id=sort_by_lora_id,
                dtype=args.dtype,
                # To be filled based on the OpType to benchmark
855
856
857
                num_slices=None,
            )
        )
858
859
860
861
862
863
864

    return ctxs


def run_list_bench(args: argparse.Namespace):
    print(args)

865
866
867
868
869
    print(
        "List bench :\n"
        f"  Hidden Sizes {args.hidden_sizes}"
        f"  LoRA Ranks {args.lora_ranks}"
    )
870
871

    # Get all benchmarking contexts
872
    bench_contexts: list[BenchmarkContext] = as_benchmark_contexts(
873
874
        hidden_sizes=args.hidden_sizes, lora_ranks=args.lora_ranks, args=args
    )
875
876
877
878
879
880
881
882

    run(args, bench_contexts)


def run_range_bench(args: argparse.Namespace):
    print(args)

    hidden_sizes = list(
883
884
885
886
887
888
        range(
            args.hidden_sizes_start,
            args.hidden_sizes_end + 1,
            args.hidden_sizes_increment,
        )
    )
889
    lora_ranks = list(
890
891
        range(args.lora_ranks_start, args.lora_ranks_end + 1, args.lora_ranks_increment)
    )
892

893
    print(f"Range bench :\n Hidden Sizes {hidden_sizes} LoRA Ranks {lora_ranks}")
894
895

    # Get all benchmarking contexts
896
    bench_contexts: list[BenchmarkContext] = as_benchmark_contexts(
897
898
        hidden_sizes=hidden_sizes, lora_ranks=lora_ranks, args=args
    )
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915

    run(args, bench_contexts)


def run_model_bench(args: argparse.Namespace):
    print(args)

    def hidden_sizes_from_model(model: str, tp_size: int) -> set[int]:
        hidden_sizes = set()
        for KN, tp_split_dim in WEIGHT_SHAPES[model]:
            KN[tp_split_dim] = KN[tp_split_dim] // tp_size
            hidden_sizes.add(KN[1])
        return hidden_sizes

    # Get all hidden sizes
    hidden_sizes: set[int] = set()
    for model_name, tp_size in product(args.models, args.tp_sizes):
916
        hidden_sizes = hidden_sizes.union(hidden_sizes_from_model(model_name, tp_size))
917

918
    print(f"Model bench :\n Hidden Sizes {hidden_sizes} LoRA Ranks {args.lora_ranks}")
919
920

    # Get all benchmarking contexts
921
    bench_contexts: list[BenchmarkContext] = as_benchmark_contexts(
922
923
        hidden_sizes=hidden_sizes, lora_ranks=args.lora_ranks, args=args
    )
924
925
926
927

    run(args, bench_contexts)


928
if __name__ == "__main__":
929
930
931
932
933
934
935
936
937

    def to_torch_dtype(dt):
        if dt == "torch.float16":
            return torch.float16
        if dt == "torch.bfloat16":
            return torch.bfloat16
        raise ValueError("unsupported dtype")

    def get_bool(s: str) -> bool:
938
        return s.lower() in ["true", "1"]
939
940
941
942
943
944

    def add_common_command_args(p: argparse.ArgumentParser):
        p.add_argument(
            "--dtype",
            type=to_torch_dtype,
            required=True,
945
946
            help="Available options are ['torch.float16', 'torch.bfloat16']",
        )
947
948
949
950
951
952
953

        p.add_argument(
            "--arg-pool-size",
            type=int,
            default=32,
            help="Run profiles with a pool of input/output/meta tensors instead"
            "of simply reusing the same tensors for all runs. A bigger arg-pool"
954
955
            "mitigates hardware caching effects during benchmarking.",
        )
956
957
958
959

        p.add_argument(
            "--cuda-graph-nops",
            type=int,
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
            help=(
                "when set profiling is done using cudagraph, "
                "with the given number of operations in a graph."
                "Note that the measurement returned is the time "
                "taken for N consecutive executions of the benchmarking "
                "functions, where N is the value of this argument."
            ),
        )
        p.add_argument("--num-loras", nargs="+", type=int, default=DEFAULT_NUM_LORAS)
        p.add_argument(
            "--num-active-loras",
            type=int,
            default=None,
            help="Active LoRAs. When None, all LoRAs are active",
        )
        p.add_argument(
            "--sort-by-lora-id",
            nargs="+",
            type=get_bool,
            default=DEFAULT_SORT_BY_LORA_IDS,
        )
        p.add_argument(
            "--op-types", nargs="+", type=OpType.from_str, default=list(OpType)
        )
        p.add_argument(
            "--seq-lengths", nargs="+", type=int, default=DEFAULT_SEQ_LENGTHS
        )
        p.add_argument(
            "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
        )
        p.add_argument(
            "--expand-fn-add-inputs",
            nargs="+",
            type=get_bool,
            default=DEFAULT_EXPAND_FN_ADD_INPUTS,
        )
996
        p.add_argument(
997
998
            "-o",
            "--output-directory",
999
            type=str,
1000
1001
1002
1003
1004
            help=(
                "Output directory to store a the list of benchmarking"
                "TMeasurement objects as a pickle file"
            ),
        )
1005
1006
1007

        p.add_argument(
            "--test-correctness",
1008
1009
1010
1011
1012
1013
            action="store_true",
            help=(
                "When enabled, the benchmarking functions are tested"
                "for correctness before the actual benchmarking"
            ),
        )
1014
1015
1016
1017
1018
1019
1020

    parser = FlexibleArgumentParser(
        description=f"""
Benchmark LoRA kernels:
    {use_cuda_graph_recommendation()}

    list_bench example:
1021
        python3 benchmarks/kernels/benchmark_lora.py list_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --hidden-sizes 2048 --lora-ranks 16 --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32
1022
1023

    model_bench example:
1024
        python3 benchmarks/kernels/benchmark_lora.py model_bench --models meta-llama/Llama-3-8b  --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16  --lora-ranks 16 --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 
1025
1026

    range_bench example:
1027
        python3 benchmarks/kernels/benchmark_lora.py range_bench  --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16   --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 --hidden-sizes-start 1024 --hidden-sizes-end 4096 --hidden-sizes-increment 1024 --lora-ranks-start 8 --lora-ranks-end 24 --lora-ranks-increment 8 
1028
            """,  # noqa: E501
1029
1030
        formatter_class=argparse.RawTextHelpFormatter,
    )
1031
1032
1033
1034

    subparsers = parser.add_subparsers(dest="cmd", required=True)

    list_parser = subparsers.add_parser("list_bench")
1035
1036
1037
1038
1039
1040
    list_parser.add_argument(
        "--hidden-sizes", nargs="+", type=int, default=DEFAULT_HIDDEN_SIZES
    )
    list_parser.add_argument(
        "--lora-ranks", nargs="+", type=int, default=DEFAULT_LORA_RANKS
    )
1041
1042
1043
1044
1045
1046
    add_common_command_args(list_parser)
    list_parser.set_defaults(func=run_list_bench)

    range_parser = subparsers.add_parser("range_bench")
    range_parser.add_argument("--hidden-sizes-start", type=int, required=True)
    range_parser.add_argument("--hidden-sizes-end", type=int, required=True)
1047
    range_parser.add_argument("--hidden-sizes-increment", type=int, required=True)
1048
1049
    range_parser.add_argument("--lora-ranks-start", type=int, required=True)
    range_parser.add_argument("--lora-ranks-end", type=int, required=True)
1050
    range_parser.add_argument("--lora-ranks-increment", type=int, required=True)
1051
1052
1053
1054
    add_common_command_args(range_parser)
    range_parser.set_defaults(func=run_range_bench)

    model_parser = subparsers.add_parser("model_bench")
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
    model_parser.add_argument(
        "--models",
        nargs="+",
        type=str,
        default=DEFAULT_MODELS,
        choices=WEIGHT_SHAPES.keys(),
    )
    model_parser.add_argument(
        "--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES
    )
    model_parser.add_argument(
        "--lora-ranks", nargs="+", type=int, default=DEFAULT_LORA_RANKS
    )
1068
1069
1070
1071
1072
    add_common_command_args(model_parser)
    model_parser.set_defaults(func=run_model_bench)

    args = parser.parse_args()
    args.func(args)