embedding_model_runner.py 5.64 KB
Newer Older
1
2
import dataclasses
from typing import Any, Dict, List, Optional, Tuple, Type
3
4
5
6
7
8
9
10
11
12

import torch

from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
                         ModelConfig, ParallelConfig, SchedulerConfig,
                         VisionLanguageConfig)
from vllm.logger import init_logger
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.pooling_params import PoolingParams
from vllm.sequence import PoolerOutput, SequenceData, SequenceGroupMetadata
13
from vllm.worker.model_runner import GPUModelRunnerBase, ModelInputForGPU
14
15
16
17

logger = init_logger(__name__)


18
19
20
21
22
23
24
25
26
27
28
29
@dataclasses.dataclass(frozen=True)
class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU):
    """
    Used by the EmbeddingModelRunner.
    """
    pooling_metadata: Optional["PoolingMetadata"] = None


class EmbeddingModelRunner(
        GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]):
    _model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = (
        ModelInputForGPUWithPoolingMetadata)
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57

    def __init__(
        self,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
        device_config: DeviceConfig,
        cache_config: CacheConfig,
        load_config: LoadConfig,
        lora_config: Optional[LoRAConfig],
        kv_cache_dtype: Optional[str] = "auto",
        is_driver_worker: bool = False,
        vision_language_config: Optional[VisionLanguageConfig] = None,
    ):
        super().__init__(model_config,
                         parallel_config,
                         scheduler_config,
                         device_config,
                         cache_config,
                         load_config,
                         lora_config=lora_config,
                         kv_cache_dtype=kv_cache_dtype,
                         is_driver_worker=is_driver_worker,
                         vision_language_config=vision_language_config)

    @torch.inference_mode()
    def execute_model(
        self,
58
        model_input: ModelInputForGPUWithPoolingMetadata,
59
60
61
        kv_caches: List[torch.Tensor],
    ) -> Optional[PoolerOutput]:
        if self.lora_config:
62
63
64
65
            assert model_input.lora_requests is not None
            assert model_input.lora_mapping is not None
            self.set_active_loras(model_input.lora_requests,
                                  model_input.lora_mapping)
66
67

        # Currently cuda graph is only supported by the decode phase.
68
69
70
        assert model_input.attn_metadata is not None
        prefill_meta = model_input.attn_metadata.prefill_metadata
        decode_meta = model_input.attn_metadata.decode_metadata
71
        if prefill_meta is None and decode_meta.use_cuda_graph:
72
73
            assert model_input.input_tokens is not None
            graph_batch_size = model_input.input_tokens.shape[0]
74
75
76
77
78
79
80
81
            model_executable = self.graph_runners[graph_batch_size]
        else:
            model_executable = self.model

        num_layers = self.model_config.get_num_layers(self.parallel_config)
        kv_caches = [None] * num_layers

        execute_model_kwargs = {
82
83
            "input_ids": model_input.input_tokens,
            "positions": model_input.input_positions,
84
            "kv_caches": kv_caches,
85
            "attn_metadata": model_input.attn_metadata,
86
87
        }
        if self.vision_language_config:
88
89
            multi_modal_kwargs = model_input.multi_modal_kwargs or {}
            execute_model_kwargs.update({"image_input": multi_modal_kwargs})
90
91
        hidden_states = model_executable(**execute_model_kwargs)

92
93
94
95
        # Only perform pooling in the driver worker.
        if not self.is_driver_worker:
            return None

96
        return self.model.pooler(hidden_states=hidden_states,
97
98
99
100
101
102
103
104
105
106
                                 pooling_metadata=model_input.pooling_metadata)

    def make_model_input_from_broadcasted_tensor_dict(
            self,
            tensor_dict: Dict[str,
                              Any]) -> ModelInputForGPUWithPoolingMetadata:
        return ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict(
            tensor_dict,
            attn_backend=self.attn_backend,
        )
107

108
    def prepare_model_input(
109
        self,
110
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
111
112
113
114
115
116
117
118
119
120
121
    ) -> ModelInputForGPUWithPoolingMetadata:
        assert seq_group_metadata_list is not None
        model_input = self._prepare_model_input_tensors(
            seq_group_metadata_list)
        # Prepare PoolingMetadata.
        assert model_input.seq_lens is not None
        pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
                                                 model_input.seq_lens)

        return dataclasses.replace(model_input,
                                   pooling_metadata=pooling_metadata)
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145

    def _prepare_pooling(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        prompt_lens: List[int],
    ) -> PoolingMetadata:
        """Prepare PoolingMetadata for the sequence group metadata list."""
        seq_groups: List[Tuple[List[int], PoolingParams]] = []
        for i, seq_group_metadata in enumerate(seq_group_metadata_list):
            seq_ids = list(seq_group_metadata.seq_data.keys())
            pooling_params = seq_group_metadata.pooling_params
            seq_groups.append((seq_ids, pooling_params))

        seq_data: Dict[int, SequenceData] = {}
        for seq_group_metadata in seq_group_metadata_list:
            seq_data.update(seq_group_metadata.seq_data)

        pooling_metadata = PoolingMetadata(
            seq_groups=seq_groups,
            seq_data=seq_data,
            prompt_lens=prompt_lens,
        )

        return pooling_metadata