enc_dec_model_runner.py 23.4 KB
Newer Older
1
import dataclasses
2
import itertools
3
4
5
6
7
8
9
from typing import Any, Dict, List, Optional, Tuple, Type, cast

import torch
import torch.distributed

from vllm.attention.backends.abstract import (AttentionBackend,
                                              AttentionMetadata)
10
from vllm.attention.backends.utils import PAD_SLOT_ID
11
12
13
from vllm.attention.selector import (_Backend, get_env_variable_attn_backend,
                                     get_global_forced_attn_backend,
                                     global_force_attn_backend)
14
from vllm.config import ModelConfig, VllmConfig
15
from vllm.forward_context import set_forward_context
16
from vllm.inputs import INPUT_REGISTRY, InputRegistry
17
18
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
19
from vllm.model_executor.layers.sampler import SamplerOutput
20
from vllm.model_executor.model_loader.utils import get_architecture_class_name
21
22
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalInputs,
                             MultiModalRegistry)
23
from vllm.sampling_params import SamplingParams
24
from vllm.sequence import (IntermediateTensors, PoolerOutput,
25
26
                           SequenceGroupMetadata)
from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad
27
from vllm.worker.model_runner import (GPUModelRunnerBase,
28
                                      ModelInputForGPUBuilder,
29
30
                                      ModelInputForGPUWithSamplingMetadata,
                                      _get_graph_batch_size)
31
32
33
34
35
36
37
from vllm.worker.model_runner_base import (
    _add_attn_metadata_broadcastable_dict,
    _add_sampling_metadata_broadcastable_dict)
from vllm.worker.utils import assert_enc_dec_mr_supported_scenario

logger = init_logger(__name__)

38
39
40
41
42
# The Mllama model has PagedAttention specific logic because of which it
# can only be run with the XFORMERS backend
# TODO Make Mllama model work with Flash Attention backend.
_XFORMERS_ONLY_ENCODER_DECODER_ARCHS = ["MllamaForConditionalGeneration"]

43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60

@dataclasses.dataclass(frozen=True)
class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata):
    """
    Used by the EncoderDecoderModelRunner.
    """
    encoder_input_tokens: Optional[torch.Tensor] = None
    encoder_input_positions: Optional[torch.Tensor] = None

    def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
        tensor_dict = {
            "input_tokens": self.input_tokens,
            "input_positions": self.input_positions,
            "encoder_input_tokens": self.encoder_input_tokens,
            "encoder_input_positions": self.encoder_input_positions,
            "virtual_engine": self.virtual_engine,
            "request_ids_to_seq_ids": self.request_ids_to_seq_ids,
            "finished_requests_ids": self.finished_requests_ids,
61
            "multi_modal_kwargs": self.multi_modal_kwargs,
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
        }
        _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
        _add_sampling_metadata_broadcastable_dict(tensor_dict,
                                                  self.sampling_metadata)
        return tensor_dict

    @classmethod
    def from_broadcasted_tensor_dict(
        cls,
        tensor_dict: Dict[str, Any],
        attn_backend: Optional["AttentionBackend"] = None,
    ) -> "EncoderDecoderModelInput":
        return cast(
            EncoderDecoderModelInput,
            super().from_broadcasted_tensor_dict(tensor_dict, attn_backend))


class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
    _model_input_cls: Type[EncoderDecoderModelInput] = (
        EncoderDecoderModelInput)
    _builder_cls: Type[ModelInputForGPUBuilder] = (ModelInputForGPUBuilder)

    def __init__(
        self,
86
        vllm_config: VllmConfig,
87
88
        kv_cache_dtype: Optional[str] = "auto",
        is_driver_worker: bool = False,
89
90
        input_registry: InputRegistry = INPUT_REGISTRY,
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
91
92
93
94
    ):
        '''
        EncoderDecoderModelRunner constructor.

95
        `lora_config` and `prompt_adapter_config` are
96
97
98
99
        unused (since these features are not yet supported for encoder/decoder
        models) but these arguments are present here for compatibility with 
        the base-class constructor.
        '''
100
101
        self._maybe_force_supported_attention_backend(vllm_config.model_config)

102
        super().__init__(
103
            vllm_config=vllm_config,
104
105
106
107
108
109
110
            kv_cache_dtype=kv_cache_dtype,
            is_driver_worker=is_driver_worker,
        )

        # Crash for unsupported encoder/scenarios
        assert_enc_dec_mr_supported_scenario(self)

111
112
113
114
115
116
    def _is_xformers_only_encoder_decoder_model(self,
                                                model: ModelConfig) -> bool:
        return get_architecture_class_name(
            model) in _XFORMERS_ONLY_ENCODER_DECODER_ARCHS

    def _maybe_force_supported_attention_backend(self, model: ModelConfig):
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
        '''
        Force vLLM to use the XFormers attention backend,
        which is currently the only supported option.
        '''

        def raise_backend_err():
            # The user has specified an attention backend override
            # which is invalid for encoder/decoder models
            raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_BACKEND)

        maybe_env_var_forced_backend = get_env_variable_attn_backend()
        maybe_global_forced_backend = get_global_forced_attn_backend()
        is_forced_by_global = maybe_global_forced_backend is not None
        is_forced_by_env_var = maybe_env_var_forced_backend is not None

132
133
        if not (is_forced_by_global or is_forced_by_env_var) \
            and self._is_xformers_only_encoder_decoder_model(model):
134
135
            # The user has not already specified an attention backend
            # override
136
137
138
139
            logger.info(
                "Encoder-Decoder Model Architecture %s requires XFormers "
                "backend; overriding backend auto-selection and "
                "forcing XFormers.", get_architecture_class_name(model))
140
141
142
143
            global_force_attn_backend(_Backend.XFORMERS)
        elif is_forced_by_global:
            # Backend override enforced by global variable takes
            # precedence over vLLM backend environment variable.
144
145
            if maybe_global_forced_backend not in\
                 [_Backend.XFORMERS, _Backend.FLASH_ATTN]:
146
147
148
149
                raise_backend_err()
        elif is_forced_by_env_var:
            # Backend override enforced by vLLM backend
            # environment variable
150
151
            if maybe_env_var_forced_backend not in\
                 [_Backend.XFORMERS, _Backend.FLASH_ATTN]:
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
                raise_backend_err()

    def _list_to_int32_tensor(
        self,
        _list: List[int],
    ) -> torch.Tensor:
        return torch.tensor(_list, dtype=torch.int32, device=self.device)

    def _list_to_long_tensor(
        self,
        _list: List[int],
    ) -> torch.Tensor:
        return torch.tensor(_list, dtype=torch.long, device=self.device)

    def _empty_int32_tensor(self) -> torch.Tensor:
        return self._list_to_int32_tensor([])

    def _empty_long_tensor(self) -> torch.Tensor:
        return self._list_to_long_tensor([])

    @torch.inference_mode()
    def execute_model(
        self,
        model_input: EncoderDecoderModelInput,
        kv_caches: List[torch.Tensor],
        intermediate_tensors: Optional[IntermediateTensors] = None,
        num_steps: int = 1,
    ) -> Optional[List[PoolerOutput]]:
        if num_steps > 1:
            raise ValueError("num_steps > 1 is not supported in "
                             "EncoderDecoderModelRunner")

184
185
186
187
188
189
190
191
192
        if (model_input.attn_metadata is not None
                and model_input.attn_metadata.prefill_metadata is None
                and model_input.attn_metadata.decode_metadata.use_cuda_graph):
            assert model_input.input_tokens is not None
            graph_batch_size = model_input.input_tokens.shape[0]
            model_executable = self.graph_runners[
                model_input.virtual_engine][graph_batch_size]
        else:
            model_executable = self.model
193
194
195
196

        seqlen_agnostic_kwargs = {
            "finished_requests_ids": model_input.finished_requests_ids,
            "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
197
        } if self.has_inner_state else {}
198
199

        multi_modal_kwargs = model_input.multi_modal_kwargs or {}
200
201
202
203
204
205
206
207
208
209
210
211
        with set_forward_context(model_input.attn_metadata):
            hidden_or_intermediate_states = model_executable(
                input_ids=model_input.input_tokens,
                positions=model_input.input_positions,
                encoder_input_ids=model_input.encoder_input_tokens,
                encoder_positions=model_input.encoder_input_positions,
                kv_caches=kv_caches,
                attn_metadata=model_input.attn_metadata,
                intermediate_tensors=intermediate_tensors,
                **MultiModalInputs.as_kwargs(multi_modal_kwargs,
                                             device=self.device),
                **seqlen_agnostic_kwargs)
212
213
214
215
216
217
218

        logits = self.model.compute_logits(hidden_or_intermediate_states,
                                           model_input.sampling_metadata)

        if not self.is_driver_worker:
            return []

219
220
221
        if model_input.async_callback is not None:
            model_input.async_callback()

222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
        # Sample the next token.
        output: SamplerOutput = self.model.sample(
            logits=logits,
            sampling_metadata=model_input.sampling_metadata,
        )

        return [output]

    def make_model_input_from_broadcasted_tensor_dict(
            self, tensor_dict: Dict[str, Any]) -> EncoderDecoderModelInput:
        return EncoderDecoderModelInput.from_broadcasted_tensor_dict(
            tensor_dict,
            attn_backend=self.attn_backend,
        )

    def prepare_model_input(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        virtual_engine: int = 0,
        finished_requests_ids: Optional[List[str]] = None
    ) -> EncoderDecoderModelInput:
        """Prepare the model input based on a given sequence group, including
        metadata for the sampling step.

        Since chunked prefill is not supported for encoder/decoder models,
        `input_tokens` is assumed to be either entirely prefill tokens or
        entirely decode tokens.

        """
        model_input = self._prepare_model_input_tensors(
            seq_group_metadata_list, finished_requests_ids)
        (
            attn_metadata,
            encoder_input_tokens_tensor,
            encoder_input_positions_tensor,
        ) = (self._prepare_encoder_model_input_tensors(seq_group_metadata_list,
                                                       model_input))
        # Inject attn_metadata encoder/cross-attention fields &
        # encoder input tokens/positions into model_input.
        # Frozen dataclass fields cannot be modified, so use
        # dataclasses.replace to construct a new model input
        # instance.
        model_input = dataclasses.replace(
            model_input,
            attn_metadata=attn_metadata,
            encoder_input_tokens=encoder_input_tokens_tensor,
            encoder_input_positions=encoder_input_positions_tensor,
        )

271
        generators = self.get_generators(finished_requests_ids)
272
273
274
275
        sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
                                                     model_input.seq_lens,
                                                     model_input.query_lens,
                                                     self.device,
276
277
                                                     self.pin_memory,
                                                     generators=generators)
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
        is_prompt = (seq_group_metadata_list[0].is_prompt
                     if seq_group_metadata_list else None)
        return dataclasses.replace(model_input,
                                   sampling_metadata=sampling_metadata,
                                   is_prompt=is_prompt,
                                   virtual_engine=virtual_engine)

    @torch.inference_mode()
    def profile_run(self) -> None:
        # Enable top-k sampling to reflect the accurate memory usage.
        sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
        max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
        max_num_seqs = self.scheduler_config.max_num_seqs

        # Profile memory usage with max_num_sequences sequences and the total
        # number of tokens equal to max_num_batched_tokens.
        seqs: List[SequenceGroupMetadata] = []

296
297
        max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
            self.model_config)
298
        if max_mm_tokens > 0:
299
            logger.info("Starting profile run for multi-modal models.")
300
301
302
303
304
305
306

        batch_size = 0
        for group_id in range(max_num_seqs):
            seq_len = (max_num_batched_tokens // max_num_seqs +
                       (group_id < max_num_batched_tokens % max_num_seqs))
            batch_size += seq_len

307
308
            decoder_dummy_data = self.input_registry \
                .dummy_data_for_profiling(self.model_config,
309
                                          seq_len,
310
311
                                          self.mm_registry,
                                          is_encoder_data=False)
312
            encoder_dummy_data \
313
314
315
316
317
                = self.input_registry.dummy_data_for_profiling(
                    self.model_config,
                                         seq_len,
                                         self.mm_registry,
                                         is_encoder_data=True)
318
319

            # Having more tokens is over-conservative but otherwise fine
320
321
322
            assert len(
                decoder_dummy_data.seq_data.prompt_token_ids
            ) >= seq_len, (
323
                f"Expected at least {seq_len} dummy tokens for profiling, "
324
325
                f"but got: {len(decoder_dummy_data.seq_data.prompt_token_ids)}"
            )
326

327
328
            assert decoder_dummy_data.multi_modal_data is None or \
            encoder_dummy_data.multi_modal_data is None, (
329
330
                "Multi-modal data can't be provided in both encoder and decoder"
            )
331
332
333
334

            seq = SequenceGroupMetadata(
                request_id=str(group_id),
                is_prompt=True,
335
                seq_data={group_id: decoder_dummy_data.seq_data},
336
337
                sampling_params=sampling_params,
                block_tables=None,
338
                encoder_seq_data=encoder_dummy_data.seq_data,
339
                cross_block_table=None,
340
341
342
343
344
                multi_modal_data=decoder_dummy_data.multi_modal_data
                or encoder_dummy_data.multi_modal_data,
                multi_modal_placeholders=decoder_dummy_data.
                multi_modal_placeholders
                or encoder_dummy_data.multi_modal_placeholders)
345
346
347
348
            seqs.append(seq)

        # Run the model with the dummy inputs.
        num_layers = self.model_config.get_num_layers(self.parallel_config)
349
350
351
352
353
354
        # use an empty tensor instead of `None`` to force Dynamo to pass
        # it by reference, rather by specializing on the value ``None``.
        # the `dtype` argument does not matter, and we use `float32` as
        # a placeholder (it has wide hardware support).
        kv_caches = [
            torch.tensor([], dtype=torch.float32, device=self.device)
355
356
            for _ in range(num_layers)
        ]
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
        finished_requests_ids = [seq.request_id for seq in seqs]
        model_input = self.prepare_model_input(
            seqs, finished_requests_ids=finished_requests_ids)
        intermediate_tensors = None
        self.execute_model(model_input, kv_caches, intermediate_tensors)
        torch.cuda.synchronize()
        return

    def _prepare_encoder_model_input_tensors(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        model_input: EncoderDecoderModelInput,
    ) -> Tuple[AttentionMetadata, Optional[torch.Tensor],
               Optional[torch.Tensor]]:
        """Helper method to prepare the encoder- and cross-attn-related
        model inputs based on a given sequence group. These additional inputs
        are used to augment an already-computed `EncoderDecoderModelInput`
        data structure which already has decoder-related model inputs
        populated.

        Sets the following attn_metadata fields:
        * `num_encoder_tokens`
        * `encoder_seq_lens`
        * `encoder_seq_lens_tensor`
        * `max_encoder_seq_len`
        * `cross_slot_mapping`
        * `cross_block_tables`

        Constructs a new model inputs data structure, based on
        (1) the existing fields in the `model_inputs` argument,
        and (2) the following additional fields which are
        computed (or in the case of `attn_metadata`, updated) 
        by this function:
        * attn_metadata
        * encoder_input_tokens
        * encoder_input_positions

        Arguments:

        * seq_group_metadata_list: list of sequence groups for which to
                                   compute inputs
        * model_inputs: model inputs data structure with decoder-oriented
                        fields already computed.

        Return:

        * Updated model inputs data structure
        """

        if len(seq_group_metadata_list) == 0:
            return (model_input.attn_metadata, None, None)

        # Since we are not supporting chunked prefill either the entire
        # batch is prefill or it is decode
        is_prompt = seq_group_metadata_list[0].is_prompt

        # Build encoder inputs
        encoder_seq_lens: List[int] = []
        if is_prompt:
            # Prefill phase.
            cross_block_tables = self._empty_int32_tensor().view(
                len(seq_group_metadata_list), -1)

            # Extract input tokens/positions, cross-attention slot-mapping,
            # & seq len from each sequence group metadata
            (
                encoder_input_tokens,
                encoder_input_positions,
                cross_slot_mapping,
            ) = (
                [],
                [],
                [],
            )
            for seq_group_metadata in seq_group_metadata_list:
                # Build seq lens
                seq_len = seq_group_metadata.encoder_seq_data.get_len()
                token_ids = seq_group_metadata.encoder_seq_data.get_token_ids()
                encoder_seq_lens.append(seq_len)

                # Build slot mapping
                is_profile_run = (seq_group_metadata.block_tables is None)
                if is_profile_run:
                    # During memory profiling, the block tables are not
                    # initialized yet. In this case, we just use a dummy
                    # slot mapping.
                    # In embeddings, the block tables are {seq_id: None}.
444
                    cross_slot_mapping.extend([PAD_SLOT_ID] * seq_len)
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
                else:
                    for i in range(0, seq_len):
                        block_number = seq_group_metadata.cross_block_table[
                            i // self.block_size]
                        block_offset = i % self.block_size
                        slot = block_number * self.block_size + block_offset
                        cross_slot_mapping.append(slot)

                # Build encoder input tokens
                encoder_input_tokens.extend(token_ids)
                encoder_input_positions.extend(list(range(0, seq_len)))

            # Convert tokens/positions & cross-attention
            # slot-mapping to encoder input tensors
            encoder_input_tokens_tensor = self._list_to_long_tensor(
                encoder_input_tokens)
            encoder_input_positions_tensor = self._list_to_long_tensor(
                encoder_input_positions)
            cross_slot_mapping_tensor = self._list_to_long_tensor(
                cross_slot_mapping)

        else:
            # Decode phase.
            encoder_input_tokens_tensor = self._empty_long_tensor()
            encoder_input_positions_tensor = self._empty_long_tensor()
            cross_slot_mapping_tensor = self._empty_long_tensor()
            # Extract cross-attention block tables &
            # seq len from each sequence group metadata.
            # Cross-attention block tables are empty
            # during vLLM memory profiling.
            cross_block_tables = []
            for seq_group_metadata in seq_group_metadata_list:
477
478
479
480
481
482
                for _ in range(len(seq_group_metadata.seq_data)):
                    encoder_seq_lens.append(
                        seq_group_metadata.encoder_seq_data.get_len())
                    cross_block_table = seq_group_metadata.cross_block_table
                    cross_block_tables.append([] if (
                        cross_block_table is None) else cross_block_table)
483

484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
            if (model_input.attn_metadata is not None
                    and model_input.attn_metadata.use_cuda_graph):
                # We will be using CUDA graph replay for this decode.
                max_len_of_block_table = self.get_max_block_per_batch()
                batch_size = len(encoder_seq_lens)
                graph_batch_size = _get_graph_batch_size(batch_size)
                assert graph_batch_size >= batch_size
                cuda_graph_pad_size = graph_batch_size - batch_size
                # extend the cross_block_tables and encoder_seq_lens to match
                # the graph_batch_size.
                cross_block_tables.extend([[]
                                           for _ in range(cuda_graph_pad_size)
                                           ])
                encoder_seq_lens.extend(
                    itertools.repeat(1, cuda_graph_pad_size))

            else:
                max_len_of_block_table = max(
                    len(block_table) for block_table in cross_block_tables)

504
505
            cross_block_tables = make_tensor_with_pad(
                cross_block_tables,
506
                max_len=max_len_of_block_table,
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
                pad=0,
                dtype=torch.int32,
                device=self.device,
            )

        # Compute encoder sequence lengths & encoder
        # sequence starting offset tensors
        max_encoder_seq_len = max(encoder_seq_lens, default=0)
        encoder_seq_lens_tensor = self._list_to_int32_tensor(encoder_seq_lens)
        encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] +
                                            1,
                                            dtype=torch.int32,
                                            device=self.device)
        torch.cumsum(encoder_seq_lens_tensor,
                     dim=0,
                     dtype=encoder_seq_start_loc.dtype,
                     out=encoder_seq_start_loc[1:])

        # Update attention metadata with encoder-oriented attributes
        attn_metadata = model_input.attn_metadata
        assert attn_metadata is not None
        (
            attn_metadata.num_encoder_tokens,
            attn_metadata.encoder_seq_lens,
            attn_metadata.encoder_seq_lens_tensor,
            attn_metadata.max_encoder_seq_len,
533
            attn_metadata.encoder_seq_start_loc,
534
535
536
537
538
539
540
            attn_metadata.cross_slot_mapping,
            attn_metadata.cross_block_tables,
        ) = (
            sum(encoder_seq_lens),
            encoder_seq_lens,
            encoder_seq_lens_tensor,
            max_encoder_seq_len,
541
            encoder_seq_start_loc,
542
543
544
545
546
547
            cross_slot_mapping_tensor,
            cross_block_tables,
        )

        return (attn_metadata, encoder_input_tokens_tensor,
                encoder_input_positions_tensor)