neuron_model_runner.py 14.5 KB
Newer Older
1
import os
2
from dataclasses import dataclass
3
from importlib.util import find_spec
4
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
5
6

import torch
7
from torch import nn
8
from transformers_neuronx.config import GenerationConfig
9
10
11
12
13

from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
                         SchedulerConfig)
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
14
from vllm.model_executor.layers.sampler import SamplerOutput
15
from vllm.model_executor.model_loader.neuron import get_neuron_model
16
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
17
                             MultiModalInputs)
18
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
19
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
20
21
22
23
from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase

if TYPE_CHECKING:
    from vllm.attention.backends.abstract import AttentionBackend
24
25
26
27

logger = init_logger(__name__)


28
29
30
31
32
33
34
35
36
@dataclass(frozen=True)
class ModelInputForNeuron(ModelRunnerInputBase):
    """
    Used by the NeuronModelRunner.
    """
    input_tokens: Optional[torch.Tensor] = None
    input_positions: Optional[torch.Tensor] = None
    input_block_ids: Optional[torch.Tensor] = None
    sampling_metadata: Optional["SamplingMetadata"] = None
37
    multi_modal_kwargs: Optional[BatchedTensorInputs] = None
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53

    def as_broadcastable_tensor_dict(
            self) -> Dict[str, Union[int, torch.Tensor]]:
        raise NotImplementedError("ModelInputForNeuron cannot be broadcast.")

    @classmethod
    def from_broadcasted_tensor_dict(
        cls,
        tensor_dict: Dict[str, Any],
        attn_backend: Optional["AttentionBackend"] = None,
    ) -> "ModelInputForNeuron":
        assert attn_backend is None
        return cls.from_broadcasted_tensor_dict(tensor_dict)


class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
54

55
56
57
    # NEURON has an upper limit on the top_k
    _MAX_NEURON_SAMPLING_TOP_K = 256

58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
    def __init__(
        self,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
        device_config: DeviceConfig,
    ):
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config

        if model_config is not None and model_config.get_sliding_window():
            logger.warning("Sliding window is not supported on Neuron. "
                           "The model will run without sliding window.")
        self.device_config = (device_config
                              if device_config is not None else DeviceConfig())
        self.device = self.device_config.device
        self.pin_memory = is_pin_memory_available()

77
78
79
80
        # Multi-modal data support
        self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
            .create_input_mapper(self.model_config)

81
82
83
        # Lazy initialization.
        self.model: nn.Module  # initialize after load_model.

84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
        # Once NEURON_ON_DEVICE_SAMPLING_DISABLED is set to a non-zero value,
        # turn off on-device sampling.
        self._on_device_sampling_disabled = int(
            os.getenv("NEURON_ON_DEVICE_SAMPLING_DISABLED", "0"))

        # NEURON needs to update sampling parameters when request IDs change
        # across batches. This variable stores the previous batch's request IDs
        # to determine if an update is needed.
        self._previous_batch_request_ids: List[str] = []

        if not self._on_device_sampling_disabled:
            logger.warning(
                "On-device sampling is turned on in Neuron by default, only "
                "top_k, top_p, and temperature are current supported sampling "
                "parameters. To turn off the on-device sampling, please set "
                "the environment variable NEURON_ON_DEVICE_SAMPLING_DISABLED=1."
            )
            self.model_config.neuron_sampling_params = GenerationConfig(
                max_length=self.scheduler_config.max_model_len,
                do_sample=True,
                per_batch_line=True,
                top_k=[self._MAX_NEURON_SAMPLING_TOP_K] \
                    * self.scheduler_config.max_num_seqs,
                top_p=[1.0] * self.scheduler_config.max_num_seqs,
                temperature=[1.0] * self.scheduler_config.max_num_seqs,
                dynamic=True,
                global_top_k=self._MAX_NEURON_SAMPLING_TOP_K)

112
    def load_model(self) -> None:
113
114
115
116
117
118
119
120
        if find_spec("transformers_neuronx") is not None:
            self.model = get_neuron_model(
                self.model_config,
                parallel_config=self.parallel_config,
                scheduler_config=self.scheduler_config)
        else:
            raise NotImplementedError(
                "Supports only Transformer-NeuronX based models.")
121
122
123
124

    def _prepare_prompt(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
125
126
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int],
               BatchedTensorInputs]:
127
128
129
130
131
        assert len(seq_group_metadata_list) > 0
        input_tokens: List[List[int]] = []
        input_positions: List[List[int]] = []
        input_block_ids: List[int] = []

132
        seq_lens: List[int] = []
133
        multi_modal_inputs_list: List[MultiModalInputs] = []
134
135
136
137
138
139
140
141
        for seq_group_metadata in seq_group_metadata_list:
            assert seq_group_metadata.is_prompt
            seq_ids = list(seq_group_metadata.seq_data.keys())
            assert len(seq_ids) == 1
            seq_id = seq_ids[0]

            seq_data = seq_group_metadata.seq_data[seq_id]
            prompt_tokens = seq_data.get_token_ids()
142
143
            seq_len = len(prompt_tokens)
            seq_lens.append(seq_len)
144
145

            input_tokens.append(prompt_tokens)
146
            input_positions.append(list(range(seq_len)))
147
148
149
150
151
152

            assert seq_group_metadata.block_tables is not None
            block_table = seq_group_metadata.block_tables[seq_id]
            assert len(block_table) == 1
            input_block_ids.append(block_table[0])

153
154
155
            mm_data = seq_group_metadata.multi_modal_data
            if mm_data:
                # Process multi-modal data
156
157
158
159
                mm_kwargs = self.multi_modal_input_mapper(
                    mm_data,
                    mm_processor_kwargs=seq_group_metadata.mm_processor_kwargs,
                )
160
161
                multi_modal_inputs_list.append(mm_kwargs)

162
163
        max_seq_len = max(seq_lens)
        assert max_seq_len > 0
164
165
        input_tokens = make_tensor_with_pad(input_tokens,
                                            pad=0,
166
                                            max_len=max_seq_len,
167
168
169
170
                                            dtype=torch.long,
                                            device=self.device)
        input_positions = make_tensor_with_pad(input_positions,
                                               pad=0,
171
                                               max_len=max_seq_len,
172
173
174
175
176
177
                                               dtype=torch.long,
                                               device=self.device)
        input_block_ids = torch.tensor(input_block_ids,
                                       dtype=torch.long,
                                       device=self.device)

178
        multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)
179
180
181

        return (input_tokens, input_positions, input_block_ids, seq_lens,
                multi_modal_kwargs)
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214

    def _prepare_decode(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        assert len(seq_group_metadata_list) > 0
        input_tokens: List[List[int]] = []
        input_positions: List[List[int]] = []
        input_block_ids: List[int] = []
        context_lens: List[int] = []

        for seq_group_metadata in seq_group_metadata_list:
            assert not seq_group_metadata.is_prompt

            seq_ids = list(seq_group_metadata.seq_data.keys())

            for seq_id in seq_ids:
                seq_data = seq_group_metadata.seq_data[seq_id]
                generation_token = seq_data.get_last_token_id()
                input_tokens.append([generation_token])

                seq_len = seq_data.get_len()
                position = seq_len - 1
                input_positions.append([position])
                context_lens.append(seq_len)

                assert seq_group_metadata.block_tables is not None
                block_table = seq_group_metadata.block_tables[seq_id]
                assert len(block_table) == 1
                input_block_ids.append(block_table[0])

        input_tokens = make_tensor_with_pad(input_tokens,
                                            pad=0,
215
                                            max_len=1,
216
217
218
219
                                            dtype=torch.long,
                                            device=self.device)
        input_positions = make_tensor_with_pad(input_positions,
                                               pad=0,
220
                                               max_len=1,
221
222
223
224
225
226
227
228
229
230
231
                                               dtype=torch.long,
                                               device=self.device)
        context_lens = torch.tensor(context_lens,
                                    dtype=torch.int,
                                    device=self.device)
        input_block_ids = torch.tensor(input_block_ids,
                                       dtype=torch.long,
                                       device=self.device)

        return input_tokens, input_positions, input_block_ids

232
233
234
235
236
    def make_model_input_from_broadcasted_tensor_dict(
            self, tensor_dict: Dict[str, Any]) -> ModelInputForNeuron:
        return ModelInputForNeuron.from_broadcasted_tensor_dict(tensor_dict)

    def prepare_model_input(
237
        self,
238
        seq_group_metadata_list: List[SequenceGroupMetadata],
239
        virtual_engine: int = 0,
Mor Zusman's avatar
Mor Zusman committed
240
        finished_requests_ids: Optional[List[str]] = None
241
    ) -> ModelInputForNeuron:
242
        multi_modal_kwargs = None
243
244
245
246
247
        # NOTE: We assume that all sequences in the group are all prompts or
        # all decodes.
        is_prompt = seq_group_metadata_list[0].is_prompt
        # Prepare input tensors.
        if is_prompt:
248
249
250
            (input_tokens, input_positions, input_block_ids, seq_lens,
             multi_modal_kwargs
             ) = self._prepare_prompt(seq_group_metadata_list)
251
252
253
        else:
            (input_tokens, input_positions,
             input_block_ids) = self._prepare_decode(seq_group_metadata_list)
254
            seq_lens = None
255
256
        sampling_metadata = SamplingMetadata.prepare(
            seq_group_metadata_list,
257
258
            seq_lens,
            # query_lens is not needed if chunked prefill is not
259
            # supported. Since neuron worker doesn't support chunked prefill
260
261
            # just use seq_lens instead.
            seq_lens,
262
            self.device,
263
264
            self.pin_memory,
            generators=self.get_generators(finished_requests_ids))
265

266
267
268
269
270
271
272
273
274
275
276
        if not self._on_device_sampling_disabled:
            # Once the request IDs are changed in current iteration, we will
            # update the on-device sampling parameters.
            current_batch_request_ids = [
                seq_group_meta_data.request_id
                for seq_group_meta_data in seq_group_metadata_list
            ]
            if current_batch_request_ids != self._previous_batch_request_ids:
                self._update_neuron_sampling_params(sampling_metadata)
                self._previous_batch_request_ids = current_batch_request_ids

277
278
279
        return ModelInputForNeuron(input_tokens=input_tokens,
                                   input_positions=input_positions,
                                   input_block_ids=input_block_ids,
280
281
                                   sampling_metadata=sampling_metadata,
                                   multi_modal_kwargs=multi_modal_kwargs)
282

283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
    def _update_neuron_sampling_params(self,
                                       sampling_metadata: SamplingMetadata):
        # Update Neuron sampling parameters (GenerationConfig in Neuron)
        current_sampling_params = self.model_config.neuron_sampling_params
        assert current_sampling_params is not None, (
            f"Failed to update sampling_params, "
            f"current sampling params is {current_sampling_params}")

        top_k = current_sampling_params.top_k
        top_p = current_sampling_params.top_p
        temperature = current_sampling_params.temperature
        for index, sequence_group_to_sample in enumerate(
                sampling_metadata.seq_groups):
            top_k[index] = self._convert_to_neuron_top_k(
                sequence_group_to_sample.sampling_params.top_k)
            top_p[index] = sequence_group_to_sample.sampling_params.top_p
            temperature[index] = \
                sequence_group_to_sample.sampling_params.temperature

        self.model.model.update_generation_config(current_sampling_params)

    def _convert_to_neuron_top_k(self, top_k: int) -> int:
        if top_k < 0 or top_k > self._MAX_NEURON_SAMPLING_TOP_K:
            return self._MAX_NEURON_SAMPLING_TOP_K
        return top_k

309
310
311
    @torch.inference_mode()
    def execute_model(
        self,
312
313
        model_input: ModelInputForNeuron,
        kv_caches: Optional[List[torch.Tensor]] = None,
314
        intermediate_tensors: Optional[IntermediateTensors] = None,
315
316
317
318
319
320
        num_steps: int = 1,
    ) -> Optional[List[SamplerOutput]]:
        if num_steps > 1:
            raise ValueError(
                "NeuronModelRunner does not support multi-step execution.")

321
        hidden_states = self.model(
322
323
324
            input_ids=model_input.input_tokens,
            positions=model_input.input_positions,
            input_block_ids=model_input.input_block_ids,
325
326
            **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {},
                                         device=self.device),
327
328
        )

329
330
331
332
333
334
335
        # Compute the logits only if the on-device sampling is turned off as
        # on-device sampling outputs the token ids.
        if self._on_device_sampling_disabled:
            logits = self.model.compute_logits(hidden_states,
                                               model_input.sampling_metadata)
        else:
            logits = hidden_states
336
337
338
339

        # Sample the next token.
        output = self.model.sample(
            logits=logits,
340
            sampling_metadata=model_input.sampling_metadata,
341
        )
342
        return [output]
343
344
345
346

    @property
    def vocab_size(self) -> int:
        return self.model_config.get_vocab_size()