"vscode:/vscode.git/clone" did not exist on "ab7165f2c7ea358df969d68a0fb0ce9bb184a083"
test_encoder_decoder_attn.py 42.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
"""
Tests:

* E2E test of Encoder attention + Decoder self-attention +
      Encoder/decoder cross-attention (collectively
      "encoder/decoder attention")

"""

from typing import NamedTuple, Optional

import pytest
import torch

from tests.kernels.utils import *
17
from vllm.attention import Attention, AttentionMetadata, AttentionType
18
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
Joe Runde's avatar
Joe Runde committed
19
from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
20
                                     global_force_attn_backend_context_manager)
21
from vllm.config import VllmConfig, set_current_vllm_config
22
from vllm.forward_context import set_forward_context
23
from vllm.platforms import current_platform
24

25
26
27
28
29
30
31
32
33
34

@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch):
    """
    Encoder-decoder is only supported on V0, so set 
    VLLM_USE_V1=0 for all tests in the module.
    """
    monkeypatch.setenv('VLLM_USE_V1', '0')


35
# List of support backends for encoder/decoder models
36
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN]
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
HEAD_SIZES = [64, 256]

NUM_HEADS = [1, 16]

BATCH_SIZES = [1, 16]
BLOCK_SIZES = [16]
CUDA_DEVICE = "cuda:0"

MAX_DEC_SEQ_LENS = [128]
MAX_ENC_SEQ_LENS = [128]

# Narrow teest-cases for unsupported-scenario
# tests
HEAD_SIZES_FOR_UNSUPP = [HEAD_SIZES[0]]


class TestPoint(NamedTuple):
    """
    Encapsulates the attributes which define a single invocation
    of the test_e2e_enc_dec_attn() test

    Attributes:
        num_heads: The number of heads in the model.
        head_size: Head dimension
        backend_name: Name of the backend framework used.
        batch_size: Number of samples per batch.
        block_size: Size of each block of data processed.
        max_dec_seq_len: Maximum sequence length for the decoder.
        max_enc_seq_len: Maximum sequence length for the encoder.
        num_blocks: Number of blocks in the model.
    """

    num_heads: int
    head_size: int
    backend_name: str
    batch_size: int
    block_size: int
    max_dec_seq_len: int
    max_enc_seq_len: int
    num_blocks: int
77
    attn_type: AttentionType
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96


class TestResources(NamedTuple):
    '''
    Encapsulates key components for performing an
    encoder/decoder attention test

    Note that
    (1) attn automatically selects an attention backend
        based on platform info & a set of canned
        heuristics
    (2) attn_backend is thus *not the same backend
        instance* used by attn, but rather it is
        intended to be a
        *different instance* of the *same backend class*;
        it is assumed that the user of TestResources
        will leverage attn_backend for the purpose of
        constructing backend-compatible attention
        metadata instances
97

98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
    Attributes:

    * scale: 1/sqrt(d) scale factor for attn
    * attn_backend: implementatino of abstraction
                    attention interface using
                    a particular kernel library
                    i.e. XFormers
    * attn: Attention layer instance
    * kv_cache: shared key/value cache for all attention
    '''

    scale: float
    attn: Attention
    kv_cache: torch.Tensor


def _make_test_resources(test_pt: TestPoint, ) -> TestResources:
    '''
    Build key components for performing encoder/decoder attention test.

    Note that
119
    (1) The Attention instance constructed here, automatically selects
120
121
        an attention backend class based on platform info & a set of canned
        heuristics, so
122
    (2) The attention backend instance constructed here is thus *not
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
        the same backend instance* used by attn, but rather it is
        intended to be a *different instance* of the *same backend class*;
        therefore,
    (3) This function requires that test_pt.backend_name matches the backend
        class that Attention will automatically select when it is constructed.


    Arguments:

    * test_pt: TestPoint data structure; this function relies on the
               following fields: num_heads, head_size, num_blocks,
               block_size, backend_name

    Returns:

    * TestResources data structure.
    '''

    scale = float(1.0 / (test_pt.head_size**0.5))
    attn = Attention(
        test_pt.num_heads,
        test_pt.head_size,
        scale=scale,
146
147
        prefix=f"{test_pt.attn_type}",
        attn_type=test_pt.attn_type,
148
149
150
    )
    if test_pt.num_blocks is None or test_pt.num_heads is None:
        # Caller does not require a KV cache
151
        return TestResources(
152
            scale, attn,
153
            torch.tensor([], dtype=torch.float32, device=CUDA_DEVICE))
154
155

    # Construct KV cache
156
157
158
159
160
161
162
163
164
165
166
167
    if test_pt.attn_type in (AttentionType.DECODER,
                             AttentionType.ENCODER_DECODER):
        kv_cache = make_kv_cache(test_pt.num_blocks,
                                 test_pt.num_heads,
                                 test_pt.head_size,
                                 test_pt.block_size,
                                 device=CUDA_DEVICE,
                                 backend=test_pt.backend_name)
    else:
        kv_cache = torch.tensor([])

    attn.kv_cache = [kv_cache]
168
    return TestResources(scale, attn, kv_cache)
169
170
171
172
173
174
175
176
177


def _encoder_attn_setup(
    test_pt: TestPoint,
    test_rsrcs: TestResources,
) -> PhaseTestParameters:
    '''
    Set up test vectors & data structures for encoder attention test.

178
    A triplet of synthetic query/key/value tensors are constructed.
179
180
181
182
183
184
185
186
187
188
189
190
    Given this is an encoder attention test, the key & value
    sequences will have the same length as the corresponding queries.

    The query/key/value tensors are passed to an ideal reference
    self-attention implementation to generate an ideal output tensor.

    Encoder inference does not populate the KV cache, therefore
    no KV cache memory mapping is constructed

    Arguments:

    * test_pt: TestPoint data structure; this function relies on the
191
               following fields: batch_size, num_heads, head_size,
192
193
194
195
               block_size, max_q_seq_len
    * test_rsrcs: TestResources data structure; this function relies on the
                  scale field

196

197
    Returns:
198

199
200
201
202
203
204
205
206
207
208
209
210
211
212
    * PhaseTestParameters data structure comprising (1) packed query/key/value
      tensors, (2) the ideal output of attention computed using a naive
      implementation, and (3) KVCache field set to None
    '''

    (
        num_heads,
        head_size,
        _,
        batch_size,
        _,
        _,
        max_q_seq_len,
        _,
213
        _,
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
    ) = test_pt

    scale = test_rsrcs.scale

    max_kv_seq_len = max_q_seq_len

    # Make test tensors

    qkv_in, _, _ = make_qkv(batch_size,
                            max_q_seq_len,
                            max_kv_seq_len,
                            num_heads,
                            head_size,
                            attn_type=AttentionType.ENCODER,
                            device=CUDA_DEVICE)

    # Compute correct answer using naive non-causal attention
    # implementation

    ideal_output = ref_masked_attention(qkv_in.query,
                                        qkv_in.key,
                                        qkv_in.value,
                                        scale=scale,
                                        q_seq_lens=qkv_in.q_seq_lens,
                                        kv_seq_lens=qkv_in.kv_seq_lens)

    packed_ideal_output, _ = pack_tensor(ideal_output,
                                         qkv_in.q_seq_lens,
                                         device=CUDA_DEVICE)

    packed_qkv = pack_qkv(qkv_in, device=CUDA_DEVICE)

    return PhaseTestParameters(
        PackedQKVO(packed_qkv, packed_ideal_output),
        None  # No KV cache
    )


def _decoder_attn_setup(
    test_pt: TestPoint,
    test_rsrcs: TestResources,
    block_base_addr: int = 0,
256
) -> tuple[QKVInputs, PhaseTestParameters, PhaseTestParameters, int]:
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
    '''
    Set up test vectors & data structures for self-attention test.

    A triplet of synthetic query/key/value tensors are constructed ("baseline"
    query/key/value). Given this is a self-attention test, the key & value
    sequences will have the same length as the corresponding queries.

    "Prefill" query/key/value tensors are derived by masking out the last value
    in each baseline query/key/value. These tensors are used to test prefill &
    populate KV cache for a subsequent decode test.

    "Decode" query/key/value tensors are derived by extracting *only* the last
    value from each baseline query/key/value (i.e. complement of the prefill
    tensors.) These tensors are used to test decode, conditional on the kv cache
    being populated during the prefill test.

    The baseline query/key/value tensors are passed to an ideal reference
    self-attention implementation to generate a "Baseline" ideal output tensor.
    This tensor is split into the "Prefill" ideal output tensor (all but the
    last element of each output sequence) and the "Decode" ideal output tensor
    (*only* the last element of each output sequence); the "Prefill" and
    "Decode" ideal output tensors can be used to validate the prefill and decode
    test results, respectively.

    This function also constructs the self-attention KV cache memory mapping
    (slot mapping and block table), ensuring that the block table starts at
    block_base_addr

    Arguments:

    * test_pt: TestPoint data structure; this function relies on the
288
               following fields: batch_size, num_heads, head_size,
289
290
291
292
293
294
295
296
297
               block_size, max_q_seq_len
    * test_rsrcs: TestResources data structure; this function relies on the
                  scale field
    * block_base_addr: decoder self-attention block-table base address

    Returns:
    * qkv: Unpacked (batch_size x padded_seq_len x num_heads x
           head_size) query/key/value tensors
    * Prefill-phase decoder self-attention PhaseTestParameters data structure,
298
      including (1) packed (number_of_tokens x num_heads x head_size)
299
      query/key/value tensors along with (2) ideal attention output
300
      computed using a naive implementation, and (3) memory-mapping data
301
      structures appropriate for prefill phase.
302
303
304
305
    * Decode-phase decoder self-attention PhaseTestParameters data structure,
      including (1) packed (number_of_tokens x num_heads x head_size)
      query/key/value tensors along with (2) ideal attention output
      computed using a naive implementation, and (3) memory-mapping data
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
      structures appropriate for decode phase.
    * max_block_idx: max physical address in decoder self-attention block-table
                     (intended to be used as the base address for the encoder/
                      decoder cross-attention block-table, which is not
                      constructed in this function)
    '''

    (
        num_heads,
        head_size,
        _,
        batch_size,
        block_size,
        max_q_seq_len,
        _,
        _,
322
        _,
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
    ) = test_pt

    scale = test_rsrcs.scale

    max_kv_seq_len = max_q_seq_len

    # Build test tensors

    (
        qkv,
        prefill_qkv,
        decode_qkv,
    ) = make_qkv(batch_size,
                 max_q_seq_len,
                 max_kv_seq_len,
                 num_heads,
                 head_size,
                 attn_type=AttentionType.DECODER,
                 device=CUDA_DEVICE)

    # Compute correct answer using naive attention implementation
    # with causal attention mask

    causal_mask = make_causal_mask(max_q_seq_len,
                                   max_kv_seq_len).to(CUDA_DEVICE)

    ideal_output = ref_masked_attention(qkv.query,
                                        qkv.key,
                                        qkv.value,
                                        scale=scale,
                                        custom_mask=causal_mask,
                                        q_seq_lens=qkv.q_seq_lens,
                                        kv_seq_lens=qkv.kv_seq_lens)

    # Split out the prefill- & decode-phase ideal answers & pack them

    prefill_ideal_output = torch.zeros_like(ideal_output)
    decode_ideal_output = torch.zeros_like(ideal_output[:, 0:1])
    for bdx, prefill_q_seq_len in enumerate(prefill_qkv.q_seq_lens):
        prefill_ideal_output[bdx, :prefill_q_seq_len] = ideal_output[
            bdx, :prefill_q_seq_len]
        decode_ideal_output[bdx, :] = ideal_output[bdx, prefill_q_seq_len:(
            prefill_q_seq_len + 1)]

    prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output,
                                                 prefill_qkv.q_seq_lens,
                                                 device=CUDA_DEVICE)
    decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output,
                                                [1 for _ in range(batch_size)],
                                                device=CUDA_DEVICE)

    # Build prefill- & decode-phase data structures
    # for decoder self-attention. Block tables and
    # slot mapping must be in a format compatible
    # with KV caching & attention kernels
    #
    # Prefill-phase:
    #
    # * Empty block-tables tensor
    # * Slot-mapping with entries for prompt tokens
    #
    # Decode-phase:
    # * Block-tables tensor with minimum number of blocks
    #   required by total num. tokens in the entirety of all sequences
    #   (including both prefill & decode)
    # * Slot-mapping with entries for tokens that will be decoded in the
    #   current decode iteration
    #
    #  Note: the format described above is simply mirroring what ModelRunner
    #        produces

    prefill_block_tables = make_empty_block_tables_tensor(device=CUDA_DEVICE)

    (
        decode_block_tables,
        slot_mapping_list,
        max_block_idx,
    ) = make_block_tables_slot_mapping(block_size,
                                       qkv.q_seq_lens,
                                       device=CUDA_DEVICE,
                                       block_base_addr=block_base_addr)

    (
        prefill_slot_mapping,
        decode_slot_mapping,
    ) = split_slot_mapping(slot_mapping_list,
                           qkv.q_seq_lens,
                           device=CUDA_DEVICE)

    prefill_pckd_qkv = pack_qkv(prefill_qkv, device=CUDA_DEVICE)

    decode_pckd_qkv = pack_qkv(decode_qkv, device=CUDA_DEVICE)

    return (
        qkv,
        PhaseTestParameters(  # Prefill test params
            PackedQKVO(prefill_pckd_qkv, prefill_packed_ideal_output),
            KVMemoryMap(prefill_block_tables, prefill_slot_mapping)),
        PhaseTestParameters(  # Decode test params
            PackedQKVO(decode_pckd_qkv, decode_packed_ideal_output),
            KVMemoryMap(decode_block_tables, decode_slot_mapping)),
        max_block_idx)


def _enc_dec_cross_attn_setup_reuses_query(
    decoder_qkv: QKVInputs,
    encoder_test_params: PhaseTestParameters,
    prefill_decoder_phase_test_params: PhaseTestParameters,
    test_pt: TestPoint,
    test_rsrcs: TestResources,
    block_base_addr: int = 0,
434
) -> tuple[PhaseTestParameters, PhaseTestParameters]:
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
    '''
    Set up test vectors & data structures for cross-attention test.

    A triplet of synthetic cross-attention key/value tensors are constructed
    ("baseline" key/value). Given this is a cross-attention test, we assume
    query tensors were already synthesized for a prior self-attention test and
    will be reused for cross-attention. The key & value sequences generated here
    may have a different length than the corresponding queries (as is often
    the case for cross-attention between decoder and encoder sequences.)

    Cross attention key & value tensors do not grow during autoregressive
    inference; thus this function obtains a single key/value pair suitable for
    both prefill and decode.

    The "baseline" query tensor is received as an argument. The "baseline"
    query/key/value tensors are passed to an ideal reference cross-attention
    implementation to generate a "baseline" ideal output tensor. This tensor is
    split into the "Prefill" ideal output tensor (all but the last element of
    each output sequence) and the "Decode" ideal output tensor (*only* the last
    element of each output sequence); the "Prefill" and "Decode" ideal output
    tensors can be used to validate the prefill and decode test results,
    respectively.

    This function also constructs the cross-attention KV cache memory mapping
    (slot mapping and block table), ensuring that the block table starts at
460
    block_base_addr.
461
462
463
464

    Arguments:

    * decoder_qkv: pre-existing unpacked (batch_size x padded_seq_len x
465
                   num_heads x head_size) decoder self-attention inputs;
466
467
468
469
470
471
472
473
474
475
                   this function relies on the query and q_seq_lens
                   fields
    * encoder_test_params: PhaseTestParameters data structure which was
                           used for encoder inference; KV cache field
                           is not used by this function
    * prefill_decoder_phase_test_params: PhaseTestParameters data structure
                                         used for prefill-phase decoder
                                         self-attention; all fields
                                         including KV cache required
    * test_pt: TestPoint data structure; this function relies on the
476
               following fields: batch_size, num_heads, head_size,
477
478
479
480
481
482
483
               block_size, max_q_seq_len
    * test_rsrcs: TestResources data structure; this function relies on the
                  scale field
    * block_base_addr: decoder self-attention block-table base address

    Returns:

484
485
    * Prefill-phase encoder/decoder cross-attention PhaseTestParameters data
      structure, including (1) packed
486
      (number_of_tokens x num_heads x head_size) query/key/value tensors
487
      along with (2) ideal attention output computed using a
488
489
      naive implementation, and (3) memory-mapping data structures appropriate
      for prefill phase.
490
    * Decode-phase encoder/decoder cross-attention PhaseTestParameters data
491
492
      structure, including (1) packed
      (number_of_tokens x num_heads x head_size) query/key/value tensors
493
      along with (2) ideal attention output computed using a
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
      naive implementation, and (3) memory-mapping data structures appropriate
      for decode phase.
    '''

    assert encoder_test_params.packed_qkvo.packed_qkv is not None
    assert prefill_decoder_phase_test_params.packed_qkvo.packed_qkv is not None

    (
        num_heads,
        head_size,
        _,
        batch_size,
        block_size,
        max_decoder_seq_len,
        max_encoder_seq_len,
        _,
510
        _,
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
    ) = test_pt

    scale = test_rsrcs.scale

    decoder_query = decoder_qkv.query
    decoder_seq_lens = decoder_qkv.q_seq_lens
    encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens
    prefill_q_seq_lens = (
        prefill_decoder_phase_test_params.packed_qkvo.packed_qkv.q_seq_lens)

    assert prefill_q_seq_lens is not None

    (
        cross_kv,
        _,
        _,
    ) = make_qkv(batch_size,
                 max_decoder_seq_len,
                 max_encoder_seq_len,
                 num_heads,
                 head_size,
                 force_kv_seq_lens=encoder_seq_lens,
                 attn_type=AttentionType.ENCODER_DECODER,
                 device=CUDA_DEVICE)

    ideal_output = ref_masked_attention(decoder_query,
                                        cross_kv.key,
                                        cross_kv.value,
                                        scale=scale,
                                        q_seq_lens=decoder_seq_lens,
                                        kv_seq_lens=cross_kv.kv_seq_lens)

    prefill_ideal_output = torch.zeros_like(ideal_output)
    decode_ideal_output = torch.zeros_like(ideal_output[:, 0:1])
    for bdx, prefill_q_seq_len in enumerate(prefill_q_seq_lens):
        prefill_ideal_output[bdx, :prefill_q_seq_len] = ideal_output[
            bdx, :prefill_q_seq_len]
        decode_ideal_output[bdx, :] = ideal_output[bdx, prefill_q_seq_len:(
            prefill_q_seq_len + 1)]

    prefill_packed_ideal_output, _ = pack_tensor(prefill_ideal_output,
                                                 prefill_q_seq_lens,
                                                 device=CUDA_DEVICE)
    decode_packed_ideal_output, _ = pack_tensor(decode_ideal_output,
                                                [1 for _ in range(batch_size)],
                                                device=CUDA_DEVICE)

    # Build prefill- & decode-phase data structures
    # for encoder/decoder cross-attention. Block tables and
    # slot mapping must be in a format compatible
    # with KV caching & attention kernels
    #
    # Whereas decoder self-attention extracts relationships between
    # equal-length Q/K/V sequences, which mutually grow in length
    # with each decoded token, cross-attention relates the Q sequence
    # - which grows with each new decoded token - to fixed-length
    # K and V sequences derived from the encoder hidden states.
    #
    # Prefill-phase:
    #
    # * Empty block-tables tensor
    # * Slot-mapping with as many entries as there are tokens in the encoder
    #   prompt.
    #
    # Decode-phase:
    # * Block-tables tensor with minimum number of blocks to
    #   accommodate K & V tensors which are equal in lnegth
    #   to the encoder prompt length
    # * Empty slot-mapping tensor (since K & V are fixed in size,
    #   new decoded tokens are not KV-cached and require no slot-
    #   mapping)
    #
    # Note: the format above is simply an extension of what ModelRunner
    #       produces for decoder-only models

    prefill_block_tables = make_empty_block_tables_tensor(device=CUDA_DEVICE)
    decode_slot_mapping = make_empty_slot_mapping_tensor(device=CUDA_DEVICE)

    (
        decode_block_tables,
        prefill_slot_mapping_list,
        _,
    ) = make_block_tables_slot_mapping(block_size,
                                       cross_kv.kv_seq_lens,
                                       block_base_addr=block_base_addr,
                                       device=CUDA_DEVICE)

    prefill_slot_mapping = maybe_make_long_tensor(prefill_slot_mapping_list,
                                                  device=CUDA_DEVICE)

    # Packed key/value (query is already provided)
    packed_cross_kv = pack_qkv(cross_kv, device=CUDA_DEVICE)

    return (
        PhaseTestParameters(  # Prefill-phase test params
            PackedQKVO(packed_cross_kv, prefill_packed_ideal_output),
            KVMemoryMap(prefill_block_tables, prefill_slot_mapping)),
        PhaseTestParameters(  # Decode-phase test params
            PackedQKVO(None, decode_packed_ideal_output),
            KVMemoryMap(decode_block_tables, decode_slot_mapping)))


def _run_encoder_attention_test(
    attn: Attention,
    encoder_test_params: PhaseTestParameters,
    attn_metadata: AttentionMetadata,
617
    test_pt: TestPoint,
618
    vllm_config: VllmConfig,
619
620
621
622
) -> torch.Tensor:
    '''
    Run encoder attention.

623
    attn.forward() is passed attn_type=AttentionType.ENCODER in order
624
625
626
627
628
629
630
631
632
633
    to configure the kernel invocation for encoder attention

    Requires attn_metadata.num_decode_tokens == 0
    (There is no encoder execution in the decode-phase)

    Arguments:

    * attn: Attention wrapper instance
    * encoder_test_params: encoder PhaseTestParameters data structure;
                           this function relies on the packed
634
                           (number_of_tokens x num_heads x head_size)
635
636
                           query/key/value fields
    * attn_metadata: attention metadata for encoder/decoder-self attention
637
638
    * test_pt: The TestPoint object containing test details like number of
               model heads, head size, name of the backend being used etc.
639
640
641
642
643
644
645
646

    Returns:
    * Attention.forward() applied to packed {query,key,value} and
      & attn_metadata
    '''
    assert attn_metadata.num_decode_tokens == 0
    packed_qkv = encoder_test_params.packed_qkvo.packed_qkv
    assert packed_qkv is not None
647
    with set_forward_context(attn_metadata, vllm_config):
648
649
650
651
652
653
654
655
656
        # In the test setup the shape of the query is
        # [batch_size, seq_len, num_heads, head_size]. However
        # the attention backend expect the shape to be
        # [num_tokens, hidden_size]. Hence reshape the query before
        # invoking the forward method.
        # TODO - Update the way we construct the query so that it
        # is shaped as [num_tokens, hidden_size] and we can skip the reshape.
        reshaped_query = packed_qkv.query.view(
            -1, test_pt.num_heads * test_pt.head_size)
657
        return attn.forward(reshaped_query, packed_qkv.key, packed_qkv.value)
658
659
660
661
662
663


def _run_decoder_self_attention_test(
    test_rsrcs: TestResources,
    decoder_test_params: PhaseTestParameters,
    attn_metadata: AttentionMetadata,
664
    test_pt: TestPoint,
665
    vllm_config: VllmConfig,
666
667
668
669
670
671
672
673
674
675
676
677
678
) -> torch.Tensor:
    '''
    Run decoder self-attention test.

    attn.forward() is passed attn_type=AttentionType.DECODER
    in order to configure the kernel invocation for decoder self-attention.

    Arguments:

    * test_rsrcs: TestResources instance; this function relies on the kv_cache
                  and attn (Attention wrapper instance) fields
    * decoder_test_params: decoder PhaseTestParameters data structure;
                           this function relies on the packed
679
                           (number_of_tokens x num_heads x head_size)
680
681
682
                           query/key/value fields
    * attn_metadata: attention metadata for decoder-self attention
                     (contains KV cache memory-mapping)
683
684
    * test_pt: The TestPoint object containing test details like number of
               model heads, head size, name of the backend being used etc.
685
686
687
688
689
690
691
692

    Returns:
    * Attention.forward() applied to packed_{query,key,value}, kv_cache
      & attn_metadata
    '''
    attn = test_rsrcs.attn
    packed_qkv = decoder_test_params.packed_qkvo.packed_qkv
    assert packed_qkv is not None
693
    with set_forward_context(attn_metadata, vllm_config):
694
695
696
697
698
699
700
701
702
        # In the test setup the shape of the query is
        # [batch_size, seq_len, num_heads, head_size]. However
        # the attention backend expect the shape to be
        # [num_tokens, hidden_size]. Hence reshape the query before
        # invoking the forward method.
        # TODO - Update the way we construct the query so that it
        # is shaped as [num_tokens, hidden_size] and we can skip the reshape.
        reshaped_query = packed_qkv.query.view(
            -1, test_pt.num_heads * test_pt.head_size)
703
        return attn.forward(reshaped_query, packed_qkv.key, packed_qkv.value)
704
705
706
707
708
709
710


def _run_encoder_decoder_cross_attention_test(
    test_rsrcs: TestResources,
    decoder_test_params: PhaseTestParameters,
    cross_test_params: Optional[PhaseTestParameters],
    attn_metadata: AttentionMetadata,
711
    test_pt: TestPoint,
712
    vllm_config: VllmConfig,
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
) -> torch.Tensor:
    '''
    Run encoder/decoder cross-attention test.

    Via PhaseTestParameters data structures, consumes the same query utilized
    for decoder self-attention, plus a key/value specific to cross-attention.

    if cross_test_params is None or cross_test_params.packed_qkvo.packed_qkv
    is None, this reflects that in decode-phase cross attention there
    is no growth in the key and value tensors.

    attn.forward() is passed attn_type=AttentionType.ENCODER_DECODER
    in order to configure the kernel invocation for encoder/decoder cross-
    attention.

    Arguments:

    * test_rsrcs: TestResources instance; this function relies on the kv_cache
                  and attn (Attention wrapper instance) fields
    * decoder_test_params: decoder PhaseTestParameters data structure;
                           this function relies on the packed
734
                           (number_of_tokens x num_heads x head_size)
735
736
737
                           query field
    * cross_test_params: encoder/decoder PhaseTestParameters data structure;
                         this function relies on the packed
738
                         (number_of_tokens x num_heads x head_size)
739
740
                         key/value fields
    * attn_metadata: attention metadata for encoder/decoder-self attention
741
742
    * test_pt: The TestPoint object containing test details like number of
               model heads, head size, name of the backend being used etc.
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757

    Returns:
    * Attention.forward() applied to packed_{query,key,value}, kv_cache
      & attn_metadata
    '''
    assert decoder_test_params.packed_qkvo.packed_qkv is not None

    attn = test_rsrcs.attn
    if cross_test_params is None:
        key = None
        value = None
    else:
        cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv
        key = (None if cross_pckd_qkv is None else cross_pckd_qkv.key)
        value = (None if cross_pckd_qkv is None else cross_pckd_qkv.value)
758
    with set_forward_context(attn_metadata, vllm_config):
759
760
761
762
763
764
765
766
767
        # In the test setup the shape of the query is
        # [batch_size, seq_len, num_heads, head_size]. However
        # the attention backend expect the shape to be
        # [num_tokens, hidden_size]. Hence reshape the query before
        # invoking the forward method.
        # TODO - Update the way we construct the query so that it
        # is shaped as [num_tokens, hidden_size] and we can skip the reshape.
        reshaped_query = decoder_test_params.packed_qkvo.packed_qkv.query.view(
            -1, test_pt.num_heads * test_pt.head_size)
768
        return attn.forward(reshaped_query, key, value)
769
770
771
772
773
774
775
776
777
778


@pytest.fixture(autouse=True)
def set_reset_environment(attn_backend):
    # Set the default torch datatype to bfloat16 to enable
    # testing of the Flash Attention backend. Also clear the
    # cached value of the backend.
    default_dtype = torch.get_default_dtype()
    if attn_backend.name == 'FLASH_ATTN':
        torch.set_default_dtype(torch.bfloat16)
Joe Runde's avatar
Joe Runde committed
779
    _cached_get_attn_backend.cache_clear()
780
781
782
783
    yield
    # Reset the torch datatype to what it was before the test
    # so as not to impact the remaining tests.
    torch.set_default_dtype(default_dtype)
784
785


786
787
@pytest.mark.skipif(current_platform.is_rocm(),
                    reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
788
789
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
790
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
791
792
793
794
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS)
@pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS)
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
def test_encoder_only(
    num_heads: int,
    head_size: int,
    attn_backend: _Backend,
    batch_size: int,
    block_size: int,
    max_dec_seq_len: int,
    max_enc_seq_len: int,
):
    '''
    End-to-end encoder-only attention test:

    * Construct fake test vectors for (1) encoder attention
    * Construct (1) attention metadata structure with prefill-phase
      encoder attention, and (2) an analogous attention metadata
      structure but for decode-phase
    * Test & validate encoder attention against ideal output

    No KV cache is required for encoder-only attention.

    Note on ROCm/HIP: currently encoder/decoder models are not supported on
816
817
    AMD GPUs, therefore this test simply is skipped if
    current_platform.is_rocm().
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832

    This test globally forces an override of the usual backend
    auto-selection process, forcing the specific backend-under-test
    to be utilized.

    Arguments:

    * num_heads
    * head_size,
    * attn_backend: The attention backend to employ for testing
    * batch_size
    * block_size: KV cache block size
    * max_dec_seq_len: max length of decoder input sequences
    * max_enc_seq_len: max length of encoder input sequences
    '''
833
    # Force Attention wrapper backend
834
835
836
837
838
839
    with global_force_attn_backend_context_manager(attn_backend):
        # Note: KV cache size of 4096 is arbitrary & chosen intentionally
        # to be more than necessary, since exceeding the kv cache size
        # is not part of this test
        test_pt = TestPoint(num_heads, head_size, attn_backend.name,
                            batch_size, block_size, max_dec_seq_len,
840
                            max_enc_seq_len, 4096, AttentionType.ENCODER)
841

842
843
        # Attention scale factor, attention backend instance, attention wrapper
        # instance, KV cache init
844
845
846
        vllm_config = VllmConfig()
        with set_current_vllm_config(vllm_config):
            test_rsrcs = _make_test_resources(test_pt)
847

848
849
        # Construct encoder attention test params (only used
        # during prefill)
850

851
        enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs)
852

853
        # Shared prefill metadata structure
854

855
        prephase_attn_metadata: AttentionMetadata = make_test_metadata(
856
            attn_backend,
857
858
859
860
861
862
            True,
            None,
            decoder_test_params=None,
            encoder_test_params=enc_test_params,
            cross_test_params=None,
            device=CUDA_DEVICE)
863

864
        # PREFILL: encoder attention
865

866
        enc_pckd_act_out: torch.Tensor = (_run_encoder_attention_test(
867
868
869
            test_rsrcs.attn,
            enc_test_params,
            prephase_attn_metadata,
870
871
            test_pt=test_pt,
            vllm_config=vllm_config))
872

873
        # - Is encoder attention result correct?
874
875
        assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out,
                                    attn_backend.name)
876
877


878
879
@pytest.mark.skipif(current_platform.is_rocm(),
                    reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
880
881
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
882
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
883
884
885
886
887
888
889
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS)
@pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS)
def test_e2e_enc_dec_attn(
    num_heads: int,
    head_size: int,
890
    attn_backend: _Backend,
891
892
893
894
895
896
897
898
899
900
901
902
903
904
    batch_size: int,
    block_size: int,
    max_dec_seq_len: int,
    max_enc_seq_len: int,
) -> None:
    '''
    End-to-end encoder/decoder test:

    * Construct fake test vectors for (1) encoder attention,
      (2) decoder self-attention, and (3) encoder/decoder cross-attention
    * Construct (1) attention metadata structure with self- and cross-attention
      attributes for prefill-phase, and (2) an analogous attention metadata
      structure but for decode-phase
    * Test attention steps in the following order
905

906
907
908
909
910
        * Encoder attention
        * Prefill self-attention
        * Prefill cross-attention
        * Decode self-attention
        * Decode cross-attention
911
912
        * Besides being reflective of realistic use-cases, this order would
          exacerbate any accidental overlap in the self-/cross-attention
913
914
915
916
917
918
919
920
921
922
923
924
925
926
          block tables, which one hopes to avoid


    * Validate output correctness against ideal reference attention
      implementation

    Block tables are constructed such that cross-attention KV cache is in a
    higher, non-intersecting address-space than self-attention KV cache.

    Self- and cross-attention share the same query tensor but not the K/V
    tensors. Self-attention K/Vs must have the same seq len as Q while
    cross-attention K/Vs are allowed to differ in seq len, as is often the case
    for cross-attention.

927
928
929
    This test globally forces an override of the usual backend
    auto-selection process, forcing the specific backend-under-test
    to be utilized.
930
931

    Note on ROCm/HIP: currently encoder/decoder models are not supported on
932
933
    AMD GPUs, therefore this test simply is skipped if
    current_platform.is_rocm().
934
935

    Note on metadata: there is a single attention metadata structure shared by
936
    all prefill-phase attention operations (encoder, decoder, enc/dec cross),
937
938
    and a single one shared by all decode-phase attention operations
    (decoder & enc/dec cross.) This is intended to reflect the behavior
939
940
941
942
943
944
945
    of EncoderDecoderModelRunner, which constructs a single attention metadata
    structure for each prefill or decode run. A realistic scenario would rely
    on the attention backend to utilize the appropriate attention metadata
    fields according to the value of attn_metadata.attention_type. Thus,
    this test is organized so as to confirm that the backend-under-test can
    handle a shared prefill attention metadata structure & a shared decode\
    attention metadata structure.
946

947
    Arguments:
948

949
950
951
952
953
954
955
956
957
958
959
960
961
    * num_heads
    * head_size,
    * attn_backend: The attention backend to employ for testing
    * batch_size
    * block_size: KV cache block size
    * max_dec_seq_len: max length of decoder input sequences
    * max_enc_seq_len: max length of encoder input sequences
    '''
    # Force Attention wrapper backend
    with global_force_attn_backend_context_manager(attn_backend):
        # Note: KV cache size of 4096 is arbitrary & chosen intentionally
        # to be more than necessary, since exceeding the kv cache size
        # is not part of this test
962
963
964
965
966
967
968
969
970
971
        enc_test_pt = TestPoint(num_heads, head_size, attn_backend.name,
                                batch_size, block_size, max_dec_seq_len,
                                max_enc_seq_len, 4096, AttentionType.ENCODER)
        enc_dec_test_pt = TestPoint(num_heads, head_size, attn_backend.name,
                                    batch_size, block_size, max_dec_seq_len,
                                    max_enc_seq_len, 4096,
                                    AttentionType.ENCODER_DECODER)
        dec_test_pt = TestPoint(num_heads, head_size, attn_backend.name,
                                batch_size, block_size, max_dec_seq_len,
                                max_enc_seq_len, 4096, AttentionType.DECODER)
972
973
974

        # Attention scale factor, attention backend instance, attention wrapper
        # instance, KV cache init
975
976
        vllm_config = VllmConfig()
        with set_current_vllm_config(vllm_config):
977
978
979
            enc_test_rsrcs = _make_test_resources(enc_test_pt)
            enc_dec_test_rsrcs = _make_test_resources(enc_dec_test_pt)
            dec_test_rsrcs = _make_test_resources(dec_test_pt)
980
981
982
983

        # Construct encoder attention test params (only used
        # during prefill)

984
        enc_test_params = _encoder_attn_setup(enc_test_pt, enc_test_rsrcs)
985
986
987
988
989
990
991
992
993
994
995
996

        # Construct Decoder self-attention prefill-phase & decode-phase
        # test params, including query/key/value tensors, decoder self-attention
        # memory-mapping. cross_block_base_addr is the uppermost address in the
        # decoder self-attention block-table, i.e. a base address which the
        # encoder/decoder cross-attention block-table may build downward toward.

        (
            dec_qkv,
            prephase_dec_test_params,
            decphase_dec_test_params,
            cross_block_base_addr,
997
        ) = _decoder_attn_setup(dec_test_pt, dec_test_rsrcs)
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009

        # Construct encoder/decoder cross-attention prefill-phase
        # & decode-phase test params, including key/value tensors,
        # cross-attention memory-mapping

        (
            prephase_cross_test_params,
            decphase_cross_test_params,
        ) = _enc_dec_cross_attn_setup_reuses_query(
            dec_qkv,
            enc_test_params,
            prephase_dec_test_params,
1010
1011
            enc_dec_test_pt,
            enc_dec_test_rsrcs,
1012
1013
1014
1015
1016
            block_base_addr=cross_block_base_addr)

        # Shared prefill metadata structure
        assert prephase_dec_test_params.packed_qkvo.packed_qkv is not None
        prephase_attn_metadata: AttentionMetadata = make_test_metadata(
1017
            attn_backend,
1018
1019
1020
1021
1022
1023
1024
1025
1026
            True,
            prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens,
            decoder_test_params=prephase_dec_test_params,
            encoder_test_params=enc_test_params,
            cross_test_params=prephase_cross_test_params,
            device=CUDA_DEVICE)

        # PREFILL: encoder attention

1027
        enc_pckd_act_out = _run_encoder_attention_test(enc_test_rsrcs.attn,
1028
                                                       enc_test_params,
1029
                                                       prephase_attn_metadata,
1030
                                                       test_pt=enc_test_pt,
1031
                                                       vllm_config=vllm_config)
1032
1033

        # - Is encoder attention result correct?
1034
1035
        assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out,
                                    attn_backend.name)
1036
1037
1038
1039

        # PREFILL: decoder self-attention test

        prephase_dec_pckd_act_out = _run_decoder_self_attention_test(
1040
            dec_test_rsrcs,
1041
1042
            prephase_dec_test_params,
            prephase_attn_metadata,
1043
            test_pt=dec_test_pt,
1044
            vllm_config=vllm_config)
1045
1046
1047

        # - Is prefill decoder self-attention correct?
        assert_actual_matches_ideal(prephase_dec_test_params,
1048
1049
                                    prephase_dec_pckd_act_out,
                                    attn_backend.name)
1050
1051
1052
1053

        # PREFILL: encoder/decoder cross-attention test

        prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
1054
            enc_dec_test_rsrcs,
1055
1056
1057
            prephase_dec_test_params,
            prephase_cross_test_params,
            prephase_attn_metadata,
1058
            test_pt=enc_dec_test_pt,
1059
            vllm_config=vllm_config)
1060
1061
1062

        # - Is prefill encoder/decoder cross-attention correct?
        assert_actual_matches_ideal(prephase_cross_test_params,
1063
1064
                                    prephase_cross_pckd_act_out,
                                    attn_backend.name)
1065
1066
1067
1068

        # DECODE: build decode-phase attention metadata

        decphase_attn_metadata: AttentionMetadata = make_test_metadata(
1069
            attn_backend,
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
            False,
            dec_qkv.q_seq_lens,
            decoder_test_params=decphase_dec_test_params,
            encoder_test_params=enc_test_params,
            cross_test_params=decphase_cross_test_params,
            device=CUDA_DEVICE)

        # DECODE: decoder self-attention test

        decphase_dec_pckd_act_out = _run_decoder_self_attention_test(
1080
            dec_test_rsrcs,
1081
1082
            decphase_dec_test_params,
            decphase_attn_metadata,
1083
            test_pt=dec_test_pt,
1084
            vllm_config=vllm_config)
1085
1086
1087

        # - Is decode-phase decoder self-attention correct?
        assert_actual_matches_ideal(decphase_dec_test_params,
1088
1089
                                    decphase_dec_pckd_act_out,
                                    attn_backend.name)
1090
1091
1092
1093

        # DECODE: encoder/decoder cross-attention test

        decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test(
1094
            enc_dec_test_rsrcs,
1095
1096
1097
            decphase_dec_test_params,
            None,
            decphase_attn_metadata,
1098
            test_pt=enc_dec_test_pt,
1099
            vllm_config=vllm_config)
1100
1101
1102

        # - Is decode-phase encoder/decoder cross-attention correct?
        assert_actual_matches_ideal(decphase_cross_test_params,
1103
1104
                                    decphase_cross_pckd_act_out,
                                    attn_backend.name)