tpu_model_runner.py 36.2 KB
Newer Older
1
import time
2
from dataclasses import dataclass
3
4
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
                    Type, Union)
5
from unittest.mock import patch
6
7
8
9
10

import numpy as np
import torch
import torch.nn as nn
import torch_xla.core.xla_model as xm
11
import torch_xla.runtime as xr
12
13

from vllm.attention import AttentionMetadata, get_attn_backend
youkaichao's avatar
youkaichao committed
14
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
15
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
16
                         ParallelConfig, SchedulerConfig)
17
from vllm.logger import init_logger
18
from vllm.model_executor.layers.sampler import SamplerOutput
19
20
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.sampling_metadata import SamplingMetadata
21
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
22
                           Logprob, SequenceGroupMetadata, SequenceOutput)
23
24
25
26
27
28
29
from vllm.worker.model_runner_base import (
    ModelRunnerBase, ModelRunnerInputBase,
    _add_attn_metadata_broadcastable_dict,
    _init_attn_metadata_from_tensor_dict)

if TYPE_CHECKING:
    from vllm.attention.backends.abstract import AttentionBackend
30
31
32

logger = init_logger(__name__)

33
34
35
# Here we utilize the behavior that out-of-bound index is ignored.
# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
_PAD_SLOT_ID = 1_000_000_000
36
37
# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow.
_ENABLE_TOP_P = False
38
39
40
# FIXME(woosuk): A temporary hack to support `n > 1`.
# This can significantly affect the performance if too large.
_MAX_NUM_SAMPLES = 128
41
42


43
44
45
46
47
48
49
50
51
@dataclass(frozen=True)
class ModelInputForTPU(ModelRunnerInputBase):
    token_ids: torch.Tensor
    position_ids: torch.Tensor
    attn_metadata: AttentionMetadata
    input_lens: torch.Tensor
    t: torch.Tensor
    p: torch.Tensor
    num_samples: int
52
    n: List[int]
53
    seq_groups: List[List[int]]
54
55
    is_first_multi_step: bool = True
    is_last_step: bool = True
56
    virtual_engine: int = 0
57
    async_callback: Optional[Callable] = None
58
59
60
61
62
63
64
65
66
67

    def as_broadcastable_tensor_dict(
            self) -> Dict[str, Union[int, torch.Tensor]]:
        tensor_dict = {
            "token_ids": self.token_ids,
            "position_ids": self.position_ids,
            "input_lens": self.input_lens,
            "t": self.t,
            "p": self.p,
            "num_samples": self.num_samples,
68
            "n": self.n,
69
            "seq_groups": self.seq_groups,
70
71
            "is_first_multi_step": self.is_first_multi_step,
            "is_last_step": self.is_last_step,
72
            "virtual_engine": self.virtual_engine,
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
        }
        _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
        return tensor_dict

    @classmethod
    def from_broadcasted_tensor_dict(
        cls: Type["ModelInputForTPU"],
        tensor_dict: Dict[str, Any],
        attn_backend: Optional["AttentionBackend"] = None,
    ) -> "ModelInputForTPU":
        if attn_backend is not None:
            tensor_dict = _init_attn_metadata_from_tensor_dict(
                attn_backend, tensor_dict)
        return cls(**tensor_dict)


class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
90
91
92
93
94
95
96
97
98

    def __init__(
        self,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
        device_config: DeviceConfig,
        cache_config: CacheConfig,
        load_config: LoadConfig,
99
        is_driver_worker: bool = False,
100
101
102
103
104
105
106
    ):
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
        self.device_config = device_config
        self.cache_config = cache_config
        self.load_config = load_config
107
        self.is_driver_worker = is_driver_worker
108
109
110
111
112
113
114
115
116
117
118
119
120

        self.block_size = self.cache_config.block_size
        self.max_num_blocks_per_seq = (self.model_config.max_model_len //
                                       self.block_size)
        self.block_tables = np.zeros(
            (self.scheduler_config.max_num_seqs, self.max_num_blocks_per_seq),
            dtype=np.int32)
        self.attn_backend = get_attn_backend(
            self.model_config.get_head_size(),
            self.model_config.get_sliding_window(),
            self.model_config.dtype,
            self.cache_config.cache_dtype,
            self.block_size,
121
            self.model_config.is_attention_free,
122
123
            False,
        )
124
        self.cached_step_outputs: List[torch.Tensor] = []
125
126
127
128

    def load_model(self) -> None:
        self.device = self.device_config.device

129
130
131
132
133
134
135
136
137
        # NOTE(woosuk): While the executor assigns the TP ranks to the worker
        # process, the ranks can be different from the ranks internally assigned
        # by the xm runtime. Therefore, there is a mismatch in the rank
        # assignment between the gloo (cpu) runtime and the xm (tpu) runtime.
        # This is not a problem in linear layers because all-reduce is
        # rank-agnostic. However, it matters for all-gather as the ranks
        # determine the order of concatenating the output tensors.
        # As a workaround, we use the xm's rank assignment only when loading
        # the embedding weights.
138
        xm_tp_rank = xr.global_ordinal()
139
140
141
142
143
144
145
146
147
148
149
150
151
        with patch(
                "vllm.model_executor.layers.vocab_parallel_embedding."
                "get_tensor_model_parallel_rank",
                return_value=xm_tp_rank):
            model = get_model(
                model_config=self.model_config,
                load_config=self.load_config,
                device_config=self.device_config,
                parallel_config=self.parallel_config,
                cache_config=self.cache_config,
                scheduler_config=self.scheduler_config,
                lora_config=None,
            )
152
        model = model.eval()
153
        xm.wait_device_ops()
154
        self.model = ModelWrapper(model)
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215

    def _dummy_run(
        self,
        batch_size: int,
        seq_len: int,
        kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
        is_prompt: bool,
    ) -> None:
        if is_prompt:
            seq_len = (seq_len + 15) // 16 * 16
            token_ids = torch.zeros((batch_size, seq_len),
                                    dtype=torch.int32,
                                    device=self.device)
            position_ids = torch.zeros((batch_size, seq_len),
                                       dtype=torch.int32,
                                       device=self.device)
            slot_mapping = torch.zeros((batch_size, seq_len),
                                       dtype=torch.int64,
                                       device=self.device)
            attn_metadata = self.attn_backend.make_metadata(
                num_prefills=batch_size,
                num_prefill_tokens=batch_size * seq_len,
                num_decode_tokens=0,
                slot_mapping=slot_mapping,
                block_tables=None,
                context_lens=None,
            )
            input_lens = torch.ones((batch_size, ),
                                    dtype=torch.int32,
                                    device=self.device)
        else:
            assert seq_len == 1
            token_ids = torch.zeros((batch_size, seq_len),
                                    dtype=torch.int32,
                                    device=self.device)
            position_ids = torch.zeros((batch_size, seq_len),
                                       dtype=torch.int32,
                                       device=self.device)
            slot_mapping = torch.zeros((batch_size, seq_len),
                                       dtype=torch.int64,
                                       device=self.device)
            block_tables = torch.zeros(
                (batch_size, self.max_num_blocks_per_seq),
                dtype=torch.int32,
                device=self.device)
            context_lens = torch.ones((batch_size, ),
                                      dtype=torch.int32,
                                      device=self.device)
            input_lens = torch.ones((batch_size, ),
                                    dtype=torch.int32,
                                    device=self.device)
            attn_metadata = self.attn_backend.make_metadata(
                num_prefills=0,
                num_prefill_tokens=0,
                num_decode_tokens=batch_size * seq_len,
                slot_mapping=slot_mapping,
                block_tables=block_tables,
                context_lens=context_lens,
            )
        t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
        p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
216
        num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
217

218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
        # NOTE(woosuk): There are two stages of compilation: torch.compile and
        # XLA compilation. Using `mark_dynamic` can reduce the torch.compile
        # overhead by reusing the FX graph for different shapes.
        # However, the XLA graph will still require static shapes and needs to
        # be re-compiled for every different shapes. This overhead is inevitable
        # in the first run, but can be skipped afterwards as we cache the XLA
        # graphs in the disk (VLLM_XLA_CACHE_PATH).
        if is_prompt:
            # Prefll
            torch._dynamo.mark_dynamic(token_ids, 1)
            torch._dynamo.mark_dynamic(position_ids, 1)
            torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1)
        else:
            # Decode
            torch._dynamo.mark_dynamic(token_ids, 0)
            torch._dynamo.mark_dynamic(position_ids, 0)
            torch._dynamo.mark_dynamic(input_lens, 0)
            torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
            torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
            torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
            torch._dynamo.mark_dynamic(t, 0)
            torch._dynamo.mark_dynamic(p, 0)
240
        # Dummy run.
241
242
243
244
245
246
247
248
249
        self.model(token_ids,
                   position_ids,
                   attn_metadata,
                   input_lens,
                   t,
                   p,
                   num_samples,
                   kv_caches,
                   is_prompt=is_prompt)
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277

    def warmup_model(
        self,
        kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
    ) -> None:
        # Prefill
        logger.info("Compiling the model with different input shapes...")
        start = time.time()
        for batch_size in [1]:
            seq_len = 16
            while True:
                self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=True)
                xm.wait_device_ops()
                logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len)

                if seq_len >= self.model_config.max_model_len:
                    break
                num_tokens = batch_size * seq_len
                if num_tokens >= self.scheduler_config.max_num_batched_tokens:
                    break
                seq_len = seq_len * 2

        end = time.time()
        logger.info("Compilation for prefill done in %.2f s.", end - start)

        # Decode
        start = time.time()
        seq_len = 1
278
        batch_size = 8  # Must be in sync with _get_padded_batch_size()
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
        while True:
            self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=False)
            xm.wait_device_ops()
            logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len)

            if batch_size >= self.scheduler_config.max_num_seqs:
                break
            batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2

        end = time.time()
        logger.info("Compilation for decode done in %.2f s.", end - start)

    def _prepare_prompt(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
294
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
295
        assert len(seq_group_metadata_list) > 0
296
297
        input_tokens: List[int] = []
        input_positions: List[int] = []
298
        prompt_lens: List[int] = []
299
        slot_mapping: List[int] = []
300
301
302
303
304
305
306
307
308
309
310
311
312

        for seq_group_metadata in seq_group_metadata_list:
            assert seq_group_metadata.is_prompt
            seq_ids = list(seq_group_metadata.seq_data.keys())
            assert len(seq_ids) == 1
            seq_id = seq_ids[0]

            seq_data = seq_group_metadata.seq_data[seq_id]
            # Could include output tokens when a request is preempted.
            prompt_tokens = seq_data.get_token_ids()
            prompt_len = len(prompt_tokens)
            prompt_lens.append(prompt_len)

313
314
            input_tokens.extend(prompt_tokens)
            input_positions.extend(list(range(prompt_len)))
315
316
317
318
319
320
321

            assert seq_group_metadata.block_tables is not None
            block_table = seq_group_metadata.block_tables[seq_id]
            for i in range(prompt_len):
                block_number = block_table[i // self.block_size]
                block_offset = i % self.block_size
                slot = block_number * self.block_size + block_offset
322
323
324
325
326
327
328
329
330
331
332
333
334
                slot_mapping.append(slot)

            # Add paddings to EACH prompt to the smallest power of 2 that is
            # greater than or equal to the prompt length.
            # We pad the seq_len to reduce the compilation overhead.
            # We execute each prompt individually (i.e., with batch_size 1)
            # because the FlashAttention kernel does not support ragged inputs.
            # TODO(woosuk): Use SplashAttention to support ragged inputs.
            padded_prompt_len = _get_padded_prefill_len(prompt_len)
            num_paddings = padded_prompt_len - prompt_len
            input_tokens += [0] * num_paddings
            input_positions += [0] * num_paddings
            slot_mapping += [_PAD_SLOT_ID] * num_paddings
335
336
337

        assert len(prompt_lens) > 0
        num_prefills = len(prompt_lens)
338
339
340
341
342
343
344
345
346
        input_tokens = torch.tensor(input_tokens,
                                    dtype=torch.int32,
                                    device="cpu")
        input_positions = torch.tensor(input_positions,
                                       dtype=torch.int32,
                                       device="cpu")
        slot_mapping = torch.tensor(slot_mapping,
                                    dtype=torch.int64,
                                    device="cpu")
347
348
        prompt_lens = torch.tensor(prompt_lens,
                                   dtype=torch.int32,
349
                                   device="cpu")
350
351
        attn_metadata = self.attn_backend.make_metadata(
            num_prefills=num_prefills,
352
            num_prefill_tokens=0,  # NOTE: This is not used.
353
354
355
356
357
            num_decode_tokens=0,
            slot_mapping=slot_mapping,
            block_tables=None,
            context_lens=None,
        )
358
        return input_tokens, input_positions, attn_metadata, prompt_lens
359
360
361
362

    def _prepare_decode(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
363
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
364
365
366
367
368
369
        assert len(seq_group_metadata_list) > 0
        input_tokens: List[List[int]] = []
        input_positions: List[List[int]] = []
        slot_mapping: List[List[int]] = []
        context_lens: List[int] = []

370
371
        batch_idx = 0
        for seq_group_metadata in seq_group_metadata_list:
372
373
374
375
376
377
378
379
380
381
382
383
384
385
            assert not seq_group_metadata.is_prompt
            seq_ids = list(seq_group_metadata.seq_data.keys())
            for seq_id in seq_ids:
                seq_data = seq_group_metadata.seq_data[seq_id]
                generation_token = seq_data.get_last_token_id()
                input_tokens.append([generation_token])

                seq_len = seq_data.get_len()
                position = seq_len - 1
                input_positions.append([position])
                context_lens.append(seq_len)

                assert seq_group_metadata.block_tables is not None
                block_table = seq_group_metadata.block_tables[seq_id]
386
387
                self.block_tables[batch_idx, :len(block_table)] = block_table
                batch_idx += 1
388
389
390
391
392
393

                block_number = block_table[position // self.block_size]
                block_offset = position % self.block_size
                slot = block_number * self.block_size + block_offset
                slot_mapping.append([slot])

394
395
        batch_size = _get_padded_batch_size(batch_idx)
        num_paddings = batch_size - batch_idx
396
397
398
399
400
401
402
        input_tokens = input_tokens + [[0]] * num_paddings
        input_positions = input_positions + [[0]] * num_paddings
        slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings
        context_lens = context_lens + [0] * num_paddings

        input_tokens = torch.tensor(input_tokens,
                                    dtype=torch.int32,
403
                                    device="cpu")
404
405
        input_positions = torch.tensor(input_positions,
                                       dtype=torch.int32,
406
                                       device="cpu")
407
408
        slot_mapping = torch.tensor(slot_mapping,
                                    dtype=torch.int64,
409
                                    device="cpu")
410
411
        context_lens = torch.tensor(context_lens,
                                    dtype=torch.int32,
412
                                    device="cpu")
413
414
        block_tables = torch.tensor(self.block_tables[:batch_size],
                                    dtype=torch.int32,
415
                                    device="cpu")
416
417
        input_lens = torch.tensor([1] * batch_size,
                                  dtype=torch.int32,
418
                                  device="cpu")
419
420
421
422
423
424
425
426
        attn_metadata = self.attn_backend.make_metadata(
            num_prefills=0,
            num_prefill_tokens=0,
            num_decode_tokens=batch_size,
            slot_mapping=slot_mapping,
            block_tables=block_tables,
            context_lens=context_lens,
        )
427
        return input_tokens, input_positions, attn_metadata, input_lens
428
429
430
431
432

    def _prepare_sample(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        padded_batch_size: int,
433
    ) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
434
435
436
        assert len(seq_group_metadata_list) > 0
        t = []
        p = []
437
        n = []
438
439
        for seq_group_metadata in seq_group_metadata_list:
            sampling_params = seq_group_metadata.sampling_params
440
            t.append(sampling_params.temperature)
441
442
443
444
            if sampling_params.top_p != 1 and not _ENABLE_TOP_P:
                raise NotImplementedError(
                    "Top-p sampling is currently disabled for the TPU backend "
                    "due to performance issues.")
445
            p.append(sampling_params.top_p)
446
447
448
449
            if sampling_params.top_k != -1:
                raise NotImplementedError(
                    "Top-k sampling is currently disabled for the TPU backend "
                    "due to performance issues.")
450
            if sampling_params.n > _MAX_NUM_SAMPLES:
451
                raise NotImplementedError(
452
                    f"Best of > {_MAX_NUM_SAMPLES} is not supported by the TPU "
453
                    "backend.")
454
            n.append(sampling_params.n)
455
456
457
458
459
460
461
462
            if sampling_params.logprobs is not None:
                raise NotImplementedError(
                    "logprobs is not currently supported by the TPU backend.")
            if sampling_params.prompt_logprobs is not None:
                raise NotImplementedError(
                    "prompt_logprobs is not currently supported by the TPU "
                    "backend.")

463
464
465
466
            # Repeat the sampling params if the seq group has multiple seqs.
            num_seqs = len(seq_group_metadata.seq_data)
            t += [t[-1]] * (num_seqs - 1)
            p += [p[-1]] * (num_seqs - 1)
467
            n += [n[-1]] * (num_seqs - 1)
468
469

        num_paddings = padded_batch_size - len(t)
470
471
472
        t += [1.0] * num_paddings
        p += [1.0] * num_paddings

473
474
        t = torch.tensor(t, dtype=torch.float32, device="cpu")
        p = torch.tensor(p, dtype=torch.float32, device="cpu")
475
        return t, p, n
476

477
    def prepare_model_input(
478
        self,
479
        seq_group_metadata_list: List[SequenceGroupMetadata],
480
481
482
483
484
        virtual_engine: int = 0,
        finished_requests_ids: Optional[List[str]] = None,
    ) -> ModelInputForTPU:
        del finished_requests_ids  # Unused.
        assert virtual_engine == 0
485
486
487
        assert len(seq_group_metadata_list) > 0
        # NOTE: We assume that all sequences in the group are all prompts or
        # all decodes.
488
489
        is_prompt = seq_group_metadata_list[0].is_prompt
        if is_prompt:
490
491
492
            inputs = self._prepare_prompt(seq_group_metadata_list)
        else:
            inputs = self._prepare_decode(seq_group_metadata_list)
493
494
        input_tokens, input_positions, attn_metadata, input_lens = inputs
        padded_batch_size = input_tokens.shape[0]
495
496
        t, p, n = self._prepare_sample(seq_group_metadata_list,
                                       padded_batch_size)
497
        num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
498

499
500
501
502
503
        seq_groups = [
            list(metadata.seq_data.keys())
            for metadata in seq_group_metadata_list
        ]
        return ModelInputForTPU(input_tokens, input_positions, attn_metadata,
504
                                input_lens, t, p, num_samples, n, seq_groups)
505
506
507
508
509
510
511

    def make_model_input_from_broadcasted_tensor_dict(
            self, tensor_dict: Dict[str, Any]) -> ModelInputForTPU:
        model_input = ModelInputForTPU.from_broadcasted_tensor_dict(
            tensor_dict, attn_backend=self.attn_backend)
        return model_input

512
    @torch.no_grad()
513
514
515
    def execute_model(
        self,
        model_input: ModelInputForTPU,
516
        kv_caches: Optional[List[Any]],
517
518
519
520
        intermediate_tensors: Optional[IntermediateTensors] = None,
        num_steps: int = 1,
    ) -> List[SamplerOutput]:
        assert intermediate_tensors is None
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
        if not model_input.is_first_multi_step:
            if not model_input.is_last_step:
                return []

            use_async_out_proc = model_input.async_callback is not None
            sampler_outputs = []
            num_outputs = len(self.cached_step_outputs)
            for i in range(num_outputs):
                next_token_ids = self.cached_step_outputs.pop(0)
                next_token_ids = next_token_ids.cpu().tolist()
                sampler_output = _make_decode_output(next_token_ids,
                                                     model_input.seq_groups)
                sampler_outputs.append(sampler_output)

                if i < num_outputs - 1 and use_async_out_proc:
                    assert model_input.async_callback is not None
                    ctx = model_input.async_callback.keywords[  # type: ignore
                        "ctx"]
                    ctx.append_output(
                        outputs=[sampler_output],
                        seq_group_metadata_list=ctx.seq_group_metadata_list,
                        scheduler_outputs=ctx.scheduler_outputs,
                        is_async=False,
544
545
                        is_last_step=False,
                        is_first_step_output=i == 0)
546
547
548
549
550
551
552
                    model_input.async_callback()
            if use_async_out_proc:
                return [sampler_outputs[-1]]
            else:
                return sampler_outputs

        is_prompt = model_input.attn_metadata.num_prefills > 0
553
        if is_prompt:
554
            assert num_steps == 1
555
556
557
558
559
560
561
            # NOTE(woosuk): Since the FlashAttention kernel does not support
            # ragged inputs, we split the prompts into different batches and
            # process them separately. This is a temporary hack that should be
            # optimized by using SplashAttention.
            orig_slot_mapping = model_input.attn_metadata.slot_mapping
            batch_size = model_input.input_lens.shape[0]
            start_idx = 0
562
            next_token_ids = []
563
564
565
566
567
568
            for i in range(batch_size):
                # Get the actual prefill_len.
                prefill_len = model_input.input_lens[i:i + 1].item()
                prefill_len = _get_padded_prefill_len(prefill_len)
                end_idx = start_idx + prefill_len

569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
                token_ids = model_input.token_ids[None, start_idx:end_idx].to(
                    self.device)
                position_ids = model_input.position_ids[None,
                                                        start_idx:end_idx].to(
                                                            self.device)
                attn_metadata = model_input.attn_metadata
                attn_metadata.num_prefills = 1
                attn_metadata.slot_mapping = orig_slot_mapping[
                    None, start_idx:end_idx].to(self.device)
                input_lens = model_input.input_lens[i:i + 1].to(self.device)
                t = model_input.t[i:i + 1].to(self.device)
                p = model_input.p[i:i + 1].to(self.device)
                output_token_ids = self.model(token_ids,
                                              position_ids,
                                              attn_metadata,
                                              input_lens,
                                              t,
                                              p,
                                              model_input.num_samples,
                                              kv_caches,
                                              is_prompt=True)
                next_token_ids.append(output_token_ids[0])
591
                start_idx = end_idx
592

593
594
            if model_input.async_callback is not None:
                model_input.async_callback()
595
            # Retrieve the outputs to CPU.
596
597
598
599
600
601
602
603
604
605
606
607
            next_token_ids = [
                output_token_ids.cpu().tolist()
                for output_token_ids in next_token_ids
            ]

            # NOTE(woosuk): Minimal code to construct the sampler outputs.
            # The TPU backend does not reuse the sampler, since the TPU backend
            # does not support advanced sampling parameters such as logprobs.
            zero_logprob = Logprob(0.0)
            sampler_outputs = []
            for i, seq_group in enumerate(model_input.seq_groups):
                seq_ids = seq_group
608
609
                assert len(seq_ids) == 1
                seq_id = seq_ids[0]
610
                seq_outputs = []
611
                for j in range(model_input.n[i]):
612
                    next_token_id = next_token_ids[i][j]
613
614
615
                    seq_outputs.append(
                        SequenceOutput(seq_id, next_token_id,
                                       {next_token_id: zero_logprob}))
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
                sampler_outputs.append(
                    CompletionSequenceGroupOutput(seq_outputs, None))
            return [SamplerOutput(sampler_outputs)]
        else:
            token_ids = model_input.token_ids.to(self.device)
            position_ids = model_input.position_ids.to(self.device)
            attn_metadata = model_input.attn_metadata
            attn_metadata.slot_mapping = attn_metadata.slot_mapping.to(
                self.device)
            attn_metadata.block_tables = attn_metadata.block_tables.to(
                self.device)
            attn_metadata.context_lens = attn_metadata.context_lens.to(
                self.device)
            t = model_input.t.to(self.device)
            p = model_input.p.to(self.device)
            input_lens = model_input.input_lens.to(self.device)
            for i in range(num_steps):
                slot_mapping = attn_metadata.slot_mapping
                output_token_ids = self.model(token_ids,
                                              position_ids,
                                              attn_metadata,
                                              input_lens,
                                              t,
                                              p,
                                              model_input.num_samples,
                                              kv_caches,
                                              is_prompt=False)
                self.cached_step_outputs.append(output_token_ids)

                if i < num_steps - 1:
                    # Prepare the inputs for the next step.
                    token_ids = output_token_ids.unsqueeze(dim=1).int()
                    position_ids = position_ids + 1
                    attn_metadata.context_lens = attn_metadata.context_lens + 1

                    block_tables = attn_metadata.block_tables
                    block_number = block_tables.gather(
                        1,
                        position_ids.long() // self.block_size)
                    block_offset = position_ids % self.block_size

                    is_padding = slot_mapping == _PAD_SLOT_ID
                    slot_mapping = block_number * self.block_size + block_offset
                    slot_mapping = slot_mapping.long()
                    slot_mapping = torch.where(is_padding, _PAD_SLOT_ID,
                                               slot_mapping)
                    attn_metadata.slot_mapping = slot_mapping

            if model_input.async_callback is not None:
                model_input.async_callback()

            if num_steps > 1:
                return []
            # Retrieve the outputs to CPU.
            next_token_ids = self.cached_step_outputs.pop(0)
            next_token_ids = next_token_ids.cpu().tolist()
            sampler_output = _make_decode_output(next_token_ids,
                                                 model_input.seq_groups)
            return [sampler_output]
675
676


youkaichao's avatar
youkaichao committed
677
class ModelWrapper(TorchCompileWrapperWithCustomDispatcher):
678
679

    def __init__(self, model: nn.Module):
680
        self.model = model
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
        compiled_callable = torch.compile(self.forward,
                                          backend="openxla",
                                          fullgraph=True,
                                          dynamic=False)
        super().__init__(compiled_callable)

    def __call__(self, *args, is_prompt: bool, **kwargs):
        if len(self.compiled_codes) < 3 or not self.use_custom_dispatcher:
            # not fully compiled yet, or not using the custom dispatcher,
            # let PyTorch handle it
            return self.compiled_callable(*args, **kwargs)
        # the 3 compiled codes are:
        # 0: for profiling
        # 1: for prompt
        # 2: for decode
        # dispatch to the compiled code directly, skip PyTorch
        if is_prompt:
            with self.dispatch_to_code(1):
                return self.forward(*args, **kwargs)
        else:
            with self.dispatch_to_code(2):
                return self.forward(*args, **kwargs)
703
704
705
706
707
708
709
710
711

    def forward(
        self,
        token_ids: torch.Tensor,
        position_ids: torch.Tensor,
        attn_metadata: AttentionMetadata,
        input_lens: torch.Tensor,
        t: torch.Tensor,
        p: torch.Tensor,
712
        num_samples: int,
713
        kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
714
715
716
717
718
719
720
721
722
723
    ) -> torch.Tensor:
        """Executes the forward pass of the model and samples the next token.

        Args:
            token_ids: The input token IDs of shape [batch_size, seq_len].
            position_ids: The input position IDs of shape [batch_size, seq_len].
            attn_metadata: The Pallas attention metadata.
            input_lens: The actual input lengths of shape [batch_size].
            t: The sampling temperature of shape [batch_size].
            p: The top-p probability of shape [batch_size].
724
725
726
            num_samples: Number of samples to draw from each logits vector.
            kv_caches: The key and value caches. They can be None during the
                memory profiling at initialization.
727
728
729
        """
        batch_size, seq_len = token_ids.shape
        # Calculate the positions to sample from.
730
        start_indicies = torch.arange(
731
            batch_size, dtype=torch.int32, device=input_lens.device) * seq_len
732
        logits_indices = start_indicies + input_lens - 1
733
734
735
736
737
738
739
740
741
742
743

        # FIXME(woosuk): This is a temporary hack to avoid using the existing
        # sampler and sampling metadata.
        sampling_metadata = SamplingMetadata(
            seq_groups=[],
            selected_token_indices=logits_indices,
            categorized_sample_indices={},
            num_prompts=attn_metadata.num_prefills,
        )

        # Skip this in memory profiling at initialization.
744
        if kv_caches[0][0].numel() > 0:
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
            # index_copy_(slot_mapping) only works when the inserted dimension
            # is 0. However, the KV cache in the Pallas backend has the shape
            # [num_kv_heads, num_blocks, block_size, head_size]. To make it
            # work, we need to flatten the first three dimensions and modify
            # the slot_mapping accordingly.
            num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape
            slot_mapping = attn_metadata.slot_mapping
            slot_mapping = slot_mapping.flatten()
            head_indicies = torch.arange(0,
                                         num_kv_heads,
                                         device=slot_mapping.device,
                                         dtype=slot_mapping.dtype)
            head_indicies *= block_size * num_blocks
            slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view(
                -1, num_kv_heads)
            slot_mapping = slot_mapping + head_indicies.view(1, -1)
            slot_mapping = slot_mapping.flatten()
            attn_metadata.slot_mapping = slot_mapping

        hidden_states = self.model(
            token_ids,
            position_ids,
            kv_caches,
            attn_metadata,
        )
        hidden_states = hidden_states.flatten(0, 1)
        logits = self.model.compute_logits(hidden_states, sampling_metadata)

773
774
775
776
777
778
779
        # Argmax sampling.
        argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
        argmax_token_ids = argmax_token_ids.repeat(1, num_samples)

        # Zero temperature means greedy decoding. Avoid division by zero.
        nonzero_t = torch.where(t != 0, t, 1.0)
        logits = logits / nonzero_t.unsqueeze(dim=1)
780
781
        if _ENABLE_TOP_P:
            logits = _apply_top_p(logits, p.unsqueeze(dim=1))
782
783

        # Random sampling.
784
        probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
785
786
787
        sampled_token_ids = torch.multinomial(probs,
                                              num_samples,
                                              replacement=True)
788
789
790
        if num_samples == 1:
            argmax_token_ids = argmax_token_ids.squeeze(dim=-1)
            sampled_token_ids = sampled_token_ids.squeeze(dim=-1)
791
792
        next_token_ids = torch.where(t != 0, sampled_token_ids,
                                     argmax_token_ids)
793
794
795
796
797
798
799
800
801
802
803
804
805
        return next_token_ids


def _get_padded_prefill_len(x: int) -> int:
    # NOTE(woosuk): The pallas FlashAttention kernel requires the sequence
    # length to be a multiple of 16. We pad the prompt length to the nearest
    # multiple of 16. This is also good for performance.
    if x <= 16:
        return 16
    return 1 << (x - 1).bit_length()


def _get_padded_batch_size(batch_size: int) -> int:
806
807
808
809
    # The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16.
    # To meet this requirement in the simplest way, we set the minimal batch
    # size to 8.
    if batch_size <= 8:
810
811
812
813
814
815
816
817
818
819
820
821
        return 8
    else:
        return ((batch_size + 15) // 16) * 16


def _apply_top_p(logits: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
    logits_sorted = torch.sort(logits, dim=-1, descending=True).values
    sorted_cum_probs = torch.cumsum(logits_sorted.softmax(dim=-1), dim=-1)
    cutoff_index = torch.sum(sorted_cum_probs < p, dim=-1, keepdim=True)
    cutoff_logit = torch.gather(logits_sorted, -1, cutoff_index)
    logits = logits.masked_fill_(logits < cutoff_logit, -float("inf"))
    return logits
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842


def _make_decode_output(
    next_token_ids: List[int],
    seq_groups: List[List[int]],
) -> SamplerOutput:
    zero_logprob = Logprob(0.0)
    sampler_outputs = []
    batch_idx = 0
    for seq_group in seq_groups:
        seq_ids = seq_group
        seq_outputs = []
        for seq_id in seq_ids:
            next_token_id = next_token_ids[batch_idx]
            seq_outputs.append(
                SequenceOutput(seq_id, next_token_id,
                               {next_token_id: zero_logprob}))
            batch_idx += 1
        sampler_outputs.append(CompletionSequenceGroupOutput(
            seq_outputs, None))
    return SamplerOutput(sampler_outputs)