benchmark_lora.py 34.8 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
9
10
11
import argparse
import copy
import json
import pickle
import time
from dataclasses import dataclass
from enum import Enum, auto
from itertools import product
from pathlib import Path
12
from typing import Any, Callable, Optional
13
14
15
16
17
18
19

import torch
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
from utils import ArgPool, Bench, CudaGraphBenchParams
from weight_shapes import WEIGHT_SHAPES

20
21
22
from vllm.triton_utils import HAS_TRITON

if HAS_TRITON:
23
24
    from vllm.lora.ops.triton_ops import LoRAKernelMeta, lora_expand, lora_shrink
    from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
25

26
27
28
29
30
from vllm.utils import FlexibleArgumentParser

DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
DEFAULT_TP_SIZES = [1]
DEFAULT_BATCH_SIZES = [
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
    1,
    16,
    32,
    64,
    128,
    192,
    256,
    320,
    384,
    448,
    512,
    640,
    768,
    896,
    1024,
    2048,
    3072,
    4096,
    5120,
    6144,
    7168,
    8192,
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
]
DEFAULT_HIDDEN_SIZES = [1024, 2048, 4096, 8192, 16384]
DEFAULT_LORA_RANKS = [16]
DEFAULT_NUM_LORAS = [1, 2, 3, 4]
DEFAULT_SORT_BY_LORA_IDS = [False, True]
DEFAULT_SEQ_LENGTHS = [1]
DEFAULT_EXPAND_FN_ADD_INPUTS = [True, False]


# Utilities
def dtype_to_str(dtype: torch.dtype):
    if dtype == torch.float16:
        return "f16"
    if dtype == torch.bfloat16:
        return "bf16"
    if dtype == torch.float32:
        return "f32"
    raise ValueError(f"Unsupported dtype {dtype}")


73
74
75
def make_rand_lora_weight_tensor(
    k: int, n: int, num_loras: int, dtype: torch.dtype, device: str = "cuda"
) -> torch.Tensor:
76
77
78
79
80
    # LoRA weights column major
    return torch.rand((num_loras, n, k), dtype=dtype).to(device)


def make_rand_tensors(
81
82
83
    a_shape: tuple[int],
    b_shape: tuple[int],
    c_shape: tuple[int],
84
85
86
87
88
    a_dtype: torch.dtype,
    b_dtype: torch.dtype,
    c_dtype: torch.dtype,
    num_slices: int,
    device: str = "cuda",
89
) -> tuple[torch.Tensor, list[torch.Tensor], torch.Tensor]:
90
91
92
93
94
95
    """
    Make LoRA input/output matrices.
    """
    A = torch.rand(a_shape, dtype=a_dtype).to(device)

    # LoRA weights column major
96
    Bs = [torch.rand(b_shape, dtype=b_dtype).to(device) for _ in range(num_slices)]
97
98
99
100
101

    C = torch.zeros(c_shape, dtype=c_dtype).to(device)
    return A, Bs, C


102
103
104
def make_prompt_lora_mapping(
    num_prompts: int, num_active_loras: int, sort_by_lora_id: bool, device: str
) -> torch.Tensor:
105
    """
106
    All prompts are mapped to a LoRA ID in range [0, num_active_loras).
107
108
109
110
111
    where 0 refers to first lora, 1 refers to second lora and so on.
    """
    assert num_active_loras > 0

    if not sort_by_lora_id:
112
        return torch.randint(0, num_active_loras, (num_prompts,), dtype=torch.long)
113
114
115
116
117
118
119
120
121
122

    # Divide LoRAs equally and in order.
    part_size = num_prompts // num_active_loras
    part_size = max(part_size, 1)

    lora_id = 0
    prompt_lora_mapping = []
    while len(prompt_lora_mapping) < num_prompts:
        prompt_lora_mapping.extend([lora_id] * part_size)
        lora_id = lora_id + 1 if lora_id + 1 < num_active_loras else lora_id
123
124
125
126
127
128
129
130
131
132
133
134
    return torch.tensor(
        prompt_lora_mapping[:num_prompts], dtype=torch.long, device=device
    )


def make_token_lora_mapping(
    num_tokens: int,
    num_prompts: int,
    prompt_lora_mapping: torch.Tensor,
    seq_len_tensor: torch.Tensor,
    device: str,
):
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
    """
    Make token_lora_mapping from prompt_lora_mapping and seq_lens_tensor
    """
    assert prompt_lora_mapping.shape[0] == num_prompts

    # token to lora index mapping
    token_lora_mapping = [0] * num_tokens
    current_offset = 0
    for b_id in range(num_prompts):
        lora_index = prompt_lora_mapping[b_id].item()
        s = current_offset
        e = s + seq_len_tensor[b_id].item()
        token_lora_mapping[s:e] = [lora_index] * (e - s)
        current_offset += seq_len_tensor[b_id].item()

    return torch.tensor(token_lora_mapping, dtype=torch.long, device=device)


153
154
155
156
157
158
159
160
161
def ref_group_gemm(
    ref_out: torch.Tensor,
    input: torch.Tensor,
    lora_weights: list[torch.Tensor],
    seq_lens_cpu: torch.Tensor,
    prompt_lora_mapping_cpu: torch.Tensor,
    scaling: float,
    add_inputs: Optional[bool],
):
162
163
164
165
166
167
168
169
    """
    Torch group gemm reference implementation to test correctness of
    benchmarking operations.
    """
    batches = seq_lens_cpu.size(0)
    out_list = []
    current_offset = 0
    for lora_index, b_length in zip(range(batches), seq_lens_cpu):
170
        x = input[current_offset : b_length + current_offset, :]
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
        current_offset += b_length
        w = lora_weights[prompt_lora_mapping_cpu[lora_index]]
        result = torch.nn.functional.linear(x, w)
        result *= scaling
        out_list.append(result)

    cat_result = torch.cat(out_list, dim=0)

    if add_inputs:
        ref_out += cat_result
    else:
        ref_out.copy_(cat_result)


class OpType(Enum):
    """
    LoRA Ops to benchmark and its properties.
    """
189

190
191
    LORA_SHRINK = auto()
    LORA_EXPAND = auto()
192
193
194

    @staticmethod
    def from_str(s: str) -> "OpType":
195
196
197
198
        if s.lower() == "lora_shrink":
            return OpType.LORA_SHRINK
        if s.lower() == "lora_expand":
            return OpType.LORA_EXPAND
199
200
201
        raise ValueError(f"Unrecognized str {s} to convert to OpType")

    def is_shrink_fn(self) -> bool:
202
        return self in [OpType.LORA_SHRINK]
203
204

    def is_expand_fn(self) -> bool:
205
        return self in [OpType.LORA_EXPAND]
206

207
    def num_slices(self) -> list[int]:
208
        return [1, 2, 3]
209

210
211
212
    def mkn(
        self, batch_size: int, seq_length: int, hidden_size: int, lora_rank: int
    ) -> tuple[int, int, int]:
213
214
215
216
217
218
        num_tokens = batch_size * seq_length
        if self.is_shrink_fn():
            m = num_tokens
            k = hidden_size
            n = lora_rank
        else:
219
            assert self.is_expand_fn()
220
221
222
223
224
225
            m = num_tokens
            k = lora_rank
            n = hidden_size
        return m, k, n

    def matmul_dtypes(
226
        self, op_dtype: torch.dtype
227
    ) -> tuple[torch.dtype, torch.dtype, torch.dtype]:
228
229
230
231
232
233
        """
        return a type, b type and c type for A x B = C
        """
        if self.is_shrink_fn():
            return op_dtype, op_dtype, torch.float32
        else:
234
            assert self.is_expand_fn()
235
236
237
            return torch.float32, op_dtype, op_dtype

    def matmul_shapes(
238
239
240
241
242
243
244
245
        self,
        batch_size: int,
        seq_length: int,
        hidden_size: int,
        lora_rank: int,
        num_loras: int,
        num_slices: int,
    ) -> tuple[tuple[int], tuple[int], tuple[int]]:
246
247
248
249
250
251
252
        """
        Given num_slices, return the shapes of the A, B, and C matrices
        in A x B = C, for the op_type
        """
        m, k, n = self.mkn(batch_size, seq_length, hidden_size, lora_rank)

        b_shape = (num_loras, n, k)  # col-major
253
254
        if self in [OpType.LORA_SHRINK]:
            # LoRA shrink kernels support num_slices inherently in the kernel.
255
            return ((m, k), b_shape, (num_slices, m, n))
256
257
        if self in [OpType.LORA_EXPAND]:
            # LoRA expand kernels support num_slices inherently in the kernel
258
259
260
261
            return ((num_slices, m, k), b_shape, (m, n * num_slices))
        raise ValueError(f"Unrecognized op_type {self}")

    def bench_fn(self) -> Callable:
262
263
264
265
        if self == OpType.LORA_SHRINK:
            return lora_shrink
        if self == OpType.LORA_EXPAND:
            return lora_expand
266

267
268
        raise ValueError(f"Unrecognized optype {self}")

269
270
271
272
273
274
275
    def run_ref_group_gemm(
        self,
        output: torch.Tensor,
        input: torch.Tensor,
        lora_weights: list[torch.Tensor],
        **kwargs,
    ) -> Callable:
276
        """Each benchmark operation expects the input, lora_weights and outputs
277
278
279
        in a slightly different format. Refer to self.matmul_shapes().
        run_ref_group_gemm accounts for those differences in executing a
        reference group gemm for correctness testing.
280
281
282
        """
        w_dtype = lora_weights[0].dtype
        num_slices = len(lora_weights)
283
        if self in [OpType.LORA_SHRINK]:
284
            for slice_idx in range(num_slices):
285
286
287
288
289
290
                ref_group_gemm(
                    ref_out=output[slice_idx, :],
                    input=input,
                    lora_weights=lora_weights[slice_idx],
                    **kwargs,
                )
291
        elif self in [OpType.LORA_EXPAND]:
292
293
294
295
            hidden_size = lora_weights[0].shape[1]
            for slice_idx in range(num_slices):
                slice_offset = slice_idx * hidden_size
                ref_group_gemm(
296
                    ref_out=output[:, slice_offset : slice_offset + hidden_size],
297
298
                    input=input[slice_idx].clone().to(dtype=w_dtype),
                    lora_weights=lora_weights[slice_idx],
299
300
                    **kwargs,
                )
301
302
        else:
            raise ValueError(f"Unrecognized optype {self}")
303
304
305
306
307
308
309


@dataclass
class BenchmarkContext:
    """
    LoRA benchmark context
    """
310

311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
    batch_size: int
    hidden_size: int
    num_loras: int
    num_active_loras: int
    lora_rank: int
    sort_by_lora_id: bool
    dtype: torch.dtype
    seq_length: Optional[int] = None
    num_slices: Optional[int] = None  # num_slices for slice based ops

    def with_seq_length(self, seq_length: int) -> "BenchmarkContext":
        ctx = copy.copy(self)
        ctx.seq_length = seq_length
        return ctx

    def with_num_slices(self, num_slices: int) -> "BenchmarkContext":
        ctx = copy.copy(self)
        ctx.num_slices = num_slices
        return ctx

    def bench_label(self) -> str:
        return f"lora-{self.dtype}"

    def bench_sublabel(self, op_type: OpType) -> str:
335
336
337
        m, k, n = op_type.mkn(
            self.batch_size, self.seq_length, self.hidden_size, self.lora_rank
        )
338
        desc = {
339
340
341
342
343
344
345
346
            "bs": self.batch_size,
            "sl": self.seq_length,
            "m": m,
            "k": k,
            "n": n,
            "num_loras": self.num_loras,
            "sort_by_lora": self.sort_by_lora_id,
            "num_slices": self.num_slices,
347
348
349
350
351
352
353
354
355
        }
        return json.dumps(desc)


@dataclass
class BenchmarkTensors:
    """
    Input/Output tensors used for benchmarks
    """
356

357
358
    # matmul tensors
    input: torch.Tensor
359
    lora_weights_lst: list[torch.Tensor]
360
    output: torch.Tensor
361
362
363
    # LoRA kernel metadata
    lora_kernel_meta: LoRAKernelMeta
    # Metadata tensors used in testing correctness
364
365
366
367
    seq_lens: torch.Tensor
    prompt_lora_mapping: torch.Tensor

    def io_types(self) -> str:
368
369
370
371
372
        return (
            f"{dtype_to_str(self.input.dtype)}x"
            f"{dtype_to_str(self.lora_weights_lst[0].dtype)}=>"
            f"{dtype_to_str(self.output.dtype)}"
        )
373
374

    @staticmethod
375
376
377
    def make(
        ctx: BenchmarkContext, op_type: OpType, device: str = "cuda"
    ) -> "BenchmarkTensors":
378
379
        # Make input / output matmul tensors.
        a_shape, b_shape, c_shape = op_type.matmul_shapes(
380
381
382
383
384
385
386
            ctx.batch_size,
            ctx.seq_length,
            ctx.hidden_size,
            ctx.lora_rank,
            ctx.num_loras,
            ctx.num_slices,
        )
387
        a_type, b_type, c_type = op_type.matmul_dtypes(ctx.dtype)
388
389
390
        input_tensor, lora_weights, output_tensor = make_rand_tensors(
            a_shape, b_shape, c_shape, a_type, b_type, c_type, num_slices=ctx.num_slices
        )
391
392
393
394
395
396
397

        # Make metadata tensors.
        # Keep the metadata tensors in the CPU for further processing if needed.
        # The tensors get moved to the GPU before benchmarking.
        assert ctx.num_active_loras <= ctx.num_loras
        total_tokens = ctx.batch_size * ctx.seq_length

398
        # Make metadata tensors involved in correctness testing.
399
        # Prepare seq lens tensor
400
401
402
        seq_len_tensor = torch.randint(
            ctx.seq_length, ctx.seq_length + 1, (ctx.batch_size,)
        )
403
404
405
        assert total_tokens == seq_len_tensor.sum()
        # Prepare prompt lora indices tensor
        prompt_lora_indices_tensor = make_prompt_lora_mapping(
406
407
            ctx.batch_size, ctx.num_active_loras, ctx.sort_by_lora_id, "cpu"
        )
408
409

        # Make LoRAKernelMeta
410
        token_lora_indices_tensor = make_token_lora_mapping(
411
412
413
414
415
416
            total_tokens,
            ctx.batch_size,
            prompt_lora_indices_tensor,
            seq_len_tensor,
            "cpu",
        )
417
418
419
        lora_kernel_meta = LoRAKernelMeta.make(
            max_loras=ctx.num_loras,
            max_num_tokens=token_lora_indices_tensor.size(0),
420
421
422
423
424
425
426
427
428
429
430
431
            device="cpu",
        )
        lora_kernel_meta.prepare_tensors(token_lora_mapping=token_lora_indices_tensor)

        return BenchmarkTensors(
            input_tensor,
            lora_weights,
            output_tensor,
            lora_kernel_meta,
            seq_len_tensor,
            prompt_lora_indices_tensor,
        )
432
433
434
435
436
437
438
439
440

    def sanity_check(self) -> None:
        """
        Fails asserts when non-conformality is detected.
        """
        num_tokens = self.input.shape[-2]
        # check metadata tensors
        assert torch.sum(self.seq_lens) == num_tokens
        num_seqs = self.seq_lens.shape[0]
441
        # assert self.seq_start_loc.shape[0] == num_seqs
442
        assert self.prompt_lora_mapping.shape[0] == num_seqs
443
        assert self.lora_kernel_meta.token_lora_mapping.shape[0] == num_tokens
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461

    def to_device(self, device: str):
        """
        Transfer tensors to device if the tensors aren't already on the device
        """

        def to_device(tensor: torch.Tensor):
            if tensor.device != device:
                tensor = tensor.to(device=device)
            return tensor

        self.input = to_device(self.input)
        self.output = to_device(self.output)
        self.seq_lens = to_device(self.seq_lens)
        self.prompt_lora_mapping = to_device(self.prompt_lora_mapping)
        for i in range(len(self.lora_weights_lst)):
            self.lora_weights_lst[i] = to_device(self.lora_weights_lst[i])

462
463
464
465
466
        # LoRA meta
        for field_name in LoRAKernelMeta.__dataclass_fields__:
            field = getattr(self.lora_kernel_meta, field_name)
            assert isinstance(field, torch.Tensor)
            setattr(self.lora_kernel_meta, field_name, to_device(field))
467

468
    def metadata(self) -> tuple[int, int, int]:
469
470
471
472
        """
        Return num_seqs, num_tokens and max_seq_len
        """
        num_seqs = self.seq_lens.shape[0]
473
        num_tokens = self.lora_kernel_meta.token_lora_mapping.shape[0]
474
475
476
477
        max_seq_len = torch.max(self.seq_lens).item()
        num_slices = len(self.lora_weights_lst)
        return num_seqs, num_tokens, max_seq_len, num_slices

478
    def as_lora_shrink_kwargs(self) -> dict[str, Any]:
479
480
481
482
483
484
        self.sanity_check()
        self.to_device(self.input.device)

        _, num_tokens, _, num_slices = self.metadata()

        # Sanity check matrix shapes.
485
486
487
488
489
        i_shape, lw_shape, o_shape = (
            self.input.shape,
            self.lora_weights_lst[0].shape,
            self.output.shape,
        )
490
491
492
493
494
495
496
497
498
499
500
501
502
        # Expected input shape [num_tokens, hidden_size]
        assert len(i_shape) == 2
        assert i_shape[0] == num_tokens
        hidden_size = i_shape[1]
        # Expected lora weight shape [num_loras, lora_rank, hidden_size]
        assert len(lw_shape) == 3
        assert lw_shape[2] == hidden_size
        lora_rank = lw_shape[1]
        # Expected output shape [num_slices, num_tokens, lora_rank]
        assert len(o_shape) == 3
        assert o_shape == (num_slices, num_tokens, lora_rank)

        return {
503
504
505
506
507
508
509
510
511
512
513
            "inputs": self.input,
            "lora_a_weights": self.lora_weights_lst,
            "output_tensor": self.output,
            "token_lora_mapping": self.lora_kernel_meta.token_lora_mapping,
            "token_indices_sorted_by_lora_ids": (
                self.lora_kernel_meta.token_indices_sorted_by_lora_ids
            ),
            "num_tokens_per_lora": self.lora_kernel_meta.num_tokens_per_lora,
            "lora_token_start_loc": self.lora_kernel_meta.lora_token_start_loc,
            "lora_ids": self.lora_kernel_meta.active_lora_ids,
            "scaling": 1.0,
514
515
        }

516
    def as_lora_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]:
517
518
519
520
521
522
        self.sanity_check()
        self.to_device(self.input.device)

        _, num_tokens, _, num_slices = self.metadata()

        # Sanity check matrix shapes.
523
524
525
526
527
        i_shape, lw_shape, o_shape = (
            self.input.shape,
            self.lora_weights_lst[0].shape,
            self.output.shape,
        )
528
529
530
531
532
533
534
535
536
537
538
539
540
541
        # Expected input shape : [num_slices, num_tokens, lora_rank]
        assert len(i_shape) == 3
        assert i_shape[0] == num_slices
        assert i_shape[1] == num_tokens
        lora_rank = i_shape[2]
        # Expected lora weight shape : [num_lora, hidden_size, lora_rank]
        assert len(lw_shape) == 3
        assert lw_shape[2] == lora_rank
        hidden_size = lw_shape[1]
        # Expected output shape : [num_tokens, hidden_size * num_slices]
        assert len(o_shape) == 2
        assert o_shape == (num_tokens, hidden_size * num_slices)

        return {
542
543
544
545
546
547
548
549
550
551
552
553
            "inputs": self.input,
            "lora_b_weights": self.lora_weights_lst,
            "output_tensor": self.output,
            "token_lora_mapping": self.lora_kernel_meta.token_lora_mapping,
            "token_indices_sorted_by_lora_ids": (
                self.lora_kernel_meta.token_indices_sorted_by_lora_ids
            ),
            "num_tokens_per_lora": self.lora_kernel_meta.num_tokens_per_lora,
            "lora_token_start_loc": self.lora_kernel_meta.lora_token_start_loc,
            "lora_ids": self.lora_kernel_meta.active_lora_ids,
            "offset_start": 0,
            "add_inputs": add_inputs,
554
555
        }

556
557
558
    def bench_fn_kwargs(
        self, op_type: OpType, add_inputs: Optional[bool] = None
    ) -> dict[str, Any]:
559
560
561
562
563
        if op_type.is_shrink_fn():
            assert add_inputs is None
        else:
            assert add_inputs is not None

564
565
566
567
        if op_type == OpType.LORA_SHRINK:
            return self.as_lora_shrink_kwargs()
        if op_type == OpType.LORA_EXPAND:
            return self.as_lora_expand_kwargs(add_inputs)
568
569
        raise ValueError(f"Unrecognized optype {self}")

570
571
572
    def test_correctness(
        self, op_type: OpType, expand_fn_add_inputs: Optional[bool]
    ) -> bool:
573
574
575
576
577
578
579
580
581
        """
        Test correctness of op_type implementation against a grouped gemm
        reference implementation.
        """
        seq_lens_cpu = self.seq_lens.to(device="cpu")
        prompt_lora_mapping_cpu = self.prompt_lora_mapping.to(device="cpu")
        ref_output = self.output.clone()

        self.output.zero_()
582
        op_type.bench_fn()(**self.bench_fn_kwargs(op_type, expand_fn_add_inputs))
583
584
585
586
587
588
589
590

        op_type.run_ref_group_gemm(
            ref_output,
            self.input,
            self.lora_weights_lst,
            seq_lens_cpu=seq_lens_cpu,
            prompt_lora_mapping_cpu=prompt_lora_mapping_cpu,
            scaling=1.0,
591
592
            add_inputs=expand_fn_add_inputs,
        )
593
594
595
596
597
598
599
600
601
602

        rtol, atol = {
            torch.float16: (6e-2, 6e-2),
            torch.bfloat16: (6e-2, 6e-2),
            torch.float32: (1e-2, 1e-2),
        }[self.output.dtype]

        return torch.allclose(ref_output, self.output, rtol=rtol, atol=atol)


603
604
605
606
607
608
609
610
def bench_optype(
    ctx: BenchmarkContext,
    arg_pool_size: int,
    op_type: OpType,
    cuda_graph_nops: Optional[int] = None,
    expand_fn_add_inputs: Optional[bool] = None,
    test_correctness: bool = False,
) -> TMeasurement:
611
612
613
614
615
616
617
    assert arg_pool_size >= 1
    if op_type.is_shrink_fn():
        assert expand_fn_add_inputs is None
    else:
        assert expand_fn_add_inputs is not None

    # BenchmarkContext -> BenchmarkTensors
618
619
620
    bench_tensors: list[BenchmarkTensors] = [
        BenchmarkTensors.make(ctx, op_type) for _ in range(arg_pool_size)
    ]
621
622
623
624
625
    for bt in bench_tensors:
        bt.sanity_check()

    # Test correctness of our implementation.
    if test_correctness:
626
627
628
        assert all(
            [bt.test_correctness(op_type, expand_fn_add_inputs) for bt in bench_tensors]
        )
629

630
    # BenchmarkTensors -> dict (kwargs)
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
    kwargs_list = [
        bt.bench_fn_kwargs(op_type, add_inputs=expand_fn_add_inputs)
        for bt in bench_tensors
    ]

    # Clear LoRA optimization hash-maps.
    _LORA_A_PTR_DICT.clear()
    _LORA_B_PTR_DICT.clear()
    # Run bench function so that _LORA_A_PTR_DICT and _LORA_B_PTR_DICT are setup
    for kwargs in kwargs_list:
        op_type.bench_fn()(**kwargs)
    torch.cuda.synchronize()

    # Merge into a single kwargs and qualify arguments as ArgPool
    kwargs = {k: ArgPool([]) for k in kwargs_list[0]}
    for _kwargs in kwargs_list:
        for k, v in _kwargs.items():
            kwargs[k].values.append(v)

650
651
652
653
    describe_args = (
        f"add_inputs={expand_fn_add_inputs}" if expand_fn_add_inputs is not None else ""
    )
    description = f"{op_type.name}({describe_args}) ({bench_tensors[0].io_types()})"
654
655
656
657
658

    cuda_graph_params = None
    if cuda_graph_nops:
        cuda_graph_params = CudaGraphBenchParams(cuda_graph_nops)
    timer = None
659
660
661
662
663
664
665
666
    with Bench(
        cuda_graph_params,
        ctx.bench_label(),
        ctx.bench_sublabel(op_type),
        description,
        op_type.bench_fn(),
        **kwargs,
    ) as bench:
667
668
669
670
        timer = bench.run()
    return timer


671
672
673
674
675
676
def bench_torch_mm(
    ctx: BenchmarkContext,
    arg_pool_size: int,
    op_type: OpType,
    cuda_graph_nops: Optional[int] = None,
) -> TMeasurement:
677
678
679
680
    """
    Benchmark basic torch.mm as a roofline.

    When all the input tokens have the same LoRA ID, the LoRA kernels are just
681
    a matmul. This torch.mm benchmark serves as a roofline for that case.
682
683
684
685

    input op_type is used in determining the m, k, n dimensions for the matmul.
    """

686
687
688
689
690
691
692
    batch_size, hidden_size, lora_rank, seq_length, dtype = (
        ctx.batch_size,
        ctx.hidden_size,
        ctx.lora_rank,
        ctx.seq_length,
        ctx.dtype,
    )
693
694
695
696
697
698
699
700
701
702
703
704
705

    m, k, n = op_type.mkn(batch_size, seq_length, hidden_size, lora_rank)
    # For a fairer comparison.
    n = n * ctx.num_slices

    # Get matmul input and output tensors for A x B = C
    As, Bs, Cs = [], [], []
    for _ in range(arg_pool_size):
        As.append(torch.rand((m, k), dtype=dtype).to("cuda"))
        Bs.append(torch.rand((n, k), dtype=dtype).to("cuda").t())
        Cs.append(torch.rand((m, n), dtype=dtype).to("cuda"))

    # Make torch.mm kwargs
706
    mm_kwargs = {"input": ArgPool(As), "mat2": ArgPool(Bs), "out": ArgPool(Cs)}
707
708
709
710

    description = (
        f"single-lora roofline using torch.mm ({dtype_to_str(dtype)}"
        f"x{dtype_to_str(dtype)}"
711
712
        f"=>{dtype_to_str(dtype)})"
    )
713
714
715
    cuda_graph_params = None
    if cuda_graph_nops:
        cuda_graph_params = CudaGraphBenchParams(cuda_graph_nops)
716
717
718
719
720
721
722
723
    with Bench(
        cuda_graph_params,
        ctx.bench_label(),
        ctx.bench_sublabel(op_type),
        description,
        torch.mm,
        **mm_kwargs,
    ) as bench:
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
        return bench.run()


# runner
def use_cuda_graph_recommendation() -> str:
    return """
            Triton kernels have a significant launch overhead with
            launched directly via python. This overhead is more noticeable
            for small the problem sizes. For these cases, it is recommended
            to use the script with `--cuda-graph-nops N` to benchmark N
            consecutive invocations of the benchmarking operations from 
            inside a CUDA Graph. Note that the returned measurement is for N 
            invocations of the operation.
            """


740
def print_timers(timers: list[TMeasurement], args: Optional[argparse.Namespace] = None):
741
742
743
744
745
746
747
748
    compare = TBenchmark.Compare(timers)
    compare.print()

    if args and args.cuda_graph_nops:
        print(
            f"Note : The timings reported above is for {args.cuda_graph_nops} "
            "consecutive invocations of the benchmarking functions. "
            f"Please divide by {args.cuda_graph_nops} for single invocation "
749
750
            "timings."
        )
751

752
753
754
755
756
757
758
759
    print(
        "Note on Comparison with torch.mm : The torch.mm numbers are "
        "benchmark numbers of a simple matmul emulating the single lora "
        "case. It is provided as a roofline for comparing our LoRA Kernel "
        "implementations. It is expected that the LoRA kernels will be "
        "slower than torch.mm in cases where num_loras is big. But for "
        "small num_loras the goal should be to match the torch.mm numbers."
    )
760
761


762
def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]):
763
764
    if args.cuda_graph_nops is not None:
        assert args.cuda_graph_nops > 0
765
        print(f"Benchmarking {args.cuda_graph_nops} invocations inside a CUDA Graph")
766
767
768
769
770
771
    else:
        print(f"CUDA Graphs not enabled.\n{use_cuda_graph_recommendation()}")

    timers = []
    for bench_ctx in bench_ctxs:
        for seq_len in args.seq_lengths:
772
            bench_ops: list[OpType] = args.op_types
773
774
775
776
            seq_len_timers = []
            for bench_op in bench_ops:
                for num_slices in bench_op.num_slices():
                    _ctx = bench_ctx.with_seq_length(seq_len).with_num_slices(
777
778
                        num_slices
                    )
779
780
                    # Benchmark torch.mm as a roofline
                    seq_len_timers.append(
781
782
783
784
                        bench_torch_mm(
                            _ctx, args.arg_pool_size, bench_op, args.cuda_graph_nops
                        )
                    )
785
786

                    # Benchmark bench_op
787
788
789
                    expand_fn_add_inputs = (
                        [None] if bench_op.is_shrink_fn() else args.expand_fn_add_inputs
                    )
790
791
                    for add_input_arg in expand_fn_add_inputs:
                        seq_len_timers.append(
792
793
794
795
796
797
798
799
800
                            bench_optype(
                                _ctx,
                                args.arg_pool_size,
                                bench_op,
                                args.cuda_graph_nops,
                                add_input_arg,
                                args.test_correctness,
                            )
                        )
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821

            print_timers(seq_len_timers)
            timers.extend(seq_len_timers)

    # Result stdout dump
    print("== All Results ====")
    print_timers(timers, args)

    if args.output_directory:
        # Result file dump
        od = Path(args.output_directory)
        if not od.exists():
            od.mkdir()

        timestamp = int(time.time())
        pkl_file = od / f"lora_bench-{timestamp}.pkl"
        print(f"Writing benchmarks to {pkl_file}")
        with open(pkl_file, "wb") as f:
            pickle.dump(timers, f)


822
823
824
def as_benchmark_contexts(
    hidden_sizes: list[int], lora_ranks: list[int], args: argparse.Namespace
) -> list[BenchmarkContext]:
825
    ctxs: list[BenchmarkContext] = []
826
    for batch_size, hidden_size, lora_rank, num_loras, sort_by_lora_id in product(  # noqa
827
828
829
830
831
832
        args.batch_sizes,
        list(hidden_sizes),
        lora_ranks,
        args.num_loras,
        args.sort_by_lora_id,
    ):
833
834
835
836
837
838
839
        ctxs.append(
            BenchmarkContext(
                batch_size=batch_size,
                hidden_size=hidden_size,
                lora_rank=lora_rank,
                num_loras=num_loras,
                num_active_loras=args.num_active_loras
840
841
                if args.num_active_loras
                else num_loras,
842
843
844
845
846
                # To be filled based on the OpType to benchmark
                seq_length=None,
                sort_by_lora_id=sort_by_lora_id,
                dtype=args.dtype,
                # To be filled based on the OpType to benchmark
847
848
849
                num_slices=None,
            )
        )
850
851
852
853
854
855
856

    return ctxs


def run_list_bench(args: argparse.Namespace):
    print(args)

857
858
859
860
861
    print(
        "List bench :\n"
        f"  Hidden Sizes {args.hidden_sizes}"
        f"  LoRA Ranks {args.lora_ranks}"
    )
862
863

    # Get all benchmarking contexts
864
    bench_contexts: list[BenchmarkContext] = as_benchmark_contexts(
865
866
        hidden_sizes=args.hidden_sizes, lora_ranks=args.lora_ranks, args=args
    )
867
868
869
870
871
872
873
874

    run(args, bench_contexts)


def run_range_bench(args: argparse.Namespace):
    print(args)

    hidden_sizes = list(
875
876
877
878
879
880
        range(
            args.hidden_sizes_start,
            args.hidden_sizes_end + 1,
            args.hidden_sizes_increment,
        )
    )
881
    lora_ranks = list(
882
883
        range(args.lora_ranks_start, args.lora_ranks_end + 1, args.lora_ranks_increment)
    )
884

885
    print(f"Range bench :\n Hidden Sizes {hidden_sizes} LoRA Ranks {lora_ranks}")
886
887

    # Get all benchmarking contexts
888
    bench_contexts: list[BenchmarkContext] = as_benchmark_contexts(
889
890
        hidden_sizes=hidden_sizes, lora_ranks=lora_ranks, args=args
    )
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907

    run(args, bench_contexts)


def run_model_bench(args: argparse.Namespace):
    print(args)

    def hidden_sizes_from_model(model: str, tp_size: int) -> set[int]:
        hidden_sizes = set()
        for KN, tp_split_dim in WEIGHT_SHAPES[model]:
            KN[tp_split_dim] = KN[tp_split_dim] // tp_size
            hidden_sizes.add(KN[1])
        return hidden_sizes

    # Get all hidden sizes
    hidden_sizes: set[int] = set()
    for model_name, tp_size in product(args.models, args.tp_sizes):
908
        hidden_sizes = hidden_sizes.union(hidden_sizes_from_model(model_name, tp_size))
909

910
    print(f"Model bench :\n Hidden Sizes {hidden_sizes} LoRA Ranks {args.lora_ranks}")
911
912

    # Get all benchmarking contexts
913
    bench_contexts: list[BenchmarkContext] = as_benchmark_contexts(
914
915
        hidden_sizes=hidden_sizes, lora_ranks=args.lora_ranks, args=args
    )
916
917
918
919

    run(args, bench_contexts)


920
if __name__ == "__main__":
921
922
923
924
925
926
927
928
929

    def to_torch_dtype(dt):
        if dt == "torch.float16":
            return torch.float16
        if dt == "torch.bfloat16":
            return torch.bfloat16
        raise ValueError("unsupported dtype")

    def get_bool(s: str) -> bool:
930
        return s.lower() in ["true", "1"]
931
932
933
934
935
936

    def add_common_command_args(p: argparse.ArgumentParser):
        p.add_argument(
            "--dtype",
            type=to_torch_dtype,
            required=True,
937
938
            help="Available options are ['torch.float16', 'torch.bfloat16']",
        )
939
940
941
942
943
944
945

        p.add_argument(
            "--arg-pool-size",
            type=int,
            default=32,
            help="Run profiles with a pool of input/output/meta tensors instead"
            "of simply reusing the same tensors for all runs. A bigger arg-pool"
946
947
            "mitigates hardware caching effects during benchmarking.",
        )
948
949
950
951

        p.add_argument(
            "--cuda-graph-nops",
            type=int,
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
            help=(
                "when set profiling is done using cudagraph, "
                "with the given number of operations in a graph."
                "Note that the measurement returned is the time "
                "taken for N consecutive executions of the benchmarking "
                "functions, where N is the value of this argument."
            ),
        )
        p.add_argument("--num-loras", nargs="+", type=int, default=DEFAULT_NUM_LORAS)
        p.add_argument(
            "--num-active-loras",
            type=int,
            default=None,
            help="Active LoRAs. When None, all LoRAs are active",
        )
        p.add_argument(
            "--sort-by-lora-id",
            nargs="+",
            type=get_bool,
            default=DEFAULT_SORT_BY_LORA_IDS,
        )
        p.add_argument(
            "--op-types", nargs="+", type=OpType.from_str, default=list(OpType)
        )
        p.add_argument(
            "--seq-lengths", nargs="+", type=int, default=DEFAULT_SEQ_LENGTHS
        )
        p.add_argument(
            "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
        )
        p.add_argument(
            "--expand-fn-add-inputs",
            nargs="+",
            type=get_bool,
            default=DEFAULT_EXPAND_FN_ADD_INPUTS,
        )
988
        p.add_argument(
989
990
            "-o",
            "--output-directory",
991
            type=str,
992
993
994
995
996
            help=(
                "Output directory to store a the list of benchmarking"
                "TMeasurement objects as a pickle file"
            ),
        )
997
998
999

        p.add_argument(
            "--test-correctness",
1000
1001
1002
1003
1004
1005
            action="store_true",
            help=(
                "When enabled, the benchmarking functions are tested"
                "for correctness before the actual benchmarking"
            ),
        )
1006
1007
1008
1009
1010
1011
1012

    parser = FlexibleArgumentParser(
        description=f"""
Benchmark LoRA kernels:
    {use_cuda_graph_recommendation()}

    list_bench example:
1013
        python3 benchmarks/kernels/benchmark_lora.py list_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --hidden-sizes 2048 --lora-ranks 16 --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32
1014
1015

    model_bench example:
1016
        python3 benchmarks/kernels/benchmark_lora.py model_bench --models meta-llama/Llama-3-8b  --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16  --lora-ranks 16 --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 
1017
1018

    range_bench example:
1019
        python3 benchmarks/kernels/benchmark_lora.py range_bench  --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16   --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 --hidden-sizes-start 1024 --hidden-sizes-end 4096 --hidden-sizes-increment 1024 --lora-ranks-start 8 --lora-ranks-end 24 --lora-ranks-increment 8 
1020
            """,  # noqa: E501
1021
1022
        formatter_class=argparse.RawTextHelpFormatter,
    )
1023
1024
1025
1026

    subparsers = parser.add_subparsers(dest="cmd", required=True)

    list_parser = subparsers.add_parser("list_bench")
1027
1028
1029
1030
1031
1032
    list_parser.add_argument(
        "--hidden-sizes", nargs="+", type=int, default=DEFAULT_HIDDEN_SIZES
    )
    list_parser.add_argument(
        "--lora-ranks", nargs="+", type=int, default=DEFAULT_LORA_RANKS
    )
1033
1034
1035
1036
1037
1038
    add_common_command_args(list_parser)
    list_parser.set_defaults(func=run_list_bench)

    range_parser = subparsers.add_parser("range_bench")
    range_parser.add_argument("--hidden-sizes-start", type=int, required=True)
    range_parser.add_argument("--hidden-sizes-end", type=int, required=True)
1039
    range_parser.add_argument("--hidden-sizes-increment", type=int, required=True)
1040
1041
    range_parser.add_argument("--lora-ranks-start", type=int, required=True)
    range_parser.add_argument("--lora-ranks-end", type=int, required=True)
1042
    range_parser.add_argument("--lora-ranks-increment", type=int, required=True)
1043
1044
1045
1046
    add_common_command_args(range_parser)
    range_parser.set_defaults(func=run_range_bench)

    model_parser = subparsers.add_parser("model_bench")
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
    model_parser.add_argument(
        "--models",
        nargs="+",
        type=str,
        default=DEFAULT_MODELS,
        choices=WEIGHT_SHAPES.keys(),
    )
    model_parser.add_argument(
        "--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES
    )
    model_parser.add_argument(
        "--lora-ranks", nargs="+", type=int, default=DEFAULT_LORA_RANKS
    )
1060
1061
1062
1063
1064
    add_common_command_args(model_parser)
    model_parser.set_defaults(func=run_model_bench)

    args = parser.parse_args()
    args.func(args)