embedding_model_runner.py 8.9 KB
Newer Older
1
import dataclasses
2
from typing import Any, Dict, List, Optional, Tuple, Type, Union
3
4
5
6

import torch

from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
7
8
                         ModelConfig, ObservabilityConfig, ParallelConfig,
                         PromptAdapterConfig, SchedulerConfig)
9
from vllm.distributed import get_pp_group
10
from vllm.forward_context import set_forward_context
11
12
from vllm.logger import init_logger
from vllm.model_executor.pooling_metadata import PoolingMetadata
13
from vllm.multimodal import MultiModalInputs
14
from vllm.pooling_params import PoolingParams
15
16
from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData,
                           SequenceGroupMetadata)
17
18
from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPU,
                                      ModelInputForGPUBuilder)
19
20
21
22

logger = init_logger(__name__)


23
24
25
26
27
28
29
30
31
32
33
34
@dataclasses.dataclass(frozen=True)
class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU):
    """
    Used by the EmbeddingModelRunner.
    """
    pooling_metadata: Optional["PoolingMetadata"] = None


class EmbeddingModelRunner(
        GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]):
    _model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = (
        ModelInputForGPUWithPoolingMetadata)
35
    _builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder
36
37
38
39
40
41
42
43
44
45
46
47

    def __init__(
        self,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
        device_config: DeviceConfig,
        cache_config: CacheConfig,
        load_config: LoadConfig,
        lora_config: Optional[LoRAConfig],
        kv_cache_dtype: Optional[str] = "auto",
        is_driver_worker: bool = False,
48
        prompt_adapter_config: Optional[PromptAdapterConfig] = None,
49
        observability_config: Optional[ObservabilityConfig] = None,
50
51
52
53
54
55
56
57
58
59
    ):
        super().__init__(model_config,
                         parallel_config,
                         scheduler_config,
                         device_config,
                         cache_config,
                         load_config,
                         lora_config=lora_config,
                         kv_cache_dtype=kv_cache_dtype,
                         is_driver_worker=is_driver_worker,
60
                         prompt_adapter_config=prompt_adapter_config,
61
                         observability_config=observability_config)
62
63
64
65

    @torch.inference_mode()
    def execute_model(
        self,
66
        model_input: ModelInputForGPUWithPoolingMetadata,
67
        kv_caches: List[torch.Tensor],
68
        intermediate_tensors: Optional[IntermediateTensors] = None,
69
        num_steps: int = 1,
70
    ) -> Optional[Union[List[PoolerOutput], IntermediateTensors]]:
71
72
73
74
        if num_steps > 1:
            raise ValueError(
                "EmbeddingModelRunner does not support multi-step execution.")

75
        if self.lora_config:
76
77
78
79
            assert model_input.lora_requests is not None
            assert model_input.lora_mapping is not None
            self.set_active_loras(model_input.lora_requests,
                                  model_input.lora_mapping)
80

81
82
83
84
85
86
87
        if self.prompt_adapter_config:
            assert model_input.prompt_adapter_requests is not None
            assert model_input.prompt_adapter_mapping is not None
            self.set_active_prompt_adapters(
                model_input.prompt_adapter_requests,
                model_input.prompt_adapter_mapping)

88
        # Currently cuda graph is only supported by the decode phase.
89
90
91
        assert model_input.attn_metadata is not None
        prefill_meta = model_input.attn_metadata.prefill_metadata
        decode_meta = model_input.attn_metadata.decode_metadata
92
        virtual_engine = model_input.virtual_engine
93
        if prefill_meta is None and decode_meta.use_cuda_graph:
94
95
            assert model_input.input_tokens is not None
            graph_batch_size = model_input.input_tokens.shape[0]
96
97
            model_executable = self.graph_runners[virtual_engine][
                graph_batch_size]
98
99
100
101
        else:
            model_executable = self.model

        num_layers = self.model_config.get_num_layers(self.parallel_config)
102
103
104
105
106
107
        # use an empty tensor instead of `None`` to force Dynamo to pass
        # it by reference, rather by specializing on the value ``None``.
        # the `dtype` argument does not matter, and we use `float32` as
        # a placeholder (it has wide hardware support).
        kv_caches = [
            torch.tensor([], dtype=torch.float32, device=self.device)
108
109
            for _ in range(num_layers)
        ]
110

111
112
113
114
115
116
        multi_modal_kwargs = model_input.multi_modal_kwargs or {}
        if (self.observability_config is not None
                and self.observability_config.collect_model_forward_time):
            model_forward_start = torch.cuda.Event(enable_timing=True)
            model_forward_end = torch.cuda.Event(enable_timing=True)
            model_forward_start.record()
117

118
        with set_forward_context(model_input.attn_metadata):
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
            hidden_or_intermediate_states = model_executable(
                input_ids=model_input.input_tokens,
                positions=model_input.input_positions,
                kv_caches=kv_caches,
                attn_metadata=model_input.attn_metadata,
                intermediate_tensors=intermediate_tensors,
                **MultiModalInputs.as_kwargs(multi_modal_kwargs,
                                             device=self.device))

        if (self.observability_config is not None
                and self.observability_config.collect_model_forward_time):
            model_forward_end.record()

        # Only perform pooling in the last pipeline stage.
        if not get_pp_group().is_last_rank:
            if (self.is_driver_worker
                    and hidden_or_intermediate_states is not None
                    and isinstance(hidden_or_intermediate_states,
                                   IntermediateTensors)
                    and self.observability_config is not None
                    and self.observability_config.collect_model_forward_time):
                model_forward_end.synchronize()
                model_forward_time = model_forward_start.elapsed_time(
                    model_forward_end)
                orig_model_forward_time = 0.0
                if intermediate_tensors is not None:
                    orig_model_forward_time = intermediate_tensors.tensors.get(
                        "model_forward_time", torch.tensor(0.0)).item()
                hidden_or_intermediate_states.tensors["model_forward_time"] = (
                    torch.tensor(model_forward_time + orig_model_forward_time))
            return hidden_or_intermediate_states
150

151
152
        # Only perform pooling in the driver worker.
        if not self.is_driver_worker:
153
            return []
154

155
        return [
156
            self.model.pooler(hidden_states=hidden_or_intermediate_states,
157
158
                              pooling_metadata=model_input.pooling_metadata)
        ]
159
160
161
162
163
164
165
166
167

    def make_model_input_from_broadcasted_tensor_dict(
            self,
            tensor_dict: Dict[str,
                              Any]) -> ModelInputForGPUWithPoolingMetadata:
        return ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict(
            tensor_dict,
            attn_backend=self.attn_backend,
        )
168

169
    def prepare_model_input(
170
        self,
171
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
172
        virtual_engine: int = 0,
Mor Zusman's avatar
Mor Zusman committed
173
        finished_requests_ids: Optional[List[str]] = None
174
175
176
    ) -> ModelInputForGPUWithPoolingMetadata:
        assert seq_group_metadata_list is not None
        model_input = self._prepare_model_input_tensors(
Mor Zusman's avatar
Mor Zusman committed
177
            seq_group_metadata_list, finished_requests_ids)
178
179
180
181
182
183
184
        # Prepare PoolingMetadata.
        assert model_input.seq_lens is not None
        pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
                                                 model_input.seq_lens)

        return dataclasses.replace(model_input,
                                   pooling_metadata=pooling_metadata)
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208

    def _prepare_pooling(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        prompt_lens: List[int],
    ) -> PoolingMetadata:
        """Prepare PoolingMetadata for the sequence group metadata list."""
        seq_groups: List[Tuple[List[int], PoolingParams]] = []
        for i, seq_group_metadata in enumerate(seq_group_metadata_list):
            seq_ids = list(seq_group_metadata.seq_data.keys())
            pooling_params = seq_group_metadata.pooling_params
            seq_groups.append((seq_ids, pooling_params))

        seq_data: Dict[int, SequenceData] = {}
        for seq_group_metadata in seq_group_metadata_list:
            seq_data.update(seq_group_metadata.seq_data)

        pooling_metadata = PoolingMetadata(
            seq_groups=seq_groups,
            seq_data=seq_data,
            prompt_lens=prompt_lens,
        )

        return pooling_metadata