model_runner_base.py 10.6 KB
Newer Older
1
import dataclasses
2
import pickle
3
from abc import ABC, abstractmethod
4
5
from datetime import datetime
from functools import wraps
6
7
from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List,
                    Optional, Type, TypeVar)
8
9

import torch
10
import torch.nn as nn
11
from torch import is_tensor
12

13
from vllm.config import VllmConfig
14
from vllm.logger import init_logger
15
16
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
17
18
19
20
21
22

if TYPE_CHECKING:
    from vllm.attention import AttentionMetadata
    from vllm.attention.backends.abstract import AttentionBackend
    from vllm.model_executor import SamplingMetadata

23
24
logger = init_logger(__name__)

25
T = TypeVar('T', bound="BroadcastableModelInput")
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49


def _add_attn_metadata_broadcastable_dict(
        tensor_dict: Dict[str, Any],
        attn_metadata: Optional["AttentionMetadata"]) -> None:
    """
    Helper method to update tensor_dict with broadcastable
    AttentionMetadata fields.
    """
    if attn_metadata is not None:
        tensor_dict.update(attn_metadata.asdict_zerocopy())


def _init_attn_metadata_from_tensor_dict(
    attn_backend: "AttentionBackend",
    tensor_dict: Dict[str, Any],
) -> Dict[str, Any]:
    """
    Helper method to initialize AttentionMetadata based on an
    AttentionBackend and broadcastable AttentionMetadata fields.
    """
    # Extract the fields used to create AttentionMetadata.
    valid_attn_kwargs = {}
    for field in dataclasses.fields(attn_backend.get_metadata_cls()):
50
51
        if field.name in tensor_dict:
            valid_attn_kwargs[field.name] = tensor_dict.pop(field.name)
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90

    attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs)
    tensor_dict["attn_metadata"] = attn_metadata
    return tensor_dict


def _init_sampling_metadata_from_tensor_dict(  # type: ignore
        tensor_dict: Dict[str, Any]) -> Dict[str, Any]:
    """
    Helper method to initialize SamplingMetadata based on broadcastable
    SamplingMetadata fields.
    """
    from vllm.model_executor import SamplingMetadata

    selected_token_indices = tensor_dict.pop("selected_token_indices", None)
    # An empty SamplingMetadata to signal that the worker should skip
    # sampling.
    if selected_token_indices is not None:
        tensor_dict["sampling_metadata"] = SamplingMetadata(
            seq_groups=None,
            selected_token_indices=selected_token_indices,
            categorized_sample_indices=None,
            num_prompts=0,
        )
    return tensor_dict


def _add_sampling_metadata_broadcastable_dict(
        tensor_dict: Dict[str, Any],
        sampling_metadata: Optional["SamplingMetadata"]) -> None:
    """
    Helper method to update tensor_dict with broadcastable
    SamplingMetadata fields.
    """
    if sampling_metadata is not None:
        tensor_dict["selected_token_indices"] = (
            sampling_metadata.selected_token_indices)


91
92
93
def _init_frozen_model_input_from_tensor_dict(
        frozen_model_input_cls: Type["ModelRunnerInputBase"],
        tensor_dict: Dict[str, Any]) -> Dict[str, Any]:
94
    """
95
96
97
98
99
100
101
102
103
104
105
    Helper method to initialize a frozen ModelInput based on broadcastable
    """
    valid_tensor_kwargs = {}
    for field in dataclasses.fields(frozen_model_input_cls):
        val = tensor_dict.pop(field.name, None)
        if val is not None:
            valid_tensor_kwargs[field.name] = val

    frozen_model_input = frozen_model_input_cls(**valid_tensor_kwargs)
    tensor_dict["frozen_model_input"] = frozen_model_input
    return tensor_dict
106

107

108
109
110
111
112
113
114
115
116
117
118
119
def dump_input_when_exception(exclude_args: Optional[List[int]] = None,
                              exclude_kwargs: Optional[List[str]] = None):

    def _inner(func):

        @wraps(func)
        def _wrapper(*args, **kwargs):
            try:
                return func(*args, **kwargs)
            except Exception as err:
                timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
                filename = f"/tmp/err_{func.__name__}_input_{timestamp}.pkl"
120
121
                logger.info("Writing input of failed execution to %s...",
                            filename)
122
123
124
125
126
127
128
129
130
                with open(filename, "wb") as filep:
                    dumped_inputs = {
                        k: v
                        for k, v in kwargs.items()
                        if k not in (exclude_kwargs or [])
                    }
                    for i, arg in enumerate(args):
                        if i not in (exclude_args or []):
                            dumped_inputs[f"arg_{i}"] = arg
131
132
133
134
135
136
137
138
139

                    # Only persist dtype and shape for kvcache tensors
                    # (can be way to big otherwise)
                    if (kv_caches := dumped_inputs.get("kv_caches")) \
                        and isinstance(kv_caches, Iterable):
                        dumped_inputs["kv_caches"] = [(t.dtype, t.shape)
                                                      for t in kv_caches
                                                      if is_tensor(t)]

140
141
142
143
144
145
146
147
148
                    try:
                        pickle.dump(dumped_inputs, filep)
                    except Exception as pickle_err:
                        logger.warning(
                            "Failed to pickle inputs of failed execution: %s",
                            str(pickle_err))
                        raise type(err)(f"Error in model execution: "
                                        f"{str(err)}") from err

149
150
151
                    logger.info(
                        "Completed writing input of failed execution to %s.",
                        filename)
152
153
154
155
156
157
158
159
160
                raise type(err)(
                    f"Error in model execution (input dumped to {filename}): "
                    f"{str(err)}") from err

        return _wrapper

    return _inner


161
162
163
class BroadcastableModelInput(ABC):

    @abstractmethod
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
    def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
        """
        Extract broadcastable fields. Override for fields that require some
        custom deserialization.
        """
        raise NotImplementedError

    @classmethod
    @abstractmethod
    def from_broadcasted_tensor_dict(
        cls: Type[T],
        tensor_dict: Dict[str, Any],
        attn_backend: Optional["AttentionBackend"] = None,
    ) -> T:
        """
        Pop fields from the given tensor_dict and populate a new instance of
180
        BroadcastableModelInput.
181
182
183
184
        """
        raise NotImplementedError


185
186
187
188
189
190
191
192
193
194
195
196
197
198
@dataclasses.dataclass(frozen=True)
class ModelRunnerInputBase(BroadcastableModelInput):
    """Local inputs to each worker's model runner. May contain
    device-specific data. Different worker backends may have different methods
    of converting from the global ExecuteModelRequest produced by the LLM
    engine to the worker-local ModelRunnerInputBase objects.

    Model runners that support multi-GPU execution should define a
    ModelRunnerInputBase subclass, add their required fields, and specify how to
    serialize/deserialize a ModelInput for broadcast between workers.
    """
    pass


199
200
201
202
class ModelRunnerInputBuilderBase(ABC, Generic[T]):
    """A builder to create ModelRunnerInputBase objects.
  """

203
204
205
206
207
    @abstractmethod
    def prepare(self,
                finished_requests_ids: Optional[List[str]] = None) -> None:
        raise NotImplementedError

208
209
210
211
212
213
214
215
216
217
218
    @abstractmethod
    def add_seq_group(self, seq_group_metadata):
        """TBA"""
        raise NotImplementedError

    @abstractmethod
    def build(self, *args, **kwargs) -> T:
        """Build metadata with on-device tensors."""
        raise NotImplementedError


219
220
221
222
223
224
225
226
227
228
class ModelRunnerBase(ABC, Generic[T]):
    """
    Model runner interface that abstracts a particular hardware and/or type of
    model. Model execution may communicate data with model runners in other
    processes, but it should not include control plane metadata communication.

    Each ModelRunnerBase subclass should define a corresponding
    ModelRunnerInputBase subclass.
    """

229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
    def __init__(
        self,
        vllm_config: VllmConfig,
    ) -> None:
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.load_config = vllm_config.load_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.device_config = vllm_config.device_config
        self.speculative_config = vllm_config.speculative_config
        self.prompt_adapter_config = vllm_config.prompt_adapter_config
        self.observability_config = vllm_config.observability_config

245
246
247
    # Map of request_id -> generator used for seeded random sampling
    generators: Dict[str, torch.Generator] = {}

248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
    @abstractmethod
    def make_model_input_from_broadcasted_tensor_dict(
        self,
        tensor_dict: Dict[str, Any],
    ) -> T:
        """
        Make an instance of a ModelRunnerInputBase from the broadcasted tensor
        dict.
        """
        raise NotImplementedError

    @abstractmethod
    def prepare_model_input(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
263
        virtual_engine: int = 0,
Mor Zusman's avatar
Mor Zusman committed
264
        finished_requests_ids: Optional[List[str]] = None,
265
266
267
268
269
270
271
272
    ) -> T:
        """
        Prepare the inputs to ModelRunnerBase.execute_model from an execution
        request. This method may move data to the worker's local device. It is
        not allowed to communicate with other workers or devices.
        """
        raise NotImplementedError

273
274
275
276
    @abstractmethod
    def get_model(self) -> nn.Module:
        raise NotImplementedError

277
278
279
280
    def execute_model(
        self,
        model_input: T,
        kv_caches: Optional[List[torch.Tensor]],
281
        intermediate_tensors: Optional[IntermediateTensors] = None,
282
        num_steps: int = 1,
283
        **kwargs,
284
    ) -> Optional[List[SamplerOutput]]:
285
286
287
288
        """
        Execute the model on the given input.
        """
        raise NotImplementedError
289
290
291
292
293
294
295
296
297
298
299
300

    def get_generators(self, finished_request_ids: Optional[List[str]] = None):
        """
        Return dict of per-request generators used for random sampling.
        """

        # Clean up generators from completed requests
        if finished_request_ids:
            for request_id in finished_request_ids:
                self.generators.pop(request_id, None)

        return self.generators
301
302
303
304
305
306
307
308
309


class ModelRunnerWrapperBase:
    """
    The whole point of this class is to lazily initialize the model_runner.
    """

    def __init__(
        self,
310
        model_runner: ModelRunnerBase,
311
    ) -> None:
312
        self.model_runner: ModelRunnerBase = model_runner
313
314
315

    def __getattr__(self, attr):
        return getattr(self.model_runner, attr)