model_runner_base.py 10 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import dataclasses
from abc import ABC, abstractmethod
6
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
7
                    TypeVar)
8
9

import torch
10
import torch.nn as nn
11

12
from vllm.config import VllmConfig
13
from vllm.logger import init_logger
14
from vllm.model_executor.layers.sampler import SamplerOutput
15
from vllm.model_executor.models.interfaces import supports_transcription
16
from vllm.model_executor.models.interfaces_base import is_text_generation_model
17
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
18
from vllm.tasks import GenerationTask, SupportedTask
19
20
21
22
23
24

if TYPE_CHECKING:
    from vllm.attention import AttentionMetadata
    from vllm.attention.backends.abstract import AttentionBackend
    from vllm.model_executor import SamplingMetadata

25
26
logger = init_logger(__name__)

27
T = TypeVar('T', bound="BroadcastableModelInput")
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51


def _add_attn_metadata_broadcastable_dict(
        tensor_dict: Dict[str, Any],
        attn_metadata: Optional["AttentionMetadata"]) -> None:
    """
    Helper method to update tensor_dict with broadcastable
    AttentionMetadata fields.
    """
    if attn_metadata is not None:
        tensor_dict.update(attn_metadata.asdict_zerocopy())


def _init_attn_metadata_from_tensor_dict(
    attn_backend: "AttentionBackend",
    tensor_dict: Dict[str, Any],
) -> Dict[str, Any]:
    """
    Helper method to initialize AttentionMetadata based on an
    AttentionBackend and broadcastable AttentionMetadata fields.
    """
    # Extract the fields used to create AttentionMetadata.
    valid_attn_kwargs = {}
    for field in dataclasses.fields(attn_backend.get_metadata_cls()):
52
        if field.name in tensor_dict:
53
54
55
56
            if field.name == "input_positions":
                valid_attn_kwargs[field.name] = tensor_dict[field.name]
            else:
                valid_attn_kwargs[field.name] = tensor_dict.pop(field.name)
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95

    attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs)
    tensor_dict["attn_metadata"] = attn_metadata
    return tensor_dict


def _init_sampling_metadata_from_tensor_dict(  # type: ignore
        tensor_dict: Dict[str, Any]) -> Dict[str, Any]:
    """
    Helper method to initialize SamplingMetadata based on broadcastable
    SamplingMetadata fields.
    """
    from vllm.model_executor import SamplingMetadata

    selected_token_indices = tensor_dict.pop("selected_token_indices", None)
    # An empty SamplingMetadata to signal that the worker should skip
    # sampling.
    if selected_token_indices is not None:
        tensor_dict["sampling_metadata"] = SamplingMetadata(
            seq_groups=None,
            selected_token_indices=selected_token_indices,
            categorized_sample_indices=None,
            num_prompts=0,
        )
    return tensor_dict


def _add_sampling_metadata_broadcastable_dict(
        tensor_dict: Dict[str, Any],
        sampling_metadata: Optional["SamplingMetadata"]) -> None:
    """
    Helper method to update tensor_dict with broadcastable
    SamplingMetadata fields.
    """
    if sampling_metadata is not None:
        tensor_dict["selected_token_indices"] = (
            sampling_metadata.selected_token_indices)


96
97
98
def _init_frozen_model_input_from_tensor_dict(
        frozen_model_input_cls: Type["ModelRunnerInputBase"],
        tensor_dict: Dict[str, Any]) -> Dict[str, Any]:
99
    """
100
101
102
103
104
105
106
107
108
109
110
    Helper method to initialize a frozen ModelInput based on broadcastable
    """
    valid_tensor_kwargs = {}
    for field in dataclasses.fields(frozen_model_input_cls):
        val = tensor_dict.pop(field.name, None)
        if val is not None:
            valid_tensor_kwargs[field.name] = val

    frozen_model_input = frozen_model_input_cls(**valid_tensor_kwargs)
    tensor_dict["frozen_model_input"] = frozen_model_input
    return tensor_dict
111

112
113
114
115

class BroadcastableModelInput(ABC):

    @abstractmethod
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
    def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
        """
        Extract broadcastable fields. Override for fields that require some
        custom deserialization.
        """
        raise NotImplementedError

    @classmethod
    @abstractmethod
    def from_broadcasted_tensor_dict(
        cls: Type[T],
        tensor_dict: Dict[str, Any],
        attn_backend: Optional["AttentionBackend"] = None,
    ) -> T:
        """
        Pop fields from the given tensor_dict and populate a new instance of
132
        BroadcastableModelInput.
133
134
135
136
        """
        raise NotImplementedError


137
138
139
140
141
142
143
144
145
146
147
148
149
150
@dataclasses.dataclass(frozen=True)
class ModelRunnerInputBase(BroadcastableModelInput):
    """Local inputs to each worker's model runner. May contain
    device-specific data. Different worker backends may have different methods
    of converting from the global ExecuteModelRequest produced by the LLM
    engine to the worker-local ModelRunnerInputBase objects.

    Model runners that support multi-GPU execution should define a
    ModelRunnerInputBase subclass, add their required fields, and specify how to
    serialize/deserialize a ModelInput for broadcast between workers.
    """
    pass


151
152
153
154
class ModelRunnerInputBuilderBase(ABC, Generic[T]):
    """A builder to create ModelRunnerInputBase objects.
  """

155
156
157
158
159
    @abstractmethod
    def prepare(self,
                finished_requests_ids: Optional[List[str]] = None) -> None:
        raise NotImplementedError

160
161
162
163
164
165
166
167
168
169
170
    @abstractmethod
    def add_seq_group(self, seq_group_metadata):
        """TBA"""
        raise NotImplementedError

    @abstractmethod
    def build(self, *args, **kwargs) -> T:
        """Build metadata with on-device tensors."""
        raise NotImplementedError


171
172
173
174
175
176
177
178
179
180
class ModelRunnerBase(ABC, Generic[T]):
    """
    Model runner interface that abstracts a particular hardware and/or type of
    model. Model execution may communicate data with model runners in other
    processes, but it should not include control plane metadata communication.

    Each ModelRunnerBase subclass should define a corresponding
    ModelRunnerInputBase subclass.
    """

181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
    def __init__(
        self,
        vllm_config: VllmConfig,
    ) -> None:
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.load_config = vllm_config.load_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.device_config = vllm_config.device_config
        self.speculative_config = vllm_config.speculative_config
        self.observability_config = vllm_config.observability_config

196
197
198
    # Map of request_id -> generator used for seeded random sampling
    generators: Dict[str, torch.Generator] = {}

199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
    @abstractmethod
    def make_model_input_from_broadcasted_tensor_dict(
        self,
        tensor_dict: Dict[str, Any],
    ) -> T:
        """
        Make an instance of a ModelRunnerInputBase from the broadcasted tensor
        dict.
        """
        raise NotImplementedError

    @abstractmethod
    def prepare_model_input(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
214
        virtual_engine: int = 0,
Mor Zusman's avatar
Mor Zusman committed
215
        finished_requests_ids: Optional[List[str]] = None,
216
217
218
219
220
221
222
223
    ) -> T:
        """
        Prepare the inputs to ModelRunnerBase.execute_model from an execution
        request. This method may move data to the worker's local device. It is
        not allowed to communicate with other workers or devices.
        """
        raise NotImplementedError

224
225
226
227
    @abstractmethod
    def get_model(self) -> nn.Module:
        raise NotImplementedError

228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
    def get_supported_generation_tasks(self) -> list[GenerationTask]:
        model = self.get_model()
        supported_tasks = list[GenerationTask]()

        if is_text_generation_model(model):
            supported_tasks.append("generate")

        if supports_transcription(model):
            if model.supports_transcription_only:
                return ["transcription"]

            supported_tasks.append("transcription")

        return supported_tasks

    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        tasks = list[SupportedTask]()

        if self.model_config.runner_type == "generate":
            tasks.extend(self.get_supported_generation_tasks())

        return tuple(tasks)

251
252
253
254
    def execute_model(
        self,
        model_input: T,
        kv_caches: Optional[List[torch.Tensor]],
255
        intermediate_tensors: Optional[IntermediateTensors] = None,
256
        num_steps: int = 1,
257
        **kwargs,
258
    ) -> Optional[List[SamplerOutput]]:
259
260
261
262
        """
        Execute the model on the given input.
        """
        raise NotImplementedError
263
264
265
266
267
268
269
270
271
272
273
274

    def get_generators(self, finished_request_ids: Optional[List[str]] = None):
        """
        Return dict of per-request generators used for random sampling.
        """

        # Clean up generators from completed requests
        if finished_request_ids:
            for request_id in finished_request_ids:
                self.generators.pop(request_id, None)

        return self.generators
275
276
277
278
279
280
281
282
283


class ModelRunnerWrapperBase:
    """
    The whole point of this class is to lazily initialize the model_runner.
    """

    def __init__(
        self,
284
        model_runner: ModelRunnerBase,
285
    ) -> None:
286
        self.model_runner: ModelRunnerBase = model_runner
287
288
289

    def __getattr__(self, attr):
        return getattr(self.model_runner, attr)
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307


class InputProcessingError(Exception):
    """This exception is raised when an error occurs preparing the inputs for
    a single sequence group.
    This allows the engine to gracefully handle errors with a single sequence
    group without having to fail the entire batch.
    """

    def __init__(self, request_id, message):
        """request_id is the id of the offending sequence group"""
        self.request_id = request_id
        self.message = message
        super().__init__(self.message)

    def __str__(self):
        return "Failed to prepare inputs for sequence group with request id: " \
                f"{self.request_id}, Error: {self.message}"