embedding_model_runner.py 6.92 KB
Newer Older
1
2
import dataclasses
from typing import Any, Dict, List, Optional, Tuple, Type
3
4
5
6

import torch

from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
7
8
                         ModelConfig, MultiModalConfig, ObservabilityConfig,
                         ParallelConfig, PromptAdapterConfig, SchedulerConfig)
9
10
from vllm.logger import init_logger
from vllm.model_executor.pooling_metadata import PoolingMetadata
11
from vllm.multimodal import MultiModalInputs
12
from vllm.pooling_params import PoolingParams
13
14
from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData,
                           SequenceGroupMetadata)
15
16
from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPU,
                                      ModelInputForGPUBuilder)
17
18
19
20

logger = init_logger(__name__)


21
22
23
24
25
26
27
28
29
30
31
32
@dataclasses.dataclass(frozen=True)
class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU):
    """
    Used by the EmbeddingModelRunner.
    """
    pooling_metadata: Optional["PoolingMetadata"] = None


class EmbeddingModelRunner(
        GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]):
    _model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = (
        ModelInputForGPUWithPoolingMetadata)
33
    _builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder
34
35
36
37
38
39
40
41
42
43
44
45

    def __init__(
        self,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
        device_config: DeviceConfig,
        cache_config: CacheConfig,
        load_config: LoadConfig,
        lora_config: Optional[LoRAConfig],
        kv_cache_dtype: Optional[str] = "auto",
        is_driver_worker: bool = False,
46
        prompt_adapter_config: Optional[PromptAdapterConfig] = None,
47
        multimodal_config: Optional[MultiModalConfig] = None,
48
        observability_config: Optional[ObservabilityConfig] = None,
49
50
51
52
53
54
55
56
57
58
    ):
        super().__init__(model_config,
                         parallel_config,
                         scheduler_config,
                         device_config,
                         cache_config,
                         load_config,
                         lora_config=lora_config,
                         kv_cache_dtype=kv_cache_dtype,
                         is_driver_worker=is_driver_worker,
59
                         prompt_adapter_config=prompt_adapter_config,
60
61
                         multimodal_config=multimodal_config,
                         observability_config=observability_config)
62
63
64
65

    @torch.inference_mode()
    def execute_model(
        self,
66
        model_input: ModelInputForGPUWithPoolingMetadata,
67
        kv_caches: List[torch.Tensor],
68
        intermediate_tensors: Optional[IntermediateTensors] = None,
69
70
71
72
73
74
        num_steps: int = 1,
    ) -> Optional[List[PoolerOutput]]:
        if num_steps > 1:
            raise ValueError(
                "EmbeddingModelRunner does not support multi-step execution.")

75
        if self.lora_config:
76
77
78
79
            assert model_input.lora_requests is not None
            assert model_input.lora_mapping is not None
            self.set_active_loras(model_input.lora_requests,
                                  model_input.lora_mapping)
80

81
82
83
84
85
86
87
        if self.prompt_adapter_config:
            assert model_input.prompt_adapter_requests is not None
            assert model_input.prompt_adapter_mapping is not None
            self.set_active_prompt_adapters(
                model_input.prompt_adapter_requests,
                model_input.prompt_adapter_mapping)

88
        # Currently cuda graph is only supported by the decode phase.
89
90
91
        assert model_input.attn_metadata is not None
        prefill_meta = model_input.attn_metadata.prefill_metadata
        decode_meta = model_input.attn_metadata.decode_metadata
92
        virtual_engine = model_input.virtual_engine
93
        if prefill_meta is None and decode_meta.use_cuda_graph:
94
95
            assert model_input.input_tokens is not None
            graph_batch_size = model_input.input_tokens.shape[0]
96
97
            model_executable = self.graph_runners[virtual_engine][
                graph_batch_size]
98
99
100
101
102
103
104
        else:
            model_executable = self.model

        num_layers = self.model_config.get_num_layers(self.parallel_config)
        kv_caches = [None] * num_layers

        execute_model_kwargs = {
105
106
107
108
109
110
111
112
113
114
            "input_ids":
            model_input.input_tokens,
            "positions":
            model_input.input_positions,
            "kv_caches":
            kv_caches,
            "attn_metadata":
            model_input.attn_metadata,
            **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {},
                                         device=self.device),
115
        }
116

117
118
        hidden_states = model_executable(**execute_model_kwargs)

119
120
        # Only perform pooling in the driver worker.
        if not self.is_driver_worker:
121
            return []
122

123
124
125
126
        return [
            self.model.pooler(hidden_states=hidden_states,
                              pooling_metadata=model_input.pooling_metadata)
        ]
127
128
129
130
131
132
133
134
135

    def make_model_input_from_broadcasted_tensor_dict(
            self,
            tensor_dict: Dict[str,
                              Any]) -> ModelInputForGPUWithPoolingMetadata:
        return ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict(
            tensor_dict,
            attn_backend=self.attn_backend,
        )
136

137
    def prepare_model_input(
138
        self,
139
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
140
        virtual_engine: int = 0,
Mor Zusman's avatar
Mor Zusman committed
141
        finished_requests_ids: Optional[List[str]] = None
142
143
144
    ) -> ModelInputForGPUWithPoolingMetadata:
        assert seq_group_metadata_list is not None
        model_input = self._prepare_model_input_tensors(
Mor Zusman's avatar
Mor Zusman committed
145
            seq_group_metadata_list, finished_requests_ids)
146
147
148
149
150
151
152
        # Prepare PoolingMetadata.
        assert model_input.seq_lens is not None
        pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
                                                 model_input.seq_lens)

        return dataclasses.replace(model_input,
                                   pooling_metadata=pooling_metadata)
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176

    def _prepare_pooling(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        prompt_lens: List[int],
    ) -> PoolingMetadata:
        """Prepare PoolingMetadata for the sequence group metadata list."""
        seq_groups: List[Tuple[List[int], PoolingParams]] = []
        for i, seq_group_metadata in enumerate(seq_group_metadata_list):
            seq_ids = list(seq_group_metadata.seq_data.keys())
            pooling_params = seq_group_metadata.pooling_params
            seq_groups.append((seq_ids, pooling_params))

        seq_data: Dict[int, SequenceData] = {}
        for seq_group_metadata in seq_group_metadata_list:
            seq_data.update(seq_group_metadata.seq_data)

        pooling_metadata = PoolingMetadata(
            seq_groups=seq_groups,
            seq_data=seq_data,
            prompt_lens=prompt_lens,
        )

        return pooling_metadata