model_runner_base.py 9.47 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import dataclasses
from abc import ABC, abstractmethod
6
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
7
                    TypeVar)
8
9

import torch
10
import torch.nn as nn
11

12
from vllm.config import VllmConfig
13
from vllm.logger import init_logger
14
from vllm.model_executor.layers.sampler import SamplerOutput
15
16
from vllm.model_executor.models.interfaces_base import is_pooling_model
from vllm.pooling_params import PoolingTask
17
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
18
19
20
21
22
23

if TYPE_CHECKING:
    from vllm.attention import AttentionMetadata
    from vllm.attention.backends.abstract import AttentionBackend
    from vllm.model_executor import SamplingMetadata

24
25
logger = init_logger(__name__)

26
T = TypeVar('T', bound="BroadcastableModelInput")
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50


def _add_attn_metadata_broadcastable_dict(
        tensor_dict: Dict[str, Any],
        attn_metadata: Optional["AttentionMetadata"]) -> None:
    """
    Helper method to update tensor_dict with broadcastable
    AttentionMetadata fields.
    """
    if attn_metadata is not None:
        tensor_dict.update(attn_metadata.asdict_zerocopy())


def _init_attn_metadata_from_tensor_dict(
    attn_backend: "AttentionBackend",
    tensor_dict: Dict[str, Any],
) -> Dict[str, Any]:
    """
    Helper method to initialize AttentionMetadata based on an
    AttentionBackend and broadcastable AttentionMetadata fields.
    """
    # Extract the fields used to create AttentionMetadata.
    valid_attn_kwargs = {}
    for field in dataclasses.fields(attn_backend.get_metadata_cls()):
51
        if field.name in tensor_dict:
52
53
54
55
            if field.name == "input_positions":
                valid_attn_kwargs[field.name] = tensor_dict[field.name]
            else:
                valid_attn_kwargs[field.name] = tensor_dict.pop(field.name)
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94

    attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs)
    tensor_dict["attn_metadata"] = attn_metadata
    return tensor_dict


def _init_sampling_metadata_from_tensor_dict(  # type: ignore
        tensor_dict: Dict[str, Any]) -> Dict[str, Any]:
    """
    Helper method to initialize SamplingMetadata based on broadcastable
    SamplingMetadata fields.
    """
    from vllm.model_executor import SamplingMetadata

    selected_token_indices = tensor_dict.pop("selected_token_indices", None)
    # An empty SamplingMetadata to signal that the worker should skip
    # sampling.
    if selected_token_indices is not None:
        tensor_dict["sampling_metadata"] = SamplingMetadata(
            seq_groups=None,
            selected_token_indices=selected_token_indices,
            categorized_sample_indices=None,
            num_prompts=0,
        )
    return tensor_dict


def _add_sampling_metadata_broadcastable_dict(
        tensor_dict: Dict[str, Any],
        sampling_metadata: Optional["SamplingMetadata"]) -> None:
    """
    Helper method to update tensor_dict with broadcastable
    SamplingMetadata fields.
    """
    if sampling_metadata is not None:
        tensor_dict["selected_token_indices"] = (
            sampling_metadata.selected_token_indices)


95
96
97
def _init_frozen_model_input_from_tensor_dict(
        frozen_model_input_cls: Type["ModelRunnerInputBase"],
        tensor_dict: Dict[str, Any]) -> Dict[str, Any]:
98
    """
99
100
101
102
103
104
105
106
107
108
109
    Helper method to initialize a frozen ModelInput based on broadcastable
    """
    valid_tensor_kwargs = {}
    for field in dataclasses.fields(frozen_model_input_cls):
        val = tensor_dict.pop(field.name, None)
        if val is not None:
            valid_tensor_kwargs[field.name] = val

    frozen_model_input = frozen_model_input_cls(**valid_tensor_kwargs)
    tensor_dict["frozen_model_input"] = frozen_model_input
    return tensor_dict
110

111
112
113
114

class BroadcastableModelInput(ABC):

    @abstractmethod
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
    def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
        """
        Extract broadcastable fields. Override for fields that require some
        custom deserialization.
        """
        raise NotImplementedError

    @classmethod
    @abstractmethod
    def from_broadcasted_tensor_dict(
        cls: Type[T],
        tensor_dict: Dict[str, Any],
        attn_backend: Optional["AttentionBackend"] = None,
    ) -> T:
        """
        Pop fields from the given tensor_dict and populate a new instance of
131
        BroadcastableModelInput.
132
133
134
135
        """
        raise NotImplementedError


136
137
138
139
140
141
142
143
144
145
146
147
148
149
@dataclasses.dataclass(frozen=True)
class ModelRunnerInputBase(BroadcastableModelInput):
    """Local inputs to each worker's model runner. May contain
    device-specific data. Different worker backends may have different methods
    of converting from the global ExecuteModelRequest produced by the LLM
    engine to the worker-local ModelRunnerInputBase objects.

    Model runners that support multi-GPU execution should define a
    ModelRunnerInputBase subclass, add their required fields, and specify how to
    serialize/deserialize a ModelInput for broadcast between workers.
    """
    pass


150
151
152
153
class ModelRunnerInputBuilderBase(ABC, Generic[T]):
    """A builder to create ModelRunnerInputBase objects.
  """

154
155
156
157
158
    @abstractmethod
    def prepare(self,
                finished_requests_ids: Optional[List[str]] = None) -> None:
        raise NotImplementedError

159
160
161
162
163
164
165
166
167
168
169
    @abstractmethod
    def add_seq_group(self, seq_group_metadata):
        """TBA"""
        raise NotImplementedError

    @abstractmethod
    def build(self, *args, **kwargs) -> T:
        """Build metadata with on-device tensors."""
        raise NotImplementedError


170
171
172
173
174
175
176
177
178
179
class ModelRunnerBase(ABC, Generic[T]):
    """
    Model runner interface that abstracts a particular hardware and/or type of
    model. Model execution may communicate data with model runners in other
    processes, but it should not include control plane metadata communication.

    Each ModelRunnerBase subclass should define a corresponding
    ModelRunnerInputBase subclass.
    """

180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
    def __init__(
        self,
        vllm_config: VllmConfig,
    ) -> None:
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.load_config = vllm_config.load_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.device_config = vllm_config.device_config
        self.speculative_config = vllm_config.speculative_config
        self.observability_config = vllm_config.observability_config

195
196
197
    # Map of request_id -> generator used for seeded random sampling
    generators: Dict[str, torch.Generator] = {}

198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
    @abstractmethod
    def make_model_input_from_broadcasted_tensor_dict(
        self,
        tensor_dict: Dict[str, Any],
    ) -> T:
        """
        Make an instance of a ModelRunnerInputBase from the broadcasted tensor
        dict.
        """
        raise NotImplementedError

    @abstractmethod
    def prepare_model_input(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
213
        virtual_engine: int = 0,
Mor Zusman's avatar
Mor Zusman committed
214
        finished_requests_ids: Optional[List[str]] = None,
215
216
217
218
219
220
221
222
    ) -> T:
        """
        Prepare the inputs to ModelRunnerBase.execute_model from an execution
        request. This method may move data to the worker's local device. It is
        not allowed to communicate with other workers or devices.
        """
        raise NotImplementedError

223
224
225
226
    @abstractmethod
    def get_model(self) -> nn.Module:
        raise NotImplementedError

227
228
229
230
231
    def get_supported_pooling_tasks(self) -> list[PoolingTask]:
        model = self.get_model()
        if not is_pooling_model(model):
            return []

232
        return list(model.pooler.get_supported_tasks())
233

234
235
236
237
    def execute_model(
        self,
        model_input: T,
        kv_caches: Optional[List[torch.Tensor]],
238
        intermediate_tensors: Optional[IntermediateTensors] = None,
239
        num_steps: int = 1,
240
        **kwargs,
241
    ) -> Optional[List[SamplerOutput]]:
242
243
244
245
        """
        Execute the model on the given input.
        """
        raise NotImplementedError
246
247
248
249
250
251
252
253
254
255
256
257

    def get_generators(self, finished_request_ids: Optional[List[str]] = None):
        """
        Return dict of per-request generators used for random sampling.
        """

        # Clean up generators from completed requests
        if finished_request_ids:
            for request_id in finished_request_ids:
                self.generators.pop(request_id, None)

        return self.generators
258
259
260
261
262
263
264
265
266


class ModelRunnerWrapperBase:
    """
    The whole point of this class is to lazily initialize the model_runner.
    """

    def __init__(
        self,
267
        model_runner: ModelRunnerBase,
268
    ) -> None:
269
        self.model_runner: ModelRunnerBase = model_runner
270
271
272

    def __getattr__(self, attr):
        return getattr(self.model_runner, attr)
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290


class InputProcessingError(Exception):
    """This exception is raised when an error occurs preparing the inputs for
    a single sequence group.
    This allows the engine to gracefully handle errors with a single sequence
    group without having to fail the entire batch.
    """

    def __init__(self, request_id, message):
        """request_id is the id of the offending sequence group"""
        self.request_id = request_id
        self.message = message
        super().__init__(self.message)

    def __str__(self):
        return "Failed to prepare inputs for sequence group with request id: " \
                f"{self.request_id}, Error: {self.message}"