pooling_model_runner.py 8.53 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import dataclasses
4
from typing import Any, Dict, List, Optional, Tuple, Type, Union
5
6
7

import torch

8
from vllm.config import VllmConfig
9
from vllm.distributed import get_pp_group
10
from vllm.forward_context import set_forward_context
11
12
from vllm.logger import init_logger
from vllm.model_executor.pooling_metadata import PoolingMetadata
13
from vllm.multimodal import MultiModalKwargs
14
from vllm.pooling_params import PoolingParams
15
16
from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData,
                           SequenceGroupMetadata)
17
18
from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPU,
                                      ModelInputForGPUBuilder)
19
20
21
22

logger = init_logger(__name__)


23
24
25
@dataclasses.dataclass(frozen=True)
class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU):
    """
26
    Used by the PoolingModelRunner.
27
28
29
30
    """
    pooling_metadata: Optional["PoolingMetadata"] = None


31
class PoolingModelRunner(
32
33
34
        GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]):
    _model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = (
        ModelInputForGPUWithPoolingMetadata)
35
    _builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder
36
37
38

    def __init__(
        self,
39
        vllm_config: VllmConfig,
40
41
42
        kv_cache_dtype: Optional[str] = "auto",
        is_driver_worker: bool = False,
    ):
43
        super().__init__(vllm_config=vllm_config,
44
                         kv_cache_dtype=kv_cache_dtype,
45
                         is_driver_worker=is_driver_worker)
46
47
48
49

    @torch.inference_mode()
    def execute_model(
        self,
50
        model_input: ModelInputForGPUWithPoolingMetadata,
51
        kv_caches: List[torch.Tensor],
52
        intermediate_tensors: Optional[IntermediateTensors] = None,
53
        num_steps: int = 1,
54
    ) -> Optional[Union[List[PoolerOutput], IntermediateTensors]]:
55
56
        if num_steps > 1:
            raise ValueError(
57
                "PoolingModelRunner does not support multi-step execution.")
58

59
        if self.lora_config:
60
61
62
63
            assert model_input.lora_requests is not None
            assert model_input.lora_mapping is not None
            self.set_active_loras(model_input.lora_requests,
                                  model_input.lora_mapping)
64

65
66
67
68
69
70
71
        if self.prompt_adapter_config:
            assert model_input.prompt_adapter_requests is not None
            assert model_input.prompt_adapter_mapping is not None
            self.set_active_prompt_adapters(
                model_input.prompt_adapter_requests,
                model_input.prompt_adapter_mapping)

72
        # Currently cuda graph is only supported by the decode phase.
73
74
75
        assert model_input.attn_metadata is not None
        prefill_meta = model_input.attn_metadata.prefill_metadata
        decode_meta = model_input.attn_metadata.decode_metadata
76
        virtual_engine = model_input.virtual_engine
77
        if prefill_meta is None and decode_meta.use_cuda_graph:
78
79
            assert model_input.input_tokens is not None
            graph_batch_size = model_input.input_tokens.shape[0]
80
81
            model_executable = self.graph_runners[virtual_engine][
                graph_batch_size]
82
83
84
85
        else:
            model_executable = self.model

        num_layers = self.model_config.get_num_layers(self.parallel_config)
86
87
88
89
90
91
        # use an empty tensor instead of `None`` to force Dynamo to pass
        # it by reference, rather by specializing on the value ``None``.
        # the `dtype` argument does not matter, and we use `float32` as
        # a placeholder (it has wide hardware support).
        kv_caches = [
            torch.tensor([], dtype=torch.float32, device=self.device)
92
93
            for _ in range(num_layers)
        ]
94

95
        multi_modal_kwargs = model_input.multi_modal_kwargs or {}
96
97
98
99
        seqlen_agnostic_kwargs = {
            "finished_requests_ids": model_input.finished_requests_ids,
            "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
        } if self.has_inner_state else {}
100
101
102
103
104
        if (self.observability_config is not None
                and self.observability_config.collect_model_forward_time):
            model_forward_start = torch.cuda.Event(enable_timing=True)
            model_forward_end = torch.cuda.Event(enable_timing=True)
            model_forward_start.record()
105

106
107
108
109
        cross_enc_kwargs = {}
        if model_input.token_types is not None:
            cross_enc_kwargs["token_type_ids"] = model_input.token_types

110
111
        with set_forward_context(model_input.attn_metadata, self.vllm_config,
                                 virtual_engine):
112
113
114
115
116
117
            hidden_or_intermediate_states = model_executable(
                input_ids=model_input.input_tokens,
                positions=model_input.input_positions,
                kv_caches=kv_caches,
                attn_metadata=model_input.attn_metadata,
                intermediate_tensors=intermediate_tensors,
118
                **MultiModalKwargs.as_kwargs(multi_modal_kwargs,
119
                                             device=self.device),
120
121
                **cross_enc_kwargs,
                **seqlen_agnostic_kwargs)
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144

        if (self.observability_config is not None
                and self.observability_config.collect_model_forward_time):
            model_forward_end.record()

        # Only perform pooling in the last pipeline stage.
        if not get_pp_group().is_last_rank:
            if (self.is_driver_worker
                    and hidden_or_intermediate_states is not None
                    and isinstance(hidden_or_intermediate_states,
                                   IntermediateTensors)
                    and self.observability_config is not None
                    and self.observability_config.collect_model_forward_time):
                model_forward_end.synchronize()
                model_forward_time = model_forward_start.elapsed_time(
                    model_forward_end)
                orig_model_forward_time = 0.0
                if intermediate_tensors is not None:
                    orig_model_forward_time = intermediate_tensors.tensors.get(
                        "model_forward_time", torch.tensor(0.0)).item()
                hidden_or_intermediate_states.tensors["model_forward_time"] = (
                    torch.tensor(model_forward_time + orig_model_forward_time))
            return hidden_or_intermediate_states
145

146
147
        # Only perform pooling in the driver worker.
        if not self.is_driver_worker:
148
            return []
149

150
        return [
151
            self.model.pooler(hidden_states=hidden_or_intermediate_states,
152
153
                              pooling_metadata=model_input.pooling_metadata)
        ]
154
155
156
157
158
159
160
161
162

    def make_model_input_from_broadcasted_tensor_dict(
            self,
            tensor_dict: Dict[str,
                              Any]) -> ModelInputForGPUWithPoolingMetadata:
        return ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict(
            tensor_dict,
            attn_backend=self.attn_backend,
        )
163

164
    def prepare_model_input(
165
        self,
166
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
167
        virtual_engine: int = 0,
Mor Zusman's avatar
Mor Zusman committed
168
        finished_requests_ids: Optional[List[str]] = None
169
170
171
    ) -> ModelInputForGPUWithPoolingMetadata:
        assert seq_group_metadata_list is not None
        model_input = self._prepare_model_input_tensors(
Mor Zusman's avatar
Mor Zusman committed
172
            seq_group_metadata_list, finished_requests_ids)
173
174
175
176
177
178
179
        # Prepare PoolingMetadata.
        assert model_input.seq_lens is not None
        pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
                                                 model_input.seq_lens)

        return dataclasses.replace(model_input,
                                   pooling_metadata=pooling_metadata)
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203

    def _prepare_pooling(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        prompt_lens: List[int],
    ) -> PoolingMetadata:
        """Prepare PoolingMetadata for the sequence group metadata list."""
        seq_groups: List[Tuple[List[int], PoolingParams]] = []
        for i, seq_group_metadata in enumerate(seq_group_metadata_list):
            seq_ids = list(seq_group_metadata.seq_data.keys())
            pooling_params = seq_group_metadata.pooling_params
            seq_groups.append((seq_ids, pooling_params))

        seq_data: Dict[int, SequenceData] = {}
        for seq_group_metadata in seq_group_metadata_list:
            seq_data.update(seq_group_metadata.seq_data)

        pooling_metadata = PoolingMetadata(
            seq_groups=seq_groups,
            seq_data=seq_data,
            prompt_lens=prompt_lens,
        )

        return pooling_metadata