flash_causal_lm.py 35.4 KB
Newer Older
1
2
import math
import itertools
3
4
5
import torch
import torch.distributed

6
7
import numpy as np

8
from dataclasses import dataclass
9
from loguru import logger
10
from opentelemetry import trace
11
from transformers import PreTrainedTokenizerBase
12
from typing import Optional, Tuple, List, Type, Union, Dict
13
14
15
16
17
18
19
20
21

from text_generation_server.models import Model
from text_generation_server.models.types import (
    Batch,
    PrefillTokens,
    Generation,
    GeneratedText,
)
from text_generation_server.pb import generate_pb2
22
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
23
24
25

tracer = trace.get_tracer(__name__)

26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
BLOCK_SIZE = 16
# Will be set in warmup
CACHE_MANAGER: Optional["CacheManager"] = None


class CacheManager:
    def __init__(
        self,
        num_blocks: int,
        num_layers: int,
        num_heads: int,
        head_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ):
        self.block_size = BLOCK_SIZE

        element_size = torch.tensor([], dtype=dtype).element_size()
        x = self.block_size // element_size

        self.kv_cache = [
            (
                torch.empty(
                    (num_blocks, num_heads, head_size // x, self.block_size, x),
                    dtype=dtype,
                    device=device,
                ),
                torch.empty(
                    (num_blocks, num_heads, head_size, self.block_size),
                    dtype=dtype,
                    device=device,
                ),
            )
            for _ in range(num_layers)
        ]
        self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu")
        self.slots = torch.arange(
            0, num_blocks * self.block_size, dtype=torch.int32
        ).view(num_blocks, self.block_size)

    def allocate(self, batch: "FlashCausalLMBatch"):
        # Get free blocks indices by finding values in mask that are not set to 0
        free_block_indices = self.free_block_mask.nonzero()
        assert (
            len(free_block_indices) >= batch.blocks
        ), f"Out of available cache blocks: asked {batch.blocks}, only {len(free_block_indices)} free blocks"

        # Slice by the number of required blocks
        block_indices = free_block_indices[: batch.blocks]
        block_indices = block_indices.flatten()

        # Padded block tables
        block_tables_tensor = torch.zeros(
            (len(batch), batch.max_blocks), dtype=torch.int32
        )

        # Allocate paged attention blocks
        cumulative_blocks = 0
        slots = []
        block_tables = []
        for i, (needed_blocks, needed_slots) in enumerate(batch.needed_blocks_slots):
            # Get allocated blocks for this sequence
            allocated_blocks = block_indices[
                cumulative_blocks : cumulative_blocks + needed_blocks
            ]
            # Get slots for the allocated blocks
            allocated_slots = self.slots[allocated_blocks].flatten()[:needed_slots]

            slots.append(allocated_slots)
            block_tables.append(allocated_blocks.tolist())
            block_tables_tensor[i, :needed_blocks] = allocated_blocks
            cumulative_blocks += needed_blocks

        batch.needed_blocks_slots = None
        batch.block_tables = block_tables
        batch.block_tables_tensor = block_tables_tensor.to(batch.input_ids.device)
        batch.slots = torch.concat(slots).to(batch.input_ids.device)

        # Allocate the required number of blocks by setting the mask to 0
        self.free_block_mask[block_indices] = 0

    def free(self, block_indices: Optional[List[int]]):
        if block_indices is not None and block_indices:
            # Reset mask
            self.free_block_mask[block_indices] = 1

112
113
114
115
116

@dataclass
class FlashCausalLMBatch(Batch):
    batch_id: int
    requests: List[generate_pb2.Request]
117
118
    # request id -> idx in list mapping
    requests_idx_mapping: Dict[int, int]
119
120

    # Decoder values
121
122
123
    input_ids: torch.Tensor
    position_ids: torch.Tensor

124
125
126
127
    # tensor of length b holding starting offset of each sequence, only used in prefill
    start_seq_prefill: Optional[torch.Tensor]
    # tensor of length b holding ending offset of each sequence, only used in prefill
    end_seq_prefill: Optional[torch.Tensor]
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146

    # Paged Attention values

    # Set when creating the batch
    # CPU tensor of length b indicating the start of each sequence in slots
    start_slots: torch.Tensor
    # tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
    slot_indices: torch.Tensor
    # List of tuple of ints representing the number of blocks and slots needed by each sequence
    needed_blocks_slots: Optional[List[Tuple[int, int]]]

    # Set in prefill by the CacheManager
    # list of length b of list of length s_i // block_size
    block_tables: Optional[List[List[int]]]
    # tensor of size [b, max_seqlen // block_size] holding the paged attention block tables for all sequences
    block_tables_tensor: Optional[torch.Tensor]
    # tensor of length \sum_{i=0}^{b} max_s_i  holding the paged attention slots for all sequences
    slots: Optional[torch.Tensor]

147
148
    max_seqlen: int

149
150
151
152
153
    # Prefill metadata tensors to efficiently compute logprobs
    prefill_head_indices: Optional[torch.Tensor]
    prefill_next_token_indices: Optional[torch.tensor]
    prefill_cu_outlens: Optional[List[int]]

154
155
    # All tokens
    all_input_ids: List[List[int]]
156
    all_input_ids_tensor: torch.Tensor
157
158
159

    # Lengths of all generations present in the batch
    input_lengths: List[int]
160
    input_lengths_tensor: torch.Tensor
161
162
    prefix_offsets: List[Optional[int]]
    read_offsets: List[Optional[int]]
163
164

    # Generation helpers
165
    next_token_chooser: HeterogeneousNextTokenChooser
166
167
    stopping_criterias: List[StoppingCriteria]

168
169
170
171
    # Number of blocks in this batch
    blocks: int
    # Maximum number of blocks
    max_blocks: int
172

173
174
    def to_pb(self) -> generate_pb2.CachedBatch:
        return generate_pb2.CachedBatch(
175
            id=self.batch_id,
176
            request_ids=[r.id for r in self.requests],
177
            size=len(self),
178
            max_tokens=self.blocks * BLOCK_SIZE,
179
180
181
182
183
184
185
        )

    @classmethod
    def from_pb(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
186
        dtype: torch.dtype,
187
        device: torch.device,
188
    ) -> "FlashCausalLMBatch":
189
190
191
192
193
194
195
196
197
198
        batch_inputs = []
        max_truncation = 0
        for r in pb.requests:
            batch_inputs.append(r.inputs)
            max_truncation = max(max_truncation, r.truncate)

        batch_tokenized_inputs = tokenizer(
            batch_inputs, truncation=True, max_length=max_truncation
        )["input_ids"]

199
        position_ids = []
200
201
        start_seq_prefill = []
        end_seq_prefill = []
202
203
204
        needed_blocks_slots = []
        start_slots = []
        slot_indices = []
205
206

        input_lengths = []
207
208
        prefix_offsets = []
        read_offsets = []
209
        all_input_ids = []
210
        requests_idx_mapping = {}
211

212
213
214
215
216
217
        all_prefill_logprobs = True
        no_prefill_logprobs = True
        prefill_head_indices = []
        prefill_next_token_indices = []
        prefill_cu_outlens = [0]

218
        next_token_chooser_parameters = []
219
220
221
222
        stopping_criterias = []

        # Cumulative length
        cumulative_length = 0
223
        cumulative_max_length = 0
224
        prefill_out_cumulative_length = 0
225

226
227
        blocks = 0
        max_seqlen = 0
228
        max_length = 0
229
        max_blocks = 0
230

231
        # Parse batch
232
233
234
        for i, (r, tokenized_input) in enumerate(
            zip(pb.requests, batch_tokenized_inputs)
        ):
235
236
237
            # request id -> idx in list mapping
            requests_idx_mapping[r.id] = i

238
            tokenized_input = tokenized_input[-r.truncate :]
239

240
241
            input_length = len(tokenized_input)
            input_lengths.append(input_length)
242

243
            prefix_offsets.append(input_length - 5)
244
            read_offsets.append(input_length)
245

246
            all_input_ids.append(tokenized_input)
247
248

            # Position ids
249
250
            request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
            position_ids.append(request_position_ids)
251
252

            # Add cumulative lengths of all previous inputs
253
254
            start_seq_prefill.append(cumulative_length)
            end_seq_prefill.append(cumulative_length + input_length)
255

256
            next_token_chooser_parameters.append(r.parameters)
257

258
259
260
            stopping_criteria = StoppingCriteria.from_pb(
                r.stopping_parameters, tokenizer
            )
261
            max_new_tokens = stopping_criteria.max_new_tokens
262
            stopping_criterias.append(stopping_criteria)
263

264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
            # Paged attention
            # Remove one as the first token des not have a past
            total_tokens = input_length + max_new_tokens - 1
            needed_blocks = math.ceil(total_tokens / BLOCK_SIZE)
            blocks += needed_blocks
            needed_blocks_slots.append((needed_blocks, total_tokens))
            start_slots.append(cumulative_max_length)

            request_slot_indices = torch.arange(
                cumulative_max_length,
                cumulative_max_length + input_length,
                dtype=torch.int64,
            )
            slot_indices.append(request_slot_indices)

279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
            all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
            no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs

            if r.prefill_logprobs:
                prefill_head_indices.append(request_position_ids + cumulative_length)
                prefill_next_token_indices.append(
                    prefill_out_cumulative_length + input_length - 1
                )
                prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
                prefill_out_cumulative_length += input_length
            else:
                prefill_head_indices.append(
                    torch.tensor(
                        [cumulative_length + input_length - 1], dtype=torch.int32
                    )
                )
                prefill_next_token_indices.append(prefill_out_cumulative_length)
                prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
                prefill_out_cumulative_length += 1

299
300
            # Update
            cumulative_length += input_length
301
302
303
            cumulative_max_length += total_tokens
            max_seqlen = max(max_seqlen, input_length)
            max_blocks = max(max_blocks, needed_blocks)
304
305
306
307
308
            max_length = max(max_length, input_length + max_new_tokens)

        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
            next_token_chooser_parameters, dtype, device
        )
309
        start_slots = torch.tensor(start_slots, dtype=torch.int64)
310
311
312
313
314
315
316

        # Padded all_input_ids_tensor
        all_input_ids_tensor = np.zeros(
            (len(all_input_ids), max_length), dtype=np.int64
        )
        for i, input_ids in enumerate(all_input_ids):
            all_input_ids_tensor[i, : len(input_ids)] = input_ids
317

318
319
320
321
322
        # Create tensors on device
        all_input_ids_tensor = torch.tensor(
            all_input_ids_tensor, dtype=torch.int64, device=device
        )

323
324
325
        if len(pb.requests) > 1:
            input_ids = np.concatenate(all_input_ids, dtype=np.int64)
            position_ids = torch.cat(position_ids)
326
            slot_indices = torch.cat(slot_indices)
327
328
329
        else:
            input_ids = all_input_ids[0]
            position_ids = position_ids[0]
330
            slot_indices = slot_indices[0]
331

332
333
334
335
336
337
        start_seq_prefill = torch.tensor(
            start_seq_prefill, device=device, dtype=torch.int32
        )
        end_seq_prefill = torch.tensor(
            end_seq_prefill, device=device, dtype=torch.int32
        )
338

339
340
        position_ids = position_ids.to(device)
        slot_indices = slot_indices.to(device)
341
        input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
342
343
        input_lengths_tensor = torch.tensor(
            input_lengths, dtype=torch.int32, device=device
344
        )
345

346
347
        if all_prefill_logprobs:
            prefill_head_indices = None
348
            prefill_next_token_indices = end_seq_prefill - 1
349
        elif no_prefill_logprobs:
350
            prefill_head_indices = end_seq_prefill - 1
351
352
353
354
355
356
357
358
359
            prefill_next_token_indices = None
        else:
            prefill_head_indices = torch.tensor(
                torch.cat(prefill_head_indices), dtype=torch.int64, device=device
            )
            prefill_next_token_indices = torch.tensor(
                prefill_next_token_indices, dtype=torch.int64, device=device
            )

360
361
362
        return cls(
            batch_id=pb.id,
            requests=pb.requests,
363
            requests_idx_mapping=requests_idx_mapping,
364
365
            input_ids=input_ids,
            position_ids=position_ids,
366
367
            start_seq_prefill=start_seq_prefill,
            end_seq_prefill=end_seq_prefill,
368
369
370
371
372
373
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=needed_blocks_slots,
            block_tables=None,
            block_tables_tensor=None,
            slots=None,
374
            max_seqlen=max_seqlen,
375
376
377
            prefill_head_indices=prefill_head_indices,
            prefill_next_token_indices=prefill_next_token_indices,
            prefill_cu_outlens=prefill_cu_outlens,
378
            input_lengths=input_lengths,
379
            input_lengths_tensor=input_lengths_tensor,
380
381
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
382
            all_input_ids=all_input_ids,
383
384
            all_input_ids_tensor=all_input_ids_tensor,
            next_token_chooser=next_token_chooser,
385
            stopping_criterias=stopping_criterias,
386
387
            blocks=blocks,
            max_blocks=max_blocks,
388
389
        )

390
    @tracer.start_as_current_span("filter")
391
392
    def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
        if len(request_ids) == 0:
393
394
            raise ValueError("Batch must have at least one request")
        # We assume that if len(requests) == len(self) then the requests are the same
395
        if len(request_ids) == len(self):
396
397
            return self

398
        device = self.input_ids.device
399

400
401
402
        # New values after filtering
        requests_idx_mapping = {}

403
404
405
        # Used to index into tensors
        indices = []

406
407
408
        # slots to keep after filtering
        slot_filtering_indices = torch.zeros(
            self.slots.shape[0], dtype=torch.bool, device=device
409
410
        )

411
        # Create on CPU to only move to GPU once instead of at every copy
412
        slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
413
414
        max_seqlen = 0

415
        requests = []
416
417
        start_slots = []
        block_tables = []
418
419
        all_input_ids = []

420
        input_lengths = []
421
422
        prefix_offsets = []
        read_offsets = []
423

424
425
        stopping_criterias = []

426
427
428
429
430
        blocks = 0
        max_blocks = 0
        # Cumulative length
        cumulative_max_length = 0

431
432
        for i, request_id in enumerate(request_ids):
            idx = self.requests_idx_mapping[request_id]
433
            indices.append(idx)
434
435
436
            requests_idx_mapping[request_id] = i

            requests.append(self.requests[idx])
437
438
439
440

            # Get length
            request_input_length = self.input_lengths[idx]
            max_seqlen = max(max_seqlen, request_input_length)
441

442
443
444
            all_input_ids.append(self.all_input_ids[idx])

            input_lengths.append(request_input_length)
445
446
            prefix_offsets.append(self.prefix_offsets[idx])
            read_offsets.append(self.read_offsets[idx])
447

448
449
            stopping_criteria = self.stopping_criterias[idx]
            stopping_criterias.append(stopping_criteria)
450

451
            remaining_tokens = (
452
453
                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
            )
454

455
456
457
458
459
            request_block_table = self.block_tables[idx]
            blocks += len(request_block_table)
            block_tables.append(request_block_table)
            start_slots.append(cumulative_max_length)

460
            # Copy to tensor (CPU)
461
            slot_indices[i] = cumulative_max_length + request_input_length - 1
462
463

            # Set slice
464
465
466
467
468
            slot_filtering_indices[
                self.start_slots[idx] : self.start_slots[idx]
                + request_input_length
                + remaining_tokens
                - 1
469
470
471
            ] = True

            cumulative_max_length += request_input_length + remaining_tokens - 1
472

473
474
475
476
477
478
479
480
481
482
483
484
485
486
            max_blocks = max(max_blocks, len(request_block_table))

        global CACHE_MANAGER
        block_indices_to_free = []
        # Iterate on all requests
        for i, r in enumerate(self.requests):
            # Filter requests that are not part of the new batch
            if r.id not in requests_idx_mapping.keys():
                block_indices_to_free.extend(self.block_tables[i])
        # Free blocks
        CACHE_MANAGER.free(block_indices_to_free)
        # Needed to avoid dropping blocks when the batches will go out of scope
        self.block_tables = None

487
488
489
490
        # Index into tensors
        input_ids = self.input_ids[indices]
        position_ids = self.position_ids[indices]
        all_input_ids_tensor = self.all_input_ids_tensor[indices]
491
492
493
        block_tables_tensor = self.block_tables_tensor[indices]
        input_lengths_tensor = self.input_lengths_tensor[indices]
        slots = self.slots[slot_filtering_indices]
494
        next_token_chooser = self.next_token_chooser.filter(indices)
495
496

        start_slots = torch.tensor(start_slots, dtype=torch.int64)
497

498
        # Move to GPU now that we have the whole tensor
499
        slot_indices = slot_indices.to(device)
500

501
502
503
504
505
506
        return FlashCausalLMBatch(
            batch_id=self.batch_id,
            requests=requests,
            requests_idx_mapping=requests_idx_mapping,
            input_ids=input_ids,
            position_ids=position_ids,
507
508
            start_seq_prefill=None,
            end_seq_prefill=None,
509
510
511
512
513
514
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=None,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
515
            max_seqlen=max_seqlen,
516
517
518
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
519
            input_lengths=input_lengths,
520
            input_lengths_tensor=input_lengths_tensor,
521
522
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
523
524
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
525
            next_token_chooser=next_token_chooser,
526
            stopping_criterias=stopping_criterias,
527
528
            blocks=blocks,
            max_blocks=max_blocks,
529
530
531
532
533
534
535
536
537
        )

    @classmethod
    @tracer.start_as_current_span("concatenate")
    def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
        # Batch attributes
        requests = []
        requests_idx_mapping = {}

538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
        blocks = 0
        total_batch_size = 0
        total_slots = 0
        max_blocks = 0
        max_length = 0
        max_seqlen = 0
        for b in batches:
            total_batch_size += len(b)
            total_slots += len(b.slots)
            blocks += b.blocks
            max_blocks = max(max_blocks, b.max_blocks)
            max_seqlen = max(max_seqlen, b.max_seqlen)
            max_length = max(
                max_length,
                max(
                    input_length
                    + stopping_criteria.max_new_tokens
                    - stopping_criteria.current_tokens
                    for input_length, stopping_criteria in zip(
                        b.input_lengths, b.stopping_criterias
                    )
                ),
            )
561
562
563

        input_ids = batches[0].input_ids.new_empty(total_batch_size)
        position_ids = batches[0].position_ids.new_empty(total_batch_size)
564
565
566
567
568
569
570
571
572
573
        slots = batches[0].slots.new_empty(total_slots)
        slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
        input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
            total_batch_size
        )
        block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
            (total_batch_size, max_blocks)
        )
        all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
            (total_batch_size, max_length)
574
        )
575

576
577
        start_slots = []
        block_tables = []
578
579
580
        all_input_ids = []

        input_lengths = []
581
582
        prefix_offsets = []
        read_offsets = []
583

584
        next_token_chooser_parameters = []
585
586
        stopping_criterias = []

587
        # Cumulative length
588
        cumulative_batch_size = 0
589
        cumulative_slots = 0
590
591
592

        for i, batch in enumerate(batches):
            requests.extend(batch.requests)
593
594
595
596
597
598
599
600

            if i == 0:
                requests_idx_mapping = batch.requests_idx_mapping
            else:
                # We need to offset the mapping for each batch by the cumulative batch size
                for k, v in batch.requests_idx_mapping.items():
                    requests_idx_mapping[k] = v + cumulative_batch_size

601
602
            start_index = cumulative_batch_size
            end_index = cumulative_batch_size + len(batch)
603
604
            slots_start_index = cumulative_slots
            slots_end_index = cumulative_slots + len(batch.slots)
605
606
607
608

            # Copy tensors (GPU)
            input_ids[start_index:end_index] = batch.input_ids
            position_ids[start_index:end_index] = batch.position_ids
609
610
611
            slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
            input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
            slots[slots_start_index:slots_end_index] = batch.slots
612

613
614
615
            all_input_ids_tensor[
                start_index:end_index, : batch.all_input_ids_tensor.shape[1]
            ] = batch.all_input_ids_tensor[:, :max_length]
616

617
618
619
            block_tables_tensor[
                start_index:end_index, : batch.block_tables_tensor.shape[1]
            ] = batch.block_tables_tensor[:, :max_blocks]
620

621
622
623
            start_slots.append(batch.start_slots + cumulative_slots)

            block_tables.extend(batch.block_tables)
624
625
            all_input_ids.extend(batch.all_input_ids)

626
            input_lengths.extend(batch.input_lengths)
627
628
            prefix_offsets.extend(batch.prefix_offsets)
            read_offsets.extend(batch.read_offsets)
629

630
            next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
631
632
633
            stopping_criterias.extend(batch.stopping_criterias)

            # Update
634
            cumulative_batch_size += len(batch)
635
            cumulative_slots += len(batch.slots)
636

637
        start_slots = torch.concat(start_slots)
638

639
        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
640
641
642
            next_token_chooser_parameters,
            dtype=batches[0].next_token_chooser.dtype,
            device=batches[0].next_token_chooser.device,
643
644
        )

645
646
647
648
        # Needed to avoid dropping blocks when the batches will go out of scope
        for b in batches:
            b.block_tables = None

649
650
651
        return FlashCausalLMBatch(
            batch_id=batches[0].batch_id,
            requests=requests,
652
            requests_idx_mapping=requests_idx_mapping,
653
654
            input_ids=input_ids,
            position_ids=position_ids,
655
656
            start_seq_prefill=None,
            end_seq_prefill=None,
657
658
659
660
661
662
            start_slots=start_slots,
            slot_indices=slot_indices,
            needed_blocks_slots=None,
            block_tables=block_tables,
            block_tables_tensor=block_tables_tensor,
            slots=slots,
663
            max_seqlen=max_seqlen,
664
665
666
            prefill_head_indices=None,
            prefill_next_token_indices=None,
            prefill_cu_outlens=None,
667
            input_lengths=input_lengths,
668
            input_lengths_tensor=input_lengths_tensor,
669
670
            prefix_offsets=prefix_offsets,
            read_offsets=read_offsets,
671
672
            all_input_ids=all_input_ids,
            all_input_ids_tensor=all_input_ids_tensor,
673
            next_token_chooser=next_token_chooser,
674
            stopping_criterias=stopping_criterias,
675
676
            blocks=blocks,
            max_blocks=max_blocks,
677
678
        )

679
680
681
682
683
684
    def __del__(self):
        if self.block_tables is not None and self.block_tables:
            global CACHE_MANAGER
            # Free blocks
            CACHE_MANAGER.free(list(itertools.chain.from_iterable(self.block_tables)))

685
686
687
688
689
690
691
    def __len__(self):
        return len(self.requests)


class FlashCausalLM(Model):
    def __init__(
        self,
692
693
694
695
696
697
698
699
700
        model: torch.nn.Module,
        tokenizer: PreTrainedTokenizerBase,
        num_layers: int,
        num_kv_heads: int,
        head_size: int,
        dtype: torch.dtype,
        device: torch.device,
        rank: int = 0,
        world_size: int = 1,
701
    ):
702
703
704
        self.num_layers = num_layers
        self.num_kv_heads = num_kv_heads
        self.head_size = head_size
705
706

        super(FlashCausalLM, self).__init__(
707
            model=model,
708
709
710
711
            tokenizer=tokenizer,
            requires_padding=False,
            dtype=dtype,
            device=device,
712
713
            rank=rank,
            world_size=world_size,
714
715
716
717
718
719
        )

    @property
    def batch_type(self) -> Type[FlashCausalLMBatch]:
        return FlashCausalLMBatch

720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
    def warmup(self, batch: FlashCausalLMBatch, max_total_tokens: int):
        global CACHE_MANAGER

        torch.cuda.empty_cache()
        try:
            CACHE_MANAGER = CacheManager(
                # Adds some wiggle room
                math.ceil(max_total_tokens / BLOCK_SIZE) + 10,
                self.num_layers,
                self.num_kv_heads,
                self.head_size,
                self.dtype,
                self.device,
            )
            _, batch = self.generate_token(batch)
        except Exception as e:
            logger.exception(
                f"Not enough memory to handle {max_total_tokens} total tokens with {len(batch.input_ids)} "
                f"prefill tokens. "
                f"You need to decrease `--max-batch-total-tokens` or `--max-batch-prefill-tokens`"
            )
            raise e
        del batch

744
745
    def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
        return self.tokenizer.decode(
746
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
747
748
749
750
751
752
        )

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
753
754
755
756
757
        start_seq_prefill: Optional[torch.Tensor],
        end_seq_prefill: Optional[torch.Tensor],
        block_tables: torch.Tensor,
        slots: torch.Tensor,
        input_lengths: torch.Tensor,
758
        max_s: int,
759
        lm_head_indices: Optional[torch.Tensor] = None,
760
    ) -> Tuple[torch.Tensor, torch.Tensor]:
761
762
        global CACHE_MANAGER

763
764
765
766
        # Model Forward
        return self.model.forward(
            input_ids=input_ids,
            position_ids=position_ids,
767
768
769
770
771
772
            start_seq_prefill=start_seq_prefill,
            end_seq_prefill=end_seq_prefill,
            kv_cache=CACHE_MANAGER.kv_cache,
            block_tables=block_tables,
            slots=slots,
            input_lengths=input_lengths,
773
            max_s=max_s,
774
            lm_head_indices=lm_head_indices,
775
776
777
778
779
780
        )

    @tracer.start_as_current_span("generate_token")
    def generate_token(
        self, batch: FlashCausalLMBatch
    ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
781
        prefill = batch.start_seq_prefill is not None
782
        prefill_logprobs = batch.prefill_next_token_indices is not None
783

784
785
786
        if batch.needed_blocks_slots:
            # Allocate blocks to this batch
            CACHE_MANAGER.allocate(batch)
787

788
        out = self.forward(
789
790
            batch.input_ids,
            batch.position_ids,
791
792
793
794
795
            batch.start_seq_prefill,
            batch.end_seq_prefill,
            batch.block_tables_tensor,
            batch.slots[batch.slot_indices],
            batch.input_lengths_tensor,
796
            batch.max_seqlen,
797
            batch.prefill_head_indices,
798
799
        )

800
801
        if prefill:
            next_token_logits = (
802
                out[batch.prefill_next_token_indices] if prefill_logprobs else out
803
804
805
806
807
808
809
810
            )
        else:
            next_token_logits = out

        next_input_ids, next_token_logprobs = batch.next_token_chooser(
            batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits
        )

811
        if prefill:
812
            if len(batch) > 1 and prefill_logprobs:
813
814
                # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
                # When batch == 1, we will just use the batch.input_ids values directly
815
                prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
816
817

            next_position_ids = batch.position_ids.new_empty(len(batch))
818
            batch.slot_indices = batch.slot_indices[batch.end_seq_prefill - 1]
819
820
821
            # We do not need start_seq_prefill and end_seq_prefill anymore
            batch.start_seq_prefill = None
            batch.end_seq_prefill = None
822
823
824
825
        else:
            prefill_logprobs = None
            next_position_ids = batch.position_ids

826
827
828
829
830
        # Cumulative length
        cumulative_length = 0

        # Results
        generations: List[Generation] = []
831
        stopped = True
832
833
834
835
836
837
838

        # Zipped iterator
        iterator = zip(
            batch.input_lengths,
            batch.all_input_ids,
        )

839
840
841
842
        # We do two for loops as the first one can run completely asynchronously from the GPU while for the second
        # one, we need to first do a GPU <-> CPU sync
        # It is faster if we delay this sync for the maximum amount of time

843
844
845
846
847
        # For each member of the batch
        for i, (
            input_length,
            all_input_ids,
        ) in enumerate(iterator):
848
            # Indexing metadata
849
850
851
            start_index = cumulative_length
            end_index = cumulative_length + input_length

852
            if prefill:
853
854
855
856
857
                # Indexing metadata
                out_start_index = batch.prefill_cu_outlens[i]
                out_end_index = batch.prefill_cu_outlens[i + 1]
                out_length = out_end_index - out_start_index

858
859
860
861
862
863
                # Initialize position_ids
                # In decode, we do not need this as we can just increment position ids
                next_position_ids[i] = batch.position_ids[end_index - 1]

                # Used to gather prefill logprobs
                # Copy batch.input_ids to prefill_token_indices
864
865
866
867
868
869
870
871
872
873
                if prefill_logprobs:
                    if len(batch) > 1:
                        prefill_tokens_indices[
                            out_start_index : out_end_index - 1
                        ] = batch.input_ids[start_index + 1 : start_index + out_length]
                    else:
                        # Set prefill_tokens_indices to the correct slice
                        prefill_tokens_indices = batch.input_ids[
                            start_index + 1 : start_index + out_length
                        ]
874

875
            batch.all_input_ids_tensor[i, input_length] = next_input_ids[i]
876
877
878
879
880
881

            cumulative_length += input_length

        # Set values in batch
        batch.input_ids = next_input_ids
        batch.position_ids = next_position_ids + 1
882
883
        batch.input_lengths_tensor += 1
        batch.slot_indices += 1
884

885
        if prefill and prefill_logprobs:
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
            # Get prefill logprobs
            prefill_logprobs_tensor = torch.log_softmax(out, -1)
            prefill_logprobs = torch.gather(
                prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)
            )
            # GPU <-> CPU sync
            prefill_logprobs = prefill_logprobs.view(-1).tolist()

        # GPU <-> CPU sync
        next_token_logprobs = next_token_logprobs.tolist()
        next_token_ids = batch.input_ids.tolist()

        # Zipped iterator
        iterator = zip(
            batch.requests,
            batch.input_lengths,
902
903
            batch.prefix_offsets,
            batch.read_offsets,
904
905
            batch.stopping_criterias,
            batch.all_input_ids,
906
907
            batch.next_token_chooser.do_sample,
            batch.next_token_chooser.seeds,
908
909
910
911
912
913
914
915
            next_token_ids,
            next_token_logprobs,
        )

        # For each member of the batch
        for i, (
            request,
            input_length,
916
917
            prefix_offset,
            read_offset,
918
919
            stopping_criteria,
            all_input_ids,
920
921
            do_sample,
            seed,
922
923
924
            next_token_id,
            next_token_logprob,
        ) in enumerate(iterator):
925
            # Append next token to all tokens
926
            all_input_ids.append(next_token_id)
927
928

            # Generated token
929
            next_token_text, prefix_offset, read_offset = self.decode_token(
930
                all_input_ids,
931
932
                prefix_offset,
                read_offset,
933
934
935
936
            )

            # Evaluate stopping criteria
            stop, reason = stopping_criteria(
937
                next_token_id,
938
939
940
                next_token_text,
            )

941
            if not stop:
942
                stopped = False
943

944
945
946
947
948
949
950
951
952
            # Shard generations
            # All generations will be appended in the rust sharded client
            if i % self.world_size == self.rank:
                if stop:
                    # Decode generated tokens
                    output_text = self.decode(
                        all_input_ids[-stopping_criteria.current_tokens :]
                    )
                    generated_text = GeneratedText(
953
954
955
956
                        output_text,
                        stopping_criteria.current_tokens,
                        reason,
                        seed if do_sample else None,
957
958
959
960
961
                    )
                else:
                    generated_text = None

                # Prefill
962
963
964
965
                if prefill and request.prefill_logprobs:
                    out_start_index = batch.prefill_cu_outlens[i]
                    out_end_index = batch.prefill_cu_outlens[i + 1]

966
967
                    # Remove generated token to only have prefill and add nan for first prompt token
                    request_prefill_logprobs = [float("nan")] + prefill_logprobs[
968
                        out_start_index : out_end_index - 1
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
                    ]
                    prefill_token_ids = all_input_ids[:-1]
                    prefill_texts = self.tokenizer.batch_decode(
                        prefill_token_ids,
                        clean_up_tokenization_spaces=False,
                        skip_special_tokens=False,
                    )
                    prefill_tokens = PrefillTokens(
                        prefill_token_ids, request_prefill_logprobs, prefill_texts
                    )
                else:
                    prefill_tokens = None

                generation = Generation(
                    request.id,
                    prefill_tokens,
                    next_token_id,
                    next_token_logprob,
                    next_token_text,
                    next_token_id in self.all_special_ids,
                    generated_text,
990
991
                )

992
                generations.append(generation)
993

994
            # Update values
995
            batch.input_lengths[i] = input_length + 1
996
997
            batch.prefix_offsets[i] = prefix_offset
            batch.read_offsets[i] = read_offset
998
999
            batch.all_input_ids[i] = all_input_ids

1000
1001
1002
1003
1004
        if stopped:
            del batch
            # No need to return a batch if we know that all requests stopped
            return generations, None

1005
1006
1007
        batch.prefill_cu_outlens = None
        batch.prefill_head_indices = None
        batch.prefill_next_token_indices = None
1008
1009
        batch.max_seqlen = batch.max_seqlen + 1

1010
        return generations, batch