test_punica_ops.py 12 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
from threading import Lock

import pytest
import torch

8
9
10
import vllm.lora.ops.torch_ops as torch_ops
import vllm.lora.ops.triton_ops as triton_ops
from vllm.lora.ops.triton_ops import LoRAKernelMeta
11
12
13
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
from vllm.platforms import current_platform

14
from .utils import PunicaTensors, assert_close, generate_data_for_nslices
15
16


17
18
19
20
21
@pytest.fixture(autouse=True)
def reset_device(reset_default_device):
    pass


22
23
24
# Utility shrink and expand operations used as reference implementations.
def sgmv_shrink_for_nslices(
        nslices: int, inputs_tensor: torch.Tensor,
25
        lora_weights_lst: list[torch.Tensor], out_tensor: torch.Tensor,
26
27
28
29
        b_seq_start_loc: torch.Tensor, seq_len_tensor: torch.Tensor,
        prompt_lora_mapping: torch.Tensor, batches: int, max_seq_length: int,
        num_tokens: int, scaling: float):
    """
30
    Wrapper around torch_ops.sgmv_shrink that handles any nslices.
31
32
    """
    for index in range(nslices):
33
        torch_ops.sgmv_shrink(
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
            inputs_tensor,
            lora_weights_lst[index],
            out_tensor[index],
            b_seq_start_loc,
            seq_len_tensor,
            prompt_lora_mapping,
            batches,
            max_seq_length,
            num_tokens,
            scaling,
        )


def sgmv_expand_for_nslices(nslices: int, hidden_size: int,
                            inputs_tensor: torch.Tensor,
49
                            lora_weights_lst: list[torch.Tensor],
50
51
52
53
54
55
56
                            out_tensor: torch.Tensor,
                            b_seq_start_loc: torch.Tensor,
                            seq_len_tensor: torch.Tensor,
                            prompt_lora_mapping: torch.Tensor, batches: int,
                            max_seq_length: int, num_tokens: int,
                            add_inputs: bool) -> None:
    """
57
    Wrapper around torch_ops.sgmv_expand that handles any nslices.
58
59
60
    """
    if nslices == 1:
        # Verify the torch's sgmv_expand op
61
        torch_ops.sgmv_expand(
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
            inputs_tensor[0],
            lora_weights_lst[0],
            out_tensor,
            b_seq_start_loc,
            seq_len_tensor,
            prompt_lora_mapping,
            batches,
            max_seq_length,
            num_tokens,
            add_inputs=add_inputs,
        )
    else:
        slice_offset = 0
        for index in range(nslices):
            lora_weights = lora_weights_lst[index]
77
            torch_ops.sgmv_expand_slice(
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
                inputs_tensor[index],
                lora_weights,
                out_tensor,
                b_seq_start_loc,
                seq_len_tensor,
                prompt_lora_mapping,
                batches,
                max_seq_length,
                num_tokens,
                slice_offset,
                hidden_size,
                add_inputs=add_inputs,
            )
            slice_offset += hidden_size


_dict_lock = Lock()


97
98
99
100
def check_lora_shrink_kernel(batches: int, num_loras: int, rank: int,
                             hidden_size: int, nslices: int,
                             dtype: torch.dtype, device: str, seq_length: int,
                             scaling: float):
101
    """
102
103
    Compare outputs of torch_ops.sgmv_shrink and triton_ops.lora_shrink
    kernels.
104
105
106
107
108
109
110
111
112
113
114
115
116
117
    """
    data: PunicaTensors = generate_data_for_nslices(
        batches,
        hidden_size,
        num_loras,
        rank,
        seq_length,
        nslices,
        dtype,
        "shrink",
        device,
    )
    max_seq_length, token_nums = data.meta()

118
119
120
121
122
    # Setup metadata information for SGMV and reference kernels
    sgmv_meta_args = (data.b_seq_start_loc, data.seq_len_tensor,
                      data.prompt_lora_mapping, batches, max_seq_length,
                      token_nums)

123
124
125
126
127
    # Setup metadata information for the LoRA kernel.
    lora_meta = LoRAKernelMeta.make(max_loras=num_loras,
                                    max_num_tokens=token_nums,
                                    device='cuda')
    lora_meta.prepare_tensors(data.token_lora_mapping)
128
129

    ref_out_tensor = data.ref_out_tensor
130
    out_tensor = data.our_out_tensor.clone()
131

132
133
    # Preventing cache error pointer.
    with _dict_lock:
134
        # lora_shrink kernel
135
        _LORA_A_PTR_DICT.clear()
136
        triton_ops.lora_shrink(
137
138
            data.inputs_tensor,
            data.lora_weights,
139
140
            out_tensor,
            *lora_meta.meta_args(token_nums=token_nums),
141
142
143
            scaling,
        )

144
145
146
147
148
149
150
151
152
    # Reference
    sgmv_shrink_for_nslices(
        nslices,
        data.inputs_tensor,
        data.lora_weights,
        ref_out_tensor,
        *sgmv_meta_args,
        scaling,
    )
153

154
    assert_close(out_tensor, ref_out_tensor)
155
156


157
158
159
160
def check_lora_expand_kernel(batches: int, num_loras: int, rank: int,
                             hidden_size: int, nslices: int,
                             dtype: torch.dtype, device: str, seq_length: int,
                             add_inputs: bool):
161
    """
162
163
    Compare outputs of torch_ops.sgmv_expand and triton_ops.lora_expand
    kernels.
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
    """
    data: PunicaTensors = generate_data_for_nslices(
        batches,
        hidden_size,
        num_loras,
        rank,
        seq_length,
        nslices,
        dtype,
        "expand",
        device,
    )

    max_seq_length, token_nums = data.meta()

179
180
181
182
183
    # Setup metadata information for SGMV and reference kernels
    sgmv_meta_args = (data.b_seq_start_loc, data.seq_len_tensor,
                      data.prompt_lora_mapping, batches, max_seq_length,
                      token_nums)

184
185
186
187
188
    # Setup metadata information for the LoRA kernel.
    lora_meta = LoRAKernelMeta.make(max_loras=num_loras,
                                    max_num_tokens=token_nums,
                                    device='cuda')
    lora_meta.prepare_tensors(data.token_lora_mapping)
189
190
191

    # Setup output tensors
    ref_out_tensor = data.ref_out_tensor
192
    out_tensor = data.our_out_tensor.clone()
193

194
    with _dict_lock:
195
        # lora_expand kernel
196
        _LORA_B_PTR_DICT.clear()
197
198
199
200
201
202
        triton_ops.lora_expand(data.inputs_tensor,
                               data.lora_weights,
                               out_tensor,
                               *lora_meta.meta_args(token_nums=token_nums),
                               offset_start=0,
                               add_inputs=add_inputs)
203
204

    # Reference
205
206
207
208
    sgmv_expand_for_nslices(nslices,
                            hidden_size,
                            data.inputs_tensor,
                            data.lora_weights,
209
210
                            ref_out_tensor,
                            *sgmv_meta_args,
211
212
                            add_inputs=add_inputs)

213
    assert_close(out_tensor, ref_out_tensor)
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342


# Tests
# We test the punica kernels along 2 verticals mainly.
# 1. Variations in hidden_dim size
# 2. Variations in all other parameters like (batch_size, max_rank, num_loras
#  etc.)

# We have collected the hidden_sizes included in the LoRA models
# currently supported by vLLM. It tests whether the corresponding Triton
# kernel can run normally when tensor parallelism is set to
# [1, 2, 4, 8, 16, 32, 64].
HIDDEN_SIZES = [
    128,
    256,
    512,
    896,
    1024,
    1152,
    1216,
    1280,
    1536,
    1664,
    2048,
    2240,
    2304,
    2368,
    2432,
    2560,
    2752,
    3072,
    3328,
    3456,
    3584,
    3712,
    4096,
    4480,
    4608,
    4736,
    4864,
    5120,
    5504,
    5632,
    5888,
    6144,
    6400,
    6848,
    6912,
    7168,
    7424,
    8192,
    8960,
    9216,
    9472,
    10240,
    11008,
    11264,
    13824,
    14336,
    14784,
    14848,
    15360,
    18944,
    22016,
    22528,
    24576,
    27392,
    27648,
    29568,
    29696,
    32000,
    32256,
    32512,
    32768,
    33024,
    36864,
    43264,
    49152,
    49408,
    60544,
    60672,
    64000,
    64256,
    102400,
    102656,
    128000,
    128256,
]
#The size of TP
divisibility = [1, 2, 8, 16, 64]

all_hidden_size = []
for div in divisibility:
    for hidden_size in HIDDEN_SIZES:
        all_hidden_size.append(hidden_size // div)

HIDDEN_SIZES = list(set(all_hidden_size))

# Test params that focuses on hidden_size variation.
hs_test_params = {
    "hidden_sizes": HIDDEN_SIZES,
    "batches": [4],
    "num_loras": [4],
    "max_ranks": [32],
}

# General tests params that tests for variations in all dimensions
# except hidden_size.
test_params = {
    "hidden_sizes": [2049],
    "batches": [1, 4, 16, 32],
    "num_loras": [1, 8, 32, 128],
    "max_ranks": [1, 4, 8, 16, 32, 64, 128, 256],
}

DTYPES = [torch.float16, torch.bfloat16]
DEVICES = [f"cuda:{0}"]
SEED = [0]


@pytest.mark.parametrize("batches", test_params['batches'])
@pytest.mark.parametrize("num_loras", test_params['num_loras'])
@pytest.mark.parametrize("rank", test_params['max_ranks'])
@pytest.mark.parametrize("hidden_size", test_params['hidden_sizes'])
@pytest.mark.parametrize("nslices", [1, 2, 3])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
343
def test_kernels(
344
345
346
347
348
349
350
351
352
353
    batches: int,
    num_loras: int,
    rank: int,
    hidden_size: int,
    nslices: int,
    dtype: torch.dtype,
    device: str,
    seed: int,
    op_type: str,
):
354
    """
355
    Tests LoRA kernels.
356
    """
357
358
359
360
    torch.set_default_device(device)
    current_platform.seed_everything(seed)

    if op_type == "shrink":
361
362
363
364
365
366
367
368
369
        check_lora_shrink_kernel(batches=batches,
                                 num_loras=num_loras,
                                 rank=rank,
                                 hidden_size=hidden_size,
                                 nslices=nslices,
                                 dtype=dtype,
                                 device=device,
                                 seq_length=128,
                                 scaling=0.5)
370
    else:
371
372
373
374
375
376
377
378
379
        check_lora_expand_kernel(batches=batches,
                                 num_loras=num_loras,
                                 rank=rank,
                                 hidden_size=hidden_size,
                                 nslices=nslices,
                                 dtype=dtype,
                                 device=device,
                                 seq_length=128,
                                 add_inputs=True)
380
381
382
383
384
385
386
387
388
389
390


@pytest.mark.parametrize("batches", hs_test_params['batches'])
@pytest.mark.parametrize("num_loras", hs_test_params['num_loras'])
@pytest.mark.parametrize("rank", hs_test_params['max_ranks'])
@pytest.mark.parametrize("hidden_size", hs_test_params['hidden_sizes'])
@pytest.mark.parametrize("nslices", [1, 2, 3])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
391
def test_kernels_hidden_size(
392
393
394
395
396
397
398
399
400
401
    batches: int,
    num_loras: int,
    rank: int,
    hidden_size: int,
    nslices: int,
    dtype: torch.dtype,
    device: str,
    seed: int,
    op_type: str,
):
402
    """
403
    Tests SGMV and LoRA kernels.
404
    """
405
406
407
408
    torch.set_default_device(device)
    current_platform.seed_everything(seed)

    if op_type == "shrink":
409
410
411
412
413
414
415
416
417
        check_lora_shrink_kernel(batches=batches,
                                 num_loras=num_loras,
                                 rank=rank,
                                 hidden_size=hidden_size,
                                 nslices=nslices,
                                 dtype=dtype,
                                 device=device,
                                 seq_length=128,
                                 scaling=0.5)
418
    else:
419
420
421
422
423
424
425
426
427
        check_lora_expand_kernel(batches=batches,
                                 num_loras=num_loras,
                                 rank=rank,
                                 hidden_size=hidden_size,
                                 nslices=nslices,
                                 dtype=dtype,
                                 device=device,
                                 seq_length=128,
                                 add_inputs=True)