neuron_model_runner.py 14.5 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import os
4
from dataclasses import dataclass
5
from importlib.util import find_spec
6
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
7
8

import torch
9
from torch import nn
10
from transformers_neuronx.config import GenerationConfig
11

12
from vllm.config import VllmConfig
13
from vllm.forward_context import set_forward_context
14
15
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
16
from vllm.model_executor.layers.sampler import SamplerOutput
17
from vllm.model_executor.model_loader.neuron import get_neuron_model
18
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
19
                             MultiModalKwargs)
20
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
21
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
22
23
24
25
from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase

if TYPE_CHECKING:
    from vllm.attention.backends.abstract import AttentionBackend
26
27
28
29

logger = init_logger(__name__)


30
31
32
33
34
35
36
37
38
@dataclass(frozen=True)
class ModelInputForNeuron(ModelRunnerInputBase):
    """
    Used by the NeuronModelRunner.
    """
    input_tokens: Optional[torch.Tensor] = None
    input_positions: Optional[torch.Tensor] = None
    input_block_ids: Optional[torch.Tensor] = None
    sampling_metadata: Optional["SamplingMetadata"] = None
39
    multi_modal_kwargs: Optional[BatchedTensorInputs] = None
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55

    def as_broadcastable_tensor_dict(
            self) -> Dict[str, Union[int, torch.Tensor]]:
        raise NotImplementedError("ModelInputForNeuron cannot be broadcast.")

    @classmethod
    def from_broadcasted_tensor_dict(
        cls,
        tensor_dict: Dict[str, Any],
        attn_backend: Optional["AttentionBackend"] = None,
    ) -> "ModelInputForNeuron":
        assert attn_backend is None
        return cls.from_broadcasted_tensor_dict(tensor_dict)


class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
56

57
58
59
    # NEURON has an upper limit on the top_k
    _MAX_NEURON_SAMPLING_TOP_K = 256

60
61
    def __init__(
        self,
62
        vllm_config: VllmConfig,
63
    ):
64
65
        ModelRunnerBase.__init__(self, vllm_config)
        model_config = self.model_config
66
67
68
69
70
71
        if model_config is not None and model_config.get_sliding_window():
            logger.warning("Sliding window is not supported on Neuron. "
                           "The model will run without sliding window.")
        self.device = self.device_config.device
        self.pin_memory = is_pin_memory_available()

72
        # Multi-modal data support
73
74
        self.mm_registry = MULTIMODAL_REGISTRY
        self.multi_modal_input_mapper = self.mm_registry \
75
76
            .create_input_mapper(self.model_config)

77
78
79
        # Lazy initialization.
        self.model: nn.Module  # initialize after load_model.

80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
        # Once NEURON_ON_DEVICE_SAMPLING_DISABLED is set to a non-zero value,
        # turn off on-device sampling.
        self._on_device_sampling_disabled = int(
            os.getenv("NEURON_ON_DEVICE_SAMPLING_DISABLED", "0"))

        # NEURON needs to update sampling parameters when request IDs change
        # across batches. This variable stores the previous batch's request IDs
        # to determine if an update is needed.
        self._previous_batch_request_ids: List[str] = []

        if not self._on_device_sampling_disabled:
            logger.warning(
                "On-device sampling is turned on in Neuron by default, only "
                "top_k, top_p, and temperature are current supported sampling "
                "parameters. To turn off the on-device sampling, please set "
                "the environment variable NEURON_ON_DEVICE_SAMPLING_DISABLED=1."
            )
            self.model_config.neuron_sampling_params = GenerationConfig(
                max_length=self.scheduler_config.max_model_len,
                do_sample=True,
                per_batch_line=True,
                top_k=[self._MAX_NEURON_SAMPLING_TOP_K] \
                    * self.scheduler_config.max_num_seqs,
                top_p=[1.0] * self.scheduler_config.max_num_seqs,
                temperature=[1.0] * self.scheduler_config.max_num_seqs,
                dynamic=True,
                global_top_k=self._MAX_NEURON_SAMPLING_TOP_K)

108
    def load_model(self) -> None:
109
110
111
112
113
114
115
116
        if find_spec("transformers_neuronx") is not None:
            self.model = get_neuron_model(
                self.model_config,
                parallel_config=self.parallel_config,
                scheduler_config=self.scheduler_config)
        else:
            raise NotImplementedError(
                "Supports only Transformer-NeuronX based models.")
117

118
119
120
    def get_model(self) -> nn.Module:
        return self.model

121
122
123
    def _prepare_prompt(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
124
125
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int],
               BatchedTensorInputs]:
126
127
128
129
130
        assert len(seq_group_metadata_list) > 0
        input_tokens: List[List[int]] = []
        input_positions: List[List[int]] = []
        input_block_ids: List[int] = []

131
        seq_lens: List[int] = []
132
        multi_modal_kwargs_list: List[MultiModalKwargs] = []
133
134
135
136
137
138
139
140
        for seq_group_metadata in seq_group_metadata_list:
            assert seq_group_metadata.is_prompt
            seq_ids = list(seq_group_metadata.seq_data.keys())
            assert len(seq_ids) == 1
            seq_id = seq_ids[0]

            seq_data = seq_group_metadata.seq_data[seq_id]
            prompt_tokens = seq_data.get_token_ids()
141
142
            seq_len = len(prompt_tokens)
            seq_lens.append(seq_len)
143
144

            input_tokens.append(prompt_tokens)
145
            input_positions.append(list(range(seq_len)))
146
147
148
149
150
151

            assert seq_group_metadata.block_tables is not None
            block_table = seq_group_metadata.block_tables[seq_id]
            assert len(block_table) == 1
            input_block_ids.append(block_table[0])

152
153
            mm_data = seq_group_metadata.multi_modal_data
            if mm_data:
154
155
156
157
158
159
160
161
162
                if self.mm_registry.has_processor(self.model_config):
                    mm_kwargs = mm_data
                else:
                    mm_kwargs = self.multi_modal_input_mapper(
                        mm_data,
                        seq_group_metadata.mm_processor_kwargs,
                    )

                multi_modal_kwargs_list.append(mm_kwargs)
163

164
165
        max_seq_len = max(seq_lens)
        assert max_seq_len > 0
166
167
        input_tokens = make_tensor_with_pad(input_tokens,
                                            pad=0,
168
                                            max_len=max_seq_len,
169
170
171
172
                                            dtype=torch.long,
                                            device=self.device)
        input_positions = make_tensor_with_pad(input_positions,
                                               pad=0,
173
                                               max_len=max_seq_len,
174
175
176
177
178
179
                                               dtype=torch.long,
                                               device=self.device)
        input_block_ids = torch.tensor(input_block_ids,
                                       dtype=torch.long,
                                       device=self.device)

180
        multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
181
182
183

        return (input_tokens, input_positions, input_block_ids, seq_lens,
                multi_modal_kwargs)
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216

    def _prepare_decode(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        assert len(seq_group_metadata_list) > 0
        input_tokens: List[List[int]] = []
        input_positions: List[List[int]] = []
        input_block_ids: List[int] = []
        context_lens: List[int] = []

        for seq_group_metadata in seq_group_metadata_list:
            assert not seq_group_metadata.is_prompt

            seq_ids = list(seq_group_metadata.seq_data.keys())

            for seq_id in seq_ids:
                seq_data = seq_group_metadata.seq_data[seq_id]
                generation_token = seq_data.get_last_token_id()
                input_tokens.append([generation_token])

                seq_len = seq_data.get_len()
                position = seq_len - 1
                input_positions.append([position])
                context_lens.append(seq_len)

                assert seq_group_metadata.block_tables is not None
                block_table = seq_group_metadata.block_tables[seq_id]
                assert len(block_table) == 1
                input_block_ids.append(block_table[0])

        input_tokens = make_tensor_with_pad(input_tokens,
                                            pad=0,
217
                                            max_len=1,
218
219
220
221
                                            dtype=torch.long,
                                            device=self.device)
        input_positions = make_tensor_with_pad(input_positions,
                                               pad=0,
222
                                               max_len=1,
223
224
225
226
227
228
229
230
231
232
233
                                               dtype=torch.long,
                                               device=self.device)
        context_lens = torch.tensor(context_lens,
                                    dtype=torch.int,
                                    device=self.device)
        input_block_ids = torch.tensor(input_block_ids,
                                       dtype=torch.long,
                                       device=self.device)

        return input_tokens, input_positions, input_block_ids

234
235
236
237
238
    def make_model_input_from_broadcasted_tensor_dict(
            self, tensor_dict: Dict[str, Any]) -> ModelInputForNeuron:
        return ModelInputForNeuron.from_broadcasted_tensor_dict(tensor_dict)

    def prepare_model_input(
239
        self,
240
        seq_group_metadata_list: List[SequenceGroupMetadata],
241
        virtual_engine: int = 0,
Mor Zusman's avatar
Mor Zusman committed
242
        finished_requests_ids: Optional[List[str]] = None
243
    ) -> ModelInputForNeuron:
244
        multi_modal_kwargs = None
245
246
247
248
249
        # NOTE: We assume that all sequences in the group are all prompts or
        # all decodes.
        is_prompt = seq_group_metadata_list[0].is_prompt
        # Prepare input tensors.
        if is_prompt:
250
251
252
            (input_tokens, input_positions, input_block_ids, seq_lens,
             multi_modal_kwargs
             ) = self._prepare_prompt(seq_group_metadata_list)
253
254
255
        else:
            (input_tokens, input_positions,
             input_block_ids) = self._prepare_decode(seq_group_metadata_list)
256
            seq_lens = None
257
258
        sampling_metadata = SamplingMetadata.prepare(
            seq_group_metadata_list,
259
260
            seq_lens,
            # query_lens is not needed if chunked prefill is not
261
            # supported. Since neuron worker doesn't support chunked prefill
262
263
            # just use seq_lens instead.
            seq_lens,
264
            self.device,
265
266
            self.pin_memory,
            generators=self.get_generators(finished_requests_ids))
267

268
269
270
271
272
273
274
275
276
277
278
        if not self._on_device_sampling_disabled:
            # Once the request IDs are changed in current iteration, we will
            # update the on-device sampling parameters.
            current_batch_request_ids = [
                seq_group_meta_data.request_id
                for seq_group_meta_data in seq_group_metadata_list
            ]
            if current_batch_request_ids != self._previous_batch_request_ids:
                self._update_neuron_sampling_params(sampling_metadata)
                self._previous_batch_request_ids = current_batch_request_ids

279
280
281
        return ModelInputForNeuron(input_tokens=input_tokens,
                                   input_positions=input_positions,
                                   input_block_ids=input_block_ids,
282
283
                                   sampling_metadata=sampling_metadata,
                                   multi_modal_kwargs=multi_modal_kwargs)
284

285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
    def _update_neuron_sampling_params(self,
                                       sampling_metadata: SamplingMetadata):
        # Update Neuron sampling parameters (GenerationConfig in Neuron)
        current_sampling_params = self.model_config.neuron_sampling_params
        assert current_sampling_params is not None, (
            f"Failed to update sampling_params, "
            f"current sampling params is {current_sampling_params}")

        top_k = current_sampling_params.top_k
        top_p = current_sampling_params.top_p
        temperature = current_sampling_params.temperature
        for index, sequence_group_to_sample in enumerate(
                sampling_metadata.seq_groups):
            top_k[index] = self._convert_to_neuron_top_k(
                sequence_group_to_sample.sampling_params.top_k)
            top_p[index] = sequence_group_to_sample.sampling_params.top_p
            temperature[index] = \
                sequence_group_to_sample.sampling_params.temperature

        self.model.model.update_generation_config(current_sampling_params)

    def _convert_to_neuron_top_k(self, top_k: int) -> int:
        if top_k < 0 or top_k > self._MAX_NEURON_SAMPLING_TOP_K:
            return self._MAX_NEURON_SAMPLING_TOP_K
        return top_k

311
312
313
    @torch.inference_mode()
    def execute_model(
        self,
314
315
        model_input: ModelInputForNeuron,
        kv_caches: Optional[List[torch.Tensor]] = None,
316
        intermediate_tensors: Optional[IntermediateTensors] = None,
317
318
319
320
321
322
        num_steps: int = 1,
    ) -> Optional[List[SamplerOutput]]:
        if num_steps > 1:
            raise ValueError(
                "NeuronModelRunner does not support multi-step execution.")

323
324
325
326
327
328
329
330
331
        with set_forward_context(None, self.vllm_config, 0):
            hidden_states = self.model(
                input_ids=model_input.input_tokens,
                positions=model_input.input_positions,
                input_block_ids=model_input.input_block_ids,
                **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs
                                             or {},
                                             device=self.device),
            )
332

333
334
335
336
337
338
339
        # Compute the logits only if the on-device sampling is turned off as
        # on-device sampling outputs the token ids.
        if self._on_device_sampling_disabled:
            logits = self.model.compute_logits(hidden_states,
                                               model_input.sampling_metadata)
        else:
            logits = hidden_states
340
341
342
343

        # Sample the next token.
        output = self.model.sample(
            logits=logits,
344
            sampling_metadata=model_input.sampling_metadata,
345
        )
346
        return [output]
347
348
349
350

    @property
    def vocab_size(self) -> int:
        return self.model_config.get_vocab_size()