model_runner_base.py 9.36 KB
Newer Older
1
import dataclasses
2
import pickle
3
from abc import ABC, abstractmethod
4
5
from datetime import datetime
from functools import wraps
6
7
from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List,
                    Optional, Type, TypeVar)
8
9

import torch
10
from torch import is_tensor
11

12
from vllm.logger import init_logger
13
from vllm.model_executor.layers.sampler import SamplerOutput
14
from vllm.platforms import current_platform
15
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
16
17
18
19
20
21

if TYPE_CHECKING:
    from vllm.attention import AttentionMetadata
    from vllm.attention.backends.abstract import AttentionBackend
    from vllm.model_executor import SamplingMetadata

22
23
logger = init_logger(__name__)

24
T = TypeVar('T', bound="BroadcastableModelInput")
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48


def _add_attn_metadata_broadcastable_dict(
        tensor_dict: Dict[str, Any],
        attn_metadata: Optional["AttentionMetadata"]) -> None:
    """
    Helper method to update tensor_dict with broadcastable
    AttentionMetadata fields.
    """
    if attn_metadata is not None:
        tensor_dict.update(attn_metadata.asdict_zerocopy())


def _init_attn_metadata_from_tensor_dict(
    attn_backend: "AttentionBackend",
    tensor_dict: Dict[str, Any],
) -> Dict[str, Any]:
    """
    Helper method to initialize AttentionMetadata based on an
    AttentionBackend and broadcastable AttentionMetadata fields.
    """
    # Extract the fields used to create AttentionMetadata.
    valid_attn_kwargs = {}
    for field in dataclasses.fields(attn_backend.get_metadata_cls()):
49
50
        if field.name in tensor_dict:
            valid_attn_kwargs[field.name] = tensor_dict.pop(field.name)
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89

    attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs)
    tensor_dict["attn_metadata"] = attn_metadata
    return tensor_dict


def _init_sampling_metadata_from_tensor_dict(  # type: ignore
        tensor_dict: Dict[str, Any]) -> Dict[str, Any]:
    """
    Helper method to initialize SamplingMetadata based on broadcastable
    SamplingMetadata fields.
    """
    from vllm.model_executor import SamplingMetadata

    selected_token_indices = tensor_dict.pop("selected_token_indices", None)
    # An empty SamplingMetadata to signal that the worker should skip
    # sampling.
    if selected_token_indices is not None:
        tensor_dict["sampling_metadata"] = SamplingMetadata(
            seq_groups=None,
            selected_token_indices=selected_token_indices,
            categorized_sample_indices=None,
            num_prompts=0,
        )
    return tensor_dict


def _add_sampling_metadata_broadcastable_dict(
        tensor_dict: Dict[str, Any],
        sampling_metadata: Optional["SamplingMetadata"]) -> None:
    """
    Helper method to update tensor_dict with broadcastable
    SamplingMetadata fields.
    """
    if sampling_metadata is not None:
        tensor_dict["selected_token_indices"] = (
            sampling_metadata.selected_token_indices)


90
91
92
def _init_frozen_model_input_from_tensor_dict(
        frozen_model_input_cls: Type["ModelRunnerInputBase"],
        tensor_dict: Dict[str, Any]) -> Dict[str, Any]:
93
    """
94
95
96
97
98
99
100
101
102
103
104
    Helper method to initialize a frozen ModelInput based on broadcastable
    """
    valid_tensor_kwargs = {}
    for field in dataclasses.fields(frozen_model_input_cls):
        val = tensor_dict.pop(field.name, None)
        if val is not None:
            valid_tensor_kwargs[field.name] = val

    frozen_model_input = frozen_model_input_cls(**valid_tensor_kwargs)
    tensor_dict["frozen_model_input"] = frozen_model_input
    return tensor_dict
105

106

107
108
109
110
111
112
113
114
115
116
117
118
def dump_input_when_exception(exclude_args: Optional[List[int]] = None,
                              exclude_kwargs: Optional[List[str]] = None):

    def _inner(func):

        @wraps(func)
        def _wrapper(*args, **kwargs):
            try:
                return func(*args, **kwargs)
            except Exception as err:
                timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
                filename = f"/tmp/err_{func.__name__}_input_{timestamp}.pkl"
119
120
                logger.info("Writing input of failed execution to %s...",
                            filename)
121
122
123
124
125
126
127
128
129
                with open(filename, "wb") as filep:
                    dumped_inputs = {
                        k: v
                        for k, v in kwargs.items()
                        if k not in (exclude_kwargs or [])
                    }
                    for i, arg in enumerate(args):
                        if i not in (exclude_args or []):
                            dumped_inputs[f"arg_{i}"] = arg
130
131
132
133
134
135
136
137
138

                    # Only persist dtype and shape for kvcache tensors
                    # (can be way to big otherwise)
                    if (kv_caches := dumped_inputs.get("kv_caches")) \
                        and isinstance(kv_caches, Iterable):
                        dumped_inputs["kv_caches"] = [(t.dtype, t.shape)
                                                      for t in kv_caches
                                                      if is_tensor(t)]

139
140
141
142
143
144
145
146
147
                    try:
                        pickle.dump(dumped_inputs, filep)
                    except Exception as pickle_err:
                        logger.warning(
                            "Failed to pickle inputs of failed execution: %s",
                            str(pickle_err))
                        raise type(err)(f"Error in model execution: "
                                        f"{str(err)}") from err

148
149
150
                    logger.info(
                        "Completed writing input of failed execution to %s.",
                        filename)
151
152
153
154
155
156
157
158
159
                raise type(err)(
                    f"Error in model execution (input dumped to {filename}): "
                    f"{str(err)}") from err

        return _wrapper

    return _inner


160
161
162
class BroadcastableModelInput(ABC):

    @abstractmethod
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
    def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
        """
        Extract broadcastable fields. Override for fields that require some
        custom deserialization.
        """
        raise NotImplementedError

    @classmethod
    @abstractmethod
    def from_broadcasted_tensor_dict(
        cls: Type[T],
        tensor_dict: Dict[str, Any],
        attn_backend: Optional["AttentionBackend"] = None,
    ) -> T:
        """
        Pop fields from the given tensor_dict and populate a new instance of
179
        BroadcastableModelInput.
180
181
182
183
        """
        raise NotImplementedError


184
185
186
187
188
189
190
191
192
193
194
195
196
197
@dataclasses.dataclass(frozen=True)
class ModelRunnerInputBase(BroadcastableModelInput):
    """Local inputs to each worker's model runner. May contain
    device-specific data. Different worker backends may have different methods
    of converting from the global ExecuteModelRequest produced by the LLM
    engine to the worker-local ModelRunnerInputBase objects.

    Model runners that support multi-GPU execution should define a
    ModelRunnerInputBase subclass, add their required fields, and specify how to
    serialize/deserialize a ModelInput for broadcast between workers.
    """
    pass


198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
class ModelRunnerInputBuilderBase(ABC, Generic[T]):
    """A builder to create ModelRunnerInputBase objects.
  """

    @abstractmethod
    def add_seq_group(self, seq_group_metadata):
        """TBA"""
        raise NotImplementedError

    @abstractmethod
    def build(self, *args, **kwargs) -> T:
        """Build metadata with on-device tensors."""
        raise NotImplementedError


213
214
215
216
217
218
219
220
221
222
class ModelRunnerBase(ABC, Generic[T]):
    """
    Model runner interface that abstracts a particular hardware and/or type of
    model. Model execution may communicate data with model runners in other
    processes, but it should not include control plane metadata communication.

    Each ModelRunnerBase subclass should define a corresponding
    ModelRunnerInputBase subclass.
    """

223
224
225
    # Map of request_id -> generator used for seeded random sampling
    generators: Dict[str, torch.Generator] = {}

226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
    @abstractmethod
    def make_model_input_from_broadcasted_tensor_dict(
        self,
        tensor_dict: Dict[str, Any],
    ) -> T:
        """
        Make an instance of a ModelRunnerInputBase from the broadcasted tensor
        dict.
        """
        raise NotImplementedError

    @abstractmethod
    def prepare_model_input(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
241
        virtual_engine: int = 0,
Mor Zusman's avatar
Mor Zusman committed
242
        finished_requests_ids: Optional[List[str]] = None,
243
244
245
246
247
248
249
250
    ) -> T:
        """
        Prepare the inputs to ModelRunnerBase.execute_model from an execution
        request. This method may move data to the worker's local device. It is
        not allowed to communicate with other workers or devices.
        """
        raise NotImplementedError

251
    @current_platform.inference_mode()
252
253
254
255
    def execute_model(
        self,
        model_input: T,
        kv_caches: Optional[List[torch.Tensor]],
256
        intermediate_tensors: Optional[IntermediateTensors],
257
258
        num_steps: int = 1,
    ) -> Optional[List[SamplerOutput]]:
259
260
261
262
        """
        Execute the model on the given input.
        """
        raise NotImplementedError
263
264
265
266
267
268
269
270
271
272
273
274

    def get_generators(self, finished_request_ids: Optional[List[str]] = None):
        """
        Return dict of per-request generators used for random sampling.
        """

        # Clean up generators from completed requests
        if finished_request_ids:
            for request_id in finished_request_ids:
                self.generators.pop(request_id, None)

        return self.generators