benchmark_lora.py 47.8 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
9
10
11
import argparse
import copy
import json
import pickle
import time
from dataclasses import dataclass
from enum import Enum, auto
from itertools import product
from pathlib import Path
12
from typing import Any, Callable, Optional
13
14
15
16
17
18
19
20
21
22
23
24
25

import torch
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
from utils import ArgPool, Bench, CudaGraphBenchParams
from weight_shapes import WEIGHT_SHAPES

from vllm.lora.ops.triton_ops.bgmv_expand import bgmv_expand
from vllm.lora.ops.triton_ops.bgmv_expand_slice import bgmv_expand_slice
from vllm.lora.ops.triton_ops.bgmv_shrink import bgmv_shrink
from vllm.lora.ops.triton_ops.sgmv_expand import sgmv_expand
from vllm.lora.ops.triton_ops.sgmv_shrink import sgmv_shrink
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
26
from vllm.lora.ops.triton_ops.v1 import V1KernelMeta, v1_expand, v1_shrink
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from vllm.utils import FlexibleArgumentParser

DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
DEFAULT_TP_SIZES = [1]
DEFAULT_BATCH_SIZES = [
    1, 16, 32, 64, 128, 192, 256, 320, 384, 448, 512, 640, 768, 896, 1024,
    2048, 3072, 4096, 5120, 6144, 7168, 8192
]
DEFAULT_HIDDEN_SIZES = [1024, 2048, 4096, 8192, 16384]
DEFAULT_LORA_RANKS = [16]
DEFAULT_NUM_LORAS = [1, 2, 3, 4]
DEFAULT_SORT_BY_LORA_IDS = [False, True]
DEFAULT_SEQ_LENGTHS = [1]
DEFAULT_EXPAND_FN_ADD_INPUTS = [True, False]


# Utilities
def dtype_to_str(dtype: torch.dtype):
    if dtype == torch.float16:
        return "f16"
    if dtype == torch.bfloat16:
        return "bf16"
    if dtype == torch.float32:
        return "f32"
    raise ValueError(f"Unsupported dtype {dtype}")


def make_rand_lora_weight_tensor(k: int,
                                 n: int,
                                 num_loras: int,
                                 dtype: torch.dtype,
                                 device: str = "cuda") -> torch.Tensor:

    # LoRA weights column major
    return torch.rand((num_loras, n, k), dtype=dtype).to(device)


def make_rand_tensors(
65
66
67
    a_shape: tuple[int],
    b_shape: tuple[int],
    c_shape: tuple[int],
68
69
70
71
72
    a_dtype: torch.dtype,
    b_dtype: torch.dtype,
    c_dtype: torch.dtype,
    num_slices: int,
    device: str = "cuda",
73
) -> tuple[torch.Tensor, list[torch.Tensor], torch.Tensor]:
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
    """
    Make LoRA input/output matrices.
    """
    A = torch.rand(a_shape, dtype=a_dtype).to(device)

    # LoRA weights column major
    Bs = [
        torch.rand(b_shape, dtype=b_dtype).to(device)
        for _ in range(num_slices)
    ]

    C = torch.zeros(c_shape, dtype=c_dtype).to(device)
    return A, Bs, C


def make_prompt_lora_mapping(num_prompts: int, num_active_loras: int,
                             sort_by_lora_id: bool,
                             device: str) -> torch.Tensor:
    """
93
    All prompts are mapped to a LoRA ID in range [0, num_active_loras).
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
    where 0 refers to first lora, 1 refers to second lora and so on.
    """
    assert num_active_loras > 0

    if not sort_by_lora_id:
        return torch.randint(0,
                             num_active_loras, (num_prompts, ),
                             dtype=torch.long)

    # Divide LoRAs equally and in order.
    part_size = num_prompts // num_active_loras
    part_size = max(part_size, 1)

    lora_id = 0
    prompt_lora_mapping = []
    while len(prompt_lora_mapping) < num_prompts:
        prompt_lora_mapping.extend([lora_id] * part_size)
        lora_id = lora_id + 1 if lora_id + 1 < num_active_loras else lora_id
    return torch.tensor(prompt_lora_mapping[:num_prompts],
                        dtype=torch.long,
                        device=device)


def make_token_lora_mapping(num_tokens: int, num_prompts: int,
                            prompt_lora_mapping: torch.Tensor,
                            seq_len_tensor: torch.Tensor, device: str):
    """
    Make token_lora_mapping from prompt_lora_mapping and seq_lens_tensor
    """
    assert prompt_lora_mapping.shape[0] == num_prompts

    # token to lora index mapping
    token_lora_mapping = [0] * num_tokens
    current_offset = 0
    for b_id in range(num_prompts):
        lora_index = prompt_lora_mapping[b_id].item()
        s = current_offset
        e = s + seq_len_tensor[b_id].item()
        token_lora_mapping[s:e] = [lora_index] * (e - s)
        current_offset += seq_len_tensor[b_id].item()

    return torch.tensor(token_lora_mapping, dtype=torch.long, device=device)


def ref_group_gemm(ref_out: torch.Tensor, input: torch.Tensor,
139
                   lora_weights: list[torch.Tensor],
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
                   seq_lens_cpu: torch.Tensor,
                   prompt_lora_mapping_cpu: torch.Tensor, scaling: float,
                   add_inputs: Optional[bool]):
    """
    Torch group gemm reference implementation to test correctness of
    benchmarking operations.
    """
    batches = seq_lens_cpu.size(0)
    out_list = []
    current_offset = 0
    for lora_index, b_length in zip(range(batches), seq_lens_cpu):
        x = input[current_offset:b_length + current_offset, :]
        current_offset += b_length
        w = lora_weights[prompt_lora_mapping_cpu[lora_index]]
        result = torch.nn.functional.linear(x, w)
        result *= scaling
        out_list.append(result)

    cat_result = torch.cat(out_list, dim=0)

    if add_inputs:
        ref_out += cat_result
    else:
        ref_out.copy_(cat_result)


class OpType(Enum):
    """
    LoRA Ops to benchmark and its properties.
    """
    SGMV_SHRINK = auto()
    BGMV_SHRINK = auto()
    SGMV_EXPAND = auto()
    BGMV_EXPAND = auto()
    BGMV_EXPAND_SLICE = auto()
175
176
    V1_SHRINK = auto()
    V1_EXPAND = auto()
177
178
179
180
181
182
183
184
185
186
187
188
189

    @staticmethod
    def from_str(s: str) -> "OpType":
        if s.lower() == 'sgmv_shrink':
            return OpType.SGMV_SHRINK
        if s.lower() == 'sgmv_expand':
            return OpType.SGMV_EXPAND
        if s.lower() == 'bgmv_shrink':
            return OpType.BGMV_SHRINK
        if s.lower() == 'bgmv_expand':
            return OpType.BGMV_EXPAND
        if s.lower() == "bgmv_expand_slice":
            return OpType.BGMV_EXPAND_SLICE
190
191
192
193
        if s.lower() == "v1_shrink":
            return OpType.V1_SHRINK
        if s.lower() == "v1_expand":
            return OpType.V1_EXPAND
194
195
196
        raise ValueError(f"Unrecognized str {s} to convert to OpType")

    def is_shrink_fn(self) -> bool:
197
198
199
        return self in [
            OpType.SGMV_SHRINK, OpType.BGMV_SHRINK, OpType.V1_SHRINK
        ]
200
201

    def is_expand_fn(self) -> bool:
202
203
204
        return self in [
            OpType.SGMV_EXPAND, OpType.BGMV_EXPAND, OpType.V1_EXPAND
        ]
205
206

    def is_prefill_op(self) -> bool:
207
208
209
210
        return self in [
            OpType.SGMV_SHRINK, OpType.SGMV_EXPAND, OpType.V1_SHRINK,
            OpType.V1_EXPAND
        ]
211
212
213

    def is_decode_op(self) -> bool:
        return self in [
214
215
            OpType.BGMV_SHRINK, OpType.BGMV_EXPAND, OpType.BGMV_EXPAND_SLICE,
            OpType.V1_SHRINK, OpType.V1_EXPAND
216
217
218
219
220
        ]

    def is_expand_slice_fn(self) -> bool:
        return self in [OpType.BGMV_EXPAND_SLICE]

221
    def num_slices(self) -> list[int]:
222
223
224
225
226
        if self in [
                OpType.SGMV_EXPAND, OpType.SGMV_SHRINK, OpType.V1_SHRINK,
                OpType.V1_EXPAND
        ]:
            # SGMV kernels and v1 kernels supports slices
227
228
229
230
231
232
233
234
            return [1, 2, 3]
        if self in [OpType.BGMV_SHRINK, OpType.BGMV_EXPAND]:
            return [1]
        if self in [OpType.BGMV_EXPAND_SLICE]:
            return [2, 3]
        raise ValueError(f"Unrecognized OpType {self}")

    def mkn(self, batch_size: int, seq_length: int, hidden_size: int,
235
            lora_rank: int) -> tuple[int, int, int]:
236
237
238
239
240
241
242
243
244
245
246
247
248
249
        num_tokens = batch_size * seq_length
        if self.is_shrink_fn():
            m = num_tokens
            k = hidden_size
            n = lora_rank
        else:
            assert self.is_expand_fn() or self.is_expand_slice_fn()
            m = num_tokens
            k = lora_rank
            n = hidden_size
        return m, k, n

    def matmul_dtypes(
            self, op_dtype: torch.dtype
250
    ) -> tuple[torch.dtype, torch.dtype, torch.dtype]:
251
252
253
254
255
256
257
258
259
260
261
262
        """
        return a type, b type and c type for A x B = C
        """
        if self.is_shrink_fn():
            return op_dtype, op_dtype, torch.float32
        else:
            assert self.is_expand_fn() or self.is_expand_slice_fn()
            return torch.float32, op_dtype, op_dtype

    def matmul_shapes(
            self, batch_size: int, seq_length: int, hidden_size: int,
            lora_rank: int, num_loras: int,
263
            num_slices: int) -> tuple[tuple[int], tuple[int], tuple[int]]:
264
265
266
267
268
269
270
        """
        Given num_slices, return the shapes of the A, B, and C matrices
        in A x B = C, for the op_type
        """
        m, k, n = self.mkn(batch_size, seq_length, hidden_size, lora_rank)

        b_shape = (num_loras, n, k)  # col-major
271
272
273
        if self in [OpType.SGMV_SHRINK, OpType.V1_SHRINK]:
            # SGMV shrink and V1 shrink kernels support num_slices inherently
            # in the kernel.
274
            return ((m, k), b_shape, (num_slices, m, n))
275
276
277
        if self in [OpType.SGMV_EXPAND, OpType.V1_EXPAND]:
            # SGMV expand and V1 expand kernels support num_slices inherently
            # in the kernel
278
279
280
281
282
283
284
285
286
287
288
289
            return ((num_slices, m, k), b_shape, (m, n * num_slices))
        if self == OpType.BGMV_SHRINK:
            return ((m, k), b_shape, (m, n))
        if self == OpType.BGMV_EXPAND:
            return ((m, k), b_shape, (m, n))
        if self == OpType.BGMV_EXPAND_SLICE:
            return ((num_slices, m, k), b_shape, (m, n * num_slices))

        raise ValueError(f"Unrecognized op_type {self}")

    def bench_fn(self) -> Callable:

290
        def emulate_bgmv_expand_slice(kwargs_list: list[dict[str, Any]]):
291
292
293
294
295
296
297
298
299
300
301
302
303
            for x in kwargs_list:
                bgmv_expand_slice(**x)

        if self == OpType.SGMV_SHRINK:
            return sgmv_shrink
        if self == OpType.SGMV_EXPAND:
            return sgmv_expand
        if self == OpType.BGMV_SHRINK:
            return bgmv_shrink
        if self == OpType.BGMV_EXPAND:
            return bgmv_expand
        if self == OpType.BGMV_EXPAND_SLICE:
            return emulate_bgmv_expand_slice
304
305
306
307
308
        if self == OpType.V1_SHRINK:
            return v1_shrink
        if self == OpType.V1_EXPAND:
            return v1_expand

309
310
311
        raise ValueError(f"Unrecognized optype {self}")

    def run_ref_group_gemm(self, output: torch.Tensor, input: torch.Tensor,
312
                           lora_weights: list[torch.Tensor],
313
                           **kwargs) -> Callable:
314
        """Each benchmark operation expects the input, lora_weights and outputs
315
316
317
318
319
320
           in a slightly different format. Refer to self.matmul_shapes().
           run_ref_group_gemm accounts for those differences in executing a
           reference group gemm for correctness testing.
        """
        w_dtype = lora_weights[0].dtype
        num_slices = len(lora_weights)
321
        if self in [OpType.SGMV_SHRINK, OpType.V1_SHRINK]:
322
323
324
325
326
            for slice_idx in range(num_slices):
                ref_group_gemm(ref_out=output[slice_idx, :],
                               input=input,
                               lora_weights=lora_weights[slice_idx],
                               **kwargs)
327
        elif self in [OpType.SGMV_EXPAND, OpType.V1_EXPAND]:
328
329
330
331
332
333
334
335
            hidden_size = lora_weights[0].shape[1]
            for slice_idx in range(num_slices):
                slice_offset = slice_idx * hidden_size
                ref_group_gemm(
                    ref_out=output[:, slice_offset:slice_offset + hidden_size],
                    input=input[slice_idx].clone().to(dtype=w_dtype),
                    lora_weights=lora_weights[slice_idx],
                    **kwargs)
336
        elif self == OpType.BGMV_SHRINK:
337
338
339
340
341
            assert num_slices == 1
            ref_group_gemm(ref_out=output,
                           input=input,
                           lora_weights=lora_weights[0],
                           **kwargs)
342
        elif self == OpType.BGMV_EXPAND:
343
344
345
346
347
            assert num_slices == 1
            ref_group_gemm(ref_out=output,
                           input=input.clone().to(dtype=w_dtype),
                           lora_weights=lora_weights[0],
                           **kwargs)
348
        elif self == OpType.BGMV_EXPAND_SLICE:
349
350
351
352
353
354
355
356
            hidden_size = lora_weights[0].shape[1]
            for slice_idx in range(num_slices):
                slice_offset = slice_idx * hidden_size
                ref_group_gemm(
                    ref_out=output[:, slice_offset:slice_offset + hidden_size],
                    input=input[slice_idx].clone().to(dtype=w_dtype),
                    lora_weights=lora_weights[slice_idx],
                    **kwargs)
357
358
        else:
            raise ValueError(f"Unrecognized optype {self}")
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411


@dataclass
class BenchmarkContext:
    """
    LoRA benchmark context
    """
    batch_size: int
    hidden_size: int
    num_loras: int
    num_active_loras: int
    lora_rank: int
    sort_by_lora_id: bool
    dtype: torch.dtype
    seq_length: Optional[int] = None
    num_slices: Optional[int] = None  # num_slices for slice based ops

    def with_seq_length(self, seq_length: int) -> "BenchmarkContext":
        ctx = copy.copy(self)
        ctx.seq_length = seq_length
        return ctx

    def with_num_slices(self, num_slices: int) -> "BenchmarkContext":
        ctx = copy.copy(self)
        ctx.num_slices = num_slices
        return ctx

    def bench_label(self) -> str:
        return f"lora-{self.dtype}"

    def bench_sublabel(self, op_type: OpType) -> str:
        m, k, n = op_type.mkn(self.batch_size, self.seq_length,
                              self.hidden_size, self.lora_rank)
        desc = {
            'bs': self.batch_size,
            'sl': self.seq_length,
            'm': m,
            'k': k,
            'n': n,
            'num_loras': self.num_loras,
            'sort_by_lora': self.sort_by_lora_id,
            'num_slices': self.num_slices,
        }
        return json.dumps(desc)


@dataclass
class BenchmarkTensors:
    """
    Input/Output tensors used for benchmarks
    """
    # matmul tensors
    input: torch.Tensor
412
    lora_weights_lst: list[torch.Tensor]
413
414
415
416
417
418
    output: torch.Tensor
    # metadata tensors
    seq_lens: torch.Tensor
    seq_start_loc: torch.Tensor
    prompt_lora_mapping: torch.Tensor
    token_lora_mapping: torch.Tensor
419
420
    # v1 kernel metadata
    v1_kernel_meta: Optional[V1KernelMeta] = None
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462

    def io_types(self) -> str:
        return (f"{dtype_to_str(self.input.dtype)}x"
                f"{dtype_to_str(self.lora_weights_lst[0].dtype)}=>"
                f"{dtype_to_str(self.output.dtype)}")

    @staticmethod
    def make(ctx: BenchmarkContext,
             op_type: OpType,
             device: str = "cuda") -> "BenchmarkTensors":

        # Make input / output matmul tensors.
        a_shape, b_shape, c_shape = op_type.matmul_shapes(
            ctx.batch_size, ctx.seq_length, ctx.hidden_size, ctx.lora_rank,
            ctx.num_loras, ctx.num_slices)
        a_type, b_type, c_type = op_type.matmul_dtypes(ctx.dtype)
        input_tensor, lora_weights, output_tensor = \
            make_rand_tensors(a_shape, b_shape, c_shape, a_type, b_type, c_type,
                              num_slices = ctx.num_slices)

        # Make metadata tensors.
        # Keep the metadata tensors in the CPU for further processing if needed.
        # The tensors get moved to the GPU before benchmarking.
        assert ctx.num_active_loras <= ctx.num_loras
        total_tokens = ctx.batch_size * ctx.seq_length

        # Prepare seq lens tensor
        seq_len_tensor = torch.randint(ctx.seq_length, ctx.seq_length + 1,
                                       (ctx.batch_size, ))
        # Prepare seq_start_loc tensor
        seq_start_loc_tensor = torch.cumsum(torch.tensor(
            [0] + seq_len_tensor[:-1].tolist(), dtype=torch.long),
                                            dim=0)
        assert total_tokens == seq_len_tensor.sum()
        # Prepare prompt lora indices tensor
        prompt_lora_indices_tensor = make_prompt_lora_mapping(
            ctx.batch_size, ctx.num_active_loras, ctx.sort_by_lora_id, "cpu")
        # Prepare token lora indices tensor
        token_lora_indices_tensor = make_token_lora_mapping(
            total_tokens, ctx.batch_size, prompt_lora_indices_tensor,
            seq_len_tensor, "cpu")

463
464
465
466
467
468
469
470
471
        v1_kernel_meta = None
        if op_type in [OpType.V1_SHRINK, OpType.V1_EXPAND]:
            v1_kernel_meta = V1KernelMeta.make(
                max_loras=ctx.num_loras,
                max_num_tokens=token_lora_indices_tensor.size(0),
                device="cpu")
            v1_kernel_meta.prepare_tensors(
                token_lora_mapping=token_lora_indices_tensor)

472
473
474
        return BenchmarkTensors(input_tensor, lora_weights, output_tensor,
                                seq_len_tensor, seq_start_loc_tensor,
                                prompt_lora_indices_tensor,
475
                                token_lora_indices_tensor, v1_kernel_meta)
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507

    def sanity_check(self) -> None:
        """
        Fails asserts when non-conformality is detected.
        """
        num_tokens = self.input.shape[-2]
        # check metadata tensors
        assert torch.sum(self.seq_lens) == num_tokens
        num_seqs = self.seq_lens.shape[0]
        assert self.seq_start_loc.shape[0] == num_seqs
        assert self.prompt_lora_mapping.shape[0] == num_seqs
        assert self.token_lora_mapping.shape[0] == num_tokens

    def to_device(self, device: str):
        """
        Transfer tensors to device if the tensors aren't already on the device
        """

        def to_device(tensor: torch.Tensor):
            if tensor.device != device:
                tensor = tensor.to(device=device)
            return tensor

        self.input = to_device(self.input)
        self.output = to_device(self.output)
        self.seq_lens = to_device(self.seq_lens)
        self.seq_start_loc = to_device(self.seq_start_loc)
        self.prompt_lora_mapping = to_device(self.prompt_lora_mapping)
        self.token_lora_mapping = to_device(self.token_lora_mapping)
        for i in range(len(self.lora_weights_lst)):
            self.lora_weights_lst[i] = to_device(self.lora_weights_lst[i])

508
509
510
511
512
513
514
        # v1 meta
        if self.v1_kernel_meta:
            for field_name in V1KernelMeta.__dataclass_fields__:
                field = getattr(self.v1_kernel_meta, field_name)
                assert isinstance(field, torch.Tensor)
                setattr(self.v1_kernel_meta, field_name, to_device(field))

515
    def metadata(self) -> tuple[int, int, int]:
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
        """
        Return num_seqs, num_tokens and max_seq_len
        """
        num_seqs = self.seq_lens.shape[0]
        num_tokens = self.token_lora_mapping.shape[0]
        max_seq_len = torch.max(self.seq_lens).item()
        num_slices = len(self.lora_weights_lst)
        return num_seqs, num_tokens, max_seq_len, num_slices

    def convert_to_sgmv_benchmark_tensors(self):
        """
        For sgmv punica kernels, when consecutive sequences have the
        same LoRA ID, we just merge them together.
        This happens in punica.py::compute_metadata
        """

        # Collapse seq_lens and seq_start_loc
        _, seq_lens = torch.unique_consecutive(self.token_lora_mapping,
                                               return_counts=True)
        cum_result = torch.cumsum(seq_lens, dim=0)
        seq_start_loc = torch.zeros_like(seq_lens)
        seq_start_loc[1:].copy_(cum_result[:-1])

        # Collapse prompt mapping
        prompt_lora_mapping = torch.unique_consecutive(
            self.prompt_lora_mapping)

        assert torch.sum(seq_lens) == torch.sum(self.seq_lens), \
         f"dont match - new {torch.sum(seq_lens)} vs {torch.sum(self.seq_lens)}"

        self.prompt_lora_mapping = prompt_lora_mapping.to(
            dtype=self.prompt_lora_mapping.dtype)
        self.seq_lens = seq_lens.to(dtype=self.seq_lens.dtype)
        self.seq_start_loc = seq_start_loc.to(dtype=self.seq_start_loc.dtype)

551
    def as_sgmv_shrink_kwargs(self) -> dict[str, Any]:
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
        self.convert_to_sgmv_benchmark_tensors()
        self.sanity_check()
        self.to_device(self.input.device)

        num_seqs, num_tokens, max_seq_len, num_slices = self.metadata()

        # Sanity check matrix shapes.
        i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
            0].shape, self.output.shape
        # Expected input shape [num_tokens, hidden_size]
        assert len(i_shape) == 2
        assert i_shape[0] == num_tokens
        hidden_size = i_shape[1]
        # Expected lora weight shape [num_loras, lora_rank, hidden_size]
        assert len(lw_shape) == 3
        assert lw_shape[2] == hidden_size
        lora_rank = lw_shape[1]
        # Expected output shape [num_slices, num_tokens, lora_rank]
        assert len(o_shape) == 3
        assert o_shape == (num_slices, num_tokens, lora_rank)

        return {
            'inputs': self.input,
            'lora_a_weights': self.lora_weights_lst,
            'output_tensor': self.output,
            'b_seq_start_loc': self.seq_start_loc,
            'seq_len_tensor': self.seq_lens,
            'lora_indices_tensor': self.prompt_lora_mapping,
            'batches': num_seqs,
            'max_seq_length': max_seq_len,
            'token_nums': num_tokens,
            'scaling': 1.0,
        }

586
    def as_sgmv_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]:
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623

        self.convert_to_sgmv_benchmark_tensors()
        self.sanity_check()
        self.to_device(self.input.device)

        num_seqs, num_tokens, max_seq_len, num_slices = self.metadata()

        # Sanity check matrix shapes.
        i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
            0].shape, self.output.shape
        # Expected input shape : [num_slices, num_tokens, lora_rank]
        assert len(i_shape) == 3
        assert i_shape[0] == num_slices
        assert i_shape[1] == num_tokens
        lora_rank = i_shape[2]
        # Expected lora weight shape : [num_lora, hidden_size, lora_rank]
        assert len(lw_shape) == 3
        assert lw_shape[2] == lora_rank
        hidden_size = lw_shape[1]
        # Expected output shape : [num_tokens, hidden_size * num_slices]
        assert len(o_shape) == 2
        assert o_shape == (num_tokens, hidden_size * num_slices)

        return {
            'inputs': self.input,
            'lora_b_weights': self.lora_weights_lst,
            'output_tensor': self.output,
            'b_seq_start_loc': self.seq_start_loc,
            'seq_len_tensor': self.seq_lens,
            'lora_indices_tensor': self.prompt_lora_mapping,
            'batches': num_seqs,
            'max_seq_length': max_seq_len,
            'token_nums': num_tokens,
            'offset_start': 0,
            'add_inputs': add_inputs,
        }

624
    def as_bgmv_shrink_kwargs(self) -> dict[str, Any]:
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
        assert len(self.lora_weights_lst) == 1
        self.to_device(self.input.device)

        _, num_tokens, _, _ = self.metadata()
        # Sanity check shapes
        i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
            0].shape, self.output.shape
        # Expected input shape [num_tokens, hidden_size]
        assert len(i_shape) == 2
        assert i_shape[0] == num_tokens
        hidden_size = i_shape[1]
        # Expected lora weight shape [num_loras, lora_rank, hidden_size]
        assert len(lw_shape) == 3
        assert lw_shape[2] == hidden_size
        lora_rank = lw_shape[1]
        # Expected output shape [num_tokens, lora_rank]
        assert len(o_shape) == 2
        assert o_shape == (num_tokens, lora_rank)

        return {
            'inputs': self.input,
            'lora_a_weights': self.lora_weights_lst[0],
            'output_tensor': self.output,
            'lora_indices_tensor': self.token_lora_mapping,
            'scaling': 1.0
        }

    def as_bgmv_expand_kwargs(self, add_inputs: bool):
        assert len(self.lora_weights_lst) == 1
        self.to_device(self.input.device)

        _, num_tokens, _, _ = self.metadata()
        # Sanity check shapes
        i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
            0].shape, self.output.shape
        # Expected input shape [num_tokens, lora_rank]
        assert len(i_shape) == 2
        assert i_shape[0] == num_tokens
        lora_rank = i_shape[1]
        # Expected lora weight shape [num_loras, hidden_size, lora_rank]
        assert len(lw_shape) == 3
        assert lw_shape[2] == lora_rank
        hidden_size = lw_shape[1]
        # Expected output shape [num_tokens, hidden_size]
        assert len(o_shape) == 2
        assert o_shape == (num_tokens, hidden_size)

        return {
            'inputs': self.input,
            'lora_b_weights': self.lora_weights_lst[0],
            'output_tensor': self.output,
            'lora_indices_tensor': self.token_lora_mapping,
            'add_inputs': add_inputs
        }

680
    def as_bgmv_expand_slice_kwargs(self, add_inputs: bool) -> dict[str, Any]:
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713

        _, num_tokens, _, num_slices = self.metadata()
        # Sanity check shapes
        i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
            0].shape, self.output.shape
        # Expected input shape [num_slices, num_tokens, lora_rank]
        assert len(i_shape) == 3
        assert i_shape[0] == num_slices
        assert i_shape[1] == num_tokens
        lora_rank = i_shape[2]
        # Expected lora weight shape [num_loras, hidden_size, lora_rank]
        assert len(lw_shape) == 3
        assert lw_shape[2] == lora_rank
        hidden_size = lw_shape[1]
        # Expected output shape [num_tokens, hidden_size * num_slices]
        assert len(o_shape) == 2
        assert o_shape == (num_tokens, hidden_size * num_slices)

        self.to_device(self.input.device)

        kwargs_list = []
        for i in range(num_slices):
            kwargs_list.append({
                'inputs': self.input[i],
                'lora_b_weights': self.lora_weights_lst[i],
                'output_tensor': self.output,
                'lora_indices_tensor': self.token_lora_mapping,
                'slice_offset': i * hidden_size,
                'slice_size': hidden_size,
                'add_inputs': add_inputs,
            })
        return {'kwargs_list': kwargs_list}

714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
    def as_v1_shrink_kwargs(self) -> dict[str, Any]:
        assert self.v1_kernel_meta is not None
        self.sanity_check()
        self.to_device(self.input.device)

        _, num_tokens, _, num_slices = self.metadata()

        # Sanity check matrix shapes.
        i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
            0].shape, self.output.shape
        # Expected input shape [num_tokens, hidden_size]
        assert len(i_shape) == 2
        assert i_shape[0] == num_tokens
        hidden_size = i_shape[1]
        # Expected lora weight shape [num_loras, lora_rank, hidden_size]
        assert len(lw_shape) == 3
        assert lw_shape[2] == hidden_size
        lora_rank = lw_shape[1]
        # Expected output shape [num_slices, num_tokens, lora_rank]
        assert len(o_shape) == 3
        assert o_shape == (num_slices, num_tokens, lora_rank)

        return {
            'inputs': self.input,
            'lora_a_weights': self.lora_weights_lst,
            'output_tensor': self.output,
            'token_lora_mapping': self.v1_kernel_meta.token_lora_mapping,
            'token_indices_sorted_by_lora_ids':
            self.v1_kernel_meta.token_indices_sorted_by_lora_ids,
            'num_tokens_per_lora': self.v1_kernel_meta.num_tokens_per_lora,
            'lora_token_start_loc': self.v1_kernel_meta.lora_token_start_loc,
            'lora_ids': self.v1_kernel_meta.active_lora_ids,
            'scaling': 1.0,
        }

    def as_v1_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]:
        assert self.v1_kernel_meta is not None
        self.sanity_check()
        self.to_device(self.input.device)

        _, num_tokens, _, num_slices = self.metadata()

        # Sanity check matrix shapes.
        i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
            0].shape, self.output.shape
        # Expected input shape : [num_slices, num_tokens, lora_rank]
        assert len(i_shape) == 3
        assert i_shape[0] == num_slices
        assert i_shape[1] == num_tokens
        lora_rank = i_shape[2]
        # Expected lora weight shape : [num_lora, hidden_size, lora_rank]
        assert len(lw_shape) == 3
        assert lw_shape[2] == lora_rank
        hidden_size = lw_shape[1]
        # Expected output shape : [num_tokens, hidden_size * num_slices]
        assert len(o_shape) == 2
        assert o_shape == (num_tokens, hidden_size * num_slices)

        return {
            'inputs': self.input,
            'lora_b_weights': self.lora_weights_lst,
            'output_tensor': self.output,
            'token_lora_mapping': self.v1_kernel_meta.token_lora_mapping,
            'token_indices_sorted_by_lora_ids':
            self.v1_kernel_meta.token_indices_sorted_by_lora_ids,
            'num_tokens_per_lora': self.v1_kernel_meta.num_tokens_per_lora,
            'lora_token_start_loc': self.v1_kernel_meta.lora_token_start_loc,
            'lora_ids': self.v1_kernel_meta.active_lora_ids,
            'offset_start': 0,
            'add_inputs': add_inputs,
        }

786
787
    def bench_fn_kwargs(self,
                        op_type: OpType,
788
                        add_inputs: Optional[bool] = None) -> dict[str, Any]:
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
        if op_type.is_shrink_fn():
            assert add_inputs is None
        else:
            assert add_inputs is not None

        if op_type == OpType.SGMV_SHRINK:
            return self.as_sgmv_shrink_kwargs()
        if op_type == OpType.SGMV_EXPAND:
            return self.as_sgmv_expand_kwargs(add_inputs)
        if op_type == OpType.BGMV_SHRINK:
            return self.as_bgmv_shrink_kwargs()
        if op_type == OpType.BGMV_EXPAND:
            return self.as_bgmv_expand_kwargs(add_inputs)
        if op_type == OpType.BGMV_EXPAND_SLICE:
            return self.as_bgmv_expand_slice_kwargs(add_inputs)
804
805
806
807
        if op_type == OpType.V1_SHRINK:
            return self.as_v1_shrink_kwargs()
        if op_type == OpType.V1_EXPAND:
            return self.as_v1_expand_kwargs(add_inputs)
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
        raise ValueError(f"Unrecognized optype {self}")

    def test_correctness(self, op_type: OpType,
                         expand_fn_add_inputs: Optional[bool]) -> bool:
        """
        Test correctness of op_type implementation against a grouped gemm
        reference implementation.
        """
        seq_lens_cpu = self.seq_lens.to(device="cpu")
        prompt_lora_mapping_cpu = self.prompt_lora_mapping.to(device="cpu")
        ref_output = self.output.clone()

        self.output.zero_()
        op_type.bench_fn()(
            **self.bench_fn_kwargs(op_type, expand_fn_add_inputs))

        op_type.run_ref_group_gemm(
            ref_output,
            self.input,
            self.lora_weights_lst,
            seq_lens_cpu=seq_lens_cpu,
            prompt_lora_mapping_cpu=prompt_lora_mapping_cpu,
            scaling=1.0,
            add_inputs=expand_fn_add_inputs)

        rtol, atol = {
            torch.float16: (6e-2, 6e-2),
            torch.bfloat16: (6e-2, 6e-2),
            torch.float32: (1e-2, 1e-2),
        }[self.output.dtype]

        return torch.allclose(ref_output, self.output, rtol=rtol, atol=atol)


def bench_optype(ctx: BenchmarkContext,
                 arg_pool_size: int,
                 op_type: OpType,
                 cuda_graph_nops: Optional[int] = None,
                 expand_fn_add_inputs: Optional[bool] = None,
                 test_correctness: bool = False) -> TMeasurement:

    assert arg_pool_size >= 1
    if op_type.is_shrink_fn():
        assert expand_fn_add_inputs is None
    else:
        assert expand_fn_add_inputs is not None

    # BenchmarkContext -> BenchmarkTensors
856
    bench_tensors : list[BenchmarkTensors] = \
857
858
859
860
861
862
863
864
865
866
867
        [BenchmarkTensors.make(ctx, op_type) for _ in range(arg_pool_size)]
    for bt in bench_tensors:
        bt.sanity_check()

    # Test correctness of our implementation.
    if test_correctness:
        assert all([
            bt.test_correctness(op_type, expand_fn_add_inputs)
            for bt in bench_tensors
        ])

868
    # BenchmarkTensors -> dict (kwargs)
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
    kwargs_list = [
        bt.bench_fn_kwargs(op_type, add_inputs=expand_fn_add_inputs)
        for bt in bench_tensors
    ]

    # Clear LoRA optimization hash-maps.
    _LORA_A_PTR_DICT.clear()
    _LORA_B_PTR_DICT.clear()
    # Run bench function so that _LORA_A_PTR_DICT and _LORA_B_PTR_DICT are setup
    for kwargs in kwargs_list:
        op_type.bench_fn()(**kwargs)
    torch.cuda.synchronize()

    # Merge into a single kwargs and qualify arguments as ArgPool
    kwargs = {k: ArgPool([]) for k in kwargs_list[0]}
    for _kwargs in kwargs_list:
        for k, v in _kwargs.items():
            kwargs[k].values.append(v)

    describe_args = (f"add_inputs={expand_fn_add_inputs}"
                     if expand_fn_add_inputs is not None else "")
    description = (
        f"{op_type.name}({describe_args}) ({bench_tensors[0].io_types()})")

    cuda_graph_params = None
    if cuda_graph_nops:
        cuda_graph_params = CudaGraphBenchParams(cuda_graph_nops)
    timer = None
    with Bench(cuda_graph_params,
               ctx.bench_label(), ctx.bench_sublabel(op_type), description,
               op_type.bench_fn(), **kwargs) as bench:
        timer = bench.run()
    return timer


def bench_torch_mm(ctx: BenchmarkContext,
                   arg_pool_size: int,
                   op_type: OpType,
                   cuda_graph_nops: Optional[int] = None) -> TMeasurement:
    """
    Benchmark basic torch.mm as a roofline.

    When all the input tokens have the same LoRA ID, the LoRA kernels are just
    a matmul. This torch.mm benchmark serves as a roofline for that case. 

    input op_type is used in determining the m, k, n dimensions for the matmul.
    """

    batch_size, hidden_size, lora_rank, seq_length, dtype = (ctx.batch_size,
                                                             ctx.hidden_size,
                                                             ctx.lora_rank,
                                                             ctx.seq_length,
                                                             ctx.dtype)

    m, k, n = op_type.mkn(batch_size, seq_length, hidden_size, lora_rank)
    # For a fairer comparison.
    n = n * ctx.num_slices

    # Get matmul input and output tensors for A x B = C
    As, Bs, Cs = [], [], []
    for _ in range(arg_pool_size):
        As.append(torch.rand((m, k), dtype=dtype).to("cuda"))
        Bs.append(torch.rand((n, k), dtype=dtype).to("cuda").t())
        Cs.append(torch.rand((m, n), dtype=dtype).to("cuda"))

    # Make torch.mm kwargs
    mm_kwargs = {'input': ArgPool(As), 'mat2': ArgPool(Bs), 'out': ArgPool(Cs)}

    description = (
        f"single-lora roofline using torch.mm ({dtype_to_str(dtype)}"
        f"x{dtype_to_str(dtype)}"
        f"=>{dtype_to_str(dtype)})")
    cuda_graph_params = None
    if cuda_graph_nops:
        cuda_graph_params = CudaGraphBenchParams(cuda_graph_nops)
    with Bench(cuda_graph_params, ctx.bench_label(),
               ctx.bench_sublabel(op_type), description, torch.mm,
               **mm_kwargs) as bench:
        return bench.run()


# runner
def use_cuda_graph_recommendation() -> str:
    return """
            Triton kernels have a significant launch overhead with
            launched directly via python. This overhead is more noticeable
            for small the problem sizes. For these cases, it is recommended
            to use the script with `--cuda-graph-nops N` to benchmark N
            consecutive invocations of the benchmarking operations from 
            inside a CUDA Graph. Note that the returned measurement is for N 
            invocations of the operation.
            """


963
def print_timers(timers: list[TMeasurement],
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
                 args: Optional[argparse.Namespace] = None):
    compare = TBenchmark.Compare(timers)
    compare.print()

    if args and args.cuda_graph_nops:
        print(
            f"Note : The timings reported above is for {args.cuda_graph_nops} "
            "consecutive invocations of the benchmarking functions. "
            f"Please divide by {args.cuda_graph_nops} for single invocation "
            "timings.")

    print("Note on Comparison with torch.mm : The torch.mm numbers are "
          "benchmark numbers of a simple matmul emulating the single lora "
          "case. It is provided as a roofline for comparing our LoRA Kernel "
          "implementations. It is expected that the LoRA kernels will be "
          "slower than torch.mm in cases where num_loras is big. But for "
          "small num_loras the goal should be to match the torch.mm numbers.")


983
def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]):
984
985
986
987
988
989
990
991
992
993
994

    if args.cuda_graph_nops is not None:
        assert args.cuda_graph_nops > 0
        print(f"Benchmarking {args.cuda_graph_nops} invocations inside a CUDA "
              "Graph")
    else:
        print(f"CUDA Graphs not enabled.\n{use_cuda_graph_recommendation()}")

    timers = []
    for bench_ctx in bench_ctxs:
        for seq_len in args.seq_lengths:
995
996
997
            bench_ops: list[OpType] = args.op_types
            if seq_len > 1:
                # bench only prefill ops
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
                bench_ops = [op for op in args.op_types if op.is_prefill_op()]

            seq_len_timers = []
            for bench_op in bench_ops:
                for num_slices in bench_op.num_slices():
                    _ctx = bench_ctx.with_seq_length(seq_len).with_num_slices(
                        num_slices)
                    # Benchmark torch.mm as a roofline
                    seq_len_timers.append(
                        bench_torch_mm(_ctx, args.arg_pool_size, bench_op,
                                       args.cuda_graph_nops))

                    # Benchmark bench_op
                    expand_fn_add_inputs = [
                        None
                    ] if bench_op.is_shrink_fn() else args.expand_fn_add_inputs
                    for add_input_arg in expand_fn_add_inputs:
                        seq_len_timers.append(
                            bench_optype(_ctx, args.arg_pool_size, bench_op,
                                         args.cuda_graph_nops, add_input_arg,
                                         args.test_correctness))

            print_timers(seq_len_timers)
            timers.extend(seq_len_timers)

    # Result stdout dump
    print("== All Results ====")
    print_timers(timers, args)

    if args.output_directory:
        # Result file dump
        od = Path(args.output_directory)
        if not od.exists():
            od.mkdir()

        timestamp = int(time.time())
        pkl_file = od / f"lora_bench-{timestamp}.pkl"
        print(f"Writing benchmarks to {pkl_file}")
        with open(pkl_file, "wb") as f:
            pickle.dump(timers, f)


1040
1041
def as_benchmark_contexts(hidden_sizes: list[int], lora_ranks: list[int],
                          args: argparse.Namespace) -> list[BenchmarkContext]:
1042

1043
    ctxs: list[BenchmarkContext] = []
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
    for batch_size, hidden_size, lora_rank, num_loras, sort_by_lora_id in product(  # noqa
            args.batch_sizes, list(hidden_sizes), lora_ranks, args.num_loras,
            args.sort_by_lora_id):
        ctxs.append(
            BenchmarkContext(
                batch_size=batch_size,
                hidden_size=hidden_size,
                lora_rank=lora_rank,
                num_loras=num_loras,
                num_active_loras=args.num_active_loras
                if args.num_active_loras else num_loras,
                # To be filled based on the OpType to benchmark
                seq_length=None,
                sort_by_lora_id=sort_by_lora_id,
                dtype=args.dtype,
                # To be filled based on the OpType to benchmark
                num_slices=None))

    return ctxs


def run_list_bench(args: argparse.Namespace):
    print(args)

    print("List bench :\n"
          f"  Hidden Sizes {args.hidden_sizes}"
          f"  LoRA Ranks {args.lora_ranks}")

    # Get all benchmarking contexts
1073
    bench_contexts: list[BenchmarkContext] = as_benchmark_contexts(
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
        hidden_sizes=args.hidden_sizes, lora_ranks=args.lora_ranks, args=args)

    run(args, bench_contexts)


def run_range_bench(args: argparse.Namespace):
    print(args)

    hidden_sizes = list(
        range(args.hidden_sizes_start, args.hidden_sizes_end + 1,
              args.hidden_sizes_increment))
    lora_ranks = list(
        range(args.lora_ranks_start, args.lora_ranks_end + 1,
              args.lora_ranks_increment))

    print("Range bench :\n"
          f" Hidden Sizes {hidden_sizes}"
          f" LoRA Ranks {lora_ranks}")

    # Get all benchmarking contexts
1094
    bench_contexts: list[BenchmarkContext] = as_benchmark_contexts(
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
        hidden_sizes=hidden_sizes, lora_ranks=lora_ranks, args=args)

    run(args, bench_contexts)


def run_model_bench(args: argparse.Namespace):
    print(args)

    def hidden_sizes_from_model(model: str, tp_size: int) -> set[int]:
        hidden_sizes = set()
        for KN, tp_split_dim in WEIGHT_SHAPES[model]:
            KN[tp_split_dim] = KN[tp_split_dim] // tp_size
            hidden_sizes.add(KN[1])
        return hidden_sizes

    # Get all hidden sizes
    hidden_sizes: set[int] = set()
    for model_name, tp_size in product(args.models, args.tp_sizes):
        hidden_sizes = hidden_sizes.union(
            hidden_sizes_from_model(model_name, tp_size))

    print("Model bench :\n"
          f" Hidden Sizes {hidden_sizes}"
          f" LoRA Ranks {args.lora_ranks}")

    # Get all benchmarking contexts
1121
    bench_contexts: list[BenchmarkContext] = as_benchmark_contexts(
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
        hidden_sizes=hidden_sizes, lora_ranks=args.lora_ranks, args=args)

    run(args, bench_contexts)


if __name__ == '__main__':

    def to_torch_dtype(dt):
        if dt == "torch.float16":
            return torch.float16
        if dt == "torch.bfloat16":
            return torch.bfloat16
        raise ValueError("unsupported dtype")

    def get_bool(s: str) -> bool:
        return s.lower() in ['true', '1']

    def add_common_command_args(p: argparse.ArgumentParser):
        p.add_argument(
            "--dtype",
            type=to_torch_dtype,
            required=True,
            help="Available options are ['torch.float16', 'torch.bfloat16']")

        p.add_argument(
            "--arg-pool-size",
            type=int,
            default=32,
            help="Run profiles with a pool of input/output/meta tensors instead"
            "of simply reusing the same tensors for all runs. A bigger arg-pool"
            "mitigates hardware caching effects during benchmarking.")

        p.add_argument(
            "--cuda-graph-nops",
            type=int,
            help=("when set profiling is done using cudagraph, "
                  "with the given number of operations in a graph."
                  "Note that the measurement returned is the time "
                  "taken for N consecutive executions of the benchmarking "
                  "functions, where N is the value of this argument."))
        p.add_argument("--num-loras",
                       nargs="+",
                       type=int,
                       default=DEFAULT_NUM_LORAS)
        p.add_argument("--num-active-loras",
                       type=int,
                       default=None,
                       help="Active LoRAs. When None, all LoRAs are active")
        p.add_argument("--sort-by-lora-id",
                       nargs="+",
                       type=get_bool,
                       default=DEFAULT_SORT_BY_LORA_IDS)
        p.add_argument("--op-types",
                       nargs="+",
                       type=OpType.from_str,
                       default=list(OpType))
        p.add_argument('--seq-lengths',
                       nargs="+",
                       type=int,
                       default=DEFAULT_SEQ_LENGTHS)
        p.add_argument("--batch-sizes",
                       nargs="+",
                       type=int,
                       default=DEFAULT_BATCH_SIZES)
        p.add_argument("--expand-fn-add-inputs",
                       nargs="+",
                       type=get_bool,
                       default=DEFAULT_EXPAND_FN_ADD_INPUTS)
        p.add_argument(
            '-o',
            '--output-directory',
            type=str,
            help=("Output directory to store a the list of benchmarking"
                  "TMeasurement objects as a pickle file"))

        p.add_argument(
            "--test-correctness",
            action='store_true',
            help=("When enabled, the benchmarking functions are tested"
                  "for correctness before the actual benchmarking"))

    parser = FlexibleArgumentParser(
        description=f"""
Benchmark LoRA kernels:
    {use_cuda_graph_recommendation()}

    list_bench example:
        python3 benchmarks/kernels/benchmark_lora.py list_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --hidden-sizes 2048 --lora-ranks 16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32

    model_bench example:
        python3 benchmarks/kernels/benchmark_lora.py model_bench --models meta-llama/Llama-3-8b  --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16  --lora-ranks 16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 

    range_bench example:
        python3 benchmarks/kernels/benchmark_lora.py range_bench  --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16   --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 --hidden-sizes-start 1024 --hidden-sizes-end 4096 --hidden-sizes-increment 1024 --lora-ranks-start 8 --lora-ranks-end 24 --lora-ranks-increment 8 
            """,  # noqa: E501
        formatter_class=argparse.RawTextHelpFormatter)

    subparsers = parser.add_subparsers(dest="cmd", required=True)

    list_parser = subparsers.add_parser("list_bench")
    list_parser.add_argument("--hidden-sizes",
                             nargs="+",
                             type=int,
                             default=DEFAULT_HIDDEN_SIZES)
    list_parser.add_argument("--lora-ranks",
                             nargs="+",
                             type=int,
                             default=DEFAULT_LORA_RANKS)
    add_common_command_args(list_parser)
    list_parser.set_defaults(func=run_list_bench)

    range_parser = subparsers.add_parser("range_bench")
    range_parser.add_argument("--hidden-sizes-start", type=int, required=True)
    range_parser.add_argument("--hidden-sizes-end", type=int, required=True)
    range_parser.add_argument("--hidden-sizes-increment",
                              type=int,
                              required=True)
    range_parser.add_argument("--lora-ranks-start", type=int, required=True)
    range_parser.add_argument("--lora-ranks-end", type=int, required=True)
    range_parser.add_argument("--lora-ranks-increment",
                              type=int,
                              required=True)
    add_common_command_args(range_parser)
    range_parser.set_defaults(func=run_range_bench)

    model_parser = subparsers.add_parser("model_bench")
    model_parser.add_argument("--models",
                              nargs="+",
                              type=str,
                              default=DEFAULT_MODELS,
                              choices=WEIGHT_SHAPES.keys())
    model_parser.add_argument("--tp-sizes",
                              nargs="+",
                              type=int,
                              default=DEFAULT_TP_SIZES)
    model_parser.add_argument("--lora-ranks",
                              nargs="+",
                              type=int,
                              default=DEFAULT_LORA_RANKS)
    add_common_command_args(model_parser)
    model_parser.set_defaults(func=run_model_bench)

    args = parser.parse_args()
    args.func(args)