neuron_model_runner.py 10.6 KB
Newer Older
1
from dataclasses import dataclass
2
3
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple,
                    Union)
4
5

import torch
6
from torch import nn
7
8
9
10
11

from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
                         SchedulerConfig)
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
12
from vllm.model_executor.model_loader.neuron import get_neuron_model
13
14
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
                             MultiModalInputs)
15
16
from vllm.sequence import (IntermediateTensors, SamplerOutput,
                           SequenceGroupMetadata)
17
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
18
19
20
21
from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase

if TYPE_CHECKING:
    from vllm.attention.backends.abstract import AttentionBackend
22
23
24
25

logger = init_logger(__name__)


26
27
28
29
30
31
32
33
34
@dataclass(frozen=True)
class ModelInputForNeuron(ModelRunnerInputBase):
    """
    Used by the NeuronModelRunner.
    """
    input_tokens: Optional[torch.Tensor] = None
    input_positions: Optional[torch.Tensor] = None
    input_block_ids: Optional[torch.Tensor] = None
    sampling_metadata: Optional["SamplingMetadata"] = None
35
    multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51

    def as_broadcastable_tensor_dict(
            self) -> Dict[str, Union[int, torch.Tensor]]:
        raise NotImplementedError("ModelInputForNeuron cannot be broadcast.")

    @classmethod
    def from_broadcasted_tensor_dict(
        cls,
        tensor_dict: Dict[str, Any],
        attn_backend: Optional["AttentionBackend"] = None,
    ) -> "ModelInputForNeuron":
        assert attn_backend is None
        return cls.from_broadcasted_tensor_dict(tensor_dict)


class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71

    def __init__(
        self,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
        device_config: DeviceConfig,
    ):
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config

        if model_config is not None and model_config.get_sliding_window():
            logger.warning("Sliding window is not supported on Neuron. "
                           "The model will run without sliding window.")
        self.device_config = (device_config
                              if device_config is not None else DeviceConfig())
        self.device = self.device_config.device
        self.pin_memory = is_pin_memory_available()

72
73
74
75
        # Multi-modal data support
        self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
            .create_input_mapper(self.model_config)

76
77
78
        # Lazy initialization.
        self.model: nn.Module  # initialize after load_model.

79
80
81
82
83
84
85
86
    def load_model(self) -> None:
        self.model = get_neuron_model(self.model_config,
                                      parallel_config=self.parallel_config,
                                      scheduler_config=self.scheduler_config)

    def _prepare_prompt(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
87
88
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int], Mapping[
            str, BatchedTensors]]:
89
90
91
92
93
        assert len(seq_group_metadata_list) > 0
        input_tokens: List[List[int]] = []
        input_positions: List[List[int]] = []
        input_block_ids: List[int] = []

94
        seq_lens: List[int] = []
95
        multi_modal_inputs_list: List[MultiModalInputs] = []
96
97
98
99
100
101
102
103
        for seq_group_metadata in seq_group_metadata_list:
            assert seq_group_metadata.is_prompt
            seq_ids = list(seq_group_metadata.seq_data.keys())
            assert len(seq_ids) == 1
            seq_id = seq_ids[0]

            seq_data = seq_group_metadata.seq_data[seq_id]
            prompt_tokens = seq_data.get_token_ids()
104
105
            seq_len = len(prompt_tokens)
            seq_lens.append(seq_len)
106
107

            input_tokens.append(prompt_tokens)
108
            input_positions.append(list(range(seq_len)))
109
110
111
112
113
114

            assert seq_group_metadata.block_tables is not None
            block_table = seq_group_metadata.block_tables[seq_id]
            assert len(block_table) == 1
            input_block_ids.append(block_table[0])

115
116
117
118
119
120
            mm_data = seq_group_metadata.multi_modal_data
            if mm_data:
                # Process multi-modal data
                mm_kwargs = self.multi_modal_input_mapper(mm_data)
                multi_modal_inputs_list.append(mm_kwargs)

121
122
        max_seq_len = max(seq_lens)
        assert max_seq_len > 0
123
124
        input_tokens = make_tensor_with_pad(input_tokens,
                                            pad=0,
125
                                            max_len=max_seq_len,
126
127
128
129
                                            dtype=torch.long,
                                            device=self.device)
        input_positions = make_tensor_with_pad(input_positions,
                                               pad=0,
130
                                               max_len=max_seq_len,
131
132
133
134
135
136
                                               dtype=torch.long,
                                               device=self.device)
        input_block_ids = torch.tensor(input_block_ids,
                                       dtype=torch.long,
                                       device=self.device)

137
138
139
140
141
        multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
                                                    device=self.device)

        return (input_tokens, input_positions, input_block_ids, seq_lens,
                multi_modal_kwargs)
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174

    def _prepare_decode(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        assert len(seq_group_metadata_list) > 0
        input_tokens: List[List[int]] = []
        input_positions: List[List[int]] = []
        input_block_ids: List[int] = []
        context_lens: List[int] = []

        for seq_group_metadata in seq_group_metadata_list:
            assert not seq_group_metadata.is_prompt

            seq_ids = list(seq_group_metadata.seq_data.keys())

            for seq_id in seq_ids:
                seq_data = seq_group_metadata.seq_data[seq_id]
                generation_token = seq_data.get_last_token_id()
                input_tokens.append([generation_token])

                seq_len = seq_data.get_len()
                position = seq_len - 1
                input_positions.append([position])
                context_lens.append(seq_len)

                assert seq_group_metadata.block_tables is not None
                block_table = seq_group_metadata.block_tables[seq_id]
                assert len(block_table) == 1
                input_block_ids.append(block_table[0])

        input_tokens = make_tensor_with_pad(input_tokens,
                                            pad=0,
175
                                            max_len=1,
176
177
178
179
                                            dtype=torch.long,
                                            device=self.device)
        input_positions = make_tensor_with_pad(input_positions,
                                               pad=0,
180
                                               max_len=1,
181
182
183
184
185
186
187
188
189
190
191
                                               dtype=torch.long,
                                               device=self.device)
        context_lens = torch.tensor(context_lens,
                                    dtype=torch.int,
                                    device=self.device)
        input_block_ids = torch.tensor(input_block_ids,
                                       dtype=torch.long,
                                       device=self.device)

        return input_tokens, input_positions, input_block_ids

192
193
194
195
196
    def make_model_input_from_broadcasted_tensor_dict(
            self, tensor_dict: Dict[str, Any]) -> ModelInputForNeuron:
        return ModelInputForNeuron.from_broadcasted_tensor_dict(tensor_dict)

    def prepare_model_input(
197
        self,
198
        seq_group_metadata_list: List[SequenceGroupMetadata],
199
        virtual_engine: int = 0,
Mor Zusman's avatar
Mor Zusman committed
200
        finished_requests_ids: Optional[List[str]] = None
201
    ) -> ModelInputForNeuron:
202
203
204
205
206
        # NOTE: We assume that all sequences in the group are all prompts or
        # all decodes.
        is_prompt = seq_group_metadata_list[0].is_prompt
        # Prepare input tensors.
        if is_prompt:
207
208
209
            (input_tokens, input_positions, input_block_ids, seq_lens,
             multi_modal_kwargs
             ) = self._prepare_prompt(seq_group_metadata_list)
210
211
212
        else:
            (input_tokens, input_positions,
             input_block_ids) = self._prepare_decode(seq_group_metadata_list)
213
            seq_lens = []
214
215
        sampling_metadata = SamplingMetadata.prepare(
            seq_group_metadata_list,
216
217
            seq_lens,
            # query_lens is not needed if chunked prefill is not
218
            # supported. Since neuron worker doesn't support chunked prefill
219
220
            # just use seq_lens instead.
            seq_lens,
221
            self.device,
222
223
            self.pin_memory,
            generators=self.get_generators(finished_requests_ids))
224

225
226
227
        return ModelInputForNeuron(input_tokens=input_tokens,
                                   input_positions=input_positions,
                                   input_block_ids=input_block_ids,
228
229
                                   sampling_metadata=sampling_metadata,
                                   multi_modal_kwargs=multi_modal_kwargs)
230
231
232
233

    @torch.inference_mode()
    def execute_model(
        self,
234
235
        model_input: ModelInputForNeuron,
        kv_caches: Optional[List[torch.Tensor]] = None,
236
        intermediate_tensors: Optional[IntermediateTensors] = None,
237
238
239
240
241
242
        num_steps: int = 1,
    ) -> Optional[List[SamplerOutput]]:
        if num_steps > 1:
            raise ValueError(
                "NeuronModelRunner does not support multi-step execution.")

243
        hidden_states = self.model(
244
245
246
            input_ids=model_input.input_tokens,
            positions=model_input.input_positions,
            input_block_ids=model_input.input_block_ids,
247
            **(model_input.multi_modal_kwargs or {}),
248
249
250
        )

        # Compute the logits.
251
252
        logits = self.model.compute_logits(hidden_states,
                                           model_input.sampling_metadata)
253
254
255
256

        # Sample the next token.
        output = self.model.sample(
            logits=logits,
257
            sampling_metadata=model_input.sampling_metadata,
258
        )
259
        return [output]
260
261
262
263

    @property
    def vocab_size(self) -> int:
        return self.model_config.get_vocab_size()