test_punica_ops.py 11.9 KB
Newer Older
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
from threading import Lock

import pytest
import torch

7
8
9
import vllm.lora.ops.torch_ops as torch_ops
import vllm.lora.ops.triton_ops as triton_ops
from vllm.lora.ops.triton_ops import LoRAKernelMeta
10
11
12
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
from vllm.platforms import current_platform

13
from .utils import PunicaTensors, assert_close, generate_data_for_nslices
14
15
16
17
18


# Utility shrink and expand operations used as reference implementations.
def sgmv_shrink_for_nslices(
        nslices: int, inputs_tensor: torch.Tensor,
19
        lora_weights_lst: list[torch.Tensor], out_tensor: torch.Tensor,
20
21
22
23
        b_seq_start_loc: torch.Tensor, seq_len_tensor: torch.Tensor,
        prompt_lora_mapping: torch.Tensor, batches: int, max_seq_length: int,
        num_tokens: int, scaling: float):
    """
24
    Wrapper around torch_ops.sgmv_shrink that handles any nslices.
25
26
    """
    for index in range(nslices):
27
        torch_ops.sgmv_shrink(
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
            inputs_tensor,
            lora_weights_lst[index],
            out_tensor[index],
            b_seq_start_loc,
            seq_len_tensor,
            prompt_lora_mapping,
            batches,
            max_seq_length,
            num_tokens,
            scaling,
        )


def sgmv_expand_for_nslices(nslices: int, hidden_size: int,
                            inputs_tensor: torch.Tensor,
43
                            lora_weights_lst: list[torch.Tensor],
44
45
46
47
48
49
50
                            out_tensor: torch.Tensor,
                            b_seq_start_loc: torch.Tensor,
                            seq_len_tensor: torch.Tensor,
                            prompt_lora_mapping: torch.Tensor, batches: int,
                            max_seq_length: int, num_tokens: int,
                            add_inputs: bool) -> None:
    """
51
    Wrapper around torch_ops.sgmv_expand that handles any nslices.
52
53
54
    """
    if nslices == 1:
        # Verify the torch's sgmv_expand op
55
        torch_ops.sgmv_expand(
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
            inputs_tensor[0],
            lora_weights_lst[0],
            out_tensor,
            b_seq_start_loc,
            seq_len_tensor,
            prompt_lora_mapping,
            batches,
            max_seq_length,
            num_tokens,
            add_inputs=add_inputs,
        )
    else:
        slice_offset = 0
        for index in range(nslices):
            lora_weights = lora_weights_lst[index]
71
            torch_ops.sgmv_expand_slice(
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
                inputs_tensor[index],
                lora_weights,
                out_tensor,
                b_seq_start_loc,
                seq_len_tensor,
                prompt_lora_mapping,
                batches,
                max_seq_length,
                num_tokens,
                slice_offset,
                hidden_size,
                add_inputs=add_inputs,
            )
            slice_offset += hidden_size


_dict_lock = Lock()


91
92
93
94
def check_lora_shrink_kernel(batches: int, num_loras: int, rank: int,
                             hidden_size: int, nslices: int,
                             dtype: torch.dtype, device: str, seq_length: int,
                             scaling: float):
95
    """
96
97
    Compare outputs of torch_ops.sgmv_shrink and triton_ops.lora_shrink
    kernels.
98
99
100
101
102
103
104
105
106
107
108
109
110
111
    """
    data: PunicaTensors = generate_data_for_nslices(
        batches,
        hidden_size,
        num_loras,
        rank,
        seq_length,
        nslices,
        dtype,
        "shrink",
        device,
    )
    max_seq_length, token_nums = data.meta()

112
113
114
115
116
    # Setup metadata information for SGMV and reference kernels
    sgmv_meta_args = (data.b_seq_start_loc, data.seq_len_tensor,
                      data.prompt_lora_mapping, batches, max_seq_length,
                      token_nums)

117
118
119
120
121
    # Setup metadata information for the LoRA kernel.
    lora_meta = LoRAKernelMeta.make(max_loras=num_loras,
                                    max_num_tokens=token_nums,
                                    device='cuda')
    lora_meta.prepare_tensors(data.token_lora_mapping)
122
123

    ref_out_tensor = data.ref_out_tensor
124
    out_tensor = data.our_out_tensor.clone()
125

126
127
    # Preventing cache error pointer.
    with _dict_lock:
128
        # lora_shrink kernel
129
        _LORA_A_PTR_DICT.clear()
130
        triton_ops.lora_shrink(
131
132
            data.inputs_tensor,
            data.lora_weights,
133
134
            out_tensor,
            *lora_meta.meta_args(token_nums=token_nums),
135
136
137
            scaling,
        )

138
139
140
141
142
143
144
145
146
    # Reference
    sgmv_shrink_for_nslices(
        nslices,
        data.inputs_tensor,
        data.lora_weights,
        ref_out_tensor,
        *sgmv_meta_args,
        scaling,
    )
147

148
    assert_close(out_tensor, ref_out_tensor)
149
150


151
152
153
154
def check_lora_expand_kernel(batches: int, num_loras: int, rank: int,
                             hidden_size: int, nslices: int,
                             dtype: torch.dtype, device: str, seq_length: int,
                             add_inputs: bool):
155
    """
156
157
    Compare outputs of torch_ops.sgmv_expand and triton_ops.lora_expand
    kernels.
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
    """
    data: PunicaTensors = generate_data_for_nslices(
        batches,
        hidden_size,
        num_loras,
        rank,
        seq_length,
        nslices,
        dtype,
        "expand",
        device,
    )

    max_seq_length, token_nums = data.meta()

173
174
175
176
177
    # Setup metadata information for SGMV and reference kernels
    sgmv_meta_args = (data.b_seq_start_loc, data.seq_len_tensor,
                      data.prompt_lora_mapping, batches, max_seq_length,
                      token_nums)

178
179
180
181
182
    # Setup metadata information for the LoRA kernel.
    lora_meta = LoRAKernelMeta.make(max_loras=num_loras,
                                    max_num_tokens=token_nums,
                                    device='cuda')
    lora_meta.prepare_tensors(data.token_lora_mapping)
183
184
185

    # Setup output tensors
    ref_out_tensor = data.ref_out_tensor
186
    out_tensor = data.our_out_tensor.clone()
187

188
    with _dict_lock:
189
        # lora_expand kernel
190
        _LORA_B_PTR_DICT.clear()
191
192
193
194
195
196
        triton_ops.lora_expand(data.inputs_tensor,
                               data.lora_weights,
                               out_tensor,
                               *lora_meta.meta_args(token_nums=token_nums),
                               offset_start=0,
                               add_inputs=add_inputs)
197
198

    # Reference
199
200
201
202
    sgmv_expand_for_nslices(nslices,
                            hidden_size,
                            data.inputs_tensor,
                            data.lora_weights,
203
204
                            ref_out_tensor,
                            *sgmv_meta_args,
205
206
                            add_inputs=add_inputs)

207
    assert_close(out_tensor, ref_out_tensor)
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336


# Tests
# We test the punica kernels along 2 verticals mainly.
# 1. Variations in hidden_dim size
# 2. Variations in all other parameters like (batch_size, max_rank, num_loras
#  etc.)

# We have collected the hidden_sizes included in the LoRA models
# currently supported by vLLM. It tests whether the corresponding Triton
# kernel can run normally when tensor parallelism is set to
# [1, 2, 4, 8, 16, 32, 64].
HIDDEN_SIZES = [
    128,
    256,
    512,
    896,
    1024,
    1152,
    1216,
    1280,
    1536,
    1664,
    2048,
    2240,
    2304,
    2368,
    2432,
    2560,
    2752,
    3072,
    3328,
    3456,
    3584,
    3712,
    4096,
    4480,
    4608,
    4736,
    4864,
    5120,
    5504,
    5632,
    5888,
    6144,
    6400,
    6848,
    6912,
    7168,
    7424,
    8192,
    8960,
    9216,
    9472,
    10240,
    11008,
    11264,
    13824,
    14336,
    14784,
    14848,
    15360,
    18944,
    22016,
    22528,
    24576,
    27392,
    27648,
    29568,
    29696,
    32000,
    32256,
    32512,
    32768,
    33024,
    36864,
    43264,
    49152,
    49408,
    60544,
    60672,
    64000,
    64256,
    102400,
    102656,
    128000,
    128256,
]
#The size of TP
divisibility = [1, 2, 8, 16, 64]

all_hidden_size = []
for div in divisibility:
    for hidden_size in HIDDEN_SIZES:
        all_hidden_size.append(hidden_size // div)

HIDDEN_SIZES = list(set(all_hidden_size))

# Test params that focuses on hidden_size variation.
hs_test_params = {
    "hidden_sizes": HIDDEN_SIZES,
    "batches": [4],
    "num_loras": [4],
    "max_ranks": [32],
}

# General tests params that tests for variations in all dimensions
# except hidden_size.
test_params = {
    "hidden_sizes": [2049],
    "batches": [1, 4, 16, 32],
    "num_loras": [1, 8, 32, 128],
    "max_ranks": [1, 4, 8, 16, 32, 64, 128, 256],
}

DTYPES = [torch.float16, torch.bfloat16]
DEVICES = [f"cuda:{0}"]
SEED = [0]


@pytest.mark.parametrize("batches", test_params['batches'])
@pytest.mark.parametrize("num_loras", test_params['num_loras'])
@pytest.mark.parametrize("rank", test_params['max_ranks'])
@pytest.mark.parametrize("hidden_size", test_params['hidden_sizes'])
@pytest.mark.parametrize("nslices", [1, 2, 3])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
337
def test_kernels(
338
339
340
341
342
343
344
345
346
347
    batches: int,
    num_loras: int,
    rank: int,
    hidden_size: int,
    nslices: int,
    dtype: torch.dtype,
    device: str,
    seed: int,
    op_type: str,
):
348
    """
349
    Tests LoRA kernels.
350
    """
351
352
353
354
    torch.set_default_device(device)
    current_platform.seed_everything(seed)

    if op_type == "shrink":
355
356
357
358
359
360
361
362
363
        check_lora_shrink_kernel(batches=batches,
                                 num_loras=num_loras,
                                 rank=rank,
                                 hidden_size=hidden_size,
                                 nslices=nslices,
                                 dtype=dtype,
                                 device=device,
                                 seq_length=128,
                                 scaling=0.5)
364
    else:
365
366
367
368
369
370
371
372
373
        check_lora_expand_kernel(batches=batches,
                                 num_loras=num_loras,
                                 rank=rank,
                                 hidden_size=hidden_size,
                                 nslices=nslices,
                                 dtype=dtype,
                                 device=device,
                                 seq_length=128,
                                 add_inputs=True)
374
375
376
377
378
379
380
381
382
383
384


@pytest.mark.parametrize("batches", hs_test_params['batches'])
@pytest.mark.parametrize("num_loras", hs_test_params['num_loras'])
@pytest.mark.parametrize("rank", hs_test_params['max_ranks'])
@pytest.mark.parametrize("hidden_size", hs_test_params['hidden_sizes'])
@pytest.mark.parametrize("nslices", [1, 2, 3])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
385
def test_kernels_hidden_size(
386
387
388
389
390
391
392
393
394
395
    batches: int,
    num_loras: int,
    rank: int,
    hidden_size: int,
    nslices: int,
    dtype: torch.dtype,
    device: str,
    seed: int,
    op_type: str,
):
396
    """
397
    Tests SGMV and LoRA kernels.
398
    """
399
400
401
402
    torch.set_default_device(device)
    current_platform.seed_everything(seed)

    if op_type == "shrink":
403
404
405
406
407
408
409
410
411
        check_lora_shrink_kernel(batches=batches,
                                 num_loras=num_loras,
                                 rank=rank,
                                 hidden_size=hidden_size,
                                 nslices=nslices,
                                 dtype=dtype,
                                 device=device,
                                 seq_length=128,
                                 scaling=0.5)
412
    else:
413
414
415
416
417
418
419
420
421
        check_lora_expand_kernel(batches=batches,
                                 num_loras=num_loras,
                                 rank=rank,
                                 hidden_size=hidden_size,
                                 nslices=nslices,
                                 dtype=dtype,
                                 device=device,
                                 seq_length=128,
                                 add_inputs=True)