cpu_model_runner.py 14.2 KB
Newer Older
1
2
from collections import defaultdict
from typing import Dict, List, Optional, Tuple
3
4

import torch
5
from torch import nn
6
7

from vllm.attention import AttentionMetadata, get_attn_backend
8
9
10
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
                         ModelConfig, ParallelConfig, SchedulerConfig,
                         VisionLanguageConfig)
11
12
13
14
from vllm.distributed import broadcast_tensor_dict
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model
15
from vllm.multimodal import MULTIMODAL_REGISTRY
16
17
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import make_tensor_with_pad
18
19
20
21
22
23
24
25
26
27
28
29
30
31

logger = init_logger(__name__)

_PAD_SLOT_ID = -1


class CPUModelRunner:

    def __init__(
        self,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
        device_config: DeviceConfig,
32
        cache_config: CacheConfig,
33
        load_config: LoadConfig,
34
        lora_config: Optional[LoRAConfig],
35
        vision_language_config: Optional[VisionLanguageConfig],
36
37
38
39
40
41
42
43
        kv_cache_dtype: Optional[str] = "auto",
        is_driver_worker: bool = False,
        *args,
        **kwargs,
    ):
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
44
45
        # Currently, CPU worker doesn't support chunked prefill.
        assert self.scheduler_config.chunked_prefill_enabled is False
46
47
        self.device_config = device_config
        self.cache_config = cache_config
48
        self.lora_config = lora_config
49
        self.vision_language_config = vision_language_config
50
        self.load_config = load_config
51
52
53
54
55
        self.is_driver_worker = is_driver_worker

        self.device = self.device_config.device

        self.kv_cache_dtype = kv_cache_dtype
56
57
        self.sliding_window = model_config.get_sliding_window()
        self.block_size = cache_config.block_size
58
59
60
61
62
63
64
65
66
        self.attn_backend = get_attn_backend(
            self.model_config.get_num_attention_heads(self.parallel_config),
            self.model_config.get_head_size(),
            self.model_config.get_num_kv_heads(self.parallel_config),
            self.model_config.get_sliding_window(),
            self.model_config.dtype,
            self.kv_cache_dtype,
            self.block_size,
        )
67

68
69
70
71
72
73
74
75
76
77
        # Create processor for multi-modal data
        if self.vision_language_config is not None:
            self.multi_modal_input_processor = MULTIMODAL_REGISTRY \
                .create_input_processor(
                    self.model_config,
                    self.vision_language_config,
                )
        else:
            self.multi_modal_input_processor = None

78
79
80
        # Lazy initialization.
        self.model: nn.Module  # Set after init_Model

81
    def load_model(self) -> None:
82
83
84
85
86
87
88
        self.model = get_model(
            model_config=self.model_config,
            load_config=self.load_config,
            device_config=self.device_config,
            vision_language_config=self.vision_language_config,
            lora_config=self.lora_config,
            parallel_config=self.parallel_config,
89
90
            scheduler_config=self.scheduler_config,
            cache_config=self.cache_config)
91
92
93
94

    def _prepare_prompt(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
95
96
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], Dict[
            str, torch.Tensor]]:
97
98
99
100
        assert len(seq_group_metadata_list) > 0
        input_tokens: List[int] = []
        input_positions: List[int] = []
        slot_mapping: List[int] = []
101
        seq_lens: List[int] = []
102
103
        multi_modal_kwargs_list: Dict[str,
                                      List[torch.Tensor]] = defaultdict(list)
104
105
106
107
108
109
110
111
112
113

        for seq_group_metadata in seq_group_metadata_list:
            assert seq_group_metadata.is_prompt
            seq_ids = list(seq_group_metadata.seq_data.keys())
            assert len(seq_ids) == 1
            seq_id = seq_ids[0]

            seq_data = seq_group_metadata.seq_data[seq_id]
            prompt_tokens = seq_data.get_token_ids()
            computed_len = seq_data.get_num_computed_tokens()
114
            seq_len = len(prompt_tokens)
115

116
            seq_lens.append(seq_len)  # Prompt token num
117
118
119
120
121
            input_tokens.extend(prompt_tokens)  # Token ids

            # Token position ids
            # NOTE(woosuk): Here we assume that the first token in the prompt
            # is always the first token in the sequence.
122
            input_positions.extend(list(range(computed_len, seq_len)))
123

124
125
126
127
128
129
130
131
132
133
134
            mm_data = seq_group_metadata.multi_modal_data
            if mm_data is not None:
                # Process multi-modal data
                if self.multi_modal_input_processor is None:
                    raise ValueError(
                        "Multi-modal inputs are only supported by "
                        "vision language models.")

                mm_kwargs = self.multi_modal_input_processor(mm_data)
                for k, v in mm_kwargs.items():
                    multi_modal_kwargs_list[k].append(v)
135

136
137
138
            # Compute the slot mapping.
            block_table = seq_group_metadata.block_tables[seq_id]
            # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
139
            # where start_idx is max(0, seq_len - sliding_window).
140
141
142
143
144
            # For example, if the prompt len is 10, sliding window is 8, and
            # block size is 4, the first two tokens are masked and the slot
            # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
            start_idx = 0
            if self.sliding_window is not None:
145
                start_idx = max(0, seq_len - self.sliding_window)
146

147
            for i in range(computed_len, seq_len):
148
149
150
151
152
153
154
155
156
157
                if i < start_idx:
                    slot_mapping.append(_PAD_SLOT_ID)
                    continue

                block_number = block_table[i //
                                           self.block_size]  # type: ignore
                block_offset = i % self.block_size  # type: ignore
                slot = block_number * self.block_size + block_offset
                slot_mapping.append(slot)

158
159
160
161
        multi_modal_kwargs = {
            k: torch.cat(v, dim=0).to(self.device)
            for k, v in multi_modal_kwargs_list.items()
        }
162

163
164
165
166
167
168
169
170
171
172
173
174
175
176
        num_prompt_tokens = len(input_tokens)

        input_tokens = torch.tensor(input_tokens,
                                    dtype=torch.long,
                                    device=self.device)  # type: ignore
        input_positions = torch.tensor(input_positions,
                                       dtype=torch.long,
                                       device=self.device)  # type: ignore
        slot_mapping = torch.tensor(slot_mapping,
                                    dtype=torch.long,
                                    device=self.device)  # type: ignore

        attn_metadata = self.attn_backend.make_metadata(
            is_prompt=True,
177
178
            seq_lens=seq_lens,
            seq_lens_tensor=None,
179
            max_decode_seq_len=None,
180
            num_prefills=len(seq_lens),
181
182
183
184
185
            num_prefill_tokens=num_prompt_tokens,
            num_decode_tokens=0,
            block_tables=torch.tensor([]),
            slot_mapping=slot_mapping,
        )
186
        return (input_tokens, input_positions, attn_metadata, seq_lens,
187
                multi_modal_kwargs)
188
189
190
191
192
193
194
195
196

    def _prepare_decode(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]:
        assert len(seq_group_metadata_list) > 0
        input_tokens: List[int] = []
        input_positions: List[int] = []
        slot_mapping: List[int] = []
197
        seq_lens: List[int] = []
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
        block_tables: List[List[int]] = []

        for seq_group_metadata in seq_group_metadata_list:
            assert not seq_group_metadata.is_prompt
            assert seq_group_metadata.token_chunk_size == 1

            seq_ids = list(seq_group_metadata.seq_data.keys())

            for seq_id in seq_ids:
                seq_data = seq_group_metadata.seq_data[seq_id]
                generation_token = seq_data.get_last_token_id()
                input_tokens.append(generation_token)

                seq_len = seq_data.get_len()
                position = seq_len - 1
                input_positions.append(position)

215
                seq_len = seq_len if self.sliding_window is None else min(
216
                    seq_len, self.sliding_window)
217
                seq_lens.append(seq_len)
218
219
220
221
222
223
224
225
226
227
228
229
230

                block_table = seq_group_metadata.block_tables[seq_id]
                block_number = block_table[position // self.block_size]
                block_offset = position % self.block_size
                slot = block_number * self.block_size + block_offset
                slot_mapping.append(slot)

                if self.sliding_window is not None:
                    sliding_window_blocks = (self.sliding_window //
                                             self.block_size)
                    block_table = block_table[-sliding_window_blocks:]
                block_tables.append(block_table)

231
        max_decode_seq_len = max(seq_lens)
232
233
234
235
236
237
238
239
240
241

        input_tokens = torch.tensor(input_tokens,
                                    dtype=torch.long,
                                    device=self.device)
        input_positions = torch.tensor(input_positions,
                                       dtype=torch.long,
                                       device=self.device)
        slot_mapping = torch.tensor(slot_mapping,
                                    dtype=torch.long,
                                    device=self.device)
242
243
244
        seq_lens_tensor = torch.tensor(seq_lens,
                                       dtype=torch.int,
                                       device=self.device)
245
246
247
248
249
250
251
252
253
254
255
256
257
258

        max_block_table_len = max(
            len(block_table) for block_table in block_tables)
        block_tables = make_tensor_with_pad(
            block_tables,
            max_len=max_block_table_len,
            pad=0,
            dtype=torch.int,
            device=self.device,
        )

        attn_metadata = self.attn_backend.make_metadata(
            is_prompt=False,
            slot_mapping=slot_mapping,
259
260
            seq_lens=seq_lens,
            seq_lens_tensor=seq_lens_tensor,
261
            max_decode_seq_len=max_decode_seq_len,
262
263
264
265
266
267
268
269
270
271
272
273
274
            num_prefill_tokens=0,
            num_decode_tokens=len(input_tokens),
            num_prefills=0,
            block_tables=block_tables,
        )
        return (
            input_tokens,
            input_positions,
            attn_metadata,
        )

    def prepare_input_tensors(
        self,
275
        seq_group_metadata_list: List[SequenceGroupMetadata],
276
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
277
278
               Optional[Dict[str, torch.Tensor]]]:
        multi_modal_kwargs = None
279
280
281
282
283
284
        if self.is_driver_worker:
            # NOTE: We assume that all sequences in the group are all prompts or
            # all decodes.
            is_prompt = seq_group_metadata_list[0].is_prompt
            # Prepare input tensors.
            if is_prompt:
285
                (input_tokens, input_positions, attn_metadata, seq_lens,
286
                 multi_modal_kwargs
287
                 ) = self._prepare_prompt(seq_group_metadata_list)
288
289
290
            else:
                (input_tokens, input_positions,
                 attn_metadata) = self._prepare_decode(seq_group_metadata_list)
291
                seq_lens = []
292
293
            sampling_metadata = SamplingMetadata.prepare(
                seq_group_metadata_list,
294
295
                seq_lens,
                # query_lens is not needed if chunked prefill is not
296
                # supported. Since CPU worker doesn't support chunked prefill
297
298
                # just use seq_lens instead.
                seq_lens,
299
300
                self.device,
                pin_memory=False)
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
            # Broadcast the metadata.
            metadata_dict = {
                "input_tokens": input_tokens,
                "input_positions": input_positions,
                "selected_token_indices":
                sampling_metadata.selected_token_indices,
            }
            metadata_dict.update(attn_metadata.asdict_zerocopy())
            broadcast_tensor_dict(metadata_dict, src=0)
        else:
            metadata_dict = broadcast_tensor_dict(src=0)
            input_tokens = metadata_dict.pop("input_tokens")
            input_positions = metadata_dict.pop("input_positions")
            selected_token_indices = metadata_dict.pop(
                "selected_token_indices")
            attn_metadata = self.attn_backend.make_metadata(**metadata_dict)
            sampling_metadata = SamplingMetadata(
                seq_groups=None,
                seq_data=None,
320
                seq_lens=None,
321
322
323
324
325
                selected_token_indices=selected_token_indices,
                categorized_sample_indices=None,
                generators=None,
            )

326
        return (input_tokens, input_positions, attn_metadata,
327
                sampling_metadata, multi_modal_kwargs)
328
329
330
331

    @torch.inference_mode()
    def execute_model(
        self,
332
        seq_group_metadata_list: List[SequenceGroupMetadata],
333
334
        kv_caches: List[torch.Tensor],
    ) -> Optional[SamplerOutput]:
335
336
        (input_tokens, input_positions, attn_metadata, sampling_metadata,
         multi_modal_input
337
338
339
340
341
342
343
344
345
         ) = self.prepare_input_tensors(seq_group_metadata_list)

        model_executable = self.model
        execute_model_kwargs = {
            "input_ids": input_tokens,
            "positions": input_positions,
            "kv_caches": kv_caches,
            "attn_metadata": attn_metadata,
        }
346
347
        if self.vision_language_config and multi_modal_input is not None:
            execute_model_kwargs.update(multi_modal_input)
348
349
350
351
352
353
354

        hidden_states = model_executable(**execute_model_kwargs)

        # Compute the logits.
        logits = self.model.compute_logits(hidden_states, sampling_metadata)

        # Only perform sampling in the driver worker.
355
        if not self.is_driver_worker:
356
357
358
359
360
361
362
363
            return None

        # Sample the next token.
        output = self.model.sample(
            logits=logits,
            sampling_metadata=sampling_metadata,
        )
        return output