"csrc/vscode:/vscode.git/clone" did not exist on "c9a9db0e023a03f7e2c40e43be302fb07512946a"
cpu_pooling_model_runner.py 5.14 KB
Newer Older
1
2
3
4
5
import dataclasses
from typing import Any, Dict, List, Optional, Tuple, Type, Union

import torch

youkaichao's avatar
youkaichao committed
6
from vllm.forward_context import set_forward_context
7
8
9
10
11
12
13
14
15
16
17
18
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.multimodal import MultiModalKwargs
from vllm.pooling_params import PoolingParams
from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData,
                           SequenceGroupMetadata)
from vllm.worker.cpu_model_runner import (CPUModelRunnerBase, ModelInputForCPU,
                                          ModelInputForCPUBuilder)


@dataclasses.dataclass(frozen=True)
class ModelInputForCPUWithPoolingMetadata(ModelInputForCPU):
    """
19
    Used by the CPUPoolingModelRunner.
20
21
22
23
    """
    pooling_metadata: Optional["PoolingMetadata"] = None


24
class CPUPoolingModelRunner(
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
        CPUModelRunnerBase[ModelInputForCPUWithPoolingMetadata]):
    _model_input_cls: Type[ModelInputForCPUWithPoolingMetadata] = (
        ModelInputForCPUWithPoolingMetadata)
    _builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder

    @torch.inference_mode()
    def execute_model(
        self,
        model_input: ModelInputForCPUWithPoolingMetadata,
        kv_caches: List[torch.Tensor],
        intermediate_tensors: Optional[IntermediateTensors] = None,
        num_steps: int = 1,
    ) -> Optional[Union[List[PoolerOutput], IntermediateTensors]]:
        if num_steps > 1:
            raise ValueError(
                "CPU worker does not support multi-step execution.")

        num_layers = self.model_config.get_num_layers(self.parallel_config)
        # use an empty tensor instead of `None`` to force Dynamo to pass
        # it by reference, rather by specializing on the value ``None``.
        # the `dtype` argument does not matter, and we use `float32` as
        # a placeholder (it has wide hardware support).
        kv_caches = [
            torch.tensor([], dtype=torch.float32, device=self.device)
            for _ in range(num_layers)
        ]

        model_executable = self.model
53
54
55
        cross_enc_kwargs = {}
        if model_input.token_type_ids is not None:
            cross_enc_kwargs["token_type_ids"] = model_input.token_type_ids
56
57
58
59
60
61
62
63
64
65
66
        execute_model_kwargs = {
            "input_ids":
            model_input.input_tokens,
            "positions":
            model_input.input_positions,
            "kv_caches":
            kv_caches,
            "attn_metadata":
            model_input.attn_metadata,
            **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
                                         device=self.device),
67
            **cross_enc_kwargs,
68
69
70
71
            "intermediate_tensors":
            intermediate_tensors,
        }

youkaichao's avatar
youkaichao committed
72
73
        with set_forward_context(model_input.attn_metadata, self.vllm_config):
            hidden_states = model_executable(**execute_model_kwargs)
74

75
76
77
78
        # Only perform pooling in the driver worker.
        if not self.is_driver_worker:
            return []

79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
        return [
            self.model.pooler(hidden_states=hidden_states,
                              pooling_metadata=model_input.pooling_metadata)
        ]

    def make_model_input_from_broadcasted_tensor_dict(
            self,
            tensor_dict: Dict[str,
                              Any]) -> ModelInputForCPUWithPoolingMetadata:
        return ModelInputForCPUWithPoolingMetadata.from_broadcasted_tensor_dict(
            tensor_dict,
            attn_backend=self.attn_backend,
        )

    def prepare_model_input(
        self,
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
        virtual_engine: int = 0,
        finished_requests_ids: Optional[List[str]] = None
    ) -> ModelInputForCPUWithPoolingMetadata:
        assert seq_group_metadata_list is not None
        model_input = self._prepare_model_input_tensors(
            seq_group_metadata_list, finished_requests_ids)
        # Prepare PoolingMetadata.
        assert model_input.seq_lens is not None
        pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
                                                 model_input.seq_lens)

        return dataclasses.replace(model_input,
108
                                   virtual_engine=virtual_engine,
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
                                   pooling_metadata=pooling_metadata)

    def _prepare_pooling(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        prompt_lens: List[int],
    ) -> PoolingMetadata:
        """Prepare PoolingMetadata for the sequence group metadata list."""
        seq_groups: List[Tuple[List[int], PoolingParams]] = []
        for i, seq_group_metadata in enumerate(seq_group_metadata_list):
            seq_ids = list(seq_group_metadata.seq_data.keys())
            pooling_params = seq_group_metadata.pooling_params
            seq_groups.append((seq_ids, pooling_params))

        seq_data: Dict[int, SequenceData] = {}
        for seq_group_metadata in seq_group_metadata_list:
            seq_data.update(seq_group_metadata.seq_data)

        pooling_metadata = PoolingMetadata(
            seq_groups=seq_groups,
            seq_data=seq_data,
            prompt_lens=prompt_lens,
        )

        return pooling_metadata