utils.py 24.5 KB
Newer Older
1
"""Attention backend utils"""
2
from collections import defaultdict
3
from contextlib import contextmanager
4
from itertools import accumulate
5
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union
6

7
import numpy as np
8
9
import torch

10
11
from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
                            AttentionState)
12
from vllm.attention.backends.abstract import AttentionType
13
from vllm.multimodal import MultiModalPlaceholderMap
14
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
15

16
17
18
if TYPE_CHECKING:
    from vllm.worker.model_runner_base import ModelRunnerBase

19
20
21
22
# Error string(s) for encoder/decoder
# unsupported attention scenarios
STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported "
                                 "with encoder/decoder models.")
23
24
25

PAD_SLOT_ID = -1

26
27
28
29
# Switch to numpy implementation of compute_slot_mapping
# if we have at least this many elements. Could be tuned further.
_COMPUTE_SLOT_MAPPING_NUMPY_NUMEL = 256

30
if TYPE_CHECKING:
31
    from vllm.worker.model_runner import ModelInputForGPUBuilder
32
33
34
35
36
37
38
39


def is_block_tables_empty(block_tables: Union[None, Dict]):
    """
    Check if block_tables is None or a dictionary with all None values.
    """
    if block_tables is None:
        return True
40
41
    return (isinstance(block_tables, dict)
            and all(value is None for value in block_tables.values()))
42
43
44


def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int,
45
                                   context_len: int, sliding_window: int):
46
47
48
49
50
51
52
53
54
    """
    Compute the start index of slot mapping.
    """
    start_idx = 0
    if is_prompt and sliding_window is not None:
        start_idx = max(0, query_len - sliding_window)
    return start_idx


55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def _compute_slot_mapping_python(slot_mapping: List[int],
                                 block_table: List[int], range_start: int,
                                 range_end: int, block_size: int):
    for i in range(range_start, range_end):
        block_number = block_table[i // block_size]
        block_offset = i % block_size
        slot = block_number * block_size + block_offset
        slot_mapping.append(slot)


def _compute_slot_mapping_numpy(slot_mapping: List[int],
                                block_table: List[int], range_start: int,
                                range_end: int, block_size: int):
    block_table_array = np.array(block_table)
    idx = np.arange(range_start, range_end)
    block_offset = idx % block_size
    idx //= block_size
    seq_slot_mapping_array = block_table_array[idx]
    seq_slot_mapping_array *= block_size
    seq_slot_mapping_array += block_offset
    slot_mapping.extend(seq_slot_mapping_array)


78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int],
                         seq_id: int, seq_len: int, context_len: int,
                         start_idx: int, block_size: int,
                         block_tables: Dict[int, List[int]]):
    """
    Compute slot mapping.
    """
    if is_profile_run:
        # During memory profiling, the block tables are not
        # initialized yet. In this case, we just use a dummy
        # slot mapping.
        # In embeddings, the block tables are {seq_id: None}.
        slot_mapping.extend([PAD_SLOT_ID] * seq_len)
        return

    # Mask the [0, start_idx) tokens of the prompt with
    # PAD_SLOT_ID, where start_idx is max(0, seq_len -
    # sliding_window). For example, if the prompt len is 10,
    # sliding window is 8, and block size is 4, the first two
    # tokens are masked and the slot mapping will be
    # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
99
100
    padding_mask_len = max(0, start_idx - context_len)
    slot_mapping.extend([PAD_SLOT_ID] * padding_mask_len)
101

102
103
104
105
    range_start = max(start_idx, context_len)
    range_end = seq_len
    numel = range_end - range_start
    block_table = block_tables[seq_id]
106

107
108
109
110
111
    # numpy implementation will be faster than python if we have
    # many elements, otherwise it will be slower.
    if numel < _COMPUTE_SLOT_MAPPING_NUMPY_NUMEL:
        _compute_slot_mapping_python(slot_mapping, block_table, range_start,
                                     range_end, block_size)
112
    else:
113
114
        _compute_slot_mapping_numpy(slot_mapping, block_table, range_start,
                                    range_end, block_size)
115

116
117
118
119
120
121
122
123
124

TAttentionMetadata = TypeVar("TAttentionMetadata", bound='AttentionMetadata')


class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):

    _metadata_cls: Type[TAttentionMetadata]

    def __init__(self, input_builder: "ModelInputForGPUBuilder"):
125
126
127
128
129
130
131
        self.input_builder = input_builder
        self.runner = input_builder.runner

        self.sliding_window = input_builder.sliding_window
        self.block_size = input_builder.block_size

    def prepare(self):
132
133
134
135
136
        self.slot_mapping: List[int] = []
        self.prefill_seq_lens: List[int] = []
        self.context_lens: List[int] = []
        self.block_tables: List[List[int]] = []
        self.curr_seq_lens: List[int] = []
137
138
139
        self.multimodal_placeholder_maps: Dict[
            str,
            MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
140
141
142
143
        self.num_prefills = 0
        self.num_prefill_tokens = 0
        self.num_decode_tokens = 0

144
145
146
147
148
    def _add_seq_group(
            self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
            chunked_prefill_enabled: bool):
        is_prompt = inter_data.is_prompt
        block_tables = inter_data.block_tables
149
150
151

        for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
             curr_sliding_window_block) in zip(
152
153
154
155
                 inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
                 inter_data.orig_seq_lens, inter_data.seq_lens,
                 inter_data.query_lens, inter_data.context_lens,
                 inter_data.curr_sliding_window_blocks):
156
157
            self.context_lens.append(context_len)
            if is_prompt:
158
159
160
161
162
163
                mm_maps = inter_data.multi_modal_placeholder_maps
                if mm_maps:
                    for modality, placeholders in mm_maps.items():
                        self.multimodal_placeholder_maps[modality].extend(
                            placeholders)

164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
                self.num_prefills += 1
                self.num_prefill_tokens += token_len
                self.prefill_seq_lens.append(seq_len)
            else:
                assert query_len == 1, (
                    "seq_len: {}, context_len: {}, query_len: {}".format(
                        seq_len, context_len, query_len))
                self.num_decode_tokens += query_len
                self.curr_seq_lens.append(curr_seq_len)

            # Compute block table.
            # TODO(sang): Combine chunked prefill and prefix caching by
            # only allowing multiple of block_size chunk size.
            # NOTE: This only works for oooooooxxx style attention.
            block_table = []
179
            if inter_data.prefix_cache_hit:
180
                block_table = block_tables[seq_id]
181
182
            elif ((chunked_prefill_enabled or not is_prompt)
                  and block_tables is not None):
183
184
185
186
187
                if curr_sliding_window_block == 0:
                    block_table = block_tables[seq_id]
                else:
                    block_table = block_tables[seq_id][
                        -curr_sliding_window_block:]
188
189
190
191
            self.block_tables.append(block_table)

            # Compute slot mapping.
            is_profile_run = is_block_tables_empty(block_tables)
192
193
194
            start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
                                                       context_len,
                                                       self.sliding_window)
195
196
            compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
                                 seq_len, context_len, start_idx,
197
198
199
200
                                 self.block_size, inter_data.block_tables)

    def build(self, seq_lens: List[int], query_lens: List[int],
              cuda_graph_pad_size: int, batch_size: int):
201
202
203
204
205
206
207
208
209
        """Build attention metadata with on-device tensors.

        Args:
            seq_lens: The maybe padded sequence lengths of the input sequences.
            query_lens: The query lengths of the input sequences.
            cuda_graph_pad_size: The padding size for cuda graph.
                                 -1 if cuda graph is not used.
            batch_size: The maybe padded batch size.
        """
210
211
212
        for inter_data in self.input_builder.inter_data_list:
            self._add_seq_group(inter_data,
                                self.input_builder.chunked_prefill_enabled)
213

214
        device = self.runner.device
215
216
217
218
219
220
        use_captured_graph = cuda_graph_pad_size != -1

        max_query_len = max(query_lens)
        max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
        max_decode_seq_len = max(self.curr_seq_lens, default=0)
        num_decode_tokens = self.num_decode_tokens
221
222
        query_start_loc = list(accumulate(query_lens, initial=0))
        seq_start_loc = list(accumulate(seq_lens, initial=0))
223
224
225
226

        if use_captured_graph:
            self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
            self.block_tables.extend([] * cuda_graph_pad_size)
227
            num_decode_tokens = batch_size
228
229
230

            # The shape of graph_block_tables is
            # [max batch size, max context len // block size].
231
            input_block_tables = self.runner.graph_block_tables[:batch_size]
232
233
234
            for i, block_table in enumerate(self.block_tables):
                if block_table:
                    input_block_tables[i, :len(block_table)] = block_table
235
236
            block_tables = torch.from_numpy(input_block_tables).to(
                device, non_blocking=True)
237
238
239
240
241
242
243
244
245
        else:
            block_tables = make_tensor_with_pad(
                self.block_tables,
                pad=0,
                dtype=torch.int,
                device=device,
            )
        assert max_query_len > 0, "query_lens: {}".format(query_lens)

246
247
248
249
250
251
252
        assert device is not None
        context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
                                               device, self.runner.pin_memory)
        seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
                                           self.runner.pin_memory)
        slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
                                               device, self.runner.pin_memory)
253
254
255
256
257
        query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
                                                  device,
                                                  self.runner.pin_memory)
        seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
                                                device, self.runner.pin_memory)
258
259
260
261
262
        placeholder_index_maps = {
            modality: placeholder_map.index_map()
            for modality, placeholder_map in
            self.multimodal_placeholder_maps.items()
        }
263
264
265
266

        return self._metadata_cls(  # type: ignore
            num_prefills=self.num_prefills,
            slot_mapping=slot_mapping_tensor,
267
            multi_modal_placeholder_index_maps=placeholder_index_maps,
268
            enable_kv_scales_calculation=True,
269
270
271
272
273
274
275
            num_prefill_tokens=self.num_prefill_tokens,
            num_decode_tokens=num_decode_tokens,
            seq_lens=seq_lens,
            seq_lens_tensor=seq_lens_tensor,
            max_query_len=max_query_len,
            max_prefill_seq_len=max_prefill_seq_len,
            max_decode_seq_len=max_decode_seq_len,
276
277
            query_start_loc=query_start_loc_tensor,
            seq_start_loc=seq_start_loc_tensor,
278
279
280
281
            context_lens_tensor=context_lens_tensor,
            block_tables=block_tables,
            use_cuda_graph=use_captured_graph,
        )
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311


class CommonAttentionState(AttentionState):

    def __init__(self, runner: "ModelRunnerBase"):
        self.runner = runner
        self._is_graph_capturing = False

    @contextmanager
    def graph_capture(self, max_batch_size: int):
        self._is_graph_capturing = True
        self._graph_slot_mapping = torch.full((max_batch_size, ),
                                              PAD_SLOT_ID,
                                              dtype=torch.long,
                                              device=self.runner.device)
        self._graph_seq_lens = torch.ones(max_batch_size,
                                          dtype=torch.int32,
                                          device=self.runner.device)
        self._graph_block_tables = torch.from_numpy(
            self.runner.graph_block_tables).to(device=self.runner.device)
        yield
        self._is_graph_capturing = False
        del self._graph_slot_mapping
        del self._graph_seq_lens
        del self._graph_block_tables

    def graph_clone(self, batch_size: int) -> "CommonAttentionState":
        assert self._is_graph_capturing
        return self.__class__(self.runner)

312
313
    def graph_capture_get_metadata_for_batch(
            self, batch_size: int, is_encoder_decoder_model: bool = False):
314
315
316
317
318
319
        assert self._is_graph_capturing
        attn_metadata = self.runner.attn_backend.make_metadata(
            num_prefills=0,
            num_prefill_tokens=0,
            num_decode_tokens=batch_size,
            slot_mapping=self._graph_slot_mapping[:batch_size],
320
            multi_modal_placeholder_index_maps=None,
321
            enable_kv_scales_calculation=True,
322
323
            seq_lens=None,
            seq_lens_tensor=self._graph_seq_lens[:batch_size],
324
            max_query_len=1,
325
            max_decode_query_len=1,
326
327
328
329
330
331
332
333
            max_prefill_seq_len=0,
            max_decode_seq_len=self.runner.max_seq_len_to_capture,
            query_start_loc=None,
            seq_start_loc=None,
            context_lens_tensor=None,
            block_tables=self._graph_block_tables[:batch_size],
            use_cuda_graph=True,
        )
334
        if is_encoder_decoder_model:
335
336
337
338
339
340
341
            # The encoder decoder model works only with XFormers and
            # Flash Attention backend. Assert the same.
            assert self.runner.attn_backend.get_name() in\
                ["XFORMERS", "FLASH_ATTN"], \
                f"Expected attn_backend name to be either 'XFORMERS' or " \
                f"'FLASH_ATTN', but "\
                f"got '{self.runner.attn_backend.get_name()}'"
342
343
344
            self._update_captured_metadata_for_enc_dec_model(
                batch_size=batch_size, attn_metadata=attn_metadata)

345
346
        return attn_metadata

347
348
349
350
351
    def get_graph_input_buffers(
            self,
            attn_metadata,
            is_encoder_decoder_model: bool = False) -> Dict[str, Any]:
        input_buffers = {
352
353
354
355
            "slot_mapping": attn_metadata.slot_mapping,
            "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
            "block_tables": attn_metadata.decode_metadata.block_tables,
        }
356
        if is_encoder_decoder_model:
357
358
359
360
361
362
363
            # The encoder decoder model works only with XFormers and
            # Flash Attention backend. Assert the same.
            assert self.runner.attn_backend.get_name() in\
                ["XFORMERS", "FLASH_ATTN"], \
                f"Expected attn_backend name to be either 'XFORMERS' or "\
                f"'FLASH_ATTN', but "\
                f"got '{self.runner.attn_backend.get_name()}'"
364
365
366
367
368
369
370
371
372
            self._add_additonal_input_buffers_for_enc_dec_model(
                attn_metadata=attn_metadata, input_buffers=input_buffers)
        return input_buffers

    def prepare_graph_input_buffers(
            self,
            input_buffers,
            attn_metadata,
            is_encoder_decoder_model: bool = False) -> None:
373
374
375
376
        input_buffers["seq_lens_tensor"].copy_(
            attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
        input_buffers["block_tables"].copy_(
            attn_metadata.decode_metadata.block_tables, non_blocking=True)
377
        if is_encoder_decoder_model:
378
379
380
381
382
383
384
            # The encoder decoder model works only with XFormers and
            # Flash Attention backend. Assert the same.
            assert self.runner.attn_backend.get_name() in\
                ["XFORMERS", "FLASH_ATTN"], \
                f"Expected attn_backend name to be either 'XFORMERS' or "\
                f"'FLASH_ATTN', but "\
                f"got '{self.runner.attn_backend.get_name()}'"
385
386
            self._prepare_input_buffers_for_enc_dec_model(
                attn_metadata, input_buffers)
387
388
389

    def begin_forward(self, model_input) -> None:
        return
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415

    def _update_captured_metadata_for_enc_dec_model(self, batch_size: int,
                                                    attn_metadata):
        """
        Updates the attention metadata parameters for CUDA graph capture in an
        encoder-decoder model.

        This method modifies attention-related tensors and metadata required
        for CUDA graph capture in encoder-decoder models. Specifically, it
        updates the cross-attention and encoder sequence tensors in the 
        AttentionMetadata object.
        """
        # During decode phase the cross_slot_mapping will be empty. Hence set
        # an empty tensor for CUDA Graph capture.
        attn_metadata.cross_slot_mapping = torch.tensor(
            [], dtype=torch.int).cuda()
        attn_metadata.cross_block_tables = torch.full(
            (batch_size, self.runner.get_max_block_per_batch()),
            1,
            dtype=torch.int).cuda()
        attn_metadata.encoder_seq_lens = torch.full((batch_size, ),
                                                    1,
                                                    dtype=torch.int).cuda()
        attn_metadata.encoder_seq_lens_tensor = torch.full(
            (batch_size, ), 1, dtype=torch.int).cuda()
        attn_metadata.max_encoder_seq_len = self.runner.max_seq_len_to_capture
416
        attn_metadata.num_encoder_tokens = 0
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458

    def _add_additonal_input_buffers_for_enc_dec_model(
            self, attn_metadata, input_buffers: Dict[str, Any]):
        """
        Saves additional input buffers specific to the encoder-decoder model
        from the attention metadata.

        This method extracts and stores encoder-decoder related input buffers
        from the `attn_metadata` into the `input_buffers` dictionary. The
        buffers include encoder sequence lengths, cross-slot mappings, and
        cross-block tables, which are essential for the encoder-decoder model
        during CUDA graph replay.
        """
        input_buffers["encoder_seq_lens_tensor"] = (
            attn_metadata.decode_metadata.encoder_seq_lens_tensor)
        input_buffers["cross_slot_mapping"] = (
            attn_metadata.decode_metadata.cross_slot_mapping)
        input_buffers["cross_block_tables"] = (
            attn_metadata.decode_metadata.cross_block_tables)

    def _prepare_input_buffers_for_enc_dec_model(self, attn_metadata,
                                                 input_buffers: Dict[str,
                                                                     Any]):
        """
        Populates input buffers with data from the encoder-decoder model's
        attention metadata.

        This method fills the input buffers with encoder-decoder specific
        tensors. It copies data from the `attn_metadata` and keyword arguments
        (`kwargs`) into corresponding buffers in the `input_buffers` dictionary.
        The copied data includes attention-related metadata as well as input 
        IDs and positional information for the encoder.
        """
        input_buffers["encoder_seq_lens_tensor"].copy_(
            attn_metadata.decode_metadata.encoder_seq_lens_tensor,
            non_blocking=True)
        input_buffers["cross_slot_mapping"].copy_(
            attn_metadata.decode_metadata.cross_slot_mapping,
            non_blocking=True)
        input_buffers["cross_block_tables"].copy_(
            attn_metadata.decode_metadata.cross_block_tables,
            non_blocking=True)
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483


def is_all_encoder_attn_metadata_set(attn_metadata):
    '''
    All attention metadata required for encoder attention is set.
    '''
    return ((attn_metadata.encoder_seq_lens is not None)
            and (attn_metadata.encoder_seq_lens_tensor is not None)
            and (attn_metadata.max_encoder_seq_len is not None))


def is_all_cross_attn_metadata_set(attn_metadata):
    '''
    All attention metadata required for enc/dec cross-attention is set.

    Superset of encoder attention required metadata.
    '''
    return (attn_metadata.is_all_encoder_attn_metadata_set
            and (attn_metadata.cross_slot_mapping is not None)
            and (attn_metadata.cross_block_tables is not None))


def get_seq_len_block_table_args(
    attn_metadata,
    is_prompt: bool,
484
    attn_type: str,
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
) -> tuple:
    '''
    The particular choice of sequence-length- and block-table-related
    attributes which should be extracted from attn_metadata is dependent
    on the type of attention operation.

    Decoder attn -> select entirely decoder self-attention-related fields
    Encoder/decoder cross-attn -> select encoder sequence lengths & 
                                  cross-attn block-tables fields
    Encoder attn -> select encoder sequence lengths fields & no block tables
    
    Arguments:

    * attn_metadata: Attention metadata structure associated with attention op
    * is_prompt: True if prefill, False otherwise
    * attn_type: encoder attention, decoder self-attention,
                 encoder/decoder cross-attention

    Returns:

    * Appropriate sequence-lengths tensor
    * Appropriate max sequence-length scalar
    * Appropriate block tables (or None)
    '''

    if attn_type == AttentionType.DECODER:
        # Decoder self-attention
        # Choose max_seq_len based on whether we are in prompt_run
        if is_prompt:
            max_seq_len = attn_metadata.max_prefill_seq_len
        else:
            max_seq_len = attn_metadata.max_decode_seq_len
        return (attn_metadata.seq_lens_tensor, max_seq_len,
                attn_metadata.block_tables)
    elif attn_type == AttentionType.ENCODER_DECODER:
        # Enc/dec cross-attention KVs match encoder sequence length;
        # cross-attention utilizes special "cross" block tables
        return (attn_metadata.encoder_seq_lens_tensor,
                attn_metadata.max_encoder_seq_len,
                attn_metadata.cross_block_tables)
    elif attn_type == AttentionType.ENCODER:
        # No block tables associated with encoder attention
        return (attn_metadata.encoder_seq_lens_tensor,
                attn_metadata.max_encoder_seq_len, None)
    else:
        raise AttributeError(f"Invalid attention type {str(attn_type)}")


def get_num_prefill_decode_query_kv_tokens(
    attn_metadata,
535
    attn_type: str,
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
) -> Tuple[int, int, int]:
    """
    Calculate the number of prefill and decode tokens for query, key/value
    based on the attention metadata and the specified attention type.

    Args:
        attn_metadata (FlashAttentionMetadata): Attention Metadata object.
        attn_type (AttentionType): The type of attention being used.
    Returns:
        Tuple[int, int, int]: A tuple containing three integers:
            - The number of prefill query tokens.
            - The number of prefill key/value tokens.
            - The number of decode query tokens.

    Raises:
        AssertionError: If the number of encoder tokens in `attn_metadata` 
        is `None` when required for the calculations.
    """
    num_prefill_query_tokens = 0
    num_decode_query_tokens = 0
    num_prefill_kv_tokens = 0
    if attn_type == AttentionType.ENCODER:
        # Encoder attention is only invoked during prefill phase.
        # The same input servers a both query and key.
        assert attn_metadata.num_encoder_tokens is not None
        num_prefill_query_tokens = attn_metadata.num_encoder_tokens
        num_prefill_kv_tokens = attn_metadata.num_encoder_tokens
        num_decode_query_tokens = 0
    elif attn_type == AttentionType.ENCODER_DECODER:
        assert attn_metadata.num_encoder_tokens is not None
        num_prefill_query_tokens = attn_metadata.num_prefill_tokens
        # The key is the encoder/cross-attention.
        num_prefill_kv_tokens = attn_metadata.num_encoder_tokens
        num_decode_query_tokens = attn_metadata.num_decode_tokens
    else:  # attn_type == AttentionType.DECODER or
        # attn_type == AttentionType.ENCODER_ONLY
        num_prefill_query_tokens = attn_metadata.num_prefill_tokens
        num_prefill_kv_tokens = attn_metadata.num_prefill_tokens
        num_decode_query_tokens = attn_metadata.num_decode_tokens

    return (num_prefill_query_tokens, num_prefill_kv_tokens,
            num_decode_query_tokens)