pooling_model_runner.py 9.38 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import dataclasses
5
from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast
6
7
8

import torch

9
from vllm.config import VllmConfig
10
from vllm.distributed import get_pp_group
11
from vllm.forward_context import set_forward_context
12
from vllm.logger import init_logger
13
from vllm.model_executor.models.interfaces_base import VllmModelForPooling
14
from vllm.model_executor.pooling_metadata import PoolingMetadata
15
from vllm.multimodal import MultiModalKwargs
16
from vllm.pooling_params import PoolingParams
17
18
from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData,
                           SequenceGroupMetadata)
19
20
from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPU,
                                      ModelInputForGPUBuilder)
21
22
23
24

logger = init_logger(__name__)


25
26
27
@dataclasses.dataclass(frozen=True)
class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU):
    """
28
    Used by the PoolingModelRunner.
29
30
31
32
    """
    pooling_metadata: Optional["PoolingMetadata"] = None


33
class PoolingModelRunner(
34
35
36
        GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]):
    _model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = (
        ModelInputForGPUWithPoolingMetadata)
37
    _builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder
38
39
40

    def __init__(
        self,
41
        vllm_config: VllmConfig,
42
43
44
        kv_cache_dtype: Optional[str] = "auto",
        is_driver_worker: bool = False,
    ):
45
        super().__init__(vllm_config=vllm_config,
46
                         kv_cache_dtype=kv_cache_dtype,
47
                         is_driver_worker=is_driver_worker)
48
49
50
51

    @torch.inference_mode()
    def execute_model(
        self,
52
        model_input: ModelInputForGPUWithPoolingMetadata,
53
        kv_caches: List[torch.Tensor],
54
        intermediate_tensors: Optional[IntermediateTensors] = None,
55
        num_steps: int = 1,
56
    ) -> Optional[Union[List[PoolerOutput], IntermediateTensors]]:
57
58
        if num_steps > 1:
            raise ValueError(
59
                "PoolingModelRunner does not support multi-step execution.")
60

61
        if self.lora_config:
62
63
64
65
            assert model_input.lora_requests is not None
            assert model_input.lora_mapping is not None
            self.set_active_loras(model_input.lora_requests,
                                  model_input.lora_mapping)
66

67
68
69
70
71
72
73
        if self.prompt_adapter_config:
            assert model_input.prompt_adapter_requests is not None
            assert model_input.prompt_adapter_mapping is not None
            self.set_active_prompt_adapters(
                model_input.prompt_adapter_requests,
                model_input.prompt_adapter_mapping)

74
        # Currently cuda graph is only supported by the decode phase.
75
76
77
        assert model_input.attn_metadata is not None
        prefill_meta = model_input.attn_metadata.prefill_metadata
        decode_meta = model_input.attn_metadata.decode_metadata
78
        virtual_engine = model_input.virtual_engine
79
80
81
82
83
84
85
86
87
88
        # Pooling models are (ab-)used also to integrate non text models that
        # are not autoregressive (PrithviGeosaptialMAE).
        # These model might not use attention and do not really have a prefill
        # and decode phase. The model input is processed in one shot and both
        # decode_metadata and prefill_metadata would be None for such models.
        # See the PlaceholderAttentionMetadata class.
        # TODO: Figure out if cuda_graph is of any use for these models and
        #  explore how to leverage it.
        if (prefill_meta is None and decode_meta is not None
                and decode_meta.use_cuda_graph):
89
90
91
92
93
94
95
96
97
98
99
            if model_input.inputs_embeds is None:
                assert model_input.input_tokens is not None
                graph_batch_size = model_input.input_tokens.shape[0]
                model_executable = (
                    self.graph_runners[model_input.virtual_engine][(
                        graph_batch_size, False)])
            else:
                graph_batch_size = model_input.inputs_embeds.shape[0]
                model_executable = (
                    self.graph_runners[model_input.virtual_engine][(
                        graph_batch_size, True)])
100
101
102
        else:
            model_executable = self.model

103
        multi_modal_kwargs = model_input.multi_modal_kwargs or {}
104
105
106
107
        seqlen_agnostic_kwargs = {
            "finished_requests_ids": model_input.finished_requests_ids,
            "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
        } if self.has_inner_state else {}
108
109
110
111
112
        if (self.observability_config is not None
                and self.observability_config.collect_model_forward_time):
            model_forward_start = torch.cuda.Event(enable_timing=True)
            model_forward_end = torch.cuda.Event(enable_timing=True)
            model_forward_start.record()
113

114
115
116
117
        cross_enc_kwargs = {}
        if model_input.token_types is not None:
            cross_enc_kwargs["token_type_ids"] = model_input.token_types

118
119
        with set_forward_context(model_input.attn_metadata, self.vllm_config,
                                 virtual_engine):
120
121
122
123
            hidden_or_intermediate_states = model_executable(
                input_ids=model_input.input_tokens,
                positions=model_input.input_positions,
                intermediate_tensors=intermediate_tensors,
124
125
126
127
                **MultiModalKwargs.as_kwargs(
                    multi_modal_kwargs,
                    device=self.device,
                ),
128
                **cross_enc_kwargs,
129
130
                **seqlen_agnostic_kwargs,
            )
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153

        if (self.observability_config is not None
                and self.observability_config.collect_model_forward_time):
            model_forward_end.record()

        # Only perform pooling in the last pipeline stage.
        if not get_pp_group().is_last_rank:
            if (self.is_driver_worker
                    and hidden_or_intermediate_states is not None
                    and isinstance(hidden_or_intermediate_states,
                                   IntermediateTensors)
                    and self.observability_config is not None
                    and self.observability_config.collect_model_forward_time):
                model_forward_end.synchronize()
                model_forward_time = model_forward_start.elapsed_time(
                    model_forward_end)
                orig_model_forward_time = 0.0
                if intermediate_tensors is not None:
                    orig_model_forward_time = intermediate_tensors.tensors.get(
                        "model_forward_time", torch.tensor(0.0)).item()
                hidden_or_intermediate_states.tensors["model_forward_time"] = (
                    torch.tensor(model_forward_time + orig_model_forward_time))
            return hidden_or_intermediate_states
154

155
156
        # Only perform pooling in the driver worker.
        if not self.is_driver_worker:
157
            return []
158

159
        return [
160
            self.model.pooler(hidden_states=hidden_or_intermediate_states,
161
162
                              pooling_metadata=model_input.pooling_metadata)
        ]
163
164
165
166
167
168
169
170
171

    def make_model_input_from_broadcasted_tensor_dict(
            self,
            tensor_dict: Dict[str,
                              Any]) -> ModelInputForGPUWithPoolingMetadata:
        return ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict(
            tensor_dict,
            attn_backend=self.attn_backend,
        )
172

173
    def prepare_model_input(
174
        self,
175
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
176
        virtual_engine: int = 0,
Mor Zusman's avatar
Mor Zusman committed
177
        finished_requests_ids: Optional[List[str]] = None
178
179
180
    ) -> ModelInputForGPUWithPoolingMetadata:
        assert seq_group_metadata_list is not None
        model_input = self._prepare_model_input_tensors(
Mor Zusman's avatar
Mor Zusman committed
181
            seq_group_metadata_list, finished_requests_ids)
182
183
184
185
186
187
188
        # Prepare PoolingMetadata.
        assert model_input.seq_lens is not None
        pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
                                                 model_input.seq_lens)

        return dataclasses.replace(model_input,
                                   pooling_metadata=pooling_metadata)
189
190
191
192
193
194
195
196
197
198

    def _prepare_pooling(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        prompt_lens: List[int],
    ) -> PoolingMetadata:
        """Prepare PoolingMetadata for the sequence group metadata list."""
        seq_groups: List[Tuple[List[int], PoolingParams]] = []
        for i, seq_group_metadata in enumerate(seq_group_metadata_list):
            seq_ids = list(seq_group_metadata.seq_data.keys())
199

200
            pooling_params = seq_group_metadata.pooling_params
201
            assert pooling_params is not None
202
            assert (task := pooling_params.task) is not None, (
203
204
                "You did not set `task` in the API")

205
206
            model = cast(VllmModelForPooling, self.model)
            to_update = model.pooler.get_pooling_updates(task)
207
208
            to_update.apply(pooling_params)

209
210
211
212
213
214
215
216
217
218
219
220
221
            seq_groups.append((seq_ids, pooling_params))

        seq_data: Dict[int, SequenceData] = {}
        for seq_group_metadata in seq_group_metadata_list:
            seq_data.update(seq_group_metadata.seq_data)

        pooling_metadata = PoolingMetadata(
            seq_groups=seq_groups,
            seq_data=seq_data,
            prompt_lens=prompt_lens,
        )

        return pooling_metadata