test_punica_ops.py 11.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
from threading import Lock

import pytest
import torch

8
9
10
import vllm.lora.ops.torch_ops as torch_ops
import vllm.lora.ops.triton_ops as triton_ops
from vllm.lora.ops.triton_ops import LoRAKernelMeta
11
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
12
from vllm.platforms import current_platform
13
from vllm.utils.torch_utils import set_random_seed
14

15
from .utils import PunicaTensors, assert_close, generate_data_for_nslices
16

17
18
DEVICE_TYPE = current_platform.device_type

19

20
21
22
23
24
@pytest.fixture(autouse=True)
def reset_device(reset_default_device):
    pass


25
26
# Utility shrink and expand operations used as reference implementations.
def sgmv_shrink_for_nslices(
27
28
29
30
31
32
33
34
35
36
37
38
    nslices: int,
    inputs_tensor: torch.Tensor,
    lora_weights_lst: list[torch.Tensor],
    out_tensor: torch.Tensor,
    b_seq_start_loc: torch.Tensor,
    seq_len_tensor: torch.Tensor,
    prompt_lora_mapping: torch.Tensor,
    batches: int,
    max_seq_length: int,
    num_tokens: int,
    scaling: float,
):
39
    """
40
    Wrapper around torch_ops.sgmv_shrink that handles any nslices.
41
42
    """
    for index in range(nslices):
43
        torch_ops.sgmv_shrink(
44
45
46
47
48
49
50
51
52
53
54
55
56
            inputs_tensor,
            lora_weights_lst[index],
            out_tensor[index],
            b_seq_start_loc,
            seq_len_tensor,
            prompt_lora_mapping,
            batches,
            max_seq_length,
            num_tokens,
            scaling,
        )


57
58
59
60
61
62
63
64
65
66
67
68
69
70
def sgmv_expand_for_nslices(
    nslices: int,
    hidden_size: int,
    inputs_tensor: torch.Tensor,
    lora_weights_lst: list[torch.Tensor],
    out_tensor: torch.Tensor,
    b_seq_start_loc: torch.Tensor,
    seq_len_tensor: torch.Tensor,
    prompt_lora_mapping: torch.Tensor,
    batches: int,
    max_seq_length: int,
    num_tokens: int,
    add_inputs: bool,
) -> None:
71
    """
72
    Wrapper around torch_ops.sgmv_expand that handles any nslices.
73
74
75
    """
    if nslices == 1:
        # Verify the torch's sgmv_expand op
76
        torch_ops.sgmv_expand(
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
            inputs_tensor[0],
            lora_weights_lst[0],
            out_tensor,
            b_seq_start_loc,
            seq_len_tensor,
            prompt_lora_mapping,
            batches,
            max_seq_length,
            num_tokens,
            add_inputs=add_inputs,
        )
    else:
        slice_offset = 0
        for index in range(nslices):
            lora_weights = lora_weights_lst[index]
92
            torch_ops.sgmv_expand_slice(
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
                inputs_tensor[index],
                lora_weights,
                out_tensor,
                b_seq_start_loc,
                seq_len_tensor,
                prompt_lora_mapping,
                batches,
                max_seq_length,
                num_tokens,
                slice_offset,
                hidden_size,
                add_inputs=add_inputs,
            )
            slice_offset += hidden_size


_dict_lock = Lock()


112
113
114
115
116
117
118
119
120
121
122
def check_lora_shrink_kernel(
    batches: int,
    num_loras: int,
    rank: int,
    hidden_size: int,
    nslices: int,
    dtype: torch.dtype,
    device: str,
    seq_length: int,
    scaling: float,
):
123
    """
124
125
    Compare outputs of torch_ops.sgmv_shrink and triton_ops.lora_shrink
    kernels.
126
127
128
129
130
131
132
133
134
135
136
137
138
139
    """
    data: PunicaTensors = generate_data_for_nslices(
        batches,
        hidden_size,
        num_loras,
        rank,
        seq_length,
        nslices,
        dtype,
        "shrink",
        device,
    )
    max_seq_length, token_nums = data.meta()

140
    # Setup metadata information for SGMV and reference kernels
141
142
143
144
145
146
147
148
    sgmv_meta_args = (
        data.b_seq_start_loc,
        data.seq_len_tensor,
        data.prompt_lora_mapping,
        batches,
        max_seq_length,
        token_nums,
    )
149

150
    # Setup metadata information for the LoRA kernel.
151
    lora_meta = LoRAKernelMeta.make(
152
153
154
        max_loras=num_loras,
        max_num_tokens=token_nums,
        device=DEVICE_TYPE,
155
    )
156
    lora_meta.prepare_tensors(data.token_lora_mapping)
157
158

    ref_out_tensor = data.ref_out_tensor
159
    out_tensor = data.our_out_tensor.clone()
160

161
162
    # Preventing cache error pointer.
    with _dict_lock:
163
        # lora_shrink kernel
164
        _LORA_A_PTR_DICT.clear()
165
        triton_ops.lora_shrink(
166
167
            data.inputs_tensor,
            data.lora_weights,
168
            out_tensor,
169
            *lora_meta.meta_args(token_nums=token_nums, specialize_active_lora=False),
170
171
172
            scaling,
        )

173
174
175
176
177
178
179
180
181
    # Reference
    sgmv_shrink_for_nslices(
        nslices,
        data.inputs_tensor,
        data.lora_weights,
        ref_out_tensor,
        *sgmv_meta_args,
        scaling,
    )
182

183
    assert_close(out_tensor, ref_out_tensor)
184
185


186
187
188
189
190
191
192
193
194
195
196
def check_lora_expand_kernel(
    batches: int,
    num_loras: int,
    rank: int,
    hidden_size: int,
    nslices: int,
    dtype: torch.dtype,
    device: str,
    seq_length: int,
    add_inputs: bool,
):
197
    """
198
199
    Compare outputs of torch_ops.sgmv_expand and triton_ops.lora_expand
    kernels.
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
    """
    data: PunicaTensors = generate_data_for_nslices(
        batches,
        hidden_size,
        num_loras,
        rank,
        seq_length,
        nslices,
        dtype,
        "expand",
        device,
    )

    max_seq_length, token_nums = data.meta()

215
    # Setup metadata information for SGMV and reference kernels
216
217
218
219
220
221
222
223
    sgmv_meta_args = (
        data.b_seq_start_loc,
        data.seq_len_tensor,
        data.prompt_lora_mapping,
        batches,
        max_seq_length,
        token_nums,
    )
224

225
    # Setup metadata information for the LoRA kernel.
226
    lora_meta = LoRAKernelMeta.make(
227
228
229
        max_loras=num_loras,
        max_num_tokens=token_nums,
        device=DEVICE_TYPE,
230
    )
231
    lora_meta.prepare_tensors(data.token_lora_mapping)
232
233
234

    # Setup output tensors
    ref_out_tensor = data.ref_out_tensor
235
    out_tensor = data.our_out_tensor.clone()
236

237
    with _dict_lock:
238
        # lora_expand kernel
239
        _LORA_B_PTR_DICT.clear()
240
241
242
243
        triton_ops.lora_expand(
            data.inputs_tensor,
            data.lora_weights,
            out_tensor,
244
            *lora_meta.meta_args(token_nums=token_nums, specialize_active_lora=False),
245
246
247
            offset_start=0,
            add_inputs=add_inputs,
        )
248
249

    # Reference
250
251
252
253
254
255
256
257
258
    sgmv_expand_for_nslices(
        nslices,
        hidden_size,
        data.inputs_tensor,
        data.lora_weights,
        ref_out_tensor,
        *sgmv_meta_args,
        add_inputs=add_inputs,
    )
259

260
    assert_close(out_tensor, ref_out_tensor)
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348


# Tests
# We test the punica kernels along 2 verticals mainly.
# 1. Variations in hidden_dim size
# 2. Variations in all other parameters like (batch_size, max_rank, num_loras
#  etc.)

# We have collected the hidden_sizes included in the LoRA models
# currently supported by vLLM. It tests whether the corresponding Triton
# kernel can run normally when tensor parallelism is set to
# [1, 2, 4, 8, 16, 32, 64].
HIDDEN_SIZES = [
    128,
    256,
    512,
    896,
    1024,
    1152,
    1216,
    1280,
    1536,
    1664,
    2048,
    2240,
    2304,
    2368,
    2432,
    2560,
    2752,
    3072,
    3328,
    3456,
    3584,
    3712,
    4096,
    4480,
    4608,
    4736,
    4864,
    5120,
    5504,
    5632,
    5888,
    6144,
    6400,
    6848,
    6912,
    7168,
    7424,
    8192,
    8960,
    9216,
    9472,
    10240,
    11008,
    11264,
    13824,
    14336,
    14784,
    14848,
    15360,
    18944,
    22016,
    22528,
    24576,
    27392,
    27648,
    29568,
    29696,
    32000,
    32256,
    32512,
    32768,
    33024,
    36864,
    43264,
    49152,
    49408,
    60544,
    60672,
    64000,
    64256,
    102400,
    102656,
    128000,
    128256,
]
349
# The size of TP
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
divisibility = [1, 2, 8, 16, 64]

all_hidden_size = []
for div in divisibility:
    for hidden_size in HIDDEN_SIZES:
        all_hidden_size.append(hidden_size // div)

HIDDEN_SIZES = list(set(all_hidden_size))

# Test params that focuses on hidden_size variation.
hs_test_params = {
    "hidden_sizes": HIDDEN_SIZES,
    "batches": [4],
    "num_loras": [4],
    "max_ranks": [32],
}

# General tests params that tests for variations in all dimensions
# except hidden_size.
test_params = {
    "hidden_sizes": [2049],
    "batches": [1, 4, 16, 32],
    "num_loras": [1, 8, 32, 128],
    "max_ranks": [1, 4, 8, 16, 32, 64, 128, 256],
}

DTYPES = [torch.float16, torch.bfloat16]
377
DEVICES = [f"{DEVICE_TYPE}:{0}"]
378
379
380
SEED = [0]


381
382
383
384
@pytest.mark.parametrize("batches", test_params["batches"])
@pytest.mark.parametrize("num_loras", test_params["num_loras"])
@pytest.mark.parametrize("rank", test_params["max_ranks"])
@pytest.mark.parametrize("hidden_size", test_params["hidden_sizes"])
385
386
387
388
389
@pytest.mark.parametrize("nslices", [1, 2, 3])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
390
def test_kernels(
391
392
393
394
395
396
397
398
399
400
    batches: int,
    num_loras: int,
    rank: int,
    hidden_size: int,
    nslices: int,
    dtype: torch.dtype,
    device: str,
    seed: int,
    op_type: str,
):
401
    """
402
    Tests LoRA kernels.
403
    """
404
    torch.set_default_device(device)
405
    torch.accelerator.set_device_index(device)
406
    set_random_seed(seed)
407
408

    if op_type == "shrink":
409
410
411
412
413
414
415
416
417
418
419
        check_lora_shrink_kernel(
            batches=batches,
            num_loras=num_loras,
            rank=rank,
            hidden_size=hidden_size,
            nslices=nslices,
            dtype=dtype,
            device=device,
            seq_length=128,
            scaling=0.5,
        )
420
    else:
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
        check_lora_expand_kernel(
            batches=batches,
            num_loras=num_loras,
            rank=rank,
            hidden_size=hidden_size,
            nslices=nslices,
            dtype=dtype,
            device=device,
            seq_length=128,
            add_inputs=True,
        )


@pytest.mark.parametrize("batches", hs_test_params["batches"])
@pytest.mark.parametrize("num_loras", hs_test_params["num_loras"])
@pytest.mark.parametrize("rank", hs_test_params["max_ranks"])
@pytest.mark.parametrize("hidden_size", hs_test_params["hidden_sizes"])
438
439
440
441
442
@pytest.mark.parametrize("nslices", [1, 2, 3])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
443
def test_kernels_hidden_size(
444
445
446
447
448
449
450
451
452
453
    batches: int,
    num_loras: int,
    rank: int,
    hidden_size: int,
    nslices: int,
    dtype: torch.dtype,
    device: str,
    seed: int,
    op_type: str,
):
454
    """
455
    Tests SGMV and LoRA kernels.
456
    """
457
    torch.set_default_device(device)
458
    torch.accelerator.set_device_index(device)
459
    set_random_seed(seed)
460
461

    if op_type == "shrink":
462
463
464
465
466
467
468
469
470
471
472
        check_lora_shrink_kernel(
            batches=batches,
            num_loras=num_loras,
            rank=rank,
            hidden_size=hidden_size,
            nslices=nslices,
            dtype=dtype,
            device=device,
            seq_length=128,
            scaling=0.5,
        )
473
    else:
474
475
476
477
478
479
480
481
482
483
484
        check_lora_expand_kernel(
            batches=batches,
            num_loras=num_loras,
            rank=rank,
            hidden_size=hidden_size,
            nslices=nslices,
            dtype=dtype,
            device=device,
            seq_length=128,
            add_inputs=True,
        )