test_punica_ops.py 12 KB
Newer Older
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
from threading import Lock

import pytest
import torch

7
8
9
import vllm.lora.ops.torch_ops as torch_ops
import vllm.lora.ops.triton_ops as triton_ops
from vllm.lora.ops.triton_ops import LoRAKernelMeta
10
11
12
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
from vllm.platforms import current_platform

13
from .utils import PunicaTensors, assert_close, generate_data_for_nslices
14
15


16
17
18
19
20
@pytest.fixture(autouse=True)
def reset_device(reset_default_device):
    pass


21
22
23
# Utility shrink and expand operations used as reference implementations.
def sgmv_shrink_for_nslices(
        nslices: int, inputs_tensor: torch.Tensor,
24
        lora_weights_lst: list[torch.Tensor], out_tensor: torch.Tensor,
25
26
27
28
        b_seq_start_loc: torch.Tensor, seq_len_tensor: torch.Tensor,
        prompt_lora_mapping: torch.Tensor, batches: int, max_seq_length: int,
        num_tokens: int, scaling: float):
    """
29
    Wrapper around torch_ops.sgmv_shrink that handles any nslices.
30
31
    """
    for index in range(nslices):
32
        torch_ops.sgmv_shrink(
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
            inputs_tensor,
            lora_weights_lst[index],
            out_tensor[index],
            b_seq_start_loc,
            seq_len_tensor,
            prompt_lora_mapping,
            batches,
            max_seq_length,
            num_tokens,
            scaling,
        )


def sgmv_expand_for_nslices(nslices: int, hidden_size: int,
                            inputs_tensor: torch.Tensor,
48
                            lora_weights_lst: list[torch.Tensor],
49
50
51
52
53
54
55
                            out_tensor: torch.Tensor,
                            b_seq_start_loc: torch.Tensor,
                            seq_len_tensor: torch.Tensor,
                            prompt_lora_mapping: torch.Tensor, batches: int,
                            max_seq_length: int, num_tokens: int,
                            add_inputs: bool) -> None:
    """
56
    Wrapper around torch_ops.sgmv_expand that handles any nslices.
57
58
59
    """
    if nslices == 1:
        # Verify the torch's sgmv_expand op
60
        torch_ops.sgmv_expand(
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
            inputs_tensor[0],
            lora_weights_lst[0],
            out_tensor,
            b_seq_start_loc,
            seq_len_tensor,
            prompt_lora_mapping,
            batches,
            max_seq_length,
            num_tokens,
            add_inputs=add_inputs,
        )
    else:
        slice_offset = 0
        for index in range(nslices):
            lora_weights = lora_weights_lst[index]
76
            torch_ops.sgmv_expand_slice(
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
                inputs_tensor[index],
                lora_weights,
                out_tensor,
                b_seq_start_loc,
                seq_len_tensor,
                prompt_lora_mapping,
                batches,
                max_seq_length,
                num_tokens,
                slice_offset,
                hidden_size,
                add_inputs=add_inputs,
            )
            slice_offset += hidden_size


_dict_lock = Lock()


96
97
98
99
def check_lora_shrink_kernel(batches: int, num_loras: int, rank: int,
                             hidden_size: int, nslices: int,
                             dtype: torch.dtype, device: str, seq_length: int,
                             scaling: float):
100
    """
101
102
    Compare outputs of torch_ops.sgmv_shrink and triton_ops.lora_shrink
    kernels.
103
104
105
106
107
108
109
110
111
112
113
114
115
116
    """
    data: PunicaTensors = generate_data_for_nslices(
        batches,
        hidden_size,
        num_loras,
        rank,
        seq_length,
        nslices,
        dtype,
        "shrink",
        device,
    )
    max_seq_length, token_nums = data.meta()

117
118
119
120
121
    # Setup metadata information for SGMV and reference kernels
    sgmv_meta_args = (data.b_seq_start_loc, data.seq_len_tensor,
                      data.prompt_lora_mapping, batches, max_seq_length,
                      token_nums)

122
123
124
125
126
    # Setup metadata information for the LoRA kernel.
    lora_meta = LoRAKernelMeta.make(max_loras=num_loras,
                                    max_num_tokens=token_nums,
                                    device='cuda')
    lora_meta.prepare_tensors(data.token_lora_mapping)
127
128

    ref_out_tensor = data.ref_out_tensor
129
    out_tensor = data.our_out_tensor.clone()
130

131
132
    # Preventing cache error pointer.
    with _dict_lock:
133
        # lora_shrink kernel
134
        _LORA_A_PTR_DICT.clear()
135
        triton_ops.lora_shrink(
136
137
            data.inputs_tensor,
            data.lora_weights,
138
139
            out_tensor,
            *lora_meta.meta_args(token_nums=token_nums),
140
141
142
            scaling,
        )

143
144
145
146
147
148
149
150
151
    # Reference
    sgmv_shrink_for_nslices(
        nslices,
        data.inputs_tensor,
        data.lora_weights,
        ref_out_tensor,
        *sgmv_meta_args,
        scaling,
    )
152

153
    assert_close(out_tensor, ref_out_tensor)
154
155


156
157
158
159
def check_lora_expand_kernel(batches: int, num_loras: int, rank: int,
                             hidden_size: int, nslices: int,
                             dtype: torch.dtype, device: str, seq_length: int,
                             add_inputs: bool):
160
    """
161
162
    Compare outputs of torch_ops.sgmv_expand and triton_ops.lora_expand
    kernels.
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
    """
    data: PunicaTensors = generate_data_for_nslices(
        batches,
        hidden_size,
        num_loras,
        rank,
        seq_length,
        nslices,
        dtype,
        "expand",
        device,
    )

    max_seq_length, token_nums = data.meta()

178
179
180
181
182
    # Setup metadata information for SGMV and reference kernels
    sgmv_meta_args = (data.b_seq_start_loc, data.seq_len_tensor,
                      data.prompt_lora_mapping, batches, max_seq_length,
                      token_nums)

183
184
185
186
187
    # Setup metadata information for the LoRA kernel.
    lora_meta = LoRAKernelMeta.make(max_loras=num_loras,
                                    max_num_tokens=token_nums,
                                    device='cuda')
    lora_meta.prepare_tensors(data.token_lora_mapping)
188
189
190

    # Setup output tensors
    ref_out_tensor = data.ref_out_tensor
191
    out_tensor = data.our_out_tensor.clone()
192

193
    with _dict_lock:
194
        # lora_expand kernel
195
        _LORA_B_PTR_DICT.clear()
196
197
198
199
200
201
        triton_ops.lora_expand(data.inputs_tensor,
                               data.lora_weights,
                               out_tensor,
                               *lora_meta.meta_args(token_nums=token_nums),
                               offset_start=0,
                               add_inputs=add_inputs)
202
203

    # Reference
204
205
206
207
    sgmv_expand_for_nslices(nslices,
                            hidden_size,
                            data.inputs_tensor,
                            data.lora_weights,
208
209
                            ref_out_tensor,
                            *sgmv_meta_args,
210
211
                            add_inputs=add_inputs)

212
    assert_close(out_tensor, ref_out_tensor)
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341


# Tests
# We test the punica kernels along 2 verticals mainly.
# 1. Variations in hidden_dim size
# 2. Variations in all other parameters like (batch_size, max_rank, num_loras
#  etc.)

# We have collected the hidden_sizes included in the LoRA models
# currently supported by vLLM. It tests whether the corresponding Triton
# kernel can run normally when tensor parallelism is set to
# [1, 2, 4, 8, 16, 32, 64].
HIDDEN_SIZES = [
    128,
    256,
    512,
    896,
    1024,
    1152,
    1216,
    1280,
    1536,
    1664,
    2048,
    2240,
    2304,
    2368,
    2432,
    2560,
    2752,
    3072,
    3328,
    3456,
    3584,
    3712,
    4096,
    4480,
    4608,
    4736,
    4864,
    5120,
    5504,
    5632,
    5888,
    6144,
    6400,
    6848,
    6912,
    7168,
    7424,
    8192,
    8960,
    9216,
    9472,
    10240,
    11008,
    11264,
    13824,
    14336,
    14784,
    14848,
    15360,
    18944,
    22016,
    22528,
    24576,
    27392,
    27648,
    29568,
    29696,
    32000,
    32256,
    32512,
    32768,
    33024,
    36864,
    43264,
    49152,
    49408,
    60544,
    60672,
    64000,
    64256,
    102400,
    102656,
    128000,
    128256,
]
#The size of TP
divisibility = [1, 2, 8, 16, 64]

all_hidden_size = []
for div in divisibility:
    for hidden_size in HIDDEN_SIZES:
        all_hidden_size.append(hidden_size // div)

HIDDEN_SIZES = list(set(all_hidden_size))

# Test params that focuses on hidden_size variation.
hs_test_params = {
    "hidden_sizes": HIDDEN_SIZES,
    "batches": [4],
    "num_loras": [4],
    "max_ranks": [32],
}

# General tests params that tests for variations in all dimensions
# except hidden_size.
test_params = {
    "hidden_sizes": [2049],
    "batches": [1, 4, 16, 32],
    "num_loras": [1, 8, 32, 128],
    "max_ranks": [1, 4, 8, 16, 32, 64, 128, 256],
}

DTYPES = [torch.float16, torch.bfloat16]
DEVICES = [f"cuda:{0}"]
SEED = [0]


@pytest.mark.parametrize("batches", test_params['batches'])
@pytest.mark.parametrize("num_loras", test_params['num_loras'])
@pytest.mark.parametrize("rank", test_params['max_ranks'])
@pytest.mark.parametrize("hidden_size", test_params['hidden_sizes'])
@pytest.mark.parametrize("nslices", [1, 2, 3])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
342
def test_kernels(
343
344
345
346
347
348
349
350
351
352
    batches: int,
    num_loras: int,
    rank: int,
    hidden_size: int,
    nslices: int,
    dtype: torch.dtype,
    device: str,
    seed: int,
    op_type: str,
):
353
    """
354
    Tests LoRA kernels.
355
    """
356
357
358
359
    torch.set_default_device(device)
    current_platform.seed_everything(seed)

    if op_type == "shrink":
360
361
362
363
364
365
366
367
368
        check_lora_shrink_kernel(batches=batches,
                                 num_loras=num_loras,
                                 rank=rank,
                                 hidden_size=hidden_size,
                                 nslices=nslices,
                                 dtype=dtype,
                                 device=device,
                                 seq_length=128,
                                 scaling=0.5)
369
    else:
370
371
372
373
374
375
376
377
378
        check_lora_expand_kernel(batches=batches,
                                 num_loras=num_loras,
                                 rank=rank,
                                 hidden_size=hidden_size,
                                 nslices=nslices,
                                 dtype=dtype,
                                 device=device,
                                 seq_length=128,
                                 add_inputs=True)
379
380
381
382
383
384
385
386
387
388
389


@pytest.mark.parametrize("batches", hs_test_params['batches'])
@pytest.mark.parametrize("num_loras", hs_test_params['num_loras'])
@pytest.mark.parametrize("rank", hs_test_params['max_ranks'])
@pytest.mark.parametrize("hidden_size", hs_test_params['hidden_sizes'])
@pytest.mark.parametrize("nslices", [1, 2, 3])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
390
def test_kernels_hidden_size(
391
392
393
394
395
396
397
398
399
400
    batches: int,
    num_loras: int,
    rank: int,
    hidden_size: int,
    nslices: int,
    dtype: torch.dtype,
    device: str,
    seed: int,
    op_type: str,
):
401
    """
402
    Tests SGMV and LoRA kernels.
403
    """
404
405
406
407
    torch.set_default_device(device)
    current_platform.seed_everything(seed)

    if op_type == "shrink":
408
409
410
411
412
413
414
415
416
        check_lora_shrink_kernel(batches=batches,
                                 num_loras=num_loras,
                                 rank=rank,
                                 hidden_size=hidden_size,
                                 nslices=nslices,
                                 dtype=dtype,
                                 device=device,
                                 seq_length=128,
                                 scaling=0.5)
417
    else:
418
419
420
421
422
423
424
425
426
        check_lora_expand_kernel(batches=batches,
                                 num_loras=num_loras,
                                 rank=rank,
                                 hidden_size=hidden_size,
                                 nslices=nslices,
                                 dtype=dtype,
                                 device=device,
                                 seq_length=128,
                                 add_inputs=True)