neuron_model_runner.py 10.8 KB
Newer Older
1
from dataclasses import dataclass
2
from importlib.util import find_spec
3
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
4
5

import torch
6
from torch import nn
7
8
9
10
11

from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
                         SchedulerConfig)
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
12
from vllm.model_executor.layers.sampler import SamplerOutput
13
from vllm.model_executor.model_loader.neuron import get_neuron_model
14
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
15
                             MultiModalInputs)
16
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
17
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
18
19
20
21
from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase

if TYPE_CHECKING:
    from vllm.attention.backends.abstract import AttentionBackend
22
23
24
25

logger = init_logger(__name__)


26
27
28
29
30
31
32
33
34
@dataclass(frozen=True)
class ModelInputForNeuron(ModelRunnerInputBase):
    """
    Used by the NeuronModelRunner.
    """
    input_tokens: Optional[torch.Tensor] = None
    input_positions: Optional[torch.Tensor] = None
    input_block_ids: Optional[torch.Tensor] = None
    sampling_metadata: Optional["SamplingMetadata"] = None
35
    multi_modal_kwargs: Optional[BatchedTensorInputs] = None
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51

    def as_broadcastable_tensor_dict(
            self) -> Dict[str, Union[int, torch.Tensor]]:
        raise NotImplementedError("ModelInputForNeuron cannot be broadcast.")

    @classmethod
    def from_broadcasted_tensor_dict(
        cls,
        tensor_dict: Dict[str, Any],
        attn_backend: Optional["AttentionBackend"] = None,
    ) -> "ModelInputForNeuron":
        assert attn_backend is None
        return cls.from_broadcasted_tensor_dict(tensor_dict)


class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71

    def __init__(
        self,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
        device_config: DeviceConfig,
    ):
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config

        if model_config is not None and model_config.get_sliding_window():
            logger.warning("Sliding window is not supported on Neuron. "
                           "The model will run without sliding window.")
        self.device_config = (device_config
                              if device_config is not None else DeviceConfig())
        self.device = self.device_config.device
        self.pin_memory = is_pin_memory_available()

72
73
74
75
        # Multi-modal data support
        self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
            .create_input_mapper(self.model_config)

76
77
78
        # Lazy initialization.
        self.model: nn.Module  # initialize after load_model.

79
    def load_model(self) -> None:
80
81
82
83
84
85
86
87
        if find_spec("transformers_neuronx") is not None:
            self.model = get_neuron_model(
                self.model_config,
                parallel_config=self.parallel_config,
                scheduler_config=self.scheduler_config)
        else:
            raise NotImplementedError(
                "Supports only Transformer-NeuronX based models.")
88
89
90
91

    def _prepare_prompt(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
92
93
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int],
               BatchedTensorInputs]:
94
95
96
97
98
        assert len(seq_group_metadata_list) > 0
        input_tokens: List[List[int]] = []
        input_positions: List[List[int]] = []
        input_block_ids: List[int] = []

99
        seq_lens: List[int] = []
100
        multi_modal_inputs_list: List[MultiModalInputs] = []
101
102
103
104
105
106
107
108
        for seq_group_metadata in seq_group_metadata_list:
            assert seq_group_metadata.is_prompt
            seq_ids = list(seq_group_metadata.seq_data.keys())
            assert len(seq_ids) == 1
            seq_id = seq_ids[0]

            seq_data = seq_group_metadata.seq_data[seq_id]
            prompt_tokens = seq_data.get_token_ids()
109
110
            seq_len = len(prompt_tokens)
            seq_lens.append(seq_len)
111
112

            input_tokens.append(prompt_tokens)
113
            input_positions.append(list(range(seq_len)))
114
115
116
117
118
119

            assert seq_group_metadata.block_tables is not None
            block_table = seq_group_metadata.block_tables[seq_id]
            assert len(block_table) == 1
            input_block_ids.append(block_table[0])

120
121
122
123
124
125
            mm_data = seq_group_metadata.multi_modal_data
            if mm_data:
                # Process multi-modal data
                mm_kwargs = self.multi_modal_input_mapper(mm_data)
                multi_modal_inputs_list.append(mm_kwargs)

126
127
        max_seq_len = max(seq_lens)
        assert max_seq_len > 0
128
129
        input_tokens = make_tensor_with_pad(input_tokens,
                                            pad=0,
130
                                            max_len=max_seq_len,
131
132
133
134
                                            dtype=torch.long,
                                            device=self.device)
        input_positions = make_tensor_with_pad(input_positions,
                                               pad=0,
135
                                               max_len=max_seq_len,
136
137
138
139
140
141
                                               dtype=torch.long,
                                               device=self.device)
        input_block_ids = torch.tensor(input_block_ids,
                                       dtype=torch.long,
                                       device=self.device)

142
        multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)
143
144
145

        return (input_tokens, input_positions, input_block_ids, seq_lens,
                multi_modal_kwargs)
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178

    def _prepare_decode(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        assert len(seq_group_metadata_list) > 0
        input_tokens: List[List[int]] = []
        input_positions: List[List[int]] = []
        input_block_ids: List[int] = []
        context_lens: List[int] = []

        for seq_group_metadata in seq_group_metadata_list:
            assert not seq_group_metadata.is_prompt

            seq_ids = list(seq_group_metadata.seq_data.keys())

            for seq_id in seq_ids:
                seq_data = seq_group_metadata.seq_data[seq_id]
                generation_token = seq_data.get_last_token_id()
                input_tokens.append([generation_token])

                seq_len = seq_data.get_len()
                position = seq_len - 1
                input_positions.append([position])
                context_lens.append(seq_len)

                assert seq_group_metadata.block_tables is not None
                block_table = seq_group_metadata.block_tables[seq_id]
                assert len(block_table) == 1
                input_block_ids.append(block_table[0])

        input_tokens = make_tensor_with_pad(input_tokens,
                                            pad=0,
179
                                            max_len=1,
180
181
182
183
                                            dtype=torch.long,
                                            device=self.device)
        input_positions = make_tensor_with_pad(input_positions,
                                               pad=0,
184
                                               max_len=1,
185
186
187
188
189
190
191
192
193
194
195
                                               dtype=torch.long,
                                               device=self.device)
        context_lens = torch.tensor(context_lens,
                                    dtype=torch.int,
                                    device=self.device)
        input_block_ids = torch.tensor(input_block_ids,
                                       dtype=torch.long,
                                       device=self.device)

        return input_tokens, input_positions, input_block_ids

196
197
198
199
200
    def make_model_input_from_broadcasted_tensor_dict(
            self, tensor_dict: Dict[str, Any]) -> ModelInputForNeuron:
        return ModelInputForNeuron.from_broadcasted_tensor_dict(tensor_dict)

    def prepare_model_input(
201
        self,
202
        seq_group_metadata_list: List[SequenceGroupMetadata],
203
        virtual_engine: int = 0,
Mor Zusman's avatar
Mor Zusman committed
204
        finished_requests_ids: Optional[List[str]] = None
205
    ) -> ModelInputForNeuron:
206
        multi_modal_kwargs = None
207
208
209
210
211
        # NOTE: We assume that all sequences in the group are all prompts or
        # all decodes.
        is_prompt = seq_group_metadata_list[0].is_prompt
        # Prepare input tensors.
        if is_prompt:
212
213
214
            (input_tokens, input_positions, input_block_ids, seq_lens,
             multi_modal_kwargs
             ) = self._prepare_prompt(seq_group_metadata_list)
215
216
217
        else:
            (input_tokens, input_positions,
             input_block_ids) = self._prepare_decode(seq_group_metadata_list)
218
            seq_lens = []
219
220
        sampling_metadata = SamplingMetadata.prepare(
            seq_group_metadata_list,
221
222
            seq_lens,
            # query_lens is not needed if chunked prefill is not
223
            # supported. Since neuron worker doesn't support chunked prefill
224
225
            # just use seq_lens instead.
            seq_lens,
226
            self.device,
227
228
            self.pin_memory,
            generators=self.get_generators(finished_requests_ids))
229

230
231
232
        return ModelInputForNeuron(input_tokens=input_tokens,
                                   input_positions=input_positions,
                                   input_block_ids=input_block_ids,
233
234
                                   sampling_metadata=sampling_metadata,
                                   multi_modal_kwargs=multi_modal_kwargs)
235
236
237
238

    @torch.inference_mode()
    def execute_model(
        self,
239
240
        model_input: ModelInputForNeuron,
        kv_caches: Optional[List[torch.Tensor]] = None,
241
        intermediate_tensors: Optional[IntermediateTensors] = None,
242
243
244
245
246
247
        num_steps: int = 1,
    ) -> Optional[List[SamplerOutput]]:
        if num_steps > 1:
            raise ValueError(
                "NeuronModelRunner does not support multi-step execution.")

248
        hidden_states = self.model(
249
250
251
            input_ids=model_input.input_tokens,
            positions=model_input.input_positions,
            input_block_ids=model_input.input_block_ids,
252
253
            **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {},
                                         device=self.device),
254
255
256
        )

        # Compute the logits.
257
258
        logits = self.model.compute_logits(hidden_states,
                                           model_input.sampling_metadata)
259
260
261
262

        # Sample the next token.
        output = self.model.sample(
            logits=logits,
263
            sampling_metadata=model_input.sampling_metadata,
264
        )
265
        return [output]
266
267
268
269

    @property
    def vocab_size(self) -> int:
        return self.model_config.get_vocab_size()