neuron_model_runner.py 17.4 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import os
4
from dataclasses import dataclass
5
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
6
7

import torch
8
from torch import nn
9

10
from vllm.config import DeviceConfig, VllmConfig
11
12
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
13
from vllm.model_executor.layers.sampler import SamplerOutput
14
from vllm.model_executor.model_loader.neuron import get_neuron_model
15
16
17
18
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
                             MultiModalKwargs)
from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams
19
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
20
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
21
22
23
24
from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase

if TYPE_CHECKING:
    from vllm.attention.backends.abstract import AttentionBackend
25
26
27
28

logger = init_logger(__name__)


29
30
31
32
33
34
35
36
@dataclass(frozen=True)
class ModelInputForNeuron(ModelRunnerInputBase):
    """
    Used by the NeuronModelRunner.
    """
    input_tokens: Optional[torch.Tensor] = None
    input_positions: Optional[torch.Tensor] = None
    input_block_ids: Optional[torch.Tensor] = None
37
38
    sampling_metadata: SamplingMetadata = None
    multi_modal_kwargs: BatchedTensorInputs = None
39
40
41

    def as_broadcastable_tensor_dict(
            self) -> Dict[str, Union[int, torch.Tensor]]:
42
43
44
45
46
47
48
        return {
            "input_tokens": self.input_tokens,
            "input_positions": self.input_positions,
            "input_block_ids": self.input_block_ids,
            "sampling_metadata": self.sampling_metadata,
            "multi_modal_kwargs": self.multi_modal_kwargs,
        }
49
50
51
52
53
54
55

    @classmethod
    def from_broadcasted_tensor_dict(
        cls,
        tensor_dict: Dict[str, Any],
        attn_backend: Optional["AttentionBackend"] = None,
    ) -> "ModelInputForNeuron":
56
57
58
59
60
61
62
        return ModelInputForNeuron(
            input_tokens=tensor_dict["input_tokens"],
            input_positions=tensor_dict["input_positions"],
            input_block_ids=tensor_dict["input_block_ids"],
            sampling_metadata=tensor_dict["sampling_metadata"],
            multi_modal_kwargs=tensor_dict["multi_modal_kwargs"],
        )
63
64
65


class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
66
    """A model runner for AWS Neuron hardware"""
67

68
69
70
    # NEURON has an upper limit on the top_k
    _MAX_NEURON_SAMPLING_TOP_K = 256

71
72
    def __init__(
        self,
73
        vllm_config: VllmConfig,
74
    ):
75
        ModelRunnerBase.__init__(self, vllm_config)
76
77
78

        if (self.model_config is not None
                and self.model_config.get_sliding_window()):
79
80
            logger.warning("Sliding window is not supported on Neuron. "
                           "The model will run without sliding window.")
81
82
        self.device_config = (self.device_config if self.device_config
                              is not None else DeviceConfig())
83
84
85
        self.device = self.device_config.device
        self.pin_memory = is_pin_memory_available()

86
87
88
89
        # Multi-modal data support
        self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
            .create_input_mapper(self.model_config)

90
91
92
        # Lazy initialization.
        self.model: nn.Module  # initialize after load_model.

93
94
95
96
97
98
99
100
101
102
103
        # Once NEURON_ON_DEVICE_SAMPLING_DISABLED is set to a non-zero value,
        # turn off on-device sampling.
        self._on_device_sampling_disabled = int(
            os.getenv("NEURON_ON_DEVICE_SAMPLING_DISABLED", "0"))

        # NEURON needs to update sampling parameters when request IDs change
        # across batches. This variable stores the previous batch's request IDs
        # to determine if an update is needed.
        self._previous_batch_request_ids: List[str] = []

        if not self._on_device_sampling_disabled:
104
            self._init_neuron_sampling()
105

106
107
108
    def _init_neuron_sampling(self) -> None:
        if current_platform.use_transformers_neuronx():
            from transformers_neuronx.config import GenerationConfig
109
        else:
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
            from transformers import GenerationConfig
        logger.warning(
            "On-device sampling is turned on in Neuron by default, only "
            "top_k, top_p, and temperature are current supported sampling "
            "parameters. To turn off the on-device sampling, please set "
            "the environment variable NEURON_ON_DEVICE_SAMPLING_DISABLED=1.")
        self.model_config.neuron_sampling_params = GenerationConfig(
            max_length=self.scheduler_config.max_model_len,
            do_sample=True,
            per_batch_line=True,
            top_k=[self._MAX_NEURON_SAMPLING_TOP_K] \
                  * self.scheduler_config.max_num_seqs,
            top_p=[1.0] * self.scheduler_config.max_num_seqs,
            temperature=[1.0] * self.scheduler_config.max_num_seqs,
            dynamic=True,
            global_top_k=self._MAX_NEURON_SAMPLING_TOP_K)

    def load_model(self) -> None:
        self.model = get_neuron_model(self.model_config,
                                      parallel_config=self.parallel_config,
                                      scheduler_config=self.scheduler_config)
131

132
133
134
    def get_model(self) -> nn.Module:
        return self.model

135
136
137
    def _prepare_prompt(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
138
139
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int],
               BatchedTensorInputs]:
140
141
142
143
144
        assert len(seq_group_metadata_list) > 0
        input_tokens: List[List[int]] = []
        input_positions: List[List[int]] = []
        input_block_ids: List[int] = []

145
        seq_lens: List[int] = []
146
        multi_modal_kwargs_list: List[MultiModalKwargs] = []
147
148
149
150
151
152
153
154
        for seq_group_metadata in seq_group_metadata_list:
            assert seq_group_metadata.is_prompt
            seq_ids = list(seq_group_metadata.seq_data.keys())
            assert len(seq_ids) == 1
            seq_id = seq_ids[0]

            seq_data = seq_group_metadata.seq_data[seq_id]
            prompt_tokens = seq_data.get_token_ids()
155
156
            seq_len = len(prompt_tokens)
            seq_lens.append(seq_len)
157
158

            input_tokens.append(prompt_tokens)
159
            input_positions.append(list(range(seq_len)))
160
161
162
163
164
165

            assert seq_group_metadata.block_tables is not None
            block_table = seq_group_metadata.block_tables[seq_id]
            assert len(block_table) == 1
            input_block_ids.append(block_table[0])

166
167
            mm_kwargs = seq_group_metadata.multi_modal_data
            if mm_kwargs:
168
                multi_modal_kwargs_list.append(mm_kwargs)
169

170
171
        max_seq_len = max(seq_lens)
        assert max_seq_len > 0
172
173
        input_tokens = make_tensor_with_pad(input_tokens,
                                            pad=0,
174
                                            max_len=max_seq_len,
175
176
177
178
                                            dtype=torch.long,
                                            device=self.device)
        input_positions = make_tensor_with_pad(input_positions,
                                               pad=0,
179
                                               max_len=max_seq_len,
180
181
182
183
184
185
                                               dtype=torch.long,
                                               device=self.device)
        input_block_ids = torch.tensor(input_block_ids,
                                       dtype=torch.long,
                                       device=self.device)

186
        multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
187
188
189

        return (input_tokens, input_positions, input_block_ids, seq_lens,
                multi_modal_kwargs)
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222

    def _prepare_decode(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        assert len(seq_group_metadata_list) > 0
        input_tokens: List[List[int]] = []
        input_positions: List[List[int]] = []
        input_block_ids: List[int] = []
        context_lens: List[int] = []

        for seq_group_metadata in seq_group_metadata_list:
            assert not seq_group_metadata.is_prompt

            seq_ids = list(seq_group_metadata.seq_data.keys())

            for seq_id in seq_ids:
                seq_data = seq_group_metadata.seq_data[seq_id]
                generation_token = seq_data.get_last_token_id()
                input_tokens.append([generation_token])

                seq_len = seq_data.get_len()
                position = seq_len - 1
                input_positions.append([position])
                context_lens.append(seq_len)

                assert seq_group_metadata.block_tables is not None
                block_table = seq_group_metadata.block_tables[seq_id]
                assert len(block_table) == 1
                input_block_ids.append(block_table[0])

        input_tokens = make_tensor_with_pad(input_tokens,
                                            pad=0,
223
                                            max_len=1,
224
225
226
227
                                            dtype=torch.long,
                                            device=self.device)
        input_positions = make_tensor_with_pad(input_positions,
                                               pad=0,
228
                                               max_len=1,
229
230
231
232
233
234
235
236
237
238
239
                                               dtype=torch.long,
                                               device=self.device)
        context_lens = torch.tensor(context_lens,
                                    dtype=torch.int,
                                    device=self.device)
        input_block_ids = torch.tensor(input_block_ids,
                                       dtype=torch.long,
                                       device=self.device)

        return input_tokens, input_positions, input_block_ids

240
241
242
243
244
    def make_model_input_from_broadcasted_tensor_dict(
            self, tensor_dict: Dict[str, Any]) -> ModelInputForNeuron:
        return ModelInputForNeuron.from_broadcasted_tensor_dict(tensor_dict)

    def prepare_model_input(
245
        self,
246
        seq_group_metadata_list: List[SequenceGroupMetadata],
247
        virtual_engine: int = 0,
Mor Zusman's avatar
Mor Zusman committed
248
        finished_requests_ids: Optional[List[str]] = None
249
    ) -> ModelInputForNeuron:
250
        multi_modal_kwargs = None
251
252
253
254
255
        # NOTE: We assume that all sequences in the group are all prompts or
        # all decodes.
        is_prompt = seq_group_metadata_list[0].is_prompt
        # Prepare input tensors.
        if is_prompt:
256
257
258
            (input_tokens, input_positions, input_block_ids, seq_lens,
             multi_modal_kwargs
             ) = self._prepare_prompt(seq_group_metadata_list)
259
260
261
        else:
            (input_tokens, input_positions,
             input_block_ids) = self._prepare_decode(seq_group_metadata_list)
262
            seq_lens = None
263
264
265
266
267
268
269
270
271
272

        if not self._on_device_sampling_disabled:
            for seq_group_metadata in seq_group_metadata_list:
                sampling_params = seq_group_metadata.sampling_params
                top_k, top_p, temperature = (
                    self._convert_to_neuron_sampling_params(sampling_params))
                sampling_params.top_k = top_k
                sampling_params.top_p = top_p
                sampling_params.temperature = temperature

273
274
        sampling_metadata = SamplingMetadata.prepare(
            seq_group_metadata_list,
275
276
            seq_lens,
            # query_lens is not needed if chunked prefill is not
277
            # supported. Since neuron worker doesn't support chunked prefill
278
279
            # just use seq_lens instead.
            seq_lens,
280
            self.device,
281
282
            self.pin_memory,
            generators=self.get_generators(finished_requests_ids))
283

284
285
        if current_platform.use_transformers_neuronx(
        ) and not self._on_device_sampling_disabled:
286
287
288
289
290
291
292
            # Once the request IDs are changed in current iteration, we will
            # update the on-device sampling parameters.
            current_batch_request_ids = [
                seq_group_meta_data.request_id
                for seq_group_meta_data in seq_group_metadata_list
            ]
            if current_batch_request_ids != self._previous_batch_request_ids:
293
                self._update_neuron_sampling_params(seq_group_metadata_list)
294
295
                self._previous_batch_request_ids = current_batch_request_ids

296
297
298
        return ModelInputForNeuron(input_tokens=input_tokens,
                                   input_positions=input_positions,
                                   input_block_ids=input_block_ids,
299
300
                                   sampling_metadata=sampling_metadata,
                                   multi_modal_kwargs=multi_modal_kwargs)
301

302
303
    def _update_neuron_sampling_params(
            self, seq_group_metadata_list: List[SequenceGroupMetadata]):
304
305
306
307
308
309
        # Update Neuron sampling parameters (GenerationConfig in Neuron)
        current_sampling_params = self.model_config.neuron_sampling_params
        assert current_sampling_params is not None, (
            f"Failed to update sampling_params, "
            f"current sampling params is {current_sampling_params}")

310
311
        is_update_needed = False

312
313
314
315
        top_k = current_sampling_params.top_k
        top_p = current_sampling_params.top_p
        temperature = current_sampling_params.temperature

316
317
318
319
320
        # The index of a sequence's sampling parameters in neuron is equal to
        # its index in `input_block_ids`.
        for seq_group_metadata in seq_group_metadata_list:
            seq_ids = list(seq_group_metadata.seq_data.keys())
            sampling_params = seq_group_metadata.sampling_params
321

322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
            seq_group_top_k = sampling_params.top_k
            seq_group_top_p = sampling_params.top_p
            seq_group_temperature = sampling_params.temperature

            for seq_id in seq_ids:
                index = seq_group_metadata.block_tables[seq_id][0]
                if (top_k[index] != seq_group_top_k
                        or top_p[index] != seq_group_top_p
                        or temperature[index] != seq_group_temperature):
                    is_update_needed = True

                top_k[index] = seq_group_top_k
                top_p[index] = seq_group_top_p
                temperature[index] = seq_group_temperature

        # update_generation_config is only available in transformers-neuronx
        if is_update_needed and current_platform.use_transformers_neuronx():
            self.model.model.update_generation_config(current_sampling_params)

    def _convert_to_neuron_sampling_params(
            self, sampling_params: SamplingParams) -> Tuple[int, float, float]:
        # Returns the top_k, top_p and temperature parameters for neuron.
        top_k = sampling_params.top_k
        top_p = sampling_params.top_p
        temperature = sampling_params.temperature

        if temperature == 0.0:
            # Enable greedy sampling on zero temperature
            return (1, 1.0, 1.0)
351
        if top_k < 1 or top_k > self._MAX_NEURON_SAMPLING_TOP_K:
352
353
354
            top_k = self._MAX_NEURON_SAMPLING_TOP_K

        return (top_k, top_p, temperature)
355

356
357
358
    @torch.inference_mode()
    def execute_model(
        self,
359
360
        model_input: ModelInputForNeuron,
        kv_caches: Optional[List[torch.Tensor]] = None,
361
        intermediate_tensors: Optional[IntermediateTensors] = None,
362
363
364
365
366
367
        num_steps: int = 1,
    ) -> Optional[List[SamplerOutput]]:
        if num_steps > 1:
            raise ValueError(
                "NeuronModelRunner does not support multi-step execution.")

368
369
370
371
372
373
374
375
376
377
378
379
380
        # extract top_k, top_p and temperature from model_input for neuron
        # forward call
        sampling_params = (torch.tensor([[
            seq_group.sampling_params.top_k, seq_group.sampling_params.top_p,
            seq_group.sampling_params.temperature
        ] for seq_group in model_input.sampling_metadata.seq_groups]))

        if current_platform.use_neuronx_distributed():
            hidden_states = self.model(
                input_ids=model_input.input_tokens,
                positions=model_input.input_positions,
                input_block_ids=model_input.input_block_ids,
                sampling_params=sampling_params,
381
382
383
384
385
                **MultiModalKwargs.as_kwargs(
                    model_input.multi_modal_kwargs or {},
                    dtype=self.model_config.dtype,
                    device=self.device,
                ),
386
387
388
389
            )
        elif current_platform.use_transformers_neuronx():
            # [TODO] validate on-device sampling
            # The model signature may need change for on-device sampling
390
391
392
393
            hidden_states = self.model(
                input_ids=model_input.input_tokens,
                positions=model_input.input_positions,
                input_block_ids=model_input.input_block_ids,
394
395
396
397
398
                **MultiModalKwargs.as_kwargs(
                    model_input.multi_modal_kwargs or {},
                    dtype=self.model_config.dtype,
                    device=self.device,
                ),
399
            )
400

401
402
403
404
405
406
407
        # Compute the logits only if the on-device sampling is turned off as
        # on-device sampling outputs the token ids.
        if self._on_device_sampling_disabled:
            logits = self.model.compute_logits(hidden_states,
                                               model_input.sampling_metadata)
        else:
            logits = hidden_states
408
409
410
411

        # Sample the next token.
        output = self.model.sample(
            logits=logits,
412
            sampling_metadata=model_input.sampling_metadata,
413
        )
414
        return [output]
415
416
417
418

    @property
    def vocab_size(self) -> int:
        return self.model_config.get_vocab_size()