test_punica_ops.py 10.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
from threading import Lock

import pytest
import torch

8
9
10
import vllm.lora.ops.torch_ops as torch_ops
import vllm.lora.ops.triton_ops as triton_ops
from vllm.lora.ops.triton_ops import LoRAKernelMeta
11
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
12
from vllm.utils.torch_utils import set_random_seed
13

14
from .utils import PunicaTensors, assert_close, generate_data_for_nslices
15
16


17
18
19
20
21
@pytest.fixture(autouse=True)
def reset_device(reset_default_device):
    pass


22
23
# Utility shrink and expand operations used as reference implementations.
def sgmv_shrink_for_nslices(
24
25
26
27
28
29
30
31
32
33
34
35
    nslices: int,
    inputs_tensor: torch.Tensor,
    lora_weights_lst: list[torch.Tensor],
    out_tensor: torch.Tensor,
    b_seq_start_loc: torch.Tensor,
    seq_len_tensor: torch.Tensor,
    prompt_lora_mapping: torch.Tensor,
    batches: int,
    max_seq_length: int,
    num_tokens: int,
    scaling: float,
):
36
    """
37
    Wrapper around torch_ops.sgmv_shrink that handles any nslices.
38
39
    """
    for index in range(nslices):
40
        torch_ops.sgmv_shrink(
41
42
43
44
45
46
47
48
49
50
51
52
53
            inputs_tensor,
            lora_weights_lst[index],
            out_tensor[index],
            b_seq_start_loc,
            seq_len_tensor,
            prompt_lora_mapping,
            batches,
            max_seq_length,
            num_tokens,
            scaling,
        )


54
55
56
57
58
59
60
61
62
63
64
65
66
67
def sgmv_expand_for_nslices(
    nslices: int,
    hidden_size: int,
    inputs_tensor: torch.Tensor,
    lora_weights_lst: list[torch.Tensor],
    out_tensor: torch.Tensor,
    b_seq_start_loc: torch.Tensor,
    seq_len_tensor: torch.Tensor,
    prompt_lora_mapping: torch.Tensor,
    batches: int,
    max_seq_length: int,
    num_tokens: int,
    add_inputs: bool,
) -> None:
68
    """
69
    Wrapper around torch_ops.sgmv_expand that handles any nslices.
70
71
72
    """
    if nslices == 1:
        # Verify the torch's sgmv_expand op
73
        torch_ops.sgmv_expand(
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
            inputs_tensor[0],
            lora_weights_lst[0],
            out_tensor,
            b_seq_start_loc,
            seq_len_tensor,
            prompt_lora_mapping,
            batches,
            max_seq_length,
            num_tokens,
            add_inputs=add_inputs,
        )
    else:
        slice_offset = 0
        for index in range(nslices):
            lora_weights = lora_weights_lst[index]
89
            torch_ops.sgmv_expand_slice(
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
                inputs_tensor[index],
                lora_weights,
                out_tensor,
                b_seq_start_loc,
                seq_len_tensor,
                prompt_lora_mapping,
                batches,
                max_seq_length,
                num_tokens,
                slice_offset,
                hidden_size,
                add_inputs=add_inputs,
            )
            slice_offset += hidden_size


_dict_lock = Lock()


109
110
111
112
113
114
115
116
117
118
119
def check_lora_shrink_kernel(
    batches: int,
    num_loras: int,
    rank: int,
    hidden_size: int,
    nslices: int,
    dtype: torch.dtype,
    device: str,
    seq_length: int,
    scaling: float,
):
120
    """
121
122
    Compare outputs of torch_ops.sgmv_shrink and triton_ops.lora_shrink
    kernels.
123
124
125
126
127
128
129
130
131
132
133
134
135
136
    """
    data: PunicaTensors = generate_data_for_nslices(
        batches,
        hidden_size,
        num_loras,
        rank,
        seq_length,
        nslices,
        dtype,
        "shrink",
        device,
    )
    max_seq_length, token_nums = data.meta()

137
    # Setup metadata information for SGMV and reference kernels
138
139
140
141
142
143
144
145
    sgmv_meta_args = (
        data.b_seq_start_loc,
        data.seq_len_tensor,
        data.prompt_lora_mapping,
        batches,
        max_seq_length,
        token_nums,
    )
146

147
    # Setup metadata information for the LoRA kernel.
148
149
150
    lora_meta = LoRAKernelMeta.make(
        max_loras=num_loras, max_num_tokens=token_nums, device="cuda"
    )
151
    lora_meta.prepare_tensors(data.token_lora_mapping)
152
153

    ref_out_tensor = data.ref_out_tensor
154
    out_tensor = data.our_out_tensor.clone()
155

156
157
    # Preventing cache error pointer.
    with _dict_lock:
158
        # lora_shrink kernel
159
        _LORA_A_PTR_DICT.clear()
160
        triton_ops.lora_shrink(
161
162
            data.inputs_tensor,
            data.lora_weights,
163
164
            out_tensor,
            *lora_meta.meta_args(token_nums=token_nums),
165
166
167
            scaling,
        )

168
169
170
171
172
173
174
175
176
    # Reference
    sgmv_shrink_for_nslices(
        nslices,
        data.inputs_tensor,
        data.lora_weights,
        ref_out_tensor,
        *sgmv_meta_args,
        scaling,
    )
177

178
    assert_close(out_tensor, ref_out_tensor)
179
180


181
182
183
184
185
186
187
188
189
190
191
def check_lora_expand_kernel(
    batches: int,
    num_loras: int,
    rank: int,
    hidden_size: int,
    nslices: int,
    dtype: torch.dtype,
    device: str,
    seq_length: int,
    add_inputs: bool,
):
192
    """
193
194
    Compare outputs of torch_ops.sgmv_expand and triton_ops.lora_expand
    kernels.
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
    """
    data: PunicaTensors = generate_data_for_nslices(
        batches,
        hidden_size,
        num_loras,
        rank,
        seq_length,
        nslices,
        dtype,
        "expand",
        device,
    )

    max_seq_length, token_nums = data.meta()

210
    # Setup metadata information for SGMV and reference kernels
211
212
213
214
215
216
217
218
    sgmv_meta_args = (
        data.b_seq_start_loc,
        data.seq_len_tensor,
        data.prompt_lora_mapping,
        batches,
        max_seq_length,
        token_nums,
    )
219

220
    # Setup metadata information for the LoRA kernel.
221
222
223
    lora_meta = LoRAKernelMeta.make(
        max_loras=num_loras, max_num_tokens=token_nums, device="cuda"
    )
224
    lora_meta.prepare_tensors(data.token_lora_mapping)
225
226
227

    # Setup output tensors
    ref_out_tensor = data.ref_out_tensor
228
    out_tensor = data.our_out_tensor.clone()
229

230
    with _dict_lock:
231
        # lora_expand kernel
232
        _LORA_B_PTR_DICT.clear()
233
234
235
236
237
238
239
240
        triton_ops.lora_expand(
            data.inputs_tensor,
            data.lora_weights,
            out_tensor,
            *lora_meta.meta_args(token_nums=token_nums),
            offset_start=0,
            add_inputs=add_inputs,
        )
241
242

    # Reference
243
244
245
246
247
248
249
250
251
    sgmv_expand_for_nslices(
        nslices,
        hidden_size,
        data.inputs_tensor,
        data.lora_weights,
        ref_out_tensor,
        *sgmv_meta_args,
        add_inputs=add_inputs,
    )
252

253
    assert_close(out_tensor, ref_out_tensor)
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341


# Tests
# We test the punica kernels along 2 verticals mainly.
# 1. Variations in hidden_dim size
# 2. Variations in all other parameters like (batch_size, max_rank, num_loras
#  etc.)

# We have collected the hidden_sizes included in the LoRA models
# currently supported by vLLM. It tests whether the corresponding Triton
# kernel can run normally when tensor parallelism is set to
# [1, 2, 4, 8, 16, 32, 64].
HIDDEN_SIZES = [
    128,
    256,
    512,
    896,
    1024,
    1152,
    1216,
    1280,
    1536,
    1664,
    2048,
    2240,
    2304,
    2368,
    2432,
    2560,
    2752,
    3072,
    3328,
    3456,
    3584,
    3712,
    4096,
    4480,
    4608,
    4736,
    4864,
    5120,
    5504,
    5632,
    5888,
    6144,
    6400,
    6848,
    6912,
    7168,
    7424,
    8192,
    8960,
    9216,
    9472,
    10240,
    11008,
    11264,
    13824,
    14336,
    14784,
    14848,
    15360,
    18944,
    22016,
    22528,
    24576,
    27392,
    27648,
    29568,
    29696,
    32000,
    32256,
    32512,
    32768,
    33024,
    36864,
    43264,
    49152,
    49408,
    60544,
    60672,
    64000,
    64256,
    102400,
    102656,
    128000,
    128256,
]
342
# The size of TP
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
divisibility = [1, 2, 8, 16, 64]

all_hidden_size = []
for div in divisibility:
    for hidden_size in HIDDEN_SIZES:
        all_hidden_size.append(hidden_size // div)

HIDDEN_SIZES = list(set(all_hidden_size))

# Test params that focuses on hidden_size variation.
hs_test_params = {
    "hidden_sizes": HIDDEN_SIZES,
    "batches": [4],
    "num_loras": [4],
    "max_ranks": [32],
}

# General tests params that tests for variations in all dimensions
# except hidden_size.
test_params = {
    "hidden_sizes": [2049],
    "batches": [1, 4, 16, 32],
    "num_loras": [1, 8, 32, 128],
    "max_ranks": [1, 4, 8, 16, 32, 64, 128, 256],
}

DTYPES = [torch.float16, torch.bfloat16]
DEVICES = [f"cuda:{0}"]
SEED = [0]


374
375
376
377
@pytest.mark.parametrize("batches", test_params["batches"])
@pytest.mark.parametrize("num_loras", test_params["num_loras"])
@pytest.mark.parametrize("rank", test_params["max_ranks"])
@pytest.mark.parametrize("hidden_size", test_params["hidden_sizes"])
378
379
380
381
382
@pytest.mark.parametrize("nslices", [1, 2, 3])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
383
def test_kernels(
384
385
386
387
388
389
390
391
392
393
    batches: int,
    num_loras: int,
    rank: int,
    hidden_size: int,
    nslices: int,
    dtype: torch.dtype,
    device: str,
    seed: int,
    op_type: str,
):
394
    """
395
    Tests LoRA kernels.
396
    """
397
    torch.set_default_device(device)
398
    set_random_seed(seed)
399
400

    if op_type == "shrink":
401
402
403
404
405
406
407
408
409
410
411
        check_lora_shrink_kernel(
            batches=batches,
            num_loras=num_loras,
            rank=rank,
            hidden_size=hidden_size,
            nslices=nslices,
            dtype=dtype,
            device=device,
            seq_length=128,
            scaling=0.5,
        )
412
    else:
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
        check_lora_expand_kernel(
            batches=batches,
            num_loras=num_loras,
            rank=rank,
            hidden_size=hidden_size,
            nslices=nslices,
            dtype=dtype,
            device=device,
            seq_length=128,
            add_inputs=True,
        )


@pytest.mark.parametrize("batches", hs_test_params["batches"])
@pytest.mark.parametrize("num_loras", hs_test_params["num_loras"])
@pytest.mark.parametrize("rank", hs_test_params["max_ranks"])
@pytest.mark.parametrize("hidden_size", hs_test_params["hidden_sizes"])
430
431
432
433
434
@pytest.mark.parametrize("nslices", [1, 2, 3])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
435
def test_kernels_hidden_size(
436
437
438
439
440
441
442
443
444
445
    batches: int,
    num_loras: int,
    rank: int,
    hidden_size: int,
    nslices: int,
    dtype: torch.dtype,
    device: str,
    seed: int,
    op_type: str,
):
446
    """
447
    Tests SGMV and LoRA kernels.
448
    """
449
    torch.set_default_device(device)
450
    set_random_seed(seed)
451
452

    if op_type == "shrink":
453
454
455
456
457
458
459
460
461
462
463
        check_lora_shrink_kernel(
            batches=batches,
            num_loras=num_loras,
            rank=rank,
            hidden_size=hidden_size,
            nslices=nslices,
            dtype=dtype,
            device=device,
            seq_length=128,
            scaling=0.5,
        )
464
    else:
465
466
467
468
469
470
471
472
473
474
475
        check_lora_expand_kernel(
            batches=batches,
            num_loras=num_loras,
            rank=rank,
            hidden_size=hidden_size,
            nslices=nslices,
            dtype=dtype,
            device=device,
            seq_length=128,
            add_inputs=True,
        )