embedding_model_runner.py 6.59 KB
Newer Older
1
2
import dataclasses
from typing import Any, Dict, List, Optional, Tuple, Type
3
4
5
6

import torch

from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
7
                         ModelConfig, MultiModalConfig, ParallelConfig,
8
                         PromptAdapterConfig, SchedulerConfig)
9
10
11
from vllm.logger import init_logger
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.pooling_params import PoolingParams
12
13
from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData,
                           SequenceGroupMetadata)
14
15
from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPU,
                                      ModelInputForGPUBuilder)
16
17
18
19

logger = init_logger(__name__)


20
21
22
23
24
25
26
27
28
29
30
31
@dataclasses.dataclass(frozen=True)
class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU):
    """
    Used by the EmbeddingModelRunner.
    """
    pooling_metadata: Optional["PoolingMetadata"] = None


class EmbeddingModelRunner(
        GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]):
    _model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = (
        ModelInputForGPUWithPoolingMetadata)
32
    _builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder
33
34
35
36
37
38
39
40
41
42
43
44

    def __init__(
        self,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
        device_config: DeviceConfig,
        cache_config: CacheConfig,
        load_config: LoadConfig,
        lora_config: Optional[LoRAConfig],
        kv_cache_dtype: Optional[str] = "auto",
        is_driver_worker: bool = False,
45
        prompt_adapter_config: Optional[PromptAdapterConfig] = None,
46
        multimodal_config: Optional[MultiModalConfig] = None,
47
48
49
50
51
52
53
54
55
56
    ):
        super().__init__(model_config,
                         parallel_config,
                         scheduler_config,
                         device_config,
                         cache_config,
                         load_config,
                         lora_config=lora_config,
                         kv_cache_dtype=kv_cache_dtype,
                         is_driver_worker=is_driver_worker,
57
                         prompt_adapter_config=prompt_adapter_config,
58
                         multimodal_config=multimodal_config)
59
60
61
62

    @torch.inference_mode()
    def execute_model(
        self,
63
        model_input: ModelInputForGPUWithPoolingMetadata,
64
        kv_caches: List[torch.Tensor],
65
        intermediate_tensors: Optional[IntermediateTensors] = None,
66
67
68
69
70
71
        num_steps: int = 1,
    ) -> Optional[List[PoolerOutput]]:
        if num_steps > 1:
            raise ValueError(
                "EmbeddingModelRunner does not support multi-step execution.")

72
        if self.lora_config:
73
74
75
76
            assert model_input.lora_requests is not None
            assert model_input.lora_mapping is not None
            self.set_active_loras(model_input.lora_requests,
                                  model_input.lora_mapping)
77

78
79
80
81
82
83
84
        if self.prompt_adapter_config:
            assert model_input.prompt_adapter_requests is not None
            assert model_input.prompt_adapter_mapping is not None
            self.set_active_prompt_adapters(
                model_input.prompt_adapter_requests,
                model_input.prompt_adapter_mapping)

85
        # Currently cuda graph is only supported by the decode phase.
86
87
88
        assert model_input.attn_metadata is not None
        prefill_meta = model_input.attn_metadata.prefill_metadata
        decode_meta = model_input.attn_metadata.decode_metadata
89
        virtual_engine = model_input.virtual_engine
90
        if prefill_meta is None and decode_meta.use_cuda_graph:
91
92
            assert model_input.input_tokens is not None
            graph_batch_size = model_input.input_tokens.shape[0]
93
94
            model_executable = self.graph_runners[virtual_engine][
                graph_batch_size]
95
96
97
98
99
100
101
        else:
            model_executable = self.model

        num_layers = self.model_config.get_num_layers(self.parallel_config)
        kv_caches = [None] * num_layers

        execute_model_kwargs = {
102
103
            "input_ids": model_input.input_tokens,
            "positions": model_input.input_positions,
104
            "kv_caches": kv_caches,
105
            "attn_metadata": model_input.attn_metadata,
106
            **(model_input.multi_modal_kwargs or {}),
107
        }
108

109
110
        hidden_states = model_executable(**execute_model_kwargs)

111
112
        # Only perform pooling in the driver worker.
        if not self.is_driver_worker:
113
            return []
114

115
116
117
118
        return [
            self.model.pooler(hidden_states=hidden_states,
                              pooling_metadata=model_input.pooling_metadata)
        ]
119
120
121
122
123
124
125
126
127

    def make_model_input_from_broadcasted_tensor_dict(
            self,
            tensor_dict: Dict[str,
                              Any]) -> ModelInputForGPUWithPoolingMetadata:
        return ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict(
            tensor_dict,
            attn_backend=self.attn_backend,
        )
128

129
    def prepare_model_input(
130
        self,
131
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
132
        virtual_engine: int = 0,
Mor Zusman's avatar
Mor Zusman committed
133
        finished_requests_ids: Optional[List[str]] = None
134
135
136
    ) -> ModelInputForGPUWithPoolingMetadata:
        assert seq_group_metadata_list is not None
        model_input = self._prepare_model_input_tensors(
Mor Zusman's avatar
Mor Zusman committed
137
            seq_group_metadata_list, finished_requests_ids)
138
139
140
141
142
143
144
        # Prepare PoolingMetadata.
        assert model_input.seq_lens is not None
        pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
                                                 model_input.seq_lens)

        return dataclasses.replace(model_input,
                                   pooling_metadata=pooling_metadata)
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168

    def _prepare_pooling(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        prompt_lens: List[int],
    ) -> PoolingMetadata:
        """Prepare PoolingMetadata for the sequence group metadata list."""
        seq_groups: List[Tuple[List[int], PoolingParams]] = []
        for i, seq_group_metadata in enumerate(seq_group_metadata_list):
            seq_ids = list(seq_group_metadata.seq_data.keys())
            pooling_params = seq_group_metadata.pooling_params
            seq_groups.append((seq_ids, pooling_params))

        seq_data: Dict[int, SequenceData] = {}
        for seq_group_metadata in seq_group_metadata_list:
            seq_data.update(seq_group_metadata.seq_data)

        pooling_metadata = PoolingMetadata(
            seq_groups=seq_groups,
            seq_data=seq_data,
            prompt_lens=prompt_lens,
        )

        return pooling_metadata