embedding_model_runner.py 7.99 KB
Newer Older
1
import dataclasses
2
from typing import Any, Dict, List, Optional, Tuple, Type, Union
3
4
5

import torch

6
from vllm.config import VllmConfig
7
from vllm.distributed import get_pp_group
8
from vllm.forward_context import set_forward_context
9
10
from vllm.logger import init_logger
from vllm.model_executor.pooling_metadata import PoolingMetadata
11
from vllm.multimodal import MultiModalInputs
12
from vllm.pooling_params import PoolingParams
13
14
from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData,
                           SequenceGroupMetadata)
15
16
from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPU,
                                      ModelInputForGPUBuilder)
17
18
19
20

logger = init_logger(__name__)


21
22
23
24
25
26
27
28
29
30
31
32
@dataclasses.dataclass(frozen=True)
class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU):
    """
    Used by the EmbeddingModelRunner.
    """
    pooling_metadata: Optional["PoolingMetadata"] = None


class EmbeddingModelRunner(
        GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]):
    _model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = (
        ModelInputForGPUWithPoolingMetadata)
33
    _builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder
34
35
36

    def __init__(
        self,
37
        vllm_config: VllmConfig,
38
39
40
        kv_cache_dtype: Optional[str] = "auto",
        is_driver_worker: bool = False,
    ):
41
        super().__init__(vllm_config=vllm_config,
42
                         kv_cache_dtype=kv_cache_dtype,
43
                         is_driver_worker=is_driver_worker)
44
45
46
47

    @torch.inference_mode()
    def execute_model(
        self,
48
        model_input: ModelInputForGPUWithPoolingMetadata,
49
        kv_caches: List[torch.Tensor],
50
        intermediate_tensors: Optional[IntermediateTensors] = None,
51
        num_steps: int = 1,
52
    ) -> Optional[Union[List[PoolerOutput], IntermediateTensors]]:
53
54
55
56
        if num_steps > 1:
            raise ValueError(
                "EmbeddingModelRunner does not support multi-step execution.")

57
        if self.lora_config:
58
59
60
61
            assert model_input.lora_requests is not None
            assert model_input.lora_mapping is not None
            self.set_active_loras(model_input.lora_requests,
                                  model_input.lora_mapping)
62

63
64
65
66
67
68
69
        if self.prompt_adapter_config:
            assert model_input.prompt_adapter_requests is not None
            assert model_input.prompt_adapter_mapping is not None
            self.set_active_prompt_adapters(
                model_input.prompt_adapter_requests,
                model_input.prompt_adapter_mapping)

70
        # Currently cuda graph is only supported by the decode phase.
71
72
73
        assert model_input.attn_metadata is not None
        prefill_meta = model_input.attn_metadata.prefill_metadata
        decode_meta = model_input.attn_metadata.decode_metadata
74
        virtual_engine = model_input.virtual_engine
75
        if prefill_meta is None and decode_meta.use_cuda_graph:
76
77
            assert model_input.input_tokens is not None
            graph_batch_size = model_input.input_tokens.shape[0]
78
79
            model_executable = self.graph_runners[virtual_engine][
                graph_batch_size]
80
81
82
83
        else:
            model_executable = self.model

        num_layers = self.model_config.get_num_layers(self.parallel_config)
84
85
86
87
88
89
        # use an empty tensor instead of `None`` to force Dynamo to pass
        # it by reference, rather by specializing on the value ``None``.
        # the `dtype` argument does not matter, and we use `float32` as
        # a placeholder (it has wide hardware support).
        kv_caches = [
            torch.tensor([], dtype=torch.float32, device=self.device)
90
91
            for _ in range(num_layers)
        ]
92

93
94
95
96
97
98
        multi_modal_kwargs = model_input.multi_modal_kwargs or {}
        if (self.observability_config is not None
                and self.observability_config.collect_model_forward_time):
            model_forward_start = torch.cuda.Event(enable_timing=True)
            model_forward_end = torch.cuda.Event(enable_timing=True)
            model_forward_start.record()
99

100
        with set_forward_context(model_input.attn_metadata):
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
            hidden_or_intermediate_states = model_executable(
                input_ids=model_input.input_tokens,
                positions=model_input.input_positions,
                kv_caches=kv_caches,
                attn_metadata=model_input.attn_metadata,
                intermediate_tensors=intermediate_tensors,
                **MultiModalInputs.as_kwargs(multi_modal_kwargs,
                                             device=self.device))

        if (self.observability_config is not None
                and self.observability_config.collect_model_forward_time):
            model_forward_end.record()

        # Only perform pooling in the last pipeline stage.
        if not get_pp_group().is_last_rank:
            if (self.is_driver_worker
                    and hidden_or_intermediate_states is not None
                    and isinstance(hidden_or_intermediate_states,
                                   IntermediateTensors)
                    and self.observability_config is not None
                    and self.observability_config.collect_model_forward_time):
                model_forward_end.synchronize()
                model_forward_time = model_forward_start.elapsed_time(
                    model_forward_end)
                orig_model_forward_time = 0.0
                if intermediate_tensors is not None:
                    orig_model_forward_time = intermediate_tensors.tensors.get(
                        "model_forward_time", torch.tensor(0.0)).item()
                hidden_or_intermediate_states.tensors["model_forward_time"] = (
                    torch.tensor(model_forward_time + orig_model_forward_time))
            return hidden_or_intermediate_states
132

133
134
        # Only perform pooling in the driver worker.
        if not self.is_driver_worker:
135
            return []
136

137
        return [
138
            self.model.pooler(hidden_states=hidden_or_intermediate_states,
139
140
                              pooling_metadata=model_input.pooling_metadata)
        ]
141
142
143
144
145
146
147
148
149

    def make_model_input_from_broadcasted_tensor_dict(
            self,
            tensor_dict: Dict[str,
                              Any]) -> ModelInputForGPUWithPoolingMetadata:
        return ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict(
            tensor_dict,
            attn_backend=self.attn_backend,
        )
150

151
    def prepare_model_input(
152
        self,
153
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
154
        virtual_engine: int = 0,
Mor Zusman's avatar
Mor Zusman committed
155
        finished_requests_ids: Optional[List[str]] = None
156
157
158
    ) -> ModelInputForGPUWithPoolingMetadata:
        assert seq_group_metadata_list is not None
        model_input = self._prepare_model_input_tensors(
Mor Zusman's avatar
Mor Zusman committed
159
            seq_group_metadata_list, finished_requests_ids)
160
161
162
163
164
165
166
        # Prepare PoolingMetadata.
        assert model_input.seq_lens is not None
        pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
                                                 model_input.seq_lens)

        return dataclasses.replace(model_input,
                                   pooling_metadata=pooling_metadata)
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190

    def _prepare_pooling(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        prompt_lens: List[int],
    ) -> PoolingMetadata:
        """Prepare PoolingMetadata for the sequence group metadata list."""
        seq_groups: List[Tuple[List[int], PoolingParams]] = []
        for i, seq_group_metadata in enumerate(seq_group_metadata_list):
            seq_ids = list(seq_group_metadata.seq_data.keys())
            pooling_params = seq_group_metadata.pooling_params
            seq_groups.append((seq_ids, pooling_params))

        seq_data: Dict[int, SequenceData] = {}
        for seq_group_metadata in seq_group_metadata_list:
            seq_data.update(seq_group_metadata.seq_data)

        pooling_metadata = PoolingMetadata(
            seq_groups=seq_groups,
            seq_data=seq_data,
            prompt_lens=prompt_lens,
        )

        return pooling_metadata